diff options
Diffstat (limited to 'contrib/llvm/lib/Transforms')
181 files changed, 29918 insertions, 19545 deletions
diff --git a/contrib/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp b/contrib/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp index a97db6f..3598766 100644 --- a/contrib/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp +++ b/contrib/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp @@ -101,7 +101,9 @@ namespace { struct CoroCleanup : FunctionPass { static char ID; // Pass identification, replacement for typeid - CoroCleanup() : FunctionPass(ID) {} + CoroCleanup() : FunctionPass(ID) { + initializeCoroCleanupPass(*PassRegistry::getPassRegistry()); + } std::unique_ptr<Lowerer> L; @@ -124,6 +126,7 @@ struct CoroCleanup : FunctionPass { if (!L) AU.setPreservesAll(); } + StringRef getPassName() const override { return "Coroutine Cleanup"; } }; } diff --git a/contrib/llvm/lib/Transforms/Coroutines/CoroEarly.cpp b/contrib/llvm/lib/Transforms/Coroutines/CoroEarly.cpp index e8bb0ca..ba05896 100644 --- a/contrib/llvm/lib/Transforms/Coroutines/CoroEarly.cpp +++ b/contrib/llvm/lib/Transforms/Coroutines/CoroEarly.cpp @@ -183,7 +183,9 @@ namespace { struct CoroEarly : public FunctionPass { static char ID; // Pass identification, replacement for typeid. - CoroEarly() : FunctionPass(ID) {} + CoroEarly() : FunctionPass(ID) { + initializeCoroEarlyPass(*PassRegistry::getPassRegistry()); + } std::unique_ptr<Lowerer> L; @@ -208,6 +210,9 @@ struct CoroEarly : public FunctionPass { void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); } + StringRef getPassName() const override { + return "Lower early coroutine intrinsics"; + } }; } diff --git a/contrib/llvm/lib/Transforms/Coroutines/CoroElide.cpp b/contrib/llvm/lib/Transforms/Coroutines/CoroElide.cpp index 99974d8..42fd6d7 100644 --- a/contrib/llvm/lib/Transforms/Coroutines/CoroElide.cpp +++ b/contrib/llvm/lib/Transforms/Coroutines/CoroElide.cpp @@ -92,7 +92,7 @@ static void removeTailCallAttribute(AllocaInst *Frame, AAResults &AA) { // Given a resume function @f.resume(%f.frame* %frame), returns %f.frame type. static Type *getFrameType(Function *Resume) { - auto *ArgType = Resume->getArgumentList().front().getType(); + auto *ArgType = Resume->arg_begin()->getType(); return cast<PointerType>(ArgType)->getElementType(); } @@ -127,7 +127,8 @@ void Lowerer::elideHeapAllocations(Function *F, Type *FrameTy, AAResults &AA) { // is spilled into the coroutine frame and recreate the alignment information // here. Possibly we will need to do a mini SROA here and break the coroutine // frame into individual AllocaInst recreating the original alignment. - auto *Frame = new AllocaInst(FrameTy, "", InsertPt); + const DataLayout &DL = F->getParent()->getDataLayout(); + auto *Frame = new AllocaInst(FrameTy, DL.getAllocaAddrSpace(), "", InsertPt); auto *FrameVoidPtr = new BitCastInst(Frame, Type::getInt8PtrTy(C), "vFrame", InsertPt); @@ -257,7 +258,9 @@ static bool replaceDevirtTrigger(Function &F) { namespace { struct CoroElide : FunctionPass { static char ID; - CoroElide() : FunctionPass(ID) {} + CoroElide() : FunctionPass(ID) { + initializeCoroElidePass(*PassRegistry::getPassRegistry()); + } std::unique_ptr<Lowerer> L; @@ -300,6 +303,7 @@ struct CoroElide : FunctionPass { void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<AAResultsWrapperPass>(); } + StringRef getPassName() const override { return "Coroutine Elision"; } }; } diff --git a/contrib/llvm/lib/Transforms/Coroutines/CoroFrame.cpp b/contrib/llvm/lib/Transforms/Coroutines/CoroFrame.cpp index bb28558a..85e9003 100644 --- a/contrib/llvm/lib/Transforms/Coroutines/CoroFrame.cpp +++ b/contrib/llvm/lib/Transforms/Coroutines/CoroFrame.cpp @@ -133,6 +133,7 @@ struct SuspendCrossingInfo { }; } // end anonymous namespace +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) LLVM_DUMP_METHOD void SuspendCrossingInfo::dump(StringRef Label, BitVector const &BV) const { dbgs() << Label << ":"; @@ -151,6 +152,7 @@ LLVM_DUMP_METHOD void SuspendCrossingInfo::dump() const { } dbgs() << "\n"; } +#endif SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape) : Mapping(F) { @@ -175,7 +177,7 @@ SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape) // consume. Note, that crossing coro.save also requires a spill, as any code // between coro.save and coro.suspend may resume the coroutine and all of the // state needs to be saved by that time. - auto markSuspendBlock = [&](IntrinsicInst* BarrierInst) { + auto markSuspendBlock = [&](IntrinsicInst *BarrierInst) { BasicBlock *SuspendBlock = BarrierInst->getParent(); auto &B = getBlockData(SuspendBlock); B.Suspend = true; @@ -345,6 +347,27 @@ static StructType *buildFrameType(Function &F, coro::Shape &Shape, return FrameTy; } +// We need to make room to insert a spill after initial PHIs, but before +// catchswitch instruction. Placing it before violates the requirement that +// catchswitch, like all other EHPads must be the first nonPHI in a block. +// +// Split away catchswitch into a separate block and insert in its place: +// +// cleanuppad <InsertPt> cleanupret. +// +// cleanupret instruction will act as an insert point for the spill. +static Instruction *splitBeforeCatchSwitch(CatchSwitchInst *CatchSwitch) { + BasicBlock *CurrentBlock = CatchSwitch->getParent(); + BasicBlock *NewBlock = CurrentBlock->splitBasicBlock(CatchSwitch); + CurrentBlock->getTerminator()->eraseFromParent(); + + auto *CleanupPad = + CleanupPadInst::Create(CatchSwitch->getParentPad(), {}, "", CurrentBlock); + auto *CleanupRet = + CleanupReturnInst::Create(CleanupPad, NewBlock, CurrentBlock); + return CleanupRet; +} + // Replace all alloca and SSA values that are accessed across suspend points // with GetElementPointer from coroutine frame + loads and stores. Create an // AllocaSpillBB that will become the new entry block for the resume parts of @@ -420,15 +443,34 @@ static Instruction *insertSpills(SpillInfo &Spills, coro::Shape &Shape) { report_fatal_error("Coroutines cannot handle non static allocas yet"); } else { // Otherwise, create a store instruction storing the value into the - // coroutine frame. For, argument, we will place the store instruction - // right after the coroutine frame pointer instruction, i.e. bitcase of - // coro.begin from i8* to %f.frame*. For all other values, the spill is - // placed immediately after the definition. - Builder.SetInsertPoint( - isa<Argument>(CurrentValue) - ? FramePtr->getNextNode() - : dyn_cast<Instruction>(E.def())->getNextNode()); + // coroutine frame. + + Instruction *InsertPt = nullptr; + if (isa<Argument>(CurrentValue)) { + // For arguments, we will place the store instruction right after + // the coroutine frame pointer instruction, i.e. bitcast of + // coro.begin from i8* to %f.frame*. + InsertPt = FramePtr->getNextNode(); + } else if (auto *II = dyn_cast<InvokeInst>(CurrentValue)) { + // If we are spilling the result of the invoke instruction, split the + // normal edge and insert the spill in the new block. + auto NewBB = SplitEdge(II->getParent(), II->getNormalDest()); + InsertPt = NewBB->getTerminator(); + } else if (dyn_cast<PHINode>(CurrentValue)) { + // Skip the PHINodes and EH pads instructions. + BasicBlock *DefBlock = cast<Instruction>(E.def())->getParent(); + if (auto *CSI = dyn_cast<CatchSwitchInst>(DefBlock->getTerminator())) + InsertPt = splitBeforeCatchSwitch(CSI); + else + InsertPt = &*DefBlock->getFirstInsertionPt(); + } else { + // For all other values, the spill is placed immediately after + // the definition. + assert(!isa<TerminatorInst>(E.def()) && "unexpected terminator"); + InsertPt = cast<Instruction>(E.def())->getNextNode(); + } + Builder.SetInsertPoint(InsertPt); auto *G = Builder.CreateConstInBoundsGEP2_32( FrameTy, FramePtr, 0, Index, CurrentValue->getName() + Twine(".spill.addr")); @@ -477,6 +519,78 @@ static Instruction *insertSpills(SpillInfo &Spills, coro::Shape &Shape) { return FramePtr; } +// Sets the unwind edge of an instruction to a particular successor. +static void setUnwindEdgeTo(TerminatorInst *TI, BasicBlock *Succ) { + if (auto *II = dyn_cast<InvokeInst>(TI)) + II->setUnwindDest(Succ); + else if (auto *CS = dyn_cast<CatchSwitchInst>(TI)) + CS->setUnwindDest(Succ); + else if (auto *CR = dyn_cast<CleanupReturnInst>(TI)) + CR->setUnwindDest(Succ); + else + llvm_unreachable("unexpected terminator instruction"); +} + +// Replaces all uses of OldPred with the NewPred block in all PHINodes in a +// block. +static void updatePhiNodes(BasicBlock *DestBB, BasicBlock *OldPred, + BasicBlock *NewPred, + PHINode *LandingPadReplacement) { + unsigned BBIdx = 0; + for (BasicBlock::iterator I = DestBB->begin(); isa<PHINode>(I); ++I) { + PHINode *PN = cast<PHINode>(I); + + // We manually update the LandingPadReplacement PHINode and it is the last + // PHI Node. So, if we find it, we are done. + if (LandingPadReplacement == PN) + break; + + // Reuse the previous value of BBIdx if it lines up. In cases where we + // have multiple phi nodes with *lots* of predecessors, this is a speed + // win because we don't have to scan the PHI looking for TIBB. This + // happens because the BB list of PHI nodes are usually in the same + // order. + if (PN->getIncomingBlock(BBIdx) != OldPred) + BBIdx = PN->getBasicBlockIndex(OldPred); + + assert(BBIdx != (unsigned)-1 && "Invalid PHI Index!"); + PN->setIncomingBlock(BBIdx, NewPred); + } +} + +// Uses SplitEdge unless the successor block is an EHPad, in which case do EH +// specific handling. +static BasicBlock *ehAwareSplitEdge(BasicBlock *BB, BasicBlock *Succ, + LandingPadInst *OriginalPad, + PHINode *LandingPadReplacement) { + auto *PadInst = Succ->getFirstNonPHI(); + if (!LandingPadReplacement && !PadInst->isEHPad()) + return SplitEdge(BB, Succ); + + auto *NewBB = BasicBlock::Create(BB->getContext(), "", BB->getParent(), Succ); + setUnwindEdgeTo(BB->getTerminator(), NewBB); + updatePhiNodes(Succ, BB, NewBB, LandingPadReplacement); + + if (LandingPadReplacement) { + auto *NewLP = OriginalPad->clone(); + auto *Terminator = BranchInst::Create(Succ, NewBB); + NewLP->insertBefore(Terminator); + LandingPadReplacement->addIncoming(NewLP, NewBB); + return NewBB; + } + Value *ParentPad = nullptr; + if (auto *FuncletPad = dyn_cast<FuncletPadInst>(PadInst)) + ParentPad = FuncletPad->getParentPad(); + else if (auto *CatchSwitch = dyn_cast<CatchSwitchInst>(PadInst)) + ParentPad = CatchSwitch->getParentPad(); + else + llvm_unreachable("handling for other EHPads not implemented yet"); + + auto *NewCleanupPad = CleanupPadInst::Create(ParentPad, {}, "", NewBB); + CleanupReturnInst::Create(NewCleanupPad, Succ, NewBB); + return NewBB; +} + static void rewritePHIs(BasicBlock &BB) { // For every incoming edge we will create a block holding all // incoming values in a single PHI nodes. @@ -499,9 +613,22 @@ static void rewritePHIs(BasicBlock &BB) { // TODO: Simplify PHINodes in the basic block to remove duplicate // predecessors. + LandingPadInst *LandingPad = nullptr; + PHINode *ReplPHI = nullptr; + if ((LandingPad = dyn_cast_or_null<LandingPadInst>(BB.getFirstNonPHI()))) { + // ehAwareSplitEdge will clone the LandingPad in all the edge blocks. + // We replace the original landing pad with a PHINode that will collect the + // results from all of them. + ReplPHI = PHINode::Create(LandingPad->getType(), 1, "", LandingPad); + ReplPHI->takeName(LandingPad); + LandingPad->replaceAllUsesWith(ReplPHI); + // We will erase the original landing pad at the end of this function after + // ehAwareSplitEdge cloned it in the transition blocks. + } + SmallVector<BasicBlock *, 8> Preds(pred_begin(&BB), pred_end(&BB)); for (BasicBlock *Pred : Preds) { - auto *IncomingBB = SplitEdge(Pred, &BB); + auto *IncomingBB = ehAwareSplitEdge(Pred, &BB, LandingPad, ReplPHI); IncomingBB->setName(BB.getName() + Twine(".from.") + Pred->getName()); auto *PN = cast<PHINode>(&BB.front()); do { @@ -513,7 +640,14 @@ static void rewritePHIs(BasicBlock &BB) { InputV->addIncoming(V, Pred); PN->setIncomingValue(Index, InputV); PN = dyn_cast<PHINode>(PN->getNextNode()); - } while (PN); + } while (PN != ReplPHI); // ReplPHI is either null or the PHI that replaced + // the landing pad. + } + + if (LandingPad) { + // Calls to ehAwareSplitEdge function cloned the original lading pad. + // No longer need it. + LandingPad->eraseFromParent(); } } @@ -665,9 +799,9 @@ void coro::buildCoroutineFrame(Function &F, Shape &Shape) { splitAround(CSI, "CoroSuspend"); } - // Put fallthrough CoroEnd into its own block. Note: Shape::buildFrom places - // the fallthrough coro.end as the first element of CoroEnds array. - splitAround(Shape.CoroEnds.front(), "CoroEnd"); + // Put CoroEnds into their own blocks. + for (CoroEndInst *CE : Shape.CoroEnds) + splitAround(CE, "CoroEnd"); // Transforms multi-edge PHI Nodes, so that any value feeding into a PHI will // never has its definition separated from the PHI by the suspend point. @@ -679,21 +813,25 @@ void coro::buildCoroutineFrame(Function &F, Shape &Shape) { IRBuilder<> Builder(F.getContext()); SpillInfo Spills; - // See if there are materializable instructions across suspend points. - for (Instruction &I : instructions(F)) - if (materializable(I)) - for (User *U : I.users()) - if (Checker.isDefinitionAcrossSuspend(I, U)) - Spills.emplace_back(&I, U); - - // Rewrite materializable instructions to be materialized at the use point. - std::sort(Spills.begin(), Spills.end()); - DEBUG(dump("Materializations", Spills)); - rewriteMaterializableInstructions(Builder, Spills); + for (int Repeat = 0; Repeat < 4; ++Repeat) { + // See if there are materializable instructions across suspend points. + for (Instruction &I : instructions(F)) + if (materializable(I)) + for (User *U : I.users()) + if (Checker.isDefinitionAcrossSuspend(I, U)) + Spills.emplace_back(&I, U); + + if (Spills.empty()) + break; + + // Rewrite materializable instructions to be materialized at the use point. + DEBUG(dump("Materializations", Spills)); + rewriteMaterializableInstructions(Builder, Spills); + Spills.clear(); + } // Collect the spills for arguments and other not-materializable values. - Spills.clear(); - for (Argument &A : F.getArgumentList()) + for (Argument &A : F.args()) for (User *U : A.users()) if (Checker.isDefinitionAcrossSuspend(A, U)) Spills.emplace_back(&A, U); @@ -714,12 +852,9 @@ void coro::buildCoroutineFrame(Function &F, Shape &Shape) { if (I.getType()->isTokenTy()) report_fatal_error( "token definition is separated from the use by a suspend point"); - assert(!materializable(I) && - "rewriteMaterializable did not do its job"); Spills.emplace_back(&I, U); } } - std::sort(Spills.begin(), Spills.end()); DEBUG(dump("Spills", Spills)); moveSpillUsesAfterCoroBegin(F, Spills, Shape.CoroBegin); Shape.FrameTy = buildFrameType(F, Shape, Spills); diff --git a/contrib/llvm/lib/Transforms/Coroutines/CoroInstr.h b/contrib/llvm/lib/Transforms/Coroutines/CoroInstr.h index e03cef4..9a8cc5a 100644 --- a/contrib/llvm/lib/Transforms/Coroutines/CoroInstr.h +++ b/contrib/llvm/lib/Transforms/Coroutines/CoroInstr.h @@ -23,6 +23,9 @@ // the Coroutine library. //===----------------------------------------------------------------------===// +#ifndef LLVM_LIB_TRANSFORMS_COROUTINES_COROINSTR_H +#define LLVM_LIB_TRANSFORMS_COROUTINES_COROINSTR_H + #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IntrinsicInst.h" @@ -55,10 +58,10 @@ public: } // Methods to support type inquiry through isa, cast, and dyn_cast: - static inline bool classof(const IntrinsicInst *I) { + static bool classof(const IntrinsicInst *I) { return I->getIntrinsicID() == Intrinsic::coro_subfn_addr; } - static inline bool classof(const Value *V) { + static bool classof(const Value *V) { return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V)); } }; @@ -67,10 +70,10 @@ public: class LLVM_LIBRARY_VISIBILITY CoroAllocInst : public IntrinsicInst { public: // Methods to support type inquiry through isa, cast, and dyn_cast: - static inline bool classof(const IntrinsicInst *I) { + static bool classof(const IntrinsicInst *I) { return I->getIntrinsicID() == Intrinsic::coro_alloc; } - static inline bool classof(const Value *V) { + static bool classof(const Value *V) { return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V)); } }; @@ -172,10 +175,10 @@ public: } // Methods to support type inquiry through isa, cast, and dyn_cast: - static inline bool classof(const IntrinsicInst *I) { + static bool classof(const IntrinsicInst *I) { return I->getIntrinsicID() == Intrinsic::coro_id; } - static inline bool classof(const Value *V) { + static bool classof(const Value *V) { return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V)); } }; @@ -184,10 +187,10 @@ public: class LLVM_LIBRARY_VISIBILITY CoroFrameInst : public IntrinsicInst { public: // Methods to support type inquiry through isa, cast, and dyn_cast: - static inline bool classof(const IntrinsicInst *I) { + static bool classof(const IntrinsicInst *I) { return I->getIntrinsicID() == Intrinsic::coro_frame; } - static inline bool classof(const Value *V) { + static bool classof(const Value *V) { return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V)); } }; @@ -200,10 +203,10 @@ public: Value *getFrame() const { return getArgOperand(FrameArg); } // Methods to support type inquiry through isa, cast, and dyn_cast: - static inline bool classof(const IntrinsicInst *I) { + static bool classof(const IntrinsicInst *I) { return I->getIntrinsicID() == Intrinsic::coro_free; } - static inline bool classof(const Value *V) { + static bool classof(const Value *V) { return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V)); } }; @@ -218,10 +221,10 @@ public: Value *getMem() const { return getArgOperand(MemArg); } // Methods for support type inquiry through isa, cast, and dyn_cast: - static inline bool classof(const IntrinsicInst *I) { + static bool classof(const IntrinsicInst *I) { return I->getIntrinsicID() == Intrinsic::coro_begin; } - static inline bool classof(const Value *V) { + static bool classof(const Value *V) { return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V)); } }; @@ -230,10 +233,10 @@ public: class LLVM_LIBRARY_VISIBILITY CoroSaveInst : public IntrinsicInst { public: // Methods to support type inquiry through isa, cast, and dyn_cast: - static inline bool classof(const IntrinsicInst *I) { + static bool classof(const IntrinsicInst *I) { return I->getIntrinsicID() == Intrinsic::coro_save; } - static inline bool classof(const Value *V) { + static bool classof(const Value *V) { return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V)); } }; @@ -251,10 +254,10 @@ public: } // Methods to support type inquiry through isa, cast, and dyn_cast: - static inline bool classof(const IntrinsicInst *I) { + static bool classof(const IntrinsicInst *I) { return I->getIntrinsicID() == Intrinsic::coro_promise; } - static inline bool classof(const Value *V) { + static bool classof(const Value *V) { return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V)); } }; @@ -276,10 +279,10 @@ public: } // Methods to support type inquiry through isa, cast, and dyn_cast: - static inline bool classof(const IntrinsicInst *I) { + static bool classof(const IntrinsicInst *I) { return I->getIntrinsicID() == Intrinsic::coro_suspend; } - static inline bool classof(const Value *V) { + static bool classof(const Value *V) { return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V)); } }; @@ -288,10 +291,10 @@ public: class LLVM_LIBRARY_VISIBILITY CoroSizeInst : public IntrinsicInst { public: // Methods to support type inquiry through isa, cast, and dyn_cast: - static inline bool classof(const IntrinsicInst *I) { + static bool classof(const IntrinsicInst *I) { return I->getIntrinsicID() == Intrinsic::coro_size; } - static inline bool classof(const Value *V) { + static bool classof(const Value *V) { return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V)); } }; @@ -307,12 +310,14 @@ public: } // Methods to support type inquiry through isa, cast, and dyn_cast: - static inline bool classof(const IntrinsicInst *I) { + static bool classof(const IntrinsicInst *I) { return I->getIntrinsicID() == Intrinsic::coro_end; } - static inline bool classof(const Value *V) { + static bool classof(const Value *V) { return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V)); } }; } // End namespace llvm. + +#endif diff --git a/contrib/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/contrib/llvm/lib/Transforms/Coroutines/CoroSplit.cpp index 7a3f4f6..173dc05 100644 --- a/contrib/llvm/lib/Transforms/Coroutines/CoroSplit.cpp +++ b/contrib/llvm/lib/Transforms/Coroutines/CoroSplit.cpp @@ -23,6 +23,7 @@ #include "llvm/Analysis/CallGraphSCCPass.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Verifier.h" #include "llvm/Transforms/Scalar.h" @@ -144,6 +145,33 @@ static void replaceFallthroughCoroEnd(IntrinsicInst *End, BB->getTerminator()->eraseFromParent(); } +// In Resumers, we replace unwind coro.end with True to force the immediate +// unwind to caller. +static void replaceUnwindCoroEnds(coro::Shape &Shape, ValueToValueMapTy &VMap) { + if (Shape.CoroEnds.empty()) + return; + + LLVMContext &Context = Shape.CoroEnds.front()->getContext(); + auto *True = ConstantInt::getTrue(Context); + for (CoroEndInst *CE : Shape.CoroEnds) { + if (!CE->isUnwind()) + continue; + + auto *NewCE = cast<IntrinsicInst>(VMap[CE]); + + // If coro.end has an associated bundle, add cleanupret instruction. + if (auto Bundle = NewCE->getOperandBundle(LLVMContext::OB_funclet)) { + Value *FromPad = Bundle->Inputs[0]; + auto *CleanupRet = CleanupReturnInst::Create(FromPad, nullptr, NewCE); + NewCE->getParent()->splitBasicBlock(NewCE); + CleanupRet->getParent()->getTerminator()->eraseFromParent(); + } + + NewCE->replaceAllUsesWith(True); + NewCE->eraseFromParent(); + } +} + // Rewrite final suspend point handling. We do not use suspend index to // represent the final suspend point. Instead we zero-out ResumeFnAddr in the // coroutine frame, since it is undefined behavior to resume a coroutine @@ -157,9 +185,9 @@ static void handleFinalSuspend(IRBuilder<> &Builder, Value *FramePtr, coro::Shape &Shape, SwitchInst *Switch, bool IsDestroy) { assert(Shape.HasFinalSuspend); - auto FinalCase = --Switch->case_end(); - BasicBlock *ResumeBB = FinalCase.getCaseSuccessor(); - Switch->removeCase(FinalCase); + auto FinalCaseIt = std::prev(Switch->case_end()); + BasicBlock *ResumeBB = FinalCaseIt->getCaseSuccessor(); + Switch->removeCase(FinalCaseIt); if (IsDestroy) { BasicBlock *OldSwitchBB = Switch->getParent(); auto *NewSwitchBB = OldSwitchBB->splitBasicBlock(Switch, "Switch"); @@ -188,26 +216,18 @@ static Function *createClone(Function &F, Twine Suffix, coro::Shape &Shape, Function *NewF = Function::Create(FnTy, GlobalValue::LinkageTypes::InternalLinkage, F.getName() + Suffix, M); - NewF->addAttribute(1, Attribute::NonNull); - NewF->addAttribute(1, Attribute::NoAlias); + NewF->addParamAttr(0, Attribute::NonNull); + NewF->addParamAttr(0, Attribute::NoAlias); ValueToValueMapTy VMap; // Replace all args with undefs. The buildCoroutineFrame algorithm already // rewritten access to the args that occurs after suspend points with loads // and stores to/from the coroutine frame. - for (Argument &A : F.getArgumentList()) + for (Argument &A : F.args()) VMap[&A] = UndefValue::get(A.getType()); SmallVector<ReturnInst *, 4> Returns; - if (DISubprogram *SP = F.getSubprogram()) { - // If we have debug info, add mapping for the metadata nodes that should not - // be cloned by CloneFunctionInfo. - auto &MD = VMap.MD(); - MD[SP->getUnit()].reset(SP->getUnit()); - MD[SP->getType()].reset(SP->getType()); - MD[SP->getFile()].reset(SP->getFile()); - } CloneFunctionInto(NewF, &F, VMap, /*ModuleLevelChanges=*/true, Returns); // Remove old returns. @@ -216,10 +236,8 @@ static Function *createClone(Function &F, Twine Suffix, coro::Shape &Shape, // Remove old return attributes. NewF->removeAttributes( - AttributeSet::ReturnIndex, - AttributeSet::get( - NewF->getContext(), AttributeSet::ReturnIndex, - AttributeFuncs::typeIncompatible(NewF->getReturnType()))); + AttributeList::ReturnIndex, + AttributeFuncs::typeIncompatible(NewF->getReturnType())); // Make AllocaSpillBlock the new entry block. auto *SwitchBB = cast<BasicBlock>(VMap[ResumeEntry]); @@ -236,7 +254,7 @@ static Function *createClone(Function &F, Twine Suffix, coro::Shape &Shape, IRBuilder<> Builder(&NewF->getEntryBlock().front()); // Remap frame pointer. - Argument *NewFramePtr = &NewF->getArgumentList().front(); + Argument *NewFramePtr = &*NewF->arg_begin(); Value *OldFramePtr = cast<Value>(VMap[Shape.FramePtr]); NewFramePtr->takeName(OldFramePtr); OldFramePtr->replaceAllUsesWith(NewFramePtr); @@ -270,9 +288,7 @@ static Function *createClone(Function &F, Twine Suffix, coro::Shape &Shape, // Remove coro.end intrinsics. replaceFallthroughCoroEnd(Shape.CoroEnds.front(), VMap); - // FIXME: coming in upcoming patches: - // replaceUnwindCoroEnds(Shape.CoroEnds, VMap); - + replaceUnwindCoroEnds(Shape, VMap); // Eliminate coro.free from the clones, replacing it with 'null' in cleanup, // to suppress deallocation code. coro::replaceCoroFree(cast<CoroIdInst>(VMap[Shape.CoroBegin->getId()]), @@ -284,8 +300,16 @@ static Function *createClone(Function &F, Twine Suffix, coro::Shape &Shape, } static void removeCoroEnds(coro::Shape &Shape) { - for (CoroEndInst *CE : Shape.CoroEnds) + if (Shape.CoroEnds.empty()) + return; + + LLVMContext &Context = Shape.CoroEnds.front()->getContext(); + auto *False = ConstantInt::getFalse(Context); + + for (CoroEndInst *CE : Shape.CoroEnds) { + CE->replaceAllUsesWith(False); CE->eraseFromParent(); + } } static void replaceFrameSize(coro::Shape &Shape) { @@ -477,12 +501,87 @@ static void simplifySuspendPoints(coro::Shape &Shape) { S.resize(N); } +static SmallPtrSet<BasicBlock *, 4> getCoroBeginPredBlocks(CoroBeginInst *CB) { + // Collect all blocks that we need to look for instructions to relocate. + SmallPtrSet<BasicBlock *, 4> RelocBlocks; + SmallVector<BasicBlock *, 4> Work; + Work.push_back(CB->getParent()); + + do { + BasicBlock *Current = Work.pop_back_val(); + for (BasicBlock *BB : predecessors(Current)) + if (RelocBlocks.count(BB) == 0) { + RelocBlocks.insert(BB); + Work.push_back(BB); + } + } while (!Work.empty()); + return RelocBlocks; +} + +static SmallPtrSet<Instruction *, 8> +getNotRelocatableInstructions(CoroBeginInst *CoroBegin, + SmallPtrSetImpl<BasicBlock *> &RelocBlocks) { + SmallPtrSet<Instruction *, 8> DoNotRelocate; + // Collect all instructions that we should not relocate + SmallVector<Instruction *, 8> Work; + + // Start with CoroBegin and terminators of all preceding blocks. + Work.push_back(CoroBegin); + BasicBlock *CoroBeginBB = CoroBegin->getParent(); + for (BasicBlock *BB : RelocBlocks) + if (BB != CoroBeginBB) + Work.push_back(BB->getTerminator()); + + // For every instruction in the Work list, place its operands in DoNotRelocate + // set. + do { + Instruction *Current = Work.pop_back_val(); + DoNotRelocate.insert(Current); + for (Value *U : Current->operands()) { + auto *I = dyn_cast<Instruction>(U); + if (!I) + continue; + if (isa<AllocaInst>(U)) + continue; + if (DoNotRelocate.count(I) == 0) { + Work.push_back(I); + DoNotRelocate.insert(I); + } + } + } while (!Work.empty()); + return DoNotRelocate; +} + +static void relocateInstructionBefore(CoroBeginInst *CoroBegin, Function &F) { + // Analyze which non-alloca instructions are needed for allocation and + // relocate the rest to after coro.begin. We need to do it, since some of the + // targets of those instructions may be placed into coroutine frame memory + // for which becomes available after coro.begin intrinsic. + + auto BlockSet = getCoroBeginPredBlocks(CoroBegin); + auto DoNotRelocateSet = getNotRelocatableInstructions(CoroBegin, BlockSet); + + Instruction *InsertPt = CoroBegin->getNextNode(); + BasicBlock &BB = F.getEntryBlock(); // TODO: Look at other blocks as well. + for (auto B = BB.begin(), E = BB.end(); B != E;) { + Instruction &I = *B++; + if (isa<AllocaInst>(&I)) + continue; + if (&I == CoroBegin) + break; + if (DoNotRelocateSet.count(&I)) + continue; + I.moveBefore(InsertPt); + } +} + static void splitCoroutine(Function &F, CallGraph &CG, CallGraphSCC &SCC) { coro::Shape Shape(F); if (!Shape.CoroBegin) return; simplifySuspendPoints(Shape); + relocateInstructionBefore(Shape.CoroBegin, F); buildCoroutineFrame(F, Shape); replaceFrameSize(Shape); @@ -582,7 +681,9 @@ namespace { struct CoroSplit : public CallGraphSCCPass { static char ID; // Pass identification, replacement for typeid - CoroSplit() : CallGraphSCCPass(ID) {} + CoroSplit() : CallGraphSCCPass(ID) { + initializeCoroSplitPass(*PassRegistry::getPassRegistry()); + } bool Run = false; @@ -628,6 +729,7 @@ struct CoroSplit : public CallGraphSCCPass { void getAnalysisUsage(AnalysisUsage &AU) const override { CallGraphSCCPass::getAnalysisUsage(AU); } + StringRef getPassName() const override { return "Coroutine Splitting"; } }; } diff --git a/contrib/llvm/lib/Transforms/Coroutines/Coroutines.cpp b/contrib/llvm/lib/Transforms/Coroutines/Coroutines.cpp index 877ec34..44e1f9b 100644 --- a/contrib/llvm/lib/Transforms/Coroutines/Coroutines.cpp +++ b/contrib/llvm/lib/Transforms/Coroutines/Coroutines.cpp @@ -218,6 +218,8 @@ void coro::Shape::buildFrom(Function &F) { size_t FinalSuspendIndex = 0; clear(*this); SmallVector<CoroFrameInst *, 8> CoroFrames; + SmallVector<CoroSaveInst *, 2> UnusedCoroSaves; + for (Instruction &I : instructions(F)) { if (auto II = dyn_cast<IntrinsicInst>(&I)) { switch (II->getIntrinsicID()) { @@ -229,6 +231,12 @@ void coro::Shape::buildFrom(Function &F) { case Intrinsic::coro_frame: CoroFrames.push_back(cast<CoroFrameInst>(II)); break; + case Intrinsic::coro_save: + // After optimizations, coro_suspends using this coro_save might have + // been removed, remember orphaned coro_saves to remove them later. + if (II->use_empty()) + UnusedCoroSaves.push_back(cast<CoroSaveInst>(II)); + break; case Intrinsic::coro_suspend: CoroSuspends.push_back(cast<CoroSuspendInst>(II)); if (CoroSuspends.back()->isFinal()) { @@ -245,9 +253,9 @@ void coro::Shape::buildFrom(Function &F) { if (CoroBegin) report_fatal_error( "coroutine should have exactly one defining @llvm.coro.begin"); - CB->addAttribute(AttributeSet::ReturnIndex, Attribute::NonNull); - CB->addAttribute(AttributeSet::ReturnIndex, Attribute::NoAlias); - CB->removeAttribute(AttributeSet::FunctionIndex, + CB->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull); + CB->addAttribute(AttributeList::ReturnIndex, Attribute::NoAlias); + CB->removeAttribute(AttributeList::FunctionIndex, Attribute::NoDuplicate); CoroBegin = CB; } @@ -311,4 +319,8 @@ void coro::Shape::buildFrom(Function &F) { if (HasFinalSuspend && FinalSuspendIndex != CoroSuspends.size() - 1) std::swap(CoroSuspends[FinalSuspendIndex], CoroSuspends.back()); + + // Remove orphaned coro.saves. + for (CoroSaveInst *CoroSave : UnusedCoroSaves) + CoroSave->eraseFromParent(); } diff --git a/contrib/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp b/contrib/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp index 65b7bad..72bae20 100644 --- a/contrib/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp +++ b/contrib/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp @@ -29,8 +29,9 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/IPO/ArgumentPromotion.h" #include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/Optional.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Analysis/AliasAnalysis.h" @@ -38,6 +39,7 @@ #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/CallGraphSCCPass.h" +#include "llvm/Analysis/LazyCallGraph.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/CFG.h" @@ -51,323 +53,404 @@ #include "llvm/IR/Module.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/IPO.h" #include <set> using namespace llvm; #define DEBUG_TYPE "argpromotion" -STATISTIC(NumArgumentsPromoted , "Number of pointer arguments promoted"); +STATISTIC(NumArgumentsPromoted, "Number of pointer arguments promoted"); STATISTIC(NumAggregatesPromoted, "Number of aggregate arguments promoted"); -STATISTIC(NumByValArgsPromoted , "Number of byval arguments promoted"); -STATISTIC(NumArgumentsDead , "Number of dead pointer args eliminated"); +STATISTIC(NumByValArgsPromoted, "Number of byval arguments promoted"); +STATISTIC(NumArgumentsDead, "Number of dead pointer args eliminated"); -namespace { - /// ArgPromotion - The 'by reference' to 'by value' argument promotion pass. - /// - struct ArgPromotion : public CallGraphSCCPass { - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<AssumptionCacheTracker>(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); - getAAResultsAnalysisUsage(AU); - CallGraphSCCPass::getAnalysisUsage(AU); - } +/// A vector used to hold the indices of a single GEP instruction +typedef std::vector<uint64_t> IndicesVector; - bool runOnSCC(CallGraphSCC &SCC) override; - static char ID; // Pass identification, replacement for typeid - explicit ArgPromotion(unsigned maxElements = 3) - : CallGraphSCCPass(ID), maxElements(maxElements) { - initializeArgPromotionPass(*PassRegistry::getPassRegistry()); - } +/// DoPromotion - This method actually performs the promotion of the specified +/// arguments, and returns the new function. At this point, we know that it's +/// safe to do so. +static Function * +doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, + SmallPtrSetImpl<Argument *> &ByValArgsToTransform, + Optional<function_ref<void(CallSite OldCS, CallSite NewCS)>> + ReplaceCallSite) { - private: + // Start by computing a new prototype for the function, which is the same as + // the old function, but has modified arguments. + FunctionType *FTy = F->getFunctionType(); + std::vector<Type *> Params; - using llvm::Pass::doInitialization; - bool doInitialization(CallGraph &CG) override; - /// The maximum number of elements to expand, or 0 for unlimited. - unsigned maxElements; - }; -} + typedef std::set<std::pair<Type *, IndicesVector>> ScalarizeTable; -/// A vector used to hold the indices of a single GEP instruction -typedef std::vector<uint64_t> IndicesVector; + // ScalarizedElements - If we are promoting a pointer that has elements + // accessed out of it, keep track of which elements are accessed so that we + // can add one argument for each. + // + // Arguments that are directly loaded will have a zero element value here, to + // handle cases where there are both a direct load and GEP accesses. + // + std::map<Argument *, ScalarizeTable> ScalarizedElements; -static CallGraphNode * -PromoteArguments(CallGraphNode *CGN, CallGraph &CG, - function_ref<AAResults &(Function &F)> AARGetter, - unsigned MaxElements); -static bool isDenselyPacked(Type *type, const DataLayout &DL); -static bool canPaddingBeAccessed(Argument *Arg); -static bool isSafeToPromoteArgument(Argument *Arg, bool isByVal, AAResults &AAR, - unsigned MaxElements); -static CallGraphNode * -DoPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, - SmallPtrSetImpl<Argument *> &ByValArgsToTransform, CallGraph &CG); + // OriginalLoads - Keep track of a representative load instruction from the + // original function so that we can tell the alias analysis implementation + // what the new GEP/Load instructions we are inserting look like. + // We need to keep the original loads for each argument and the elements + // of the argument that are accessed. + std::map<std::pair<Argument *, IndicesVector>, LoadInst *> OriginalLoads; -char ArgPromotion::ID = 0; -INITIALIZE_PASS_BEGIN(ArgPromotion, "argpromotion", - "Promote 'by reference' arguments to scalars", false, false) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_END(ArgPromotion, "argpromotion", - "Promote 'by reference' arguments to scalars", false, false) + // Attribute - Keep track of the parameter attributes for the arguments + // that we are *not* promoting. For the ones that we do promote, the parameter + // attributes are lost + SmallVector<AttributeSet, 8> ArgAttrVec; + AttributeList PAL = F->getAttributes(); -Pass *llvm::createArgumentPromotionPass(unsigned maxElements) { - return new ArgPromotion(maxElements); -} + // First, determine the new argument list + unsigned ArgNo = 0; + for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E; + ++I, ++ArgNo) { + if (ByValArgsToTransform.count(&*I)) { + // Simple byval argument? Just add all the struct element types. + Type *AgTy = cast<PointerType>(I->getType())->getElementType(); + StructType *STy = cast<StructType>(AgTy); + Params.insert(Params.end(), STy->element_begin(), STy->element_end()); + ArgAttrVec.insert(ArgAttrVec.end(), STy->getNumElements(), + AttributeSet()); + ++NumByValArgsPromoted; + } else if (!ArgsToPromote.count(&*I)) { + // Unchanged argument + Params.push_back(I->getType()); + ArgAttrVec.push_back(PAL.getParamAttributes(ArgNo)); + } else if (I->use_empty()) { + // Dead argument (which are always marked as promotable) + ++NumArgumentsDead; -static bool runImpl(CallGraphSCC &SCC, CallGraph &CG, - function_ref<AAResults &(Function &F)> AARGetter, - unsigned MaxElements) { - bool Changed = false, LocalChange; + // There may be remaining metadata uses of the argument for things like + // llvm.dbg.value. Replace them with undef. + I->replaceAllUsesWith(UndefValue::get(I->getType())); + } else { + // Okay, this is being promoted. This means that the only uses are loads + // or GEPs which are only used by loads - do { // Iterate until we stop promoting from this SCC. - LocalChange = false; - // Attempt to promote arguments from all functions in this SCC. - for (CallGraphNode *OldNode : SCC) { - if (CallGraphNode *NewNode = - PromoteArguments(OldNode, CG, AARGetter, MaxElements)) { - LocalChange = true; - SCC.ReplaceNode(OldNode, NewNode); + // In this table, we will track which indices are loaded from the argument + // (where direct loads are tracked as no indices). + ScalarizeTable &ArgIndices = ScalarizedElements[&*I]; + for (User *U : I->users()) { + Instruction *UI = cast<Instruction>(U); + Type *SrcTy; + if (LoadInst *L = dyn_cast<LoadInst>(UI)) + SrcTy = L->getType(); + else + SrcTy = cast<GetElementPtrInst>(UI)->getSourceElementType(); + IndicesVector Indices; + Indices.reserve(UI->getNumOperands() - 1); + // Since loads will only have a single operand, and GEPs only a single + // non-index operand, this will record direct loads without any indices, + // and gep+loads with the GEP indices. + for (User::op_iterator II = UI->op_begin() + 1, IE = UI->op_end(); + II != IE; ++II) + Indices.push_back(cast<ConstantInt>(*II)->getSExtValue()); + // GEPs with a single 0 index can be merged with direct loads + if (Indices.size() == 1 && Indices.front() == 0) + Indices.clear(); + ArgIndices.insert(std::make_pair(SrcTy, Indices)); + LoadInst *OrigLoad; + if (LoadInst *L = dyn_cast<LoadInst>(UI)) + OrigLoad = L; + else + // Take any load, we will use it only to update Alias Analysis + OrigLoad = cast<LoadInst>(UI->user_back()); + OriginalLoads[std::make_pair(&*I, Indices)] = OrigLoad; } - } - Changed |= LocalChange; // Remember that we changed something. - } while (LocalChange); - - return Changed; -} -bool ArgPromotion::runOnSCC(CallGraphSCC &SCC) { - if (skipSCC(SCC)) - return false; + // Add a parameter to the function for each element passed in. + for (const auto &ArgIndex : ArgIndices) { + // not allowed to dereference ->begin() if size() is 0 + Params.push_back(GetElementPtrInst::getIndexedType( + cast<PointerType>(I->getType()->getScalarType())->getElementType(), + ArgIndex.second)); + ArgAttrVec.push_back(AttributeSet()); + assert(Params.back()); + } - // Get the callgraph information that we need to update to reflect our - // changes. - CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); + if (ArgIndices.size() == 1 && ArgIndices.begin()->second.empty()) + ++NumArgumentsPromoted; + else + ++NumAggregatesPromoted; + } + } - // We compute dedicated AA results for each function in the SCC as needed. We - // use a lambda referencing external objects so that they live long enough to - // be queried, but we re-use them each time. - Optional<BasicAAResult> BAR; - Optional<AAResults> AAR; - auto AARGetter = [&](Function &F) -> AAResults & { - BAR.emplace(createLegacyPMBasicAAResult(*this, F)); - AAR.emplace(createLegacyPMAAResults(*this, F, *BAR)); - return *AAR; - }; - - return runImpl(SCC, CG, AARGetter, maxElements); -} + Type *RetTy = FTy->getReturnType(); -/// \brief Checks if a type could have padding bytes. -static bool isDenselyPacked(Type *type, const DataLayout &DL) { + // Construct the new function type using the new arguments. + FunctionType *NFTy = FunctionType::get(RetTy, Params, FTy->isVarArg()); - // There is no size information, so be conservative. - if (!type->isSized()) - return false; + // Create the new function body and insert it into the module. + Function *NF = Function::Create(NFTy, F->getLinkage(), F->getName()); + NF->copyAttributesFrom(F); - // If the alloc size is not equal to the storage size, then there are padding - // bytes. For x86_fp80 on x86-64, size: 80 alloc size: 128. - if (DL.getTypeSizeInBits(type) != DL.getTypeAllocSizeInBits(type)) - return false; + // Patch the pointer to LLVM function in debug info descriptor. + NF->setSubprogram(F->getSubprogram()); + F->setSubprogram(nullptr); - if (!isa<CompositeType>(type)) - return true; + DEBUG(dbgs() << "ARG PROMOTION: Promoting to:" << *NF << "\n" + << "From: " << *F); - // For homogenous sequential types, check for padding within members. - if (SequentialType *seqTy = dyn_cast<SequentialType>(type)) - return isDenselyPacked(seqTy->getElementType(), DL); + // Recompute the parameter attributes list based on the new arguments for + // the function. + NF->setAttributes(AttributeList::get(F->getContext(), PAL.getFnAttributes(), + PAL.getRetAttributes(), ArgAttrVec)); + ArgAttrVec.clear(); - // Check for padding within and between elements of a struct. - StructType *StructTy = cast<StructType>(type); - const StructLayout *Layout = DL.getStructLayout(StructTy); - uint64_t StartPos = 0; - for (unsigned i = 0, E = StructTy->getNumElements(); i < E; ++i) { - Type *ElTy = StructTy->getElementType(i); - if (!isDenselyPacked(ElTy, DL)) - return false; - if (StartPos != Layout->getElementOffsetInBits(i)) - return false; - StartPos += DL.getTypeAllocSizeInBits(ElTy); - } + F->getParent()->getFunctionList().insert(F->getIterator(), NF); + NF->takeName(F); - return true; -} + // Loop over all of the callers of the function, transforming the call sites + // to pass in the loaded pointers. + // + SmallVector<Value *, 16> Args; + while (!F->use_empty()) { + CallSite CS(F->user_back()); + assert(CS.getCalledFunction() == F); + Instruction *Call = CS.getInstruction(); + const AttributeList &CallPAL = CS.getAttributes(); -/// \brief Checks if the padding bytes of an argument could be accessed. -static bool canPaddingBeAccessed(Argument *arg) { + // Loop over the operands, inserting GEP and loads in the caller as + // appropriate. + CallSite::arg_iterator AI = CS.arg_begin(); + ArgNo = 0; + for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E; + ++I, ++AI, ++ArgNo) + if (!ArgsToPromote.count(&*I) && !ByValArgsToTransform.count(&*I)) { + Args.push_back(*AI); // Unmodified argument + ArgAttrVec.push_back(CallPAL.getParamAttributes(ArgNo)); + } else if (ByValArgsToTransform.count(&*I)) { + // Emit a GEP and load for each element of the struct. + Type *AgTy = cast<PointerType>(I->getType())->getElementType(); + StructType *STy = cast<StructType>(AgTy); + Value *Idxs[2] = { + ConstantInt::get(Type::getInt32Ty(F->getContext()), 0), nullptr}; + for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { + Idxs[1] = ConstantInt::get(Type::getInt32Ty(F->getContext()), i); + Value *Idx = GetElementPtrInst::Create( + STy, *AI, Idxs, (*AI)->getName() + "." + Twine(i), Call); + // TODO: Tell AA about the new values? + Args.push_back(new LoadInst(Idx, Idx->getName() + ".val", Call)); + ArgAttrVec.push_back(AttributeSet()); + } + } else if (!I->use_empty()) { + // Non-dead argument: insert GEPs and loads as appropriate. + ScalarizeTable &ArgIndices = ScalarizedElements[&*I]; + // Store the Value* version of the indices in here, but declare it now + // for reuse. + std::vector<Value *> Ops; + for (const auto &ArgIndex : ArgIndices) { + Value *V = *AI; + LoadInst *OrigLoad = + OriginalLoads[std::make_pair(&*I, ArgIndex.second)]; + if (!ArgIndex.second.empty()) { + Ops.reserve(ArgIndex.second.size()); + Type *ElTy = V->getType(); + for (auto II : ArgIndex.second) { + // Use i32 to index structs, and i64 for others (pointers/arrays). + // This satisfies GEP constraints. + Type *IdxTy = + (ElTy->isStructTy() ? Type::getInt32Ty(F->getContext()) + : Type::getInt64Ty(F->getContext())); + Ops.push_back(ConstantInt::get(IdxTy, II)); + // Keep track of the type we're currently indexing. + if (auto *ElPTy = dyn_cast<PointerType>(ElTy)) + ElTy = ElPTy->getElementType(); + else + ElTy = cast<CompositeType>(ElTy)->getTypeAtIndex(II); + } + // And create a GEP to extract those indices. + V = GetElementPtrInst::Create(ArgIndex.first, V, Ops, + V->getName() + ".idx", Call); + Ops.clear(); + } + // Since we're replacing a load make sure we take the alignment + // of the previous load. + LoadInst *newLoad = new LoadInst(V, V->getName() + ".val", Call); + newLoad->setAlignment(OrigLoad->getAlignment()); + // Transfer the AA info too. + AAMDNodes AAInfo; + OrigLoad->getAAMetadata(AAInfo); + newLoad->setAAMetadata(AAInfo); - assert(arg->hasByValAttr()); + Args.push_back(newLoad); + ArgAttrVec.push_back(AttributeSet()); + } + } - // Track all the pointers to the argument to make sure they are not captured. - SmallPtrSet<Value *, 16> PtrValues; - PtrValues.insert(arg); + // Push any varargs arguments on the list. + for (; AI != CS.arg_end(); ++AI, ++ArgNo) { + Args.push_back(*AI); + ArgAttrVec.push_back(CallPAL.getParamAttributes(ArgNo)); + } - // Track all of the stores. - SmallVector<StoreInst *, 16> Stores; + SmallVector<OperandBundleDef, 1> OpBundles; + CS.getOperandBundlesAsDefs(OpBundles); - // Scan through the uses recursively to make sure the pointer is always used - // sanely. - SmallVector<Value *, 16> WorkList; - WorkList.insert(WorkList.end(), arg->user_begin(), arg->user_end()); - while (!WorkList.empty()) { - Value *V = WorkList.back(); - WorkList.pop_back(); - if (isa<GetElementPtrInst>(V) || isa<PHINode>(V)) { - if (PtrValues.insert(V).second) - WorkList.insert(WorkList.end(), V->user_begin(), V->user_end()); - } else if (StoreInst *Store = dyn_cast<StoreInst>(V)) { - Stores.push_back(Store); - } else if (!isa<LoadInst>(V)) { - return true; + CallSite NewCS; + if (InvokeInst *II = dyn_cast<InvokeInst>(Call)) { + NewCS = InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(), + Args, OpBundles, "", Call); + } else { + auto *NewCall = CallInst::Create(NF, Args, OpBundles, "", Call); + NewCall->setTailCallKind(cast<CallInst>(Call)->getTailCallKind()); + NewCS = NewCall; } - } + NewCS.setCallingConv(CS.getCallingConv()); + NewCS.setAttributes( + AttributeList::get(F->getContext(), CallPAL.getFnAttributes(), + CallPAL.getRetAttributes(), ArgAttrVec)); + NewCS->setDebugLoc(Call->getDebugLoc()); + uint64_t W; + if (Call->extractProfTotalWeight(W)) + NewCS->setProfWeight(W); + Args.clear(); + ArgAttrVec.clear(); -// Check to make sure the pointers aren't captured - for (StoreInst *Store : Stores) - if (PtrValues.count(Store->getValueOperand())) - return true; + // Update the callgraph to know that the callsite has been transformed. + if (ReplaceCallSite) + (*ReplaceCallSite)(CS, NewCS); - return false; -} + if (!Call->use_empty()) { + Call->replaceAllUsesWith(NewCS.getInstruction()); + NewCS->takeName(Call); + } -/// PromoteArguments - This method checks the specified function to see if there -/// are any promotable arguments and if it is safe to promote the function (for -/// example, all callers are direct). If safe to promote some arguments, it -/// calls the DoPromotion method. -/// -static CallGraphNode * -PromoteArguments(CallGraphNode *CGN, CallGraph &CG, - function_ref<AAResults &(Function &F)> AARGetter, - unsigned MaxElements) { - Function *F = CGN->getFunction(); + // Finally, remove the old call from the program, reducing the use-count of + // F. + Call->eraseFromParent(); + } - // Make sure that it is local to this module. - if (!F || !F->hasLocalLinkage()) return nullptr; + const DataLayout &DL = F->getParent()->getDataLayout(); - // Don't promote arguments for variadic functions. Adding, removing, or - // changing non-pack parameters can change the classification of pack - // parameters. Frontends encode that classification at the call site in the - // IR, while in the callee the classification is determined dynamically based - // on the number of registers consumed so far. - if (F->isVarArg()) return nullptr; + // Since we have now created the new function, splice the body of the old + // function right into the new function, leaving the old rotting hulk of the + // function empty. + NF->getBasicBlockList().splice(NF->begin(), F->getBasicBlockList()); - // First check: see if there are any pointer arguments! If not, quick exit. - SmallVector<Argument*, 16> PointerArgs; - for (Argument &I : F->args()) - if (I.getType()->isPointerTy()) - PointerArgs.push_back(&I); - if (PointerArgs.empty()) return nullptr; + // Loop over the argument list, transferring uses of the old arguments over to + // the new arguments, also transferring over the names as well. + // + for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(), + I2 = NF->arg_begin(); + I != E; ++I) { + if (!ArgsToPromote.count(&*I) && !ByValArgsToTransform.count(&*I)) { + // If this is an unmodified argument, move the name and users over to the + // new version. + I->replaceAllUsesWith(&*I2); + I2->takeName(&*I); + ++I2; + continue; + } - // Second check: make sure that all callers are direct callers. We can't - // transform functions that have indirect callers. Also see if the function - // is self-recursive. - bool isSelfRecursive = false; - for (Use &U : F->uses()) { - CallSite CS(U.getUser()); - // Must be a direct call. - if (CS.getInstruction() == nullptr || !CS.isCallee(&U)) return nullptr; - - if (CS.getInstruction()->getParent()->getParent() == F) - isSelfRecursive = true; - } - - const DataLayout &DL = F->getParent()->getDataLayout(); + if (ByValArgsToTransform.count(&*I)) { + // In the callee, we create an alloca, and store each of the new incoming + // arguments into the alloca. + Instruction *InsertPt = &NF->begin()->front(); - AAResults &AAR = AARGetter(*F); + // Just add all the struct element types. + Type *AgTy = cast<PointerType>(I->getType())->getElementType(); + Value *TheAlloca = new AllocaInst(AgTy, DL.getAllocaAddrSpace(), nullptr, + I->getParamAlignment(), "", InsertPt); + StructType *STy = cast<StructType>(AgTy); + Value *Idxs[2] = {ConstantInt::get(Type::getInt32Ty(F->getContext()), 0), + nullptr}; - // Check to see which arguments are promotable. If an argument is promotable, - // add it to ArgsToPromote. - SmallPtrSet<Argument*, 8> ArgsToPromote; - SmallPtrSet<Argument*, 8> ByValArgsToTransform; - for (Argument *PtrArg : PointerArgs) { - Type *AgTy = cast<PointerType>(PtrArg->getType())->getElementType(); + for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { + Idxs[1] = ConstantInt::get(Type::getInt32Ty(F->getContext()), i); + Value *Idx = GetElementPtrInst::Create( + AgTy, TheAlloca, Idxs, TheAlloca->getName() + "." + Twine(i), + InsertPt); + I2->setName(I->getName() + "." + Twine(i)); + new StoreInst(&*I2++, Idx, InsertPt); + } - // Replace sret attribute with noalias. This reduces register pressure by - // avoiding a register copy. - if (PtrArg->hasStructRetAttr()) { - unsigned ArgNo = PtrArg->getArgNo(); - F->setAttributes( - F->getAttributes() - .removeAttribute(F->getContext(), ArgNo + 1, Attribute::StructRet) - .addAttribute(F->getContext(), ArgNo + 1, Attribute::NoAlias)); - for (Use &U : F->uses()) { - CallSite CS(U.getUser()); - CS.setAttributes( - CS.getAttributes() - .removeAttribute(F->getContext(), ArgNo + 1, - Attribute::StructRet) - .addAttribute(F->getContext(), ArgNo + 1, Attribute::NoAlias)); + // Anything that used the arg should now use the alloca. + I->replaceAllUsesWith(TheAlloca); + TheAlloca->takeName(&*I); + + // If the alloca is used in a call, we must clear the tail flag since + // the callee now uses an alloca from the caller. + for (User *U : TheAlloca->users()) { + CallInst *Call = dyn_cast<CallInst>(U); + if (!Call) + continue; + Call->setTailCall(false); } + continue; } - // If this is a byval argument, and if the aggregate type is small, just - // pass the elements, which is always safe, if the passed value is densely - // packed or if we can prove the padding bytes are never accessed. This does - // not apply to inalloca. - bool isSafeToPromote = - PtrArg->hasByValAttr() && - (isDenselyPacked(AgTy, DL) || !canPaddingBeAccessed(PtrArg)); - if (isSafeToPromote) { - if (StructType *STy = dyn_cast<StructType>(AgTy)) { - if (MaxElements > 0 && STy->getNumElements() > MaxElements) { - DEBUG(dbgs() << "argpromotion disable promoting argument '" - << PtrArg->getName() << "' because it would require adding more" - << " than " << MaxElements << " arguments to the function.\n"); - continue; - } - - // If all the elements are single-value types, we can promote it. - bool AllSimple = true; - for (const auto *EltTy : STy->elements()) { - if (!EltTy->isSingleValueType()) { - AllSimple = false; - break; - } + if (I->use_empty()) + continue; + + // Otherwise, if we promoted this argument, then all users are load + // instructions (or GEPs with only load users), and all loads should be + // using the new argument that we added. + ScalarizeTable &ArgIndices = ScalarizedElements[&*I]; + + while (!I->use_empty()) { + if (LoadInst *LI = dyn_cast<LoadInst>(I->user_back())) { + assert(ArgIndices.begin()->second.empty() && + "Load element should sort to front!"); + I2->setName(I->getName() + ".val"); + LI->replaceAllUsesWith(&*I2); + LI->eraseFromParent(); + DEBUG(dbgs() << "*** Promoted load of argument '" << I->getName() + << "' in function '" << F->getName() << "'\n"); + } else { + GetElementPtrInst *GEP = cast<GetElementPtrInst>(I->user_back()); + IndicesVector Operands; + Operands.reserve(GEP->getNumIndices()); + for (User::op_iterator II = GEP->idx_begin(), IE = GEP->idx_end(); + II != IE; ++II) + Operands.push_back(cast<ConstantInt>(*II)->getSExtValue()); + + // GEPs with a single 0 index can be merged with direct loads + if (Operands.size() == 1 && Operands.front() == 0) + Operands.clear(); + + Function::arg_iterator TheArg = I2; + for (ScalarizeTable::iterator It = ArgIndices.begin(); + It->second != Operands; ++It, ++TheArg) { + assert(It != ArgIndices.end() && "GEP not handled??"); } - // Safe to transform, don't even bother trying to "promote" it. - // Passing the elements as a scalar will allow sroa to hack on - // the new alloca we introduce. - if (AllSimple) { - ByValArgsToTransform.insert(PtrArg); - continue; + std::string NewName = I->getName(); + for (unsigned i = 0, e = Operands.size(); i != e; ++i) { + NewName += "." + utostr(Operands[i]); } - } - } + NewName += ".val"; + TheArg->setName(NewName); - // If the argument is a recursive type and we're in a recursive - // function, we could end up infinitely peeling the function argument. - if (isSelfRecursive) { - if (StructType *STy = dyn_cast<StructType>(AgTy)) { - bool RecursiveType = false; - for (const auto *EltTy : STy->elements()) { - if (EltTy == PtrArg->getType()) { - RecursiveType = true; - break; - } + DEBUG(dbgs() << "*** Promoted agg argument '" << TheArg->getName() + << "' of function '" << NF->getName() << "'\n"); + + // All of the uses must be load instructions. Replace them all with + // the argument specified by ArgNo. + while (!GEP->use_empty()) { + LoadInst *L = cast<LoadInst>(GEP->user_back()); + L->replaceAllUsesWith(&*TheArg); + L->eraseFromParent(); } - if (RecursiveType) - continue; + GEP->eraseFromParent(); } } - - // Otherwise, see if we can promote the pointer to its value. - if (isSafeToPromoteArgument(PtrArg, PtrArg->hasByValOrInAllocaAttr(), AAR, - MaxElements)) - ArgsToPromote.insert(PtrArg); - } - // No promotable pointer arguments. - if (ArgsToPromote.empty() && ByValArgsToTransform.empty()) - return nullptr; + // Increment I2 past all of the arguments added for this promoted pointer. + std::advance(I2, ArgIndices.size()); + } - return DoPromotion(F, ArgsToPromote, ByValArgsToTransform, CG); + return NF; } /// AllCallersPassInValidPointerForArgument - Return true if we can prove that /// all callees pass in a valid pointer for the specified function argument. -static bool AllCallersPassInValidPointerForArgument(Argument *Arg) { +static bool allCallersPassInValidPointerForArgument(Argument *Arg) { Function *Callee = Arg->getParent(); const DataLayout &DL = Callee->getParent()->getDataLayout(); @@ -390,26 +473,25 @@ static bool AllCallersPassInValidPointerForArgument(Argument *Arg) { /// elements in Prefix is the same as the corresponding elements in Longer. /// /// This means it also returns true when Prefix and Longer are equal! -static bool IsPrefix(const IndicesVector &Prefix, const IndicesVector &Longer) { +static bool isPrefix(const IndicesVector &Prefix, const IndicesVector &Longer) { if (Prefix.size() > Longer.size()) return false; return std::equal(Prefix.begin(), Prefix.end(), Longer.begin()); } - /// Checks if Indices, or a prefix of Indices, is in Set. -static bool PrefixIn(const IndicesVector &Indices, +static bool prefixIn(const IndicesVector &Indices, std::set<IndicesVector> &Set) { - std::set<IndicesVector>::iterator Low; - Low = Set.upper_bound(Indices); - if (Low != Set.begin()) - Low--; - // Low is now the last element smaller than or equal to Indices. This means - // it points to a prefix of Indices (possibly Indices itself), if such - // prefix exists. - // - // This load is safe if any prefix of its operands is safe to load. - return Low != Set.end() && IsPrefix(*Low, Indices); + std::set<IndicesVector>::iterator Low; + Low = Set.upper_bound(Indices); + if (Low != Set.begin()) + Low--; + // Low is now the last element smaller than or equal to Indices. This means + // it points to a prefix of Indices (possibly Indices itself), if such + // prefix exists. + // + // This load is safe if any prefix of its operands is safe to load. + return Low != Set.end() && isPrefix(*Low, Indices); } /// Mark the given indices (ToMark) as safe in the given set of indices @@ -417,7 +499,7 @@ static bool PrefixIn(const IndicesVector &Indices, /// is already a prefix of Indices in Safe, Indices are implicitely marked safe /// already. Furthermore, any indices that Indices is itself a prefix of, are /// removed from Safe (since they are implicitely safe because of Indices now). -static void MarkIndicesSafe(const IndicesVector &ToMark, +static void markIndicesSafe(const IndicesVector &ToMark, std::set<IndicesVector> &Safe) { std::set<IndicesVector>::iterator Low; Low = Safe.upper_bound(ToMark); @@ -428,7 +510,7 @@ static void MarkIndicesSafe(const IndicesVector &ToMark, // means it points to a prefix of Indices (possibly Indices itself), if // such prefix exists. if (Low != Safe.end()) { - if (IsPrefix(*Low, ToMark)) + if (isPrefix(*Low, ToMark)) // If there is already a prefix of these indices (or exactly these // indices) marked a safe, don't bother adding these indices return; @@ -441,7 +523,7 @@ static void MarkIndicesSafe(const IndicesVector &ToMark, ++Low; // If there we're a prefix of longer index list(s), remove those std::set<IndicesVector>::iterator End = Safe.end(); - while (Low != End && IsPrefix(ToMark, *Low)) { + while (Low != End && isPrefix(ToMark, *Low)) { std::set<IndicesVector>::iterator Remove = Low; ++Low; Safe.erase(Remove); @@ -486,7 +568,7 @@ static bool isSafeToPromoteArgument(Argument *Arg, bool isByValOrInAlloca, GEPIndicesSet ToPromote; // If the pointer is always valid, any load with first index 0 is valid. - if (isByValOrInAlloca || AllCallersPassInValidPointerForArgument(Arg)) + if (isByValOrInAlloca || allCallersPassInValidPointerForArgument(Arg)) SafeToUnconditionallyLoad.insert(IndicesVector(1, 0)); // First, iterate the entry block and mark loads of (geps of) arguments as @@ -512,25 +594,26 @@ static bool isSafeToPromoteArgument(Argument *Arg, bool isByValOrInAlloca, return false; // Indices checked out, mark them as safe - MarkIndicesSafe(Indices, SafeToUnconditionallyLoad); + markIndicesSafe(Indices, SafeToUnconditionallyLoad); Indices.clear(); } } else if (V == Arg) { // Direct loads are equivalent to a GEP with a single 0 index. - MarkIndicesSafe(IndicesVector(1, 0), SafeToUnconditionallyLoad); + markIndicesSafe(IndicesVector(1, 0), SafeToUnconditionallyLoad); } } // Now, iterate all uses of the argument to see if there are any uses that are // not (GEP+)loads, or any (GEP+)loads that are not safe to promote. - SmallVector<LoadInst*, 16> Loads; + SmallVector<LoadInst *, 16> Loads; IndicesVector Operands; for (Use &U : Arg->uses()) { User *UR = U.getUser(); Operands.clear(); if (LoadInst *LI = dyn_cast<LoadInst>(UR)) { // Don't hack volatile/atomic loads - if (!LI->isSimple()) return false; + if (!LI->isSimple()) + return false; Loads.push_back(LI); // Direct loads are equivalent to a GEP with a zero index and then a load. Operands.push_back(0); @@ -547,30 +630,31 @@ static bool isSafeToPromoteArgument(Argument *Arg, bool isByValOrInAlloca, } // Ensure that all of the indices are constants. - for (User::op_iterator i = GEP->idx_begin(), e = GEP->idx_end(); - i != e; ++i) + for (User::op_iterator i = GEP->idx_begin(), e = GEP->idx_end(); i != e; + ++i) if (ConstantInt *C = dyn_cast<ConstantInt>(*i)) Operands.push_back(C->getSExtValue()); else - return false; // Not a constant operand GEP! + return false; // Not a constant operand GEP! // Ensure that the only users of the GEP are load instructions. for (User *GEPU : GEP->users()) if (LoadInst *LI = dyn_cast<LoadInst>(GEPU)) { // Don't hack volatile/atomic loads - if (!LI->isSimple()) return false; + if (!LI->isSimple()) + return false; Loads.push_back(LI); } else { // Other uses than load? return false; } } else { - return false; // Not a load or a GEP. + return false; // Not a load or a GEP. } // Now, see if it is safe to promote this load / loads of this GEP. Loading // is safe if Operands, or a prefix of Operands, is marked as safe. - if (!PrefixIn(Operands, SafeToUnconditionallyLoad)) + if (!prefixIn(Operands, SafeToUnconditionallyLoad)) return false; // See if we are already promoting a load with these indices. If not, check @@ -579,8 +663,10 @@ static bool isSafeToPromoteArgument(Argument *Arg, bool isByValOrInAlloca, if (ToPromote.find(Operands) == ToPromote.end()) { if (MaxElements > 0 && ToPromote.size() == MaxElements) { DEBUG(dbgs() << "argpromotion not promoting argument '" - << Arg->getName() << "' because it would require adding more " - << "than " << MaxElements << " arguments to the function.\n"); + << Arg->getName() + << "' because it would require adding more " + << "than " << MaxElements + << " arguments to the function.\n"); // We limit aggregate promotion to only promoting up to a fixed number // of elements of the aggregate. return false; @@ -589,7 +675,8 @@ static bool isSafeToPromoteArgument(Argument *Arg, bool isByValOrInAlloca, } } - if (Loads.empty()) return true; // No users, this is a dead argument. + if (Loads.empty()) + return true; // No users, this is a dead argument. // Okay, now we know that the argument is only used by load instructions and // it is safe to unconditionally perform all of them. Use alias analysis to @@ -598,7 +685,7 @@ static bool isSafeToPromoteArgument(Argument *Arg, bool isByValOrInAlloca, // Because there could be several/many load instructions, remember which // blocks we know to be transparent to the load. - df_iterator_default_set<BasicBlock*, 16> TranspBlocks; + df_iterator_default_set<BasicBlock *, 16> TranspBlocks; for (LoadInst *Load : Loads) { // Check to see if the load is invalidated from the start of the block to @@ -607,7 +694,7 @@ static bool isSafeToPromoteArgument(Argument *Arg, bool isByValOrInAlloca, MemoryLocation Loc = MemoryLocation::get(Load); if (AAR.canInstructionRangeModRef(BB->front(), *Load, Loc, MRI_Mod)) - return false; // Pointer is invalidated! + return false; // Pointer is invalidated! // Now check every path from the entry block to the load for transparency. // To do this, we perform a depth first search on the inverse CFG from the @@ -625,416 +712,347 @@ static bool isSafeToPromoteArgument(Argument *Arg, bool isByValOrInAlloca, return true; } -/// DoPromotion - This method actually performs the promotion of the specified -/// arguments, and returns the new function. At this point, we know that it's -/// safe to do so. -static CallGraphNode * -DoPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, - SmallPtrSetImpl<Argument *> &ByValArgsToTransform, CallGraph &CG) { +/// \brief Checks if a type could have padding bytes. +static bool isDenselyPacked(Type *type, const DataLayout &DL) { - // Start by computing a new prototype for the function, which is the same as - // the old function, but has modified arguments. - FunctionType *FTy = F->getFunctionType(); - std::vector<Type*> Params; + // There is no size information, so be conservative. + if (!type->isSized()) + return false; - typedef std::set<std::pair<Type *, IndicesVector>> ScalarizeTable; + // If the alloc size is not equal to the storage size, then there are padding + // bytes. For x86_fp80 on x86-64, size: 80 alloc size: 128. + if (DL.getTypeSizeInBits(type) != DL.getTypeAllocSizeInBits(type)) + return false; - // ScalarizedElements - If we are promoting a pointer that has elements - // accessed out of it, keep track of which elements are accessed so that we - // can add one argument for each. - // - // Arguments that are directly loaded will have a zero element value here, to - // handle cases where there are both a direct load and GEP accesses. - // - std::map<Argument*, ScalarizeTable> ScalarizedElements; + if (!isa<CompositeType>(type)) + return true; - // OriginalLoads - Keep track of a representative load instruction from the - // original function so that we can tell the alias analysis implementation - // what the new GEP/Load instructions we are inserting look like. - // We need to keep the original loads for each argument and the elements - // of the argument that are accessed. - std::map<std::pair<Argument*, IndicesVector>, LoadInst*> OriginalLoads; + // For homogenous sequential types, check for padding within members. + if (SequentialType *seqTy = dyn_cast<SequentialType>(type)) + return isDenselyPacked(seqTy->getElementType(), DL); - // Attribute - Keep track of the parameter attributes for the arguments - // that we are *not* promoting. For the ones that we do promote, the parameter - // attributes are lost - SmallVector<AttributeSet, 8> AttributesVec; - const AttributeSet &PAL = F->getAttributes(); + // Check for padding within and between elements of a struct. + StructType *StructTy = cast<StructType>(type); + const StructLayout *Layout = DL.getStructLayout(StructTy); + uint64_t StartPos = 0; + for (unsigned i = 0, E = StructTy->getNumElements(); i < E; ++i) { + Type *ElTy = StructTy->getElementType(i); + if (!isDenselyPacked(ElTy, DL)) + return false; + if (StartPos != Layout->getElementOffsetInBits(i)) + return false; + StartPos += DL.getTypeAllocSizeInBits(ElTy); + } - // Add any return attributes. - if (PAL.hasAttributes(AttributeSet::ReturnIndex)) - AttributesVec.push_back(AttributeSet::get(F->getContext(), - PAL.getRetAttributes())); + return true; +} - // First, determine the new argument list - unsigned ArgIndex = 1; - for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E; - ++I, ++ArgIndex) { - if (ByValArgsToTransform.count(&*I)) { - // Simple byval argument? Just add all the struct element types. - Type *AgTy = cast<PointerType>(I->getType())->getElementType(); - StructType *STy = cast<StructType>(AgTy); - Params.insert(Params.end(), STy->element_begin(), STy->element_end()); - ++NumByValArgsPromoted; - } else if (!ArgsToPromote.count(&*I)) { - // Unchanged argument - Params.push_back(I->getType()); - AttributeSet attrs = PAL.getParamAttributes(ArgIndex); - if (attrs.hasAttributes(ArgIndex)) { - AttrBuilder B(attrs, ArgIndex); - AttributesVec. - push_back(AttributeSet::get(F->getContext(), Params.size(), B)); - } - } else if (I->use_empty()) { - // Dead argument (which are always marked as promotable) - ++NumArgumentsDead; - } else { - // Okay, this is being promoted. This means that the only uses are loads - // or GEPs which are only used by loads +/// \brief Checks if the padding bytes of an argument could be accessed. +static bool canPaddingBeAccessed(Argument *arg) { - // In this table, we will track which indices are loaded from the argument - // (where direct loads are tracked as no indices). - ScalarizeTable &ArgIndices = ScalarizedElements[&*I]; - for (User *U : I->users()) { - Instruction *UI = cast<Instruction>(U); - Type *SrcTy; - if (LoadInst *L = dyn_cast<LoadInst>(UI)) - SrcTy = L->getType(); - else - SrcTy = cast<GetElementPtrInst>(UI)->getSourceElementType(); - IndicesVector Indices; - Indices.reserve(UI->getNumOperands() - 1); - // Since loads will only have a single operand, and GEPs only a single - // non-index operand, this will record direct loads without any indices, - // and gep+loads with the GEP indices. - for (User::op_iterator II = UI->op_begin() + 1, IE = UI->op_end(); - II != IE; ++II) - Indices.push_back(cast<ConstantInt>(*II)->getSExtValue()); - // GEPs with a single 0 index can be merged with direct loads - if (Indices.size() == 1 && Indices.front() == 0) - Indices.clear(); - ArgIndices.insert(std::make_pair(SrcTy, Indices)); - LoadInst *OrigLoad; - if (LoadInst *L = dyn_cast<LoadInst>(UI)) - OrigLoad = L; - else - // Take any load, we will use it only to update Alias Analysis - OrigLoad = cast<LoadInst>(UI->user_back()); - OriginalLoads[std::make_pair(&*I, Indices)] = OrigLoad; - } + assert(arg->hasByValAttr()); - // Add a parameter to the function for each element passed in. - for (const auto &ArgIndex : ArgIndices) { - // not allowed to dereference ->begin() if size() is 0 - Params.push_back(GetElementPtrInst::getIndexedType( - cast<PointerType>(I->getType()->getScalarType())->getElementType(), - ArgIndex.second)); - assert(Params.back()); - } + // Track all the pointers to the argument to make sure they are not captured. + SmallPtrSet<Value *, 16> PtrValues; + PtrValues.insert(arg); - if (ArgIndices.size() == 1 && ArgIndices.begin()->second.empty()) - ++NumArgumentsPromoted; - else - ++NumAggregatesPromoted; + // Track all of the stores. + SmallVector<StoreInst *, 16> Stores; + + // Scan through the uses recursively to make sure the pointer is always used + // sanely. + SmallVector<Value *, 16> WorkList; + WorkList.insert(WorkList.end(), arg->user_begin(), arg->user_end()); + while (!WorkList.empty()) { + Value *V = WorkList.back(); + WorkList.pop_back(); + if (isa<GetElementPtrInst>(V) || isa<PHINode>(V)) { + if (PtrValues.insert(V).second) + WorkList.insert(WorkList.end(), V->user_begin(), V->user_end()); + } else if (StoreInst *Store = dyn_cast<StoreInst>(V)) { + Stores.push_back(Store); + } else if (!isa<LoadInst>(V)) { + return true; } } - // Add any function attributes. - if (PAL.hasAttributes(AttributeSet::FunctionIndex)) - AttributesVec.push_back(AttributeSet::get(FTy->getContext(), - PAL.getFnAttributes())); + // Check to make sure the pointers aren't captured + for (StoreInst *Store : Stores) + if (PtrValues.count(Store->getValueOperand())) + return true; - Type *RetTy = FTy->getReturnType(); + return false; +} - // Construct the new function type using the new arguments. - FunctionType *NFTy = FunctionType::get(RetTy, Params, FTy->isVarArg()); +/// PromoteArguments - This method checks the specified function to see if there +/// are any promotable arguments and if it is safe to promote the function (for +/// example, all callers are direct). If safe to promote some arguments, it +/// calls the DoPromotion method. +/// +static Function * +promoteArguments(Function *F, function_ref<AAResults &(Function &F)> AARGetter, + unsigned MaxElements, + Optional<function_ref<void(CallSite OldCS, CallSite NewCS)>> + ReplaceCallSite) { + // Make sure that it is local to this module. + if (!F->hasLocalLinkage()) + return nullptr; - // Create the new function body and insert it into the module. - Function *NF = Function::Create(NFTy, F->getLinkage(), F->getName()); - NF->copyAttributesFrom(F); + // Don't promote arguments for variadic functions. Adding, removing, or + // changing non-pack parameters can change the classification of pack + // parameters. Frontends encode that classification at the call site in the + // IR, while in the callee the classification is determined dynamically based + // on the number of registers consumed so far. + if (F->isVarArg()) + return nullptr; - // Patch the pointer to LLVM function in debug info descriptor. - NF->setSubprogram(F->getSubprogram()); - F->setSubprogram(nullptr); + // First check: see if there are any pointer arguments! If not, quick exit. + SmallVector<Argument *, 16> PointerArgs; + for (Argument &I : F->args()) + if (I.getType()->isPointerTy()) + PointerArgs.push_back(&I); + if (PointerArgs.empty()) + return nullptr; - DEBUG(dbgs() << "ARG PROMOTION: Promoting to:" << *NF << "\n" - << "From: " << *F); - - // Recompute the parameter attributes list based on the new arguments for - // the function. - NF->setAttributes(AttributeSet::get(F->getContext(), AttributesVec)); - AttributesVec.clear(); + // Second check: make sure that all callers are direct callers. We can't + // transform functions that have indirect callers. Also see if the function + // is self-recursive. + bool isSelfRecursive = false; + for (Use &U : F->uses()) { + CallSite CS(U.getUser()); + // Must be a direct call. + if (CS.getInstruction() == nullptr || !CS.isCallee(&U)) + return nullptr; - F->getParent()->getFunctionList().insert(F->getIterator(), NF); - NF->takeName(F); + if (CS.getInstruction()->getParent()->getParent() == F) + isSelfRecursive = true; + } - // Get a new callgraph node for NF. - CallGraphNode *NF_CGN = CG.getOrInsertFunction(NF); + const DataLayout &DL = F->getParent()->getDataLayout(); - // Loop over all of the callers of the function, transforming the call sites - // to pass in the loaded pointers. - // - SmallVector<Value*, 16> Args; - while (!F->use_empty()) { - CallSite CS(F->user_back()); - assert(CS.getCalledFunction() == F); - Instruction *Call = CS.getInstruction(); - const AttributeSet &CallPAL = CS.getAttributes(); + AAResults &AAR = AARGetter(*F); - // Add any return attributes. - if (CallPAL.hasAttributes(AttributeSet::ReturnIndex)) - AttributesVec.push_back(AttributeSet::get(F->getContext(), - CallPAL.getRetAttributes())); + // Check to see which arguments are promotable. If an argument is promotable, + // add it to ArgsToPromote. + SmallPtrSet<Argument *, 8> ArgsToPromote; + SmallPtrSet<Argument *, 8> ByValArgsToTransform; + for (Argument *PtrArg : PointerArgs) { + Type *AgTy = cast<PointerType>(PtrArg->getType())->getElementType(); - // Loop over the operands, inserting GEP and loads in the caller as - // appropriate. - CallSite::arg_iterator AI = CS.arg_begin(); - ArgIndex = 1; - for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); - I != E; ++I, ++AI, ++ArgIndex) - if (!ArgsToPromote.count(&*I) && !ByValArgsToTransform.count(&*I)) { - Args.push_back(*AI); // Unmodified argument + // Replace sret attribute with noalias. This reduces register pressure by + // avoiding a register copy. + if (PtrArg->hasStructRetAttr()) { + unsigned ArgNo = PtrArg->getArgNo(); + F->removeParamAttr(ArgNo, Attribute::StructRet); + F->addParamAttr(ArgNo, Attribute::NoAlias); + for (Use &U : F->uses()) { + CallSite CS(U.getUser()); + CS.removeParamAttr(ArgNo, Attribute::StructRet); + CS.addParamAttr(ArgNo, Attribute::NoAlias); + } + } - if (CallPAL.hasAttributes(ArgIndex)) { - AttrBuilder B(CallPAL, ArgIndex); - AttributesVec. - push_back(AttributeSet::get(F->getContext(), Args.size(), B)); - } - } else if (ByValArgsToTransform.count(&*I)) { - // Emit a GEP and load for each element of the struct. - Type *AgTy = cast<PointerType>(I->getType())->getElementType(); - StructType *STy = cast<StructType>(AgTy); - Value *Idxs[2] = { - ConstantInt::get(Type::getInt32Ty(F->getContext()), 0), nullptr }; - for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { - Idxs[1] = ConstantInt::get(Type::getInt32Ty(F->getContext()), i); - Value *Idx = GetElementPtrInst::Create( - STy, *AI, Idxs, (*AI)->getName() + "." + Twine(i), Call); - // TODO: Tell AA about the new values? - Args.push_back(new LoadInst(Idx, Idx->getName()+".val", Call)); + // If this is a byval argument, and if the aggregate type is small, just + // pass the elements, which is always safe, if the passed value is densely + // packed or if we can prove the padding bytes are never accessed. This does + // not apply to inalloca. + bool isSafeToPromote = + PtrArg->hasByValAttr() && + (isDenselyPacked(AgTy, DL) || !canPaddingBeAccessed(PtrArg)); + if (isSafeToPromote) { + if (StructType *STy = dyn_cast<StructType>(AgTy)) { + if (MaxElements > 0 && STy->getNumElements() > MaxElements) { + DEBUG(dbgs() << "argpromotion disable promoting argument '" + << PtrArg->getName() + << "' because it would require adding more" + << " than " << MaxElements + << " arguments to the function.\n"); + continue; } - } else if (!I->use_empty()) { - // Non-dead argument: insert GEPs and loads as appropriate. - ScalarizeTable &ArgIndices = ScalarizedElements[&*I]; - // Store the Value* version of the indices in here, but declare it now - // for reuse. - std::vector<Value*> Ops; - for (const auto &ArgIndex : ArgIndices) { - Value *V = *AI; - LoadInst *OrigLoad = - OriginalLoads[std::make_pair(&*I, ArgIndex.second)]; - if (!ArgIndex.second.empty()) { - Ops.reserve(ArgIndex.second.size()); - Type *ElTy = V->getType(); - for (unsigned long II : ArgIndex.second) { - // Use i32 to index structs, and i64 for others (pointers/arrays). - // This satisfies GEP constraints. - Type *IdxTy = (ElTy->isStructTy() ? - Type::getInt32Ty(F->getContext()) : - Type::getInt64Ty(F->getContext())); - Ops.push_back(ConstantInt::get(IdxTy, II)); - // Keep track of the type we're currently indexing. - if (auto *ElPTy = dyn_cast<PointerType>(ElTy)) - ElTy = ElPTy->getElementType(); - else - ElTy = cast<CompositeType>(ElTy)->getTypeAtIndex(II); - } - // And create a GEP to extract those indices. - V = GetElementPtrInst::Create(ArgIndex.first, V, Ops, - V->getName() + ".idx", Call); - Ops.clear(); + + // If all the elements are single-value types, we can promote it. + bool AllSimple = true; + for (const auto *EltTy : STy->elements()) { + if (!EltTy->isSingleValueType()) { + AllSimple = false; + break; } - // Since we're replacing a load make sure we take the alignment - // of the previous load. - LoadInst *newLoad = new LoadInst(V, V->getName()+".val", Call); - newLoad->setAlignment(OrigLoad->getAlignment()); - // Transfer the AA info too. - AAMDNodes AAInfo; - OrigLoad->getAAMetadata(AAInfo); - newLoad->setAAMetadata(AAInfo); + } - Args.push_back(newLoad); + // Safe to transform, don't even bother trying to "promote" it. + // Passing the elements as a scalar will allow sroa to hack on + // the new alloca we introduce. + if (AllSimple) { + ByValArgsToTransform.insert(PtrArg); + continue; } } + } - // Push any varargs arguments on the list. - for (; AI != CS.arg_end(); ++AI, ++ArgIndex) { - Args.push_back(*AI); - if (CallPAL.hasAttributes(ArgIndex)) { - AttrBuilder B(CallPAL, ArgIndex); - AttributesVec. - push_back(AttributeSet::get(F->getContext(), Args.size(), B)); + // If the argument is a recursive type and we're in a recursive + // function, we could end up infinitely peeling the function argument. + if (isSelfRecursive) { + if (StructType *STy = dyn_cast<StructType>(AgTy)) { + bool RecursiveType = false; + for (const auto *EltTy : STy->elements()) { + if (EltTy == PtrArg->getType()) { + RecursiveType = true; + break; + } + } + if (RecursiveType) + continue; } } - // Add any function attributes. - if (CallPAL.hasAttributes(AttributeSet::FunctionIndex)) - AttributesVec.push_back(AttributeSet::get(Call->getContext(), - CallPAL.getFnAttributes())); + // Otherwise, see if we can promote the pointer to its value. + if (isSafeToPromoteArgument(PtrArg, PtrArg->hasByValOrInAllocaAttr(), AAR, + MaxElements)) + ArgsToPromote.insert(PtrArg); + } - SmallVector<OperandBundleDef, 1> OpBundles; - CS.getOperandBundlesAsDefs(OpBundles); + // No promotable pointer arguments. + if (ArgsToPromote.empty() && ByValArgsToTransform.empty()) + return nullptr; - Instruction *New; - if (InvokeInst *II = dyn_cast<InvokeInst>(Call)) { - New = InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(), - Args, OpBundles, "", Call); - cast<InvokeInst>(New)->setCallingConv(CS.getCallingConv()); - cast<InvokeInst>(New)->setAttributes(AttributeSet::get(II->getContext(), - AttributesVec)); - } else { - New = CallInst::Create(NF, Args, OpBundles, "", Call); - cast<CallInst>(New)->setCallingConv(CS.getCallingConv()); - cast<CallInst>(New)->setAttributes(AttributeSet::get(New->getContext(), - AttributesVec)); - cast<CallInst>(New)->setTailCallKind( - cast<CallInst>(Call)->getTailCallKind()); - } - New->setDebugLoc(Call->getDebugLoc()); - Args.clear(); - AttributesVec.clear(); + return doPromotion(F, ArgsToPromote, ByValArgsToTransform, ReplaceCallSite); +} - // Update the callgraph to know that the callsite has been transformed. - CallGraphNode *CalleeNode = CG[Call->getParent()->getParent()]; - CalleeNode->replaceCallEdge(CS, CallSite(New), NF_CGN); +PreservedAnalyses ArgumentPromotionPass::run(LazyCallGraph::SCC &C, + CGSCCAnalysisManager &AM, + LazyCallGraph &CG, + CGSCCUpdateResult &UR) { + bool Changed = false, LocalChange; - if (!Call->use_empty()) { - Call->replaceAllUsesWith(New); - New->takeName(Call); + // Iterate until we stop promoting from this SCC. + do { + LocalChange = false; + + for (LazyCallGraph::Node &N : C) { + Function &OldF = N.getFunction(); + + FunctionAnalysisManager &FAM = + AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager(); + // FIXME: This lambda must only be used with this function. We should + // skip the lambda and just get the AA results directly. + auto AARGetter = [&](Function &F) -> AAResults & { + assert(&F == &OldF && "Called with an unexpected function!"); + return FAM.getResult<AAManager>(F); + }; + + Function *NewF = promoteArguments(&OldF, AARGetter, 3u, None); + if (!NewF) + continue; + LocalChange = true; + + // Directly substitute the functions in the call graph. Note that this + // requires the old function to be completely dead and completely + // replaced by the new function. It does no call graph updates, it merely + // swaps out the particular function mapped to a particular node in the + // graph. + C.getOuterRefSCC().replaceNodeFunction(N, *NewF); + OldF.eraseFromParent(); } - // Finally, remove the old call from the program, reducing the use-count of - // F. - Call->eraseFromParent(); - } + Changed |= LocalChange; + } while (LocalChange); - // Since we have now created the new function, splice the body of the old - // function right into the new function, leaving the old rotting hulk of the - // function empty. - NF->getBasicBlockList().splice(NF->begin(), F->getBasicBlockList()); + if (!Changed) + return PreservedAnalyses::all(); - // Loop over the argument list, transferring uses of the old arguments over to - // the new arguments, also transferring over the names as well. - // - for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(), - I2 = NF->arg_begin(); I != E; ++I) { - if (!ArgsToPromote.count(&*I) && !ByValArgsToTransform.count(&*I)) { - // If this is an unmodified argument, move the name and users over to the - // new version. - I->replaceAllUsesWith(&*I2); - I2->takeName(&*I); - ++I2; - continue; - } - - if (ByValArgsToTransform.count(&*I)) { - // In the callee, we create an alloca, and store each of the new incoming - // arguments into the alloca. - Instruction *InsertPt = &NF->begin()->front(); + return PreservedAnalyses::none(); +} - // Just add all the struct element types. - Type *AgTy = cast<PointerType>(I->getType())->getElementType(); - Value *TheAlloca = new AllocaInst(AgTy, nullptr, "", InsertPt); - StructType *STy = cast<StructType>(AgTy); - Value *Idxs[2] = { - ConstantInt::get(Type::getInt32Ty(F->getContext()), 0), nullptr }; +namespace { +/// ArgPromotion - The 'by reference' to 'by value' argument promotion pass. +/// +struct ArgPromotion : public CallGraphSCCPass { + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); + getAAResultsAnalysisUsage(AU); + CallGraphSCCPass::getAnalysisUsage(AU); + } - for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { - Idxs[1] = ConstantInt::get(Type::getInt32Ty(F->getContext()), i); - Value *Idx = GetElementPtrInst::Create( - AgTy, TheAlloca, Idxs, TheAlloca->getName() + "." + Twine(i), - InsertPt); - I2->setName(I->getName()+"."+Twine(i)); - new StoreInst(&*I2++, Idx, InsertPt); - } + bool runOnSCC(CallGraphSCC &SCC) override; + static char ID; // Pass identification, replacement for typeid + explicit ArgPromotion(unsigned MaxElements = 3) + : CallGraphSCCPass(ID), MaxElements(MaxElements) { + initializeArgPromotionPass(*PassRegistry::getPassRegistry()); + } - // Anything that used the arg should now use the alloca. - I->replaceAllUsesWith(TheAlloca); - TheAlloca->takeName(&*I); +private: + using llvm::Pass::doInitialization; + bool doInitialization(CallGraph &CG) override; + /// The maximum number of elements to expand, or 0 for unlimited. + unsigned MaxElements; +}; +} - // If the alloca is used in a call, we must clear the tail flag since - // the callee now uses an alloca from the caller. - for (User *U : TheAlloca->users()) { - CallInst *Call = dyn_cast<CallInst>(U); - if (!Call) - continue; - Call->setTailCall(false); - } - continue; - } +char ArgPromotion::ID = 0; +INITIALIZE_PASS_BEGIN(ArgPromotion, "argpromotion", + "Promote 'by reference' arguments to scalars", false, + false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(ArgPromotion, "argpromotion", + "Promote 'by reference' arguments to scalars", false, false) - if (I->use_empty()) - continue; +Pass *llvm::createArgumentPromotionPass(unsigned MaxElements) { + return new ArgPromotion(MaxElements); +} - // Otherwise, if we promoted this argument, then all users are load - // instructions (or GEPs with only load users), and all loads should be - // using the new argument that we added. - ScalarizeTable &ArgIndices = ScalarizedElements[&*I]; +bool ArgPromotion::runOnSCC(CallGraphSCC &SCC) { + if (skipSCC(SCC)) + return false; - while (!I->use_empty()) { - if (LoadInst *LI = dyn_cast<LoadInst>(I->user_back())) { - assert(ArgIndices.begin()->second.empty() && - "Load element should sort to front!"); - I2->setName(I->getName()+".val"); - LI->replaceAllUsesWith(&*I2); - LI->eraseFromParent(); - DEBUG(dbgs() << "*** Promoted load of argument '" << I->getName() - << "' in function '" << F->getName() << "'\n"); - } else { - GetElementPtrInst *GEP = cast<GetElementPtrInst>(I->user_back()); - IndicesVector Operands; - Operands.reserve(GEP->getNumIndices()); - for (User::op_iterator II = GEP->idx_begin(), IE = GEP->idx_end(); - II != IE; ++II) - Operands.push_back(cast<ConstantInt>(*II)->getSExtValue()); + // Get the callgraph information that we need to update to reflect our + // changes. + CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); - // GEPs with a single 0 index can be merged with direct loads - if (Operands.size() == 1 && Operands.front() == 0) - Operands.clear(); + LegacyAARGetter AARGetter(*this); - Function::arg_iterator TheArg = I2; - for (ScalarizeTable::iterator It = ArgIndices.begin(); - It->second != Operands; ++It, ++TheArg) { - assert(It != ArgIndices.end() && "GEP not handled??"); - } + bool Changed = false, LocalChange; - std::string NewName = I->getName(); - for (unsigned i = 0, e = Operands.size(); i != e; ++i) { - NewName += "." + utostr(Operands[i]); - } - NewName += ".val"; - TheArg->setName(NewName); + // Iterate until we stop promoting from this SCC. + do { + LocalChange = false; + // Attempt to promote arguments from all functions in this SCC. + for (CallGraphNode *OldNode : SCC) { + Function *OldF = OldNode->getFunction(); + if (!OldF) + continue; + + auto ReplaceCallSite = [&](CallSite OldCS, CallSite NewCS) { + Function *Caller = OldCS.getInstruction()->getParent()->getParent(); + CallGraphNode *NewCalleeNode = + CG.getOrInsertFunction(NewCS.getCalledFunction()); + CallGraphNode *CallerNode = CG[Caller]; + CallerNode->replaceCallEdge(OldCS, NewCS, NewCalleeNode); + }; + + if (Function *NewF = promoteArguments(OldF, AARGetter, MaxElements, + {ReplaceCallSite})) { + LocalChange = true; - DEBUG(dbgs() << "*** Promoted agg argument '" << TheArg->getName() - << "' of function '" << NF->getName() << "'\n"); + // Update the call graph for the newly promoted function. + CallGraphNode *NewNode = CG.getOrInsertFunction(NewF); + NewNode->stealCalledFunctionsFrom(OldNode); + if (OldNode->getNumReferences() == 0) + delete CG.removeFunctionFromModule(OldNode); + else + OldF->setLinkage(Function::ExternalLinkage); - // All of the uses must be load instructions. Replace them all with - // the argument specified by ArgNo. - while (!GEP->use_empty()) { - LoadInst *L = cast<LoadInst>(GEP->user_back()); - L->replaceAllUsesWith(&*TheArg); - L->eraseFromParent(); - } - GEP->eraseFromParent(); + // And updat ethe SCC we're iterating as well. + SCC.ReplaceNode(OldNode, NewNode); } } + // Remember that we changed something. + Changed |= LocalChange; + } while (LocalChange); - // Increment I2 past all of the arguments added for this promoted pointer. - std::advance(I2, ArgIndices.size()); - } - - NF_CGN->stealCalledFunctionsFrom(CG[F]); - - // Now that the old function is dead, delete it. If there is a dangling - // reference to the CallgraphNode, just leave the dead function around for - // someone else to nuke. - CallGraphNode *CGN = CG[F]; - if (CGN->getNumReferences() == 0) - delete CG.removeFunctionFromModule(CGN); - else - F->setLinkage(Function::ExternalLinkage); - - return NF_CGN; + return Changed; } bool ArgPromotion::doInitialization(CallGraph &CG) { diff --git a/contrib/llvm/lib/Transforms/IPO/ConstantMerge.cpp b/contrib/llvm/lib/Transforms/IPO/ConstantMerge.cpp index d75ed20..62b5a9c 100644 --- a/contrib/llvm/lib/Transforms/IPO/ConstantMerge.cpp +++ b/contrib/llvm/lib/Transforms/IPO/ConstantMerge.cpp @@ -60,6 +60,23 @@ static bool IsBetterCanonical(const GlobalVariable &A, return A.hasGlobalUnnamedAddr(); } +static bool hasMetadataOtherThanDebugLoc(const GlobalVariable *GV) { + SmallVector<std::pair<unsigned, MDNode *>, 4> MDs; + GV->getAllMetadata(MDs); + for (const auto &V : MDs) + if (V.first != LLVMContext::MD_dbg) + return true; + return false; +} + +static void copyDebugLocMetadata(const GlobalVariable *From, + GlobalVariable *To) { + SmallVector<DIGlobalVariableExpression *, 1> MDs; + From->getDebugInfo(MDs); + for (auto MD : MDs) + To->addDebugInfo(MD); +} + static unsigned getAlignment(GlobalVariable *GV) { unsigned Align = GV->getAlignment(); if (Align) @@ -113,6 +130,10 @@ static bool mergeConstants(Module &M) { if (GV->isWeakForLinker()) continue; + // Don't touch globals with metadata other then !dbg. + if (hasMetadataOtherThanDebugLoc(GV)) + continue; + Constant *Init = GV->getInitializer(); // Check to see if the initializer is already known. @@ -155,6 +176,9 @@ static bool mergeConstants(Module &M) { if (!Slot->hasGlobalUnnamedAddr() && !GV->hasGlobalUnnamedAddr()) continue; + if (hasMetadataOtherThanDebugLoc(GV)) + continue; + if (!GV->hasGlobalUnnamedAddr()) Slot->setUnnamedAddr(GlobalValue::UnnamedAddr::None); @@ -178,6 +202,8 @@ static bool mergeConstants(Module &M) { getAlignment(Replacements[i].second))); } + copyDebugLocMetadata(Replacements[i].first, Replacements[i].second); + // Eliminate any uses of the dead global. Replacements[i].first->replaceAllUsesWith(Replacements[i].second); diff --git a/contrib/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp b/contrib/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp index ba2e60d..d94aa5d 100644 --- a/contrib/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp +++ b/contrib/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp @@ -95,11 +95,25 @@ void CrossDSOCFI::buildCFICheck(Module &M) { } } + NamedMDNode *CfiFunctionsMD = M.getNamedMetadata("cfi.functions"); + if (CfiFunctionsMD) { + for (auto Func : CfiFunctionsMD->operands()) { + assert(Func->getNumOperands() >= 2); + for (unsigned I = 2; I < Func->getNumOperands(); ++I) + if (ConstantInt *TypeId = + extractNumericTypeId(cast<MDNode>(Func->getOperand(I).get()))) + TypeIds.insert(TypeId->getZExtValue()); + } + } + LLVMContext &Ctx = M.getContext(); Constant *C = M.getOrInsertFunction( "__cfi_check", Type::getVoidTy(Ctx), Type::getInt64Ty(Ctx), - Type::getInt8PtrTy(Ctx), Type::getInt8PtrTy(Ctx), nullptr); + Type::getInt8PtrTy(Ctx), Type::getInt8PtrTy(Ctx)); Function *F = dyn_cast<Function>(C); + // Take over the existing function. The frontend emits a weak stub so that the + // linker knows about the symbol; this pass replaces the function body. + F->deleteBody(); F->setAlignment(4096); auto args = F->arg_begin(); Value &CallSiteTypeId = *(args++); @@ -117,7 +131,7 @@ void CrossDSOCFI::buildCFICheck(Module &M) { IRBuilder<> IRBFail(TrapBB); Constant *CFICheckFailFn = M.getOrInsertFunction( "__cfi_check_fail", Type::getVoidTy(Ctx), Type::getInt8PtrTy(Ctx), - Type::getInt8PtrTy(Ctx), nullptr); + Type::getInt8PtrTy(Ctx)); IRBFail.CreateCall(CFICheckFailFn, {&CFICheckFailData, &Addr}); IRBFail.CreateBr(ExitBB); diff --git a/contrib/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp b/contrib/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp index 1a5ed46..8e26849 100644 --- a/contrib/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp +++ b/contrib/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp @@ -166,41 +166,40 @@ bool DeadArgumentEliminationPass::DeleteDeadVarargs(Function &Fn) { Args.assign(CS.arg_begin(), CS.arg_begin() + NumArgs); // Drop any attributes that were on the vararg arguments. - AttributeSet PAL = CS.getAttributes(); - if (!PAL.isEmpty() && PAL.getSlotIndex(PAL.getNumSlots() - 1) > NumArgs) { - SmallVector<AttributeSet, 8> AttributesVec; - for (unsigned i = 0; PAL.getSlotIndex(i) <= NumArgs; ++i) - AttributesVec.push_back(PAL.getSlotAttributes(i)); - if (PAL.hasAttributes(AttributeSet::FunctionIndex)) - AttributesVec.push_back(AttributeSet::get(Fn.getContext(), - PAL.getFnAttributes())); - PAL = AttributeSet::get(Fn.getContext(), AttributesVec); + AttributeList PAL = CS.getAttributes(); + if (!PAL.isEmpty()) { + SmallVector<AttributeSet, 8> ArgAttrs; + for (unsigned ArgNo = 0; ArgNo < NumArgs; ++ArgNo) + ArgAttrs.push_back(PAL.getParamAttributes(ArgNo)); + PAL = AttributeList::get(Fn.getContext(), PAL.getFnAttributes(), + PAL.getRetAttributes(), ArgAttrs); } SmallVector<OperandBundleDef, 1> OpBundles; CS.getOperandBundlesAsDefs(OpBundles); - Instruction *New; + CallSite NewCS; if (InvokeInst *II = dyn_cast<InvokeInst>(Call)) { - New = InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(), - Args, OpBundles, "", Call); - cast<InvokeInst>(New)->setCallingConv(CS.getCallingConv()); - cast<InvokeInst>(New)->setAttributes(PAL); + NewCS = InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(), + Args, OpBundles, "", Call); } else { - New = CallInst::Create(NF, Args, OpBundles, "", Call); - cast<CallInst>(New)->setCallingConv(CS.getCallingConv()); - cast<CallInst>(New)->setAttributes(PAL); - cast<CallInst>(New)->setTailCallKind( - cast<CallInst>(Call)->getTailCallKind()); + NewCS = CallInst::Create(NF, Args, OpBundles, "", Call); + cast<CallInst>(NewCS.getInstruction()) + ->setTailCallKind(cast<CallInst>(Call)->getTailCallKind()); } - New->setDebugLoc(Call->getDebugLoc()); + NewCS.setCallingConv(CS.getCallingConv()); + NewCS.setAttributes(PAL); + NewCS->setDebugLoc(Call->getDebugLoc()); + uint64_t W; + if (Call->extractProfTotalWeight(W)) + NewCS->setProfWeight(W); Args.clear(); if (!Call->use_empty()) - Call->replaceAllUsesWith(New); + Call->replaceAllUsesWith(NewCS.getInstruction()); - New->takeName(Call); + NewCS->takeName(Call); // Finally, remove the old call from the program, reducing the use-count of // F. @@ -681,8 +680,8 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { bool HasLiveReturnedArg = false; // Set up to build a new list of parameter attributes. - SmallVector<AttributeSet, 8> AttributesVec; - const AttributeSet &PAL = F->getAttributes(); + SmallVector<AttributeSet, 8> ArgAttrVec; + const AttributeList &PAL = F->getAttributes(); // Remember which arguments are still alive. SmallVector<bool, 10> ArgAlive(FTy->getNumParams(), false); @@ -696,16 +695,8 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { if (LiveValues.erase(Arg)) { Params.push_back(I->getType()); ArgAlive[i] = true; - - // Get the original parameter attributes (skipping the first one, that is - // for the return value. - if (PAL.hasAttributes(i + 1)) { - AttrBuilder B(PAL, i + 1); - if (B.contains(Attribute::Returned)) - HasLiveReturnedArg = true; - AttributesVec. - push_back(AttributeSet::get(F->getContext(), Params.size(), B)); - } + ArgAttrVec.push_back(PAL.getParamAttributes(i)); + HasLiveReturnedArg |= PAL.hasParamAttribute(i, Attribute::Returned); } else { ++NumArgumentsEliminated; DEBUG(dbgs() << "DeadArgumentEliminationPass - Removing argument " << i @@ -779,30 +770,24 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { assert(NRetTy && "No new return type found?"); // The existing function return attributes. - AttributeSet RAttrs = PAL.getRetAttributes(); + AttrBuilder RAttrs(PAL.getRetAttributes()); // Remove any incompatible attributes, but only if we removed all return // values. Otherwise, ensure that we don't have any conflicting attributes // here. Currently, this should not be possible, but special handling might be // required when new return value attributes are added. if (NRetTy->isVoidTy()) - RAttrs = RAttrs.removeAttributes(NRetTy->getContext(), - AttributeSet::ReturnIndex, - AttributeFuncs::typeIncompatible(NRetTy)); + RAttrs.remove(AttributeFuncs::typeIncompatible(NRetTy)); else - assert(!AttrBuilder(RAttrs, AttributeSet::ReturnIndex). - overlaps(AttributeFuncs::typeIncompatible(NRetTy)) && + assert(!RAttrs.overlaps(AttributeFuncs::typeIncompatible(NRetTy)) && "Return attributes no longer compatible?"); - if (RAttrs.hasAttributes(AttributeSet::ReturnIndex)) - AttributesVec.push_back(AttributeSet::get(NRetTy->getContext(), RAttrs)); - - if (PAL.hasAttributes(AttributeSet::FunctionIndex)) - AttributesVec.push_back(AttributeSet::get(F->getContext(), - PAL.getFnAttributes())); + AttributeSet RetAttrs = AttributeSet::get(F->getContext(), RAttrs); // Reconstruct the AttributesList based on the vector we constructed. - AttributeSet NewPAL = AttributeSet::get(F->getContext(), AttributesVec); + assert(ArgAttrVec.size() == Params.size()); + AttributeList NewPAL = AttributeList::get( + F->getContext(), PAL.getFnAttributes(), RetAttrs, ArgAttrVec); // Create the new function type based on the recomputed parameters. FunctionType *NFTy = FunctionType::get(NRetTy, Params, FTy->isVarArg()); @@ -829,18 +814,14 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { CallSite CS(F->user_back()); Instruction *Call = CS.getInstruction(); - AttributesVec.clear(); - const AttributeSet &CallPAL = CS.getAttributes(); - - // The call return attributes. - AttributeSet RAttrs = CallPAL.getRetAttributes(); + ArgAttrVec.clear(); + const AttributeList &CallPAL = CS.getAttributes(); - // Adjust in case the function was changed to return void. - RAttrs = RAttrs.removeAttributes(NRetTy->getContext(), - AttributeSet::ReturnIndex, - AttributeFuncs::typeIncompatible(NF->getReturnType())); - if (RAttrs.hasAttributes(AttributeSet::ReturnIndex)) - AttributesVec.push_back(AttributeSet::get(NF->getContext(), RAttrs)); + // Adjust the call return attributes in case the function was changed to + // return void. + AttrBuilder RAttrs(CallPAL.getRetAttributes()); + RAttrs.remove(AttributeFuncs::typeIncompatible(NRetTy)); + AttributeSet RetAttrs = AttributeSet::get(F->getContext(), RAttrs); // Declare these outside of the loops, so we can reuse them for the second // loop, which loops the varargs. @@ -852,57 +833,55 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { if (ArgAlive[i]) { Args.push_back(*I); // Get original parameter attributes, but skip return attributes. - if (CallPAL.hasAttributes(i + 1)) { - AttrBuilder B(CallPAL, i + 1); + AttributeSet Attrs = CallPAL.getParamAttributes(i); + if (NRetTy != RetTy && Attrs.hasAttribute(Attribute::Returned)) { // If the return type has changed, then get rid of 'returned' on the // call site. The alternative is to make all 'returned' attributes on // call sites keep the return value alive just like 'returned' - // attributes on function declaration but it's less clearly a win - // and this is not an expected case anyway - if (NRetTy != RetTy && B.contains(Attribute::Returned)) - B.removeAttribute(Attribute::Returned); - AttributesVec. - push_back(AttributeSet::get(F->getContext(), Args.size(), B)); + // attributes on function declaration but it's less clearly a win and + // this is not an expected case anyway + ArgAttrVec.push_back(AttributeSet::get( + F->getContext(), + AttrBuilder(Attrs).removeAttribute(Attribute::Returned))); + } else { + // Otherwise, use the original attributes. + ArgAttrVec.push_back(Attrs); } } // Push any varargs arguments on the list. Don't forget their attributes. for (CallSite::arg_iterator E = CS.arg_end(); I != E; ++I, ++i) { Args.push_back(*I); - if (CallPAL.hasAttributes(i + 1)) { - AttrBuilder B(CallPAL, i + 1); - AttributesVec. - push_back(AttributeSet::get(F->getContext(), Args.size(), B)); - } + ArgAttrVec.push_back(CallPAL.getParamAttributes(i)); } - if (CallPAL.hasAttributes(AttributeSet::FunctionIndex)) - AttributesVec.push_back(AttributeSet::get(Call->getContext(), - CallPAL.getFnAttributes())); - // Reconstruct the AttributesList based on the vector we constructed. - AttributeSet NewCallPAL = AttributeSet::get(F->getContext(), AttributesVec); + assert(ArgAttrVec.size() == Args.size()); + AttributeList NewCallPAL = AttributeList::get( + F->getContext(), CallPAL.getFnAttributes(), RetAttrs, ArgAttrVec); SmallVector<OperandBundleDef, 1> OpBundles; CS.getOperandBundlesAsDefs(OpBundles); - Instruction *New; + CallSite NewCS; if (InvokeInst *II = dyn_cast<InvokeInst>(Call)) { - New = InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(), - Args, OpBundles, "", Call->getParent()); - cast<InvokeInst>(New)->setCallingConv(CS.getCallingConv()); - cast<InvokeInst>(New)->setAttributes(NewCallPAL); + NewCS = InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(), + Args, OpBundles, "", Call->getParent()); } else { - New = CallInst::Create(NF, Args, OpBundles, "", Call); - cast<CallInst>(New)->setCallingConv(CS.getCallingConv()); - cast<CallInst>(New)->setAttributes(NewCallPAL); - cast<CallInst>(New)->setTailCallKind( - cast<CallInst>(Call)->getTailCallKind()); + NewCS = CallInst::Create(NF, Args, OpBundles, "", Call); + cast<CallInst>(NewCS.getInstruction()) + ->setTailCallKind(cast<CallInst>(Call)->getTailCallKind()); } - New->setDebugLoc(Call->getDebugLoc()); - + NewCS.setCallingConv(CS.getCallingConv()); + NewCS.setAttributes(NewCallPAL); + NewCS->setDebugLoc(Call->getDebugLoc()); + uint64_t W; + if (Call->extractProfTotalWeight(W)) + NewCS->setProfWeight(W); Args.clear(); + ArgAttrVec.clear(); + Instruction *New = NewCS.getInstruction(); if (!Call->use_empty()) { if (New->getType() == Call->getType()) { // Return type not changed? Just replace users then. diff --git a/contrib/llvm/lib/Transforms/IPO/ElimAvailExtern.cpp b/contrib/llvm/lib/Transforms/IPO/ElimAvailExtern.cpp index 98c4b17..ecff88c 100644 --- a/contrib/llvm/lib/Transforms/IPO/ElimAvailExtern.cpp +++ b/contrib/llvm/lib/Transforms/IPO/ElimAvailExtern.cpp @@ -17,9 +17,9 @@ #include "llvm/ADT/Statistic.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Module.h" +#include "llvm/Pass.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Utils/GlobalStatus.h" -#include "llvm/Pass.h" using namespace llvm; #define DEBUG_TYPE "elim-avail-extern" diff --git a/contrib/llvm/lib/Transforms/IPO/ExtractGV.cpp b/contrib/llvm/lib/Transforms/IPO/ExtractGV.cpp index 479fd18..d1147f7 100644 --- a/contrib/llvm/lib/Transforms/IPO/ExtractGV.cpp +++ b/contrib/llvm/lib/Transforms/IPO/ExtractGV.cpp @@ -11,13 +11,13 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/IPO.h" #include "llvm/ADT/SetVector.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/Pass.h" +#include "llvm/Transforms/IPO.h" #include <algorithm> using namespace llvm; @@ -53,18 +53,18 @@ static void makeVisible(GlobalValue &GV, bool Delete) { } namespace { - /// @brief A pass to extract specific functions and their dependencies. + /// @brief A pass to extract specific global values and their dependencies. class GVExtractorPass : public ModulePass { SetVector<GlobalValue *> Named; bool deleteStuff; public: static char ID; // Pass identification, replacement for typeid - /// FunctionExtractorPass - If deleteFn is true, this pass deletes as the - /// specified function. Otherwise, it deletes as much of the module as - /// possible, except for the function specified. - /// - explicit GVExtractorPass(std::vector<GlobalValue*>& GVs, bool deleteS = true) + /// If deleteS is true, this pass deletes the specified global values. + /// Otherwise, it deletes as much of the module as possible, except for the + /// global values specified. + explicit GVExtractorPass(std::vector<GlobalValue*> &GVs, + bool deleteS = true) : ModulePass(ID), Named(GVs.begin(), GVs.end()), deleteStuff(deleteS) {} bool runOnModule(Module &M) override { diff --git a/contrib/llvm/lib/Transforms/IPO/FunctionAttrs.cpp b/contrib/llvm/lib/Transforms/IPO/FunctionAttrs.cpp index 402a665..813a4b6 100644 --- a/contrib/llvm/lib/Transforms/IPO/FunctionAttrs.cpp +++ b/contrib/llvm/lib/Transforms/IPO/FunctionAttrs.cpp @@ -14,7 +14,6 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/IPO/FunctionAttrs.h" -#include "llvm/Transforms/IPO.h" #include "llvm/ADT/SCCIterator.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallSet.h" @@ -34,7 +33,7 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Transforms/IPO.h" using namespace llvm; #define DEBUG_TYPE "functionattrs" @@ -49,31 +48,35 @@ STATISTIC(NumNoAlias, "Number of function returns marked noalias"); STATISTIC(NumNonNullReturn, "Number of function returns marked nonnull"); STATISTIC(NumNoRecurse, "Number of functions marked as norecurse"); -namespace { -typedef SmallSetVector<Function *, 8> SCCNodeSet; -} +// FIXME: This is disabled by default to avoid exposing security vulnerabilities +// in C/C++ code compiled by clang: +// http://lists.llvm.org/pipermail/cfe-dev/2017-January/052066.html +static cl::opt<bool> EnableNonnullArgPropagation( + "enable-nonnull-arg-prop", cl::Hidden, + cl::desc("Try to propagate nonnull argument attributes from callsites to " + "caller functions.")); namespace { -/// The three kinds of memory access relevant to 'readonly' and -/// 'readnone' attributes. -enum MemoryAccessKind { - MAK_ReadNone = 0, - MAK_ReadOnly = 1, - MAK_MayWrite = 2 -}; +typedef SmallSetVector<Function *, 8> SCCNodeSet; } -static MemoryAccessKind checkFunctionMemoryAccess(Function &F, AAResults &AAR, +/// Returns the memory access attribute for function F using AAR for AA results, +/// where SCCNodes is the current SCC. +/// +/// If ThisBody is true, this function may examine the function body and will +/// return a result pertaining to this copy of the function. If it is false, the +/// result will be based only on AA results for the function declaration; it +/// will be assumed that some other (perhaps less optimized) version of the +/// function may be selected at link time. +static MemoryAccessKind checkFunctionMemoryAccess(Function &F, bool ThisBody, + AAResults &AAR, const SCCNodeSet &SCCNodes) { FunctionModRefBehavior MRB = AAR.getModRefBehavior(&F); if (MRB == FMRB_DoesNotAccessMemory) // Already perfect! return MAK_ReadNone; - // Non-exact function definitions may not be selected at link time, and an - // alternative version that writes to memory may be selected. See the comment - // on GlobalValue::isDefinitionExact for more details. - if (!F.hasExactDefinition()) { + if (!ThisBody) { if (AliasAnalysis::onlyReadsMemory(MRB)) return MAK_ReadOnly; @@ -172,9 +175,14 @@ static MemoryAccessKind checkFunctionMemoryAccess(Function &F, AAResults &AAR, return ReadsMemory ? MAK_ReadOnly : MAK_ReadNone; } +MemoryAccessKind llvm::computeFunctionBodyMemoryAccess(Function &F, + AAResults &AAR) { + return checkFunctionMemoryAccess(F, /*ThisBody=*/true, AAR, {}); +} + /// Deduce readonly/readnone attributes for the SCC. template <typename AARGetterT> -static bool addReadAttrs(const SCCNodeSet &SCCNodes, AARGetterT AARGetter) { +static bool addReadAttrs(const SCCNodeSet &SCCNodes, AARGetterT &&AARGetter) { // Check if any of the functions in the SCC read or write memory. If they // write memory then they can't be marked readnone or readonly. bool ReadsMemory = false; @@ -182,7 +190,11 @@ static bool addReadAttrs(const SCCNodeSet &SCCNodes, AARGetterT AARGetter) { // Call the callable parameter to look up AA results for this function. AAResults &AAR = AARGetter(*F); - switch (checkFunctionMemoryAccess(*F, AAR, SCCNodes)) { + // Non-exact function definitions may not be selected at link time, and an + // alternative version that writes to memory may be selected. See the + // comment on GlobalValue::isDefinitionExact for more details. + switch (checkFunctionMemoryAccess(*F, F->hasExactDefinition(), + AAR, SCCNodes)) { case MAK_MayWrite: return false; case MAK_ReadOnly: @@ -209,15 +221,11 @@ static bool addReadAttrs(const SCCNodeSet &SCCNodes, AARGetterT AARGetter) { MadeChange = true; // Clear out any existing attributes. - AttrBuilder B; - B.addAttribute(Attribute::ReadOnly).addAttribute(Attribute::ReadNone); - F->removeAttributes( - AttributeSet::FunctionIndex, - AttributeSet::get(F->getContext(), AttributeSet::FunctionIndex, B)); + F->removeFnAttr(Attribute::ReadOnly); + F->removeFnAttr(Attribute::ReadNone); // Add in the new attribute. - F->addAttribute(AttributeSet::FunctionIndex, - ReadsMemory ? Attribute::ReadOnly : Attribute::ReadNone); + F->addFnAttr(ReadsMemory ? Attribute::ReadOnly : Attribute::ReadNone); if (ReadsMemory) ++NumReadOnly; @@ -482,9 +490,6 @@ determinePointerReadAttrs(Argument *A, static bool addArgumentReturnedAttrs(const SCCNodeSet &SCCNodes) { bool Changed = false; - AttrBuilder B; - B.addAttribute(Attribute::Returned); - // Check each function in turn, determining if an argument is always returned. for (Function *F : SCCNodes) { // We can infer and propagate function attributes only when we know that the @@ -522,7 +527,7 @@ static bool addArgumentReturnedAttrs(const SCCNodeSet &SCCNodes) { if (Value *RetArg = FindRetArg()) { auto *A = cast<Argument>(RetArg); - A->addAttr(AttributeSet::get(F->getContext(), A->getArgNo() + 1, B)); + A->addAttr(Attribute::Returned); ++NumReturned; Changed = true; } @@ -531,15 +536,55 @@ static bool addArgumentReturnedAttrs(const SCCNodeSet &SCCNodes) { return Changed; } +/// If a callsite has arguments that are also arguments to the parent function, +/// try to propagate attributes from the callsite's arguments to the parent's +/// arguments. This may be important because inlining can cause information loss +/// when attribute knowledge disappears with the inlined call. +static bool addArgumentAttrsFromCallsites(Function &F) { + if (!EnableNonnullArgPropagation) + return false; + + bool Changed = false; + + // For an argument attribute to transfer from a callsite to the parent, the + // call must be guaranteed to execute every time the parent is called. + // Conservatively, just check for calls in the entry block that are guaranteed + // to execute. + // TODO: This could be enhanced by testing if the callsite post-dominates the + // entry block or by doing simple forward walks or backward walks to the + // callsite. + BasicBlock &Entry = F.getEntryBlock(); + for (Instruction &I : Entry) { + if (auto CS = CallSite(&I)) { + if (auto *CalledFunc = CS.getCalledFunction()) { + for (auto &CSArg : CalledFunc->args()) { + if (!CSArg.hasNonNullAttr()) + continue; + + // If the non-null callsite argument operand is an argument to 'F' + // (the caller) and the call is guaranteed to execute, then the value + // must be non-null throughout 'F'. + auto *FArg = dyn_cast<Argument>(CS.getArgOperand(CSArg.getArgNo())); + if (FArg && !FArg->hasNonNullAttr()) { + FArg->addAttr(Attribute::NonNull); + Changed = true; + } + } + } + } + if (!isGuaranteedToTransferExecutionToSuccessor(&I)) + break; + } + + return Changed; +} + /// Deduce nocapture attributes for the SCC. static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) { bool Changed = false; ArgumentGraph AG; - AttrBuilder B; - B.addAttribute(Attribute::NoCapture); - // Check each function in turn, determining which pointer arguments are not // captured. for (Function *F : SCCNodes) { @@ -549,6 +594,8 @@ static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) { if (!F->hasExactDefinition()) continue; + Changed |= addArgumentAttrsFromCallsites(*F); + // Functions that are readonly (or readnone) and nounwind and don't return // a value can't capture arguments. Don't analyze them. if (F->onlyReadsMemory() && F->doesNotThrow() && @@ -556,7 +603,7 @@ static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) { for (Function::arg_iterator A = F->arg_begin(), E = F->arg_end(); A != E; ++A) { if (A->getType()->isPointerTy() && !A->hasNoCaptureAttr()) { - A->addAttr(AttributeSet::get(F->getContext(), A->getArgNo() + 1, B)); + A->addAttr(Attribute::NoCapture); ++NumNoCapture; Changed = true; } @@ -575,8 +622,7 @@ static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) { if (!Tracker.Captured) { if (Tracker.Uses.empty()) { // If it's trivially not captured, mark it nocapture now. - A->addAttr( - AttributeSet::get(F->getContext(), A->getArgNo() + 1, B)); + A->addAttr(Attribute::NoCapture); ++NumNoCapture; Changed = true; } else { @@ -602,9 +648,7 @@ static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) { Self.insert(&*A); Attribute::AttrKind R = determinePointerReadAttrs(&*A, Self); if (R != Attribute::None) { - AttrBuilder B; - B.addAttribute(R); - A->addAttr(AttributeSet::get(A->getContext(), A->getArgNo() + 1, B)); + A->addAttr(R); Changed = true; R == Attribute::ReadOnly ? ++NumReadOnlyArg : ++NumReadNoneArg; } @@ -629,7 +673,7 @@ static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) { if (ArgumentSCC[0]->Uses.size() == 1 && ArgumentSCC[0]->Uses[0] == ArgumentSCC[0]) { Argument *A = ArgumentSCC[0]->Definition; - A->addAttr(AttributeSet::get(A->getContext(), A->getArgNo() + 1, B)); + A->addAttr(Attribute::NoCapture); ++NumNoCapture; Changed = true; } @@ -671,7 +715,7 @@ static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) { for (unsigned i = 0, e = ArgumentSCC.size(); i != e; ++i) { Argument *A = ArgumentSCC[i]->Definition; - A->addAttr(AttributeSet::get(A->getContext(), A->getArgNo() + 1, B)); + A->addAttr(Attribute::NoCapture); ++NumNoCapture; Changed = true; } @@ -702,14 +746,12 @@ static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) { } if (ReadAttr != Attribute::None) { - AttrBuilder B, R; - B.addAttribute(ReadAttr); - R.addAttribute(Attribute::ReadOnly).addAttribute(Attribute::ReadNone); for (unsigned i = 0, e = ArgumentSCC.size(); i != e; ++i) { Argument *A = ArgumentSCC[i]->Definition; // Clear out existing readonly/readnone attributes - A->removeAttr(AttributeSet::get(A->getContext(), A->getArgNo() + 1, R)); - A->addAttr(AttributeSet::get(A->getContext(), A->getArgNo() + 1, B)); + A->removeAttr(Attribute::ReadOnly); + A->removeAttr(Attribute::ReadNone); + A->addAttr(ReadAttr); ReadAttr == Attribute::ReadOnly ? ++NumReadOnlyArg : ++NumReadNoneArg; Changed = true; } @@ -769,7 +811,7 @@ static bool isFunctionMallocLike(Function *F, const SCCNodeSet &SCCNodes) { case Instruction::Call: case Instruction::Invoke: { CallSite CS(RVI); - if (CS.paramHasAttr(0, Attribute::NoAlias)) + if (CS.hasRetAttr(Attribute::NoAlias)) break; if (CS.getCalledFunction() && SCCNodes.count(CS.getCalledFunction())) break; @@ -792,7 +834,7 @@ static bool addNoAliasAttrs(const SCCNodeSet &SCCNodes) { // pointers. for (Function *F : SCCNodes) { // Already noalias. - if (F->doesNotAlias(0)) + if (F->returnDoesNotAlias()) continue; // We can infer and propagate function attributes only when we know that the @@ -812,10 +854,11 @@ static bool addNoAliasAttrs(const SCCNodeSet &SCCNodes) { bool MadeChange = false; for (Function *F : SCCNodes) { - if (F->doesNotAlias(0) || !F->getReturnType()->isPointerTy()) + if (F->returnDoesNotAlias() || + !F->getReturnType()->isPointerTy()) continue; - F->setDoesNotAlias(0); + F->setReturnDoesNotAlias(); ++NumNoAlias; MadeChange = true; } @@ -905,7 +948,7 @@ static bool addNonNullAttrs(const SCCNodeSet &SCCNodes) { // pointers. for (Function *F : SCCNodes) { // Already nonnull. - if (F->getAttributes().hasAttribute(AttributeSet::ReturnIndex, + if (F->getAttributes().hasAttribute(AttributeList::ReturnIndex, Attribute::NonNull)) continue; @@ -926,7 +969,7 @@ static bool addNonNullAttrs(const SCCNodeSet &SCCNodes) { // Mark the function eagerly since we may discover a function // which prevents us from speculating about the entire SCC DEBUG(dbgs() << "Eagerly marking " << F->getName() << " as nonnull\n"); - F->addAttribute(AttributeSet::ReturnIndex, Attribute::NonNull); + F->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull); ++NumNonNullReturn; MadeChange = true; } @@ -939,13 +982,13 @@ static bool addNonNullAttrs(const SCCNodeSet &SCCNodes) { if (SCCReturnsNonNull) { for (Function *F : SCCNodes) { - if (F->getAttributes().hasAttribute(AttributeSet::ReturnIndex, + if (F->getAttributes().hasAttribute(AttributeList::ReturnIndex, Attribute::NonNull) || !F->getReturnType()->isPointerTy()) continue; DEBUG(dbgs() << "SCC marking " << F->getName() << " as nonnull\n"); - F->addAttribute(AttributeSet::ReturnIndex, Attribute::NonNull); + F->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull); ++NumNonNullReturn; MadeChange = true; } @@ -1144,6 +1187,10 @@ static bool runImpl(CallGraphSCC &SCC, AARGetterT AARGetter) { SCCNodes.insert(F); } + // Skip it if the SCC only contains optnone functions. + if (SCCNodes.empty()) + return Changed; + Changed |= addArgumentReturnedAttrs(SCCNodes); Changed |= addReadAttrs(SCCNodes, AARGetter); Changed |= addArgumentAttrs(SCCNodes); @@ -1163,19 +1210,7 @@ static bool runImpl(CallGraphSCC &SCC, AARGetterT AARGetter) { bool PostOrderFunctionAttrsLegacyPass::runOnSCC(CallGraphSCC &SCC) { if (skipSCC(SCC)) return false; - - // We compute dedicated AA results for each function in the SCC as needed. We - // use a lambda referencing external objects so that they live long enough to - // be queried, but we re-use them each time. - Optional<BasicAAResult> BAR; - Optional<AAResults> AAR; - auto AARGetter = [&](Function &F) -> AAResults & { - BAR.emplace(createLegacyPMBasicAAResult(*this, F)); - AAR.emplace(createLegacyPMAAResults(*this, F, *BAR)); - return *AAR; - }; - - return runImpl(SCC, AARGetter); + return runImpl(SCC, LegacyAARGetter(*this)); } namespace { @@ -1275,16 +1310,9 @@ PreservedAnalyses ReversePostOrderFunctionAttrsPass::run(Module &M, ModuleAnalysisManager &AM) { auto &CG = AM.getResult<CallGraphAnalysis>(M); - bool Changed = deduceFunctionAttributeInRPO(M, CG); - - // CallGraphAnalysis holds AssertingVH and must be invalidated eagerly so - // that other passes don't delete stuff from under it. - // FIXME: We need to invalidate this to avoid PR28400. Is there a better - // solution? - AM.invalidate<CallGraphAnalysis>(M); - - if (!Changed) + if (!deduceFunctionAttributeInRPO(M, CG)) return PreservedAnalyses::all(); + PreservedAnalyses PA; PA.preserve<CallGraphAnalysis>(); return PA; diff --git a/contrib/llvm/lib/Transforms/IPO/FunctionImport.cpp b/contrib/llvm/lib/Transforms/IPO/FunctionImport.cpp index 6b32f6c..233a36d 100644 --- a/contrib/llvm/lib/Transforms/IPO/FunctionImport.cpp +++ b/contrib/llvm/lib/Transforms/IPO/FunctionImport.cpp @@ -17,6 +17,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringSet.h" #include "llvm/ADT/Triple.h" +#include "llvm/Bitcode/BitcodeReader.h" #include "llvm/IR/AutoUpgrade.h" #include "llvm/IR/DiagnosticPrinter.h" #include "llvm/IR/IntrinsicInst.h" @@ -25,7 +26,6 @@ #include "llvm/IRReader/IRReader.h" #include "llvm/Linker/Linker.h" #include "llvm/Object/IRObjectFile.h" -#include "llvm/Object/ModuleSummaryIndexObjectFile.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/SourceMgr.h" @@ -64,6 +64,12 @@ static cl::opt<float> ImportHotMultiplier( "import-hot-multiplier", cl::init(3.0), cl::Hidden, cl::value_desc("x"), cl::desc("Multiply the `import-instr-limit` threshold for hot callsites")); +static cl::opt<float> ImportCriticalMultiplier( + "import-critical-multiplier", cl::init(100.0), cl::Hidden, + cl::value_desc("x"), + cl::desc( + "Multiply the `import-instr-limit` threshold for critical callsites")); + // FIXME: This multiplier was not really tuned up. static cl::opt<float> ImportColdMultiplier( "import-cold-multiplier", cl::init(0), cl::Hidden, cl::value_desc("N"), @@ -75,12 +81,6 @@ static cl::opt<bool> PrintImports("print-imports", cl::init(false), cl::Hidden, static cl::opt<bool> ComputeDead("compute-dead", cl::init(true), cl::Hidden, cl::desc("Compute dead symbols")); -// Temporary allows the function import pass to disable always linking -// referenced discardable symbols. -static cl::opt<bool> - DontForceImportReferencedDiscardableSymbols("disable-force-link-odr", - cl::init(false), cl::Hidden); - static cl::opt<bool> EnableImportMetadata( "enable-import-metadata", cl::init( #if !defined(NDEBUG) @@ -123,8 +123,8 @@ namespace { /// - [insert you fancy metric here] static const GlobalValueSummary * selectCallee(const ModuleSummaryIndex &Index, - const GlobalValueSummaryList &CalleeSummaryList, - unsigned Threshold) { + ArrayRef<std::unique_ptr<GlobalValueSummary>> CalleeSummaryList, + unsigned Threshold, StringRef CallerModulePath) { auto It = llvm::find_if( CalleeSummaryList, [&](const std::unique_ptr<GlobalValueSummary> &SummaryPtr) { @@ -145,6 +145,21 @@ selectCallee(const ModuleSummaryIndex &Index, auto *Summary = cast<FunctionSummary>(GVSummary); + // If this is a local function, make sure we import the copy + // in the caller's module. The only time a local function can + // share an entry in the index is if there is a local with the same name + // in another module that had the same source file name (in a different + // directory), where each was compiled in their own directory so there + // was not distinguishing path. + // However, do the import from another module if there is only one + // entry in the list - in that case this must be a reference due + // to indirect call profile data, since a function pointer can point to + // a local in another module. + if (GlobalValue::isLocalLinkage(Summary->linkage()) && + CalleeSummaryList.size() > 1 && + Summary->modulePath() != CallerModulePath) + return false; + if (Summary->instCount() > Threshold) return false; @@ -159,17 +174,6 @@ selectCallee(const ModuleSummaryIndex &Index, return cast<GlobalValueSummary>(It->get()); } -/// Return the summary for the function \p GUID that fits the \p Threshold, or -/// null if there's no match. -static const GlobalValueSummary *selectCallee(GlobalValue::GUID GUID, - unsigned Threshold, - const ModuleSummaryIndex &Index) { - auto CalleeSummaryList = Index.findGlobalValueSummaryList(GUID); - if (CalleeSummaryList == Index.end()) - return nullptr; // This function does not have a summary - return selectCallee(Index, CalleeSummaryList->second, Threshold); -} - using EdgeInfo = std::tuple<const FunctionSummary *, unsigned /* Threshold */, GlobalValue::GUID>; @@ -183,10 +187,23 @@ static void computeImportForFunction( FunctionImporter::ImportMapTy &ImportList, StringMap<FunctionImporter::ExportSetTy> *ExportLists = nullptr) { for (auto &Edge : Summary.calls()) { - auto GUID = Edge.first.getGUID(); - DEBUG(dbgs() << " edge -> " << GUID << " Threshold:" << Threshold << "\n"); + ValueInfo VI = Edge.first; + DEBUG(dbgs() << " edge -> " << VI.getGUID() << " Threshold:" << Threshold + << "\n"); + + if (VI.getSummaryList().empty()) { + // For SamplePGO, the indirect call targets for local functions will + // have its original name annotated in profile. We try to find the + // corresponding PGOFuncName as the GUID. + auto GUID = Index.getGUIDFromOriginalID(VI.getGUID()); + if (GUID == 0) + continue; + VI = Index.getValueInfo(GUID); + if (!VI) + continue; + } - if (DefinedGVSummaries.count(GUID)) { + if (DefinedGVSummaries.count(VI.getGUID())) { DEBUG(dbgs() << "ignored! Target already in destination module.\n"); continue; } @@ -196,13 +213,16 @@ static void computeImportForFunction( return ImportHotMultiplier; if (Hotness == CalleeInfo::HotnessType::Cold) return ImportColdMultiplier; + if (Hotness == CalleeInfo::HotnessType::Critical) + return ImportCriticalMultiplier; return 1.0; }; const auto NewThreshold = Threshold * GetBonusMultiplier(Edge.second.Hotness); - auto *CalleeSummary = selectCallee(GUID, NewThreshold, Index); + auto *CalleeSummary = selectCallee(Index, VI.getSummaryList(), NewThreshold, + Summary.modulePath()); if (!CalleeSummary) { DEBUG(dbgs() << "ignored! No qualifying callee with summary found.\n"); continue; @@ -234,7 +254,7 @@ static void computeImportForFunction( const auto AdjThreshold = GetAdjustedThreshold(Threshold, IsHotCallsite); auto ExportModulePath = ResolvedCalleeSummary->modulePath(); - auto &ProcessedThreshold = ImportList[ExportModulePath][GUID]; + auto &ProcessedThreshold = ImportList[ExportModulePath][VI.getGUID()]; /// Since the traversal of the call graph is DFS, we can revisit a function /// a second time with a higher threshold. In this case, it is added back to /// the worklist with the new threshold. @@ -250,7 +270,7 @@ static void computeImportForFunction( // Make exports in the source module. if (ExportLists) { auto &ExportList = (*ExportLists)[ExportModulePath]; - ExportList.insert(GUID); + ExportList.insert(VI.getGUID()); if (!PreviouslyImported) { // This is the first time this function was exported from its source // module, so mark all functions and globals it references as exported @@ -270,7 +290,7 @@ static void computeImportForFunction( } // Insert the newly imported function to the worklist. - Worklist.emplace_back(ResolvedCalleeSummary, AdjThreshold, GUID); + Worklist.emplace_back(ResolvedCalleeSummary, AdjThreshold, VI.getGUID()); } } @@ -280,8 +300,7 @@ static void computeImportForFunction( static void ComputeImportForModule( const GVSummaryMapTy &DefinedGVSummaries, const ModuleSummaryIndex &Index, FunctionImporter::ImportMapTy &ImportList, - StringMap<FunctionImporter::ExportSetTy> *ExportLists = nullptr, - const DenseSet<GlobalValue::GUID> *DeadSymbols = nullptr) { + StringMap<FunctionImporter::ExportSetTy> *ExportLists = nullptr) { // Worklist contains the list of function imported in this module, for which // we will analyse the callees and may import further down the callgraph. SmallVector<EdgeInfo, 128> Worklist; @@ -289,7 +308,7 @@ static void ComputeImportForModule( // Populate the worklist with the import for the functions in the current // module for (auto &GVSummary : DefinedGVSummaries) { - if (DeadSymbols && DeadSymbols->count(GVSummary.first)) { + if (!Index.isGlobalValueLive(GVSummary.second)) { DEBUG(dbgs() << "Ignores Dead GUID: " << GVSummary.first << "\n"); continue; } @@ -332,15 +351,14 @@ void llvm::ComputeCrossModuleImport( const ModuleSummaryIndex &Index, const StringMap<GVSummaryMapTy> &ModuleToDefinedGVSummaries, StringMap<FunctionImporter::ImportMapTy> &ImportLists, - StringMap<FunctionImporter::ExportSetTy> &ExportLists, - const DenseSet<GlobalValue::GUID> *DeadSymbols) { + StringMap<FunctionImporter::ExportSetTy> &ExportLists) { // For each module that has function defined, compute the import/export lists. for (auto &DefinedGVSummaries : ModuleToDefinedGVSummaries) { auto &ImportList = ImportLists[DefinedGVSummaries.first()]; DEBUG(dbgs() << "Computing import for Module '" << DefinedGVSummaries.first() << "'\n"); ComputeImportForModule(DefinedGVSummaries.second, Index, ImportList, - &ExportLists, DeadSymbols); + &ExportLists); } // When computing imports we added all GUIDs referenced by anything @@ -402,84 +420,71 @@ void llvm::ComputeCrossModuleImportForModule( #endif } -DenseSet<GlobalValue::GUID> llvm::computeDeadSymbols( - const ModuleSummaryIndex &Index, +void llvm::computeDeadSymbols( + ModuleSummaryIndex &Index, const DenseSet<GlobalValue::GUID> &GUIDPreservedSymbols) { + assert(!Index.withGlobalValueDeadStripping()); if (!ComputeDead) - return DenseSet<GlobalValue::GUID>(); + return; if (GUIDPreservedSymbols.empty()) // Don't do anything when nothing is live, this is friendly with tests. - return DenseSet<GlobalValue::GUID>(); - DenseSet<GlobalValue::GUID> LiveSymbols = GUIDPreservedSymbols; - SmallVector<GlobalValue::GUID, 128> Worklist; - Worklist.reserve(LiveSymbols.size() * 2); - for (auto GUID : LiveSymbols) { - DEBUG(dbgs() << "Live root: " << GUID << "\n"); - Worklist.push_back(GUID); - } - // Add values flagged in the index as live roots to the worklist. - for (const auto &Entry : Index) { - bool IsLiveRoot = llvm::any_of( - Entry.second, - [&](const std::unique_ptr<llvm::GlobalValueSummary> &Summary) { - return Summary->liveRoot(); - }); - if (!IsLiveRoot) + return; + unsigned LiveSymbols = 0; + SmallVector<ValueInfo, 128> Worklist; + Worklist.reserve(GUIDPreservedSymbols.size() * 2); + for (auto GUID : GUIDPreservedSymbols) { + ValueInfo VI = Index.getValueInfo(GUID); + if (!VI) continue; - DEBUG(dbgs() << "Live root (summary): " << Entry.first << "\n"); - Worklist.push_back(Entry.first); + for (auto &S : VI.getSummaryList()) + S->setLive(true); } - while (!Worklist.empty()) { - auto GUID = Worklist.pop_back_val(); - auto It = Index.findGlobalValueSummaryList(GUID); - if (It == Index.end()) { - DEBUG(dbgs() << "Not in index: " << GUID << "\n"); - continue; - } - - // FIXME: we should only make the prevailing copy live here - for (auto &Summary : It->second) { - for (auto Ref : Summary->refs()) { - auto RefGUID = Ref.getGUID(); - if (LiveSymbols.insert(RefGUID).second) { - DEBUG(dbgs() << "Marking live (ref): " << RefGUID << "\n"); - Worklist.push_back(RefGUID); - } - } - if (auto *FS = dyn_cast<FunctionSummary>(Summary.get())) { - for (auto Call : FS->calls()) { - auto CallGUID = Call.first.getGUID(); - if (LiveSymbols.insert(CallGUID).second) { - DEBUG(dbgs() << "Marking live (call): " << CallGUID << "\n"); - Worklist.push_back(CallGUID); - } - } + // Add values flagged in the index as live roots to the worklist. + for (const auto &Entry : Index) + for (auto &S : Entry.second.SummaryList) + if (S->isLive()) { + DEBUG(dbgs() << "Live root: " << Entry.first << "\n"); + Worklist.push_back(ValueInfo(&Entry)); + ++LiveSymbols; + break; } + + // Make value live and add it to the worklist if it was not live before. + // FIXME: we should only make the prevailing copy live here + auto visit = [&](ValueInfo VI) { + for (auto &S : VI.getSummaryList()) + if (S->isLive()) + return; + for (auto &S : VI.getSummaryList()) + S->setLive(true); + ++LiveSymbols; + Worklist.push_back(VI); + }; + + while (!Worklist.empty()) { + auto VI = Worklist.pop_back_val(); + for (auto &Summary : VI.getSummaryList()) { + for (auto Ref : Summary->refs()) + visit(Ref); + if (auto *FS = dyn_cast<FunctionSummary>(Summary.get())) + for (auto Call : FS->calls()) + visit(Call.first); if (auto *AS = dyn_cast<AliasSummary>(Summary.get())) { auto AliaseeGUID = AS->getAliasee().getOriginalName(); - if (LiveSymbols.insert(AliaseeGUID).second) { - DEBUG(dbgs() << "Marking live (alias): " << AliaseeGUID << "\n"); - Worklist.push_back(AliaseeGUID); - } + ValueInfo AliaseeVI = Index.getValueInfo(AliaseeGUID); + if (AliaseeVI) + visit(AliaseeVI); } } } - DenseSet<GlobalValue::GUID> DeadSymbols; - DeadSymbols.reserve( - std::min(Index.size(), Index.size() - LiveSymbols.size())); - for (auto &Entry : Index) { - auto GUID = Entry.first; - if (!LiveSymbols.count(GUID)) { - DEBUG(dbgs() << "Marking dead: " << GUID << "\n"); - DeadSymbols.insert(GUID); - } - } - DEBUG(dbgs() << LiveSymbols.size() << " symbols Live, and " - << DeadSymbols.size() << " symbols Dead \n"); - NumDeadSymbols += DeadSymbols.size(); - NumLiveSymbols += LiveSymbols.size(); - return DeadSymbols; + Index.setWithGlobalValueDeadStripping(); + + unsigned DeadSymbols = Index.size() - LiveSymbols; + DEBUG(dbgs() << LiveSymbols << " symbols Live, and " << DeadSymbols + << " symbols Dead \n"); + NumDeadSymbols += DeadSymbols; + NumLiveSymbols += LiveSymbols; } /// Compute the set of summaries needed for a ThinLTO backend compilation of @@ -522,9 +527,24 @@ llvm::EmitImportsFiles(StringRef ModulePath, StringRef OutputFilename, /// Fixup WeakForLinker linkages in \p TheModule based on summary analysis. void llvm::thinLTOResolveWeakForLinkerModule( Module &TheModule, const GVSummaryMapTy &DefinedGlobals) { + auto ConvertToDeclaration = [](GlobalValue &GV) { + DEBUG(dbgs() << "Converting to a declaration: `" << GV.getName() << "\n"); + if (Function *F = dyn_cast<Function>(&GV)) { + F->deleteBody(); + F->clearMetadata(); + } else if (GlobalVariable *V = dyn_cast<GlobalVariable>(&GV)) { + V->setInitializer(nullptr); + V->setLinkage(GlobalValue::ExternalLinkage); + V->clearMetadata(); + } else + // For now we don't resolve or drop aliases. Once we do we'll + // need to add support here for creating either a function or + // variable declaration, and return the new GlobalValue* for + // the caller to use. + llvm_unreachable("Expected function or variable"); + }; + auto updateLinkage = [&](GlobalValue &GV) { - if (!GlobalValue::isWeakForLinker(GV.getLinkage())) - return; // See if the global summary analysis computed a new resolved linkage. const auto &GS = DefinedGlobals.find(GV.getGUID()); if (GS == DefinedGlobals.end()) @@ -532,18 +552,40 @@ void llvm::thinLTOResolveWeakForLinkerModule( auto NewLinkage = GS->second->linkage(); if (NewLinkage == GV.getLinkage()) return; - DEBUG(dbgs() << "ODR fixing up linkage for `" << GV.getName() << "` from " - << GV.getLinkage() << " to " << NewLinkage << "\n"); - GV.setLinkage(NewLinkage); - // Remove functions converted to available_externally from comdats, + + // Switch the linkage to weakany if asked for, e.g. we do this for + // linker redefined symbols (via --wrap or --defsym). + // We record that the visibility should be changed here in `addThinLTO` + // as we need access to the resolution vectors for each input file in + // order to find which symbols have been redefined. + // We may consider reorganizing this code and moving the linkage recording + // somewhere else, e.g. in thinLTOResolveWeakForLinkerInIndex. + if (NewLinkage == GlobalValue::WeakAnyLinkage) { + GV.setLinkage(NewLinkage); + return; + } + + if (!GlobalValue::isWeakForLinker(GV.getLinkage())) + return; + // Check for a non-prevailing def that has interposable linkage + // (e.g. non-odr weak or linkonce). In that case we can't simply + // convert to available_externally, since it would lose the + // interposable property and possibly get inlined. Simply drop + // the definition in that case. + if (GlobalValue::isAvailableExternallyLinkage(NewLinkage) && + GlobalValue::isInterposableLinkage(GV.getLinkage())) + ConvertToDeclaration(GV); + else { + DEBUG(dbgs() << "ODR fixing up linkage for `" << GV.getName() << "` from " + << GV.getLinkage() << " to " << NewLinkage << "\n"); + GV.setLinkage(NewLinkage); + } + // Remove declarations from comdats, including available_externally // as this is a declaration for the linker, and will be dropped eventually. // It is illegal for comdats to contain declarations. auto *GO = dyn_cast_or_null<GlobalObject>(&GV); - if (GO && GO->isDeclarationForLinker() && GO->hasComdat()) { - assert(GO->hasAvailableExternallyLinkage() && - "Expected comdat on definition (possibly available external)"); + if (GO && GO->isDeclarationForLinker() && GO->hasComdat()) GO->setComdat(nullptr); - } }; // Process functions and global now @@ -562,7 +604,7 @@ void llvm::thinLTOInternalizeModule(Module &TheModule, // the current module. StringSet<> AsmUndefinedRefs; ModuleSymbolTable::CollectAsmSymbols( - Triple(TheModule.getTargetTriple()), TheModule.getModuleInlineAsm(), + TheModule, [&AsmUndefinedRefs](StringRef Name, object::BasicSymbolRef::Flags Flags) { if (Flags & object::BasicSymbolRef::SF_Undefined) AsmUndefinedRefs.insert(Name); @@ -576,8 +618,7 @@ void llvm::thinLTOInternalizeModule(Module &TheModule, return true; // Lookup the linkage recorded in the summaries during global analysis. - const auto &GS = DefinedGlobals.find(GV.getGUID()); - GlobalValue::LinkageTypes Linkage; + auto GS = DefinedGlobals.find(GV.getGUID()); if (GS == DefinedGlobals.end()) { // Must have been promoted (possibly conservatively). Find original // name so that we can access the correct summary and see if it can @@ -589,7 +630,7 @@ void llvm::thinLTOInternalizeModule(Module &TheModule, std::string OrigId = GlobalValue::getGlobalIdentifier( OrigName, GlobalValue::InternalLinkage, TheModule.getSourceFileName()); - const auto &GS = DefinedGlobals.find(GlobalValue::getGUID(OrigId)); + GS = DefinedGlobals.find(GlobalValue::getGUID(OrigId)); if (GS == DefinedGlobals.end()) { // Also check the original non-promoted non-globalized name. In some // cases a preempted weak value is linked in as a local copy because @@ -597,15 +638,11 @@ void llvm::thinLTOInternalizeModule(Module &TheModule, // In that case, since it was originally not a local value, it was // recorded in the index using the original name. // FIXME: This may not be needed once PR27866 is fixed. - const auto &GS = DefinedGlobals.find(GlobalValue::getGUID(OrigName)); + GS = DefinedGlobals.find(GlobalValue::getGUID(OrigName)); assert(GS != DefinedGlobals.end()); - Linkage = GS->second->linkage(); - } else { - Linkage = GS->second->linkage(); } - } else - Linkage = GS->second->linkage(); - return !GlobalValue::isLocalLinkage(Linkage); + } + return !GlobalValue::isLocalLinkage(GS->second->linkage()); }; // FIXME: See if we can just internalize directly here via linkage changes @@ -617,14 +654,12 @@ void llvm::thinLTOInternalizeModule(Module &TheModule, // index. // Expected<bool> FunctionImporter::importFunctions( - Module &DestModule, const FunctionImporter::ImportMapTy &ImportList, - bool ForceImportReferencedDiscardableSymbols) { + Module &DestModule, const FunctionImporter::ImportMapTy &ImportList) { DEBUG(dbgs() << "Starting import for Module " << DestModule.getModuleIdentifier() << "\n"); unsigned ImportedCount = 0; - // Linker that will be used for importing function - Linker TheLinker(DestModule); + IRMover Mover(DestModule); // Do the actual import of functions now, one Module at a time std::set<StringRef> ModuleNameOrderedList; for (auto &FunctionsToImportPerModule : ImportList) { @@ -648,7 +683,7 @@ Expected<bool> FunctionImporter::importFunctions( auto &ImportGUIDs = FunctionsToImportPerModule->second; // Find the globals to import - DenseSet<const GlobalValue *> GlobalsToImport; + SetVector<GlobalValue *> GlobalsToImport; for (Function &F : *SrcModule) { if (!F.hasName()) continue; @@ -687,6 +722,13 @@ Expected<bool> FunctionImporter::importFunctions( } } for (GlobalAlias &GA : SrcModule->aliases()) { + // FIXME: This should eventually be controlled entirely by the summary. + if (FunctionImportGlobalProcessing::doImportAsDefinition( + &GA, &GlobalsToImport)) { + GlobalsToImport.insert(&GA); + continue; + } + if (!GA.hasName()) continue; auto GUID = GA.getGUID(); @@ -731,12 +773,9 @@ Expected<bool> FunctionImporter::importFunctions( << " from " << SrcModule->getSourceFileName() << "\n"; } - // Instruct the linker that the client will take care of linkonce resolution - unsigned Flags = Linker::Flags::None; - if (!ForceImportReferencedDiscardableSymbols) - Flags |= Linker::Flags::DontForceLinkLinkonceODR; - - if (TheLinker.linkInModule(std::move(SrcModule), Flags, &GlobalsToImport)) + if (Mover.move(std::move(SrcModule), GlobalsToImport.getArrayRef(), + [](GlobalValue &, IRMover::ValueAdder) {}, + /*IsPerformingImport=*/true)) report_fatal_error("Function Import: link error"); ImportedCount += GlobalsToImport.size(); @@ -778,7 +817,7 @@ static bool doImportingForModule(Module &M) { // is only enabled when testing importing via the 'opt' tool, which does // not do the ThinLink that would normally determine what values to promote. for (auto &I : *Index) { - for (auto &S : I.second) { + for (auto &S : I.second.SummaryList) { if (GlobalValue::isLocalLinkage(S->linkage())) S->setLinkage(GlobalValue::ExternalLinkage); } @@ -796,8 +835,7 @@ static bool doImportingForModule(Module &M) { return loadFile(Identifier, M.getContext()); }; FunctionImporter Importer(*Index, ModuleLoader); - Expected<bool> Result = Importer.importFunctions( - M, ImportList, !DontForceImportReferencedDiscardableSymbols); + Expected<bool> Result = Importer.importFunctions(M, ImportList); // FIXME: Probably need to propagate Errors through the pass manager. if (!Result) { diff --git a/contrib/llvm/lib/Transforms/IPO/GlobalDCE.cpp b/contrib/llvm/lib/Transforms/IPO/GlobalDCE.cpp index 7a04de3..c91e8b4 100644 --- a/contrib/llvm/lib/Transforms/IPO/GlobalDCE.cpp +++ b/contrib/llvm/lib/Transforms/IPO/GlobalDCE.cpp @@ -25,7 +25,7 @@ #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Utils/CtorUtils.h" #include "llvm/Transforms/Utils/GlobalStatus.h" -#include <unordered_map> + using namespace llvm; #define DEBUG_TYPE "globaldce" @@ -50,7 +50,14 @@ namespace { if (skipModule(M)) return false; + // We need a minimally functional dummy module analysis manager. It needs + // to at least know about the possibility of proxying a function analysis + // manager. + FunctionAnalysisManager DummyFAM; ModuleAnalysisManager DummyMAM; + DummyMAM.registerPass( + [&] { return FunctionAnalysisManagerModuleProxy(DummyFAM); }); + auto PA = Impl.run(M, DummyMAM); return !PA.areAllPreserved(); } @@ -78,9 +85,67 @@ static bool isEmptyFunction(Function *F) { return RI.getReturnValue() == nullptr; } -PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &) { +/// Compute the set of GlobalValue that depends from V. +/// The recursion stops as soon as a GlobalValue is met. +void GlobalDCEPass::ComputeDependencies(Value *V, + SmallPtrSetImpl<GlobalValue *> &Deps) { + if (auto *I = dyn_cast<Instruction>(V)) { + Function *Parent = I->getParent()->getParent(); + Deps.insert(Parent); + } else if (auto *GV = dyn_cast<GlobalValue>(V)) { + Deps.insert(GV); + } else if (auto *CE = dyn_cast<Constant>(V)) { + // Avoid walking the whole tree of a big ConstantExprs multiple times. + auto Where = ConstantDependenciesCache.find(CE); + if (Where != ConstantDependenciesCache.end()) { + auto const &K = Where->second; + Deps.insert(K.begin(), K.end()); + } else { + SmallPtrSetImpl<GlobalValue *> &LocalDeps = ConstantDependenciesCache[CE]; + for (User *CEUser : CE->users()) + ComputeDependencies(CEUser, LocalDeps); + Deps.insert(LocalDeps.begin(), LocalDeps.end()); + } + } +} + +void GlobalDCEPass::UpdateGVDependencies(GlobalValue &GV) { + SmallPtrSet<GlobalValue *, 8> Deps; + for (User *User : GV.users()) + ComputeDependencies(User, Deps); + Deps.erase(&GV); // Remove self-reference. + for (GlobalValue *GVU : Deps) { + GVDependencies.insert(std::make_pair(GVU, &GV)); + } +} + +/// Mark Global value as Live +void GlobalDCEPass::MarkLive(GlobalValue &GV, + SmallVectorImpl<GlobalValue *> *Updates) { + auto const Ret = AliveGlobals.insert(&GV); + if (!Ret.second) + return; + + if (Updates) + Updates->push_back(&GV); + if (Comdat *C = GV.getComdat()) { + for (auto &&CM : make_range(ComdatMembers.equal_range(C))) + MarkLive(*CM.second, Updates); // Recursion depth is only two because only + // globals in the same comdat are visited. + } +} + +PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &MAM) { bool Changed = false; + // The algorithm first computes the set L of global variables that are + // trivially live. Then it walks the initialization of these variables to + // compute the globals used to initialize them, which effectively builds a + // directed graph where nodes are global variables, and an edge from A to B + // means B is used to initialize A. Finally, it propagates the liveness + // information through the graph starting from the nodes in L. Nodes note + // marked as alive are discarded. + // Remove empty functions from the global ctors list. Changed |= optimizeGlobalCtorsList(M, isEmptyFunction); @@ -103,21 +168,39 @@ PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &) { // initializer. if (!GO.isDeclaration() && !GO.hasAvailableExternallyLinkage()) if (!GO.isDiscardableIfUnused()) - GlobalIsNeeded(&GO); + MarkLive(GO); + + UpdateGVDependencies(GO); } + // Compute direct dependencies of aliases. for (GlobalAlias &GA : M.aliases()) { Changed |= RemoveUnusedGlobalValue(GA); // Externally visible aliases are needed. if (!GA.isDiscardableIfUnused()) - GlobalIsNeeded(&GA); + MarkLive(GA); + + UpdateGVDependencies(GA); } + // Compute direct dependencies of ifuncs. for (GlobalIFunc &GIF : M.ifuncs()) { Changed |= RemoveUnusedGlobalValue(GIF); // Externally visible ifuncs are needed. if (!GIF.isDiscardableIfUnused()) - GlobalIsNeeded(&GIF); + MarkLive(GIF); + + UpdateGVDependencies(GIF); + } + + // Propagate liveness from collected Global Values through the computed + // dependencies. + SmallVector<GlobalValue *, 8> NewLiveGVs{AliveGlobals.begin(), + AliveGlobals.end()}; + while (!NewLiveGVs.empty()) { + GlobalValue *LGV = NewLiveGVs.pop_back_val(); + for (auto &&GVD : make_range(GVDependencies.equal_range(LGV))) + MarkLive(*GVD.second, &NewLiveGVs); } // Now that all globals which are needed are in the AliveGlobals set, we loop @@ -154,7 +237,7 @@ PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &) { GA.setAliasee(nullptr); } - // The third pass drops targets of ifuncs which are dead... + // The fourth pass drops targets of ifuncs which are dead... std::vector<GlobalIFunc*> DeadIFuncs; for (GlobalIFunc &GIF : M.ifuncs()) if (!AliveGlobals.count(&GIF)) { @@ -188,7 +271,8 @@ PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &) { // Make sure that all memory is released AliveGlobals.clear(); - SeenConstants.clear(); + ConstantDependenciesCache.clear(); + GVDependencies.clear(); ComdatMembers.clear(); if (Changed) @@ -196,60 +280,6 @@ PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &) { return PreservedAnalyses::all(); } -/// GlobalIsNeeded - the specific global value as needed, and -/// recursively mark anything that it uses as also needed. -void GlobalDCEPass::GlobalIsNeeded(GlobalValue *G) { - // If the global is already in the set, no need to reprocess it. - if (!AliveGlobals.insert(G).second) - return; - - if (Comdat *C = G->getComdat()) { - for (auto &&CM : make_range(ComdatMembers.equal_range(C))) - GlobalIsNeeded(CM.second); - } - - if (GlobalVariable *GV = dyn_cast<GlobalVariable>(G)) { - // If this is a global variable, we must make sure to add any global values - // referenced by the initializer to the alive set. - if (GV->hasInitializer()) - MarkUsedGlobalsAsNeeded(GV->getInitializer()); - } else if (GlobalIndirectSymbol *GIS = dyn_cast<GlobalIndirectSymbol>(G)) { - // The target of a global alias or ifunc is needed. - MarkUsedGlobalsAsNeeded(GIS->getIndirectSymbol()); - } else { - // Otherwise this must be a function object. We have to scan the body of - // the function looking for constants and global values which are used as - // operands. Any operands of these types must be processed to ensure that - // any globals used will be marked as needed. - Function *F = cast<Function>(G); - - for (Use &U : F->operands()) - MarkUsedGlobalsAsNeeded(cast<Constant>(U.get())); - - for (BasicBlock &BB : *F) - for (Instruction &I : BB) - for (Use &U : I.operands()) - if (GlobalValue *GV = dyn_cast<GlobalValue>(U)) - GlobalIsNeeded(GV); - else if (Constant *C = dyn_cast<Constant>(U)) - MarkUsedGlobalsAsNeeded(C); - } -} - -void GlobalDCEPass::MarkUsedGlobalsAsNeeded(Constant *C) { - if (GlobalValue *GV = dyn_cast<GlobalValue>(C)) - return GlobalIsNeeded(GV); - - // Loop over all of the operands of the constant, adding any globals they - // use to the list of needed globals. - for (Use &U : C->operands()) { - // If we've already processed this constant there's no need to do it again. - Constant *Op = dyn_cast<Constant>(U); - if (Op && SeenConstants.insert(Op).second) - MarkUsedGlobalsAsNeeded(Op); - } -} - // RemoveUnusedGlobalValue - Loop over all of the uses of the specified // GlobalValue, looking for the constant pointer ref that may be pointing to it. // If found, check to see if the constant pointer ref is safe to destroy, and if diff --git a/contrib/llvm/lib/Transforms/IPO/GlobalOpt.cpp b/contrib/llvm/lib/Transforms/IPO/GlobalOpt.cpp index 5b0d5e3..93eab68 100644 --- a/contrib/llvm/lib/Transforms/IPO/GlobalOpt.cpp +++ b/contrib/llvm/lib/Transforms/IPO/GlobalOpt.cpp @@ -239,7 +239,7 @@ static bool CleanupConstantGlobalUsers(Value *V, Constant *Init, // we delete a constant array, we may also be holding pointer to one of its // elements (or an element of one of its elements if we're dealing with an // array of arrays) in the worklist. - SmallVector<WeakVH, 8> WorkList(V->user_begin(), V->user_end()); + SmallVector<WeakTrackingVH, 8> WorkList(V->user_begin(), V->user_end()); while (!WorkList.empty()) { Value *UV = WorkList.pop_back_val(); if (!UV) @@ -837,7 +837,7 @@ OptimizeGlobalAddressOfMalloc(GlobalVariable *GV, CallInst *CI, Type *AllocTy, if (StoreInst *SI = dyn_cast<StoreInst>(GV->user_back())) { // The global is initialized when the store to it occurs. new StoreInst(ConstantInt::getTrue(GV->getContext()), InitBool, false, 0, - SI->getOrdering(), SI->getSynchScope(), SI); + SI->getOrdering(), SI->getSyncScopeID(), SI); SI->eraseFromParent(); continue; } @@ -854,7 +854,7 @@ OptimizeGlobalAddressOfMalloc(GlobalVariable *GV, CallInst *CI, Type *AllocTy, // Replace the cmp X, 0 with a use of the bool value. // Sink the load to where the compare was, if atomic rules allow us to. Value *LV = new LoadInst(InitBool, InitBool->getName()+".val", false, 0, - LI->getOrdering(), LI->getSynchScope(), + LI->getOrdering(), LI->getSyncScopeID(), LI->isUnordered() ? (Instruction*)ICI : LI); InitBoolUsed = true; switch (ICI->getPredicate()) { @@ -1605,7 +1605,7 @@ static bool TryToShrinkGlobalToBoolean(GlobalVariable *GV, Constant *OtherVal) { assert(LI->getOperand(0) == GV && "Not a copy!"); // Insert a new load, to preserve the saved value. StoreVal = new LoadInst(NewGV, LI->getName()+".b", false, 0, - LI->getOrdering(), LI->getSynchScope(), LI); + LI->getOrdering(), LI->getSyncScopeID(), LI); } else { assert((isa<CastInst>(StoredVal) || isa<SelectInst>(StoredVal)) && "This is not a form that we understand!"); @@ -1614,12 +1614,12 @@ static bool TryToShrinkGlobalToBoolean(GlobalVariable *GV, Constant *OtherVal) { } } new StoreInst(StoreVal, NewGV, false, 0, - SI->getOrdering(), SI->getSynchScope(), SI); + SI->getOrdering(), SI->getSyncScopeID(), SI); } else { // Change the load into a load of bool then a select. LoadInst *LI = cast<LoadInst>(UI); LoadInst *NLI = new LoadInst(NewGV, LI->getName()+".b", false, 0, - LI->getOrdering(), LI->getSynchScope(), LI); + LI->getOrdering(), LI->getSyncScopeID(), LI); Value *NSI; if (IsOneZero) NSI = new ZExtInst(NLI, LI->getType(), "", LI); @@ -1792,7 +1792,9 @@ static void makeAllConstantUsesInstructions(Constant *C) { NewU->insertBefore(UI); UI->replaceUsesOfWith(U, NewU); } - U->dropAllReferences(); + // We've replaced all the uses, so destroy the constant. (destroyConstant + // will update value handles and metadata.) + U->destroyConstant(); } } @@ -1819,12 +1821,14 @@ static bool processInternalGlobal( GS.AccessingFunction->doesNotRecurse() && isPointerValueDeadOnEntryToFunction(GS.AccessingFunction, GV, LookupDomTree)) { + const DataLayout &DL = GV->getParent()->getDataLayout(); + DEBUG(dbgs() << "LOCALIZING GLOBAL: " << *GV << "\n"); Instruction &FirstI = const_cast<Instruction&>(*GS.AccessingFunction ->getEntryBlock().begin()); Type *ElemTy = GV->getValueType(); // FIXME: Pass Global's alignment when globals have alignment - AllocaInst *Alloca = new AllocaInst(ElemTy, nullptr, + AllocaInst *Alloca = new AllocaInst(ElemTy, DL.getAllocaAddrSpace(), nullptr, GV->getName(), &FirstI); if (!isa<UndefValue>(GV->getInitializer())) new StoreInst(GV->getInitializer(), Alloca, &FirstI); @@ -1977,16 +1981,11 @@ static void ChangeCalleesToFastCall(Function *F) { } } -static AttributeSet StripNest(LLVMContext &C, const AttributeSet &Attrs) { - for (unsigned i = 0, e = Attrs.getNumSlots(); i != e; ++i) { - unsigned Index = Attrs.getSlotIndex(i); - if (!Attrs.getSlotAttributes(i).hasAttribute(Index, Attribute::Nest)) - continue; - - // There can be only one. - return Attrs.removeAttribute(C, Index, Attribute::Nest); - } - +static AttributeList StripNest(LLVMContext &C, AttributeList Attrs) { + // There can be at most one attribute set with a nest attribute. + unsigned NestIndex; + if (Attrs.hasAttrSomewhere(Attribute::Nest, &NestIndex)) + return Attrs.removeAttribute(C, NestIndex, Attribute::Nest); return Attrs; } @@ -2027,6 +2026,24 @@ OptimizeFunctions(Module &M, TargetLibraryInfo *TLI, continue; } + // LLVM's definition of dominance allows instructions that are cyclic + // in unreachable blocks, e.g.: + // %pat = select i1 %condition, @global, i16* %pat + // because any instruction dominates an instruction in a block that's + // not reachable from entry. + // So, remove unreachable blocks from the function, because a) there's + // no point in analyzing them and b) GlobalOpt should otherwise grow + // some more complicated logic to break these cycles. + // Removing unreachable blocks might invalidate the dominator so we + // recalculate it. + if (!F->isDeclaration()) { + if (removeUnreachableBlocks(*F)) { + auto &DT = LookupDomTree(*F); + DT.recalculate(*F); + Changed = true; + } + } + Changed |= processGlobal(*F, TLI, LookupDomTree); if (!F->hasLocalLinkage()) @@ -2387,7 +2404,7 @@ OptimizeGlobalAliases(Module &M, } static Function *FindCXAAtExit(Module &M, TargetLibraryInfo *TLI) { - LibFunc::Func F = LibFunc::cxa_atexit; + LibFunc F = LibFunc_cxa_atexit; if (!TLI->has(F)) return nullptr; @@ -2396,7 +2413,7 @@ static Function *FindCXAAtExit(Module &M, TargetLibraryInfo *TLI) { return nullptr; // Make sure that the function has the correct prototype. - if (!TLI->getLibFunc(*Fn, F) || F != LibFunc::cxa_atexit) + if (!TLI->getLibFunc(*Fn, F) || F != LibFunc_cxa_atexit) return nullptr; return Fn; diff --git a/contrib/llvm/lib/Transforms/IPO/GlobalSplit.cpp b/contrib/llvm/lib/Transforms/IPO/GlobalSplit.cpp index bbbd096..e47d881 100644 --- a/contrib/llvm/lib/Transforms/IPO/GlobalSplit.cpp +++ b/contrib/llvm/lib/Transforms/IPO/GlobalSplit.cpp @@ -14,7 +14,6 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/GlobalSplit.h" #include "llvm/ADT/StringExtras.h" #include "llvm/IR/Constants.h" @@ -23,6 +22,7 @@ #include "llvm/IR/Module.h" #include "llvm/IR/Operator.h" #include "llvm/Pass.h" +#include "llvm/Transforms/IPO.h" #include <set> @@ -85,7 +85,16 @@ bool splitGlobal(GlobalVariable &GV) { uint64_t ByteOffset = cast<ConstantInt>( cast<ConstantAsMetadata>(Type->getOperand(0))->getValue()) ->getZExtValue(); - if (ByteOffset < SplitBegin || ByteOffset >= SplitEnd) + // Type metadata may be attached one byte after the end of the vtable, for + // classes without virtual methods in Itanium ABI. AFAIK, it is never + // attached to the first byte of a vtable. Subtract one to get the right + // slice. + // This is making an assumption that vtable groups are the only kinds of + // global variables that !type metadata can be attached to, and that they + // are either Itanium ABI vtable groups or contain a single vtable (i.e. + // Microsoft ABI vtables). + uint64_t AttachedTo = (ByteOffset == 0) ? ByteOffset : ByteOffset - 1; + if (AttachedTo < SplitBegin || AttachedTo >= SplitEnd) continue; SplitGV->addMetadata( LLVMContext::MD_type, diff --git a/contrib/llvm/lib/Transforms/IPO/IPConstantPropagation.cpp b/contrib/llvm/lib/Transforms/IPO/IPConstantPropagation.cpp index 916135e..f79b610 100644 --- a/contrib/llvm/lib/Transforms/IPO/IPConstantPropagation.cpp +++ b/contrib/llvm/lib/Transforms/IPO/IPConstantPropagation.cpp @@ -15,7 +15,6 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/IPO.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/ValueTracking.h" @@ -24,6 +23,7 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" #include "llvm/Pass.h" +#include "llvm/Transforms/IPO.h" using namespace llvm; #define DEBUG_TYPE "ipconstprop" @@ -136,7 +136,13 @@ static bool PropagateConstantReturn(Function &F) { // For more details, see GlobalValue::mayBeDerefined. if (!F.isDefinitionExact()) return false; - + + // Don't touch naked functions. The may contain asm returning + // value we don't see, so we may end up interprocedurally propagating + // the return value incorrectly. + if (F.hasFnAttribute(Attribute::Naked)) + return false; + // Check to see if this function returns a constant. SmallVector<Value *,4> RetVals; StructType *STy = dyn_cast<StructType>(F.getReturnType()); diff --git a/contrib/llvm/lib/Transforms/IPO/IPO.cpp b/contrib/llvm/lib/Transforms/IPO/IPO.cpp index 89518f3..5bb305c 100644 --- a/contrib/llvm/lib/Transforms/IPO/IPO.cpp +++ b/contrib/llvm/lib/Transforms/IPO/IPO.cpp @@ -13,10 +13,10 @@ // //===----------------------------------------------------------------------===// -#include "llvm-c/Initialization.h" #include "llvm-c/Transforms/IPO.h" -#include "llvm/InitializePasses.h" +#include "llvm-c/Initialization.h" #include "llvm/IR/LegacyPassManager.h" +#include "llvm/InitializePasses.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/AlwaysInliner.h" #include "llvm/Transforms/IPO/FunctionAttrs.h" diff --git a/contrib/llvm/lib/Transforms/IPO/InferFunctionAttrs.cpp b/contrib/llvm/lib/Transforms/IPO/InferFunctionAttrs.cpp index 2ef299d..15d7515 100644 --- a/contrib/llvm/lib/Transforms/IPO/InferFunctionAttrs.cpp +++ b/contrib/llvm/lib/Transforms/IPO/InferFunctionAttrs.cpp @@ -8,8 +8,8 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/IPO/InferFunctionAttrs.h" -#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/Function.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" diff --git a/contrib/llvm/lib/Transforms/IPO/InlineSimple.cpp b/contrib/llvm/lib/Transforms/IPO/InlineSimple.cpp index 1770445..50e7cc8 100644 --- a/contrib/llvm/lib/Transforms/IPO/InlineSimple.cpp +++ b/contrib/llvm/lib/Transforms/IPO/InlineSimple.cpp @@ -48,7 +48,7 @@ public: } explicit SimpleInliner(InlineParams Params) - : LegacyInlinerBase(ID), Params(Params) { + : LegacyInlinerBase(ID), Params(std::move(Params)) { initializeSimpleInlinerPass(*PassRegistry::getPassRegistry()); } @@ -61,7 +61,8 @@ public: [&](Function &F) -> AssumptionCache & { return ACT->getAssumptionCache(F); }; - return llvm::getInlineCost(CS, Params, TTI, GetAssumptionCache, PSI); + return llvm::getInlineCost(CS, Params, TTI, GetAssumptionCache, + /*GetBFI=*/None, PSI); } bool runOnSCC(CallGraphSCC &SCC) override; @@ -92,8 +93,12 @@ Pass *llvm::createFunctionInliningPass(int Threshold) { } Pass *llvm::createFunctionInliningPass(unsigned OptLevel, - unsigned SizeOptLevel) { - return new SimpleInliner(llvm::getInlineParams(OptLevel, SizeOptLevel)); + unsigned SizeOptLevel, + bool DisableInlineHotCallSite) { + auto Param = llvm::getInlineParams(OptLevel, SizeOptLevel); + if (DisableInlineHotCallSite) + Param.HotCallSiteThreshold = 0; + return new SimpleInliner(Param); } Pass *llvm::createFunctionInliningPass(InlineParams &Params) { diff --git a/contrib/llvm/lib/Transforms/IPO/Inliner.cpp b/contrib/llvm/lib/Transforms/IPO/Inliner.cpp index 3f4731c..317770d 100644 --- a/contrib/llvm/lib/Transforms/IPO/Inliner.cpp +++ b/contrib/llvm/lib/Transforms/IPO/Inliner.cpp @@ -19,6 +19,7 @@ #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/InlineCost.h" #include "llvm/Analysis/OptimizationDiagnosticInfo.h" @@ -260,8 +261,8 @@ static bool InlineCallIfPossible( /// Return true if inlining of CS can block the caller from being /// inlined which is proved to be more beneficial. \p IC is the /// estimated inline cost associated with callsite \p CS. -/// \p TotalAltCost will be set to the estimated cost of inlining the caller -/// if \p CS is suppressed for inlining. +/// \p TotalSecondaryCost will be set to the estimated cost of inlining the +/// caller if \p CS is suppressed for inlining. static bool shouldBeDeferred(Function *Caller, CallSite CS, InlineCost IC, int &TotalSecondaryCost, @@ -288,7 +289,7 @@ shouldBeDeferred(Function *Caller, CallSite CS, InlineCost IC, // treating them as truly abstract units etc. TotalSecondaryCost = 0; // The candidate cost to be imposed upon the current function. - int CandidateCost = IC.getCost() - (InlineConstants::CallPenalty + 1); + int CandidateCost = IC.getCost() - 1; // This bool tracks what happens if we do NOT inline C into B. bool callerWillBeRemoved = Caller->hasLocalLinkage(); // This bool tracks what happens if we DO inline C into B. @@ -325,7 +326,7 @@ shouldBeDeferred(Function *Caller, CallSite CS, InlineCost IC, // one is set very low by getInlineCost, in anticipation that Caller will // be removed entirely. We did not account for this above unless there // is only one caller of Caller. - if (callerWillBeRemoved && !Caller->use_empty()) + if (callerWillBeRemoved && !Caller->hasOneUse()) TotalSecondaryCost -= InlineConstants::LastCallToStaticBonus; if (inliningPreventsSomeOuterInline && TotalSecondaryCost < IC.getCost()) @@ -342,6 +343,7 @@ static bool shouldInline(CallSite CS, InlineCost IC = GetInlineCost(CS); Instruction *Call = CS.getInstruction(); Function *Callee = CS.getCalledFunction(); + Function *Caller = CS.getCaller(); if (IC.isAlways()) { DEBUG(dbgs() << " Inlining: cost=always" @@ -355,19 +357,20 @@ static bool shouldInline(CallSite CS, if (IC.isNever()) { DEBUG(dbgs() << " NOT Inlining: cost=never" << ", Call: " << *CS.getInstruction() << "\n"); - ORE.emit(OptimizationRemarkAnalysis(DEBUG_TYPE, "NeverInline", Call) - << NV("Callee", Callee) - << " should never be inlined (cost=never)"); + ORE.emit(OptimizationRemarkMissed(DEBUG_TYPE, "NeverInline", Call) + << NV("Callee", Callee) << " not inlined into " + << NV("Caller", Caller) + << " because it should never be inlined (cost=never)"); return false; } - Function *Caller = CS.getCaller(); if (!IC) { DEBUG(dbgs() << " NOT Inlining: cost=" << IC.getCost() << ", thres=" << (IC.getCostDelta() + IC.getCost()) << ", Call: " << *CS.getInstruction() << "\n"); - ORE.emit(OptimizationRemarkAnalysis(DEBUG_TYPE, "TooCostly", Call) - << NV("Callee", Callee) << " too costly to inline (cost=" + ORE.emit(OptimizationRemarkMissed(DEBUG_TYPE, "TooCostly", Call) + << NV("Callee", Callee) << " not inlined into " + << NV("Caller", Caller) << " because too costly to inline (cost=" << NV("Cost", IC.getCost()) << ", threshold=" << NV("Threshold", IC.getCostDelta() + IC.getCost()) << ")"); return false; @@ -378,8 +381,8 @@ static bool shouldInline(CallSite CS, DEBUG(dbgs() << " NOT Inlining: " << *CS.getInstruction() << " Cost = " << IC.getCost() << ", outer Cost = " << TotalSecondaryCost << '\n'); - ORE.emit(OptimizationRemarkAnalysis(DEBUG_TYPE, - "IncreaseCostInOtherContexts", Call) + ORE.emit(OptimizationRemarkMissed(DEBUG_TYPE, "IncreaseCostInOtherContexts", + Call) << "Not inlining. Cost of inlining " << NV("Callee", Callee) << " increases the cost of inlining " << NV("Caller", Caller) << " in other contexts"); @@ -499,7 +502,7 @@ inlineCallsImpl(CallGraphSCC &SCC, CallGraph &CG, std::swap(CallSites[i--], CallSites[--FirstCallInSCC]); InlinedArrayAllocasTy InlinedArrayAllocas; - InlineFunctionInfo InlineInfo(&CG, &GetAssumptionCache); + InlineFunctionInfo InlineInfo(&CG, &GetAssumptionCache, PSI); // Now that we have all of the call sites, loop over them and inline them if // it looks profitable to do so. @@ -516,52 +519,54 @@ inlineCallsImpl(CallGraphSCC &SCC, CallGraph &CG, Function *Caller = CS.getCaller(); Function *Callee = CS.getCalledFunction(); - // If this call site is dead and it is to a readonly function, we should - // just delete the call instead of trying to inline it, regardless of - // size. This happens because IPSCCP propagates the result out of the - // call and then we're left with the dead call. - if (isInstructionTriviallyDead(CS.getInstruction(), &TLI)) { - DEBUG(dbgs() << " -> Deleting dead call: " << *CS.getInstruction() - << "\n"); - // Update the call graph by deleting the edge from Callee to Caller. - CG[Caller]->removeCallEdgeFor(CS); - CS.getInstruction()->eraseFromParent(); - ++NumCallsDeleted; - } else { - // We can only inline direct calls to non-declarations. - if (!Callee || Callee->isDeclaration()) - continue; + // We can only inline direct calls to non-declarations. + if (!Callee || Callee->isDeclaration()) + continue; + Instruction *Instr = CS.getInstruction(); + + bool IsTriviallyDead = isInstructionTriviallyDead(Instr, &TLI); + + int InlineHistoryID; + if (!IsTriviallyDead) { // If this call site was obtained by inlining another function, verify // that the include path for the function did not include the callee // itself. If so, we'd be recursively inlining the same function, // which would provide the same callsites, which would cause us to // infinitely inline. - int InlineHistoryID = CallSites[CSi].second; + InlineHistoryID = CallSites[CSi].second; if (InlineHistoryID != -1 && InlineHistoryIncludes(Callee, InlineHistoryID, InlineHistory)) continue; + } + // FIXME for new PM: because of the old PM we currently generate ORE and + // in turn BFI on demand. With the new PM, the ORE dependency should + // just become a regular analysis dependency. + OptimizationRemarkEmitter ORE(Caller); + + // If the policy determines that we should inline this function, + // delete the call instead. + if (!shouldInline(CS, GetInlineCost, ORE)) + continue; + + // If this call site is dead and it is to a readonly function, we should + // just delete the call instead of trying to inline it, regardless of + // size. This happens because IPSCCP propagates the result out of the + // call and then we're left with the dead call. + if (IsTriviallyDead) { + DEBUG(dbgs() << " -> Deleting dead call: " << *Instr << "\n"); + // Update the call graph by deleting the edge from Callee to Caller. + CG[Caller]->removeCallEdgeFor(CS); + Instr->eraseFromParent(); + ++NumCallsDeleted; + } else { // Get DebugLoc to report. CS will be invalid after Inliner. - DebugLoc DLoc = CS.getInstruction()->getDebugLoc(); + DebugLoc DLoc = Instr->getDebugLoc(); BasicBlock *Block = CS.getParent(); - // FIXME for new PM: because of the old PM we currently generate ORE and - // in turn BFI on demand. With the new PM, the ORE dependency should - // just become a regular analysis dependency. - OptimizationRemarkEmitter ORE(Caller); - - // If the policy determines that we should inline this function, - // try to do so. - using namespace ore; - if (!shouldInline(CS, GetInlineCost, ORE)) { - ORE.emit( - OptimizationRemarkMissed(DEBUG_TYPE, "NotInlined", DLoc, Block) - << NV("Callee", Callee) << " will not be inlined into " - << NV("Caller", Caller)); - continue; - } // Attempt to inline the function. + using namespace ore; if (!InlineCallIfPossible(CS, InlineInfo, InlinedArrayAllocas, InlineHistoryID, InsertLifetime, AARGetter, ImportedFunctionsStats)) { @@ -638,22 +643,12 @@ bool LegacyInlinerBase::inlineCalls(CallGraphSCC &SCC) { ACT = &getAnalysis<AssumptionCacheTracker>(); PSI = getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); - // We compute dedicated AA results for each function in the SCC as needed. We - // use a lambda referencing external objects so that they live long enough to - // be queried, but we re-use them each time. - Optional<BasicAAResult> BAR; - Optional<AAResults> AAR; - auto AARGetter = [&](Function &F) -> AAResults & { - BAR.emplace(createLegacyPMBasicAAResult(*this, F)); - AAR.emplace(createLegacyPMAAResults(*this, F, *BAR)); - return *AAR; - }; auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & { return ACT->getAssumptionCache(F); }; return inlineCallsImpl(SCC, CG, GetAssumptionCache, PSI, TLI, InsertLifetime, [this](CallSite CS) { return getInlineCost(CS); }, - AARGetter, ImportedFunctionsStats); + LegacyAARGetter(*this), ImportedFunctionsStats); } /// Remove now-dead linkonce functions at the end of @@ -750,9 +745,6 @@ bool LegacyInlinerBase::removeDeadFunctions(CallGraph &CG, PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, CGSCCAnalysisManager &AM, LazyCallGraph &CG, CGSCCUpdateResult &UR) { - FunctionAnalysisManager &FAM = - AM.getResult<FunctionAnalysisManagerCGSCCProxy>(InitialC, CG) - .getManager(); const ModuleAnalysisManager &MAM = AM.getResult<ModuleAnalysisManagerCGSCCProxy>(InitialC, CG).getManager(); bool Changed = false; @@ -761,35 +753,52 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, Module &M = *InitialC.begin()->getFunction().getParent(); ProfileSummaryInfo *PSI = MAM.getCachedResult<ProfileSummaryAnalysis>(M); - std::function<AssumptionCache &(Function &)> GetAssumptionCache = - [&](Function &F) -> AssumptionCache & { - return FAM.getResult<AssumptionAnalysis>(F); - }; - - // Setup the data structure used to plumb customization into the - // `InlineFunction` routine. - InlineFunctionInfo IFI(/*cg=*/nullptr, &GetAssumptionCache); + // We use a single common worklist for calls across the entire SCC. We + // process these in-order and append new calls introduced during inlining to + // the end. + // + // Note that this particular order of processing is actually critical to + // avoid very bad behaviors. Consider *highly connected* call graphs where + // each function contains a small amonut of code and a couple of calls to + // other functions. Because the LLVM inliner is fundamentally a bottom-up + // inliner, it can handle gracefully the fact that these all appear to be + // reasonable inlining candidates as it will flatten things until they become + // too big to inline, and then move on and flatten another batch. + // + // However, when processing call edges *within* an SCC we cannot rely on this + // bottom-up behavior. As a consequence, with heavily connected *SCCs* of + // functions we can end up incrementally inlining N calls into each of + // N functions because each incremental inlining decision looks good and we + // don't have a topological ordering to prevent explosions. + // + // To compensate for this, we don't process transitive edges made immediate + // by inlining until we've done one pass of inlining across the entire SCC. + // Large, highly connected SCCs still lead to some amount of code bloat in + // this model, but it is uniformly spread across all the functions in the SCC + // and eventually they all become too large to inline, rather than + // incrementally maknig a single function grow in a super linear fashion. + SmallVector<std::pair<CallSite, int>, 16> Calls; - auto GetInlineCost = [&](CallSite CS) { - Function &Callee = *CS.getCalledFunction(); - auto &CalleeTTI = FAM.getResult<TargetIRAnalysis>(Callee); - return getInlineCost(CS, Params, CalleeTTI, GetAssumptionCache, PSI); - }; + // Populate the initial list of calls in this SCC. + for (auto &N : InitialC) { + // We want to generally process call sites top-down in order for + // simplifications stemming from replacing the call with the returned value + // after inlining to be visible to subsequent inlining decisions. + // FIXME: Using instructions sequence is a really bad way to do this. + // Instead we should do an actual RPO walk of the function body. + for (Instruction &I : instructions(N.getFunction())) + if (auto CS = CallSite(&I)) + if (Function *Callee = CS.getCalledFunction()) + if (!Callee->isDeclaration()) + Calls.push_back({CS, -1}); + } + if (Calls.empty()) + return PreservedAnalyses::all(); - // We use a worklist of nodes to process so that we can handle if the SCC - // structure changes and some nodes are no longer part of the current SCC. We - // also need to use an updatable pointer for the SCC as a consequence. - SmallVector<LazyCallGraph::Node *, 16> Nodes; - for (auto &N : InitialC) - Nodes.push_back(&N); + // Capture updatable variables for the current SCC and RefSCC. auto *C = &InitialC; auto *RC = &C->getOuterRefSCC(); - // We also use a secondary worklist of call sites within a particular node to - // allow quickly continuing to inline through newly inlined call sites where - // possible. - SmallVector<std::pair<CallSite, int>, 16> Calls; - // When inlining a callee produces new call sites, we want to keep track of // the fact that they were inlined from the callee. This allows us to avoid // infinite inlining in some obscure cases. To represent this, we use an @@ -805,34 +814,58 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, // defer deleting these to make it easier to handle the call graph updates. SmallVector<Function *, 4> DeadFunctions; - do { - auto &N = *Nodes.pop_back_val(); + // Loop forward over all of the calls. Note that we cannot cache the size as + // inlining can introduce new calls that need to be processed. + for (int i = 0; i < (int)Calls.size(); ++i) { + // We expect the calls to typically be batched with sequences of calls that + // have the same caller, so we first set up some shared infrastructure for + // this caller. We also do any pruning we can at this layer on the caller + // alone. + Function &F = *Calls[i].first.getCaller(); + LazyCallGraph::Node &N = *CG.lookup(F); if (CG.lookupSCC(N) != C) continue; - Function &F = N.getFunction(); if (F.hasFnAttribute(Attribute::OptimizeNone)) continue; + DEBUG(dbgs() << "Inlining calls in: " << F.getName() << "\n"); + + // Get a FunctionAnalysisManager via a proxy for this particular node. We + // do this each time we visit a node as the SCC may have changed and as + // we're going to mutate this particular function we want to make sure the + // proxy is in place to forward any invalidation events. We can use the + // manager we get here for looking up results for functions other than this + // node however because those functions aren't going to be mutated by this + // pass. + FunctionAnalysisManager &FAM = + AM.getResult<FunctionAnalysisManagerCGSCCProxy>(*C, CG) + .getManager(); + std::function<AssumptionCache &(Function &)> GetAssumptionCache = + [&](Function &F) -> AssumptionCache & { + return FAM.getResult<AssumptionAnalysis>(F); + }; + auto GetBFI = [&](Function &F) -> BlockFrequencyInfo & { + return FAM.getResult<BlockFrequencyAnalysis>(F); + }; + + auto GetInlineCost = [&](CallSite CS) { + Function &Callee = *CS.getCalledFunction(); + auto &CalleeTTI = FAM.getResult<TargetIRAnalysis>(Callee); + return getInlineCost(CS, Params, CalleeTTI, GetAssumptionCache, {GetBFI}, + PSI); + }; + // Get the remarks emission analysis for the caller. auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F); - // We want to generally process call sites top-down in order for - // simplifications stemming from replacing the call with the returned value - // after inlining to be visible to subsequent inlining decisions. So we - // walk the function backwards and then process the back of the vector. - // FIXME: Using reverse is a really bad way to do this. Instead we should - // do an actual PO walk of the function body. - for (Instruction &I : reverse(instructions(F))) - if (auto CS = CallSite(&I)) - if (Function *Callee = CS.getCalledFunction()) - if (!Callee->isDeclaration()) - Calls.push_back({CS, -1}); - + // Now process as many calls as we have within this caller in the sequnece. + // We bail out as soon as the caller has to change so we can update the + // call graph and prepare the context of that new caller. bool DidInline = false; - while (!Calls.empty()) { + for (; i < (int)Calls.size() && Calls[i].first.getCaller() == &F; ++i) { int InlineHistoryID; CallSite CS; - std::tie(CS, InlineHistoryID) = Calls.pop_back_val(); + std::tie(CS, InlineHistoryID) = Calls[i]; Function &Callee = *CS.getCalledFunction(); if (InlineHistoryID != -1 && @@ -843,6 +876,13 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, if (!shouldInline(CS, GetInlineCost, ORE)) continue; + // Setup the data structure used to plumb customization into the + // `InlineFunction` routine. + InlineFunctionInfo IFI( + /*cg=*/nullptr, &GetAssumptionCache, PSI, + &FAM.getResult<BlockFrequencyAnalysis>(*(CS.getCaller())), + &FAM.getResult<BlockFrequencyAnalysis>(Callee)); + if (!InlineFunction(CS, IFI)) continue; DidInline = true; @@ -869,7 +909,13 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, // To check this we also need to nuke any dead constant uses (perhaps // made dead by this operation on other functions). Callee.removeDeadConstantUsers(); - if (Callee.use_empty()) { + if (Callee.use_empty() && !CG.isLibFunction(Callee)) { + Calls.erase( + std::remove_if(Calls.begin() + i + 1, Calls.end(), + [&Callee](const std::pair<CallSite, int> &Call) { + return Call.first.getCaller() == &Callee; + }), + Calls.end()); // Clear the body and queue the function itself for deletion when we // finish inlining and call graph updates. // Note that after this point, it is an error to do anything other @@ -882,6 +928,10 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, } } + // Back the call index up by one to put us in a good position to go around + // the outer loop. + --i; + if (!DidInline) continue; Changed = true; @@ -896,8 +946,8 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, // below. for (Function *InlinedCallee : InlinedCallees) { LazyCallGraph::Node &CalleeN = *CG.lookup(*InlinedCallee); - for (LazyCallGraph::Edge &E : CalleeN) - RC->insertTrivialRefEdge(N, *E.getNode()); + for (LazyCallGraph::Edge &E : *CalleeN) + RC->insertTrivialRefEdge(N, E.getNode()); } InlinedCallees.clear(); @@ -908,8 +958,9 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, // re-use the exact same logic for updating the call graph to reflect the // change.. C = &updateCGAndAnalysisManagerForFunctionPass(CG, *C, N, AM, UR); + DEBUG(dbgs() << "Updated inlining SCC: " << *C << "\n"); RC = &C->getOuterRefSCC(); - } while (!Nodes.empty()); + } // Now that we've finished inlining all of the calls across this SCC, delete // all of the trivially dead functions, updating the call graph and the CGSCC @@ -920,8 +971,13 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, // sets. for (Function *DeadF : DeadFunctions) { // Get the necessary information out of the call graph and nuke the - // function there. + // function there. Also, cclear out any cached analyses. auto &DeadC = *CG.lookupSCC(*CG.lookup(*DeadF)); + FunctionAnalysisManager &FAM = + AM.getResult<FunctionAnalysisManagerCGSCCProxy>(DeadC, CG) + .getManager(); + FAM.clear(*DeadF); + AM.clear(DeadC); auto &DeadRC = DeadC.getOuterRefSCC(); CG.removeDeadFunction(*DeadF); @@ -933,5 +989,13 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, // And delete the actual function from the module. M.getFunctionList().erase(DeadF); } - return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); + + if (!Changed) + return PreservedAnalyses::all(); + + // Even if we change the IR, we update the core CGSCC data structures and so + // can preserve the proxy to the function analysis manager. + PreservedAnalyses PA; + PA.preserve<FunctionAnalysisManagerCGSCCProxy>(); + return PA; } diff --git a/contrib/llvm/lib/Transforms/IPO/LoopExtractor.cpp b/contrib/llvm/lib/Transforms/IPO/LoopExtractor.cpp index f898c3b..c74b0a3 100644 --- a/contrib/llvm/lib/Transforms/IPO/LoopExtractor.cpp +++ b/contrib/llvm/lib/Transforms/IPO/LoopExtractor.cpp @@ -14,7 +14,6 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/IPO.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/IR/Dominators.h" @@ -22,6 +21,7 @@ #include "llvm/IR/Module.h" #include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/CodeExtractor.h" diff --git a/contrib/llvm/lib/Transforms/IPO/LowerTypeTests.cpp b/contrib/llvm/lib/Transforms/IPO/LowerTypeTests.cpp index deb7e81..693df5e 100644 --- a/contrib/llvm/lib/Transforms/IPO/LowerTypeTests.cpp +++ b/contrib/llvm/lib/Transforms/IPO/LowerTypeTests.cpp @@ -17,6 +17,7 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/Triple.h" +#include "llvm/Analysis/TypeMetadataUtils.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" @@ -42,8 +43,6 @@ using namespace llvm; using namespace lowertypetests; -using SummaryAction = LowerTypeTestsSummaryAction; - #define DEBUG_TYPE "lowertypetests" STATISTIC(ByteArraySizeBits, "Byte array size in bits"); @@ -57,13 +56,13 @@ static cl::opt<bool> AvoidReuse( cl::desc("Try to avoid reuse of byte array addresses using aliases"), cl::Hidden, cl::init(true)); -static cl::opt<SummaryAction> ClSummaryAction( +static cl::opt<PassSummaryAction> ClSummaryAction( "lowertypetests-summary-action", cl::desc("What to do with the summary when running this pass"), - cl::values(clEnumValN(SummaryAction::None, "none", "Do nothing"), - clEnumValN(SummaryAction::Import, "import", + cl::values(clEnumValN(PassSummaryAction::None, "none", "Do nothing"), + clEnumValN(PassSummaryAction::Import, "import", "Import typeid resolutions from summary and globals"), - clEnumValN(SummaryAction::Export, "export", + clEnumValN(PassSummaryAction::Export, "export", "Export typeid resolutions to summary and globals")), cl::Hidden); @@ -208,17 +207,26 @@ struct ByteArrayInfo { class GlobalTypeMember final : TrailingObjects<GlobalTypeMember, MDNode *> { GlobalObject *GO; size_t NTypes; + // For functions: true if this is a definition (either in the merged module or + // in one of the thinlto modules). + bool IsDefinition; + // For functions: true if this function is either defined or used in a thinlto + // module and its jumptable entry needs to be exported to thinlto backends. + bool IsExported; friend TrailingObjects; size_t numTrailingObjects(OverloadToken<MDNode *>) const { return NTypes; } public: static GlobalTypeMember *create(BumpPtrAllocator &Alloc, GlobalObject *GO, + bool IsDefinition, bool IsExported, ArrayRef<MDNode *> Types) { auto *GTM = static_cast<GlobalTypeMember *>(Alloc.Allocate( totalSizeToAlloc<MDNode *>(Types.size()), alignof(GlobalTypeMember))); GTM->GO = GO; GTM->NTypes = Types.size(); + GTM->IsDefinition = IsDefinition; + GTM->IsExported = IsExported; std::uninitialized_copy(Types.begin(), Types.end(), GTM->getTrailingObjects<MDNode *>()); return GTM; @@ -226,6 +234,12 @@ public: GlobalObject *getGlobal() const { return GO; } + bool isDefinition() const { + return IsDefinition; + } + bool isExported() const { + return IsExported; + } ArrayRef<MDNode *> types() const { return makeArrayRef(getTrailingObjects<MDNode *>(), NTypes); } @@ -234,10 +248,9 @@ public: class LowerTypeTestsModule { Module &M; - SummaryAction Action; - ModuleSummaryIndex *Summary; + ModuleSummaryIndex *ExportSummary; + const ModuleSummaryIndex *ImportSummary; - bool LinkerSubsectionsViaSymbols; Triple::ArchType Arch; Triple::OSType OS; Triple::ObjectFormatType ObjectFormat; @@ -253,15 +266,21 @@ class LowerTypeTestsModule { // Indirect function call index assignment counter for WebAssembly uint64_t IndirectIndex = 1; - // Mapping from type identifiers to the call sites that test them. - DenseMap<Metadata *, std::vector<CallInst *>> TypeTestCallSites; + // Mapping from type identifiers to the call sites that test them, as well as + // whether the type identifier needs to be exported to ThinLTO backends as + // part of the regular LTO phase of the ThinLTO pipeline (see exportTypeId). + struct TypeIdUserInfo { + std::vector<CallInst *> CallSites; + bool IsExported = false; + }; + DenseMap<Metadata *, TypeIdUserInfo> TypeIdUsers; /// This structure describes how to lower type tests for a particular type /// identifier. It is either built directly from the global analysis (during /// regular LTO or the regular LTO phase of ThinLTO), or indirectly using type /// identifier summaries and external symbol references (in ThinLTO backends). struct TypeIdLowering { - TypeTestResolution::Kind TheKind; + TypeTestResolution::Kind TheKind = TypeTestResolution::Unsat; /// All except Unsat: the start address within the combined global. Constant *OffsetedGlobal; @@ -274,9 +293,6 @@ class LowerTypeTestsModule { /// covering members of this type identifier as a multiple of 2^AlignLog2. Constant *SizeM1; - /// ByteArray, Inline, AllOnes: range of SizeM1 expressed as a bit width. - unsigned SizeM1BitWidth; - /// ByteArray: the byte array to test the address against. Constant *TheByteArray; @@ -291,6 +307,11 @@ class LowerTypeTestsModule { Function *WeakInitializerFn = nullptr; + void exportTypeId(StringRef TypeId, const TypeIdLowering &TIL); + TypeIdLowering importTypeId(StringRef TypeId); + void importTypeTest(CallInst *CI); + void importFunction(Function *F, bool isDefinition); + BitSetInfo buildBitSet(Metadata *TypeId, const DenseMap<GlobalTypeMember *, uint64_t> &GlobalLayout); @@ -327,8 +348,8 @@ class LowerTypeTestsModule { void createJumpTable(Function *F, ArrayRef<GlobalTypeMember *> Functions); public: - LowerTypeTestsModule(Module &M, SummaryAction Action, - ModuleSummaryIndex *Summary); + LowerTypeTestsModule(Module &M, ModuleSummaryIndex *ExportSummary, + const ModuleSummaryIndex *ImportSummary); bool lower(); // Lower the module using the action and summary passed as command line @@ -341,15 +362,17 @@ struct LowerTypeTests : public ModulePass { bool UseCommandLine = false; - SummaryAction Action; - ModuleSummaryIndex *Summary; + ModuleSummaryIndex *ExportSummary; + const ModuleSummaryIndex *ImportSummary; LowerTypeTests() : ModulePass(ID), UseCommandLine(true) { initializeLowerTypeTestsPass(*PassRegistry::getPassRegistry()); } - LowerTypeTests(SummaryAction Action, ModuleSummaryIndex *Summary) - : ModulePass(ID), Action(Action), Summary(Summary) { + LowerTypeTests(ModuleSummaryIndex *ExportSummary, + const ModuleSummaryIndex *ImportSummary) + : ModulePass(ID), ExportSummary(ExportSummary), + ImportSummary(ImportSummary) { initializeLowerTypeTestsPass(*PassRegistry::getPassRegistry()); } @@ -358,7 +381,7 @@ struct LowerTypeTests : public ModulePass { return false; if (UseCommandLine) return LowerTypeTestsModule::runForTesting(M); - return LowerTypeTestsModule(M, Action, Summary).lower(); + return LowerTypeTestsModule(M, ExportSummary, ImportSummary).lower(); } }; @@ -368,9 +391,10 @@ INITIALIZE_PASS(LowerTypeTests, "lowertypetests", "Lower type metadata", false, false) char LowerTypeTests::ID = 0; -ModulePass *llvm::createLowerTypeTestsPass(SummaryAction Action, - ModuleSummaryIndex *Summary) { - return new LowerTypeTests(Action, Summary); +ModulePass * +llvm::createLowerTypeTestsPass(ModuleSummaryIndex *ExportSummary, + const ModuleSummaryIndex *ImportSummary) { + return new LowerTypeTests(ExportSummary, ImportSummary); } /// Build a bit set for TypeId using the object layouts in @@ -467,13 +491,9 @@ void LowerTypeTestsModule::allocateByteArrays() { // Create an alias instead of RAUW'ing the gep directly. On x86 this ensures // that the pc-relative displacement is folded into the lea instead of the // test instruction getting another displacement. - if (LinkerSubsectionsViaSymbols) { - BAI->ByteArray->replaceAllUsesWith(GEP); - } else { - GlobalAlias *Alias = GlobalAlias::create( - Int8Ty, 0, GlobalValue::PrivateLinkage, "bits", GEP, &M); - BAI->ByteArray->replaceAllUsesWith(Alias); - } + GlobalAlias *Alias = GlobalAlias::create( + Int8Ty, 0, GlobalValue::PrivateLinkage, "bits", GEP, &M); + BAI->ByteArray->replaceAllUsesWith(Alias); BAI->ByteArray->eraseFromParent(); } @@ -494,10 +514,11 @@ Value *LowerTypeTestsModule::createBitSetTest(IRBuilder<> &B, return createMaskedBitTest(B, TIL.InlineBits, BitOffset); } else { Constant *ByteArray = TIL.TheByteArray; - if (!LinkerSubsectionsViaSymbols && AvoidReuse) { + if (AvoidReuse && !ImportSummary) { // Each use of the byte array uses a different alias. This makes the // backend less likely to reuse previously computed byte array addresses, // improving the security of the CFI mechanism based on this pass. + // This won't work when importing because TheByteArray is external. ByteArray = GlobalAlias::create(Int8Ty, 0, GlobalValue::PrivateLinkage, "bits_use", ByteArray, &M); } @@ -593,15 +614,31 @@ Value *LowerTypeTestsModule::lowerTypeTestCall(Metadata *TypeId, CallInst *CI, IntPtrTy)); Value *BitOffset = B.CreateOr(OffsetSHR, OffsetSHL); - Constant *BitSizeConst = ConstantExpr::getZExt(TIL.SizeM1, IntPtrTy); - Value *OffsetInRange = B.CreateICmpULE(BitOffset, BitSizeConst); + Value *OffsetInRange = B.CreateICmpULE(BitOffset, TIL.SizeM1); // If the bit set is all ones, testing against it is unnecessary. if (TIL.TheKind == TypeTestResolution::AllOnes) return OffsetInRange; - TerminatorInst *Term = SplitBlockAndInsertIfThen(OffsetInRange, CI, false); - IRBuilder<> ThenB(Term); + // See if the intrinsic is used in the following common pattern: + // br(llvm.type.test(...), thenbb, elsebb) + // where nothing happens between the type test and the br. + // If so, create slightly simpler IR. + if (CI->hasOneUse()) + if (auto *Br = dyn_cast<BranchInst>(*CI->user_begin())) + if (CI->getNextNode() == Br) { + BasicBlock *Then = InitialBB->splitBasicBlock(CI->getIterator()); + BasicBlock *Else = Br->getSuccessor(1); + BranchInst *NewBr = BranchInst::Create(Then, Else, OffsetInRange); + NewBr->setMetadata(LLVMContext::MD_prof, + Br->getMetadata(LLVMContext::MD_prof)); + ReplaceInstWithInst(InitialBB->getTerminator(), NewBr); + + IRBuilder<> ThenB(CI); + return createBitSetTest(ThenB, TIL, BitOffset); + } + + IRBuilder<> ThenB(SplitBlockAndInsertIfThen(OffsetInRange, CI, false)); // Now that we know that the offset is in range and aligned, load the // appropriate bit from the bitset. @@ -672,21 +709,174 @@ void LowerTypeTestsModule::buildBitSetsFromGlobalVariables( ConstantInt::get(Int32Ty, I * 2)}; Constant *CombinedGlobalElemPtr = ConstantExpr::getGetElementPtr( NewInit->getType(), CombinedGlobal, CombinedGlobalIdxs); - if (LinkerSubsectionsViaSymbols) { - GV->replaceAllUsesWith(CombinedGlobalElemPtr); - } else { - assert(GV->getType()->getAddressSpace() == 0); - GlobalAlias *GAlias = GlobalAlias::create(NewTy->getElementType(I * 2), 0, - GV->getLinkage(), "", - CombinedGlobalElemPtr, &M); - GAlias->setVisibility(GV->getVisibility()); - GAlias->takeName(GV); - GV->replaceAllUsesWith(GAlias); - } + assert(GV->getType()->getAddressSpace() == 0); + GlobalAlias *GAlias = + GlobalAlias::create(NewTy->getElementType(I * 2), 0, GV->getLinkage(), + "", CombinedGlobalElemPtr, &M); + GAlias->setVisibility(GV->getVisibility()); + GAlias->takeName(GV); + GV->replaceAllUsesWith(GAlias); GV->eraseFromParent(); } } +/// Export the given type identifier so that ThinLTO backends may import it. +/// Type identifiers are exported by adding coarse-grained information about how +/// to test the type identifier to the summary, and creating symbols in the +/// object file (aliases and absolute symbols) containing fine-grained +/// information about the type identifier. +void LowerTypeTestsModule::exportTypeId(StringRef TypeId, + const TypeIdLowering &TIL) { + TypeTestResolution &TTRes = + ExportSummary->getOrInsertTypeIdSummary(TypeId).TTRes; + TTRes.TheKind = TIL.TheKind; + + auto ExportGlobal = [&](StringRef Name, Constant *C) { + GlobalAlias *GA = + GlobalAlias::create(Int8Ty, 0, GlobalValue::ExternalLinkage, + "__typeid_" + TypeId + "_" + Name, C, &M); + GA->setVisibility(GlobalValue::HiddenVisibility); + }; + + if (TIL.TheKind != TypeTestResolution::Unsat) + ExportGlobal("global_addr", TIL.OffsetedGlobal); + + if (TIL.TheKind == TypeTestResolution::ByteArray || + TIL.TheKind == TypeTestResolution::Inline || + TIL.TheKind == TypeTestResolution::AllOnes) { + ExportGlobal("align", ConstantExpr::getIntToPtr(TIL.AlignLog2, Int8PtrTy)); + ExportGlobal("size_m1", ConstantExpr::getIntToPtr(TIL.SizeM1, Int8PtrTy)); + + uint64_t BitSize = cast<ConstantInt>(TIL.SizeM1)->getZExtValue() + 1; + if (TIL.TheKind == TypeTestResolution::Inline) + TTRes.SizeM1BitWidth = (BitSize <= 32) ? 5 : 6; + else + TTRes.SizeM1BitWidth = (BitSize <= 128) ? 7 : 32; + } + + if (TIL.TheKind == TypeTestResolution::ByteArray) { + ExportGlobal("byte_array", TIL.TheByteArray); + ExportGlobal("bit_mask", TIL.BitMask); + } + + if (TIL.TheKind == TypeTestResolution::Inline) + ExportGlobal("inline_bits", + ConstantExpr::getIntToPtr(TIL.InlineBits, Int8PtrTy)); +} + +LowerTypeTestsModule::TypeIdLowering +LowerTypeTestsModule::importTypeId(StringRef TypeId) { + const TypeIdSummary *TidSummary = ImportSummary->getTypeIdSummary(TypeId); + if (!TidSummary) + return {}; // Unsat: no globals match this type id. + const TypeTestResolution &TTRes = TidSummary->TTRes; + + TypeIdLowering TIL; + TIL.TheKind = TTRes.TheKind; + + auto ImportGlobal = [&](StringRef Name, unsigned AbsWidth) { + Constant *C = + M.getOrInsertGlobal(("__typeid_" + TypeId + "_" + Name).str(), Int8Ty); + auto *GV = dyn_cast<GlobalVariable>(C); + // We only need to set metadata if the global is newly created, in which + // case it would not have hidden visibility. + if (!GV || GV->getVisibility() == GlobalValue::HiddenVisibility) + return C; + + GV->setVisibility(GlobalValue::HiddenVisibility); + auto SetAbsRange = [&](uint64_t Min, uint64_t Max) { + auto *MinC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Min)); + auto *MaxC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Max)); + GV->setMetadata(LLVMContext::MD_absolute_symbol, + MDNode::get(M.getContext(), {MinC, MaxC})); + }; + if (AbsWidth == IntPtrTy->getBitWidth()) + SetAbsRange(~0ull, ~0ull); // Full set. + else if (AbsWidth) + SetAbsRange(0, 1ull << AbsWidth); + return C; + }; + + if (TIL.TheKind != TypeTestResolution::Unsat) + TIL.OffsetedGlobal = ImportGlobal("global_addr", 0); + + if (TIL.TheKind == TypeTestResolution::ByteArray || + TIL.TheKind == TypeTestResolution::Inline || + TIL.TheKind == TypeTestResolution::AllOnes) { + TIL.AlignLog2 = ConstantExpr::getPtrToInt(ImportGlobal("align", 8), Int8Ty); + TIL.SizeM1 = ConstantExpr::getPtrToInt( + ImportGlobal("size_m1", TTRes.SizeM1BitWidth), IntPtrTy); + } + + if (TIL.TheKind == TypeTestResolution::ByteArray) { + TIL.TheByteArray = ImportGlobal("byte_array", 0); + TIL.BitMask = ImportGlobal("bit_mask", 8); + } + + if (TIL.TheKind == TypeTestResolution::Inline) + TIL.InlineBits = ConstantExpr::getPtrToInt( + ImportGlobal("inline_bits", 1 << TTRes.SizeM1BitWidth), + TTRes.SizeM1BitWidth <= 5 ? Int32Ty : Int64Ty); + + return TIL; +} + +void LowerTypeTestsModule::importTypeTest(CallInst *CI) { + auto TypeIdMDVal = dyn_cast<MetadataAsValue>(CI->getArgOperand(1)); + if (!TypeIdMDVal) + report_fatal_error("Second argument of llvm.type.test must be metadata"); + + auto TypeIdStr = dyn_cast<MDString>(TypeIdMDVal->getMetadata()); + if (!TypeIdStr) + report_fatal_error( + "Second argument of llvm.type.test must be a metadata string"); + + TypeIdLowering TIL = importTypeId(TypeIdStr->getString()); + Value *Lowered = lowerTypeTestCall(TypeIdStr, CI, TIL); + CI->replaceAllUsesWith(Lowered); + CI->eraseFromParent(); +} + +// ThinLTO backend: the function F has a jump table entry; update this module +// accordingly. isDefinition describes the type of the jump table entry. +void LowerTypeTestsModule::importFunction(Function *F, bool isDefinition) { + assert(F->getType()->getAddressSpace() == 0); + + // Declaration of a local function - nothing to do. + if (F->isDeclarationForLinker() && isDefinition) + return; + + GlobalValue::VisibilityTypes Visibility = F->getVisibility(); + std::string Name = F->getName(); + Function *FDecl; + + if (F->isDeclarationForLinker() && !isDefinition) { + // Declaration of an external function. + FDecl = Function::Create(F->getFunctionType(), GlobalValue::ExternalLinkage, + Name + ".cfi_jt", &M); + FDecl->setVisibility(GlobalValue::HiddenVisibility); + } else if (isDefinition) { + F->setName(Name + ".cfi"); + F->setLinkage(GlobalValue::ExternalLinkage); + F->setVisibility(GlobalValue::HiddenVisibility); + FDecl = Function::Create(F->getFunctionType(), GlobalValue::ExternalLinkage, + Name, &M); + FDecl->setVisibility(Visibility); + } else { + // Function definition without type metadata, where some other translation + // unit contained a declaration with type metadata. This normally happens + // during mixed CFI + non-CFI compilation. We do nothing with the function + // so that it is treated the same way as a function defined outside of the + // LTO unit. + return; + } + + if (F->isWeakForLinker()) + replaceWeakDeclarationWithJumpTablePtr(F, FDecl); + else + F->replaceAllUsesWith(FDecl); +} + void LowerTypeTestsModule::lowerTypeTestCalls( ArrayRef<Metadata *> TypeIds, Constant *CombinedGlobalAddr, const DenseMap<GlobalTypeMember *, uint64_t> &GlobalLayout) { @@ -708,16 +898,12 @@ void LowerTypeTestsModule::lowerTypeTestCalls( TIL.OffsetedGlobal = ConstantExpr::getGetElementPtr( Int8Ty, CombinedGlobalAddr, ConstantInt::get(IntPtrTy, BSI.ByteOffset)), TIL.AlignLog2 = ConstantInt::get(Int8Ty, BSI.AlignLog2); + TIL.SizeM1 = ConstantInt::get(IntPtrTy, BSI.BitSize - 1); if (BSI.isAllOnes()) { TIL.TheKind = (BSI.BitSize == 1) ? TypeTestResolution::Single : TypeTestResolution::AllOnes; - TIL.SizeM1BitWidth = (BSI.BitSize <= 128) ? 7 : 32; - TIL.SizeM1 = ConstantInt::get((BSI.BitSize <= 128) ? Int8Ty : Int32Ty, - BSI.BitSize - 1); } else if (BSI.BitSize <= 64) { TIL.TheKind = TypeTestResolution::Inline; - TIL.SizeM1BitWidth = (BSI.BitSize <= 32) ? 5 : 6; - TIL.SizeM1 = ConstantInt::get(Int8Ty, BSI.BitSize - 1); uint64_t InlineBits = 0; for (auto Bit : BSI.Bits) InlineBits |= uint64_t(1) << Bit; @@ -728,17 +914,19 @@ void LowerTypeTestsModule::lowerTypeTestCalls( (BSI.BitSize <= 32) ? Int32Ty : Int64Ty, InlineBits); } else { TIL.TheKind = TypeTestResolution::ByteArray; - TIL.SizeM1BitWidth = (BSI.BitSize <= 128) ? 7 : 32; - TIL.SizeM1 = ConstantInt::get((BSI.BitSize <= 128) ? Int8Ty : Int32Ty, - BSI.BitSize - 1); ++NumByteArraysCreated; ByteArrayInfo *BAI = createByteArray(BSI); TIL.TheByteArray = BAI->ByteArray; TIL.BitMask = BAI->MaskGlobal; } + TypeIdUserInfo &TIUI = TypeIdUsers[TypeId]; + + if (TIUI.IsExported) + exportTypeId(cast<MDString>(TypeId)->getString(), TIL); + // Lower each call to llvm.type.test for this type identifier. - for (CallInst *CI : TypeTestCallSites[TypeId]) { + for (CallInst *CI : TIUI.CallSites) { ++NumTypeTestCallsLowered; Value *Lowered = lowerTypeTestCall(TypeId, CI, TIL); CI->replaceAllUsesWith(Lowered); @@ -757,9 +945,9 @@ void LowerTypeTestsModule::verifyTypeMDNode(GlobalObject *GO, MDNode *Type) { report_fatal_error( "A member of a type identifier may not have an explicit section"); - if (isa<GlobalVariable>(GO) && GO->isDeclarationForLinker()) - report_fatal_error( - "A global var member of a type identifier must be a definition"); + // FIXME: We previously checked that global var member of a type identifier + // must be a definition, but the IR linker may leave type metadata on + // declarations. We should restore this check after fixing PR31759. auto OffsetConstMD = dyn_cast<ConstantAsMetadata>(Type->getOperand(0)); if (!OffsetConstMD) @@ -1012,7 +1200,6 @@ void LowerTypeTestsModule::buildBitSetsFromFunctionsNative( // arithmetic that we normally use for globals. // FIXME: find a better way to represent the jumptable in the IR. - assert(!Functions.empty()); // Build a simple layout based on the regular layout of jump tables. @@ -1036,6 +1223,7 @@ void LowerTypeTestsModule::buildBitSetsFromFunctionsNative( // references to the original functions with references to the aliases. for (unsigned I = 0; I != Functions.size(); ++I) { Function *F = cast<Function>(Functions[I]->getGlobal()); + bool IsDefinition = Functions[I]->isDefinition(); Constant *CombinedGlobalElemPtr = ConstantExpr::getBitCast( ConstantExpr::getInBoundsGetElementPtr( @@ -1043,8 +1231,18 @@ void LowerTypeTestsModule::buildBitSetsFromFunctionsNative( ArrayRef<Constant *>{ConstantInt::get(IntPtrTy, 0), ConstantInt::get(IntPtrTy, I)}), F->getType()); - if (LinkerSubsectionsViaSymbols || F->isDeclarationForLinker()) { - + if (Functions[I]->isExported()) { + if (IsDefinition) { + ExportSummary->cfiFunctionDefs().insert(F->getName()); + } else { + GlobalAlias *JtAlias = GlobalAlias::create( + F->getValueType(), 0, GlobalValue::ExternalLinkage, + F->getName() + ".cfi_jt", CombinedGlobalElemPtr, &M); + JtAlias->setVisibility(GlobalValue::HiddenVisibility); + ExportSummary->cfiFunctionDecls().insert(F->getName()); + } + } + if (!IsDefinition) { if (F->isWeakForLinker()) replaceWeakDeclarationWithJumpTablePtr(F, CombinedGlobalElemPtr); else @@ -1052,9 +1250,8 @@ void LowerTypeTestsModule::buildBitSetsFromFunctionsNative( } else { assert(F->getType()->getAddressSpace() == 0); - GlobalAlias *FAlias = GlobalAlias::create(F->getValueType(), 0, - F->getLinkage(), "", - CombinedGlobalElemPtr, &M); + GlobalAlias *FAlias = GlobalAlias::create( + F->getValueType(), 0, F->getLinkage(), "", CombinedGlobalElemPtr, &M); FAlias->setVisibility(F->getVisibility()); FAlias->takeName(F); if (FAlias->hasName()) @@ -1173,15 +1370,12 @@ void LowerTypeTestsModule::buildBitSetsFromDisjointSet( } /// Lower all type tests in this module. -LowerTypeTestsModule::LowerTypeTestsModule(Module &M, SummaryAction Action, - ModuleSummaryIndex *Summary) - : M(M), Action(Action), Summary(Summary) { - // FIXME: Use these fields. - (void)this->Action; - (void)this->Summary; - +LowerTypeTestsModule::LowerTypeTestsModule( + Module &M, ModuleSummaryIndex *ExportSummary, + const ModuleSummaryIndex *ImportSummary) + : M(M), ExportSummary(ExportSummary), ImportSummary(ImportSummary) { + assert(!(ExportSummary && ImportSummary)); Triple TargetTriple(M.getTargetTriple()); - LinkerSubsectionsViaSymbols = TargetTriple.isMacOSX(); Arch = TargetTriple.getArch(); OS = TargetTriple.getOS(); ObjectFormat = TargetTriple.getObjectFormat(); @@ -1203,7 +1397,11 @@ bool LowerTypeTestsModule::runForTesting(Module &M) { ExitOnErr(errorCodeToError(In.error())); } - bool Changed = LowerTypeTestsModule(M, ClSummaryAction, &Summary).lower(); + bool Changed = + LowerTypeTestsModule( + M, ClSummaryAction == PassSummaryAction::Export ? &Summary : nullptr, + ClSummaryAction == PassSummaryAction::Import ? &Summary : nullptr) + .lower(); if (!ClWriteSummary.empty()) { ExitOnError ExitOnErr("-lowertypetests-write-summary: " + ClWriteSummary + @@ -1222,9 +1420,40 @@ bool LowerTypeTestsModule::runForTesting(Module &M) { bool LowerTypeTestsModule::lower() { Function *TypeTestFunc = M.getFunction(Intrinsic::getName(Intrinsic::type_test)); - if (!TypeTestFunc || TypeTestFunc->use_empty()) + if ((!TypeTestFunc || TypeTestFunc->use_empty()) && !ExportSummary && + !ImportSummary) return false; + if (ImportSummary) { + if (TypeTestFunc) { + for (auto UI = TypeTestFunc->use_begin(), UE = TypeTestFunc->use_end(); + UI != UE;) { + auto *CI = cast<CallInst>((*UI++).getUser()); + importTypeTest(CI); + } + } + + SmallVector<Function *, 8> Defs; + SmallVector<Function *, 8> Decls; + for (auto &F : M) { + // CFI functions are either external, or promoted. A local function may + // have the same name, but it's not the one we are looking for. + if (F.hasLocalLinkage()) + continue; + if (ImportSummary->cfiFunctionDefs().count(F.getName())) + Defs.push_back(&F); + else if (ImportSummary->cfiFunctionDecls().count(F.getName())) + Decls.push_back(&F); + } + + for (auto F : Defs) + importFunction(F, /*isDefinition*/ true); + for (auto F : Decls) + importFunction(F, /*isDefinition*/ false); + + return true; + } + // Equivalence class set containing type identifiers and the globals that // reference them. This is used to partition the set of type identifiers in // the module into disjoint sets. @@ -1247,13 +1476,76 @@ bool LowerTypeTestsModule::lower() { llvm::DenseMap<Metadata *, TIInfo> TypeIdInfo; unsigned I = 0; SmallVector<MDNode *, 2> Types; + + struct ExportedFunctionInfo { + CfiFunctionLinkage Linkage; + MDNode *FuncMD; // {name, linkage, type[, type...]} + }; + DenseMap<StringRef, ExportedFunctionInfo> ExportedFunctions; + if (ExportSummary) { + NamedMDNode *CfiFunctionsMD = M.getNamedMetadata("cfi.functions"); + if (CfiFunctionsMD) { + for (auto FuncMD : CfiFunctionsMD->operands()) { + assert(FuncMD->getNumOperands() >= 2); + StringRef FunctionName = + cast<MDString>(FuncMD->getOperand(0))->getString(); + if (!ExportSummary->isGUIDLive(GlobalValue::getGUID( + GlobalValue::dropLLVMManglingEscape(FunctionName)))) + continue; + CfiFunctionLinkage Linkage = static_cast<CfiFunctionLinkage>( + cast<ConstantAsMetadata>(FuncMD->getOperand(1)) + ->getValue() + ->getUniqueInteger() + .getZExtValue()); + auto P = ExportedFunctions.insert({FunctionName, {Linkage, FuncMD}}); + if (!P.second && P.first->second.Linkage != CFL_Definition) + P.first->second = {Linkage, FuncMD}; + } + + for (const auto &P : ExportedFunctions) { + StringRef FunctionName = P.first; + CfiFunctionLinkage Linkage = P.second.Linkage; + MDNode *FuncMD = P.second.FuncMD; + Function *F = M.getFunction(FunctionName); + if (!F) + F = Function::Create( + FunctionType::get(Type::getVoidTy(M.getContext()), false), + GlobalVariable::ExternalLinkage, FunctionName, &M); + + if (Linkage == CFL_Definition) + F->eraseMetadata(LLVMContext::MD_type); + + if (F->isDeclaration()) { + if (Linkage == CFL_WeakDeclaration) + F->setLinkage(GlobalValue::ExternalWeakLinkage); + + SmallVector<MDNode *, 2> Types; + for (unsigned I = 2; I < FuncMD->getNumOperands(); ++I) + F->addMetadata(LLVMContext::MD_type, + *cast<MDNode>(FuncMD->getOperand(I).get())); + } + } + } + } + for (GlobalObject &GO : M.global_objects()) { + if (isa<GlobalVariable>(GO) && GO.isDeclarationForLinker()) + continue; + Types.clear(); GO.getMetadata(LLVMContext::MD_type, Types); if (Types.empty()) continue; - auto *GTM = GlobalTypeMember::create(Alloc, &GO, Types); + bool IsDefinition = !GO.isDeclarationForLinker(); + bool IsExported = false; + if (isa<Function>(GO) && ExportedFunctions.count(GO.getName())) { + IsDefinition |= ExportedFunctions[GO.getName()].Linkage == CFL_Definition; + IsExported = true; + } + + auto *GTM = + GlobalTypeMember::create(Alloc, &GO, IsDefinition, IsExported, Types); for (MDNode *Type : Types) { verifyTypeMDNode(&GO, Type); auto &Info = TypeIdInfo[cast<MDNode>(Type)->getOperand(1)]; @@ -1262,33 +1554,56 @@ bool LowerTypeTestsModule::lower() { } } - for (const Use &U : TypeTestFunc->uses()) { - auto CI = cast<CallInst>(U.getUser()); + auto AddTypeIdUse = [&](Metadata *TypeId) -> TypeIdUserInfo & { + // Add the call site to the list of call sites for this type identifier. We + // also use TypeIdUsers to keep track of whether we have seen this type + // identifier before. If we have, we don't need to re-add the referenced + // globals to the equivalence class. + auto Ins = TypeIdUsers.insert({TypeId, {}}); + if (Ins.second) { + // Add the type identifier to the equivalence class. + GlobalClassesTy::iterator GCI = GlobalClasses.insert(TypeId); + GlobalClassesTy::member_iterator CurSet = GlobalClasses.findLeader(GCI); + + // Add the referenced globals to the type identifier's equivalence class. + for (GlobalTypeMember *GTM : TypeIdInfo[TypeId].RefGlobals) + CurSet = GlobalClasses.unionSets( + CurSet, GlobalClasses.findLeader(GlobalClasses.insert(GTM))); + } + + return Ins.first->second; + }; - auto BitSetMDVal = dyn_cast<MetadataAsValue>(CI->getArgOperand(1)); - if (!BitSetMDVal) - report_fatal_error("Second argument of llvm.type.test must be metadata"); - auto BitSet = BitSetMDVal->getMetadata(); + if (TypeTestFunc) { + for (const Use &U : TypeTestFunc->uses()) { + auto CI = cast<CallInst>(U.getUser()); - // Add the call site to the list of call sites for this type identifier. We - // also use TypeTestCallSites to keep track of whether we have seen this - // type identifier before. If we have, we don't need to re-add the - // referenced globals to the equivalence class. - std::pair<DenseMap<Metadata *, std::vector<CallInst *>>::iterator, bool> - Ins = TypeTestCallSites.insert( - std::make_pair(BitSet, std::vector<CallInst *>())); - Ins.first->second.push_back(CI); - if (!Ins.second) - continue; + auto TypeIdMDVal = dyn_cast<MetadataAsValue>(CI->getArgOperand(1)); + if (!TypeIdMDVal) + report_fatal_error("Second argument of llvm.type.test must be metadata"); + auto TypeId = TypeIdMDVal->getMetadata(); + AddTypeIdUse(TypeId).CallSites.push_back(CI); + } + } - // Add the type identifier to the equivalence class. - GlobalClassesTy::iterator GCI = GlobalClasses.insert(BitSet); - GlobalClassesTy::member_iterator CurSet = GlobalClasses.findLeader(GCI); + if (ExportSummary) { + DenseMap<GlobalValue::GUID, TinyPtrVector<Metadata *>> MetadataByGUID; + for (auto &P : TypeIdInfo) { + if (auto *TypeId = dyn_cast<MDString>(P.first)) + MetadataByGUID[GlobalValue::getGUID(TypeId->getString())].push_back( + TypeId); + } - // Add the referenced globals to the type identifier's equivalence class. - for (GlobalTypeMember *GTM : TypeIdInfo[BitSet].RefGlobals) - CurSet = GlobalClasses.unionSets( - CurSet, GlobalClasses.findLeader(GlobalClasses.insert(GTM))); + for (auto &P : *ExportSummary) { + for (auto &S : P.second.SummaryList) { + auto *FS = dyn_cast<FunctionSummary>(S.get()); + if (!FS || !ExportSummary->isGlobalValueLive(FS)) + continue; + for (GlobalValue::GUID G : FS->type_tests()) + for (Metadata *MD : MetadataByGUID[G]) + AddTypeIdUse(MD).IsExported = true; + } + } } if (GlobalClasses.empty()) @@ -1349,8 +1664,9 @@ bool LowerTypeTestsModule::lower() { PreservedAnalyses LowerTypeTestsPass::run(Module &M, ModuleAnalysisManager &AM) { - bool Changed = - LowerTypeTestsModule(M, SummaryAction::None, /*Summary=*/nullptr).lower(); + bool Changed = LowerTypeTestsModule(M, /*ExportSummary=*/nullptr, + /*ImportSummary=*/nullptr) + .lower(); if (!Changed) return PreservedAnalyses::all(); return PreservedAnalyses::none(); diff --git a/contrib/llvm/lib/Transforms/IPO/MergeFunctions.cpp b/contrib/llvm/lib/Transforms/IPO/MergeFunctions.cpp index e0bb0eb..0e478ba 100644 --- a/contrib/llvm/lib/Transforms/IPO/MergeFunctions.cpp +++ b/contrib/llvm/lib/Transforms/IPO/MergeFunctions.cpp @@ -96,8 +96,10 @@ #include "llvm/IR/CallSite.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugInfo.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/ValueHandle.h" @@ -127,6 +129,26 @@ static cl::opt<unsigned> NumFunctionsForSanityCheck( "'0' disables this check. Works only with '-debug' key."), cl::init(0), cl::Hidden); +// Under option -mergefunc-preserve-debug-info we: +// - Do not create a new function for a thunk. +// - Retain the debug info for a thunk's parameters (and associated +// instructions for the debug info) from the entry block. +// Note: -debug will display the algorithm at work. +// - Create debug-info for the call (to the shared implementation) made by +// a thunk and its return value. +// - Erase the rest of the function, retaining the (minimally sized) entry +// block to create a thunk. +// - Preserve a thunk's call site to point to the thunk even when both occur +// within the same translation unit, to aid debugability. Note that this +// behaviour differs from the underlying -mergefunc implementation which +// modifies the thunk's call site to point to the shared implementation +// when both occur within the same translation unit. +static cl::opt<bool> + MergeFunctionsPDI("mergefunc-preserve-debug-info", cl::Hidden, + cl::init(false), + cl::desc("Preserve debug info in thunk when mergefunc " + "transformations are made.")); + namespace { class FunctionNode { @@ -185,11 +207,13 @@ private: /// A work queue of functions that may have been modified and should be /// analyzed again. - std::vector<WeakVH> Deferred; + std::vector<WeakTrackingVH> Deferred; /// Checks the rules of order relation introduced among functions set. /// Returns true, if sanity check has been passed, and false if failed. - bool doSanityCheck(std::vector<WeakVH> &Worklist); +#ifndef NDEBUG + bool doSanityCheck(std::vector<WeakTrackingVH> &Worklist); +#endif /// Insert a ComparableFunction into the FnTree, or merge it away if it's /// equal to one that's already present. @@ -215,8 +239,21 @@ private: /// Replace G with a thunk or an alias to F. Deletes G. void writeThunkOrAlias(Function *F, Function *G); - /// Replace G with a simple tail call to bitcast(F). Also replace direct uses - /// of G with bitcast(F). Deletes G. + /// Fill PDIUnrelatedWL with instructions from the entry block that are + /// unrelated to parameter related debug info. + void filterInstsUnrelatedToPDI(BasicBlock *GEntryBlock, + std::vector<Instruction *> &PDIUnrelatedWL); + + /// Erase the rest of the CFG (i.e. barring the entry block). + void eraseTail(Function *G); + + /// Erase the instructions in PDIUnrelatedWL as they are unrelated to the + /// parameter debug info, from the entry block. + void eraseInstsUnrelatedToPDI(std::vector<Instruction *> &PDIUnrelatedWL); + + /// Replace G with a simple tail call to bitcast(F). Also (unless + /// MergeFunctionsPDI holds) replace direct uses of G with bitcast(F), + /// delete G. void writeThunk(Function *F, Function *G); /// Replace G with an alias to F. Deletes G. @@ -248,7 +285,8 @@ ModulePass *llvm::createMergeFunctionsPass() { return new MergeFunctions(); } -bool MergeFunctions::doSanityCheck(std::vector<WeakVH> &Worklist) { +#ifndef NDEBUG +bool MergeFunctions::doSanityCheck(std::vector<WeakTrackingVH> &Worklist) { if (const unsigned Max = NumFunctionsForSanityCheck) { unsigned TripleNumber = 0; bool Valid = true; @@ -256,10 +294,12 @@ bool MergeFunctions::doSanityCheck(std::vector<WeakVH> &Worklist) { dbgs() << "MERGEFUNC-SANITY: Started for first " << Max << " functions.\n"; unsigned i = 0; - for (std::vector<WeakVH>::iterator I = Worklist.begin(), E = Worklist.end(); + for (std::vector<WeakTrackingVH>::iterator I = Worklist.begin(), + E = Worklist.end(); I != E && i < Max; ++I, ++i) { unsigned j = i; - for (std::vector<WeakVH>::iterator J = I; J != E && j < Max; ++J, ++j) { + for (std::vector<WeakTrackingVH>::iterator J = I; J != E && j < Max; + ++J, ++j) { Function *F1 = cast<Function>(*I); Function *F2 = cast<Function>(*J); int Res1 = FunctionComparator(F1, F2, &GlobalNumbers).compare(); @@ -269,8 +309,7 @@ bool MergeFunctions::doSanityCheck(std::vector<WeakVH> &Worklist) { if (Res1 != -Res2) { dbgs() << "MERGEFUNC-SANITY: Non-symmetric; triple: " << TripleNumber << "\n"; - F1->dump(); - F2->dump(); + dbgs() << *F1 << '\n' << *F2 << '\n'; Valid = false; } @@ -278,7 +317,7 @@ bool MergeFunctions::doSanityCheck(std::vector<WeakVH> &Worklist) { continue; unsigned k = j; - for (std::vector<WeakVH>::iterator K = J; K != E && k < Max; + for (std::vector<WeakTrackingVH>::iterator K = J; K != E && k < Max; ++k, ++K, ++TripleNumber) { if (K == J) continue; @@ -305,9 +344,7 @@ bool MergeFunctions::doSanityCheck(std::vector<WeakVH> &Worklist) { << TripleNumber << "\n"; dbgs() << "Res1, Res3, Res4: " << Res1 << ", " << Res3 << ", " << Res4 << "\n"; - F1->dump(); - F2->dump(); - F3->dump(); + dbgs() << *F1 << '\n' << *F2 << '\n' << *F3 << '\n'; Valid = false; } } @@ -319,6 +356,7 @@ bool MergeFunctions::doSanityCheck(std::vector<WeakVH> &Worklist) { } return true; } +#endif bool MergeFunctions::runOnModule(Module &M) { if (skipModule(M)) @@ -349,12 +387,12 @@ bool MergeFunctions::runOnModule(Module &M) { // consider merging it. Otherwise it is dropped and never considered again. if ((I != S && std::prev(I)->first == I->first) || (std::next(I) != IE && std::next(I)->first == I->first) ) { - Deferred.push_back(WeakVH(I->second)); + Deferred.push_back(WeakTrackingVH(I->second)); } } do { - std::vector<WeakVH> Worklist; + std::vector<WeakTrackingVH> Worklist; Deferred.swap(Worklist); DEBUG(doSanityCheck(Worklist)); @@ -363,7 +401,7 @@ bool MergeFunctions::runOnModule(Module &M) { DEBUG(dbgs() << "size of worklist: " << Worklist.size() << '\n'); // Insert functions and merge them. - for (WeakVH &I : Worklist) { + for (WeakTrackingVH &I : Worklist) { if (!I) continue; Function *F = cast<Function>(I); @@ -400,19 +438,15 @@ void MergeFunctions::replaceDirectCallers(Function *Old, Function *New) { // Transferring other attributes may help other optimizations, but that // should be done uniformly and not in this ad-hoc way. auto &Context = New->getContext(); - auto NewFuncAttrs = New->getAttributes(); - auto CallSiteAttrs = CS.getAttributes(); - - CallSiteAttrs = CallSiteAttrs.addAttributes( - Context, AttributeSet::ReturnIndex, NewFuncAttrs.getRetAttributes()); - - for (unsigned argIdx = 0; argIdx < CS.arg_size(); argIdx++) { - AttributeSet Attrs = NewFuncAttrs.getParamAttributes(argIdx); - if (Attrs.getNumSlots()) - CallSiteAttrs = CallSiteAttrs.addAttributes(Context, argIdx, Attrs); - } - - CS.setAttributes(CallSiteAttrs); + auto NewPAL = New->getAttributes(); + SmallVector<AttributeSet, 4> NewArgAttrs; + for (unsigned argIdx = 0; argIdx < CS.arg_size(); argIdx++) + NewArgAttrs.push_back(NewPAL.getParamAttributes(argIdx)); + // Don't transfer attributes from the function to the callee. Function + // attributes typically aren't relevant to the calling convention or ABI. + CS.setAttributes(AttributeList::get(Context, /*FnAttrs=*/AttributeSet(), + NewPAL.getRetAttributes(), + NewArgAttrs)); remove(CS.getInstruction()->getParent()->getParent()); U->set(BitcastNew); @@ -461,51 +495,242 @@ static Value *createCast(IRBuilder<> &Builder, Value *V, Type *DestTy) { return Builder.CreateBitCast(V, DestTy); } -// Replace G with a simple tail call to bitcast(F). Also replace direct uses -// of G with bitcast(F). Deletes G. +// Erase the instructions in PDIUnrelatedWL as they are unrelated to the +// parameter debug info, from the entry block. +void MergeFunctions::eraseInstsUnrelatedToPDI( + std::vector<Instruction *> &PDIUnrelatedWL) { + + DEBUG(dbgs() << " Erasing instructions (in reverse order of appearance in " + "entry block) unrelated to parameter debug info from entry " + "block: {\n"); + while (!PDIUnrelatedWL.empty()) { + Instruction *I = PDIUnrelatedWL.back(); + DEBUG(dbgs() << " Deleting Instruction: "); + DEBUG(I->print(dbgs())); + DEBUG(dbgs() << "\n"); + I->eraseFromParent(); + PDIUnrelatedWL.pop_back(); + } + DEBUG(dbgs() << " } // Done erasing instructions unrelated to parameter " + "debug info from entry block. \n"); +} + +// Reduce G to its entry block. +void MergeFunctions::eraseTail(Function *G) { + + std::vector<BasicBlock *> WorklistBB; + for (Function::iterator BBI = std::next(G->begin()), BBE = G->end(); + BBI != BBE; ++BBI) { + BBI->dropAllReferences(); + WorklistBB.push_back(&*BBI); + } + while (!WorklistBB.empty()) { + BasicBlock *BB = WorklistBB.back(); + BB->eraseFromParent(); + WorklistBB.pop_back(); + } +} + +// We are interested in the following instructions from the entry block as being +// related to parameter debug info: +// - @llvm.dbg.declare +// - stores from the incoming parameters to locations on the stack-frame +// - allocas that create these locations on the stack-frame +// - @llvm.dbg.value +// - the entry block's terminator +// The rest are unrelated to debug info for the parameters; fill up +// PDIUnrelatedWL with such instructions. +void MergeFunctions::filterInstsUnrelatedToPDI( + BasicBlock *GEntryBlock, std::vector<Instruction *> &PDIUnrelatedWL) { + + std::set<Instruction *> PDIRelated; + for (BasicBlock::iterator BI = GEntryBlock->begin(), BIE = GEntryBlock->end(); + BI != BIE; ++BI) { + if (auto *DVI = dyn_cast<DbgValueInst>(&*BI)) { + DEBUG(dbgs() << " Deciding: "); + DEBUG(BI->print(dbgs())); + DEBUG(dbgs() << "\n"); + DILocalVariable *DILocVar = DVI->getVariable(); + if (DILocVar->isParameter()) { + DEBUG(dbgs() << " Include (parameter): "); + DEBUG(BI->print(dbgs())); + DEBUG(dbgs() << "\n"); + PDIRelated.insert(&*BI); + } else { + DEBUG(dbgs() << " Delete (!parameter): "); + DEBUG(BI->print(dbgs())); + DEBUG(dbgs() << "\n"); + } + } else if (auto *DDI = dyn_cast<DbgDeclareInst>(&*BI)) { + DEBUG(dbgs() << " Deciding: "); + DEBUG(BI->print(dbgs())); + DEBUG(dbgs() << "\n"); + DILocalVariable *DILocVar = DDI->getVariable(); + if (DILocVar->isParameter()) { + DEBUG(dbgs() << " Parameter: "); + DEBUG(DILocVar->print(dbgs())); + AllocaInst *AI = dyn_cast_or_null<AllocaInst>(DDI->getAddress()); + if (AI) { + DEBUG(dbgs() << " Processing alloca users: "); + DEBUG(dbgs() << "\n"); + for (User *U : AI->users()) { + if (StoreInst *SI = dyn_cast<StoreInst>(U)) { + if (Value *Arg = SI->getValueOperand()) { + if (dyn_cast<Argument>(Arg)) { + DEBUG(dbgs() << " Include: "); + DEBUG(AI->print(dbgs())); + DEBUG(dbgs() << "\n"); + PDIRelated.insert(AI); + DEBUG(dbgs() << " Include (parameter): "); + DEBUG(SI->print(dbgs())); + DEBUG(dbgs() << "\n"); + PDIRelated.insert(SI); + DEBUG(dbgs() << " Include: "); + DEBUG(BI->print(dbgs())); + DEBUG(dbgs() << "\n"); + PDIRelated.insert(&*BI); + } else { + DEBUG(dbgs() << " Delete (!parameter): "); + DEBUG(SI->print(dbgs())); + DEBUG(dbgs() << "\n"); + } + } + } else { + DEBUG(dbgs() << " Defer: "); + DEBUG(U->print(dbgs())); + DEBUG(dbgs() << "\n"); + } + } + } else { + DEBUG(dbgs() << " Delete (alloca NULL): "); + DEBUG(BI->print(dbgs())); + DEBUG(dbgs() << "\n"); + } + } else { + DEBUG(dbgs() << " Delete (!parameter): "); + DEBUG(BI->print(dbgs())); + DEBUG(dbgs() << "\n"); + } + } else if (dyn_cast<TerminatorInst>(BI) == GEntryBlock->getTerminator()) { + DEBUG(dbgs() << " Will Include Terminator: "); + DEBUG(BI->print(dbgs())); + DEBUG(dbgs() << "\n"); + PDIRelated.insert(&*BI); + } else { + DEBUG(dbgs() << " Defer: "); + DEBUG(BI->print(dbgs())); + DEBUG(dbgs() << "\n"); + } + } + DEBUG(dbgs() + << " Report parameter debug info related/related instructions: {\n"); + for (BasicBlock::iterator BI = GEntryBlock->begin(), BE = GEntryBlock->end(); + BI != BE; ++BI) { + + Instruction *I = &*BI; + if (PDIRelated.find(I) == PDIRelated.end()) { + DEBUG(dbgs() << " !PDIRelated: "); + DEBUG(I->print(dbgs())); + DEBUG(dbgs() << "\n"); + PDIUnrelatedWL.push_back(I); + } else { + DEBUG(dbgs() << " PDIRelated: "); + DEBUG(I->print(dbgs())); + DEBUG(dbgs() << "\n"); + } + } + DEBUG(dbgs() << " }\n"); +} + +// Replace G with a simple tail call to bitcast(F). Also (unless +// MergeFunctionsPDI holds) replace direct uses of G with bitcast(F), +// delete G. Under MergeFunctionsPDI, we use G itself for creating +// the thunk as we preserve the debug info (and associated instructions) +// from G's entry block pertaining to G's incoming arguments which are +// passed on as corresponding arguments in the call that G makes to F. +// For better debugability, under MergeFunctionsPDI, we do not modify G's +// call sites to point to F even when within the same translation unit. void MergeFunctions::writeThunk(Function *F, Function *G) { - if (!G->isInterposable()) { - // Redirect direct callers of G to F. + if (!G->isInterposable() && !MergeFunctionsPDI) { + // Redirect direct callers of G to F. (See note on MergeFunctionsPDI + // above). replaceDirectCallers(G, F); } // If G was internal then we may have replaced all uses of G with F. If so, - // stop here and delete G. There's no need for a thunk. - if (G->hasLocalLinkage() && G->use_empty()) { + // stop here and delete G. There's no need for a thunk. (See note on + // MergeFunctionsPDI above). + if (G->hasLocalLinkage() && G->use_empty() && !MergeFunctionsPDI) { G->eraseFromParent(); return; } - Function *NewG = Function::Create(G->getFunctionType(), G->getLinkage(), "", - G->getParent()); - BasicBlock *BB = BasicBlock::Create(F->getContext(), "", NewG); - IRBuilder<> Builder(BB); + BasicBlock *GEntryBlock = nullptr; + std::vector<Instruction *> PDIUnrelatedWL; + BasicBlock *BB = nullptr; + Function *NewG = nullptr; + if (MergeFunctionsPDI) { + DEBUG(dbgs() << "writeThunk: (MergeFunctionsPDI) Do not create a new " + "function as thunk; retain original: " + << G->getName() << "()\n"); + GEntryBlock = &G->getEntryBlock(); + DEBUG(dbgs() << "writeThunk: (MergeFunctionsPDI) filter parameter related " + "debug info for " + << G->getName() << "() {\n"); + filterInstsUnrelatedToPDI(GEntryBlock, PDIUnrelatedWL); + GEntryBlock->getTerminator()->eraseFromParent(); + BB = GEntryBlock; + } else { + NewG = Function::Create(G->getFunctionType(), G->getLinkage(), "", + G->getParent()); + BB = BasicBlock::Create(F->getContext(), "", NewG); + } + IRBuilder<> Builder(BB); + Function *H = MergeFunctionsPDI ? G : NewG; SmallVector<Value *, 16> Args; unsigned i = 0; FunctionType *FFTy = F->getFunctionType(); - for (Argument & AI : NewG->args()) { + for (Argument & AI : H->args()) { Args.push_back(createCast(Builder, &AI, FFTy->getParamType(i))); ++i; } CallInst *CI = Builder.CreateCall(F, Args); + ReturnInst *RI = nullptr; CI->setTailCall(); CI->setCallingConv(F->getCallingConv()); CI->setAttributes(F->getAttributes()); - if (NewG->getReturnType()->isVoidTy()) { - Builder.CreateRetVoid(); + if (H->getReturnType()->isVoidTy()) { + RI = Builder.CreateRetVoid(); } else { - Builder.CreateRet(createCast(Builder, CI, NewG->getReturnType())); + RI = Builder.CreateRet(createCast(Builder, CI, H->getReturnType())); } - NewG->copyAttributesFrom(G); - NewG->takeName(G); - removeUsers(G); - G->replaceAllUsesWith(NewG); - G->eraseFromParent(); + if (MergeFunctionsPDI) { + DISubprogram *DIS = G->getSubprogram(); + if (DIS) { + DebugLoc CIDbgLoc = DebugLoc::get(DIS->getScopeLine(), 0, DIS); + DebugLoc RIDbgLoc = DebugLoc::get(DIS->getScopeLine(), 0, DIS); + CI->setDebugLoc(CIDbgLoc); + RI->setDebugLoc(RIDbgLoc); + } else { + DEBUG(dbgs() << "writeThunk: (MergeFunctionsPDI) No DISubprogram for " + << G->getName() << "()\n"); + } + eraseTail(G); + eraseInstsUnrelatedToPDI(PDIUnrelatedWL); + DEBUG(dbgs() << "} // End of parameter related debug info filtering for: " + << G->getName() << "()\n"); + } else { + NewG->copyAttributesFrom(G); + NewG->takeName(G); + removeUsers(G); + G->replaceAllUsesWith(NewG); + G->eraseFromParent(); + } - DEBUG(dbgs() << "writeThunk: " << NewG->getName() << '\n'); + DEBUG(dbgs() << "writeThunk: " << H->getName() << '\n'); ++NumThunksWritten; } diff --git a/contrib/llvm/lib/Transforms/IPO/PartialInlining.cpp b/contrib/llvm/lib/Transforms/IPO/PartialInlining.cpp index 7ef3fc1..8840435 100644 --- a/contrib/llvm/lib/Transforms/IPO/PartialInlining.cpp +++ b/contrib/llvm/lib/Transforms/IPO/PartialInlining.cpp @@ -16,8 +16,15 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/BranchProbabilityInfo.h" +#include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/InlineCost.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/OptimizationDiagnosticInfo.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/CFG.h" +#include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" @@ -27,19 +34,177 @@ #include "llvm/Transforms/Utils/CodeExtractor.h" using namespace llvm; -#define DEBUG_TYPE "partialinlining" - -STATISTIC(NumPartialInlined, "Number of functions partially inlined"); +#define DEBUG_TYPE "partial-inlining" + +STATISTIC(NumPartialInlined, + "Number of callsites functions partially inlined into."); + +// Command line option to disable partial-inlining. The default is false: +static cl::opt<bool> + DisablePartialInlining("disable-partial-inlining", cl::init(false), + cl::Hidden, cl::desc("Disable partial ininling")); +// This is an option used by testing: +static cl::opt<bool> SkipCostAnalysis("skip-partial-inlining-cost-analysis", + cl::init(false), cl::ZeroOrMore, + cl::ReallyHidden, + cl::desc("Skip Cost Analysis")); + +static cl::opt<unsigned> MaxNumInlineBlocks( + "max-num-inline-blocks", cl::init(5), cl::Hidden, + cl::desc("Max Number of Blocks To be Partially Inlined")); + +// Command line option to set the maximum number of partial inlining allowed +// for the module. The default value of -1 means no limit. +static cl::opt<int> MaxNumPartialInlining( + "max-partial-inlining", cl::init(-1), cl::Hidden, cl::ZeroOrMore, + cl::desc("Max number of partial inlining. The default is unlimited")); + +// Used only when PGO or user annotated branch data is absent. It is +// the least value that is used to weigh the outline region. If BFI +// produces larger value, the BFI value will be used. +static cl::opt<int> + OutlineRegionFreqPercent("outline-region-freq-percent", cl::init(75), + cl::Hidden, cl::ZeroOrMore, + cl::desc("Relative frequency of outline region to " + "the entry block")); + +static cl::opt<unsigned> ExtraOutliningPenalty( + "partial-inlining-extra-penalty", cl::init(0), cl::Hidden, + cl::desc("A debug option to add additional penalty to the computed one.")); namespace { + +struct FunctionOutliningInfo { + FunctionOutliningInfo() + : Entries(), ReturnBlock(nullptr), NonReturnBlock(nullptr), + ReturnBlockPreds() {} + // Returns the number of blocks to be inlined including all blocks + // in Entries and one return block. + unsigned GetNumInlinedBlocks() const { return Entries.size() + 1; } + + // A set of blocks including the function entry that guard + // the region to be outlined. + SmallVector<BasicBlock *, 4> Entries; + // The return block that is not included in the outlined region. + BasicBlock *ReturnBlock; + // The dominating block of the region to be outlined. + BasicBlock *NonReturnBlock; + // The set of blocks in Entries that that are predecessors to ReturnBlock + SmallVector<BasicBlock *, 4> ReturnBlockPreds; +}; + struct PartialInlinerImpl { - PartialInlinerImpl(InlineFunctionInfo IFI) : IFI(IFI) {} + PartialInlinerImpl( + std::function<AssumptionCache &(Function &)> *GetAC, + std::function<TargetTransformInfo &(Function &)> *GTTI, + Optional<function_ref<BlockFrequencyInfo &(Function &)>> GBFI, + ProfileSummaryInfo *ProfSI) + : GetAssumptionCache(GetAC), GetTTI(GTTI), GetBFI(GBFI), PSI(ProfSI) {} bool run(Module &M); Function *unswitchFunction(Function *F); + // This class speculatively clones the the function to be partial inlined. + // At the end of partial inlining, the remaining callsites to the cloned + // function that are not partially inlined will be fixed up to reference + // the original function, and the cloned function will be erased. + struct FunctionCloner { + FunctionCloner(Function *F, FunctionOutliningInfo *OI); + ~FunctionCloner(); + + // Prepare for function outlining: making sure there is only + // one incoming edge from the extracted/outlined region to + // the return block. + void NormalizeReturnBlock(); + + // Do function outlining: + Function *doFunctionOutlining(); + + Function *OrigFunc = nullptr; + Function *ClonedFunc = nullptr; + Function *OutlinedFunc = nullptr; + BasicBlock *OutliningCallBB = nullptr; + // ClonedFunc is inlined in one of its callers after function + // outlining. + bool IsFunctionInlined = false; + // The cost of the region to be outlined. + int OutlinedRegionCost = 0; + std::unique_ptr<FunctionOutliningInfo> ClonedOI = nullptr; + std::unique_ptr<BlockFrequencyInfo> ClonedFuncBFI = nullptr; + }; + private: - InlineFunctionInfo IFI; + int NumPartialInlining = 0; + std::function<AssumptionCache &(Function &)> *GetAssumptionCache; + std::function<TargetTransformInfo &(Function &)> *GetTTI; + Optional<function_ref<BlockFrequencyInfo &(Function &)>> GetBFI; + ProfileSummaryInfo *PSI; + + // Return the frequency of the OutlininingBB relative to F's entry point. + // The result is no larger than 1 and is represented using BP. + // (Note that the outlined region's 'head' block can only have incoming + // edges from the guarding entry blocks). + BranchProbability getOutliningCallBBRelativeFreq(FunctionCloner &Cloner); + + // Return true if the callee of CS should be partially inlined with + // profit. + bool shouldPartialInline(CallSite CS, FunctionCloner &Cloner, + BlockFrequency WeightedOutliningRcost, + OptimizationRemarkEmitter &ORE); + + // Try to inline DuplicateFunction (cloned from F with call to + // the OutlinedFunction into its callers. Return true + // if there is any successful inlining. + bool tryPartialInline(FunctionCloner &Cloner); + + // Compute the mapping from use site of DuplicationFunction to the enclosing + // BB's profile count. + void computeCallsiteToProfCountMap(Function *DuplicateFunction, + DenseMap<User *, uint64_t> &SiteCountMap); + + bool IsLimitReached() { + return (MaxNumPartialInlining != -1 && + NumPartialInlining >= MaxNumPartialInlining); + } + + static CallSite getCallSite(User *U) { + CallSite CS; + if (CallInst *CI = dyn_cast<CallInst>(U)) + CS = CallSite(CI); + else if (InvokeInst *II = dyn_cast<InvokeInst>(U)) + CS = CallSite(II); + else + llvm_unreachable("All uses must be calls"); + return CS; + } + + static CallSite getOneCallSiteTo(Function *F) { + User *User = *F->user_begin(); + return getCallSite(User); + } + + std::tuple<DebugLoc, BasicBlock *> getOneDebugLoc(Function *F) { + CallSite CS = getOneCallSiteTo(F); + DebugLoc DLoc = CS.getInstruction()->getDebugLoc(); + BasicBlock *Block = CS.getParent(); + return std::make_tuple(DLoc, Block); + } + + // Returns the costs associated with function outlining: + // - The first value is the non-weighted runtime cost for making the call + // to the outlined function, including the addtional setup cost in the + // outlined function itself; + // - The second value is the estimated size of the new call sequence in + // basic block Cloner.OutliningCallBB; + std::tuple<int, int> computeOutliningCosts(FunctionCloner &Cloner); + // Compute the 'InlineCost' of block BB. InlineCost is a proxy used to + // approximate both the size and runtime cost (Note that in the current + // inline cost analysis, there is no clear distinction there either). + static int computeBBInlineCost(BasicBlock *BB); + + std::unique_ptr<FunctionOutliningInfo> computeOutliningInfo(Function *F); + }; + struct PartialInlinerLegacyPass : public ModulePass { static char ID; // Pass identification, replacement for typeid PartialInlinerLegacyPass() : ModulePass(ID) { @@ -48,124 +213,713 @@ struct PartialInlinerLegacyPass : public ModulePass { void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<ProfileSummaryInfoWrapperPass>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); } bool runOnModule(Module &M) override { if (skipModule(M)) return false; AssumptionCacheTracker *ACT = &getAnalysis<AssumptionCacheTracker>(); + TargetTransformInfoWrapperPass *TTIWP = + &getAnalysis<TargetTransformInfoWrapperPass>(); + ProfileSummaryInfo *PSI = + getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); + std::function<AssumptionCache &(Function &)> GetAssumptionCache = [&ACT](Function &F) -> AssumptionCache & { return ACT->getAssumptionCache(F); }; - InlineFunctionInfo IFI(nullptr, &GetAssumptionCache); - return PartialInlinerImpl(IFI).run(M); + + std::function<TargetTransformInfo &(Function &)> GetTTI = + [&TTIWP](Function &F) -> TargetTransformInfo & { + return TTIWP->getTTI(F); + }; + + return PartialInlinerImpl(&GetAssumptionCache, &GetTTI, None, PSI).run(M); } }; } -Function *PartialInlinerImpl::unswitchFunction(Function *F) { - // First, verify that this function is an unswitching candidate... +std::unique_ptr<FunctionOutliningInfo> +PartialInlinerImpl::computeOutliningInfo(Function *F) { BasicBlock *EntryBlock = &F->front(); BranchInst *BR = dyn_cast<BranchInst>(EntryBlock->getTerminator()); if (!BR || BR->isUnconditional()) - return nullptr; + return std::unique_ptr<FunctionOutliningInfo>(); + + // Returns true if Succ is BB's successor + auto IsSuccessor = [](BasicBlock *Succ, BasicBlock *BB) { + return is_contained(successors(BB), Succ); + }; + + auto SuccSize = [](BasicBlock *BB) { + return std::distance(succ_begin(BB), succ_end(BB)); + }; + + auto IsReturnBlock = [](BasicBlock *BB) { + TerminatorInst *TI = BB->getTerminator(); + return isa<ReturnInst>(TI); + }; + + auto GetReturnBlock = [&](BasicBlock *Succ1, BasicBlock *Succ2) { + if (IsReturnBlock(Succ1)) + return std::make_tuple(Succ1, Succ2); + if (IsReturnBlock(Succ2)) + return std::make_tuple(Succ2, Succ1); + + return std::make_tuple<BasicBlock *, BasicBlock *>(nullptr, nullptr); + }; + + // Detect a triangular shape: + auto GetCommonSucc = [&](BasicBlock *Succ1, BasicBlock *Succ2) { + if (IsSuccessor(Succ1, Succ2)) + return std::make_tuple(Succ1, Succ2); + if (IsSuccessor(Succ2, Succ1)) + return std::make_tuple(Succ2, Succ1); + + return std::make_tuple<BasicBlock *, BasicBlock *>(nullptr, nullptr); + }; + + std::unique_ptr<FunctionOutliningInfo> OutliningInfo = + llvm::make_unique<FunctionOutliningInfo>(); + + BasicBlock *CurrEntry = EntryBlock; + bool CandidateFound = false; + do { + // The number of blocks to be inlined has already reached + // the limit. When MaxNumInlineBlocks is set to 0 or 1, this + // disables partial inlining for the function. + if (OutliningInfo->GetNumInlinedBlocks() >= MaxNumInlineBlocks) + break; + + if (SuccSize(CurrEntry) != 2) + break; + + BasicBlock *Succ1 = *succ_begin(CurrEntry); + BasicBlock *Succ2 = *(succ_begin(CurrEntry) + 1); + + BasicBlock *ReturnBlock, *NonReturnBlock; + std::tie(ReturnBlock, NonReturnBlock) = GetReturnBlock(Succ1, Succ2); + + if (ReturnBlock) { + OutliningInfo->Entries.push_back(CurrEntry); + OutliningInfo->ReturnBlock = ReturnBlock; + OutliningInfo->NonReturnBlock = NonReturnBlock; + CandidateFound = true; + break; + } + + BasicBlock *CommSucc; + BasicBlock *OtherSucc; + std::tie(CommSucc, OtherSucc) = GetCommonSucc(Succ1, Succ2); + + if (!CommSucc) + break; - BasicBlock *ReturnBlock = nullptr; - BasicBlock *NonReturnBlock = nullptr; - unsigned ReturnCount = 0; - for (BasicBlock *BB : successors(EntryBlock)) { - if (isa<ReturnInst>(BB->getTerminator())) { - ReturnBlock = BB; - ReturnCount++; - } else - NonReturnBlock = BB; + OutliningInfo->Entries.push_back(CurrEntry); + CurrEntry = OtherSucc; + + } while (true); + + if (!CandidateFound) + return std::unique_ptr<FunctionOutliningInfo>(); + + // Do sanity check of the entries: threre should not + // be any successors (not in the entry set) other than + // {ReturnBlock, NonReturnBlock} + assert(OutliningInfo->Entries[0] == &F->front() && + "Function Entry must be the first in Entries vector"); + DenseSet<BasicBlock *> Entries; + for (BasicBlock *E : OutliningInfo->Entries) + Entries.insert(E); + + // Returns true of BB has Predecessor which is not + // in Entries set. + auto HasNonEntryPred = [Entries](BasicBlock *BB) { + for (auto Pred : predecessors(BB)) { + if (!Entries.count(Pred)) + return true; + } + return false; + }; + auto CheckAndNormalizeCandidate = + [Entries, HasNonEntryPred](FunctionOutliningInfo *OutliningInfo) { + for (BasicBlock *E : OutliningInfo->Entries) { + for (auto Succ : successors(E)) { + if (Entries.count(Succ)) + continue; + if (Succ == OutliningInfo->ReturnBlock) + OutliningInfo->ReturnBlockPreds.push_back(E); + else if (Succ != OutliningInfo->NonReturnBlock) + return false; + } + // There should not be any outside incoming edges either: + if (HasNonEntryPred(E)) + return false; + } + return true; + }; + + if (!CheckAndNormalizeCandidate(OutliningInfo.get())) + return std::unique_ptr<FunctionOutliningInfo>(); + + // Now further growing the candidate's inlining region by + // peeling off dominating blocks from the outlining region: + while (OutliningInfo->GetNumInlinedBlocks() < MaxNumInlineBlocks) { + BasicBlock *Cand = OutliningInfo->NonReturnBlock; + if (SuccSize(Cand) != 2) + break; + + if (HasNonEntryPred(Cand)) + break; + + BasicBlock *Succ1 = *succ_begin(Cand); + BasicBlock *Succ2 = *(succ_begin(Cand) + 1); + + BasicBlock *ReturnBlock, *NonReturnBlock; + std::tie(ReturnBlock, NonReturnBlock) = GetReturnBlock(Succ1, Succ2); + if (!ReturnBlock || ReturnBlock != OutliningInfo->ReturnBlock) + break; + + if (NonReturnBlock->getSinglePredecessor() != Cand) + break; + + // Now grow and update OutlininigInfo: + OutliningInfo->Entries.push_back(Cand); + OutliningInfo->NonReturnBlock = NonReturnBlock; + OutliningInfo->ReturnBlockPreds.push_back(Cand); + Entries.insert(Cand); } - if (ReturnCount != 1) - return nullptr; + return OutliningInfo; +} + +// Check if there is PGO data or user annoated branch data: +static bool hasProfileData(Function *F, FunctionOutliningInfo *OI) { + if (F->getEntryCount()) + return true; + // Now check if any of the entry block has MD_prof data: + for (auto *E : OI->Entries) { + BranchInst *BR = dyn_cast<BranchInst>(E->getTerminator()); + if (!BR || BR->isUnconditional()) + continue; + uint64_t T, F; + if (BR->extractProfMetadata(T, F)) + return true; + } + return false; +} + +BranchProbability +PartialInlinerImpl::getOutliningCallBBRelativeFreq(FunctionCloner &Cloner) { + + auto EntryFreq = + Cloner.ClonedFuncBFI->getBlockFreq(&Cloner.ClonedFunc->getEntryBlock()); + auto OutliningCallFreq = + Cloner.ClonedFuncBFI->getBlockFreq(Cloner.OutliningCallBB); + + auto OutlineRegionRelFreq = + BranchProbability::getBranchProbability(OutliningCallFreq.getFrequency(), + EntryFreq.getFrequency()); + + if (hasProfileData(Cloner.OrigFunc, Cloner.ClonedOI.get())) + return OutlineRegionRelFreq; + + // When profile data is not available, we need to be conservative in + // estimating the overall savings. Static branch prediction can usually + // guess the branch direction right (taken/non-taken), but the guessed + // branch probability is usually not biased enough. In case when the + // outlined region is predicted to be likely, its probability needs + // to be made higher (more biased) to not under-estimate the cost of + // function outlining. On the other hand, if the outlined region + // is predicted to be less likely, the predicted probablity is usually + // higher than the actual. For instance, the actual probability of the + // less likely target is only 5%, but the guessed probablity can be + // 40%. In the latter case, there is no need for further adjustement. + // FIXME: add an option for this. + if (OutlineRegionRelFreq < BranchProbability(45, 100)) + return OutlineRegionRelFreq; + + OutlineRegionRelFreq = std::max( + OutlineRegionRelFreq, BranchProbability(OutlineRegionFreqPercent, 100)); + + return OutlineRegionRelFreq; +} + +bool PartialInlinerImpl::shouldPartialInline( + CallSite CS, FunctionCloner &Cloner, BlockFrequency WeightedOutliningRcost, + OptimizationRemarkEmitter &ORE) { + + using namespace ore; + if (SkipCostAnalysis) + return true; + + Instruction *Call = CS.getInstruction(); + Function *Callee = CS.getCalledFunction(); + assert(Callee == Cloner.ClonedFunc); + + Function *Caller = CS.getCaller(); + auto &CalleeTTI = (*GetTTI)(*Callee); + InlineCost IC = getInlineCost(CS, getInlineParams(), CalleeTTI, + *GetAssumptionCache, GetBFI, PSI); + + if (IC.isAlways()) { + ORE.emit(OptimizationRemarkAnalysis(DEBUG_TYPE, "AlwaysInline", Call) + << NV("Callee", Cloner.OrigFunc) + << " should always be fully inlined, not partially"); + return false; + } + + if (IC.isNever()) { + ORE.emit(OptimizationRemarkMissed(DEBUG_TYPE, "NeverInline", Call) + << NV("Callee", Cloner.OrigFunc) << " not partially inlined into " + << NV("Caller", Caller) + << " because it should never be inlined (cost=never)"); + return false; + } + + if (!IC) { + ORE.emit(OptimizationRemarkAnalysis(DEBUG_TYPE, "TooCostly", Call) + << NV("Callee", Cloner.OrigFunc) << " not partially inlined into " + << NV("Caller", Caller) << " because too costly to inline (cost=" + << NV("Cost", IC.getCost()) << ", threshold=" + << NV("Threshold", IC.getCostDelta() + IC.getCost()) << ")"); + return false; + } + const DataLayout &DL = Caller->getParent()->getDataLayout(); + + // The savings of eliminating the call: + int NonWeightedSavings = getCallsiteCost(CS, DL); + BlockFrequency NormWeightedSavings(NonWeightedSavings); + + // Weighted saving is smaller than weighted cost, return false + if (NormWeightedSavings < WeightedOutliningRcost) { + ORE.emit( + OptimizationRemarkAnalysis(DEBUG_TYPE, "OutliningCallcostTooHigh", Call) + << NV("Callee", Cloner.OrigFunc) << " not partially inlined into " + << NV("Caller", Caller) << " runtime overhead (overhead=" + << NV("Overhead", (unsigned)WeightedOutliningRcost.getFrequency()) + << ", savings=" + << NV("Savings", (unsigned)NormWeightedSavings.getFrequency()) << ")" + << " of making the outlined call is too high"); + + return false; + } + + ORE.emit(OptimizationRemarkAnalysis(DEBUG_TYPE, "CanBePartiallyInlined", Call) + << NV("Callee", Cloner.OrigFunc) << " can be partially inlined into " + << NV("Caller", Caller) << " with cost=" << NV("Cost", IC.getCost()) + << " (threshold=" + << NV("Threshold", IC.getCostDelta() + IC.getCost()) << ")"); + return true; +} + +// TODO: Ideally we should share Inliner's InlineCost Analysis code. +// For now use a simplified version. The returned 'InlineCost' will be used +// to esimate the size cost as well as runtime cost of the BB. +int PartialInlinerImpl::computeBBInlineCost(BasicBlock *BB) { + int InlineCost = 0; + const DataLayout &DL = BB->getParent()->getParent()->getDataLayout(); + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) { + if (isa<DbgInfoIntrinsic>(I)) + continue; + + switch (I->getOpcode()) { + case Instruction::BitCast: + case Instruction::PtrToInt: + case Instruction::IntToPtr: + case Instruction::Alloca: + continue; + case Instruction::GetElementPtr: + if (cast<GetElementPtrInst>(I)->hasAllZeroIndices()) + continue; + default: + break; + } + + IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(I); + if (IntrInst) { + if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_start || + IntrInst->getIntrinsicID() == Intrinsic::lifetime_end) + continue; + } + + if (CallInst *CI = dyn_cast<CallInst>(I)) { + InlineCost += getCallsiteCost(CallSite(CI), DL); + continue; + } + + if (InvokeInst *II = dyn_cast<InvokeInst>(I)) { + InlineCost += getCallsiteCost(CallSite(II), DL); + continue; + } + + if (SwitchInst *SI = dyn_cast<SwitchInst>(I)) { + InlineCost += (SI->getNumCases() + 1) * InlineConstants::InstrCost; + continue; + } + InlineCost += InlineConstants::InstrCost; + } + return InlineCost; +} + +std::tuple<int, int> +PartialInlinerImpl::computeOutliningCosts(FunctionCloner &Cloner) { + + // Now compute the cost of the call sequence to the outlined function + // 'OutlinedFunction' in BB 'OutliningCallBB': + int OutliningFuncCallCost = computeBBInlineCost(Cloner.OutliningCallBB); + + // Now compute the cost of the extracted/outlined function itself: + int OutlinedFunctionCost = 0; + for (BasicBlock &BB : *Cloner.OutlinedFunc) { + OutlinedFunctionCost += computeBBInlineCost(&BB); + } + + assert(OutlinedFunctionCost >= Cloner.OutlinedRegionCost && + "Outlined function cost should be no less than the outlined region"); + // The code extractor introduces a new root and exit stub blocks with + // additional unconditional branches. Those branches will be eliminated + // later with bb layout. The cost should be adjusted accordingly: + OutlinedFunctionCost -= 2 * InlineConstants::InstrCost; + + int OutliningRuntimeOverhead = + OutliningFuncCallCost + + (OutlinedFunctionCost - Cloner.OutlinedRegionCost) + + ExtraOutliningPenalty; + + return std::make_tuple(OutliningFuncCallCost, OutliningRuntimeOverhead); +} + +// Create the callsite to profile count map which is +// used to update the original function's entry count, +// after the function is partially inlined into the callsite. +void PartialInlinerImpl::computeCallsiteToProfCountMap( + Function *DuplicateFunction, + DenseMap<User *, uint64_t> &CallSiteToProfCountMap) { + std::vector<User *> Users(DuplicateFunction->user_begin(), + DuplicateFunction->user_end()); + Function *CurrentCaller = nullptr; + std::unique_ptr<BlockFrequencyInfo> TempBFI; + BlockFrequencyInfo *CurrentCallerBFI = nullptr; + + auto ComputeCurrBFI = [&,this](Function *Caller) { + // For the old pass manager: + if (!GetBFI) { + DominatorTree DT(*Caller); + LoopInfo LI(DT); + BranchProbabilityInfo BPI(*Caller, LI); + TempBFI.reset(new BlockFrequencyInfo(*Caller, BPI, LI)); + CurrentCallerBFI = TempBFI.get(); + } else { + // New pass manager: + CurrentCallerBFI = &(*GetBFI)(*Caller); + } + }; + + for (User *User : Users) { + CallSite CS = getCallSite(User); + Function *Caller = CS.getCaller(); + if (CurrentCaller != Caller) { + CurrentCaller = Caller; + ComputeCurrBFI(Caller); + } else { + assert(CurrentCallerBFI && "CallerBFI is not set"); + } + BasicBlock *CallBB = CS.getInstruction()->getParent(); + auto Count = CurrentCallerBFI->getBlockProfileCount(CallBB); + if (Count) + CallSiteToProfCountMap[User] = *Count; + else + CallSiteToProfCountMap[User] = 0; + } +} + +PartialInlinerImpl::FunctionCloner::FunctionCloner(Function *F, + FunctionOutliningInfo *OI) + : OrigFunc(F) { + ClonedOI = llvm::make_unique<FunctionOutliningInfo>(); // Clone the function, so that we can hack away on it. ValueToValueMapTy VMap; - Function *DuplicateFunction = CloneFunction(F, VMap); - DuplicateFunction->setLinkage(GlobalValue::InternalLinkage); - BasicBlock *NewEntryBlock = cast<BasicBlock>(VMap[EntryBlock]); - BasicBlock *NewReturnBlock = cast<BasicBlock>(VMap[ReturnBlock]); - BasicBlock *NewNonReturnBlock = cast<BasicBlock>(VMap[NonReturnBlock]); + ClonedFunc = CloneFunction(F, VMap); + ClonedOI->ReturnBlock = cast<BasicBlock>(VMap[OI->ReturnBlock]); + ClonedOI->NonReturnBlock = cast<BasicBlock>(VMap[OI->NonReturnBlock]); + for (BasicBlock *BB : OI->Entries) { + ClonedOI->Entries.push_back(cast<BasicBlock>(VMap[BB])); + } + for (BasicBlock *E : OI->ReturnBlockPreds) { + BasicBlock *NewE = cast<BasicBlock>(VMap[E]); + ClonedOI->ReturnBlockPreds.push_back(NewE); + } // Go ahead and update all uses to the duplicate, so that we can just // use the inliner functionality when we're done hacking. - F->replaceAllUsesWith(DuplicateFunction); + F->replaceAllUsesWith(ClonedFunc); +} + +void PartialInlinerImpl::FunctionCloner::NormalizeReturnBlock() { + + auto getFirstPHI = [](BasicBlock *BB) { + BasicBlock::iterator I = BB->begin(); + PHINode *FirstPhi = nullptr; + while (I != BB->end()) { + PHINode *Phi = dyn_cast<PHINode>(I); + if (!Phi) + break; + if (!FirstPhi) { + FirstPhi = Phi; + break; + } + } + return FirstPhi; + }; // Special hackery is needed with PHI nodes that have inputs from more than // one extracted block. For simplicity, just split the PHIs into a two-level // sequence of PHIs, some of which will go in the extracted region, and some // of which will go outside. - BasicBlock *PreReturn = NewReturnBlock; - NewReturnBlock = NewReturnBlock->splitBasicBlock( - NewReturnBlock->getFirstNonPHI()->getIterator()); + BasicBlock *PreReturn = ClonedOI->ReturnBlock; + // only split block when necessary: + PHINode *FirstPhi = getFirstPHI(PreReturn); + unsigned NumPredsFromEntries = ClonedOI->ReturnBlockPreds.size(); + + if (!FirstPhi || FirstPhi->getNumIncomingValues() <= NumPredsFromEntries + 1) + return; + + auto IsTrivialPhi = [](PHINode *PN) -> Value * { + Value *CommonValue = PN->getIncomingValue(0); + if (all_of(PN->incoming_values(), + [&](Value *V) { return V == CommonValue; })) + return CommonValue; + return nullptr; + }; + + ClonedOI->ReturnBlock = ClonedOI->ReturnBlock->splitBasicBlock( + ClonedOI->ReturnBlock->getFirstNonPHI()->getIterator()); BasicBlock::iterator I = PreReturn->begin(); - Instruction *Ins = &NewReturnBlock->front(); + Instruction *Ins = &ClonedOI->ReturnBlock->front(); + SmallVector<Instruction *, 4> DeadPhis; while (I != PreReturn->end()) { PHINode *OldPhi = dyn_cast<PHINode>(I); if (!OldPhi) break; - PHINode *RetPhi = PHINode::Create(OldPhi->getType(), 2, "", Ins); + PHINode *RetPhi = + PHINode::Create(OldPhi->getType(), NumPredsFromEntries + 1, "", Ins); OldPhi->replaceAllUsesWith(RetPhi); - Ins = NewReturnBlock->getFirstNonPHI(); + Ins = ClonedOI->ReturnBlock->getFirstNonPHI(); RetPhi->addIncoming(&*I, PreReturn); - RetPhi->addIncoming(OldPhi->getIncomingValueForBlock(NewEntryBlock), - NewEntryBlock); - OldPhi->removeIncomingValue(NewEntryBlock); + for (BasicBlock *E : ClonedOI->ReturnBlockPreds) { + RetPhi->addIncoming(OldPhi->getIncomingValueForBlock(E), E); + OldPhi->removeIncomingValue(E); + } + // After incoming values splitting, the old phi may become trivial. + // Keeping the trivial phi can introduce definition inside the outline + // region which is live-out, causing necessary overhead (load, store + // arg passing etc). + if (auto *OldPhiVal = IsTrivialPhi(OldPhi)) { + OldPhi->replaceAllUsesWith(OldPhiVal); + DeadPhis.push_back(OldPhi); + } ++I; - } - NewEntryBlock->getTerminator()->replaceUsesOfWith(PreReturn, NewReturnBlock); + } + for (auto *DP : DeadPhis) + DP->eraseFromParent(); + + for (auto E : ClonedOI->ReturnBlockPreds) { + E->getTerminator()->replaceUsesOfWith(PreReturn, ClonedOI->ReturnBlock); + } +} + +Function *PartialInlinerImpl::FunctionCloner::doFunctionOutlining() { + // Returns true if the block is to be partial inlined into the caller + // (i.e. not to be extracted to the out of line function) + auto ToBeInlined = [&, this](BasicBlock *BB) { + return BB == ClonedOI->ReturnBlock || + (std::find(ClonedOI->Entries.begin(), ClonedOI->Entries.end(), BB) != + ClonedOI->Entries.end()); + }; // Gather up the blocks that we're going to extract. std::vector<BasicBlock *> ToExtract; - ToExtract.push_back(NewNonReturnBlock); - for (BasicBlock &BB : *DuplicateFunction) - if (&BB != NewEntryBlock && &BB != NewReturnBlock && - &BB != NewNonReturnBlock) + ToExtract.push_back(ClonedOI->NonReturnBlock); + OutlinedRegionCost += + PartialInlinerImpl::computeBBInlineCost(ClonedOI->NonReturnBlock); + for (BasicBlock &BB : *ClonedFunc) + if (!ToBeInlined(&BB) && &BB != ClonedOI->NonReturnBlock) { ToExtract.push_back(&BB); + // FIXME: the code extractor may hoist/sink more code + // into the outlined function which may make the outlining + // overhead (the difference of the outlined function cost + // and OutliningRegionCost) look larger. + OutlinedRegionCost += computeBBInlineCost(&BB); + } // The CodeExtractor needs a dominator tree. DominatorTree DT; - DT.recalculate(*DuplicateFunction); + DT.recalculate(*ClonedFunc); // Manually calculate a BlockFrequencyInfo and BranchProbabilityInfo. LoopInfo LI(DT); - BranchProbabilityInfo BPI(*DuplicateFunction, LI); - BlockFrequencyInfo BFI(*DuplicateFunction, BPI, LI); + BranchProbabilityInfo BPI(*ClonedFunc, LI); + ClonedFuncBFI.reset(new BlockFrequencyInfo(*ClonedFunc, BPI, LI)); // Extract the body of the if. - Function *ExtractedFunction = - CodeExtractor(ToExtract, &DT, /*AggregateArgs*/ false, &BFI, &BPI) - .extractCodeRegion(); + OutlinedFunc = CodeExtractor(ToExtract, &DT, /*AggregateArgs*/ false, + ClonedFuncBFI.get(), &BPI) + .extractCodeRegion(); + + if (OutlinedFunc) { + OutliningCallBB = PartialInlinerImpl::getOneCallSiteTo(OutlinedFunc) + .getInstruction() + ->getParent(); + assert(OutliningCallBB->getParent() == ClonedFunc); + } - // Inline the top-level if test into all callers. - std::vector<User *> Users(DuplicateFunction->user_begin(), - DuplicateFunction->user_end()); - for (User *User : Users) - if (CallInst *CI = dyn_cast<CallInst>(User)) - InlineFunction(CI, IFI); - else if (InvokeInst *II = dyn_cast<InvokeInst>(User)) - InlineFunction(II, IFI); + return OutlinedFunc; +} +PartialInlinerImpl::FunctionCloner::~FunctionCloner() { // Ditch the duplicate, since we're done with it, and rewrite all remaining // users (function pointers, etc.) back to the original function. - DuplicateFunction->replaceAllUsesWith(F); - DuplicateFunction->eraseFromParent(); + ClonedFunc->replaceAllUsesWith(OrigFunc); + ClonedFunc->eraseFromParent(); + if (!IsFunctionInlined) { + // Remove the function that is speculatively created if there is no + // reference. + if (OutlinedFunc) + OutlinedFunc->eraseFromParent(); + } +} - ++NumPartialInlined; +Function *PartialInlinerImpl::unswitchFunction(Function *F) { + + if (F->hasAddressTaken()) + return nullptr; + + // Let inliner handle it + if (F->hasFnAttribute(Attribute::AlwaysInline)) + return nullptr; + + if (F->hasFnAttribute(Attribute::NoInline)) + return nullptr; + + if (PSI->isFunctionEntryCold(F)) + return nullptr; + + if (F->user_begin() == F->user_end()) + return nullptr; + + std::unique_ptr<FunctionOutliningInfo> OI = computeOutliningInfo(F); - return ExtractedFunction; + if (!OI) + return nullptr; + + FunctionCloner Cloner(F, OI.get()); + Cloner.NormalizeReturnBlock(); + Function *OutlinedFunction = Cloner.doFunctionOutlining(); + + bool AnyInline = tryPartialInline(Cloner); + + if (AnyInline) + return OutlinedFunction; + + return nullptr; +} + +bool PartialInlinerImpl::tryPartialInline(FunctionCloner &Cloner) { + int NonWeightedRcost; + int SizeCost; + + if (Cloner.OutlinedFunc == nullptr) + return false; + + std::tie(SizeCost, NonWeightedRcost) = computeOutliningCosts(Cloner); + + auto RelativeToEntryFreq = getOutliningCallBBRelativeFreq(Cloner); + auto WeightedRcost = BlockFrequency(NonWeightedRcost) * RelativeToEntryFreq; + + // The call sequence to the outlined function is larger than the original + // outlined region size, it does not increase the chances of inlining + // the function with outlining (The inliner usies the size increase to + // model the cost of inlining a callee). + if (!SkipCostAnalysis && Cloner.OutlinedRegionCost < SizeCost) { + OptimizationRemarkEmitter ORE(Cloner.OrigFunc); + DebugLoc DLoc; + BasicBlock *Block; + std::tie(DLoc, Block) = getOneDebugLoc(Cloner.ClonedFunc); + ORE.emit(OptimizationRemarkAnalysis(DEBUG_TYPE, "OutlineRegionTooSmall", + DLoc, Block) + << ore::NV("Function", Cloner.OrigFunc) + << " not partially inlined into callers (Original Size = " + << ore::NV("OutlinedRegionOriginalSize", Cloner.OutlinedRegionCost) + << ", Size of call sequence to outlined function = " + << ore::NV("NewSize", SizeCost) << ")"); + return false; + } + + assert(Cloner.OrigFunc->user_begin() == Cloner.OrigFunc->user_end() && + "F's users should all be replaced!"); + + std::vector<User *> Users(Cloner.ClonedFunc->user_begin(), + Cloner.ClonedFunc->user_end()); + + DenseMap<User *, uint64_t> CallSiteToProfCountMap; + if (Cloner.OrigFunc->getEntryCount()) + computeCallsiteToProfCountMap(Cloner.ClonedFunc, CallSiteToProfCountMap); + + auto CalleeEntryCount = Cloner.OrigFunc->getEntryCount(); + uint64_t CalleeEntryCountV = (CalleeEntryCount ? *CalleeEntryCount : 0); + + bool AnyInline = false; + for (User *User : Users) { + CallSite CS = getCallSite(User); + + if (IsLimitReached()) + continue; + + OptimizationRemarkEmitter ORE(CS.getCaller()); + + if (!shouldPartialInline(CS, Cloner, WeightedRcost, ORE)) + continue; + + ORE.emit( + OptimizationRemark(DEBUG_TYPE, "PartiallyInlined", CS.getInstruction()) + << ore::NV("Callee", Cloner.OrigFunc) << " partially inlined into " + << ore::NV("Caller", CS.getCaller())); + + InlineFunctionInfo IFI(nullptr, GetAssumptionCache, PSI); + InlineFunction(CS, IFI); + + // Now update the entry count: + if (CalleeEntryCountV && CallSiteToProfCountMap.count(User)) { + uint64_t CallSiteCount = CallSiteToProfCountMap[User]; + CalleeEntryCountV -= std::min(CalleeEntryCountV, CallSiteCount); + } + + AnyInline = true; + NumPartialInlining++; + // Update the stats + NumPartialInlined++; + } + + if (AnyInline) { + Cloner.IsFunctionInlined = true; + if (CalleeEntryCount) + Cloner.OrigFunc->setEntryCount(CalleeEntryCountV); + } + + return AnyInline; } bool PartialInlinerImpl::run(Module &M) { + if (DisablePartialInlining) + return false; + std::vector<Function *> Worklist; Worklist.reserve(M.size()); for (Function &F : M) @@ -203,6 +957,8 @@ char PartialInlinerLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(PartialInlinerLegacyPass, "partial-inliner", "Partial Inliner", false, false) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_END(PartialInlinerLegacyPass, "partial-inliner", "Partial Inliner", false, false) @@ -213,12 +969,25 @@ ModulePass *llvm::createPartialInliningPass() { PreservedAnalyses PartialInlinerPass::run(Module &M, ModuleAnalysisManager &AM) { auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + std::function<AssumptionCache &(Function &)> GetAssumptionCache = [&FAM](Function &F) -> AssumptionCache & { return FAM.getResult<AssumptionAnalysis>(F); }; - InlineFunctionInfo IFI(nullptr, &GetAssumptionCache); - if (PartialInlinerImpl(IFI).run(M)) + + std::function<BlockFrequencyInfo &(Function &)> GetBFI = + [&FAM](Function &F) -> BlockFrequencyInfo & { + return FAM.getResult<BlockFrequencyAnalysis>(F); + }; + + std::function<TargetTransformInfo &(Function &)> GetTTI = + [&FAM](Function &F) -> TargetTransformInfo & { + return FAM.getResult<TargetIRAnalysis>(F); + }; + + ProfileSummaryInfo *PSI = &AM.getResult<ProfileSummaryAnalysis>(M); + + if (PartialInlinerImpl(&GetAssumptionCache, &GetTTI, {GetBFI}, PSI).run(M)) return PreservedAnalyses::none(); return PreservedAnalyses::all(); } diff --git a/contrib/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp b/contrib/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp index 941efb2..0b319f6 100644 --- a/contrib/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp +++ b/contrib/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp @@ -38,21 +38,22 @@ #include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/GVN.h" +#include "llvm/Transforms/Scalar/SimpleLoopUnswitch.h" #include "llvm/Transforms/Vectorize.h" using namespace llvm; static cl::opt<bool> -RunLoopVectorization("vectorize-loops", cl::Hidden, - cl::desc("Run the Loop vectorization passes")); + RunPartialInlining("enable-partial-inlining", cl::init(false), cl::Hidden, + cl::ZeroOrMore, cl::desc("Run Partial inlinining pass")); static cl::opt<bool> -RunSLPVectorization("vectorize-slp", cl::Hidden, - cl::desc("Run the SLP vectorization passes")); + RunLoopVectorization("vectorize-loops", cl::Hidden, + cl::desc("Run the Loop vectorization passes")); static cl::opt<bool> -RunBBVectorization("vectorize-slp-aggressive", cl::Hidden, - cl::desc("Run the BB vectorization passes")); +RunSLPVectorization("vectorize-slp", cl::Hidden, + cl::desc("Run the SLP vectorization passes")); static cl::opt<bool> UseGVNAfterVectorization("use-gvn-after-vectorization", @@ -67,10 +68,6 @@ static cl::opt<bool> RunLoopRerolling("reroll-loops", cl::Hidden, cl::desc("Run the loop rerolling pass")); -static cl::opt<bool> RunLoadCombine("combine-loads", cl::init(false), - cl::Hidden, - cl::desc("Run the load combining pass")); - static cl::opt<bool> RunNewGVN("enable-newgvn", cl::init(false), cl::Hidden, cl::desc("Run the NewGVN pass")); @@ -93,10 +90,6 @@ static cl::opt<CFLAAType> clEnumValN(CFLAAType::Both, "both", "Enable both variants of CFL-AA"))); -static cl::opt<bool> -EnableMLSM("mlsm", cl::init(true), cl::Hidden, - cl::desc("Enable motion of merged load and store")); - static cl::opt<bool> EnableLoopInterchange( "enable-loopinterchange", cl::init(false), cl::Hidden, cl::desc("Enable the new, experimental LoopInterchange Pass")); @@ -140,15 +133,28 @@ static cl::opt<int> PreInlineThreshold( cl::desc("Control the amount of inlining in pre-instrumentation inliner " "(default = 75)")); +static cl::opt<bool> EnableEarlyCSEMemSSA( + "enable-earlycse-memssa", cl::init(true), cl::Hidden, + cl::desc("Enable the EarlyCSE w/ MemorySSA pass (default = on)")); + static cl::opt<bool> EnableGVNHoist( "enable-gvn-hoist", cl::init(false), cl::Hidden, - cl::desc("Enable the GVN hoisting pass")); + cl::desc("Enable the GVN hoisting pass (default = off)")); static cl::opt<bool> DisableLibCallsShrinkWrap("disable-libcalls-shrinkwrap", cl::init(false), cl::Hidden, cl::desc("Disable shrink-wrap library calls")); +static cl::opt<bool> + EnableSimpleLoopUnswitch("enable-simple-loop-unswitch", cl::init(false), + cl::Hidden, + cl::desc("Enable the simple loop unswitch pass.")); + +static cl::opt<bool> EnableGVNSink( + "enable-gvn-sink", cl::init(false), cl::Hidden, + cl::desc("Enable the GVN sinking pass (default = off)")); + PassManagerBuilder::PassManagerBuilder() { OptLevel = 2; SizeLevel = 0; @@ -156,11 +162,9 @@ PassManagerBuilder::PassManagerBuilder() { Inliner = nullptr; DisableUnitAtATime = false; DisableUnrollLoops = false; - BBVectorize = RunBBVectorization; SLPVectorize = RunSLPVectorization; LoopVectorize = RunLoopVectorization; RerollLoops = RunLoopRerolling; - LoadCombine = RunLoadCombine; NewGVN = RunNewGVN; DisableGVNLoadPRE = false; VerifyInput = false; @@ -172,6 +176,7 @@ PassManagerBuilder::PassManagerBuilder() { PGOInstrUse = RunPGOInstrUse; PrepareForThinLTO = EnablePrepareForThinLTO; PerformThinLTO = false; + DivergentTarget = false; } PassManagerBuilder::~PassManagerBuilder() { @@ -183,6 +188,13 @@ PassManagerBuilder::~PassManagerBuilder() { static ManagedStatic<SmallVector<std::pair<PassManagerBuilder::ExtensionPointTy, PassManagerBuilder::ExtensionFn>, 8> > GlobalExtensions; +/// Check if GlobalExtensions is constructed and not empty. +/// Since GlobalExtensions is a managed static, calling 'empty()' will trigger +/// the construction of the object. +static bool GlobalExtensionsNotEmpty() { + return GlobalExtensions.isConstructed() && !GlobalExtensions->empty(); +} + void PassManagerBuilder::addGlobalExtension( PassManagerBuilder::ExtensionPointTy Ty, PassManagerBuilder::ExtensionFn Fn) { @@ -195,9 +207,12 @@ void PassManagerBuilder::addExtension(ExtensionPointTy Ty, ExtensionFn Fn) { void PassManagerBuilder::addExtensionsToPM(ExtensionPointTy ETy, legacy::PassManagerBase &PM) const { - for (unsigned i = 0, e = GlobalExtensions->size(); i != e; ++i) - if ((*GlobalExtensions)[i].first == ETy) - (*GlobalExtensions)[i].second(*this, PM); + if (GlobalExtensionsNotEmpty()) { + for (auto &Ext : *GlobalExtensions) { + if (Ext.first == ETy) + Ext.second(*this, PM); + } + } for (unsigned i = 0, e = Extensions.size(); i != e; ++i) if (Extensions[i].first == ETy) Extensions[i].second(*this, PM); @@ -248,18 +263,17 @@ void PassManagerBuilder::populateFunctionPassManager( FPM.add(createCFGSimplificationPass()); FPM.add(createSROAPass()); FPM.add(createEarlyCSEPass()); - if(EnableGVNHoist) - FPM.add(createGVNHoistPass()); FPM.add(createLowerExpectIntrinsicPass()); } // Do PGO instrumentation generation or use pass as the option specified. void PassManagerBuilder::addPGOInstrPasses(legacy::PassManagerBase &MPM) { - if (!EnablePGOInstrGen && PGOInstrUse.empty()) + if (!EnablePGOInstrGen && PGOInstrUse.empty() && PGOSampleUse.empty()) return; // Perform the preinline and cleanup passes for O1 and above. // And avoid doing them if optimizing for size. - if (OptLevel > 0 && SizeLevel == 0 && !DisablePreInliner) { + if (OptLevel > 0 && SizeLevel == 0 && !DisablePreInliner && + PGOSampleUse.empty()) { // Create preinline pass. We construct an InlineParams object and specify // the threshold here to avoid the command line options of the regular // inliner to influence pre-inlining. The only fields of InlineParams we @@ -283,17 +297,32 @@ void PassManagerBuilder::addPGOInstrPasses(legacy::PassManagerBase &MPM) { InstrProfOptions Options; if (!PGOInstrGen.empty()) Options.InstrProfileOutput = PGOInstrGen; + Options.DoCounterPromotion = true; + MPM.add(createLoopRotatePass()); MPM.add(createInstrProfilingLegacyPass(Options)); } if (!PGOInstrUse.empty()) MPM.add(createPGOInstrumentationUseLegacyPass(PGOInstrUse)); + // Indirect call promotion that promotes intra-module targets only. + // For ThinLTO this is done earlier due to interactions with globalopt + // for imported functions. We don't run this at -O0. + if (OptLevel > 0) + MPM.add( + createPGOIndirectCallPromotionLegacyPass(false, !PGOSampleUse.empty())); } void PassManagerBuilder::addFunctionSimplificationPasses( legacy::PassManagerBase &MPM) { // Start of function pass. // Break up aggregate allocas, using SSAUpdater. MPM.add(createSROAPass()); - MPM.add(createEarlyCSEPass()); // Catch trivial redundancies + MPM.add(createEarlyCSEPass(EnableEarlyCSEMemSSA)); // Catch trivial redundancies + if (EnableGVNHoist) + MPM.add(createGVNHoistPass()); + if (EnableGVNSink) { + MPM.add(createGVNSinkPass()); + MPM.add(createCFGSimplificationPass()); + } + // Speculative execution if the target has divergent branches; otherwise nop. MPM.add(createSpeculativeExecutionIfHasBranchDivergencePass()); MPM.add(createJumpThreadingPass()); // Thread jumps. @@ -305,29 +334,37 @@ void PassManagerBuilder::addFunctionSimplificationPasses( MPM.add(createLibCallsShrinkWrapPass()); addExtensionsToPM(EP_Peephole, MPM); + // Optimize memory intrinsic calls based on the profiled size information. + if (SizeLevel == 0) + MPM.add(createPGOMemOPSizeOptLegacyPass()); + MPM.add(createTailCallEliminationPass()); // Eliminate tail calls MPM.add(createCFGSimplificationPass()); // Merge & remove BBs MPM.add(createReassociatePass()); // Reassociate expressions // Rotate Loop - disable header duplication at -Oz MPM.add(createLoopRotatePass(SizeLevel == 2 ? 0 : -1)); MPM.add(createLICMPass()); // Hoist loop invariants - MPM.add(createLoopUnswitchPass(SizeLevel || OptLevel < 3)); + if (EnableSimpleLoopUnswitch) + MPM.add(createSimpleLoopUnswitchLegacyPass()); + else + MPM.add(createLoopUnswitchPass(SizeLevel || OptLevel < 3, DivergentTarget)); MPM.add(createCFGSimplificationPass()); addInstructionCombiningPass(MPM); MPM.add(createIndVarSimplifyPass()); // Canonicalize indvars MPM.add(createLoopIdiomPass()); // Recognize idioms like memset. + addExtensionsToPM(EP_LateLoopOptimizations, MPM); MPM.add(createLoopDeletionPass()); // Delete dead loops + if (EnableLoopInterchange) { MPM.add(createLoopInterchangePass()); // Interchange loops MPM.add(createCFGSimplificationPass()); } if (!DisableUnrollLoops) - MPM.add(createSimpleLoopUnrollPass()); // Unroll small loops + MPM.add(createSimpleLoopUnrollPass(OptLevel)); // Unroll small loops addExtensionsToPM(EP_LoopOptimizerEnd, MPM); if (OptLevel > 1) { - if (EnableMLSM) - MPM.add(createMergedLoadStoreMotionPass()); // Merge ld/st in diamonds + MPM.add(createMergedLoadStoreMotionPass()); // Merge ld/st in diamonds MPM.add(NewGVN ? createNewGVNPass() : createGVNPass(DisableGVNLoadPRE)); // Remove redundancies } @@ -352,29 +389,8 @@ void PassManagerBuilder::addFunctionSimplificationPasses( if (RerollLoops) MPM.add(createLoopRerollPass()); - if (!RunSLPAfterLoopVectorization) { - if (SLPVectorize) - MPM.add(createSLPVectorizerPass()); // Vectorize parallel scalar chains. - - if (BBVectorize) { - MPM.add(createBBVectorizePass()); - addInstructionCombiningPass(MPM); - addExtensionsToPM(EP_Peephole, MPM); - if (OptLevel > 1 && UseGVNAfterVectorization) - MPM.add(NewGVN - ? createNewGVNPass() - : createGVNPass(DisableGVNLoadPRE)); // Remove redundancies - else - MPM.add(createEarlyCSEPass()); // Catch trivial redundancies - - // BBVectorize may have significantly shortened a loop body; unroll again. - if (!DisableUnrollLoops) - MPM.add(createLoopUnrollPass()); - } - } - - if (LoadCombine) - MPM.add(createLoadCombinePass()); + if (!RunSLPAfterLoopVectorization && SLPVectorize) + MPM.add(createSLPVectorizerPass()); // Vectorize parallel scalar chains. MPM.add(createAggressiveDCEPass()); // Delete dead instructions MPM.add(createCFGSimplificationPass()); // Merge & remove BBs @@ -409,14 +425,17 @@ void PassManagerBuilder::populateModulePassManager( // builds. The function merging pass is if (MergeFunctions) MPM.add(createMergeFunctionsPass()); - else if (!GlobalExtensions->empty() || !Extensions.empty()) + else if (GlobalExtensionsNotEmpty() || !Extensions.empty()) MPM.add(createBarrierNoopPass()); + addExtensionsToPM(EP_EnabledOnOptLevel0, MPM); + + // Rename anon globals to be able to export them in the summary. + // This has to be done after we add the extensions to the pass manager + // as there could be passes (e.g. Adddress sanitizer) which introduce + // new unnamed globals. if (PrepareForThinLTO) - // Rename anon globals to be able to export them in the summary. MPM.add(createNameAnonGlobalPass()); - - addExtensionsToPM(EP_EnabledOnOptLevel0, MPM); return; } @@ -434,7 +453,16 @@ void PassManagerBuilder::populateModulePassManager( // earlier in the pass pipeline, here before globalopt. Otherwise imported // available_externally functions look unreferenced and are removed. if (PerformThinLTO) - MPM.add(createPGOIndirectCallPromotionLegacyPass(/*InLTO = */ true)); + MPM.add(createPGOIndirectCallPromotionLegacyPass(/*InLTO = */ true, + !PGOSampleUse.empty())); + + // For SamplePGO in ThinLTO compile phase, we do not want to unroll loops + // as it will change the CFG too much to make the 2nd profile annotation + // in backend more difficult. + bool PrepareForThinLTOUsingPGOSampleProfile = + PrepareForThinLTO && !PGOSampleUse.empty(); + if (PrepareForThinLTOUsingPGOSampleProfile) + DisableUnrollLoops = true; if (!DisableUnitAtATime) { // Infer attributes about declarations if possible. @@ -454,15 +482,13 @@ void PassManagerBuilder::populateModulePassManager( MPM.add(createCFGSimplificationPass()); // Clean up after IPCP & DAE } - if (!PerformThinLTO) { - /// PGO instrumentation is added during the compile phase for ThinLTO, do - /// not run it a second time + // For SamplePGO in ThinLTO compile phase, we do not want to do indirect + // call promotion as it will change the CFG too much to make the 2nd + // profile annotation in backend more difficult. + // PGO instrumentation is added during the compile phase for ThinLTO, do + // not run it a second time + if (!PerformThinLTO && !PrepareForThinLTOUsingPGOSampleProfile) addPGOInstrPasses(MPM); - // Indirect call promotion that promotes intra-module targets only. - // For ThinLTO this is done earlier due to interactions with globalopt - // for imported functions. - MPM.add(createPGOIndirectCallPromotionLegacyPass()); - } if (EnableNonLTOGlobalsModRef) // We add a module alias analysis pass here. In part due to bugs in the @@ -489,6 +515,8 @@ void PassManagerBuilder::populateModulePassManager( // pass manager that we are specifically trying to avoid. To prevent this // we must insert a no-op module pass to reset the pass manager. MPM.add(createBarrierNoopPass()); + if (RunPartialInlining) + MPM.add(createPartialInliningPass()); if (!DisableUnitAtATime && OptLevel > 1 && !PrepareForLTO && !PrepareForThinLTO) @@ -589,42 +617,24 @@ void PassManagerBuilder::populateModulePassManager( MPM.add(createCorrelatedValuePropagationPass()); addInstructionCombiningPass(MPM); MPM.add(createLICMPass()); - MPM.add(createLoopUnswitchPass(SizeLevel || OptLevel < 3)); + MPM.add(createLoopUnswitchPass(SizeLevel || OptLevel < 3, DivergentTarget)); MPM.add(createCFGSimplificationPass()); addInstructionCombiningPass(MPM); } - if (RunSLPAfterLoopVectorization) { - if (SLPVectorize) { - MPM.add(createSLPVectorizerPass()); // Vectorize parallel scalar chains. - if (OptLevel > 1 && ExtraVectorizerPasses) { - MPM.add(createEarlyCSEPass()); - } - } - - if (BBVectorize) { - MPM.add(createBBVectorizePass()); - addInstructionCombiningPass(MPM); - addExtensionsToPM(EP_Peephole, MPM); - if (OptLevel > 1 && UseGVNAfterVectorization) - MPM.add(NewGVN - ? createNewGVNPass() - : createGVNPass(DisableGVNLoadPRE)); // Remove redundancies - else - MPM.add(createEarlyCSEPass()); // Catch trivial redundancies - - // BBVectorize may have significantly shortened a loop body; unroll again. - if (!DisableUnrollLoops) - MPM.add(createLoopUnrollPass()); + if (RunSLPAfterLoopVectorization && SLPVectorize) { + MPM.add(createSLPVectorizerPass()); // Vectorize parallel scalar chains. + if (OptLevel > 1 && ExtraVectorizerPasses) { + MPM.add(createEarlyCSEPass()); } } addExtensionsToPM(EP_Peephole, MPM); - MPM.add(createCFGSimplificationPass()); + MPM.add(createLateCFGSimplificationPass()); // Switches to lookup tables addInstructionCombiningPass(MPM); if (!DisableUnrollLoops) { - MPM.add(createLoopUnrollPass()); // Unroll small loops + MPM.add(createLoopUnrollPass(OptLevel)); // Unroll small loops // LoopUnroll may generate some redundency to cleanup. addInstructionCombiningPass(MPM); @@ -662,6 +672,11 @@ void PassManagerBuilder::populateModulePassManager( MPM.add(createLoopSinkPass()); // Get rid of LCSSA nodes. MPM.add(createInstructionSimplifierPass()); + + // LoopSink (and other loop passes since the last simplifyCFG) might have + // resulted in single-entry-single-exit or empty blocks. Clean up the CFG. + MPM.add(createCFGSimplificationPass()); + addExtensionsToPM(EP_OptimizerLast, MPM); } @@ -684,7 +699,8 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) { // left by the earlier promotion pass that promotes intra-module targets. // This two-step promotion is to save the compile time. For LTO, it should // produce the same result as if we only do promotion here. - PM.add(createPGOIndirectCallPromotionLegacyPass(true)); + PM.add( + createPGOIndirectCallPromotionLegacyPass(true, !PGOSampleUse.empty())); // Propagate constants at call sites into the functions they call. This // opens opportunities for globalopt (and inlining) by substituting function @@ -703,7 +719,7 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) { PM.add(createGlobalSplitPass()); // Apply whole-program devirtualization and virtual constant propagation. - PM.add(createWholeProgramDevirtPass()); + PM.add(createWholeProgramDevirtPass(ExportSummary, nullptr)); // That's all we need at opt level 1. if (OptLevel == 1) @@ -759,8 +775,7 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) { PM.add(createGlobalsAAWrapperPass()); // IP alias analysis. PM.add(createLICMPass()); // Hoist loop invariants. - if (EnableMLSM) - PM.add(createMergedLoadStoreMotionPass()); // Merge ld/st in diamonds. + PM.add(createMergedLoadStoreMotionPass()); // Merge ld/st in diamonds. PM.add(NewGVN ? createNewGVNPass() : createGVNPass(DisableGVNLoadPRE)); // Remove redundancies. PM.add(createMemCpyOptPass()); // Remove dead memcpys. @@ -775,11 +790,11 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) { PM.add(createLoopInterchangePass()); if (!DisableUnrollLoops) - PM.add(createSimpleLoopUnrollPass()); // Unroll small loops + PM.add(createSimpleLoopUnrollPass(OptLevel)); // Unroll small loops PM.add(createLoopVectorizePass(true, LoopVectorize)); // The vectorizer may have significantly shortened a loop body; unroll again. if (!DisableUnrollLoops) - PM.add(createLoopUnrollPass()); + PM.add(createLoopUnrollPass(OptLevel)); // Now that we've optimized loops (in particular loop induction variables), // we may have exposed more scalar opportunities. Run parts of the scalar @@ -799,9 +814,6 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) { // alignments. PM.add(createAlignmentFromAssumptionsPass()); - if (LoadCombine) - PM.add(createLoadCombinePass()); - // Cleanup and simplify the code after the scalar optimizations. addInstructionCombiningPass(PM); addExtensionsToPM(EP_Peephole, PM); @@ -833,6 +845,23 @@ void PassManagerBuilder::populateThinLTOPassManager( if (VerifyInput) PM.add(createVerifierPass()); + if (ImportSummary) { + // These passes import type identifier resolutions for whole-program + // devirtualization and CFI. They must run early because other passes may + // disturb the specific instruction patterns that these passes look for, + // creating dependencies on resolutions that may not appear in the summary. + // + // For example, GVN may transform the pattern assume(type.test) appearing in + // two basic blocks into assume(phi(type.test, type.test)), which would + // transform a dependency on a WPD resolution into a dependency on a type + // identifier resolution for CFI. + // + // Also, WPD has access to more precise information than ICP and can + // devirtualize more effectively, so it should operate on the IR first. + PM.add(createWholeProgramDevirtPass(nullptr, ImportSummary)); + PM.add(createLowerTypeTestsPass(nullptr, ImportSummary)); + } + populateModulePassManager(PM); if (VerifyOutput) @@ -849,6 +878,12 @@ void PassManagerBuilder::populateLTOPassManager(legacy::PassManagerBase &PM) { if (OptLevel != 0) addLTOOptimizationPasses(PM); + else { + // The whole-program-devirt pass needs to run at -O0 because only it knows + // about the llvm.type.checked.load intrinsic: it needs to both lower the + // intrinsic itself and handle it in the summary. + PM.add(createWholeProgramDevirtPass(ExportSummary, nullptr)); + } // Create a function that performs CFI checks for cross-DSO calls with targets // in the current module. @@ -857,8 +892,7 @@ void PassManagerBuilder::populateLTOPassManager(legacy::PassManagerBase &PM) { // Lower type metadata and the type.test intrinsic. This pass supports Clang's // control flow integrity mechanisms (-fsanitize=cfi*) and needs to run at // link time if CFI is enabled. The pass does nothing if CFI is disabled. - PM.add(createLowerTypeTestsPass(LowerTypeTestsSummaryAction::None, - /*Summary=*/nullptr)); + PM.add(createLowerTypeTestsPass(ExportSummary, nullptr)); if (OptLevel != 0) addLateLTOOptimizationPasses(PM); diff --git a/contrib/llvm/lib/Transforms/IPO/PruneEH.cpp b/contrib/llvm/lib/Transforms/IPO/PruneEH.cpp index d9acb9b..3fd5984 100644 --- a/contrib/llvm/lib/Transforms/IPO/PruneEH.cpp +++ b/contrib/llvm/lib/Transforms/IPO/PruneEH.cpp @@ -14,10 +14,8 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/IPO.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" -#include "llvm/Support/raw_ostream.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/CallGraphSCCPass.h" #include "llvm/Analysis/EHPersonalities.h" @@ -28,6 +26,8 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Utils/Local.h" #include <algorithm> using namespace llvm; diff --git a/contrib/llvm/lib/Transforms/IPO/SampleProfile.cpp b/contrib/llvm/lib/Transforms/IPO/SampleProfile.cpp index 6a43f8d..6baada2 100644 --- a/contrib/llvm/lib/Transforms/IPO/SampleProfile.cpp +++ b/contrib/llvm/lib/Transforms/IPO/SampleProfile.cpp @@ -35,6 +35,7 @@ #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" +#include "llvm/IR/GlobalValue.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" @@ -42,7 +43,9 @@ #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" +#include "llvm/IR/ValueSymbolTable.h" #include "llvm/Pass.h" +#include "llvm/ProfileData/InstrProf.h" #include "llvm/ProfileData/SampleProfReader.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -50,6 +53,7 @@ #include "llvm/Support/Format.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Utils/Cloning.h" #include <cctype> @@ -159,21 +163,26 @@ protected: ErrorOr<uint64_t> getInstWeight(const Instruction &I); ErrorOr<uint64_t> getBlockWeight(const BasicBlock *BB); const FunctionSamples *findCalleeFunctionSamples(const Instruction &I) const; + std::vector<const FunctionSamples *> + findIndirectCallFunctionSamples(const Instruction &I) const; const FunctionSamples *findFunctionSamples(const Instruction &I) const; - bool inlineHotFunctions(Function &F); + bool inlineHotFunctions(Function &F, + DenseSet<GlobalValue::GUID> &ImportGUIDs); void printEdgeWeight(raw_ostream &OS, Edge E); void printBlockWeight(raw_ostream &OS, const BasicBlock *BB) const; void printBlockEquivalence(raw_ostream &OS, const BasicBlock *BB); bool computeBlockWeights(Function &F); void findEquivalenceClasses(Function &F); + template <bool IsPostDom> void findEquivalencesFor(BasicBlock *BB1, ArrayRef<BasicBlock *> Descendants, - DominatorTreeBase<BasicBlock> *DomTree); + DominatorTreeBase<BasicBlock, IsPostDom> *DomTree); + void propagateWeights(Function &F); uint64_t visitEdge(Edge E, unsigned *NumUnknownEdges, Edge *UnknownEdge); void buildEdges(Function &F); bool propagateThroughEdges(Function &F, bool UpdateBlockCount); void computeDominanceAndLoopInfo(Function &F); - unsigned getOffset(unsigned L, unsigned H) const; + unsigned getOffset(const DILocation *DIL) const; void clearFunctionData(); /// \brief Map basic blocks to their computed weights. @@ -202,9 +211,15 @@ protected: /// the same number of times. EquivalenceClassMap EquivalenceClass; + /// Map from function name to Function *. Used to find the function from + /// the function name. If the function name contains suffix, additional + /// entry is added to map from the stripped name to the function if there + /// is one-to-one mapping. + StringMap<Function *> SymbolMap; + /// \brief Dominance, post-dominance and loop information. std::unique_ptr<DominatorTree> DT; - std::unique_ptr<DominatorTreeBase<BasicBlock>> PDT; + std::unique_ptr<PostDomTreeBase<BasicBlock>> PDT; std::unique_ptr<LoopInfo> LI; AssumptionCacheTracker *ACT; @@ -326,11 +341,12 @@ SampleCoverageTracker::countUsedRecords(const FunctionSamples *FS) const { // If there are inlined callsites in this function, count the samples found // in the respective bodies. However, do not bother counting callees with 0 // total samples, these are callees that were never invoked at runtime. - for (const auto &I : FS->getCallsiteSamples()) { - const FunctionSamples *CalleeSamples = &I.second; - if (callsiteIsHot(FS, CalleeSamples)) - Count += countUsedRecords(CalleeSamples); - } + for (const auto &I : FS->getCallsiteSamples()) + for (const auto &J : I.second) { + const FunctionSamples *CalleeSamples = &J.second; + if (callsiteIsHot(FS, CalleeSamples)) + Count += countUsedRecords(CalleeSamples); + } return Count; } @@ -343,11 +359,12 @@ SampleCoverageTracker::countBodyRecords(const FunctionSamples *FS) const { unsigned Count = FS->getBodySamples().size(); // Only count records in hot callsites. - for (const auto &I : FS->getCallsiteSamples()) { - const FunctionSamples *CalleeSamples = &I.second; - if (callsiteIsHot(FS, CalleeSamples)) - Count += countBodyRecords(CalleeSamples); - } + for (const auto &I : FS->getCallsiteSamples()) + for (const auto &J : I.second) { + const FunctionSamples *CalleeSamples = &J.second; + if (callsiteIsHot(FS, CalleeSamples)) + Count += countBodyRecords(CalleeSamples); + } return Count; } @@ -362,11 +379,12 @@ SampleCoverageTracker::countBodySamples(const FunctionSamples *FS) const { Total += I.second.getSamples(); // Only count samples in hot callsites. - for (const auto &I : FS->getCallsiteSamples()) { - const FunctionSamples *CalleeSamples = &I.second; - if (callsiteIsHot(FS, CalleeSamples)) - Total += countBodySamples(CalleeSamples); - } + for (const auto &I : FS->getCallsiteSamples()) + for (const auto &J : I.second) { + const FunctionSamples *CalleeSamples = &J.second; + if (callsiteIsHot(FS, CalleeSamples)) + Total += countBodySamples(CalleeSamples); + } return Total; } @@ -398,15 +416,11 @@ void SampleProfileLoader::clearFunctionData() { CoverageTracker.clear(); } -/// \brief Returns the offset of lineno \p L to head_lineno \p H -/// -/// \param L Lineno -/// \param H Header lineno of the function -/// -/// \returns offset to the header lineno. 16 bits are used to represent offset. +/// Returns the line offset to the start line of the subprogram. /// We assume that a single function will not exceed 65535 LOC. -unsigned SampleProfileLoader::getOffset(unsigned L, unsigned H) const { - return (L - H) & 0xffff; +unsigned SampleProfileLoader::getOffset(const DILocation *DIL) const { + return (DIL->getLine() - DIL->getScope()->getSubprogram()->getLine()) & + 0xffff; } /// \brief Print the weight of edge \p E on stream \p OS. @@ -451,8 +465,7 @@ void SampleProfileLoader::printBlockWeight(raw_ostream &OS, /// \param Inst Instruction to query. /// /// \returns the weight of \p Inst. -ErrorOr<uint64_t> -SampleProfileLoader::getInstWeight(const Instruction &Inst) { +ErrorOr<uint64_t> SampleProfileLoader::getInstWeight(const Instruction &Inst) { const DebugLoc &DLoc = Inst.getDebugLoc(); if (!DLoc) return std::error_code(); @@ -470,19 +483,14 @@ SampleProfileLoader::getInstWeight(const Instruction &Inst) { // If a call/invoke instruction is inlined in profile, but not inlined here, // it means that the inlined callsite has no sample, thus the call // instruction should have 0 count. - bool IsCall = isa<CallInst>(Inst) || isa<InvokeInst>(Inst); - if (IsCall && findCalleeFunctionSamples(Inst)) + if ((isa<CallInst>(Inst) || isa<InvokeInst>(Inst)) && + findCalleeFunctionSamples(Inst)) return 0; const DILocation *DIL = DLoc; - unsigned Lineno = DLoc.getLine(); - unsigned HeaderLineno = DIL->getScope()->getSubprogram()->getLine(); - - uint32_t LineOffset = getOffset(Lineno, HeaderLineno); - uint32_t Discriminator = DIL->getDiscriminator(); - ErrorOr<uint64_t> R = IsCall - ? FS->findCallSamplesAt(LineOffset, Discriminator) - : FS->findSamplesAt(LineOffset, Discriminator); + uint32_t LineOffset = getOffset(DIL); + uint32_t Discriminator = DIL->getBaseDiscriminator(); + ErrorOr<uint64_t> R = FS->findSamplesAt(LineOffset, Discriminator); if (R) { bool FirstMark = CoverageTracker.markSamplesUsed(FS, LineOffset, Discriminator, R.get()); @@ -491,13 +499,14 @@ SampleProfileLoader::getInstWeight(const Instruction &Inst) { LLVMContext &Ctx = F->getContext(); emitOptimizationRemark( Ctx, DEBUG_TYPE, *F, DLoc, - Twine("Applied ") + Twine(*R) + " samples from profile (offset: " + - Twine(LineOffset) + + Twine("Applied ") + Twine(*R) + + " samples from profile (offset: " + Twine(LineOffset) + ((Discriminator) ? Twine(".") + Twine(Discriminator) : "") + ")"); } - DEBUG(dbgs() << " " << Lineno << "." << DIL->getDiscriminator() << ":" - << Inst << " (line offset: " << Lineno - HeaderLineno << "." - << DIL->getDiscriminator() << " - weight: " << R.get() + DEBUG(dbgs() << " " << DLoc.getLine() << "." + << DIL->getBaseDiscriminator() << ":" << Inst + << " (line offset: " << LineOffset << "." + << DIL->getBaseDiscriminator() << " - weight: " << R.get() << ")\n"); } return R; @@ -511,8 +520,7 @@ SampleProfileLoader::getInstWeight(const Instruction &Inst) { /// \param BB The basic block to query. /// /// \returns the weight for \p BB. -ErrorOr<uint64_t> -SampleProfileLoader::getBlockWeight(const BasicBlock *BB) { +ErrorOr<uint64_t> SampleProfileLoader::getBlockWeight(const BasicBlock *BB) { uint64_t Max = 0; bool HasWeight = false; for (auto &I : BB->getInstList()) { @@ -565,16 +573,49 @@ SampleProfileLoader::findCalleeFunctionSamples(const Instruction &Inst) const { if (!DIL) { return nullptr; } - DISubprogram *SP = DIL->getScope()->getSubprogram(); - if (!SP) - return nullptr; + + StringRef CalleeName; + if (const CallInst *CI = dyn_cast<CallInst>(&Inst)) + if (Function *Callee = CI->getCalledFunction()) + CalleeName = Callee->getName(); const FunctionSamples *FS = findFunctionSamples(Inst); if (FS == nullptr) return nullptr; - return FS->findFunctionSamplesAt(LineLocation( - getOffset(DIL->getLine(), SP->getLine()), DIL->getDiscriminator())); + return FS->findFunctionSamplesAt( + LineLocation(getOffset(DIL), DIL->getBaseDiscriminator()), CalleeName); +} + +/// Returns a vector of FunctionSamples that are the indirect call targets +/// of \p Inst. The vector is sorted by the total number of samples. +std::vector<const FunctionSamples *> +SampleProfileLoader::findIndirectCallFunctionSamples( + const Instruction &Inst) const { + const DILocation *DIL = Inst.getDebugLoc(); + std::vector<const FunctionSamples *> R; + + if (!DIL) { + return R; + } + + const FunctionSamples *FS = findFunctionSamples(Inst); + if (FS == nullptr) + return R; + + if (const FunctionSamplesMap *M = FS->findFunctionSamplesMapAt( + LineLocation(getOffset(DIL), DIL->getBaseDiscriminator()))) { + if (M->size() == 0) + return R; + for (const auto &NameFS : *M) { + R.push_back(&NameFS.second); + } + std::sort(R.begin(), R.end(), + [](const FunctionSamples *L, const FunctionSamples *R) { + return L->getTotalSamples() > R->getTotalSamples(); + }); + } + return R; } /// \brief Get the FunctionSamples for an instruction. @@ -588,23 +629,23 @@ SampleProfileLoader::findCalleeFunctionSamples(const Instruction &Inst) const { /// \returns the FunctionSamples pointer to the inlined instance. const FunctionSamples * SampleProfileLoader::findFunctionSamples(const Instruction &Inst) const { - SmallVector<LineLocation, 10> S; + SmallVector<std::pair<LineLocation, StringRef>, 10> S; const DILocation *DIL = Inst.getDebugLoc(); - if (!DIL) { + if (!DIL) return Samples; - } + + const DILocation *PrevDIL = DIL; for (DIL = DIL->getInlinedAt(); DIL; DIL = DIL->getInlinedAt()) { - DISubprogram *SP = DIL->getScope()->getSubprogram(); - if (!SP) - return nullptr; - S.push_back(LineLocation(getOffset(DIL->getLine(), SP->getLine()), - DIL->getDiscriminator())); + S.push_back(std::make_pair( + LineLocation(getOffset(DIL), DIL->getBaseDiscriminator()), + PrevDIL->getScope()->getSubprogram()->getLinkageName())); + PrevDIL = DIL; } if (S.size() == 0) return Samples; const FunctionSamples *FS = Samples; for (int i = S.size() - 1; i >= 0 && FS != nullptr; i--) { - FS = FS->findFunctionSamplesAt(S[i]); + FS = FS->findFunctionSamplesAt(S[i].first, S[i].second); } return FS; } @@ -614,14 +655,17 @@ SampleProfileLoader::findFunctionSamples(const Instruction &Inst) const { /// Iteratively traverse all callsites of the function \p F, and find if /// the corresponding inlined instance exists and is hot in profile. If /// it is hot enough, inline the callsites and adds new callsites of the -/// callee into the caller. -/// -/// TODO: investigate the possibility of not invoking InlineFunction directly. +/// callee into the caller. If the call is an indirect call, first promote +/// it to direct call. Each indirect call is limited with a single target. /// /// \param F function to perform iterative inlining. +/// \param ImportGUIDs a set to be updated to include all GUIDs that come +/// from a different module but inlined in the profiled binary. /// /// \returns True if there is any inline happened. -bool SampleProfileLoader::inlineHotFunctions(Function &F) { +bool SampleProfileLoader::inlineHotFunctions( + Function &F, DenseSet<GlobalValue::GUID> &ImportGUIDs) { + DenseSet<Instruction *> PromotedInsns; bool Changed = false; LLVMContext &Ctx = F.getContext(); std::function<AssumptionCache &(Function &)> GetAssumptionCache = [&]( @@ -635,7 +679,7 @@ bool SampleProfileLoader::inlineHotFunctions(Function &F) { for (auto &I : BB.getInstList()) { const FunctionSamples *FS = nullptr; if ((isa<CallInst>(I) || isa<InvokeInst>(I)) && - (FS = findCalleeFunctionSamples(I))) { + !isa<IntrinsicInst>(I) && (FS = findCalleeFunctionSamples(I))) { Candidates.push_back(&I); if (callsiteIsHot(Samples, FS)) Hot = true; @@ -647,18 +691,55 @@ bool SampleProfileLoader::inlineHotFunctions(Function &F) { } for (auto I : CIS) { InlineFunctionInfo IFI(nullptr, ACT ? &GetAssumptionCache : nullptr); - CallSite CS(I); - Function *CalledFunction = CS.getCalledFunction(); - if (!CalledFunction || !CalledFunction->getSubprogram()) + Function *CalledFunction = CallSite(I).getCalledFunction(); + // Do not inline recursive calls. + if (CalledFunction == &F) continue; + Instruction *DI = I; + if (!CalledFunction && !PromotedInsns.count(I) && + CallSite(I).isIndirectCall()) + for (const auto *FS : findIndirectCallFunctionSamples(*I)) { + auto CalleeFunctionName = FS->getName(); + // If it is a recursive call, we do not inline it as it could bloat + // the code exponentially. There is way to better handle this, e.g. + // clone the caller first, and inline the cloned caller if it is + // recursive. As llvm does not inline recursive calls, we will simply + // ignore it instead of handling it explicitly. + if (CalleeFunctionName == F.getName()) + continue; + const char *Reason = "Callee function not available"; + auto R = SymbolMap.find(CalleeFunctionName); + if (R == SymbolMap.end()) + continue; + CalledFunction = R->getValue(); + if (CalledFunction && isLegalToPromote(I, CalledFunction, &Reason)) { + // The indirect target was promoted and inlined in the profile, as a + // result, we do not have profile info for the branch probability. + // We set the probability to 80% taken to indicate that the static + // call is likely taken. + DI = dyn_cast<Instruction>( + promoteIndirectCall(I, CalledFunction, 80, 100, false) + ->stripPointerCasts()); + PromotedInsns.insert(I); + } else { + DEBUG(dbgs() << "\nFailed to promote indirect call to " + << CalleeFunctionName << " because " << Reason + << "\n"); + continue; + } + } + if (!CalledFunction || !CalledFunction->getSubprogram()) { + findCalleeFunctionSamples(*I)->findImportedFunctions( + ImportGUIDs, F.getParent(), + Samples->getTotalSamples() * SampleProfileHotThreshold / 100); + continue; + } DebugLoc DLoc = I->getDebugLoc(); - uint64_t NumSamples = findCalleeFunctionSamples(*I)->getTotalSamples(); - if (InlineFunction(CS, IFI)) { + if (InlineFunction(CallSite(DI), IFI)) { LocalChanged = true; emitOptimizationRemark(Ctx, DEBUG_TYPE, F, DLoc, Twine("inlined hot callee '") + - CalledFunction->getName() + "' with " + - Twine(NumSamples) + " samples into '" + + CalledFunction->getName() + "' into '" + F.getName() + "'"); } } @@ -694,9 +775,10 @@ bool SampleProfileLoader::inlineHotFunctions(Function &F) { /// \param DomTree Opposite dominator tree. If \p Descendants is filled /// with blocks from \p BB1's dominator tree, then /// this is the post-dominator tree, and vice versa. +template <bool IsPostDom> void SampleProfileLoader::findEquivalencesFor( BasicBlock *BB1, ArrayRef<BasicBlock *> Descendants, - DominatorTreeBase<BasicBlock> *DomTree) { + DominatorTreeBase<BasicBlock, IsPostDom> *DomTree) { const BasicBlock *EC = EquivalenceClass[BB1]; uint64_t Weight = BlockWeights[EC]; for (const auto *BB2 : Descendants) { @@ -994,6 +1076,26 @@ void SampleProfileLoader::buildEdges(Function &F) { } } +/// Sorts the CallTargetMap \p M by count in descending order and stores the +/// sorted result in \p Sorted. Returns the total counts. +static uint64_t SortCallTargets(SmallVector<InstrProfValueData, 2> &Sorted, + const SampleRecord::CallTargetMap &M) { + Sorted.clear(); + uint64_t Sum = 0; + for (auto I = M.begin(); I != M.end(); ++I) { + Sum += I->getValue(); + Sorted.push_back({Function::getGUID(I->getKey()), I->getValue()}); + } + std::sort(Sorted.begin(), Sorted.end(), + [](const InstrProfValueData &L, const InstrProfValueData &R) { + if (L.Count == R.Count) + return L.Value > R.Value; + else + return L.Count > R.Count; + }); + return Sum; +} + /// \brief Propagate weights into edges /// /// The following rules are applied to every block BB in the CFG: @@ -1015,10 +1117,6 @@ void SampleProfileLoader::propagateWeights(Function &F) { bool Changed = true; unsigned I = 0; - // Add an entry count to the function using the samples gathered - // at the function entry. - F.setEntryCount(Samples->getHeadSamples() + 1); - // If BB weight is larger than its corresponding loop's header BB weight, // use the BB weight to replace the loop header BB weight. for (auto &BI : F) { @@ -1071,13 +1169,32 @@ void SampleProfileLoader::propagateWeights(Function &F) { if (BlockWeights[BB]) { for (auto &I : BB->getInstList()) { - if (CallInst *CI = dyn_cast<CallInst>(&I)) { - if (!dyn_cast<IntrinsicInst>(&I)) { - SmallVector<uint32_t, 1> Weights; - Weights.push_back(BlockWeights[BB]); - CI->setMetadata(LLVMContext::MD_prof, - MDB.createBranchWeights(Weights)); - } + if (!isa<CallInst>(I) && !isa<InvokeInst>(I)) + continue; + CallSite CS(&I); + if (!CS.getCalledFunction()) { + const DebugLoc &DLoc = I.getDebugLoc(); + if (!DLoc) + continue; + const DILocation *DIL = DLoc; + uint32_t LineOffset = getOffset(DIL); + uint32_t Discriminator = DIL->getBaseDiscriminator(); + + const FunctionSamples *FS = findFunctionSamples(I); + if (!FS) + continue; + auto T = FS->findCallTargetMapAt(LineOffset, Discriminator); + if (!T || T.get().size() == 0) + continue; + SmallVector<InstrProfValueData, 2> SortedCallTargets; + uint64_t Sum = SortCallTargets(SortedCallTargets, T.get()); + annotateValueSite(*I.getParent()->getParent()->getParent(), I, + SortedCallTargets, Sum, IPVK_IndirectCallTarget, + SortedCallTargets.size()); + } else if (!dyn_cast<IntrinsicInst>(&I)) { + SmallVector<uint32_t, 1> Weights; + Weights.push_back(BlockWeights[BB]); + I.setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights)); } } } @@ -1087,8 +1204,11 @@ void SampleProfileLoader::propagateWeights(Function &F) { if (!isa<BranchInst>(TI) && !isa<SwitchInst>(TI)) continue; + DebugLoc BranchLoc = TI->getDebugLoc(); DEBUG(dbgs() << "\nGetting weights for branch at line " - << TI->getDebugLoc().getLine() << ".\n"); + << ((BranchLoc) ? Twine(BranchLoc.getLine()) + : Twine("<UNKNOWN LOCATION>")) + << ".\n"); SmallVector<uint32_t, 4> Weights; uint32_t MaxWeight = 0; DebugLoc MaxDestLoc; @@ -1115,13 +1235,16 @@ void SampleProfileLoader::propagateWeights(Function &F) { } } + uint64_t TempWeight; // Only set weights if there is at least one non-zero weight. // In any other case, let the analyzer set weights. - if (MaxWeight > 0) { + // Do not set weights if the weights are present. In ThinLTO, the profile + // annotation is done twice. If the first annotation already set the + // weights, the second pass does not need to set it. + if (MaxWeight > 0 && !TI->extractProfTotalWeight(TempWeight)) { DEBUG(dbgs() << "SUCCESS. Found non-zero weights.\n"); TI->setMetadata(llvm::LLVMContext::MD_prof, MDB.createBranchWeights(Weights)); - DebugLoc BranchLoc = TI->getDebugLoc(); emitOptimizationRemark( Ctx, DEBUG_TYPE, F, MaxDestLoc, Twine("most popular destination for conditional branches at ") + @@ -1163,7 +1286,7 @@ void SampleProfileLoader::computeDominanceAndLoopInfo(Function &F) { DT.reset(new DominatorTree); DT->recalculate(F); - PDT.reset(new DominatorTreeBase<BasicBlock>(true)); + PDT.reset(new PostDomTreeBase<BasicBlock>()); PDT->recalculate(F); LI.reset(new LoopInfo); @@ -1228,12 +1351,19 @@ bool SampleProfileLoader::emitAnnotations(Function &F) { DEBUG(dbgs() << "Line number for the first instruction in " << F.getName() << ": " << getFunctionLoc(F) << "\n"); - Changed |= inlineHotFunctions(F); + DenseSet<GlobalValue::GUID> ImportGUIDs; + Changed |= inlineHotFunctions(F, ImportGUIDs); // Compute basic block weights. Changed |= computeBlockWeights(F); if (Changed) { + // Add an entry count to the function using the samples gathered at the + // function entry. Also sets the GUIDs that comes from a different + // module but inlined in the profiled binary. This is aiming at making + // the IR match the profiled binary before annotation. + F.setEntryCount(Samples->getHeadSamples() + 1, &ImportGUIDs); + // Compute dominance and loop info needed for propagation. computeDominanceAndLoopInfo(F); @@ -1309,6 +1439,26 @@ bool SampleProfileLoader::runOnModule(Module &M) { for (const auto &I : Reader->getProfiles()) TotalCollectedSamples += I.second.getTotalSamples(); + // Populate the symbol map. + for (const auto &N_F : M.getValueSymbolTable()) { + std::string OrigName = N_F.getKey(); + Function *F = dyn_cast<Function>(N_F.getValue()); + if (F == nullptr) + continue; + SymbolMap[OrigName] = F; + auto pos = OrigName.find('.'); + if (pos != std::string::npos) { + std::string NewName = OrigName.substr(0, pos); + auto r = SymbolMap.insert(std::make_pair(NewName, F)); + // Failiing to insert means there is already an entry in SymbolMap, + // thus there are multiple functions that are mapped to the same + // stripped name. In this case of name conflicting, set the value + // to nullptr to avoid confusion. + if (!r.second) + r.first->second = nullptr; + } + } + bool retval = false; for (auto &F : M) if (!F.isDeclaration()) { @@ -1329,7 +1479,7 @@ bool SampleProfileLoaderLegacyPass::runOnModule(Module &M) { bool SampleProfileLoader::runOnFunction(Function &F) { F.setEntryCount(0); Samples = Reader->getSamplesFor(F); - if (!Samples->empty()) + if (Samples && !Samples->empty()) return emitAnnotations(F); return false; } @@ -1337,7 +1487,8 @@ bool SampleProfileLoader::runOnFunction(Function &F) { PreservedAnalyses SampleProfileLoaderPass::run(Module &M, ModuleAnalysisManager &AM) { - SampleProfileLoader SampleLoader(SampleProfileFile); + SampleProfileLoader SampleLoader( + ProfileFileName.empty() ? SampleProfileFile : ProfileFileName); SampleLoader.doInitialization(M); diff --git a/contrib/llvm/lib/Transforms/IPO/StripSymbols.cpp b/contrib/llvm/lib/Transforms/IPO/StripSymbols.cpp index 8f6f161..de1b51e 100644 --- a/contrib/llvm/lib/Transforms/IPO/StripSymbols.cpp +++ b/contrib/llvm/lib/Transforms/IPO/StripSymbols.cpp @@ -20,7 +20,6 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/IPO.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DebugInfo.h" @@ -30,6 +29,7 @@ #include "llvm/IR/TypeFinder.h" #include "llvm/IR/ValueSymbolTable.h" #include "llvm/Pass.h" +#include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; @@ -323,6 +323,14 @@ bool StripDeadDebugInfo::runOnModule(Module &M) { LiveGVs.insert(GVE); } + std::set<DICompileUnit *> LiveCUs; + // Any CU referenced from a subprogram is live. + for (DISubprogram *SP : F.subprograms()) { + if (SP->getUnit()) + LiveCUs.insert(SP->getUnit()); + } + + bool HasDeadCUs = false; for (DICompileUnit *DIC : F.compile_units()) { // Create our live global variable list. bool GlobalVariableChange = false; @@ -341,6 +349,11 @@ bool StripDeadDebugInfo::runOnModule(Module &M) { GlobalVariableChange = true; } + if (!LiveGlobalVariables.empty()) + LiveCUs.insert(DIC); + else if (!LiveCUs.count(DIC)) + HasDeadCUs = true; + // If we found dead global variables, replace the current global // variable list with our new live global variable list. if (GlobalVariableChange) { @@ -352,5 +365,16 @@ bool StripDeadDebugInfo::runOnModule(Module &M) { LiveGlobalVariables.clear(); } + if (HasDeadCUs) { + // Delete the old node and replace it with a new one + NamedMDNode *NMD = M.getOrInsertNamedMetadata("llvm.dbg.cu"); + NMD->clearOperands(); + if (!LiveCUs.empty()) { + for (DICompileUnit *CU : LiveCUs) + NMD->addOperand(CU); + } + Changed = true; + } + return Changed; } diff --git a/contrib/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp b/contrib/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp index 3680cfc..8ef6bb6 100644 --- a/contrib/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp +++ b/contrib/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp @@ -6,99 +6,67 @@ // License. See LICENSE.TXT for details. // //===----------------------------------------------------------------------===// -// -// This pass prepares a module containing type metadata for ThinLTO by splitting -// it into regular and thin LTO parts if possible, and writing both parts to -// a multi-module bitcode file. Modules that do not contain type metadata are -// written unmodified as a single module. -// -//===----------------------------------------------------------------------===// -#include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/IPO/ThinLTOBitcodeWriter.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/ModuleSummaryAnalysis.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/TypeMetadataUtils.h" #include "llvm/Bitcode/BitcodeWriter.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/DebugInfo.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" #include "llvm/Pass.h" +#include "llvm/Support/FileSystem.h" #include "llvm/Support/ScopedPrinter.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/IPO/FunctionAttrs.h" #include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/ModuleUtils.h" using namespace llvm; namespace { -// Produce a unique identifier for this module by taking the MD5 sum of the -// names of the module's strong external symbols. This identifier is -// normally guaranteed to be unique, or the program would fail to link due to -// multiply defined symbols. -// -// If the module has no strong external symbols (such a module may still have a -// semantic effect if it performs global initialization), we cannot produce a -// unique identifier for this module, so we return the empty string, which -// causes the entire module to be written as a regular LTO module. -std::string getModuleId(Module *M) { - MD5 Md5; - bool ExportsSymbols = false; - auto AddGlobal = [&](GlobalValue &GV) { - if (GV.isDeclaration() || GV.getName().startswith("llvm.") || - !GV.hasExternalLinkage()) - return; - ExportsSymbols = true; - Md5.update(GV.getName()); - Md5.update(ArrayRef<uint8_t>{0}); - }; - - for (auto &F : *M) - AddGlobal(F); - for (auto &GV : M->globals()) - AddGlobal(GV); - for (auto &GA : M->aliases()) - AddGlobal(GA); - for (auto &IF : M->ifuncs()) - AddGlobal(IF); - - if (!ExportsSymbols) - return ""; - - MD5::MD5Result R; - Md5.final(R); - - SmallString<32> Str; - MD5::stringifyResult(R, Str); - return ("$" + Str).str(); -} - // Promote each local-linkage entity defined by ExportM and used by ImportM by // changing visibility and appending the given ModuleId. -void promoteInternals(Module &ExportM, Module &ImportM, StringRef ModuleId) { - auto PromoteInternal = [&](GlobalValue &ExportGV) { +void promoteInternals(Module &ExportM, Module &ImportM, StringRef ModuleId, + SetVector<GlobalValue *> &PromoteExtra) { + DenseMap<const Comdat *, Comdat *> RenamedComdats; + for (auto &ExportGV : ExportM.global_values()) { if (!ExportGV.hasLocalLinkage()) - return; + continue; + + auto Name = ExportGV.getName(); + GlobalValue *ImportGV = ImportM.getNamedValue(Name); + if ((!ImportGV || ImportGV->use_empty()) && !PromoteExtra.count(&ExportGV)) + continue; - GlobalValue *ImportGV = ImportM.getNamedValue(ExportGV.getName()); - if (!ImportGV || ImportGV->use_empty()) - return; + std::string NewName = (Name + ModuleId).str(); - std::string NewName = (ExportGV.getName() + ModuleId).str(); + if (const auto *C = ExportGV.getComdat()) + if (C->getName() == Name) + RenamedComdats.try_emplace(C, ExportM.getOrInsertComdat(NewName)); ExportGV.setName(NewName); ExportGV.setLinkage(GlobalValue::ExternalLinkage); ExportGV.setVisibility(GlobalValue::HiddenVisibility); - ImportGV->setName(NewName); - ImportGV->setVisibility(GlobalValue::HiddenVisibility); - }; + if (ImportGV) { + ImportGV->setName(NewName); + ImportGV->setVisibility(GlobalValue::HiddenVisibility); + } + } - for (auto &F : ExportM) - PromoteInternal(F); - for (auto &GV : ExportM.globals()) - PromoteInternal(GV); - for (auto &GA : ExportM.aliases()) - PromoteInternal(GA); - for (auto &IF : ExportM.ifuncs()) - PromoteInternal(IF); + if (!RenamedComdats.empty()) + for (auto &GO : ExportM.global_objects()) + if (auto *C = GO.getComdat()) { + auto Replacement = RenamedComdats.find(C); + if (Replacement != RenamedComdats.end()) + GO.setComdat(Replacement->second); + } } // Promote all internal (i.e. distinct) type ids used by the module by replacing @@ -194,24 +162,7 @@ void simplifyExternals(Module &M) { } void filterModule( - Module *M, std::function<bool(const GlobalValue *)> ShouldKeepDefinition) { - for (Function &F : *M) { - if (ShouldKeepDefinition(&F)) - continue; - - F.deleteBody(); - F.clearMetadata(); - } - - for (GlobalVariable &GV : M->globals()) { - if (ShouldKeepDefinition(&GV)) - continue; - - GV.setInitializer(nullptr); - GV.setLinkage(GlobalValue::ExternalLinkage); - GV.clearMetadata(); - } - + Module *M, function_ref<bool(const GlobalValue *)> ShouldKeepDefinition) { for (Module::alias_iterator I = M->alias_begin(), E = M->alias_end(); I != E;) { GlobalAlias *GA = &*I++; @@ -219,65 +170,227 @@ void filterModule( continue; GlobalObject *GO; - if (I->getValueType()->isFunctionTy()) + if (GA->getValueType()->isFunctionTy()) GO = Function::Create(cast<FunctionType>(GA->getValueType()), GlobalValue::ExternalLinkage, "", M); else GO = new GlobalVariable( *M, GA->getValueType(), false, GlobalValue::ExternalLinkage, - (Constant *)nullptr, "", (GlobalVariable *)nullptr, + nullptr, "", nullptr, GA->getThreadLocalMode(), GA->getType()->getAddressSpace()); GO->takeName(GA); GA->replaceAllUsesWith(GO); GA->eraseFromParent(); } + + for (Function &F : *M) { + if (ShouldKeepDefinition(&F)) + continue; + + F.deleteBody(); + F.setComdat(nullptr); + F.clearMetadata(); + } + + for (GlobalVariable &GV : M->globals()) { + if (ShouldKeepDefinition(&GV)) + continue; + + GV.setInitializer(nullptr); + GV.setLinkage(GlobalValue::ExternalLinkage); + GV.setComdat(nullptr); + GV.clearMetadata(); + } +} + +void forEachVirtualFunction(Constant *C, function_ref<void(Function *)> Fn) { + if (auto *F = dyn_cast<Function>(C)) + return Fn(F); + if (isa<GlobalValue>(C)) + return; + for (Value *Op : C->operands()) + forEachVirtualFunction(cast<Constant>(Op), Fn); } // If it's possible to split M into regular and thin LTO parts, do so and write // a multi-module bitcode file with the two parts to OS. Otherwise, write only a // regular LTO bitcode file to OS. -void splitAndWriteThinLTOBitcode(raw_ostream &OS, Module &M) { - std::string ModuleId = getModuleId(&M); +void splitAndWriteThinLTOBitcode( + raw_ostream &OS, raw_ostream *ThinLinkOS, + function_ref<AAResults &(Function &)> AARGetter, Module &M) { + std::string ModuleId = getUniqueModuleId(&M); if (ModuleId.empty()) { // We couldn't generate a module ID for this module, just write it out as a // regular LTO module. WriteBitcodeToFile(&M, OS); + if (ThinLinkOS) + // We don't have a ThinLTO part, but still write the module to the + // ThinLinkOS if requested so that the expected output file is produced. + WriteBitcodeToFile(&M, *ThinLinkOS); return; } promoteTypeIds(M, ModuleId); - auto IsInMergedM = [&](const GlobalValue *GV) { - auto *GVar = dyn_cast<GlobalVariable>(GV->getBaseObject()); - if (!GVar) - return false; - + // Returns whether a global has attached type metadata. Such globals may + // participate in CFI or whole-program devirtualization, so they need to + // appear in the merged module instead of the thin LTO module. + auto HasTypeMetadata = [&](const GlobalObject *GO) { SmallVector<MDNode *, 1> MDs; - GVar->getMetadata(LLVMContext::MD_type, MDs); + GO->getMetadata(LLVMContext::MD_type, MDs); return !MDs.empty(); }; + // Collect the set of virtual functions that are eligible for virtual constant + // propagation. Each eligible function must not access memory, must return + // an integer of width <=64 bits, must take at least one argument, must not + // use its first argument (assumed to be "this") and all arguments other than + // the first one must be of <=64 bit integer type. + // + // Note that we test whether this copy of the function is readnone, rather + // than testing function attributes, which must hold for any copy of the + // function, even a less optimized version substituted at link time. This is + // sound because the virtual constant propagation optimizations effectively + // inline all implementations of the virtual function into each call site, + // rather than using function attributes to perform local optimization. + std::set<const Function *> EligibleVirtualFns; + // If any member of a comdat lives in MergedM, put all members of that + // comdat in MergedM to keep the comdat together. + DenseSet<const Comdat *> MergedMComdats; + for (GlobalVariable &GV : M.globals()) + if (HasTypeMetadata(&GV)) { + if (const auto *C = GV.getComdat()) + MergedMComdats.insert(C); + forEachVirtualFunction(GV.getInitializer(), [&](Function *F) { + auto *RT = dyn_cast<IntegerType>(F->getReturnType()); + if (!RT || RT->getBitWidth() > 64 || F->arg_empty() || + !F->arg_begin()->use_empty()) + return; + for (auto &Arg : make_range(std::next(F->arg_begin()), F->arg_end())) { + auto *ArgT = dyn_cast<IntegerType>(Arg.getType()); + if (!ArgT || ArgT->getBitWidth() > 64) + return; + } + if (!F->isDeclaration() && + computeFunctionBodyMemoryAccess(*F, AARGetter(*F)) == MAK_ReadNone) + EligibleVirtualFns.insert(F); + }); + } + ValueToValueMapTy VMap; - std::unique_ptr<Module> MergedM(CloneModule(&M, VMap, IsInMergedM)); + std::unique_ptr<Module> MergedM( + CloneModule(&M, VMap, [&](const GlobalValue *GV) -> bool { + if (const auto *C = GV->getComdat()) + if (MergedMComdats.count(C)) + return true; + if (auto *F = dyn_cast<Function>(GV)) + return EligibleVirtualFns.count(F); + if (auto *GVar = dyn_cast_or_null<GlobalVariable>(GV->getBaseObject())) + return HasTypeMetadata(GVar); + return false; + })); + StripDebugInfo(*MergedM); + + for (Function &F : *MergedM) + if (!F.isDeclaration()) { + // Reset the linkage of all functions eligible for virtual constant + // propagation. The canonical definitions live in the thin LTO module so + // that they can be imported. + F.setLinkage(GlobalValue::AvailableExternallyLinkage); + F.setComdat(nullptr); + } - filterModule(&M, [&](const GlobalValue *GV) { return !IsInMergedM(GV); }); + SetVector<GlobalValue *> CfiFunctions; + for (auto &F : M) + if ((!F.hasLocalLinkage() || F.hasAddressTaken()) && HasTypeMetadata(&F)) + CfiFunctions.insert(&F); + + // Remove all globals with type metadata, globals with comdats that live in + // MergedM, and aliases pointing to such globals from the thin LTO module. + filterModule(&M, [&](const GlobalValue *GV) { + if (auto *GVar = dyn_cast_or_null<GlobalVariable>(GV->getBaseObject())) + if (HasTypeMetadata(GVar)) + return false; + if (const auto *C = GV->getComdat()) + if (MergedMComdats.count(C)) + return false; + return true; + }); + + promoteInternals(*MergedM, M, ModuleId, CfiFunctions); + promoteInternals(M, *MergedM, ModuleId, CfiFunctions); + + SmallVector<MDNode *, 8> CfiFunctionMDs; + for (auto V : CfiFunctions) { + Function &F = *cast<Function>(V); + SmallVector<MDNode *, 2> Types; + F.getMetadata(LLVMContext::MD_type, Types); + + auto &Ctx = MergedM->getContext(); + SmallVector<Metadata *, 4> Elts; + Elts.push_back(MDString::get(Ctx, F.getName())); + CfiFunctionLinkage Linkage; + if (!F.isDeclarationForLinker()) + Linkage = CFL_Definition; + else if (F.isWeakForLinker()) + Linkage = CFL_WeakDeclaration; + else + Linkage = CFL_Declaration; + Elts.push_back(ConstantAsMetadata::get( + llvm::ConstantInt::get(Type::getInt8Ty(Ctx), Linkage))); + for (auto Type : Types) + Elts.push_back(Type); + CfiFunctionMDs.push_back(MDTuple::get(Ctx, Elts)); + } - promoteInternals(*MergedM, M, ModuleId); - promoteInternals(M, *MergedM, ModuleId); + if(!CfiFunctionMDs.empty()) { + NamedMDNode *NMD = MergedM->getOrInsertNamedMetadata("cfi.functions"); + for (auto MD : CfiFunctionMDs) + NMD->addOperand(MD); + } simplifyExternals(*MergedM); - SmallVector<char, 0> Buffer; - BitcodeWriter W(Buffer); - // FIXME: Try to re-use BSI and PFI from the original module here. - ModuleSummaryIndex Index = buildModuleSummaryIndex(M, nullptr, nullptr); - W.writeModule(&M, /*ShouldPreserveUseListOrder=*/false, &Index, - /*GenerateHash=*/true); + ProfileSummaryInfo PSI(M); + ModuleSummaryIndex Index = buildModuleSummaryIndex(M, nullptr, &PSI); - W.writeModule(MergedM.get()); + // Mark the merged module as requiring full LTO. We still want an index for + // it though, so that it can participate in summary-based dead stripping. + MergedM->addModuleFlag(Module::Error, "ThinLTO", uint32_t(0)); + ModuleSummaryIndex MergedMIndex = + buildModuleSummaryIndex(*MergedM, nullptr, &PSI); + SmallVector<char, 0> Buffer; + + BitcodeWriter W(Buffer); + // Save the module hash produced for the full bitcode, which will + // be used in the backends, and use that in the minimized bitcode + // produced for the full link. + ModuleHash ModHash = {{0}}; + W.writeModule(&M, /*ShouldPreserveUseListOrder=*/false, &Index, + /*GenerateHash=*/true, &ModHash); + W.writeModule(MergedM.get(), /*ShouldPreserveUseListOrder=*/false, + &MergedMIndex); + W.writeSymtab(); + W.writeStrtab(); OS << Buffer; + + // If a minimized bitcode module was requested for the thin link, + // strip the debug info (the merged module was already stripped above) + // and write it to the given OS. + if (ThinLinkOS) { + Buffer.clear(); + BitcodeWriter W2(Buffer); + StripDebugInfo(M); + W2.writeModule(&M, /*ShouldPreserveUseListOrder=*/false, &Index, + /*GenerateHash=*/false, &ModHash); + W2.writeModule(MergedM.get(), /*ShouldPreserveUseListOrder=*/false, + &MergedMIndex); + W2.writeSymtab(); + W2.writeStrtab(); + *ThinLinkOS << Buffer; + } } // Returns whether this module needs to be split because it uses type metadata. @@ -292,28 +405,45 @@ bool requiresSplit(Module &M) { return false; } -void writeThinLTOBitcode(raw_ostream &OS, Module &M, - const ModuleSummaryIndex *Index) { +void writeThinLTOBitcode(raw_ostream &OS, raw_ostream *ThinLinkOS, + function_ref<AAResults &(Function &)> AARGetter, + Module &M, const ModuleSummaryIndex *Index) { // See if this module has any type metadata. If so, we need to split it. if (requiresSplit(M)) - return splitAndWriteThinLTOBitcode(OS, M); + return splitAndWriteThinLTOBitcode(OS, ThinLinkOS, AARGetter, M); // Otherwise we can just write it out as a regular module. + + // Save the module hash produced for the full bitcode, which will + // be used in the backends, and use that in the minimized bitcode + // produced for the full link. + ModuleHash ModHash = {{0}}; WriteBitcodeToFile(&M, OS, /*ShouldPreserveUseListOrder=*/false, Index, - /*GenerateHash=*/true); + /*GenerateHash=*/true, &ModHash); + // If a minimized bitcode module was requested for the thin link, + // strip the debug info and write it to the given OS. + if (ThinLinkOS) { + StripDebugInfo(M); + WriteBitcodeToFile(&M, *ThinLinkOS, /*ShouldPreserveUseListOrder=*/false, + Index, + /*GenerateHash=*/false, &ModHash); + } } class WriteThinLTOBitcode : public ModulePass { raw_ostream &OS; // raw_ostream to print on + // The output stream on which to emit a minimized module for use + // just in the thin link, if requested. + raw_ostream *ThinLinkOS; public: static char ID; // Pass identification, replacement for typeid - WriteThinLTOBitcode() : ModulePass(ID), OS(dbgs()) { + WriteThinLTOBitcode() : ModulePass(ID), OS(dbgs()), ThinLinkOS(nullptr) { initializeWriteThinLTOBitcodePass(*PassRegistry::getPassRegistry()); } - explicit WriteThinLTOBitcode(raw_ostream &o) - : ModulePass(ID), OS(o) { + explicit WriteThinLTOBitcode(raw_ostream &o, raw_ostream *ThinLinkOS) + : ModulePass(ID), OS(o), ThinLinkOS(ThinLinkOS) { initializeWriteThinLTOBitcodePass(*PassRegistry::getPassRegistry()); } @@ -322,12 +452,14 @@ public: bool runOnModule(Module &M) override { const ModuleSummaryIndex *Index = &(getAnalysis<ModuleSummaryIndexWrapperPass>().getIndex()); - writeThinLTOBitcode(OS, M, Index); + writeThinLTOBitcode(OS, ThinLinkOS, LegacyAARGetter(*this), M, Index); return true; } void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesAll(); + AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<ModuleSummaryIndexWrapperPass>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); } }; } // anonymous namespace @@ -335,10 +467,25 @@ public: char WriteThinLTOBitcode::ID = 0; INITIALIZE_PASS_BEGIN(WriteThinLTOBitcode, "write-thinlto-bitcode", "Write ThinLTO Bitcode", false, true) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(ModuleSummaryIndexWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_END(WriteThinLTOBitcode, "write-thinlto-bitcode", "Write ThinLTO Bitcode", false, true) -ModulePass *llvm::createWriteThinLTOBitcodePass(raw_ostream &Str) { - return new WriteThinLTOBitcode(Str); +ModulePass *llvm::createWriteThinLTOBitcodePass(raw_ostream &Str, + raw_ostream *ThinLinkOS) { + return new WriteThinLTOBitcode(Str, ThinLinkOS); +} + +PreservedAnalyses +llvm::ThinLTOBitcodeWriterPass::run(Module &M, ModuleAnalysisManager &AM) { + FunctionAnalysisManager &FAM = + AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + writeThinLTOBitcode(OS, ThinLinkOS, + [&FAM](Function &F) -> AAResults & { + return FAM.getResult<AAManager>(F); + }, + M, &AM.getResult<ModuleSummaryIndexAnalysis>(M)); + return PreservedAnalyses::all(); } diff --git a/contrib/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/contrib/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp index 844cc0f..00769cd 100644 --- a/contrib/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ b/contrib/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -25,6 +25,20 @@ // returns 0, or a single vtable's function returns 1, replace each virtual // call with a comparison of the vptr against that vtable's address. // +// This pass is intended to be used during the regular and thin LTO pipelines. +// During regular LTO, the pass determines the best optimization for each +// virtual call and applies the resolutions directly to virtual calls that are +// eligible for virtual call optimization (i.e. calls that use either of the +// llvm.assume(llvm.type.test) or llvm.type.checked.load intrinsics). During +// ThinLTO, the pass operates in two phases: +// - Export phase: this is run during the thin link over a single merged module +// that contains all vtables with !type metadata that participate in the link. +// The pass computes a resolution for each virtual call and stores it in the +// type identifier summary. +// - Import phase: this is run during the thin backends over the individual +// modules. The pass applies the resolutions previously computed during the +// import phase to each eligible virtual call. +// //===----------------------------------------------------------------------===// #include "llvm/Transforms/IPO/WholeProgramDevirt.h" @@ -32,9 +46,11 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMapInfo.h" #include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/iterator_range.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/TypeMetadataUtils.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/Constants.h" @@ -54,12 +70,16 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" +#include "llvm/IR/ModuleSummaryIndexYAML.h" #include "llvm/Pass.h" #include "llvm/PassRegistry.h" #include "llvm/PassSupport.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/FileSystem.h" #include "llvm/Support/MathExtras.h" #include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/IPO/FunctionAttrs.h" #include "llvm/Transforms/Utils/Evaluator.h" #include <algorithm> #include <cstddef> @@ -72,6 +92,26 @@ using namespace wholeprogramdevirt; #define DEBUG_TYPE "wholeprogramdevirt" +static cl::opt<PassSummaryAction> ClSummaryAction( + "wholeprogramdevirt-summary-action", + cl::desc("What to do with the summary when running this pass"), + cl::values(clEnumValN(PassSummaryAction::None, "none", "Do nothing"), + clEnumValN(PassSummaryAction::Import, "import", + "Import typeid resolutions from summary and globals"), + clEnumValN(PassSummaryAction::Export, "export", + "Export typeid resolutions to summary and globals")), + cl::Hidden); + +static cl::opt<std::string> ClReadSummary( + "wholeprogramdevirt-read-summary", + cl::desc("Read summary from given YAML file before running pass"), + cl::Hidden); + +static cl::opt<std::string> ClWriteSummary( + "wholeprogramdevirt-write-summary", + cl::desc("Write summary to given YAML file after running pass"), + cl::Hidden); + // Find the minimum offset that we may store a value of size Size bits at. If // IsAfter is set, look for an offset before the object, otherwise look for an // offset after the object. @@ -259,15 +299,92 @@ struct VirtualCallSite { } }; +// Call site information collected for a specific VTableSlot and possibly a list +// of constant integer arguments. The grouping by arguments is handled by the +// VTableSlotInfo class. +struct CallSiteInfo { + /// The set of call sites for this slot. Used during regular LTO and the + /// import phase of ThinLTO (as well as the export phase of ThinLTO for any + /// call sites that appear in the merged module itself); in each of these + /// cases we are directly operating on the call sites at the IR level. + std::vector<VirtualCallSite> CallSites; + + // These fields are used during the export phase of ThinLTO and reflect + // information collected from function summaries. + + /// Whether any function summary contains an llvm.assume(llvm.type.test) for + /// this slot. + bool SummaryHasTypeTestAssumeUsers; + + /// CFI-specific: a vector containing the list of function summaries that use + /// the llvm.type.checked.load intrinsic and therefore will require + /// resolutions for llvm.type.test in order to implement CFI checks if + /// devirtualization was unsuccessful. If devirtualization was successful, the + /// pass will clear this vector by calling markDevirt(). If at the end of the + /// pass the vector is non-empty, we will need to add a use of llvm.type.test + /// to each of the function summaries in the vector. + std::vector<FunctionSummary *> SummaryTypeCheckedLoadUsers; + + bool isExported() const { + return SummaryHasTypeTestAssumeUsers || + !SummaryTypeCheckedLoadUsers.empty(); + } + + /// As explained in the comment for SummaryTypeCheckedLoadUsers. + void markDevirt() { SummaryTypeCheckedLoadUsers.clear(); } +}; + +// Call site information collected for a specific VTableSlot. +struct VTableSlotInfo { + // The set of call sites which do not have all constant integer arguments + // (excluding "this"). + CallSiteInfo CSInfo; + + // The set of call sites with all constant integer arguments (excluding + // "this"), grouped by argument list. + std::map<std::vector<uint64_t>, CallSiteInfo> ConstCSInfo; + + void addCallSite(Value *VTable, CallSite CS, unsigned *NumUnsafeUses); + +private: + CallSiteInfo &findCallSiteInfo(CallSite CS); +}; + +CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallSite CS) { + std::vector<uint64_t> Args; + auto *CI = dyn_cast<IntegerType>(CS.getType()); + if (!CI || CI->getBitWidth() > 64 || CS.arg_empty()) + return CSInfo; + for (auto &&Arg : make_range(CS.arg_begin() + 1, CS.arg_end())) { + auto *CI = dyn_cast<ConstantInt>(Arg); + if (!CI || CI->getBitWidth() > 64) + return CSInfo; + Args.push_back(CI->getZExtValue()); + } + return ConstCSInfo[Args]; +} + +void VTableSlotInfo::addCallSite(Value *VTable, CallSite CS, + unsigned *NumUnsafeUses) { + findCallSiteInfo(CS).CallSites.push_back({VTable, CS, NumUnsafeUses}); +} + struct DevirtModule { Module &M; + function_ref<AAResults &(Function &)> AARGetter; + + ModuleSummaryIndex *ExportSummary; + const ModuleSummaryIndex *ImportSummary; + IntegerType *Int8Ty; PointerType *Int8PtrTy; IntegerType *Int32Ty; + IntegerType *Int64Ty; + IntegerType *IntPtrTy; bool RemarksEnabled; - MapVector<VTableSlot, std::vector<VirtualCallSite>> CallSlots; + MapVector<VTableSlot, VTableSlotInfo> CallSlots; // This map keeps track of the number of "unsafe" uses of a loaded function // pointer. The key is the associated llvm.type.test intrinsic call generated @@ -279,11 +396,18 @@ struct DevirtModule { // true. std::map<CallInst *, unsigned> NumUnsafeUsesForTypeTest; - DevirtModule(Module &M) - : M(M), Int8Ty(Type::getInt8Ty(M.getContext())), + DevirtModule(Module &M, function_ref<AAResults &(Function &)> AARGetter, + ModuleSummaryIndex *ExportSummary, + const ModuleSummaryIndex *ImportSummary) + : M(M), AARGetter(AARGetter), ExportSummary(ExportSummary), + ImportSummary(ImportSummary), Int8Ty(Type::getInt8Ty(M.getContext())), Int8PtrTy(Type::getInt8PtrTy(M.getContext())), Int32Ty(Type::getInt32Ty(M.getContext())), - RemarksEnabled(areRemarksEnabled()) {} + Int64Ty(Type::getInt64Ty(M.getContext())), + IntPtrTy(M.getDataLayout().getIntPtrType(M.getContext(), 0)), + RemarksEnabled(areRemarksEnabled()) { + assert(!(ExportSummary && ImportSummary)); + } bool areRemarksEnabled(); @@ -298,57 +422,169 @@ struct DevirtModule { tryFindVirtualCallTargets(std::vector<VirtualCallTarget> &TargetsForSlot, const std::set<TypeMemberInfo> &TypeMemberInfos, uint64_t ByteOffset); + + void applySingleImplDevirt(VTableSlotInfo &SlotInfo, Constant *TheFn, + bool &IsExported); bool trySingleImplDevirt(MutableArrayRef<VirtualCallTarget> TargetsForSlot, - MutableArrayRef<VirtualCallSite> CallSites); + VTableSlotInfo &SlotInfo, + WholeProgramDevirtResolution *Res); + bool tryEvaluateFunctionsWithArgs( MutableArrayRef<VirtualCallTarget> TargetsForSlot, - ArrayRef<ConstantInt *> Args); - bool tryUniformRetValOpt(IntegerType *RetType, - MutableArrayRef<VirtualCallTarget> TargetsForSlot, - MutableArrayRef<VirtualCallSite> CallSites); + ArrayRef<uint64_t> Args); + + void applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, + uint64_t TheRetVal); + bool tryUniformRetValOpt(MutableArrayRef<VirtualCallTarget> TargetsForSlot, + CallSiteInfo &CSInfo, + WholeProgramDevirtResolution::ByArg *Res); + + // Returns the global symbol name that is used to export information about the + // given vtable slot and list of arguments. + std::string getGlobalName(VTableSlot Slot, ArrayRef<uint64_t> Args, + StringRef Name); + + // This function is called during the export phase to create a symbol + // definition containing information about the given vtable slot and list of + // arguments. + void exportGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, StringRef Name, + Constant *C); + + // This function is called during the import phase to create a reference to + // the symbol definition created during the export phase. + Constant *importGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, + StringRef Name, unsigned AbsWidth = 0); + + void applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, bool IsOne, + Constant *UniqueMemberAddr); bool tryUniqueRetValOpt(unsigned BitWidth, MutableArrayRef<VirtualCallTarget> TargetsForSlot, - MutableArrayRef<VirtualCallSite> CallSites); + CallSiteInfo &CSInfo, + WholeProgramDevirtResolution::ByArg *Res, + VTableSlot Slot, ArrayRef<uint64_t> Args); + + void applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName, + Constant *Byte, Constant *Bit); bool tryVirtualConstProp(MutableArrayRef<VirtualCallTarget> TargetsForSlot, - ArrayRef<VirtualCallSite> CallSites); + VTableSlotInfo &SlotInfo, + WholeProgramDevirtResolution *Res, VTableSlot Slot); void rebuildGlobal(VTableBits &B); + // Apply the summary resolution for Slot to all virtual calls in SlotInfo. + void importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo); + + // If we were able to eliminate all unsafe uses for a type checked load, + // eliminate the associated type tests by replacing them with true. + void removeRedundantTypeTests(); + bool run(); + + // Lower the module using the action and summary passed as command line + // arguments. For testing purposes only. + static bool runForTesting(Module &M, + function_ref<AAResults &(Function &)> AARGetter); }; struct WholeProgramDevirt : public ModulePass { static char ID; - WholeProgramDevirt() : ModulePass(ID) { + bool UseCommandLine = false; + + ModuleSummaryIndex *ExportSummary; + const ModuleSummaryIndex *ImportSummary; + + WholeProgramDevirt() : ModulePass(ID), UseCommandLine(true) { + initializeWholeProgramDevirtPass(*PassRegistry::getPassRegistry()); + } + + WholeProgramDevirt(ModuleSummaryIndex *ExportSummary, + const ModuleSummaryIndex *ImportSummary) + : ModulePass(ID), ExportSummary(ExportSummary), + ImportSummary(ImportSummary) { initializeWholeProgramDevirtPass(*PassRegistry::getPassRegistry()); } bool runOnModule(Module &M) override { if (skipModule(M)) return false; + if (UseCommandLine) + return DevirtModule::runForTesting(M, LegacyAARGetter(*this)); + return DevirtModule(M, LegacyAARGetter(*this), ExportSummary, ImportSummary) + .run(); + } - return DevirtModule(M).run(); + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); } }; } // end anonymous namespace -INITIALIZE_PASS(WholeProgramDevirt, "wholeprogramdevirt", - "Whole program devirtualization", false, false) +INITIALIZE_PASS_BEGIN(WholeProgramDevirt, "wholeprogramdevirt", + "Whole program devirtualization", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(WholeProgramDevirt, "wholeprogramdevirt", + "Whole program devirtualization", false, false) char WholeProgramDevirt::ID = 0; -ModulePass *llvm::createWholeProgramDevirtPass() { - return new WholeProgramDevirt; +ModulePass * +llvm::createWholeProgramDevirtPass(ModuleSummaryIndex *ExportSummary, + const ModuleSummaryIndex *ImportSummary) { + return new WholeProgramDevirt(ExportSummary, ImportSummary); } PreservedAnalyses WholeProgramDevirtPass::run(Module &M, - ModuleAnalysisManager &) { - if (!DevirtModule(M).run()) + ModuleAnalysisManager &AM) { + auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + auto AARGetter = [&](Function &F) -> AAResults & { + return FAM.getResult<AAManager>(F); + }; + if (!DevirtModule(M, AARGetter, nullptr, nullptr).run()) return PreservedAnalyses::all(); return PreservedAnalyses::none(); } +bool DevirtModule::runForTesting( + Module &M, function_ref<AAResults &(Function &)> AARGetter) { + ModuleSummaryIndex Summary; + + // Handle the command-line summary arguments. This code is for testing + // purposes only, so we handle errors directly. + if (!ClReadSummary.empty()) { + ExitOnError ExitOnErr("-wholeprogramdevirt-read-summary: " + ClReadSummary + + ": "); + auto ReadSummaryFile = + ExitOnErr(errorOrToExpected(MemoryBuffer::getFile(ClReadSummary))); + + yaml::Input In(ReadSummaryFile->getBuffer()); + In >> Summary; + ExitOnErr(errorCodeToError(In.error())); + } + + bool Changed = + DevirtModule( + M, AARGetter, + ClSummaryAction == PassSummaryAction::Export ? &Summary : nullptr, + ClSummaryAction == PassSummaryAction::Import ? &Summary : nullptr) + .run(); + + if (!ClWriteSummary.empty()) { + ExitOnError ExitOnErr( + "-wholeprogramdevirt-write-summary: " + ClWriteSummary + ": "); + std::error_code EC; + raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::F_Text); + ExitOnErr(errorCodeToError(EC)); + + yaml::Output Out(OS); + Out << Summary; + } + + return Changed; +} + void DevirtModule::buildTypeIdentifierMap( std::vector<VTableBits> &Bits, DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap) { @@ -443,9 +679,31 @@ bool DevirtModule::tryFindVirtualCallTargets( return !TargetsForSlot.empty(); } +void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo, + Constant *TheFn, bool &IsExported) { + auto Apply = [&](CallSiteInfo &CSInfo) { + for (auto &&VCallSite : CSInfo.CallSites) { + if (RemarksEnabled) + VCallSite.emitRemark("single-impl", TheFn->getName()); + VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast( + TheFn, VCallSite.CS.getCalledValue()->getType())); + // This use is no longer unsafe. + if (VCallSite.NumUnsafeUses) + --*VCallSite.NumUnsafeUses; + } + if (CSInfo.isExported()) { + IsExported = true; + CSInfo.markDevirt(); + } + }; + Apply(SlotInfo.CSInfo); + for (auto &P : SlotInfo.ConstCSInfo) + Apply(P.second); +} + bool DevirtModule::trySingleImplDevirt( MutableArrayRef<VirtualCallTarget> TargetsForSlot, - MutableArrayRef<VirtualCallSite> CallSites) { + VTableSlotInfo &SlotInfo, WholeProgramDevirtResolution *Res) { // See if the program contains a single implementation of this virtual // function. Function *TheFn = TargetsForSlot[0].Fn; @@ -453,39 +711,51 @@ bool DevirtModule::trySingleImplDevirt( if (TheFn != Target.Fn) return false; + // If so, update each call site to call that implementation directly. if (RemarksEnabled) TargetsForSlot[0].WasDevirt = true; - // If so, update each call site to call that implementation directly. - for (auto &&VCallSite : CallSites) { - if (RemarksEnabled) - VCallSite.emitRemark("single-impl", TheFn->getName()); - VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast( - TheFn, VCallSite.CS.getCalledValue()->getType())); - // This use is no longer unsafe. - if (VCallSite.NumUnsafeUses) - --*VCallSite.NumUnsafeUses; + + bool IsExported = false; + applySingleImplDevirt(SlotInfo, TheFn, IsExported); + if (!IsExported) + return false; + + // If the only implementation has local linkage, we must promote to external + // to make it visible to thin LTO objects. We can only get here during the + // ThinLTO export phase. + if (TheFn->hasLocalLinkage()) { + TheFn->setLinkage(GlobalValue::ExternalLinkage); + TheFn->setVisibility(GlobalValue::HiddenVisibility); + TheFn->setName(TheFn->getName() + "$merged"); } + + Res->TheKind = WholeProgramDevirtResolution::SingleImpl; + Res->SingleImplName = TheFn->getName(); + return true; } bool DevirtModule::tryEvaluateFunctionsWithArgs( MutableArrayRef<VirtualCallTarget> TargetsForSlot, - ArrayRef<ConstantInt *> Args) { + ArrayRef<uint64_t> Args) { // Evaluate each function and store the result in each target's RetVal // field. for (VirtualCallTarget &Target : TargetsForSlot) { if (Target.Fn->arg_size() != Args.size() + 1) return false; - for (unsigned I = 0; I != Args.size(); ++I) - if (Target.Fn->getFunctionType()->getParamType(I + 1) != - Args[I]->getType()) - return false; Evaluator Eval(M.getDataLayout(), nullptr); SmallVector<Constant *, 2> EvalArgs; EvalArgs.push_back( Constant::getNullValue(Target.Fn->getFunctionType()->getParamType(0))); - EvalArgs.insert(EvalArgs.end(), Args.begin(), Args.end()); + for (unsigned I = 0; I != Args.size(); ++I) { + auto *ArgTy = dyn_cast<IntegerType>( + Target.Fn->getFunctionType()->getParamType(I + 1)); + if (!ArgTy) + return false; + EvalArgs.push_back(ConstantInt::get(ArgTy, Args[I])); + } + Constant *RetVal; if (!Eval.EvaluateFunction(Target.Fn, RetVal, EvalArgs) || !isa<ConstantInt>(RetVal)) @@ -495,9 +765,18 @@ bool DevirtModule::tryEvaluateFunctionsWithArgs( return true; } +void DevirtModule::applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, + uint64_t TheRetVal) { + for (auto Call : CSInfo.CallSites) + Call.replaceAndErase( + "uniform-ret-val", FnName, RemarksEnabled, + ConstantInt::get(cast<IntegerType>(Call.CS.getType()), TheRetVal)); + CSInfo.markDevirt(); +} + bool DevirtModule::tryUniformRetValOpt( - IntegerType *RetType, MutableArrayRef<VirtualCallTarget> TargetsForSlot, - MutableArrayRef<VirtualCallSite> CallSites) { + MutableArrayRef<VirtualCallTarget> TargetsForSlot, CallSiteInfo &CSInfo, + WholeProgramDevirtResolution::ByArg *Res) { // Uniform return value optimization. If all functions return the same // constant, replace all calls with that constant. uint64_t TheRetVal = TargetsForSlot[0].RetVal; @@ -505,19 +784,77 @@ bool DevirtModule::tryUniformRetValOpt( if (Target.RetVal != TheRetVal) return false; - auto TheRetValConst = ConstantInt::get(RetType, TheRetVal); - for (auto Call : CallSites) - Call.replaceAndErase("uniform-ret-val", TargetsForSlot[0].Fn->getName(), - RemarksEnabled, TheRetValConst); + if (CSInfo.isExported()) { + Res->TheKind = WholeProgramDevirtResolution::ByArg::UniformRetVal; + Res->Info = TheRetVal; + } + + applyUniformRetValOpt(CSInfo, TargetsForSlot[0].Fn->getName(), TheRetVal); if (RemarksEnabled) for (auto &&Target : TargetsForSlot) Target.WasDevirt = true; return true; } +std::string DevirtModule::getGlobalName(VTableSlot Slot, + ArrayRef<uint64_t> Args, + StringRef Name) { + std::string FullName = "__typeid_"; + raw_string_ostream OS(FullName); + OS << cast<MDString>(Slot.TypeID)->getString() << '_' << Slot.ByteOffset; + for (uint64_t Arg : Args) + OS << '_' << Arg; + OS << '_' << Name; + return OS.str(); +} + +void DevirtModule::exportGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, + StringRef Name, Constant *C) { + GlobalAlias *GA = GlobalAlias::create(Int8Ty, 0, GlobalValue::ExternalLinkage, + getGlobalName(Slot, Args, Name), C, &M); + GA->setVisibility(GlobalValue::HiddenVisibility); +} + +Constant *DevirtModule::importGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, + StringRef Name, unsigned AbsWidth) { + Constant *C = M.getOrInsertGlobal(getGlobalName(Slot, Args, Name), Int8Ty); + auto *GV = dyn_cast<GlobalVariable>(C); + // We only need to set metadata if the global is newly created, in which + // case it would not have hidden visibility. + if (!GV || GV->getVisibility() == GlobalValue::HiddenVisibility) + return C; + + GV->setVisibility(GlobalValue::HiddenVisibility); + auto SetAbsRange = [&](uint64_t Min, uint64_t Max) { + auto *MinC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Min)); + auto *MaxC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Max)); + GV->setMetadata(LLVMContext::MD_absolute_symbol, + MDNode::get(M.getContext(), {MinC, MaxC})); + }; + if (AbsWidth == IntPtrTy->getBitWidth()) + SetAbsRange(~0ull, ~0ull); // Full set. + else if (AbsWidth) + SetAbsRange(0, 1ull << AbsWidth); + return GV; +} + +void DevirtModule::applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, + bool IsOne, + Constant *UniqueMemberAddr) { + for (auto &&Call : CSInfo.CallSites) { + IRBuilder<> B(Call.CS.getInstruction()); + Value *Cmp = B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, + Call.VTable, UniqueMemberAddr); + Cmp = B.CreateZExt(Cmp, Call.CS->getType()); + Call.replaceAndErase("unique-ret-val", FnName, RemarksEnabled, Cmp); + } + CSInfo.markDevirt(); +} + bool DevirtModule::tryUniqueRetValOpt( unsigned BitWidth, MutableArrayRef<VirtualCallTarget> TargetsForSlot, - MutableArrayRef<VirtualCallSite> CallSites) { + CallSiteInfo &CSInfo, WholeProgramDevirtResolution::ByArg *Res, + VTableSlot Slot, ArrayRef<uint64_t> Args) { // IsOne controls whether we look for a 0 or a 1. auto tryUniqueRetValOptFor = [&](bool IsOne) { const TypeMemberInfo *UniqueMember = nullptr; @@ -533,16 +870,23 @@ bool DevirtModule::tryUniqueRetValOpt( // checked for a uniform return value in tryUniformRetValOpt. assert(UniqueMember); - // Replace each call with the comparison. - for (auto &&Call : CallSites) { - IRBuilder<> B(Call.CS.getInstruction()); - Value *OneAddr = B.CreateBitCast(UniqueMember->Bits->GV, Int8PtrTy); - OneAddr = B.CreateConstGEP1_64(OneAddr, UniqueMember->Offset); - Value *Cmp = B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, - Call.VTable, OneAddr); - Call.replaceAndErase("unique-ret-val", TargetsForSlot[0].Fn->getName(), - RemarksEnabled, Cmp); + Constant *UniqueMemberAddr = + ConstantExpr::getBitCast(UniqueMember->Bits->GV, Int8PtrTy); + UniqueMemberAddr = ConstantExpr::getGetElementPtr( + Int8Ty, UniqueMemberAddr, + ConstantInt::get(Int64Ty, UniqueMember->Offset)); + + if (CSInfo.isExported()) { + Res->TheKind = WholeProgramDevirtResolution::ByArg::UniqueRetVal; + Res->Info = IsOne; + + exportGlobal(Slot, Args, "unique_member", UniqueMemberAddr); } + + // Replace each call with the comparison. + applyUniqueRetValOpt(CSInfo, TargetsForSlot[0].Fn->getName(), IsOne, + UniqueMemberAddr); + // Update devirtualization statistics for targets. if (RemarksEnabled) for (auto &&Target : TargetsForSlot) @@ -560,9 +904,30 @@ bool DevirtModule::tryUniqueRetValOpt( return false; } +void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName, + Constant *Byte, Constant *Bit) { + for (auto Call : CSInfo.CallSites) { + auto *RetType = cast<IntegerType>(Call.CS.getType()); + IRBuilder<> B(Call.CS.getInstruction()); + Value *Addr = B.CreateGEP(Int8Ty, Call.VTable, Byte); + if (RetType->getBitWidth() == 1) { + Value *Bits = B.CreateLoad(Addr); + Value *BitsAndBit = B.CreateAnd(Bits, Bit); + auto IsBitSet = B.CreateICmpNE(BitsAndBit, ConstantInt::get(Int8Ty, 0)); + Call.replaceAndErase("virtual-const-prop-1-bit", FnName, RemarksEnabled, + IsBitSet); + } else { + Value *ValAddr = B.CreateBitCast(Addr, RetType->getPointerTo()); + Value *Val = B.CreateLoad(RetType, ValAddr); + Call.replaceAndErase("virtual-const-prop", FnName, RemarksEnabled, Val); + } + } + CSInfo.markDevirt(); +} + bool DevirtModule::tryVirtualConstProp( - MutableArrayRef<VirtualCallTarget> TargetsForSlot, - ArrayRef<VirtualCallSite> CallSites) { + MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo, + WholeProgramDevirtResolution *Res, VTableSlot Slot) { // This only works if the function returns an integer. auto RetType = dyn_cast<IntegerType>(TargetsForSlot[0].Fn->getReturnType()); if (!RetType) @@ -571,55 +936,38 @@ bool DevirtModule::tryVirtualConstProp( if (BitWidth > 64) return false; - // Make sure that each function does not access memory, takes at least one - // argument, does not use its first argument (which we assume is 'this'), - // and has the same return type. + // Make sure that each function is defined, does not access memory, takes at + // least one argument, does not use its first argument (which we assume is + // 'this'), and has the same return type. + // + // Note that we test whether this copy of the function is readnone, rather + // than testing function attributes, which must hold for any copy of the + // function, even a less optimized version substituted at link time. This is + // sound because the virtual constant propagation optimizations effectively + // inline all implementations of the virtual function into each call site, + // rather than using function attributes to perform local optimization. for (VirtualCallTarget &Target : TargetsForSlot) { - if (!Target.Fn->doesNotAccessMemory() || Target.Fn->arg_empty() || - !Target.Fn->arg_begin()->use_empty() || + if (Target.Fn->isDeclaration() || + computeFunctionBodyMemoryAccess(*Target.Fn, AARGetter(*Target.Fn)) != + MAK_ReadNone || + Target.Fn->arg_empty() || !Target.Fn->arg_begin()->use_empty() || Target.Fn->getReturnType() != RetType) return false; } - // Group call sites by the list of constant arguments they pass. - // The comparator ensures deterministic ordering. - struct ByAPIntValue { - bool operator()(const std::vector<ConstantInt *> &A, - const std::vector<ConstantInt *> &B) const { - return std::lexicographical_compare( - A.begin(), A.end(), B.begin(), B.end(), - [](ConstantInt *AI, ConstantInt *BI) { - return AI->getValue().ult(BI->getValue()); - }); - } - }; - std::map<std::vector<ConstantInt *>, std::vector<VirtualCallSite>, - ByAPIntValue> - VCallSitesByConstantArg; - for (auto &&VCallSite : CallSites) { - std::vector<ConstantInt *> Args; - if (VCallSite.CS.getType() != RetType) - continue; - for (auto &&Arg : - make_range(VCallSite.CS.arg_begin() + 1, VCallSite.CS.arg_end())) { - if (!isa<ConstantInt>(Arg)) - break; - Args.push_back(cast<ConstantInt>(&Arg)); - } - if (Args.size() + 1 != VCallSite.CS.arg_size()) - continue; - - VCallSitesByConstantArg[Args].push_back(VCallSite); - } - - for (auto &&CSByConstantArg : VCallSitesByConstantArg) { + for (auto &&CSByConstantArg : SlotInfo.ConstCSInfo) { if (!tryEvaluateFunctionsWithArgs(TargetsForSlot, CSByConstantArg.first)) continue; - if (tryUniformRetValOpt(RetType, TargetsForSlot, CSByConstantArg.second)) + WholeProgramDevirtResolution::ByArg *ResByArg = nullptr; + if (Res) + ResByArg = &Res->ResByArg[CSByConstantArg.first]; + + if (tryUniformRetValOpt(TargetsForSlot, CSByConstantArg.second, ResByArg)) continue; - if (tryUniqueRetValOpt(BitWidth, TargetsForSlot, CSByConstantArg.second)) + if (tryUniqueRetValOpt(BitWidth, TargetsForSlot, CSByConstantArg.second, + ResByArg, Slot, CSByConstantArg.first)) continue; // Find an allocation offset in bits in all vtables associated with the @@ -659,26 +1007,20 @@ bool DevirtModule::tryVirtualConstProp( for (auto &&Target : TargetsForSlot) Target.WasDevirt = true; - // Rewrite each call to a load from OffsetByte/OffsetBit. - for (auto Call : CSByConstantArg.second) { - IRBuilder<> B(Call.CS.getInstruction()); - Value *Addr = B.CreateConstGEP1_64(Call.VTable, OffsetByte); - if (BitWidth == 1) { - Value *Bits = B.CreateLoad(Addr); - Value *Bit = ConstantInt::get(Int8Ty, 1ULL << OffsetBit); - Value *BitsAndBit = B.CreateAnd(Bits, Bit); - auto IsBitSet = B.CreateICmpNE(BitsAndBit, ConstantInt::get(Int8Ty, 0)); - Call.replaceAndErase("virtual-const-prop-1-bit", - TargetsForSlot[0].Fn->getName(), - RemarksEnabled, IsBitSet); - } else { - Value *ValAddr = B.CreateBitCast(Addr, RetType->getPointerTo()); - Value *Val = B.CreateLoad(RetType, ValAddr); - Call.replaceAndErase("virtual-const-prop", - TargetsForSlot[0].Fn->getName(), - RemarksEnabled, Val); - } + Constant *ByteConst = ConstantInt::get(Int32Ty, OffsetByte); + Constant *BitConst = ConstantInt::get(Int8Ty, 1ULL << OffsetBit); + + if (CSByConstantArg.second.isExported()) { + ResByArg->TheKind = WholeProgramDevirtResolution::ByArg::VirtualConstProp; + exportGlobal(Slot, CSByConstantArg.first, "byte", + ConstantExpr::getIntToPtr(ByteConst, Int8PtrTy)); + exportGlobal(Slot, CSByConstantArg.first, "bit", + ConstantExpr::getIntToPtr(BitConst, Int8PtrTy)); } + + // Rewrite each call to a load from OffsetByte/OffsetBit. + applyVirtualConstProp(CSByConstantArg.second, + TargetsForSlot[0].Fn->getName(), ByteConst, BitConst); } return true; } @@ -733,7 +1075,11 @@ bool DevirtModule::areRemarksEnabled() { if (FL.empty()) return false; const Function &Fn = FL.front(); - auto DI = OptimizationRemark(DEBUG_TYPE, Fn, DebugLoc(), ""); + + const auto &BBL = Fn.getBasicBlockList(); + if (BBL.empty()) + return false; + auto DI = OptimizationRemark(DEBUG_TYPE, "", DebugLoc(), &BBL.front()); return DI.isEnabled(); } @@ -766,8 +1112,8 @@ void DevirtModule::scanTypeTestUsers(Function *TypeTestFunc, Value *Ptr = CI->getArgOperand(0)->stripPointerCasts(); if (SeenPtrs.insert(Ptr).second) { for (DevirtCallSite Call : DevirtCalls) { - CallSlots[{TypeId, Call.Offset}].push_back( - {CI->getArgOperand(0), Call.CS, nullptr}); + CallSlots[{TypeId, Call.Offset}].addCallSite(CI->getArgOperand(0), + Call.CS, nullptr); } } } @@ -853,14 +1199,79 @@ void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) { if (HasNonCallUses) ++NumUnsafeUses; for (DevirtCallSite Call : DevirtCalls) { - CallSlots[{TypeId, Call.Offset}].push_back( - {Ptr, Call.CS, &NumUnsafeUses}); + CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CS, + &NumUnsafeUses); } CI->eraseFromParent(); } } +void DevirtModule::importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo) { + const TypeIdSummary *TidSummary = + ImportSummary->getTypeIdSummary(cast<MDString>(Slot.TypeID)->getString()); + if (!TidSummary) + return; + auto ResI = TidSummary->WPDRes.find(Slot.ByteOffset); + if (ResI == TidSummary->WPDRes.end()) + return; + const WholeProgramDevirtResolution &Res = ResI->second; + + if (Res.TheKind == WholeProgramDevirtResolution::SingleImpl) { + // The type of the function in the declaration is irrelevant because every + // call site will cast it to the correct type. + auto *SingleImpl = M.getOrInsertFunction( + Res.SingleImplName, Type::getVoidTy(M.getContext())); + + // This is the import phase so we should not be exporting anything. + bool IsExported = false; + applySingleImplDevirt(SlotInfo, SingleImpl, IsExported); + assert(!IsExported); + } + + for (auto &CSByConstantArg : SlotInfo.ConstCSInfo) { + auto I = Res.ResByArg.find(CSByConstantArg.first); + if (I == Res.ResByArg.end()) + continue; + auto &ResByArg = I->second; + // FIXME: We should figure out what to do about the "function name" argument + // to the apply* functions, as the function names are unavailable during the + // importing phase. For now we just pass the empty string. This does not + // impact correctness because the function names are just used for remarks. + switch (ResByArg.TheKind) { + case WholeProgramDevirtResolution::ByArg::UniformRetVal: + applyUniformRetValOpt(CSByConstantArg.second, "", ResByArg.Info); + break; + case WholeProgramDevirtResolution::ByArg::UniqueRetVal: { + Constant *UniqueMemberAddr = + importGlobal(Slot, CSByConstantArg.first, "unique_member"); + applyUniqueRetValOpt(CSByConstantArg.second, "", ResByArg.Info, + UniqueMemberAddr); + break; + } + case WholeProgramDevirtResolution::ByArg::VirtualConstProp: { + Constant *Byte = importGlobal(Slot, CSByConstantArg.first, "byte", 32); + Byte = ConstantExpr::getPtrToInt(Byte, Int32Ty); + Constant *Bit = importGlobal(Slot, CSByConstantArg.first, "bit", 8); + Bit = ConstantExpr::getPtrToInt(Bit, Int8Ty); + applyVirtualConstProp(CSByConstantArg.second, "", Byte, Bit); + } + default: + break; + } + } +} + +void DevirtModule::removeRedundantTypeTests() { + auto True = ConstantInt::getTrue(M.getContext()); + for (auto &&U : NumUnsafeUsesForTypeTest) { + if (U.second == 0) { + U.first->replaceAllUsesWith(True); + U.first->eraseFromParent(); + } + } +} + bool DevirtModule::run() { Function *TypeTestFunc = M.getFunction(Intrinsic::getName(Intrinsic::type_test)); @@ -868,7 +1279,11 @@ bool DevirtModule::run() { M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load)); Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume)); - if ((!TypeTestFunc || TypeTestFunc->use_empty() || !AssumeFunc || + // Normally if there are no users of the devirtualization intrinsics in the + // module, this pass has nothing to do. But if we are exporting, we also need + // to handle any users that appear only in the function summaries. + if (!ExportSummary && + (!TypeTestFunc || TypeTestFunc->use_empty() || !AssumeFunc || AssumeFunc->use_empty()) && (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty())) return false; @@ -879,6 +1294,17 @@ bool DevirtModule::run() { if (TypeCheckedLoadFunc) scanTypeCheckedLoadUsers(TypeCheckedLoadFunc); + if (ImportSummary) { + for (auto &S : CallSlots) + importResolution(S.first, S.second); + + removeRedundantTypeTests(); + + // The rest of the code is only necessary when exporting or during regular + // LTO, so we are done. + return true; + } + // Rebuild type metadata into a map for easy lookup. std::vector<VTableBits> Bits; DenseMap<Metadata *, std::set<TypeMemberInfo>> TypeIdMap; @@ -886,6 +1312,53 @@ bool DevirtModule::run() { if (TypeIdMap.empty()) return true; + // Collect information from summary about which calls to try to devirtualize. + if (ExportSummary) { + DenseMap<GlobalValue::GUID, TinyPtrVector<Metadata *>> MetadataByGUID; + for (auto &P : TypeIdMap) { + if (auto *TypeId = dyn_cast<MDString>(P.first)) + MetadataByGUID[GlobalValue::getGUID(TypeId->getString())].push_back( + TypeId); + } + + for (auto &P : *ExportSummary) { + for (auto &S : P.second.SummaryList) { + auto *FS = dyn_cast<FunctionSummary>(S.get()); + if (!FS) + continue; + // FIXME: Only add live functions. + for (FunctionSummary::VFuncId VF : FS->type_test_assume_vcalls()) { + for (Metadata *MD : MetadataByGUID[VF.GUID]) { + CallSlots[{MD, VF.Offset}].CSInfo.SummaryHasTypeTestAssumeUsers = + true; + } + } + for (FunctionSummary::VFuncId VF : FS->type_checked_load_vcalls()) { + for (Metadata *MD : MetadataByGUID[VF.GUID]) { + CallSlots[{MD, VF.Offset}] + .CSInfo.SummaryTypeCheckedLoadUsers.push_back(FS); + } + } + for (const FunctionSummary::ConstVCall &VC : + FS->type_test_assume_const_vcalls()) { + for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) { + CallSlots[{MD, VC.VFunc.Offset}] + .ConstCSInfo[VC.Args] + .SummaryHasTypeTestAssumeUsers = true; + } + } + for (const FunctionSummary::ConstVCall &VC : + FS->type_checked_load_const_vcalls()) { + for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) { + CallSlots[{MD, VC.VFunc.Offset}] + .ConstCSInfo[VC.Args] + .SummaryTypeCheckedLoadUsers.push_back(FS); + } + } + } + } + } + // For each (type, offset) pair: bool DidVirtualConstProp = false; std::map<std::string, Function*> DevirtTargets; @@ -894,19 +1367,39 @@ bool DevirtModule::run() { // function implementation at offset S.first.ByteOffset, and add to // TargetsForSlot. std::vector<VirtualCallTarget> TargetsForSlot; - if (!tryFindVirtualCallTargets(TargetsForSlot, TypeIdMap[S.first.TypeID], - S.first.ByteOffset)) - continue; - - if (!trySingleImplDevirt(TargetsForSlot, S.second) && - tryVirtualConstProp(TargetsForSlot, S.second)) + if (tryFindVirtualCallTargets(TargetsForSlot, TypeIdMap[S.first.TypeID], + S.first.ByteOffset)) { + WholeProgramDevirtResolution *Res = nullptr; + if (ExportSummary && isa<MDString>(S.first.TypeID)) + Res = &ExportSummary + ->getOrInsertTypeIdSummary( + cast<MDString>(S.first.TypeID)->getString()) + .WPDRes[S.first.ByteOffset]; + + if (!trySingleImplDevirt(TargetsForSlot, S.second, Res) && + tryVirtualConstProp(TargetsForSlot, S.second, Res, S.first)) DidVirtualConstProp = true; - // Collect functions devirtualized at least for one call site for stats. - if (RemarksEnabled) - for (const auto &T : TargetsForSlot) - if (T.WasDevirt) - DevirtTargets[T.Fn->getName()] = T.Fn; + // Collect functions devirtualized at least for one call site for stats. + if (RemarksEnabled) + for (const auto &T : TargetsForSlot) + if (T.WasDevirt) + DevirtTargets[T.Fn->getName()] = T.Fn; + } + + // CFI-specific: if we are exporting and any llvm.type.checked.load + // intrinsics were *not* devirtualized, we need to add the resulting + // llvm.type.test intrinsics to the function summaries so that the + // LowerTypeTests pass will export them. + if (ExportSummary && isa<MDString>(S.first.TypeID)) { + auto GUID = + GlobalValue::getGUID(cast<MDString>(S.first.TypeID)->getString()); + for (auto FS : S.second.CSInfo.SummaryTypeCheckedLoadUsers) + FS->addTypeTest(GUID); + for (auto &CCS : S.second.ConstCSInfo) + for (auto FS : CCS.second.SummaryTypeCheckedLoadUsers) + FS->addTypeTest(GUID); + } } if (RemarksEnabled) { @@ -914,23 +1407,12 @@ bool DevirtModule::run() { for (const auto &DT : DevirtTargets) { Function *F = DT.second; DISubprogram *SP = F->getSubprogram(); - DebugLoc DL = SP ? DebugLoc::get(SP->getScopeLine(), 0, SP) : DebugLoc(); - emitOptimizationRemark(F->getContext(), DEBUG_TYPE, *F, DL, + emitOptimizationRemark(F->getContext(), DEBUG_TYPE, *F, SP, Twine("devirtualized ") + F->getName()); } } - // If we were able to eliminate all unsafe uses for a type checked load, - // eliminate the type test by replacing it with true. - if (TypeCheckedLoadFunc) { - auto True = ConstantInt::getTrue(M.getContext()); - for (auto &&U : NumUnsafeUsesForTypeTest) { - if (U.second == 0) { - U.first->replaceAllUsesWith(True); - U.first->eraseFromParent(); - } - } - } + removeRedundantTypeTests(); // Rebuild each global we touched as part of virtual constant propagation to // include the before and after bytes. diff --git a/contrib/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/contrib/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 2d34c1c..809471c 100644 --- a/contrib/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/contrib/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -17,6 +17,7 @@ #include "llvm/IR/DataLayout.h" #include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/Support/KnownBits.h" using namespace llvm; using namespace PatternMatch; @@ -163,7 +164,7 @@ namespace { /// class FAddCombine { public: - FAddCombine(InstCombiner::BuilderTy *B) : Builder(B), Instr(nullptr) {} + FAddCombine(InstCombiner::BuilderTy &B) : Builder(B), Instr(nullptr) {} Value *simplify(Instruction *FAdd); private: @@ -186,7 +187,7 @@ namespace { Value *createNaryFAdd(const AddendVect& Opnds, unsigned InstrQuota); void createInstPostProc(Instruction *NewInst, bool NoNumber = false); - InstCombiner::BuilderTy *Builder; + InstCombiner::BuilderTy &Builder; Instruction *Instr; // Debugging stuff are clustered here. @@ -734,7 +735,7 @@ Value *FAddCombine::createNaryFAdd } Value *FAddCombine::createFSub(Value *Opnd0, Value *Opnd1) { - Value *V = Builder->CreateFSub(Opnd0, Opnd1); + Value *V = Builder.CreateFSub(Opnd0, Opnd1); if (Instruction *I = dyn_cast<Instruction>(V)) createInstPostProc(I); return V; @@ -749,21 +750,21 @@ Value *FAddCombine::createFNeg(Value *V) { } Value *FAddCombine::createFAdd(Value *Opnd0, Value *Opnd1) { - Value *V = Builder->CreateFAdd(Opnd0, Opnd1); + Value *V = Builder.CreateFAdd(Opnd0, Opnd1); if (Instruction *I = dyn_cast<Instruction>(V)) createInstPostProc(I); return V; } Value *FAddCombine::createFMul(Value *Opnd0, Value *Opnd1) { - Value *V = Builder->CreateFMul(Opnd0, Opnd1); + Value *V = Builder.CreateFMul(Opnd0, Opnd1); if (Instruction *I = dyn_cast<Instruction>(V)) createInstPostProc(I); return V; } Value *FAddCombine::createFDiv(Value *Opnd0, Value *Opnd1) { - Value *V = Builder->CreateFDiv(Opnd0, Opnd1); + Value *V = Builder.CreateFDiv(Opnd0, Opnd1); if (Instruction *I = dyn_cast<Instruction>(V)) createInstPostProc(I); return V; @@ -794,6 +795,11 @@ unsigned FAddCombine::calcInstrNumber(const AddendVect &Opnds) { if (Opnd->isConstant()) continue; + // The constant check above is really for a few special constant + // coefficients. + if (isa<UndefValue>(Opnd->getSymVal())) + continue; + const FAddendCoef &CE = Opnd->getCoef(); if (CE.isMinusOne() || CE.isMinusTwo()) NegOpndNum++; @@ -841,108 +847,28 @@ Value *FAddCombine::createAddendVal(const FAddend &Opnd, bool &NeedNeg) { return createFMul(OpndVal, Coeff.getValue(Instr->getType())); } -// If one of the operands only has one non-zero bit, and if the other -// operand has a known-zero bit in a more significant place than it (not -// including the sign bit) the ripple may go up to and fill the zero, but -// won't change the sign. For example, (X & ~4) + 1. -static bool checkRippleForAdd(const APInt &Op0KnownZero, - const APInt &Op1KnownZero) { - APInt Op1MaybeOne = ~Op1KnownZero; - // Make sure that one of the operand has at most one bit set to 1. - if (Op1MaybeOne.countPopulation() != 1) - return false; - - // Find the most significant known 0 other than the sign bit. - int BitWidth = Op0KnownZero.getBitWidth(); - APInt Op0KnownZeroTemp(Op0KnownZero); - Op0KnownZeroTemp.clearBit(BitWidth - 1); - int Op0ZeroPosition = BitWidth - Op0KnownZeroTemp.countLeadingZeros() - 1; - - int Op1OnePosition = BitWidth - Op1MaybeOne.countLeadingZeros() - 1; - assert(Op1OnePosition >= 0); - - // This also covers the case of no known zero, since in that case - // Op0ZeroPosition is -1. - return Op0ZeroPosition >= Op1OnePosition; -} - -/// Return true if we can prove that: -/// (sext (add LHS, RHS)) === (add (sext LHS), (sext RHS)) -/// This basically requires proving that the add in the original type would not -/// overflow to change the sign bit or have a carry out. -bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS, - Instruction &CxtI) { - // There are different heuristics we can use for this. Here are some simple - // ones. - - // If LHS and RHS each have at least two sign bits, the addition will look - // like - // - // XX..... + - // YY..... - // - // If the carry into the most significant position is 0, X and Y can't both - // be 1 and therefore the carry out of the addition is also 0. - // - // If the carry into the most significant position is 1, X and Y can't both - // be 0 and therefore the carry out of the addition is also 1. - // - // Since the carry into the most significant position is always equal to - // the carry out of the addition, there is no signed overflow. - if (ComputeNumSignBits(LHS, 0, &CxtI) > 1 && - ComputeNumSignBits(RHS, 0, &CxtI) > 1) - return true; - - unsigned BitWidth = LHS->getType()->getScalarSizeInBits(); - APInt LHSKnownZero(BitWidth, 0); - APInt LHSKnownOne(BitWidth, 0); - computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, &CxtI); - - APInt RHSKnownZero(BitWidth, 0); - APInt RHSKnownOne(BitWidth, 0); - computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, &CxtI); - - // Addition of two 2's compliment numbers having opposite signs will never - // overflow. - if ((LHSKnownOne[BitWidth - 1] && RHSKnownZero[BitWidth - 1]) || - (LHSKnownZero[BitWidth - 1] && RHSKnownOne[BitWidth - 1])) - return true; - - // Check if carry bit of addition will not cause overflow. - if (checkRippleForAdd(LHSKnownZero, RHSKnownZero)) - return true; - if (checkRippleForAdd(RHSKnownZero, LHSKnownZero)) - return true; - - return false; -} - /// \brief Return true if we can prove that: /// (sub LHS, RHS) === (sub nsw LHS, RHS) /// This basically requires proving that the add in the original type would not /// overflow to change the sign bit or have a carry out. /// TODO: Handle this for Vectors. -bool InstCombiner::WillNotOverflowSignedSub(Value *LHS, Value *RHS, - Instruction &CxtI) { +bool InstCombiner::willNotOverflowSignedSub(const Value *LHS, + const Value *RHS, + const Instruction &CxtI) const { // If LHS and RHS each have at least two sign bits, the subtraction // cannot overflow. if (ComputeNumSignBits(LHS, 0, &CxtI) > 1 && ComputeNumSignBits(RHS, 0, &CxtI) > 1) return true; - unsigned BitWidth = LHS->getType()->getScalarSizeInBits(); - APInt LHSKnownZero(BitWidth, 0); - APInt LHSKnownOne(BitWidth, 0); - computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, &CxtI); + KnownBits LHSKnown = computeKnownBits(LHS, 0, &CxtI); - APInt RHSKnownZero(BitWidth, 0); - APInt RHSKnownOne(BitWidth, 0); - computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, &CxtI); + KnownBits RHSKnown = computeKnownBits(RHS, 0, &CxtI); - // Subtraction of two 2's compliment numbers having identical signs will + // Subtraction of two 2's complement numbers having identical signs will // never overflow. - if ((LHSKnownOne[BitWidth - 1] && RHSKnownOne[BitWidth - 1]) || - (LHSKnownZero[BitWidth - 1] && RHSKnownZero[BitWidth - 1])) + if ((LHSKnown.isNegative() && RHSKnown.isNegative()) || + (LHSKnown.isNonNegative() && RHSKnown.isNonNegative())) return true; // TODO: implement logic similar to checkRippleForAdd @@ -951,16 +877,13 @@ bool InstCombiner::WillNotOverflowSignedSub(Value *LHS, Value *RHS, /// \brief Return true if we can prove that: /// (sub LHS, RHS) === (sub nuw LHS, RHS) -bool InstCombiner::WillNotOverflowUnsignedSub(Value *LHS, Value *RHS, - Instruction &CxtI) { +bool InstCombiner::willNotOverflowUnsignedSub(const Value *LHS, + const Value *RHS, + const Instruction &CxtI) const { // If the LHS is negative and the RHS is non-negative, no unsigned wrap. - bool LHSKnownNonNegative, LHSKnownNegative; - bool RHSKnownNonNegative, RHSKnownNegative; - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, /*Depth=*/0, - &CxtI); - ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, /*Depth=*/0, - &CxtI); - if (LHSKnownNegative && RHSKnownNonNegative) + KnownBits LHSKnown = computeKnownBits(LHS, /*Depth=*/0, &CxtI); + KnownBits RHSKnown = computeKnownBits(RHS, /*Depth=*/0, &CxtI); + if (LHSKnown.isNegative() && RHSKnown.isNonNegative()) return true; return false; @@ -972,7 +895,7 @@ bool InstCombiner::WillNotOverflowUnsignedSub(Value *LHS, Value *RHS, // ADD(XOR(AND(Z, C), C), 1) == NEG(OR(Z, ~C)) // XOR(AND(Z, C), (C + 1)) == NEG(OR(Z, ~C)) if C is even static Value *checkForNegativeOperand(BinaryOperator &I, - InstCombiner::BuilderTy *Builder) { + InstCombiner::BuilderTy &Builder) { Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); // This function creates 2 instructions to replace ADD, we need at least one @@ -996,13 +919,13 @@ static Value *checkForNegativeOperand(BinaryOperator &I, // X = XOR(Y, C1), Y = OR(Z, C2), C2 = NOT(C1) ==> X == NOT(AND(Z, C1)) // ADD(ADD(X, 1), RHS) == ADD(X, ADD(RHS, 1)) == SUB(RHS, AND(Z, C1)) if (match(Y, m_Or(m_Value(Z), m_APInt(C2))) && (*C2 == ~(*C1))) { - Value *NewAnd = Builder->CreateAnd(Z, *C1); - return Builder->CreateSub(RHS, NewAnd, "sub"); + Value *NewAnd = Builder.CreateAnd(Z, *C1); + return Builder.CreateSub(RHS, NewAnd, "sub"); } else if (match(Y, m_And(m_Value(Z), m_APInt(C2))) && (*C1 == *C2)) { // X = XOR(Y, C1), Y = AND(Z, C2), C2 == C1 ==> X == NOT(OR(Z, ~C1)) // ADD(ADD(X, 1), RHS) == ADD(X, ADD(RHS, 1)) == SUB(RHS, OR(Z, ~C1)) - Value *NewOr = Builder->CreateOr(Z, ~(*C1)); - return Builder->CreateSub(RHS, NewOr, "sub"); + Value *NewOr = Builder.CreateOr(Z, ~(*C1)); + return Builder.CreateSub(RHS, NewOr, "sub"); } } } @@ -1021,12 +944,73 @@ static Value *checkForNegativeOperand(BinaryOperator &I, if (match(LHS, m_Xor(m_Value(Y), m_APInt(C1)))) if (C1->countTrailingZeros() == 0) if (match(Y, m_And(m_Value(Z), m_APInt(C2))) && *C1 == (*C2 + 1)) { - Value *NewOr = Builder->CreateOr(Z, ~(*C2)); - return Builder->CreateSub(RHS, NewOr, "sub"); + Value *NewOr = Builder.CreateOr(Z, ~(*C2)); + return Builder.CreateSub(RHS, NewOr, "sub"); } return nullptr; } +static Instruction *foldAddWithConstant(BinaryOperator &Add, + InstCombiner::BuilderTy &Builder) { + Value *Op0 = Add.getOperand(0), *Op1 = Add.getOperand(1); + const APInt *C; + if (!match(Op1, m_APInt(C))) + return nullptr; + + if (C->isSignMask()) { + // If wrapping is not allowed, then the addition must set the sign bit: + // X + (signmask) --> X | signmask + if (Add.hasNoSignedWrap() || Add.hasNoUnsignedWrap()) + return BinaryOperator::CreateOr(Op0, Op1); + + // If wrapping is allowed, then the addition flips the sign bit of LHS: + // X + (signmask) --> X ^ signmask + return BinaryOperator::CreateXor(Op0, Op1); + } + + Value *X; + const APInt *C2; + Type *Ty = Add.getType(); + + // Is this add the last step in a convoluted sext? + // add(zext(xor i16 X, -32768), -32768) --> sext X + if (match(Op0, m_ZExt(m_Xor(m_Value(X), m_APInt(C2)))) && + C2->isMinSignedValue() && C2->sext(Ty->getScalarSizeInBits()) == *C) + return CastInst::Create(Instruction::SExt, X, Ty); + + // (add (zext (add nuw X, C2)), C) --> (zext (add nuw X, C2 + C)) + // FIXME: This should check hasOneUse to not increase the instruction count? + if (C->isNegative() && + match(Op0, m_ZExt(m_NUWAdd(m_Value(X), m_APInt(C2)))) && + C->sge(-C2->sext(C->getBitWidth()))) { + Constant *NewC = + ConstantInt::get(X->getType(), *C2 + C->trunc(C2->getBitWidth())); + return new ZExtInst(Builder.CreateNUWAdd(X, NewC), Ty); + } + + if (C->isOneValue() && Op0->hasOneUse()) { + // add (sext i1 X), 1 --> zext (not X) + // TODO: The smallest IR representation is (select X, 0, 1), and that would + // not require the one-use check. But we need to remove a transform in + // visitSelect and make sure that IR value tracking for select is equal or + // better than for these ops. + if (match(Op0, m_SExt(m_Value(X))) && + X->getType()->getScalarSizeInBits() == 1) + return new ZExtInst(Builder.CreateNot(X), Ty); + + // Shifts and add used to flip and mask off the low bit: + // add (ashr (shl i32 X, 31), 31), 1 --> and (not X), 1 + const APInt *C3; + if (match(Op0, m_AShr(m_Shl(m_Value(X), m_APInt(C2)), m_APInt(C3))) && + C2 == C3 && *C2 == Ty->getScalarSizeInBits() - 1) { + Value *NotX = Builder.CreateNot(X); + return BinaryOperator::CreateAnd(NotX, ConstantInt::get(Ty, 1)); + } + } + + return nullptr; +} + Instruction *InstCombiner::visitAdd(BinaryOperator &I) { bool Changed = SimplifyAssociativeOrCommutative(I); Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); @@ -1034,51 +1018,21 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); - if (Value *V = SimplifyAddInst(LHS, RHS, I.hasNoSignedWrap(), - I.hasNoUnsignedWrap(), DL, &TLI, &DT, &AC)) + if (Value *V = + SimplifyAddInst(LHS, RHS, I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); // (A*B)+(A*C) -> A*(B+C) etc if (Value *V = SimplifyUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); - const APInt *Val; - if (match(RHS, m_APInt(Val))) { - // X + (signbit) --> X ^ signbit - if (Val->isSignBit()) - return BinaryOperator::CreateXor(LHS, RHS); - - // Is this add the last step in a convoluted sext? - Value *X; - const APInt *C; - if (match(LHS, m_ZExt(m_Xor(m_Value(X), m_APInt(C)))) && - C->isMinSignedValue() && - C->sext(LHS->getType()->getScalarSizeInBits()) == *Val) { - // add(zext(xor i16 X, -32768), -32768) --> sext X - return CastInst::Create(Instruction::SExt, X, LHS->getType()); - } - - if (Val->isNegative() && - match(LHS, m_ZExt(m_NUWAdd(m_Value(X), m_APInt(C)))) && - Val->sge(-C->sext(Val->getBitWidth()))) { - // (add (zext (add nuw X, C)), Val) -> (zext (add nuw X, C+Val)) - return CastInst::Create( - Instruction::ZExt, - Builder->CreateNUWAdd( - X, Constant::getIntegerValue(X->getType(), - *C + Val->trunc(C->getBitWidth()))), - I.getType()); - } - } + if (Instruction *X = foldAddWithConstant(I, Builder)) + return X; - // FIXME: Use the match above instead of dyn_cast to allow these transforms - // for splat vectors. + // FIXME: This should be moved into the above helper function to allow these + // transforms for splat vectors. if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) { - // See if SimplifyDemandedBits can simplify this. This handles stuff like - // (X & 254)+1 -> (X&254)|1 - if (SimplifyDemandedInstructionBits(I)) - return &I; - // zext(bool) + C -> bool ? C + 1 : C if (ZExtInst *ZI = dyn_cast<ZExtInst>(LHS)) if (ZI->getSrcTy()->isIntegerTy(1)) @@ -1106,34 +1060,31 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (ExtendAmt) { Constant *ShAmt = ConstantInt::get(I.getType(), ExtendAmt); - Value *NewShl = Builder->CreateShl(XorLHS, ShAmt, "sext"); + Value *NewShl = Builder.CreateShl(XorLHS, ShAmt, "sext"); return BinaryOperator::CreateAShr(NewShl, ShAmt); } // If this is a xor that was canonicalized from a sub, turn it back into // a sub and fuse this add with it. if (LHS->hasOneUse() && (XorRHS->getValue()+1).isPowerOf2()) { - IntegerType *IT = cast<IntegerType>(I.getType()); - APInt LHSKnownOne(IT->getBitWidth(), 0); - APInt LHSKnownZero(IT->getBitWidth(), 0); - computeKnownBits(XorLHS, LHSKnownZero, LHSKnownOne, 0, &I); - if ((XorRHS->getValue() | LHSKnownZero).isAllOnesValue()) + KnownBits LHSKnown = computeKnownBits(XorLHS, 0, &I); + if ((XorRHS->getValue() | LHSKnown.Zero).isAllOnesValue()) return BinaryOperator::CreateSub(ConstantExpr::getAdd(XorRHS, CI), XorLHS); } - // (X + signbit) + C could have gotten canonicalized to (X ^ signbit) + C, - // transform them into (X + (signbit ^ C)) - if (XorRHS->getValue().isSignBit()) + // (X + signmask) + C could have gotten canonicalized to (X^signmask) + C, + // transform them into (X + (signmask ^ C)) + if (XorRHS->getValue().isSignMask()) return BinaryOperator::CreateAdd(XorLHS, ConstantExpr::getXor(XorRHS, CI)); } } - if (isa<Constant>(RHS) && isa<PHINode>(LHS)) - if (Instruction *NV = FoldOpIntoPhi(I)) + if (isa<Constant>(RHS)) + if (Instruction *NV = foldOpWithConstantIntoOperand(I)) return NV; - if (I.getType()->getScalarType()->isIntegerTy(1)) + if (I.getType()->isIntOrIntVectorTy(1)) return BinaryOperator::CreateXor(LHS, RHS); // X + X --> X << 1 @@ -1150,7 +1101,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (Value *LHSV = dyn_castNegVal(LHS)) { if (!isa<Constant>(RHS)) if (Value *RHSV = dyn_castNegVal(RHS)) { - Value *NewAdd = Builder->CreateAdd(LHSV, RHSV, "sum"); + Value *NewAdd = Builder.CreateAdd(LHSV, RHSV, "sum"); return BinaryOperator::CreateNeg(NewAdd); } @@ -1197,15 +1148,10 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (AddRHSHighBits == AddRHSHighBitsAnd) { // Okay, the xform is safe. Insert the new add pronto. - Value *NewAdd = Builder->CreateAdd(X, CRHS, LHS->getName()); + Value *NewAdd = Builder.CreateAdd(X, CRHS, LHS->getName()); return BinaryOperator::CreateAnd(NewAdd, C2); } } - - // Try to fold constant add into select arguments. - if (SelectInst *SI = dyn_cast<SelectInst>(LHS)) - if (Instruction *R = FoldOpIntoSelect(I, SI)) - return R; } // add (select X 0 (sub n A)) A --> select X A n @@ -1242,10 +1188,10 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { Constant *CI = ConstantExpr::getTrunc(RHSC, LHSConv->getOperand(0)->getType()); if (ConstantExpr::getSExt(CI, I.getType()) == RHSC && - WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI, I)) { + willNotOverflowSignedAdd(LHSConv->getOperand(0), CI, I)) { // Insert the new, smaller add. Value *NewAdd = - Builder->CreateNSWAdd(LHSConv->getOperand(0), CI, "addconv"); + Builder.CreateNSWAdd(LHSConv->getOperand(0), CI, "addconv"); return new SExtInst(NewAdd, I.getType()); } } @@ -1253,16 +1199,16 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { // (add (sext x), (sext y)) --> (sext (add int x, y)) if (SExtInst *RHSConv = dyn_cast<SExtInst>(RHS)) { - // Only do this if x/y have the same type, if at last one of them has a + // Only do this if x/y have the same type, if at least one of them has a // single use (so we don't increase the number of sexts), and if the // integer add will not overflow. if (LHSConv->getOperand(0)->getType() == RHSConv->getOperand(0)->getType() && (LHSConv->hasOneUse() || RHSConv->hasOneUse()) && - WillNotOverflowSignedAdd(LHSConv->getOperand(0), + willNotOverflowSignedAdd(LHSConv->getOperand(0), RHSConv->getOperand(0), I)) { // Insert the new integer add. - Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0), + Value *NewAdd = Builder.CreateNSWAdd(LHSConv->getOperand(0), RHSConv->getOperand(0), "addconv"); return new SExtInst(NewAdd, I.getType()); } @@ -1278,11 +1224,10 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { Constant *CI = ConstantExpr::getTrunc(RHSC, LHSConv->getOperand(0)->getType()); if (ConstantExpr::getZExt(CI, I.getType()) == RHSC && - computeOverflowForUnsignedAdd(LHSConv->getOperand(0), CI, &I) == - OverflowResult::NeverOverflows) { + willNotOverflowUnsignedAdd(LHSConv->getOperand(0), CI, I)) { // Insert the new, smaller add. Value *NewAdd = - Builder->CreateNUWAdd(LHSConv->getOperand(0), CI, "addconv"); + Builder.CreateNUWAdd(LHSConv->getOperand(0), CI, "addconv"); return new ZExtInst(NewAdd, I.getType()); } } @@ -1290,17 +1235,16 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { // (add (zext x), (zext y)) --> (zext (add int x, y)) if (auto *RHSConv = dyn_cast<ZExtInst>(RHS)) { - // Only do this if x/y have the same type, if at last one of them has a + // Only do this if x/y have the same type, if at least one of them has a // single use (so we don't increase the number of zexts), and if the // integer add will not overflow. if (LHSConv->getOperand(0)->getType() == RHSConv->getOperand(0)->getType() && (LHSConv->hasOneUse() || RHSConv->hasOneUse()) && - computeOverflowForUnsignedAdd(LHSConv->getOperand(0), - RHSConv->getOperand(0), - &I) == OverflowResult::NeverOverflows) { + willNotOverflowUnsignedAdd(LHSConv->getOperand(0), + RHSConv->getOperand(0), I)) { // Insert the new integer add. - Value *NewAdd = Builder->CreateNUWAdd( + Value *NewAdd = Builder.CreateNUWAdd( LHSConv->getOperand(0), RHSConv->getOperand(0), "addconv"); return new ZExtInst(NewAdd, I.getType()); } @@ -1311,13 +1255,11 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { { Value *A = nullptr, *B = nullptr; if (match(RHS, m_Xor(m_Value(A), m_Value(B))) && - (match(LHS, m_And(m_Specific(A), m_Specific(B))) || - match(LHS, m_And(m_Specific(B), m_Specific(A))))) + match(LHS, m_c_And(m_Specific(A), m_Specific(B)))) return BinaryOperator::CreateOr(A, B); if (match(LHS, m_Xor(m_Value(A), m_Value(B))) && - (match(RHS, m_And(m_Specific(A), m_Specific(B))) || - match(RHS, m_And(m_Specific(B), m_Specific(A))))) + match(RHS, m_c_And(m_Specific(A), m_Specific(B)))) return BinaryOperator::CreateOr(A, B); } @@ -1325,8 +1267,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { { Value *A = nullptr, *B = nullptr; if (match(RHS, m_Or(m_Value(A), m_Value(B))) && - (match(LHS, m_And(m_Specific(A), m_Specific(B))) || - match(LHS, m_And(m_Specific(B), m_Specific(A))))) { + match(LHS, m_c_And(m_Specific(A), m_Specific(B)))) { auto *New = BinaryOperator::CreateAdd(A, B); New->setHasNoSignedWrap(I.hasNoSignedWrap()); New->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); @@ -1334,8 +1275,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { } if (match(LHS, m_Or(m_Value(A), m_Value(B))) && - (match(RHS, m_And(m_Specific(A), m_Specific(B))) || - match(RHS, m_And(m_Specific(B), m_Specific(A))))) { + match(RHS, m_c_And(m_Specific(A), m_Specific(B)))) { auto *New = BinaryOperator::CreateAdd(A, B); New->setHasNoSignedWrap(I.hasNoSignedWrap()); New->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); @@ -1343,16 +1283,14 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { } } - // TODO(jingyue): Consider WillNotOverflowSignedAdd and - // WillNotOverflowUnsignedAdd to reduce the number of invocations of + // TODO(jingyue): Consider willNotOverflowSignedAdd and + // willNotOverflowUnsignedAdd to reduce the number of invocations of // computeKnownBits. - if (!I.hasNoSignedWrap() && WillNotOverflowSignedAdd(LHS, RHS, I)) { + if (!I.hasNoSignedWrap() && willNotOverflowSignedAdd(LHS, RHS, I)) { Changed = true; I.setHasNoSignedWrap(true); } - if (!I.hasNoUnsignedWrap() && - computeOverflowForUnsignedAdd(LHS, RHS, &I) == - OverflowResult::NeverOverflows) { + if (!I.hasNoUnsignedWrap() && willNotOverflowUnsignedAdd(LHS, RHS, I)) { Changed = true; I.setHasNoUnsignedWrap(true); } @@ -1367,8 +1305,8 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); - if (Value *V = - SimplifyFAddInst(LHS, RHS, I.getFastMathFlags(), DL, &TLI, &DT, &AC)) + if (Value *V = SimplifyFAddInst(LHS, RHS, I.getFastMathFlags(), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); if (isa<Constant>(RHS)) @@ -1394,38 +1332,57 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { // Check for (fadd double (sitofp x), y), see if we can merge this into an // integer add followed by a promotion. if (SIToFPInst *LHSConv = dyn_cast<SIToFPInst>(LHS)) { + Value *LHSIntVal = LHSConv->getOperand(0); + Type *FPType = LHSConv->getType(); + + // TODO: This check is overly conservative. In many cases known bits + // analysis can tell us that the result of the addition has less significant + // bits than the integer type can hold. + auto IsValidPromotion = [](Type *FTy, Type *ITy) { + Type *FScalarTy = FTy->getScalarType(); + Type *IScalarTy = ITy->getScalarType(); + + // Do we have enough bits in the significand to represent the result of + // the integer addition? + unsigned MaxRepresentableBits = + APFloat::semanticsPrecision(FScalarTy->getFltSemantics()); + return IScalarTy->getIntegerBitWidth() <= MaxRepresentableBits; + }; + // (fadd double (sitofp x), fpcst) --> (sitofp (add int x, intcst)) // ... if the constant fits in the integer value. This is useful for things // like (double)(x & 1234) + 4.0 -> (double)((X & 1234)+4) which no longer // requires a constant pool load, and generally allows the add to be better // instcombined. - if (ConstantFP *CFP = dyn_cast<ConstantFP>(RHS)) { - Constant *CI = - ConstantExpr::getFPToSI(CFP, LHSConv->getOperand(0)->getType()); - if (LHSConv->hasOneUse() && - ConstantExpr::getSIToFP(CI, I.getType()) == CFP && - WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI, I)) { - // Insert the new integer add. - Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0), - CI, "addconv"); - return new SIToFPInst(NewAdd, I.getType()); + if (ConstantFP *CFP = dyn_cast<ConstantFP>(RHS)) + if (IsValidPromotion(FPType, LHSIntVal->getType())) { + Constant *CI = + ConstantExpr::getFPToSI(CFP, LHSIntVal->getType()); + if (LHSConv->hasOneUse() && + ConstantExpr::getSIToFP(CI, I.getType()) == CFP && + willNotOverflowSignedAdd(LHSIntVal, CI, I)) { + // Insert the new integer add. + Value *NewAdd = Builder.CreateNSWAdd(LHSIntVal, CI, "addconv"); + return new SIToFPInst(NewAdd, I.getType()); + } } - } // (fadd double (sitofp x), (sitofp y)) --> (sitofp (add int x, y)) if (SIToFPInst *RHSConv = dyn_cast<SIToFPInst>(RHS)) { - // Only do this if x/y have the same type, if at last one of them has a - // single use (so we don't increase the number of int->fp conversions), - // and if the integer add will not overflow. - if (LHSConv->getOperand(0)->getType() == - RHSConv->getOperand(0)->getType() && - (LHSConv->hasOneUse() || RHSConv->hasOneUse()) && - WillNotOverflowSignedAdd(LHSConv->getOperand(0), - RHSConv->getOperand(0), I)) { - // Insert the new integer add. - Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0), - RHSConv->getOperand(0),"addconv"); - return new SIToFPInst(NewAdd, I.getType()); + Value *RHSIntVal = RHSConv->getOperand(0); + // It's enough to check LHS types only because we require int types to + // be the same for this transform. + if (IsValidPromotion(FPType, LHSIntVal->getType())) { + // Only do this if x/y have the same type, if at least one of them has a + // single use (so we don't increase the number of int->fp conversions), + // and if the integer add will not overflow. + if (LHSIntVal->getType() == RHSIntVal->getType() && + (LHSConv->hasOneUse() || RHSConv->hasOneUse()) && + willNotOverflowSignedAdd(LHSIntVal, RHSIntVal, I)) { + // Insert the new integer add. + Value *NewAdd = Builder.CreateNSWAdd(LHSIntVal, RHSIntVal, "addconv"); + return new SIToFPInst(NewAdd, I.getType()); + } } } } @@ -1521,14 +1478,14 @@ Value *InstCombiner::OptimizePointerDifference(Value *LHS, Value *RHS, // pointer, subtract it from the offset we have. if (GEP2) { Value *Offset = EmitGEPOffset(GEP2); - Result = Builder->CreateSub(Result, Offset); + Result = Builder.CreateSub(Result, Offset); } // If we have p - gep(p, ...) then we have to negate the result. if (Swapped) - Result = Builder->CreateNeg(Result, "diff.neg"); + Result = Builder.CreateNeg(Result, "diff.neg"); - return Builder->CreateIntCast(Result, Ty, true); + return Builder.CreateIntCast(Result, Ty, true); } Instruction *InstCombiner::visitSub(BinaryOperator &I) { @@ -1537,8 +1494,9 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); - if (Value *V = SimplifySubInst(Op0, Op1, I.hasNoSignedWrap(), - I.hasNoUnsignedWrap(), DL, &TLI, &DT, &AC)) + if (Value *V = + SimplifySubInst(Op0, Op1, I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); // (A*B)-(A*C) -> A*(B-C) etc @@ -1562,7 +1520,7 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return Res; } - if (I.getType()->isIntegerTy(1)) + if (I.getType()->isIntOrIntVectorTy(1)) return BinaryOperator::CreateXor(Op0, Op1); // Replace (-1 - A) with (~A). @@ -1580,22 +1538,24 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { if (Instruction *R = FoldOpIntoSelect(I, SI)) return R; + // Try to fold constant sub into PHI values. + if (PHINode *PN = dyn_cast<PHINode>(Op1)) + if (Instruction *R = foldOpIntoPhi(I, PN)) + return R; + // C-(X+C2) --> (C-C2)-X Constant *C2; if (match(Op1, m_Add(m_Value(X), m_Constant(C2)))) return BinaryOperator::CreateSub(ConstantExpr::getSub(C, C2), X); - if (SimplifyDemandedInstructionBits(I)) - return &I; - // Fold (sub 0, (zext bool to B)) --> (sext bool to B) if (C->isNullValue() && match(Op1, m_ZExt(m_Value(X)))) - if (X->getType()->getScalarType()->isIntegerTy(1)) + if (X->getType()->isIntOrIntVectorTy(1)) return CastInst::CreateSExtOrBitCast(X, Op1->getType()); // Fold (sub 0, (sext bool to B)) --> (zext bool to B) if (C->isNullValue() && match(Op1, m_SExt(m_Value(X)))) - if (X->getType()->getScalarType()->isIntegerTy(1)) + if (X->getType()->isIntOrIntVectorTy(1)) return CastInst::CreateZExtOrBitCast(X, Op1->getType()); } @@ -1605,7 +1565,7 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { // -(X >>u 31) -> (X >>s 31) // -(X >>s 31) -> (X >>u 31) - if (*Op0C == 0) { + if (Op0C->isNullValue()) { Value *X; const APInt *ShAmt; if (match(Op1, m_LShr(m_Value(X), m_APInt(ShAmt))) && @@ -1622,11 +1582,9 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { // Turn this into a xor if LHS is 2^n-1 and the remaining bits are known // zero. - if ((*Op0C + 1).isPowerOf2()) { - APInt KnownZero(BitWidth, 0); - APInt KnownOne(BitWidth, 0); - computeKnownBits(&I, KnownZero, KnownOne, 0, &I); - if ((*Op0C | KnownZero).isAllOnesValue()) + if (Op0C->isMask()) { + KnownBits RHSKnown = computeKnownBits(Op1, 0, &I); + if ((*Op0C | RHSKnown.Zero).isAllOnesValue()) return BinaryOperator::CreateXor(Op1, Op0); } } @@ -1634,8 +1592,7 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { { Value *Y; // X-(X+Y) == -Y X-(Y+X) == -Y - if (match(Op1, m_Add(m_Specific(Op0), m_Value(Y))) || - match(Op1, m_Add(m_Value(Y), m_Specific(Op0)))) + if (match(Op1, m_c_Add(m_Specific(Op0), m_Value(Y)))) return BinaryOperator::CreateNeg(Y); // (X-Y)-X == -Y @@ -1645,38 +1602,34 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { // (sub (or A, B) (xor A, B)) --> (and A, B) { - Value *A = nullptr, *B = nullptr; + Value *A, *B; if (match(Op1, m_Xor(m_Value(A), m_Value(B))) && - (match(Op0, m_Or(m_Specific(A), m_Specific(B))) || - match(Op0, m_Or(m_Specific(B), m_Specific(A))))) + match(Op0, m_c_Or(m_Specific(A), m_Specific(B)))) return BinaryOperator::CreateAnd(A, B); } - if (Op0->hasOneUse()) { - Value *Y = nullptr; + { + Value *Y; // ((X | Y) - X) --> (~X & Y) - if (match(Op0, m_Or(m_Value(Y), m_Specific(Op1))) || - match(Op0, m_Or(m_Specific(Op1), m_Value(Y)))) + if (match(Op0, m_OneUse(m_c_Or(m_Value(Y), m_Specific(Op1))))) return BinaryOperator::CreateAnd( - Y, Builder->CreateNot(Op1, Op1->getName() + ".not")); + Y, Builder.CreateNot(Op1, Op1->getName() + ".not")); } if (Op1->hasOneUse()) { Value *X = nullptr, *Y = nullptr, *Z = nullptr; Constant *C = nullptr; - Constant *CI = nullptr; // (X - (Y - Z)) --> (X + (Z - Y)). if (match(Op1, m_Sub(m_Value(Y), m_Value(Z)))) return BinaryOperator::CreateAdd(Op0, - Builder->CreateSub(Z, Y, Op1->getName())); + Builder.CreateSub(Z, Y, Op1->getName())); // (X - (X & Y)) --> (X & ~Y) // - if (match(Op1, m_And(m_Value(Y), m_Specific(Op0))) || - match(Op1, m_And(m_Specific(Op0), m_Value(Y)))) + if (match(Op1, m_c_And(m_Value(Y), m_Specific(Op0)))) return BinaryOperator::CreateAnd(Op0, - Builder->CreateNot(Y, Y->getName() + ".not")); + Builder.CreateNot(Y, Y->getName() + ".not")); // 0 - (X sdiv C) -> (X sdiv -C) provided the negation doesn't overflow. if (match(Op1, m_SDiv(m_Value(X), m_Constant(C))) && match(Op0, m_Zero()) && @@ -1693,7 +1646,7 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { // 'nuw' is dropped in favor of the canonical form. if (match(Op1, m_SExt(m_Value(Y))) && Y->getType()->getScalarSizeInBits() == 1) { - Value *Zext = Builder->CreateZExt(Y, I.getType()); + Value *Zext = Builder.CreateZExt(Y, I.getType()); BinaryOperator *Add = BinaryOperator::CreateAdd(Op0, Zext); Add->setHasNoSignedWrap(I.hasNoSignedWrap()); return Add; @@ -1702,15 +1655,15 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { // X - A*-B -> X + A*B // X - -A*B -> X + A*B Value *A, *B; - if (match(Op1, m_Mul(m_Value(A), m_Neg(m_Value(B)))) || - match(Op1, m_Mul(m_Neg(m_Value(A)), m_Value(B)))) - return BinaryOperator::CreateAdd(Op0, Builder->CreateMul(A, B)); + Constant *CI; + if (match(Op1, m_c_Mul(m_Value(A), m_Neg(m_Value(B))))) + return BinaryOperator::CreateAdd(Op0, Builder.CreateMul(A, B)); // X - A*CI -> X + A*-CI - // X - CI*A -> X + A*-CI - if (match(Op1, m_Mul(m_Value(A), m_Constant(CI))) || - match(Op1, m_Mul(m_Constant(CI), m_Value(A)))) { - Value *NewMul = Builder->CreateMul(A, ConstantExpr::getNeg(CI)); + // No need to handle commuted multiply because multiply handling will + // ensure constant will be move to the right hand side. + if (match(Op1, m_Mul(m_Value(A), m_Constant(CI)))) { + Value *NewMul = Builder.CreateMul(A, ConstantExpr::getNeg(CI)); return BinaryOperator::CreateAdd(Op0, NewMul); } } @@ -1730,11 +1683,11 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return replaceInstUsesWith(I, Res); bool Changed = false; - if (!I.hasNoSignedWrap() && WillNotOverflowSignedSub(Op0, Op1, I)) { + if (!I.hasNoSignedWrap() && willNotOverflowSignedSub(Op0, Op1, I)) { Changed = true; I.setHasNoSignedWrap(true); } - if (!I.hasNoUnsignedWrap() && WillNotOverflowUnsignedSub(Op0, Op1, I)) { + if (!I.hasNoUnsignedWrap() && willNotOverflowUnsignedSub(Op0, Op1, I)) { Changed = true; I.setHasNoUnsignedWrap(true); } @@ -1748,8 +1701,8 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); - if (Value *V = - SimplifyFSubInst(Op0, Op1, I.getFastMathFlags(), DL, &TLI, &DT, &AC)) + if (Value *V = SimplifyFSubInst(Op0, Op1, I.getFastMathFlags(), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); // fsub nsz 0, X ==> fsub nsz -0.0, X @@ -1774,14 +1727,14 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { } if (FPTruncInst *FPTI = dyn_cast<FPTruncInst>(Op1)) { if (Value *V = dyn_castFNegVal(FPTI->getOperand(0))) { - Value *NewTrunc = Builder->CreateFPTrunc(V, I.getType()); + Value *NewTrunc = Builder.CreateFPTrunc(V, I.getType()); Instruction *NewI = BinaryOperator::CreateFAdd(Op0, NewTrunc); NewI->copyFastMathFlags(&I); return NewI; } } else if (FPExtInst *FPEI = dyn_cast<FPExtInst>(Op1)) { if (Value *V = dyn_castFNegVal(FPEI->getOperand(0))) { - Value *NewExt = Builder->CreateFPExt(V, I.getType()); + Value *NewExt = Builder.CreateFPExt(V, I.getType()); Instruction *NewI = BinaryOperator::CreateFAdd(Op0, NewExt); NewI->copyFastMathFlags(&I); return NewI; diff --git a/contrib/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/contrib/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index da5384a..fdc9c37 100644 --- a/contrib/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/contrib/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -23,21 +23,6 @@ using namespace PatternMatch; #define DEBUG_TYPE "instcombine" -static inline Value *dyn_castNotVal(Value *V) { - // If this is not(not(x)) don't return that this is a not: we want the two - // not's to be folded first. - if (BinaryOperator::isNot(V)) { - Value *Operand = BinaryOperator::getNotArgument(V); - if (!IsFreeToInvert(Operand, Operand->hasOneUse())) - return Operand; - } - - // Constants can be considered to be not'ed values... - if (ConstantInt *C = dyn_cast<ConstantInt>(V)) - return ConstantInt::get(C->getType(), ~C->getValue()); - return nullptr; -} - /// Similar to getICmpCode but for FCmpInst. This encodes a fcmp predicate into /// a four bit mask. static unsigned getFCmpCode(FCmpInst::Predicate CC) { @@ -69,17 +54,17 @@ static unsigned getFCmpCode(FCmpInst::Predicate CC) { /// instruction. The sign is passed in to determine which kind of predicate to /// use in the new icmp instruction. static Value *getNewICmpValue(bool Sign, unsigned Code, Value *LHS, Value *RHS, - InstCombiner::BuilderTy *Builder) { + InstCombiner::BuilderTy &Builder) { ICmpInst::Predicate NewPred; if (Value *NewConstant = getICmpValue(Sign, Code, LHS, RHS, NewPred)) return NewConstant; - return Builder->CreateICmp(NewPred, LHS, RHS); + return Builder.CreateICmp(NewPred, LHS, RHS); } /// This is the complement of getFCmpCode, which turns an opcode and two /// operands into either a FCmp instruction, or a true/false constant. static Value *getFCmpValue(unsigned Code, Value *LHS, Value *RHS, - InstCombiner::BuilderTy *Builder) { + InstCombiner::BuilderTy &Builder) { const auto Pred = static_cast<FCmpInst::Predicate>(Code); assert(FCmpInst::FCMP_FALSE <= Pred && Pred <= FCmpInst::FCMP_TRUE && "Unexpected FCmp predicate!"); @@ -87,59 +72,50 @@ static Value *getFCmpValue(unsigned Code, Value *LHS, Value *RHS, return ConstantInt::get(CmpInst::makeCmpResultType(LHS->getType()), 0); if (Pred == FCmpInst::FCMP_TRUE) return ConstantInt::get(CmpInst::makeCmpResultType(LHS->getType()), 1); - return Builder->CreateFCmp(Pred, LHS, RHS); + return Builder.CreateFCmp(Pred, LHS, RHS); } -/// \brief Transform BITWISE_OP(BSWAP(A),BSWAP(B)) to BSWAP(BITWISE_OP(A, B)) +/// \brief Transform BITWISE_OP(BSWAP(A),BSWAP(B)) or +/// BITWISE_OP(BSWAP(A), Constant) to BSWAP(BITWISE_OP(A, B)) /// \param I Binary operator to transform. /// \return Pointer to node that must replace the original binary operator, or /// null pointer if no transformation was made. -Value *InstCombiner::SimplifyBSwap(BinaryOperator &I) { - IntegerType *ITy = dyn_cast<IntegerType>(I.getType()); - - // Can't do vectors. - if (I.getType()->isVectorTy()) - return nullptr; - - // Can only do bitwise ops. - if (!I.isBitwiseLogicOp()) - return nullptr; +static Value *SimplifyBSwap(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + assert(I.isBitwiseLogicOp() && "Unexpected opcode for bswap simplifying"); Value *OldLHS = I.getOperand(0); Value *OldRHS = I.getOperand(1); - ConstantInt *ConstLHS = dyn_cast<ConstantInt>(OldLHS); - ConstantInt *ConstRHS = dyn_cast<ConstantInt>(OldRHS); - IntrinsicInst *IntrLHS = dyn_cast<IntrinsicInst>(OldLHS); - IntrinsicInst *IntrRHS = dyn_cast<IntrinsicInst>(OldRHS); - bool IsBswapLHS = (IntrLHS && IntrLHS->getIntrinsicID() == Intrinsic::bswap); - bool IsBswapRHS = (IntrRHS && IntrRHS->getIntrinsicID() == Intrinsic::bswap); - - if (!IsBswapLHS && !IsBswapRHS) - return nullptr; - - if (!IsBswapLHS && !ConstLHS) - return nullptr; - if (!IsBswapRHS && !ConstRHS) + Value *NewLHS; + if (!match(OldLHS, m_BSwap(m_Value(NewLHS)))) return nullptr; - /// OP( BSWAP(x), BSWAP(y) ) -> BSWAP( OP(x, y) ) - /// OP( BSWAP(x), CONSTANT ) -> BSWAP( OP(x, BSWAP(CONSTANT) ) ) - Value *NewLHS = IsBswapLHS ? IntrLHS->getOperand(0) : - Builder->getInt(ConstLHS->getValue().byteSwap()); + Value *NewRHS; + const APInt *C; - Value *NewRHS = IsBswapRHS ? IntrRHS->getOperand(0) : - Builder->getInt(ConstRHS->getValue().byteSwap()); + if (match(OldRHS, m_BSwap(m_Value(NewRHS)))) { + // OP( BSWAP(x), BSWAP(y) ) -> BSWAP( OP(x, y) ) + if (!OldLHS->hasOneUse() && !OldRHS->hasOneUse()) + return nullptr; + // NewRHS initialized by the matcher. + } else if (match(OldRHS, m_APInt(C))) { + // OP( BSWAP(x), CONSTANT ) -> BSWAP( OP(x, BSWAP(CONSTANT) ) ) + if (!OldLHS->hasOneUse()) + return nullptr; + NewRHS = ConstantInt::get(I.getType(), C->byteSwap()); + } else + return nullptr; - Value *BinOp = Builder->CreateBinOp(I.getOpcode(), NewLHS, NewRHS); - Function *F = Intrinsic::getDeclaration(I.getModule(), Intrinsic::bswap, ITy); - return Builder->CreateCall(F, BinOp); + Value *BinOp = Builder.CreateBinOp(I.getOpcode(), NewLHS, NewRHS); + Function *F = Intrinsic::getDeclaration(I.getModule(), Intrinsic::bswap, + I.getType()); + return Builder.CreateCall(F, BinOp); } /// This handles expressions of the form ((val OP C1) & C2). Where -/// the Op parameter is 'OP', OpRHS is 'C1', and AndRHS is 'C2'. Op is -/// guaranteed to be a binary operator. -Instruction *InstCombiner::OptAndOp(Instruction *Op, +/// the Op parameter is 'OP', OpRHS is 'C1', and AndRHS is 'C2'. +Instruction *InstCombiner::OptAndOp(BinaryOperator *Op, ConstantInt *OpRHS, ConstantInt *AndRHS, BinaryOperator &TheAnd) { @@ -149,30 +125,24 @@ Instruction *InstCombiner::OptAndOp(Instruction *Op, Together = ConstantExpr::getAnd(AndRHS, OpRHS); switch (Op->getOpcode()) { + default: break; case Instruction::Xor: if (Op->hasOneUse()) { // (X ^ C1) & C2 --> (X & C2) ^ (C1&C2) - Value *And = Builder->CreateAnd(X, AndRHS); + Value *And = Builder.CreateAnd(X, AndRHS); And->takeName(Op); return BinaryOperator::CreateXor(And, Together); } break; case Instruction::Or: if (Op->hasOneUse()){ - if (Together != OpRHS) { - // (X | C1) & C2 --> (X | (C1&C2)) & C2 - Value *Or = Builder->CreateOr(X, Together); - Or->takeName(Op); - return BinaryOperator::CreateAnd(Or, AndRHS); - } - ConstantInt *TogetherCI = dyn_cast<ConstantInt>(Together); if (TogetherCI && !TogetherCI->isZero()){ // (X | C1) & C2 --> (X & (C2^(C1&C2))) | C1 // NOTE: This reduces the number of bits set in the & mask, which // can expose opportunities for store narrowing. Together = ConstantExpr::getXor(AndRHS, Together); - Value *And = Builder->CreateAnd(X, Together); + Value *And = Builder.CreateAnd(X, Together); And->takeName(Op); return BinaryOperator::CreateOr(And, OpRHS); } @@ -194,17 +164,17 @@ Instruction *InstCombiner::OptAndOp(Instruction *Op, const APInt& AddRHS = OpRHS->getValue(); // Check to see if any bits below the one bit set in AndRHSV are set. - if ((AddRHS & (AndRHSV-1)) == 0) { + if ((AddRHS & (AndRHSV - 1)).isNullValue()) { // If not, the only thing that can effect the output of the AND is // the bit specified by AndRHSV. If that bit is set, the effect of // the XOR is to toggle the bit. If it is clear, then the ADD has // no effect. - if ((AddRHS & AndRHSV) == 0) { // Bit is not set, noop + if ((AddRHS & AndRHSV).isNullValue()) { // Bit is not set, noop TheAnd.setOperand(0, X); return &TheAnd; } else { // Pull the XOR out of the AND. - Value *NewAnd = Builder->CreateAnd(X, AndRHS); + Value *NewAnd = Builder.CreateAnd(X, AndRHS); NewAnd->takeName(Op); return BinaryOperator::CreateXor(NewAnd, AndRHS); } @@ -220,7 +190,7 @@ Instruction *InstCombiner::OptAndOp(Instruction *Op, uint32_t BitWidth = AndRHS->getType()->getBitWidth(); uint32_t OpRHSVal = OpRHS->getLimitedValue(BitWidth); APInt ShlMask(APInt::getHighBitsSet(BitWidth, BitWidth-OpRHSVal)); - ConstantInt *CI = Builder->getInt(AndRHS->getValue() & ShlMask); + ConstantInt *CI = Builder.getInt(AndRHS->getValue() & ShlMask); if (CI->getValue() == ShlMask) // Masking out bits that the shift already masks. @@ -240,7 +210,7 @@ Instruction *InstCombiner::OptAndOp(Instruction *Op, uint32_t BitWidth = AndRHS->getType()->getBitWidth(); uint32_t OpRHSVal = OpRHS->getLimitedValue(BitWidth); APInt ShrMask(APInt::getLowBitsSet(BitWidth, BitWidth - OpRHSVal)); - ConstantInt *CI = Builder->getInt(AndRHS->getValue() & ShrMask); + ConstantInt *CI = Builder.getInt(AndRHS->getValue() & ShrMask); if (CI->getValue() == ShrMask) // Masking out bits that the shift already masks. @@ -260,12 +230,12 @@ Instruction *InstCombiner::OptAndOp(Instruction *Op, uint32_t BitWidth = AndRHS->getType()->getBitWidth(); uint32_t OpRHSVal = OpRHS->getLimitedValue(BitWidth); APInt ShrMask(APInt::getLowBitsSet(BitWidth, BitWidth - OpRHSVal)); - Constant *C = Builder->getInt(AndRHS->getValue() & ShrMask); + Constant *C = Builder.getInt(AndRHS->getValue() & ShrMask); if (C == AndRHS) { // Masking out bits shifted in. // (Val ashr C1) & C2 -> (Val lshr C1) & C2 // Make the argument unsigned. Value *ShVal = Op->getOperand(0); - ShVal = Builder->CreateLShr(ShVal, OpRHS, Op->getName()); + ShVal = Builder.CreateLShr(ShVal, OpRHS, Op->getName()); return BinaryOperator::CreateAnd(ShVal, AndRHS, TheAnd.getName()); } } @@ -291,189 +261,102 @@ Value *InstCombiner::insertRangeTest(Value *V, const APInt &Lo, const APInt &Hi, ICmpInst::Predicate Pred = Inside ? ICmpInst::ICMP_ULT : ICmpInst::ICMP_UGE; if (isSigned ? Lo.isMinSignedValue() : Lo.isMinValue()) { Pred = isSigned ? ICmpInst::getSignedPredicate(Pred) : Pred; - return Builder->CreateICmp(Pred, V, ConstantInt::get(Ty, Hi)); + return Builder.CreateICmp(Pred, V, ConstantInt::get(Ty, Hi)); } // V >= Lo && V < Hi --> V - Lo u< Hi - Lo // V < Lo || V >= Hi --> V - Lo u>= Hi - Lo Value *VMinusLo = - Builder->CreateSub(V, ConstantInt::get(Ty, Lo), V->getName() + ".off"); + Builder.CreateSub(V, ConstantInt::get(Ty, Lo), V->getName() + ".off"); Constant *HiMinusLo = ConstantInt::get(Ty, Hi - Lo); - return Builder->CreateICmp(Pred, VMinusLo, HiMinusLo); + return Builder.CreateICmp(Pred, VMinusLo, HiMinusLo); } -/// Returns true iff Val consists of one contiguous run of 1s with any number -/// of 0s on either side. The 1s are allowed to wrap from LSB to MSB, -/// so 0x000FFF0, 0x0000FFFF, and 0xFF0000FF are all runs. 0x0F0F0000 is -/// not, since all 1s are not contiguous. -static bool isRunOfOnes(ConstantInt *Val, uint32_t &MB, uint32_t &ME) { - const APInt& V = Val->getValue(); - uint32_t BitWidth = Val->getType()->getBitWidth(); - if (!APIntOps::isShiftedMask(BitWidth, V)) return false; - - // look for the first zero bit after the run of ones - MB = BitWidth - ((V - 1) ^ V).countLeadingZeros(); - // look for the first non-zero bit - ME = V.getActiveBits(); - return true; -} - -/// This is part of an expression (LHS +/- RHS) & Mask, where isSub determines -/// whether the operator is a sub. If we can fold one of the following xforms: +/// Classify (icmp eq (A & B), C) and (icmp ne (A & B), C) as matching patterns +/// that can be simplified. +/// One of A and B is considered the mask. The other is the value. This is +/// described as the "AMask" or "BMask" part of the enum. If the enum contains +/// only "Mask", then both A and B can be considered masks. If A is the mask, +/// then it was proven that (A & C) == C. This is trivial if C == A or C == 0. +/// If both A and C are constants, this proof is also easy. +/// For the following explanations, we assume that A is the mask. /// -/// ((A & N) +/- B) & Mask -> (A +/- B) & Mask iff N&Mask == Mask -/// ((A | N) +/- B) & Mask -> (A +/- B) & Mask iff N&Mask == 0 -/// ((A ^ N) +/- B) & Mask -> (A +/- B) & Mask iff N&Mask == 0 +/// "AllOnes" declares that the comparison is true only if (A & B) == A or all +/// bits of A are set in B. +/// Example: (icmp eq (A & 3), 3) -> AMask_AllOnes /// -/// return (A +/- B). +/// "AllZeros" declares that the comparison is true only if (A & B) == 0 or all +/// bits of A are cleared in B. +/// Example: (icmp eq (A & 3), 0) -> Mask_AllZeroes +/// +/// "Mixed" declares that (A & B) == C and C might or might not contain any +/// number of one bits and zero bits. +/// Example: (icmp eq (A & 3), 1) -> AMask_Mixed +/// +/// "Not" means that in above descriptions "==" should be replaced by "!=". +/// Example: (icmp ne (A & 3), 3) -> AMask_NotAllOnes /// -Value *InstCombiner::FoldLogicalPlusAnd(Value *LHS, Value *RHS, - ConstantInt *Mask, bool isSub, - Instruction &I) { - Instruction *LHSI = dyn_cast<Instruction>(LHS); - if (!LHSI || LHSI->getNumOperands() != 2 || - !isa<ConstantInt>(LHSI->getOperand(1))) return nullptr; - - ConstantInt *N = cast<ConstantInt>(LHSI->getOperand(1)); - - switch (LHSI->getOpcode()) { - default: return nullptr; - case Instruction::And: - if (ConstantExpr::getAnd(N, Mask) == Mask) { - // If the AndRHS is a power of two minus one (0+1+), this is simple. - if ((Mask->getValue().countLeadingZeros() + - Mask->getValue().countPopulation()) == - Mask->getValue().getBitWidth()) - break; - - // Otherwise, if Mask is 0+1+0+, and if B is known to have the low 0+ - // part, we don't need any explicit masks to take them out of A. If that - // is all N is, ignore it. - uint32_t MB = 0, ME = 0; - if (isRunOfOnes(Mask, MB, ME)) { // begin/end bit of run, inclusive - uint32_t BitWidth = cast<IntegerType>(RHS->getType())->getBitWidth(); - APInt Mask(APInt::getLowBitsSet(BitWidth, MB-1)); - if (MaskedValueIsZero(RHS, Mask, 0, &I)) - break; - } - } - return nullptr; - case Instruction::Or: - case Instruction::Xor: - // If the AndRHS is a power of two minus one (0+1+), and N&Mask == 0 - if ((Mask->getValue().countLeadingZeros() + - Mask->getValue().countPopulation()) == Mask->getValue().getBitWidth() - && ConstantExpr::getAnd(N, Mask)->isNullValue()) - break; - return nullptr; - } - - if (isSub) - return Builder->CreateSub(LHSI->getOperand(0), RHS, "fold"); - return Builder->CreateAdd(LHSI->getOperand(0), RHS, "fold"); -} - -/// enum for classifying (icmp eq (A & B), C) and (icmp ne (A & B), C) -/// One of A and B is considered the mask, the other the value. This is -/// described as the "AMask" or "BMask" part of the enum. If the enum -/// contains only "Mask", then both A and B can be considered masks. -/// If A is the mask, then it was proven, that (A & C) == C. This -/// is trivial if C == A, or C == 0. If both A and C are constants, this -/// proof is also easy. -/// For the following explanations we assume that A is the mask. -/// The part "AllOnes" declares, that the comparison is true only -/// if (A & B) == A, or all bits of A are set in B. -/// Example: (icmp eq (A & 3), 3) -> FoldMskICmp_AMask_AllOnes -/// The part "AllZeroes" declares, that the comparison is true only -/// if (A & B) == 0, or all bits of A are cleared in B. -/// Example: (icmp eq (A & 3), 0) -> FoldMskICmp_Mask_AllZeroes -/// The part "Mixed" declares, that (A & B) == C and C might or might not -/// contain any number of one bits and zero bits. -/// Example: (icmp eq (A & 3), 1) -> FoldMskICmp_AMask_Mixed -/// The Part "Not" means, that in above descriptions "==" should be replaced -/// by "!=". -/// Example: (icmp ne (A & 3), 3) -> FoldMskICmp_AMask_NotAllOnes /// If the mask A contains a single bit, then the following is equivalent: /// (icmp eq (A & B), A) equals (icmp ne (A & B), 0) /// (icmp ne (A & B), A) equals (icmp eq (A & B), 0) enum MaskedICmpType { - FoldMskICmp_AMask_AllOnes = 1, - FoldMskICmp_AMask_NotAllOnes = 2, - FoldMskICmp_BMask_AllOnes = 4, - FoldMskICmp_BMask_NotAllOnes = 8, - FoldMskICmp_Mask_AllZeroes = 16, - FoldMskICmp_Mask_NotAllZeroes = 32, - FoldMskICmp_AMask_Mixed = 64, - FoldMskICmp_AMask_NotMixed = 128, - FoldMskICmp_BMask_Mixed = 256, - FoldMskICmp_BMask_NotMixed = 512 + AMask_AllOnes = 1, + AMask_NotAllOnes = 2, + BMask_AllOnes = 4, + BMask_NotAllOnes = 8, + Mask_AllZeros = 16, + Mask_NotAllZeros = 32, + AMask_Mixed = 64, + AMask_NotMixed = 128, + BMask_Mixed = 256, + BMask_NotMixed = 512 }; -/// Return the set of pattern classes (from MaskedICmpType) -/// that (icmp SCC (A & B), C) satisfies. -static unsigned getTypeOfMaskedICmp(Value* A, Value* B, Value* C, - ICmpInst::Predicate SCC) -{ +/// Return the set of patterns (from MaskedICmpType) that (icmp SCC (A & B), C) +/// satisfies. +static unsigned getMaskedICmpType(Value *A, Value *B, Value *C, + ICmpInst::Predicate Pred) { ConstantInt *ACst = dyn_cast<ConstantInt>(A); ConstantInt *BCst = dyn_cast<ConstantInt>(B); ConstantInt *CCst = dyn_cast<ConstantInt>(C); - bool icmp_eq = (SCC == ICmpInst::ICMP_EQ); - bool icmp_abit = (ACst && !ACst->isZero() && - ACst->getValue().isPowerOf2()); - bool icmp_bbit = (BCst && !BCst->isZero() && - BCst->getValue().isPowerOf2()); - unsigned result = 0; + bool IsEq = (Pred == ICmpInst::ICMP_EQ); + bool IsAPow2 = (ACst && !ACst->isZero() && ACst->getValue().isPowerOf2()); + bool IsBPow2 = (BCst && !BCst->isZero() && BCst->getValue().isPowerOf2()); + unsigned MaskVal = 0; if (CCst && CCst->isZero()) { // if C is zero, then both A and B qualify as mask - result |= (icmp_eq ? (FoldMskICmp_Mask_AllZeroes | - FoldMskICmp_AMask_Mixed | - FoldMskICmp_BMask_Mixed) - : (FoldMskICmp_Mask_NotAllZeroes | - FoldMskICmp_AMask_NotMixed | - FoldMskICmp_BMask_NotMixed)); - if (icmp_abit) - result |= (icmp_eq ? (FoldMskICmp_AMask_NotAllOnes | - FoldMskICmp_AMask_NotMixed) - : (FoldMskICmp_AMask_AllOnes | - FoldMskICmp_AMask_Mixed)); - if (icmp_bbit) - result |= (icmp_eq ? (FoldMskICmp_BMask_NotAllOnes | - FoldMskICmp_BMask_NotMixed) - : (FoldMskICmp_BMask_AllOnes | - FoldMskICmp_BMask_Mixed)); - return result; + MaskVal |= (IsEq ? (Mask_AllZeros | AMask_Mixed | BMask_Mixed) + : (Mask_NotAllZeros | AMask_NotMixed | BMask_NotMixed)); + if (IsAPow2) + MaskVal |= (IsEq ? (AMask_NotAllOnes | AMask_NotMixed) + : (AMask_AllOnes | AMask_Mixed)); + if (IsBPow2) + MaskVal |= (IsEq ? (BMask_NotAllOnes | BMask_NotMixed) + : (BMask_AllOnes | BMask_Mixed)); + return MaskVal; } + if (A == C) { - result |= (icmp_eq ? (FoldMskICmp_AMask_AllOnes | - FoldMskICmp_AMask_Mixed) - : (FoldMskICmp_AMask_NotAllOnes | - FoldMskICmp_AMask_NotMixed)); - if (icmp_abit) - result |= (icmp_eq ? (FoldMskICmp_Mask_NotAllZeroes | - FoldMskICmp_AMask_NotMixed) - : (FoldMskICmp_Mask_AllZeroes | - FoldMskICmp_AMask_Mixed)); - } else if (ACst && CCst && - ConstantExpr::getAnd(ACst, CCst) == CCst) { - result |= (icmp_eq ? FoldMskICmp_AMask_Mixed - : FoldMskICmp_AMask_NotMixed); + MaskVal |= (IsEq ? (AMask_AllOnes | AMask_Mixed) + : (AMask_NotAllOnes | AMask_NotMixed)); + if (IsAPow2) + MaskVal |= (IsEq ? (Mask_NotAllZeros | AMask_NotMixed) + : (Mask_AllZeros | AMask_Mixed)); + } else if (ACst && CCst && ConstantExpr::getAnd(ACst, CCst) == CCst) { + MaskVal |= (IsEq ? AMask_Mixed : AMask_NotMixed); } + if (B == C) { - result |= (icmp_eq ? (FoldMskICmp_BMask_AllOnes | - FoldMskICmp_BMask_Mixed) - : (FoldMskICmp_BMask_NotAllOnes | - FoldMskICmp_BMask_NotMixed)); - if (icmp_bbit) - result |= (icmp_eq ? (FoldMskICmp_Mask_NotAllZeroes | - FoldMskICmp_BMask_NotMixed) - : (FoldMskICmp_Mask_AllZeroes | - FoldMskICmp_BMask_Mixed)); - } else if (BCst && CCst && - ConstantExpr::getAnd(BCst, CCst) == CCst) { - result |= (icmp_eq ? FoldMskICmp_BMask_Mixed - : FoldMskICmp_BMask_NotMixed); - } - return result; + MaskVal |= (IsEq ? (BMask_AllOnes | BMask_Mixed) + : (BMask_NotAllOnes | BMask_NotMixed)); + if (IsBPow2) + MaskVal |= (IsEq ? (Mask_NotAllZeros | BMask_NotMixed) + : (Mask_AllZeros | BMask_Mixed)); + } else if (BCst && CCst && ConstantExpr::getAnd(BCst, CCst) == CCst) { + MaskVal |= (IsEq ? BMask_Mixed : BMask_NotMixed); + } + + return MaskVal; } /// Convert an analysis of a masked ICmp into its equivalent if all boolean @@ -482,32 +365,30 @@ static unsigned getTypeOfMaskedICmp(Value* A, Value* B, Value* C, /// involves swapping those bits over. static unsigned conjugateICmpMask(unsigned Mask) { unsigned NewMask; - NewMask = (Mask & (FoldMskICmp_AMask_AllOnes | FoldMskICmp_BMask_AllOnes | - FoldMskICmp_Mask_AllZeroes | FoldMskICmp_AMask_Mixed | - FoldMskICmp_BMask_Mixed)) + NewMask = (Mask & (AMask_AllOnes | BMask_AllOnes | Mask_AllZeros | + AMask_Mixed | BMask_Mixed)) << 1; - NewMask |= - (Mask & (FoldMskICmp_AMask_NotAllOnes | FoldMskICmp_BMask_NotAllOnes | - FoldMskICmp_Mask_NotAllZeroes | FoldMskICmp_AMask_NotMixed | - FoldMskICmp_BMask_NotMixed)) - >> 1; + NewMask |= (Mask & (AMask_NotAllOnes | BMask_NotAllOnes | Mask_NotAllZeros | + AMask_NotMixed | BMask_NotMixed)) + >> 1; return NewMask; } -/// Handle (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E) -/// Return the set of pattern classes (from MaskedICmpType) -/// that both LHS and RHS satisfy. -static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A, - Value*& B, Value*& C, - Value*& D, Value*& E, - ICmpInst *LHS, ICmpInst *RHS, - ICmpInst::Predicate &LHSCC, - ICmpInst::Predicate &RHSCC) { - if (LHS->getOperand(0)->getType() != RHS->getOperand(0)->getType()) return 0; +/// Handle (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E). +/// Return the set of pattern classes (from MaskedICmpType) that both LHS and +/// RHS satisfy. +static unsigned getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, + Value *&D, Value *&E, ICmpInst *LHS, + ICmpInst *RHS, + ICmpInst::Predicate &PredL, + ICmpInst::Predicate &PredR) { + if (LHS->getOperand(0)->getType() != RHS->getOperand(0)->getType()) + return 0; // vectors are not (yet?) supported - if (LHS->getOperand(0)->getType()->isVectorTy()) return 0; + if (LHS->getOperand(0)->getType()->isVectorTy()) + return 0; // Here comes the tricky part: // LHS might be of the form L11 & L12 == X, X == L21 & L22, @@ -517,9 +398,9 @@ static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A, // above. Value *L1 = LHS->getOperand(0); Value *L2 = LHS->getOperand(1); - Value *L11,*L12,*L21,*L22; + Value *L11, *L12, *L21, *L22; // Check whether the icmp can be decomposed into a bit test. - if (decomposeBitTestICmp(LHS, LHSCC, L11, L12, L2)) { + if (decomposeBitTestICmp(LHS, PredL, L11, L12, L2)) { L21 = L22 = L1 = nullptr; } else { // Look for ANDs in the LHS icmp. @@ -543,22 +424,26 @@ static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A, } // Bail if LHS was a icmp that can't be decomposed into an equality. - if (!ICmpInst::isEquality(LHSCC)) + if (!ICmpInst::isEquality(PredL)) return 0; Value *R1 = RHS->getOperand(0); Value *R2 = RHS->getOperand(1); - Value *R11,*R12; - bool ok = false; - if (decomposeBitTestICmp(RHS, RHSCC, R11, R12, R2)) { + Value *R11, *R12; + bool Ok = false; + if (decomposeBitTestICmp(RHS, PredR, R11, R12, R2)) { if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) { - A = R11; D = R12; + A = R11; + D = R12; } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) { - A = R12; D = R11; + A = R12; + D = R11; } else { return 0; } - E = R2; R1 = nullptr; ok = true; + E = R2; + R1 = nullptr; + Ok = true; } else if (R1->getType()->isIntegerTy()) { if (!match(R1, m_And(m_Value(R11), m_Value(R12)))) { // As before, model no mask as a trivial mask if it'll let us do an @@ -568,60 +453,78 @@ static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A, } if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) { - A = R11; D = R12; E = R2; ok = true; + A = R11; + D = R12; + E = R2; + Ok = true; } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) { - A = R12; D = R11; E = R2; ok = true; + A = R12; + D = R11; + E = R2; + Ok = true; } } // Bail if RHS was a icmp that can't be decomposed into an equality. - if (!ICmpInst::isEquality(RHSCC)) + if (!ICmpInst::isEquality(PredR)) return 0; // Look for ANDs on the right side of the RHS icmp. - if (!ok && R2->getType()->isIntegerTy()) { + if (!Ok && R2->getType()->isIntegerTy()) { if (!match(R2, m_And(m_Value(R11), m_Value(R12)))) { R11 = R2; R12 = Constant::getAllOnesValue(R2->getType()); } if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) { - A = R11; D = R12; E = R1; ok = true; + A = R11; + D = R12; + E = R1; + Ok = true; } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) { - A = R12; D = R11; E = R1; ok = true; + A = R12; + D = R11; + E = R1; + Ok = true; } else { return 0; } } - if (!ok) + if (!Ok) return 0; if (L11 == A) { - B = L12; C = L2; + B = L12; + C = L2; } else if (L12 == A) { - B = L11; C = L2; + B = L11; + C = L2; } else if (L21 == A) { - B = L22; C = L1; + B = L22; + C = L1; } else if (L22 == A) { - B = L21; C = L1; + B = L21; + C = L1; } - unsigned LeftType = getTypeOfMaskedICmp(A, B, C, LHSCC); - unsigned RightType = getTypeOfMaskedICmp(A, D, E, RHSCC); + unsigned LeftType = getMaskedICmpType(A, B, C, PredL); + unsigned RightType = getMaskedICmpType(A, D, E, PredR); return LeftType & RightType; } /// Try to fold (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E) /// into a single (icmp(A & X) ==/!= Y). static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, - llvm::InstCombiner::BuilderTy *Builder) { + llvm::InstCombiner::BuilderTy &Builder) { Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr; - ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate(); - unsigned Mask = foldLogOpOfMaskedICmpsHelper(A, B, C, D, E, LHS, RHS, - LHSCC, RHSCC); - if (Mask == 0) return nullptr; - assert(ICmpInst::isEquality(LHSCC) && ICmpInst::isEquality(RHSCC) && - "foldLogOpOfMaskedICmpsHelper must return an equality predicate."); + ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); + unsigned Mask = + getMaskedTypeForICmpPair(A, B, C, D, E, LHS, RHS, PredL, PredR); + if (Mask == 0) + return nullptr; + + assert(ICmpInst::isEquality(PredL) && ICmpInst::isEquality(PredR) && + "Expected equality predicates for masked type of icmps."); // In full generality: // (icmp (A & B) Op C) | (icmp (A & D) Op E) @@ -642,41 +545,43 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, Mask = conjugateICmpMask(Mask); } - if (Mask & FoldMskICmp_Mask_AllZeroes) { + if (Mask & Mask_AllZeros) { // (icmp eq (A & B), 0) & (icmp eq (A & D), 0) // -> (icmp eq (A & (B|D)), 0) - Value *NewOr = Builder->CreateOr(B, D); - Value *NewAnd = Builder->CreateAnd(A, NewOr); + Value *NewOr = Builder.CreateOr(B, D); + Value *NewAnd = Builder.CreateAnd(A, NewOr); // We can't use C as zero because we might actually handle // (icmp ne (A & B), B) & (icmp ne (A & D), D) // with B and D, having a single bit set. Value *Zero = Constant::getNullValue(A->getType()); - return Builder->CreateICmp(NewCC, NewAnd, Zero); + return Builder.CreateICmp(NewCC, NewAnd, Zero); } - if (Mask & FoldMskICmp_BMask_AllOnes) { + if (Mask & BMask_AllOnes) { // (icmp eq (A & B), B) & (icmp eq (A & D), D) // -> (icmp eq (A & (B|D)), (B|D)) - Value *NewOr = Builder->CreateOr(B, D); - Value *NewAnd = Builder->CreateAnd(A, NewOr); - return Builder->CreateICmp(NewCC, NewAnd, NewOr); + Value *NewOr = Builder.CreateOr(B, D); + Value *NewAnd = Builder.CreateAnd(A, NewOr); + return Builder.CreateICmp(NewCC, NewAnd, NewOr); } - if (Mask & FoldMskICmp_AMask_AllOnes) { + if (Mask & AMask_AllOnes) { // (icmp eq (A & B), A) & (icmp eq (A & D), A) // -> (icmp eq (A & (B&D)), A) - Value *NewAnd1 = Builder->CreateAnd(B, D); - Value *NewAnd2 = Builder->CreateAnd(A, NewAnd1); - return Builder->CreateICmp(NewCC, NewAnd2, A); + Value *NewAnd1 = Builder.CreateAnd(B, D); + Value *NewAnd2 = Builder.CreateAnd(A, NewAnd1); + return Builder.CreateICmp(NewCC, NewAnd2, A); } // Remaining cases assume at least that B and D are constant, and depend on // their actual values. This isn't strictly necessary, just a "handle the // easy cases for now" decision. ConstantInt *BCst = dyn_cast<ConstantInt>(B); - if (!BCst) return nullptr; + if (!BCst) + return nullptr; ConstantInt *DCst = dyn_cast<ConstantInt>(D); - if (!DCst) return nullptr; + if (!DCst) + return nullptr; - if (Mask & (FoldMskICmp_Mask_NotAllZeroes | FoldMskICmp_BMask_NotAllOnes)) { + if (Mask & (Mask_NotAllZeros | BMask_NotAllOnes)) { // (icmp ne (A & B), 0) & (icmp ne (A & D), 0) and // (icmp ne (A & B), B) & (icmp ne (A & D), D) // -> (icmp ne (A & B), 0) or (icmp ne (A & D), 0) @@ -689,7 +594,8 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, else if (NewMask == DCst->getValue()) return RHS; } - if (Mask & FoldMskICmp_AMask_NotAllOnes) { + + if (Mask & AMask_NotAllOnes) { // (icmp ne (A & B), B) & (icmp ne (A & D), D) // -> (icmp ne (A & B), A) or (icmp ne (A & D), A) // Only valid if one of the masks is a superset of the other (check "B|D" is @@ -701,7 +607,8 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, else if (NewMask == DCst->getValue()) return RHS; } - if (Mask & FoldMskICmp_BMask_Mixed) { + + if (Mask & BMask_Mixed) { // (icmp eq (A & B), C) & (icmp eq (A & D), E) // We already know that B & C == C && D & E == E. // If we can prove that (B & D) & (C ^ E) == 0, that is, the bits of @@ -713,23 +620,28 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, // (icmp ne (A & B), B) & (icmp eq (A & D), D) // with B and D, having a single bit set. ConstantInt *CCst = dyn_cast<ConstantInt>(C); - if (!CCst) return nullptr; + if (!CCst) + return nullptr; ConstantInt *ECst = dyn_cast<ConstantInt>(E); - if (!ECst) return nullptr; - if (LHSCC != NewCC) + if (!ECst) + return nullptr; + if (PredL != NewCC) CCst = cast<ConstantInt>(ConstantExpr::getXor(BCst, CCst)); - if (RHSCC != NewCC) + if (PredR != NewCC) ECst = cast<ConstantInt>(ConstantExpr::getXor(DCst, ECst)); + // If there is a conflict, we should actually return a false for the // whole construct. if (((BCst->getValue() & DCst->getValue()) & - (CCst->getValue() ^ ECst->getValue())) != 0) + (CCst->getValue() ^ ECst->getValue())).getBoolValue()) return ConstantInt::get(LHS->getType(), !IsAnd); - Value *NewOr1 = Builder->CreateOr(B, D); + + Value *NewOr1 = Builder.CreateOr(B, D); Value *NewOr2 = ConstantExpr::getOr(CCst, ECst); - Value *NewAnd = Builder->CreateAnd(A, NewOr1); - return Builder->CreateICmp(NewCC, NewAnd, NewOr2); + Value *NewAnd = Builder.CreateAnd(A, NewOr1); + return Builder.CreateICmp(NewCC, NewAnd, NewOr2); } + return nullptr; } @@ -778,23 +690,123 @@ Value *InstCombiner::simplifyRangeCheck(ICmpInst *Cmp0, ICmpInst *Cmp1, } // This simplification is only valid if the upper range is not negative. - bool IsNegative, IsNotNegative; - ComputeSignBit(RangeEnd, IsNotNegative, IsNegative, /*Depth=*/0, Cmp1); - if (!IsNotNegative) + KnownBits Known = computeKnownBits(RangeEnd, /*Depth=*/0, Cmp1); + if (!Known.isNonNegative()) return nullptr; if (Inverted) NewPred = ICmpInst::getInversePredicate(NewPred); - return Builder->CreateICmp(NewPred, Input, RangeEnd); + return Builder.CreateICmp(NewPred, Input, RangeEnd); +} + +static Value * +foldAndOrOfEqualityCmpsWithConstants(ICmpInst *LHS, ICmpInst *RHS, + bool JoinedByAnd, + InstCombiner::BuilderTy &Builder) { + Value *X = LHS->getOperand(0); + if (X != RHS->getOperand(0)) + return nullptr; + + const APInt *C1, *C2; + if (!match(LHS->getOperand(1), m_APInt(C1)) || + !match(RHS->getOperand(1), m_APInt(C2))) + return nullptr; + + // We only handle (X != C1 && X != C2) and (X == C1 || X == C2). + ICmpInst::Predicate Pred = LHS->getPredicate(); + if (Pred != RHS->getPredicate()) + return nullptr; + if (JoinedByAnd && Pred != ICmpInst::ICMP_NE) + return nullptr; + if (!JoinedByAnd && Pred != ICmpInst::ICMP_EQ) + return nullptr; + + // The larger unsigned constant goes on the right. + if (C1->ugt(*C2)) + std::swap(C1, C2); + + APInt Xor = *C1 ^ *C2; + if (Xor.isPowerOf2()) { + // If LHSC and RHSC differ by only one bit, then set that bit in X and + // compare against the larger constant: + // (X == C1 || X == C2) --> (X | (C1 ^ C2)) == C2 + // (X != C1 && X != C2) --> (X | (C1 ^ C2)) != C2 + // We choose an 'or' with a Pow2 constant rather than the inverse mask with + // 'and' because that may lead to smaller codegen from a smaller constant. + Value *Or = Builder.CreateOr(X, ConstantInt::get(X->getType(), Xor)); + return Builder.CreateICmp(Pred, Or, ConstantInt::get(X->getType(), *C2)); + } + + // Special case: get the ordering right when the values wrap around zero. + // Ie, we assumed the constants were unsigned when swapping earlier. + if (C1->isNullValue() && C2->isAllOnesValue()) + std::swap(C1, C2); + + if (*C1 == *C2 - 1) { + // (X == 13 || X == 14) --> X - 13 <=u 1 + // (X != 13 && X != 14) --> X - 13 >u 1 + // An 'add' is the canonical IR form, so favor that over a 'sub'. + Value *Add = Builder.CreateAdd(X, ConstantInt::get(X->getType(), -(*C1))); + auto NewPred = JoinedByAnd ? ICmpInst::ICMP_UGT : ICmpInst::ICMP_ULE; + return Builder.CreateICmp(NewPred, Add, ConstantInt::get(X->getType(), 1)); + } + + return nullptr; +} + +// Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2) +// Fold (!iszero(A & K1) & !iszero(A & K2)) -> (A & (K1 | K2)) == (K1 | K2) +Value *InstCombiner::foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS, + bool JoinedByAnd, + Instruction &CxtI) { + ICmpInst::Predicate Pred = LHS->getPredicate(); + if (Pred != RHS->getPredicate()) + return nullptr; + if (JoinedByAnd && Pred != ICmpInst::ICMP_NE) + return nullptr; + if (!JoinedByAnd && Pred != ICmpInst::ICMP_EQ) + return nullptr; + + // TODO support vector splats + ConstantInt *LHSC = dyn_cast<ConstantInt>(LHS->getOperand(1)); + ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS->getOperand(1)); + if (!LHSC || !RHSC || !LHSC->isZero() || !RHSC->isZero()) + return nullptr; + + Value *A, *B, *C, *D; + if (match(LHS->getOperand(0), m_And(m_Value(A), m_Value(B))) && + match(RHS->getOperand(0), m_And(m_Value(C), m_Value(D)))) { + if (A == D || B == D) + std::swap(C, D); + if (B == C) + std::swap(A, B); + + if (A == C && + isKnownToBeAPowerOfTwo(B, false, 0, &CxtI) && + isKnownToBeAPowerOfTwo(D, false, 0, &CxtI)) { + Value *Mask = Builder.CreateOr(B, D); + Value *Masked = Builder.CreateAnd(A, Mask); + auto NewPred = JoinedByAnd ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE; + return Builder.CreateICmp(NewPred, Masked, Mask); + } + } + + return nullptr; } /// Fold (icmp)&(icmp) if possible. -Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { - ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate(); +Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, + Instruction &CxtI) { + // Fold (!iszero(A & K1) & !iszero(A & K2)) -> (A & (K1 | K2)) == (K1 | K2) + // if K1 and K2 are a one-bit mask. + if (Value *V = foldAndOrOfICmpsOfAndWithPow2(LHS, RHS, true, CxtI)) + return V; + + ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); // (icmp1 A, B) & (icmp2 A, B) --> (icmp3 A, B) - if (PredicatesFoldable(LHSCC, RHSCC)) { + if (PredicatesFoldable(PredL, PredR)) { if (LHS->getOperand(0) == RHS->getOperand(1) && LHS->getOperand(1) == RHS->getOperand(0)) LHS->swapOperands(); @@ -819,86 +831,91 @@ Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { if (Value *V = simplifyRangeCheck(RHS, LHS, /*Inverted=*/false)) return V; + if (Value *V = foldAndOrOfEqualityCmpsWithConstants(LHS, RHS, true, Builder)) + return V; + // This only handles icmp of constants: (icmp1 A, C1) & (icmp2 B, C2). - Value *Val = LHS->getOperand(0), *Val2 = RHS->getOperand(0); - ConstantInt *LHSCst = dyn_cast<ConstantInt>(LHS->getOperand(1)); - ConstantInt *RHSCst = dyn_cast<ConstantInt>(RHS->getOperand(1)); - if (!LHSCst || !RHSCst) return nullptr; + Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0); + ConstantInt *LHSC = dyn_cast<ConstantInt>(LHS->getOperand(1)); + ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS->getOperand(1)); + if (!LHSC || !RHSC) + return nullptr; - if (LHSCst == RHSCst && LHSCC == RHSCC) { + if (LHSC == RHSC && PredL == PredR) { // (icmp ult A, C) & (icmp ult B, C) --> (icmp ult (A|B), C) // where C is a power of 2 or // (icmp eq A, 0) & (icmp eq B, 0) --> (icmp eq (A|B), 0) - if ((LHSCC == ICmpInst::ICMP_ULT && LHSCst->getValue().isPowerOf2()) || - (LHSCC == ICmpInst::ICMP_EQ && LHSCst->isZero())) { - Value *NewOr = Builder->CreateOr(Val, Val2); - return Builder->CreateICmp(LHSCC, NewOr, LHSCst); + if ((PredL == ICmpInst::ICMP_ULT && LHSC->getValue().isPowerOf2()) || + (PredL == ICmpInst::ICMP_EQ && LHSC->isZero())) { + Value *NewOr = Builder.CreateOr(LHS0, RHS0); + return Builder.CreateICmp(PredL, NewOr, LHSC); } } // (trunc x) == C1 & (and x, CA) == C2 -> (and x, CA|CMAX) == C1|C2 // where CMAX is the all ones value for the truncated type, // iff the lower bits of C2 and CA are zero. - if (LHSCC == ICmpInst::ICMP_EQ && LHSCC == RHSCC && - LHS->hasOneUse() && RHS->hasOneUse()) { + if (PredL == ICmpInst::ICMP_EQ && PredL == PredR && LHS->hasOneUse() && + RHS->hasOneUse()) { Value *V; - ConstantInt *AndCst, *SmallCst = nullptr, *BigCst = nullptr; + ConstantInt *AndC, *SmallC = nullptr, *BigC = nullptr; // (trunc x) == C1 & (and x, CA) == C2 // (and x, CA) == C2 & (trunc x) == C1 - if (match(Val2, m_Trunc(m_Value(V))) && - match(Val, m_And(m_Specific(V), m_ConstantInt(AndCst)))) { - SmallCst = RHSCst; - BigCst = LHSCst; - } else if (match(Val, m_Trunc(m_Value(V))) && - match(Val2, m_And(m_Specific(V), m_ConstantInt(AndCst)))) { - SmallCst = LHSCst; - BigCst = RHSCst; + if (match(RHS0, m_Trunc(m_Value(V))) && + match(LHS0, m_And(m_Specific(V), m_ConstantInt(AndC)))) { + SmallC = RHSC; + BigC = LHSC; + } else if (match(LHS0, m_Trunc(m_Value(V))) && + match(RHS0, m_And(m_Specific(V), m_ConstantInt(AndC)))) { + SmallC = LHSC; + BigC = RHSC; } - if (SmallCst && BigCst) { - unsigned BigBitSize = BigCst->getType()->getBitWidth(); - unsigned SmallBitSize = SmallCst->getType()->getBitWidth(); + if (SmallC && BigC) { + unsigned BigBitSize = BigC->getType()->getBitWidth(); + unsigned SmallBitSize = SmallC->getType()->getBitWidth(); // Check that the low bits are zero. APInt Low = APInt::getLowBitsSet(BigBitSize, SmallBitSize); - if ((Low & AndCst->getValue()) == 0 && (Low & BigCst->getValue()) == 0) { - Value *NewAnd = Builder->CreateAnd(V, Low | AndCst->getValue()); - APInt N = SmallCst->getValue().zext(BigBitSize) | BigCst->getValue(); - Value *NewVal = ConstantInt::get(AndCst->getType()->getContext(), N); - return Builder->CreateICmp(LHSCC, NewAnd, NewVal); + if ((Low & AndC->getValue()).isNullValue() && + (Low & BigC->getValue()).isNullValue()) { + Value *NewAnd = Builder.CreateAnd(V, Low | AndC->getValue()); + APInt N = SmallC->getValue().zext(BigBitSize) | BigC->getValue(); + Value *NewVal = ConstantInt::get(AndC->getType()->getContext(), N); + return Builder.CreateICmp(PredL, NewAnd, NewVal); } } } // From here on, we only handle: // (icmp1 A, C1) & (icmp2 A, C2) --> something simpler. - if (Val != Val2) return nullptr; + if (LHS0 != RHS0) + return nullptr; - // ICMP_[US][GL]E X, CST is folded to ICMP_[US][GL]T elsewhere. - if (LHSCC == ICmpInst::ICMP_UGE || LHSCC == ICmpInst::ICMP_ULE || - RHSCC == ICmpInst::ICMP_UGE || RHSCC == ICmpInst::ICMP_ULE || - LHSCC == ICmpInst::ICMP_SGE || LHSCC == ICmpInst::ICMP_SLE || - RHSCC == ICmpInst::ICMP_SGE || RHSCC == ICmpInst::ICMP_SLE) + // ICMP_[US][GL]E X, C is folded to ICMP_[US][GL]T elsewhere. + if (PredL == ICmpInst::ICMP_UGE || PredL == ICmpInst::ICMP_ULE || + PredR == ICmpInst::ICMP_UGE || PredR == ICmpInst::ICMP_ULE || + PredL == ICmpInst::ICMP_SGE || PredL == ICmpInst::ICMP_SLE || + PredR == ICmpInst::ICMP_SGE || PredR == ICmpInst::ICMP_SLE) return nullptr; // We can't fold (ugt x, C) & (sgt x, C2). - if (!PredicatesFoldable(LHSCC, RHSCC)) + if (!PredicatesFoldable(PredL, PredR)) return nullptr; // Ensure that the larger constant is on the RHS. bool ShouldSwap; - if (CmpInst::isSigned(LHSCC) || - (ICmpInst::isEquality(LHSCC) && - CmpInst::isSigned(RHSCC))) - ShouldSwap = LHSCst->getValue().sgt(RHSCst->getValue()); + if (CmpInst::isSigned(PredL) || + (ICmpInst::isEquality(PredL) && CmpInst::isSigned(PredR))) + ShouldSwap = LHSC->getValue().sgt(RHSC->getValue()); else - ShouldSwap = LHSCst->getValue().ugt(RHSCst->getValue()); + ShouldSwap = LHSC->getValue().ugt(RHSC->getValue()); if (ShouldSwap) { std::swap(LHS, RHS); - std::swap(LHSCst, RHSCst); - std::swap(LHSCC, RHSCC); + std::swap(LHSC, RHSC); + std::swap(PredL, PredR); } // At this point, we know we have two icmp instructions @@ -907,113 +924,55 @@ Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { // icmp eq, icmp ne, icmp [su]lt, and icmp [SU]gt here. We also know // (from the icmp folding check above), that the two constants // are not equal and that the larger constant is on the RHS - assert(LHSCst != RHSCst && "Compares not folded above?"); + assert(LHSC != RHSC && "Compares not folded above?"); - switch (LHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_EQ: - switch (RHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_NE: // (X == 13 & X != 15) -> X == 13 - case ICmpInst::ICMP_ULT: // (X == 13 & X < 15) -> X == 13 - case ICmpInst::ICMP_SLT: // (X == 13 & X < 15) -> X == 13 - return LHS; - } + switch (PredL) { + default: + llvm_unreachable("Unknown integer condition code!"); case ICmpInst::ICMP_NE: - switch (RHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); + switch (PredR) { + default: + llvm_unreachable("Unknown integer condition code!"); case ICmpInst::ICMP_ULT: - if (LHSCst == SubOne(RHSCst)) // (X != 13 & X u< 14) -> X < 13 - return Builder->CreateICmpULT(Val, LHSCst); - if (LHSCst->isNullValue()) // (X != 0 & X u< 14) -> X-1 u< 13 - return insertRangeTest(Val, LHSCst->getValue() + 1, RHSCst->getValue(), + if (LHSC == SubOne(RHSC)) // (X != 13 & X u< 14) -> X < 13 + return Builder.CreateICmpULT(LHS0, LHSC); + if (LHSC->isZero()) // (X != 0 & X u< 14) -> X-1 u< 13 + return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(), false, true); - break; // (X != 13 & X u< 15) -> no change + break; // (X != 13 & X u< 15) -> no change case ICmpInst::ICMP_SLT: - if (LHSCst == SubOne(RHSCst)) // (X != 13 & X s< 14) -> X < 13 - return Builder->CreateICmpSLT(Val, LHSCst); - break; // (X != 13 & X s< 15) -> no change - case ICmpInst::ICMP_EQ: // (X != 13 & X == 15) -> X == 15 - case ICmpInst::ICMP_UGT: // (X != 13 & X u> 15) -> X u> 15 - case ICmpInst::ICMP_SGT: // (X != 13 & X s> 15) -> X s> 15 - return RHS; + if (LHSC == SubOne(RHSC)) // (X != 13 & X s< 14) -> X < 13 + return Builder.CreateICmpSLT(LHS0, LHSC); + break; // (X != 13 & X s< 15) -> no change case ICmpInst::ICMP_NE: - // Special case to get the ordering right when the values wrap around - // zero. - if (LHSCst->getValue() == 0 && RHSCst->getValue().isAllOnesValue()) - std::swap(LHSCst, RHSCst); - if (LHSCst == SubOne(RHSCst)){// (X != 13 & X != 14) -> X-13 >u 1 - Constant *AddCST = ConstantExpr::getNeg(LHSCst); - Value *Add = Builder->CreateAdd(Val, AddCST, Val->getName()+".off"); - return Builder->CreateICmpUGT(Add, ConstantInt::get(Add->getType(), 1), - Val->getName()+".cmp"); - } - break; // (X != 13 & X != 15) -> no change - } - break; - case ICmpInst::ICMP_ULT: - switch (RHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_EQ: // (X u< 13 & X == 15) -> false - case ICmpInst::ICMP_UGT: // (X u< 13 & X u> 15) -> false - return ConstantInt::get(CmpInst::makeCmpResultType(LHS->getType()), 0); - case ICmpInst::ICMP_SGT: // (X u< 13 & X s> 15) -> no change - break; - case ICmpInst::ICMP_NE: // (X u< 13 & X != 15) -> X u< 13 - case ICmpInst::ICMP_ULT: // (X u< 13 & X u< 15) -> X u< 13 - return LHS; - case ICmpInst::ICMP_SLT: // (X u< 13 & X s< 15) -> no change - break; - } - break; - case ICmpInst::ICMP_SLT: - switch (RHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_UGT: // (X s< 13 & X u> 15) -> no change - break; - case ICmpInst::ICMP_NE: // (X s< 13 & X != 15) -> X < 13 - case ICmpInst::ICMP_SLT: // (X s< 13 & X s< 15) -> X < 13 - return LHS; - case ICmpInst::ICMP_ULT: // (X s< 13 & X u< 15) -> no change + // Potential folds for this case should already be handled. break; } break; case ICmpInst::ICMP_UGT: - switch (RHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_EQ: // (X u> 13 & X == 15) -> X == 15 - case ICmpInst::ICMP_UGT: // (X u> 13 & X u> 15) -> X u> 15 - return RHS; - case ICmpInst::ICMP_SGT: // (X u> 13 & X s> 15) -> no change - break; + switch (PredR) { + default: + llvm_unreachable("Unknown integer condition code!"); case ICmpInst::ICMP_NE: - if (RHSCst == AddOne(LHSCst)) // (X u> 13 & X != 14) -> X u> 14 - return Builder->CreateICmp(LHSCC, Val, RHSCst); - break; // (X u> 13 & X != 15) -> no change - case ICmpInst::ICMP_ULT: // (X u> 13 & X u< 15) -> (X-14) <u 1 - return insertRangeTest(Val, LHSCst->getValue() + 1, RHSCst->getValue(), + if (RHSC == AddOne(LHSC)) // (X u> 13 & X != 14) -> X u> 14 + return Builder.CreateICmp(PredL, LHS0, RHSC); + break; // (X u> 13 & X != 15) -> no change + case ICmpInst::ICMP_ULT: // (X u> 13 & X u< 15) -> (X-14) <u 1 + return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(), false, true); - case ICmpInst::ICMP_SLT: // (X u> 13 & X s< 15) -> no change - break; } break; case ICmpInst::ICMP_SGT: - switch (RHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_EQ: // (X s> 13 & X == 15) -> X == 15 - case ICmpInst::ICMP_SGT: // (X s> 13 & X s> 15) -> X s> 15 - return RHS; - case ICmpInst::ICMP_UGT: // (X s> 13 & X u> 15) -> no change - break; + switch (PredR) { + default: + llvm_unreachable("Unknown integer condition code!"); case ICmpInst::ICMP_NE: - if (RHSCst == AddOne(LHSCst)) // (X s> 13 & X != 14) -> X s> 14 - return Builder->CreateICmp(LHSCC, Val, RHSCst); - break; // (X s> 13 & X != 15) -> no change - case ICmpInst::ICMP_SLT: // (X s> 13 & X s< 15) -> (X-14) s< 1 - return insertRangeTest(Val, LHSCst->getValue() + 1, RHSCst->getValue(), - true, true); - case ICmpInst::ICMP_ULT: // (X s> 13 & X u< 15) -> no change - break; + if (RHSC == AddOne(LHSC)) // (X s> 13 & X != 14) -> X s> 14 + return Builder.CreateICmp(PredL, LHS0, RHSC); + break; // (X s> 13 & X != 15) -> no change + case ICmpInst::ICMP_SLT: // (X s> 13 & X s< 15) -> (X-14) s< 1 + return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(), true, + true); } break; } @@ -1023,7 +982,7 @@ Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { /// Optimize (fcmp)&(fcmp). NOTE: Unlike the rest of instcombine, this returns /// a Value which should already be inserted into the function. -Value *InstCombiner::FoldAndOfFCmps(FCmpInst *LHS, FCmpInst *RHS) { +Value *InstCombiner::foldAndOfFCmps(FCmpInst *LHS, FCmpInst *RHS) { Value *Op0LHS = LHS->getOperand(0), *Op0RHS = LHS->getOperand(1); Value *Op1LHS = RHS->getOperand(0), *Op1RHS = RHS->getOperand(1); FCmpInst::Predicate Op0CC = LHS->getPredicate(), Op1CC = RHS->getPredicate(); @@ -1058,15 +1017,15 @@ Value *InstCombiner::FoldAndOfFCmps(FCmpInst *LHS, FCmpInst *RHS) { // If either of the constants are nans, then the whole thing returns // false. if (LHSC->getValueAPF().isNaN() || RHSC->getValueAPF().isNaN()) - return Builder->getFalse(); - return Builder->CreateFCmpORD(LHS->getOperand(0), RHS->getOperand(0)); + return Builder.getFalse(); + return Builder.CreateFCmpORD(LHS->getOperand(0), RHS->getOperand(0)); } // Handle vector zeros. This occurs because the canonical form of // "fcmp ord x,x" is "fcmp ord x, 0". if (isa<ConstantAggregateZero>(LHS->getOperand(1)) && isa<ConstantAggregateZero>(RHS->getOperand(1))) - return Builder->CreateFCmpORD(LHS->getOperand(0), RHS->getOperand(0)); + return Builder.CreateFCmpORD(LHS->getOperand(0), RHS->getOperand(0)); return nullptr; } @@ -1077,26 +1036,22 @@ Value *InstCombiner::FoldAndOfFCmps(FCmpInst *LHS, FCmpInst *RHS) { /// (~A & ~B) == (~(A | B)) /// (~A | ~B) == (~(A & B)) static Instruction *matchDeMorgansLaws(BinaryOperator &I, - InstCombiner::BuilderTy *Builder) { + InstCombiner::BuilderTy &Builder) { auto Opcode = I.getOpcode(); assert((Opcode == Instruction::And || Opcode == Instruction::Or) && "Trying to match De Morgan's Laws with something other than and/or"); + // Flip the logic operation. - if (Opcode == Instruction::And) - Opcode = Instruction::Or; - else - Opcode = Instruction::And; + Opcode = (Opcode == Instruction::And) ? Instruction::Or : Instruction::And; - Value *Op0 = I.getOperand(0); - Value *Op1 = I.getOperand(1); - // TODO: Use pattern matchers instead of dyn_cast. - if (Value *Op0NotVal = dyn_castNotVal(Op0)) - if (Value *Op1NotVal = dyn_castNotVal(Op1)) - if (Op0->hasOneUse() && Op1->hasOneUse()) { - Value *LogicOp = Builder->CreateBinOp(Opcode, Op0NotVal, Op1NotVal, - I.getName() + ".demorgan"); - return BinaryOperator::CreateNot(LogicOp); - } + Value *A, *B; + if (match(I.getOperand(0), m_OneUse(m_Not(m_Value(A)))) && + match(I.getOperand(1), m_OneUse(m_Not(m_Value(B)))) && + !IsFreeToInvert(A, A->hasOneUse()) && + !IsFreeToInvert(B, B->hasOneUse())) { + Value *AndOr = Builder.CreateBinOp(Opcode, A, B, I.getName() + ".demorgan"); + return BinaryOperator::CreateNot(AndOr); + } return nullptr; } @@ -1125,7 +1080,7 @@ bool InstCombiner::shouldOptimizeCast(CastInst *CI) { /// Fold {and,or,xor} (cast X), C. static Instruction *foldLogicCastConstant(BinaryOperator &Logic, CastInst *Cast, - InstCombiner::BuilderTy *Builder) { + InstCombiner::BuilderTy &Builder) { Constant *C; if (!match(Logic.getOperand(1), m_Constant(C))) return nullptr; @@ -1134,26 +1089,17 @@ static Instruction *foldLogicCastConstant(BinaryOperator &Logic, CastInst *Cast, Type *DestTy = Logic.getType(); Type *SrcTy = Cast->getSrcTy(); - // If the first operand is bitcast, move the logic operation ahead of the - // bitcast (do the logic operation in the original type). This can eliminate - // bitcasts and allow combines that would otherwise be impeded by the bitcast. + // Move the logic operation ahead of a zext if the constant is unchanged in + // the smaller source type. Performing the logic in a smaller type may provide + // more information to later folds, and the smaller logic instruction may be + // cheaper (particularly in the case of vectors). Value *X; - if (match(Cast, m_BitCast(m_Value(X)))) { - Value *NewConstant = ConstantExpr::getBitCast(C, SrcTy); - Value *NewOp = Builder->CreateBinOp(LogicOpc, X, NewConstant); - return CastInst::CreateBitOrPointerCast(NewOp, DestTy); - } - - // Similarly, move the logic operation ahead of a zext if the constant is - // unchanged in the smaller source type. Performing the logic in a smaller - // type may provide more information to later folds, and the smaller logic - // instruction may be cheaper (particularly in the case of vectors). if (match(Cast, m_OneUse(m_ZExt(m_Value(X))))) { Constant *TruncC = ConstantExpr::getTrunc(C, SrcTy); Constant *ZextTruncC = ConstantExpr::getZExt(TruncC, DestTy); if (ZextTruncC == C) { // LogicOpc (zext X), C --> zext (LogicOpc X, C) - Value *NewOp = Builder->CreateBinOp(LogicOpc, X, TruncC); + Value *NewOp = Builder.CreateBinOp(LogicOpc, X, TruncC); return new ZExtInst(NewOp, DestTy); } } @@ -1196,7 +1142,7 @@ Instruction *InstCombiner::foldCastedBitwiseLogic(BinaryOperator &I) { // fold logic(cast(A), cast(B)) -> cast(logic(A, B)) if (shouldOptimizeCast(Cast0) && shouldOptimizeCast(Cast1)) { - Value *NewOp = Builder->CreateBinOp(LogicOpc, Cast0Src, Cast1Src, + Value *NewOp = Builder.CreateBinOp(LogicOpc, Cast0Src, Cast1Src, I.getName()); return CastInst::Create(CastOpcode, NewOp, DestTy); } @@ -1210,8 +1156,8 @@ Instruction *InstCombiner::foldCastedBitwiseLogic(BinaryOperator &I) { ICmpInst *ICmp0 = dyn_cast<ICmpInst>(Cast0Src); ICmpInst *ICmp1 = dyn_cast<ICmpInst>(Cast1Src); if (ICmp0 && ICmp1) { - Value *Res = LogicOpc == Instruction::And ? FoldAndOfICmps(ICmp0, ICmp1) - : FoldOrOfICmps(ICmp0, ICmp1, &I); + Value *Res = LogicOpc == Instruction::And ? foldAndOfICmps(ICmp0, ICmp1, I) + : foldOrOfICmps(ICmp0, ICmp1, I); if (Res) return CastInst::Create(CastOpcode, Res, DestTy); return nullptr; @@ -1222,8 +1168,8 @@ Instruction *InstCombiner::foldCastedBitwiseLogic(BinaryOperator &I) { FCmpInst *FCmp0 = dyn_cast<FCmpInst>(Cast0Src); FCmpInst *FCmp1 = dyn_cast<FCmpInst>(Cast1Src); if (FCmp0 && FCmp1) { - Value *Res = LogicOpc == Instruction::And ? FoldAndOfFCmps(FCmp0, FCmp1) - : FoldOrOfFCmps(FCmp0, FCmp1); + Value *Res = LogicOpc == Instruction::And ? foldAndOfFCmps(FCmp0, FCmp1) + : foldOrOfFCmps(FCmp0, FCmp1); if (Res) return CastInst::Create(CastOpcode, Res, DestTy); return nullptr; @@ -1242,15 +1188,14 @@ static Instruction *foldBoolSextMaskToSelect(BinaryOperator &I) { // Fold (and (sext bool to A), B) --> (select bool, B, 0) Value *X = nullptr; - if (match(Op0, m_SExt(m_Value(X))) && - X->getType()->getScalarType()->isIntegerTy(1)) { + if (match(Op0, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) { Value *Zero = Constant::getNullValue(Op1->getType()); return SelectInst::Create(X, Op1, Zero); } // Fold (and ~(sext bool to A), B) --> (select bool, 0, B) if (match(Op0, m_Not(m_SExt(m_Value(X)))) && - X->getType()->getScalarType()->isIntegerTy(1)) { + X->getType()->isIntOrIntVectorTy(1)) { Value *Zero = Constant::getNullValue(Op0->getType()); return SelectInst::Create(X, Zero, Op1); } @@ -1258,6 +1203,58 @@ static Instruction *foldBoolSextMaskToSelect(BinaryOperator &I) { return nullptr; } +static Instruction *foldAndToXor(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + assert(I.getOpcode() == Instruction::And); + Value *Op0 = I.getOperand(0); + Value *Op1 = I.getOperand(1); + Value *A, *B; + + // Operand complexity canonicalization guarantees that the 'or' is Op0. + // (A | B) & ~(A & B) --> A ^ B + // (A | B) & ~(B & A) --> A ^ B + if (match(Op0, m_Or(m_Value(A), m_Value(B))) && + match(Op1, m_Not(m_c_And(m_Specific(A), m_Specific(B))))) + return BinaryOperator::CreateXor(A, B); + + // (A | ~B) & (~A | B) --> ~(A ^ B) + // (A | ~B) & (B | ~A) --> ~(A ^ B) + // (~B | A) & (~A | B) --> ~(A ^ B) + // (~B | A) & (B | ~A) --> ~(A ^ B) + if (Op0->hasOneUse() || Op1->hasOneUse()) + if (match(Op0, m_c_Or(m_Value(A), m_Not(m_Value(B)))) && + match(Op1, m_c_Or(m_Not(m_Specific(A)), m_Specific(B)))) + return BinaryOperator::CreateNot(Builder.CreateXor(A, B)); + + return nullptr; +} + +static Instruction *foldOrToXor(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + assert(I.getOpcode() == Instruction::Or); + Value *Op0 = I.getOperand(0); + Value *Op1 = I.getOperand(1); + Value *A, *B; + + // Operand complexity canonicalization guarantees that the 'and' is Op0. + // (A & B) | ~(A | B) --> ~(A ^ B) + // (A & B) | ~(B | A) --> ~(A ^ B) + if (Op0->hasOneUse() || Op1->hasOneUse()) + if (match(Op0, m_And(m_Value(A), m_Value(B))) && + match(Op1, m_Not(m_c_Or(m_Specific(A), m_Specific(B))))) + return BinaryOperator::CreateNot(Builder.CreateXor(A, B)); + + // (A & ~B) | (~A & B) --> A ^ B + // (A & ~B) | (B & ~A) --> A ^ B + // (~B & A) | (~A & B) --> A ^ B + // (~B & A) | (B & ~A) --> A ^ B + if (match(Op0, m_c_And(m_Value(A), m_Not(m_Value(B)))) && + match(Op1, m_c_And(m_Not(m_Specific(A)), m_Specific(B)))) + return BinaryOperator::CreateXor(A, B); + + return nullptr; +} + // FIXME: We use commutative matchers (m_c_*) for some, but not all, matches // here. We should standardize that construct where it is needed or choose some // other way to ensure that commutated variants of patterns are not missed. @@ -1268,11 +1265,7 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); - if (Value *V = SimplifyAndInst(Op0, Op1, DL, &TLI, &DT, &AC)) - return replaceInstUsesWith(I, V); - - // (A|B)&(A|C) -> A|(B&C) etc - if (Value *V = SimplifyUsingDistributiveLaws(I)) + if (Value *V = SimplifyAndInst(Op0, Op1, SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); // See if we can simplify any instructions used by the instruction whose sole @@ -1280,9 +1273,27 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { if (SimplifyDemandedInstructionBits(I)) return &I; - if (Value *V = SimplifyBSwap(I)) + // Do this before using distributive laws to catch simple and/or/not patterns. + if (Instruction *Xor = foldAndToXor(I, Builder)) + return Xor; + + // (A|B)&(A|C) -> A|(B&C) etc + if (Value *V = SimplifyUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); + if (Value *V = SimplifyBSwap(I, Builder)) + return replaceInstUsesWith(I, V); + + if (match(Op1, m_One())) { + // (1 << x) & 1 --> zext(x == 0) + // (1 >> x) & 1 --> zext(x == 0) + Value *X; + if (match(Op0, m_OneUse(m_LogicalShift(m_One(), m_Value(X))))) { + Value *IsZero = Builder.CreateICmpEQ(X, ConstantInt::get(I.getType(), 0)); + return new ZExtInst(IsZero, I.getType()); + } + } + if (ConstantInt *AndRHS = dyn_cast<ConstantInt>(Op1)) { const APInt &AndRHSMask = AndRHS->getValue(); @@ -1300,65 +1311,47 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { APInt NotAndRHS(~AndRHSMask); if (MaskedValueIsZero(Op0LHS, NotAndRHS, 0, &I)) { // Not masking anything out for the LHS, move to RHS. - Value *NewRHS = Builder->CreateAnd(Op0RHS, AndRHS, - Op0RHS->getName()+".masked"); + Value *NewRHS = Builder.CreateAnd(Op0RHS, AndRHS, + Op0RHS->getName()+".masked"); return BinaryOperator::Create(Op0I->getOpcode(), Op0LHS, NewRHS); } if (!isa<Constant>(Op0RHS) && MaskedValueIsZero(Op0RHS, NotAndRHS, 0, &I)) { // Not masking anything out for the RHS, move to LHS. - Value *NewLHS = Builder->CreateAnd(Op0LHS, AndRHS, - Op0LHS->getName()+".masked"); + Value *NewLHS = Builder.CreateAnd(Op0LHS, AndRHS, + Op0LHS->getName()+".masked"); return BinaryOperator::Create(Op0I->getOpcode(), NewLHS, Op0RHS); } break; } - case Instruction::Add: - // ((A & N) + B) & AndRHS -> (A + B) & AndRHS iff N&AndRHS == AndRHS. - // ((A | N) + B) & AndRHS -> (A + B) & AndRHS iff N&AndRHS == 0 - // ((A ^ N) + B) & AndRHS -> (A + B) & AndRHS iff N&AndRHS == 0 - if (Value *V = FoldLogicalPlusAnd(Op0LHS, Op0RHS, AndRHS, false, I)) - return BinaryOperator::CreateAnd(V, AndRHS); - if (Value *V = FoldLogicalPlusAnd(Op0RHS, Op0LHS, AndRHS, false, I)) - return BinaryOperator::CreateAnd(V, AndRHS); // Add commutes - break; + } + // ((C1 OP zext(X)) & C2) -> zext((C1-X) & C2) if C2 fits in the bitwidth + // of X and OP behaves well when given trunc(C1) and X. + switch (Op0I->getOpcode()) { + default: + break; + case Instruction::Xor: + case Instruction::Or: + case Instruction::Mul: + case Instruction::Add: case Instruction::Sub: - // ((A & N) - B) & AndRHS -> (A - B) & AndRHS iff N&AndRHS == AndRHS. - // ((A | N) - B) & AndRHS -> (A - B) & AndRHS iff N&AndRHS == 0 - // ((A ^ N) - B) & AndRHS -> (A - B) & AndRHS iff N&AndRHS == 0 - if (Value *V = FoldLogicalPlusAnd(Op0LHS, Op0RHS, AndRHS, true, I)) - return BinaryOperator::CreateAnd(V, AndRHS); - - // -x & 1 -> x & 1 - if (AndRHSMask == 1 && match(Op0LHS, m_Zero())) - return BinaryOperator::CreateAnd(Op0RHS, AndRHS); - - // (A - N) & AndRHS -> -N & AndRHS iff A&AndRHS==0 and AndRHS - // has 1's for all bits that the subtraction with A might affect. - if (Op0I->hasOneUse() && !match(Op0LHS, m_Zero())) { - uint32_t BitWidth = AndRHSMask.getBitWidth(); - uint32_t Zeros = AndRHSMask.countLeadingZeros(); - APInt Mask = APInt::getLowBitsSet(BitWidth, BitWidth - Zeros); - - if (MaskedValueIsZero(Op0LHS, Mask, 0, &I)) { - Value *NewNeg = Builder->CreateNeg(Op0RHS); - return BinaryOperator::CreateAnd(NewNeg, AndRHS); + Value *X; + ConstantInt *C1; + if (match(Op0I, m_c_BinOp(m_ZExt(m_Value(X)), m_ConstantInt(C1)))) { + if (AndRHSMask.isIntN(X->getType()->getScalarSizeInBits())) { + auto *TruncC1 = ConstantExpr::getTrunc(C1, X->getType()); + Value *BinOp; + if (isa<ZExtInst>(Op0LHS)) + BinOp = Builder.CreateBinOp(Op0I->getOpcode(), X, TruncC1); + else + BinOp = Builder.CreateBinOp(Op0I->getOpcode(), TruncC1, X); + auto *TruncC2 = ConstantExpr::getTrunc(AndRHS, X->getType()); + auto *And = Builder.CreateAnd(BinOp, TruncC2); + return new ZExtInst(And, I.getType()); } } - break; - - case Instruction::Shl: - case Instruction::LShr: - // (1 << x) & 1 --> zext(x == 0) - // (1 >> x) & 1 --> zext(x == 0) - if (AndRHSMask == 1 && Op0LHS == AndRHS) { - Value *NewICmp = - Builder->CreateICmpEQ(Op0RHS, Constant::getNullValue(I.getType())); - return new ZExtInst(NewICmp, I.getType()); - } - break; } if (ConstantInt *Op0CI = dyn_cast<ConstantInt>(Op0I->getOperand(1))) @@ -1375,34 +1368,23 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { // into : and (trunc X to T), trunc(YC) & C2 // This will fold the two constants together, which may allow // other simplifications. - Value *NewCast = Builder->CreateTrunc(X, I.getType(), "and.shrunk"); + Value *NewCast = Builder.CreateTrunc(X, I.getType(), "and.shrunk"); Constant *C3 = ConstantExpr::getTrunc(YC, I.getType()); C3 = ConstantExpr::getAnd(C3, AndRHS); return BinaryOperator::CreateAnd(NewCast, C3); } } + } + if (isa<Constant>(Op1)) if (Instruction *FoldedLogic = foldOpWithConstantIntoOperand(I)) return FoldedLogic; - } if (Instruction *DeMorgan = matchDeMorgansLaws(I, Builder)) return DeMorgan; { - Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr; - // (A|B) & ~(A&B) -> A^B - if (match(Op0, m_Or(m_Value(A), m_Value(B))) && - match(Op1, m_Not(m_And(m_Value(C), m_Value(D)))) && - ((A == C && B == D) || (A == D && B == C))) - return BinaryOperator::CreateXor(A, B); - - // ~(A&B) & (A|B) -> A^B - if (match(Op1, m_Or(m_Value(A), m_Value(B))) && - match(Op0, m_Not(m_And(m_Value(C), m_Value(D)))) && - ((A == C && B == D) || (A == D && B == C))) - return BinaryOperator::CreateXor(A, B); - + Value *A = nullptr, *B = nullptr, *C = nullptr; // A&(A^B) => A & ~B { Value *tmpOp0 = Op0; @@ -1424,38 +1406,35 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { // an endless loop. By checking that A is non-constant we ensure that // we will never get to the loop. if (A == tmpOp0 && !isa<Constant>(A)) // A&(A^B) -> A & ~B - return BinaryOperator::CreateAnd(A, Builder->CreateNot(B)); + return BinaryOperator::CreateAnd(A, Builder.CreateNot(B)); } } - // (A&((~A)|B)) -> A&B - if (match(Op0, m_Or(m_Not(m_Specific(Op1)), m_Value(A))) || - match(Op0, m_Or(m_Value(A), m_Not(m_Specific(Op1))))) - return BinaryOperator::CreateAnd(A, Op1); - if (match(Op1, m_Or(m_Not(m_Specific(Op0)), m_Value(A))) || - match(Op1, m_Or(m_Value(A), m_Not(m_Specific(Op0))))) - return BinaryOperator::CreateAnd(A, Op0); - // (A ^ B) & ((B ^ C) ^ A) -> (A ^ B) & ~C if (match(Op0, m_Xor(m_Value(A), m_Value(B)))) if (match(Op1, m_Xor(m_Xor(m_Specific(B), m_Value(C)), m_Specific(A)))) - if (Op1->hasOneUse() || cast<BinaryOperator>(Op1)->hasOneUse()) - return BinaryOperator::CreateAnd(Op0, Builder->CreateNot(C)); + if (Op1->hasOneUse() || IsFreeToInvert(C, C->hasOneUse())) + return BinaryOperator::CreateAnd(Op0, Builder.CreateNot(C)); // ((A ^ C) ^ B) & (B ^ A) -> (B ^ A) & ~C if (match(Op0, m_Xor(m_Xor(m_Value(A), m_Value(C)), m_Value(B)))) if (match(Op1, m_Xor(m_Specific(B), m_Specific(A)))) - if (Op0->hasOneUse() || cast<BinaryOperator>(Op0)->hasOneUse()) - return BinaryOperator::CreateAnd(Op1, Builder->CreateNot(C)); + if (Op0->hasOneUse() || IsFreeToInvert(C, C->hasOneUse())) + return BinaryOperator::CreateAnd(Op1, Builder.CreateNot(C)); // (A | B) & ((~A) ^ B) -> (A & B) - if (match(Op0, m_Or(m_Value(A), m_Value(B))) && - match(Op1, m_Xor(m_Not(m_Specific(A)), m_Specific(B)))) + // (A | B) & (B ^ (~A)) -> (A & B) + // (B | A) & ((~A) ^ B) -> (A & B) + // (B | A) & (B ^ (~A)) -> (A & B) + if (match(Op1, m_c_Xor(m_Not(m_Value(A)), m_Value(B))) && + match(Op0, m_c_Or(m_Specific(A), m_Specific(B)))) return BinaryOperator::CreateAnd(A, B); // ((~A) ^ B) & (A | B) -> (A & B) // ((~A) ^ B) & (B | A) -> (A & B) - if (match(Op0, m_Xor(m_Not(m_Value(A)), m_Value(B))) && + // (B ^ (~A)) & (A | B) -> (A & B) + // (B ^ (~A)) & (B | A) -> (A & B) + if (match(Op0, m_c_Xor(m_Not(m_Value(A)), m_Value(B))) && match(Op1, m_c_Or(m_Specific(A), m_Specific(B)))) return BinaryOperator::CreateAnd(A, B); } @@ -1464,7 +1443,7 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { ICmpInst *LHS = dyn_cast<ICmpInst>(Op0); ICmpInst *RHS = dyn_cast<ICmpInst>(Op1); if (LHS && RHS) - if (Value *Res = FoldAndOfICmps(LHS, RHS)) + if (Value *Res = foldAndOfICmps(LHS, RHS, I)) return replaceInstUsesWith(I, Res); // TODO: Make this recursive; it's a little tricky because an arbitrary @@ -1472,26 +1451,26 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { Value *X, *Y; if (LHS && match(Op1, m_OneUse(m_And(m_Value(X), m_Value(Y))))) { if (auto *Cmp = dyn_cast<ICmpInst>(X)) - if (Value *Res = FoldAndOfICmps(LHS, Cmp)) - return replaceInstUsesWith(I, Builder->CreateAnd(Res, Y)); + if (Value *Res = foldAndOfICmps(LHS, Cmp, I)) + return replaceInstUsesWith(I, Builder.CreateAnd(Res, Y)); if (auto *Cmp = dyn_cast<ICmpInst>(Y)) - if (Value *Res = FoldAndOfICmps(LHS, Cmp)) - return replaceInstUsesWith(I, Builder->CreateAnd(Res, X)); + if (Value *Res = foldAndOfICmps(LHS, Cmp, I)) + return replaceInstUsesWith(I, Builder.CreateAnd(Res, X)); } if (RHS && match(Op0, m_OneUse(m_And(m_Value(X), m_Value(Y))))) { if (auto *Cmp = dyn_cast<ICmpInst>(X)) - if (Value *Res = FoldAndOfICmps(Cmp, RHS)) - return replaceInstUsesWith(I, Builder->CreateAnd(Res, Y)); + if (Value *Res = foldAndOfICmps(Cmp, RHS, I)) + return replaceInstUsesWith(I, Builder.CreateAnd(Res, Y)); if (auto *Cmp = dyn_cast<ICmpInst>(Y)) - if (Value *Res = FoldAndOfICmps(Cmp, RHS)) - return replaceInstUsesWith(I, Builder->CreateAnd(Res, X)); + if (Value *Res = foldAndOfICmps(Cmp, RHS, I)) + return replaceInstUsesWith(I, Builder.CreateAnd(Res, X)); } } // If and'ing two fcmp, try combine them into one. if (FCmpInst *LHS = dyn_cast<FCmpInst>(I.getOperand(0))) if (FCmpInst *RHS = dyn_cast<FCmpInst>(I.getOperand(1))) - if (Value *Res = FoldAndOfFCmps(LHS, RHS)) + if (Value *Res = foldAndOfFCmps(LHS, RHS)) return replaceInstUsesWith(I, Res); if (Instruction *CastedAnd = foldCastedBitwiseLogic(I)) @@ -1566,16 +1545,19 @@ static Value *getSelectCondition(Value *A, Value *B, InstCombiner::BuilderTy &Builder) { // If these are scalars or vectors of i1, A can be used directly. Type *Ty = A->getType(); - if (match(A, m_Not(m_Specific(B))) && Ty->getScalarType()->isIntegerTy(1)) + if (match(A, m_Not(m_Specific(B))) && Ty->isIntOrIntVectorTy(1)) return A; // If A and B are sign-extended, look through the sexts to find the booleans. Value *Cond; + Value *NotB; if (match(A, m_SExt(m_Value(Cond))) && - Cond->getType()->getScalarType()->isIntegerTy(1) && - match(B, m_CombineOr(m_Not(m_SExt(m_Specific(Cond))), - m_SExt(m_Not(m_Specific(Cond)))))) - return Cond; + Cond->getType()->isIntOrIntVectorTy(1) && + match(B, m_OneUse(m_Not(m_Value(NotB))))) { + NotB = peekThroughBitcast(NotB, true); + if (match(NotB, m_SExt(m_Specific(Cond)))) + return Cond; + } // All scalar (and most vector) possibilities should be handled now. // Try more matches that only apply to non-splat constant vectors. @@ -1592,7 +1574,7 @@ static Value *getSelectCondition(Value *A, Value *B, // operand, see if the constants are inverse bitmasks. if (match(A, (m_Xor(m_SExt(m_Value(Cond)), m_Constant(AC)))) && match(B, (m_Xor(m_SExt(m_Specific(Cond)), m_Constant(BC)))) && - Cond->getType()->getScalarType()->isIntegerTy(1) && + Cond->getType()->isIntOrIntVectorTy(1) && areInverseVectorBitmasks(AC, BC)) { AC = ConstantExpr::getTrunc(AC, CmpInst::makeCmpResultType(Ty)); return Builder.CreateXor(Cond, AC); @@ -1607,12 +1589,8 @@ static Value *matchSelectFromAndOr(Value *A, Value *C, Value *B, Value *D, // The potential condition of the select may be bitcasted. In that case, look // through its bitcast and the corresponding bitcast of the 'not' condition. Type *OrigType = A->getType(); - Value *SrcA, *SrcB; - if (match(A, m_OneUse(m_BitCast(m_Value(SrcA)))) && - match(B, m_OneUse(m_BitCast(m_Value(SrcB))))) { - A = SrcA; - B = SrcB; - } + A = peekThroughBitcast(A, true); + B = peekThroughBitcast(B, true); if (Value *Cond = getSelectCondition(A, B, Builder)) { // ((bc Cond) & C) | ((bc ~Cond) & D) --> bc (select Cond, (bc C), (bc D)) @@ -1628,46 +1606,17 @@ static Value *matchSelectFromAndOr(Value *A, Value *C, Value *B, Value *D, } /// Fold (icmp)|(icmp) if possible. -Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, - Instruction *CxtI) { - ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate(); - +Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, + Instruction &CxtI) { // Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2) // if K1 and K2 are a one-bit mask. - ConstantInt *LHSCst = dyn_cast<ConstantInt>(LHS->getOperand(1)); - ConstantInt *RHSCst = dyn_cast<ConstantInt>(RHS->getOperand(1)); - - if (LHS->getPredicate() == ICmpInst::ICMP_EQ && LHSCst && LHSCst->isZero() && - RHS->getPredicate() == ICmpInst::ICMP_EQ && RHSCst && RHSCst->isZero()) { - - BinaryOperator *LAnd = dyn_cast<BinaryOperator>(LHS->getOperand(0)); - BinaryOperator *RAnd = dyn_cast<BinaryOperator>(RHS->getOperand(0)); - if (LAnd && RAnd && LAnd->hasOneUse() && RHS->hasOneUse() && - LAnd->getOpcode() == Instruction::And && - RAnd->getOpcode() == Instruction::And) { - - Value *Mask = nullptr; - Value *Masked = nullptr; - if (LAnd->getOperand(0) == RAnd->getOperand(0) && - isKnownToBeAPowerOfTwo(LAnd->getOperand(1), DL, false, 0, &AC, CxtI, - &DT) && - isKnownToBeAPowerOfTwo(RAnd->getOperand(1), DL, false, 0, &AC, CxtI, - &DT)) { - Mask = Builder->CreateOr(LAnd->getOperand(1), RAnd->getOperand(1)); - Masked = Builder->CreateAnd(LAnd->getOperand(0), Mask); - } else if (LAnd->getOperand(1) == RAnd->getOperand(1) && - isKnownToBeAPowerOfTwo(LAnd->getOperand(0), DL, false, 0, &AC, - CxtI, &DT) && - isKnownToBeAPowerOfTwo(RAnd->getOperand(0), DL, false, 0, &AC, - CxtI, &DT)) { - Mask = Builder->CreateOr(LAnd->getOperand(0), RAnd->getOperand(0)); - Masked = Builder->CreateAnd(LAnd->getOperand(1), Mask); - } + if (Value *V = foldAndOrOfICmpsOfAndWithPow2(LHS, RHS, false, CxtI)) + return V; - if (Masked) - return Builder->CreateICmp(ICmpInst::ICMP_NE, Masked, Mask); - } - } + ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); + + ConstantInt *LHSC = dyn_cast<ConstantInt>(LHS->getOperand(1)); + ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS->getOperand(1)); // Fold (icmp ult/ule (A + C1), C3) | (icmp ult/ule (A + C2), C3) // --> (icmp ult/ule ((A & ~(C1 ^ C2)) + max(C1, C2)), C3) @@ -1680,52 +1629,52 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, // 4) LowRange1 ^ LowRange2 and HighRange1 ^ HighRange2 are one-bit mask. // This implies all values in the two ranges differ by exactly one bit. - if ((LHSCC == ICmpInst::ICMP_ULT || LHSCC == ICmpInst::ICMP_ULE) && - LHSCC == RHSCC && LHSCst && RHSCst && LHS->hasOneUse() && - RHS->hasOneUse() && LHSCst->getType() == RHSCst->getType() && - LHSCst->getValue() == (RHSCst->getValue())) { + if ((PredL == ICmpInst::ICMP_ULT || PredL == ICmpInst::ICMP_ULE) && + PredL == PredR && LHSC && RHSC && LHS->hasOneUse() && RHS->hasOneUse() && + LHSC->getType() == RHSC->getType() && + LHSC->getValue() == (RHSC->getValue())) { Value *LAdd = LHS->getOperand(0); Value *RAdd = RHS->getOperand(0); Value *LAddOpnd, *RAddOpnd; - ConstantInt *LAddCst, *RAddCst; - if (match(LAdd, m_Add(m_Value(LAddOpnd), m_ConstantInt(LAddCst))) && - match(RAdd, m_Add(m_Value(RAddOpnd), m_ConstantInt(RAddCst))) && - LAddCst->getValue().ugt(LHSCst->getValue()) && - RAddCst->getValue().ugt(LHSCst->getValue())) { - - APInt DiffCst = LAddCst->getValue() ^ RAddCst->getValue(); - if (LAddOpnd == RAddOpnd && DiffCst.isPowerOf2()) { - ConstantInt *MaxAddCst = nullptr; - if (LAddCst->getValue().ult(RAddCst->getValue())) - MaxAddCst = RAddCst; + ConstantInt *LAddC, *RAddC; + if (match(LAdd, m_Add(m_Value(LAddOpnd), m_ConstantInt(LAddC))) && + match(RAdd, m_Add(m_Value(RAddOpnd), m_ConstantInt(RAddC))) && + LAddC->getValue().ugt(LHSC->getValue()) && + RAddC->getValue().ugt(LHSC->getValue())) { + + APInt DiffC = LAddC->getValue() ^ RAddC->getValue(); + if (LAddOpnd == RAddOpnd && DiffC.isPowerOf2()) { + ConstantInt *MaxAddC = nullptr; + if (LAddC->getValue().ult(RAddC->getValue())) + MaxAddC = RAddC; else - MaxAddCst = LAddCst; + MaxAddC = LAddC; - APInt RRangeLow = -RAddCst->getValue(); - APInt RRangeHigh = RRangeLow + LHSCst->getValue(); - APInt LRangeLow = -LAddCst->getValue(); - APInt LRangeHigh = LRangeLow + LHSCst->getValue(); + APInt RRangeLow = -RAddC->getValue(); + APInt RRangeHigh = RRangeLow + LHSC->getValue(); + APInt LRangeLow = -LAddC->getValue(); + APInt LRangeHigh = LRangeLow + LHSC->getValue(); APInt LowRangeDiff = RRangeLow ^ LRangeLow; APInt HighRangeDiff = RRangeHigh ^ LRangeHigh; APInt RangeDiff = LRangeLow.sgt(RRangeLow) ? LRangeLow - RRangeLow : RRangeLow - LRangeLow; if (LowRangeDiff.isPowerOf2() && LowRangeDiff == HighRangeDiff && - RangeDiff.ugt(LHSCst->getValue())) { - Value *MaskCst = ConstantInt::get(LAddCst->getType(), ~DiffCst); + RangeDiff.ugt(LHSC->getValue())) { + Value *MaskC = ConstantInt::get(LAddC->getType(), ~DiffC); - Value *NewAnd = Builder->CreateAnd(LAddOpnd, MaskCst); - Value *NewAdd = Builder->CreateAdd(NewAnd, MaxAddCst); - return (Builder->CreateICmp(LHS->getPredicate(), NewAdd, LHSCst)); + Value *NewAnd = Builder.CreateAnd(LAddOpnd, MaskC); + Value *NewAdd = Builder.CreateAdd(NewAnd, MaxAddC); + return Builder.CreateICmp(LHS->getPredicate(), NewAdd, LHSC); } } } } // (icmp1 A, B) | (icmp2 A, B) --> (icmp3 A, B) - if (PredicatesFoldable(LHSCC, RHSCC)) { + if (PredicatesFoldable(PredL, PredR)) { if (LHS->getOperand(0) == RHS->getOperand(1) && LHS->getOperand(1) == RHS->getOperand(0)) LHS->swapOperands(); @@ -1743,31 +1692,31 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, if (Value *V = foldLogOpOfMaskedICmps(LHS, RHS, false, Builder)) return V; - Value *Val = LHS->getOperand(0), *Val2 = RHS->getOperand(0); + Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0); if (LHS->hasOneUse() || RHS->hasOneUse()) { // (icmp eq B, 0) | (icmp ult A, B) -> (icmp ule A, B-1) // (icmp eq B, 0) | (icmp ugt B, A) -> (icmp ule A, B-1) Value *A = nullptr, *B = nullptr; - if (LHSCC == ICmpInst::ICMP_EQ && LHSCst && LHSCst->isZero()) { - B = Val; - if (RHSCC == ICmpInst::ICMP_ULT && Val == RHS->getOperand(1)) - A = Val2; - else if (RHSCC == ICmpInst::ICMP_UGT && Val == Val2) + if (PredL == ICmpInst::ICMP_EQ && LHSC && LHSC->isZero()) { + B = LHS0; + if (PredR == ICmpInst::ICMP_ULT && LHS0 == RHS->getOperand(1)) + A = RHS0; + else if (PredR == ICmpInst::ICMP_UGT && LHS0 == RHS0) A = RHS->getOperand(1); } // (icmp ult A, B) | (icmp eq B, 0) -> (icmp ule A, B-1) // (icmp ugt B, A) | (icmp eq B, 0) -> (icmp ule A, B-1) - else if (RHSCC == ICmpInst::ICMP_EQ && RHSCst && RHSCst->isZero()) { - B = Val2; - if (LHSCC == ICmpInst::ICMP_ULT && Val2 == LHS->getOperand(1)) - A = Val; - else if (LHSCC == ICmpInst::ICMP_UGT && Val2 == Val) + else if (PredR == ICmpInst::ICMP_EQ && RHSC && RHSC->isZero()) { + B = RHS0; + if (PredL == ICmpInst::ICMP_ULT && RHS0 == LHS->getOperand(1)) + A = LHS0; + else if (PredL == ICmpInst::ICMP_UGT && LHS0 == RHS0) A = LHS->getOperand(1); } if (A && B) - return Builder->CreateICmp( + return Builder.CreateICmp( ICmpInst::ICMP_UGE, - Builder->CreateAdd(B, ConstantInt::getSigned(B->getType(), -1)), A); + Builder.CreateAdd(B, ConstantInt::getSigned(B->getType(), -1)), A); } // E.g. (icmp slt x, 0) | (icmp sgt x, n) --> icmp ugt x, n @@ -1778,54 +1727,58 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, if (Value *V = simplifyRangeCheck(RHS, LHS, /*Inverted=*/true)) return V; + if (Value *V = foldAndOrOfEqualityCmpsWithConstants(LHS, RHS, false, Builder)) + return V; + // This only handles icmp of constants: (icmp1 A, C1) | (icmp2 B, C2). - if (!LHSCst || !RHSCst) return nullptr; + if (!LHSC || !RHSC) + return nullptr; - if (LHSCst == RHSCst && LHSCC == RHSCC) { + if (LHSC == RHSC && PredL == PredR) { // (icmp ne A, 0) | (icmp ne B, 0) --> (icmp ne (A|B), 0) - if (LHSCC == ICmpInst::ICMP_NE && LHSCst->isZero()) { - Value *NewOr = Builder->CreateOr(Val, Val2); - return Builder->CreateICmp(LHSCC, NewOr, LHSCst); + if (PredL == ICmpInst::ICMP_NE && LHSC->isZero()) { + Value *NewOr = Builder.CreateOr(LHS0, RHS0); + return Builder.CreateICmp(PredL, NewOr, LHSC); } } // (icmp ult (X + CA), C1) | (icmp eq X, C2) -> (icmp ule (X + CA), C1) // iff C2 + CA == C1. - if (LHSCC == ICmpInst::ICMP_ULT && RHSCC == ICmpInst::ICMP_EQ) { - ConstantInt *AddCst; - if (match(Val, m_Add(m_Specific(Val2), m_ConstantInt(AddCst)))) - if (RHSCst->getValue() + AddCst->getValue() == LHSCst->getValue()) - return Builder->CreateICmpULE(Val, LHSCst); + if (PredL == ICmpInst::ICMP_ULT && PredR == ICmpInst::ICMP_EQ) { + ConstantInt *AddC; + if (match(LHS0, m_Add(m_Specific(RHS0), m_ConstantInt(AddC)))) + if (RHSC->getValue() + AddC->getValue() == LHSC->getValue()) + return Builder.CreateICmpULE(LHS0, LHSC); } // From here on, we only handle: // (icmp1 A, C1) | (icmp2 A, C2) --> something simpler. - if (Val != Val2) return nullptr; + if (LHS0 != RHS0) + return nullptr; - // ICMP_[US][GL]E X, CST is folded to ICMP_[US][GL]T elsewhere. - if (LHSCC == ICmpInst::ICMP_UGE || LHSCC == ICmpInst::ICMP_ULE || - RHSCC == ICmpInst::ICMP_UGE || RHSCC == ICmpInst::ICMP_ULE || - LHSCC == ICmpInst::ICMP_SGE || LHSCC == ICmpInst::ICMP_SLE || - RHSCC == ICmpInst::ICMP_SGE || RHSCC == ICmpInst::ICMP_SLE) + // ICMP_[US][GL]E X, C is folded to ICMP_[US][GL]T elsewhere. + if (PredL == ICmpInst::ICMP_UGE || PredL == ICmpInst::ICMP_ULE || + PredR == ICmpInst::ICMP_UGE || PredR == ICmpInst::ICMP_ULE || + PredL == ICmpInst::ICMP_SGE || PredL == ICmpInst::ICMP_SLE || + PredR == ICmpInst::ICMP_SGE || PredR == ICmpInst::ICMP_SLE) return nullptr; // We can't fold (ugt x, C) | (sgt x, C2). - if (!PredicatesFoldable(LHSCC, RHSCC)) + if (!PredicatesFoldable(PredL, PredR)) return nullptr; // Ensure that the larger constant is on the RHS. bool ShouldSwap; - if (CmpInst::isSigned(LHSCC) || - (ICmpInst::isEquality(LHSCC) && - CmpInst::isSigned(RHSCC))) - ShouldSwap = LHSCst->getValue().sgt(RHSCst->getValue()); + if (CmpInst::isSigned(PredL) || + (ICmpInst::isEquality(PredL) && CmpInst::isSigned(PredR))) + ShouldSwap = LHSC->getValue().sgt(RHSC->getValue()); else - ShouldSwap = LHSCst->getValue().ugt(RHSCst->getValue()); + ShouldSwap = LHSC->getValue().ugt(RHSC->getValue()); if (ShouldSwap) { std::swap(LHS, RHS); - std::swap(LHSCst, RHSCst); - std::swap(LHSCC, RHSCC); + std::swap(LHSC, RHSC); + std::swap(PredL, PredR); } // At this point, we know we have two icmp instructions @@ -1834,127 +1787,45 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, // ICMP_EQ, ICMP_NE, ICMP_LT, and ICMP_GT here. We also know (from the // icmp folding check above), that the two constants are not // equal. - assert(LHSCst != RHSCst && "Compares not folded above?"); + assert(LHSC != RHSC && "Compares not folded above?"); - switch (LHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); + switch (PredL) { + default: + llvm_unreachable("Unknown integer condition code!"); case ICmpInst::ICMP_EQ: - switch (RHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); + switch (PredR) { + default: + llvm_unreachable("Unknown integer condition code!"); case ICmpInst::ICMP_EQ: - if (LHS->getOperand(0) == RHS->getOperand(0)) { - // if LHSCst and RHSCst differ only by one bit: - // (A == C1 || A == C2) -> (A | (C1 ^ C2)) == C2 - assert(LHSCst->getValue().ule(LHSCst->getValue())); - - APInt Xor = LHSCst->getValue() ^ RHSCst->getValue(); - if (Xor.isPowerOf2()) { - Value *Cst = Builder->getInt(Xor); - Value *Or = Builder->CreateOr(LHS->getOperand(0), Cst); - return Builder->CreateICmp(ICmpInst::ICMP_EQ, Or, RHSCst); - } - } - - if (LHSCst == SubOne(RHSCst)) { - // (X == 13 | X == 14) -> X-13 <u 2 - Constant *AddCST = ConstantExpr::getNeg(LHSCst); - Value *Add = Builder->CreateAdd(Val, AddCST, Val->getName()+".off"); - AddCST = ConstantExpr::getSub(AddOne(RHSCst), LHSCst); - return Builder->CreateICmpULT(Add, AddCST); - } - - break; // (X == 13 | X == 15) -> no change - case ICmpInst::ICMP_UGT: // (X == 13 | X u> 14) -> no change - case ICmpInst::ICMP_SGT: // (X == 13 | X s> 14) -> no change + // Potential folds for this case should already be handled. + break; + case ICmpInst::ICMP_UGT: // (X == 13 | X u> 14) -> no change + case ICmpInst::ICMP_SGT: // (X == 13 | X s> 14) -> no change break; - case ICmpInst::ICMP_NE: // (X == 13 | X != 15) -> X != 15 - case ICmpInst::ICMP_ULT: // (X == 13 | X u< 15) -> X u< 15 - case ICmpInst::ICMP_SLT: // (X == 13 | X s< 15) -> X s< 15 - return RHS; } break; - case ICmpInst::ICMP_NE: - switch (RHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_EQ: // (X != 13 | X == 15) -> X != 13 - case ICmpInst::ICMP_UGT: // (X != 13 | X u> 15) -> X != 13 - case ICmpInst::ICMP_SGT: // (X != 13 | X s> 15) -> X != 13 - return LHS; - case ICmpInst::ICMP_NE: // (X != 13 | X != 15) -> true - case ICmpInst::ICMP_ULT: // (X != 13 | X u< 15) -> true - case ICmpInst::ICMP_SLT: // (X != 13 | X s< 15) -> true - return Builder->getTrue(); - } case ICmpInst::ICMP_ULT: - switch (RHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_EQ: // (X u< 13 | X == 14) -> no change + switch (PredR) { + default: + llvm_unreachable("Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: // (X u< 13 | X == 14) -> no change break; - case ICmpInst::ICMP_UGT: // (X u< 13 | X u> 15) -> (X-13) u> 2 - // If RHSCst is [us]MAXINT, it is always false. Not handling - // this can cause overflow. - if (RHSCst->isMaxValue(false)) - return LHS; - return insertRangeTest(Val, LHSCst->getValue(), RHSCst->getValue() + 1, + case ICmpInst::ICMP_UGT: // (X u< 13 | X u> 15) -> (X-13) u> 2 + assert(!RHSC->isMaxValue(false) && "Missed icmp simplification"); + return insertRangeTest(LHS0, LHSC->getValue(), RHSC->getValue() + 1, false, false); - case ICmpInst::ICMP_SGT: // (X u< 13 | X s> 15) -> no change - break; - case ICmpInst::ICMP_NE: // (X u< 13 | X != 15) -> X != 15 - case ICmpInst::ICMP_ULT: // (X u< 13 | X u< 15) -> X u< 15 - return RHS; - case ICmpInst::ICMP_SLT: // (X u< 13 | X s< 15) -> no change - break; } break; case ICmpInst::ICMP_SLT: - switch (RHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_EQ: // (X s< 13 | X == 14) -> no change - break; - case ICmpInst::ICMP_SGT: // (X s< 13 | X s> 15) -> (X-13) s> 2 - // If RHSCst is [us]MAXINT, it is always false. Not handling - // this can cause overflow. - if (RHSCst->isMaxValue(true)) - return LHS; - return insertRangeTest(Val, LHSCst->getValue(), RHSCst->getValue() + 1, - true, false); - case ICmpInst::ICMP_UGT: // (X s< 13 | X u> 15) -> no change - break; - case ICmpInst::ICMP_NE: // (X s< 13 | X != 15) -> X != 15 - case ICmpInst::ICMP_SLT: // (X s< 13 | X s< 15) -> X s< 15 - return RHS; - case ICmpInst::ICMP_ULT: // (X s< 13 | X u< 15) -> no change - break; - } - break; - case ICmpInst::ICMP_UGT: - switch (RHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_EQ: // (X u> 13 | X == 15) -> X u> 13 - case ICmpInst::ICMP_UGT: // (X u> 13 | X u> 15) -> X u> 13 - return LHS; - case ICmpInst::ICMP_SGT: // (X u> 13 | X s> 15) -> no change - break; - case ICmpInst::ICMP_NE: // (X u> 13 | X != 15) -> true - case ICmpInst::ICMP_ULT: // (X u> 13 | X u< 15) -> true - return Builder->getTrue(); - case ICmpInst::ICMP_SLT: // (X u> 13 | X s< 15) -> no change - break; - } - break; - case ICmpInst::ICMP_SGT: - switch (RHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_EQ: // (X s> 13 | X == 15) -> X > 13 - case ICmpInst::ICMP_SGT: // (X s> 13 | X s> 15) -> X > 13 - return LHS; - case ICmpInst::ICMP_UGT: // (X s> 13 | X u> 15) -> no change - break; - case ICmpInst::ICMP_NE: // (X s> 13 | X != 15) -> true - case ICmpInst::ICMP_SLT: // (X s> 13 | X s< 15) -> true - return Builder->getTrue(); - case ICmpInst::ICMP_ULT: // (X s> 13 | X u< 15) -> no change + switch (PredR) { + default: + llvm_unreachable("Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: // (X s< 13 | X == 14) -> no change break; + case ICmpInst::ICMP_SGT: // (X s< 13 | X s> 15) -> (X-13) s> 2 + assert(!RHSC->isMaxValue(true) && "Missed icmp simplification"); + return insertRangeTest(LHS0, LHSC->getValue(), RHSC->getValue() + 1, true, + false); } break; } @@ -1963,7 +1834,7 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, /// Optimize (fcmp)|(fcmp). NOTE: Unlike the rest of instcombine, this returns /// a Value which should already be inserted into the function. -Value *InstCombiner::FoldOrOfFCmps(FCmpInst *LHS, FCmpInst *RHS) { +Value *InstCombiner::foldOrOfFCmps(FCmpInst *LHS, FCmpInst *RHS) { Value *Op0LHS = LHS->getOperand(0), *Op0RHS = LHS->getOperand(1); Value *Op1LHS = RHS->getOperand(0), *Op1RHS = RHS->getOperand(1); FCmpInst::Predicate Op0CC = LHS->getPredicate(), Op1CC = RHS->getPredicate(); @@ -1993,18 +1864,18 @@ Value *InstCombiner::FoldOrOfFCmps(FCmpInst *LHS, FCmpInst *RHS) { // If either of the constants are nans, then the whole thing returns // true. if (LHSC->getValueAPF().isNaN() || RHSC->getValueAPF().isNaN()) - return Builder->getTrue(); + return Builder.getTrue(); // Otherwise, no need to compare the two constants, compare the // rest. - return Builder->CreateFCmpUNO(LHS->getOperand(0), RHS->getOperand(0)); + return Builder.CreateFCmpUNO(LHS->getOperand(0), RHS->getOperand(0)); } // Handle vector zeros. This occurs because the canonical form of // "fcmp uno x,x" is "fcmp uno x, 0". if (isa<ConstantAggregateZero>(LHS->getOperand(1)) && isa<ConstantAggregateZero>(RHS->getOperand(1))) - return Builder->CreateFCmpUNO(LHS->getOperand(0), RHS->getOperand(0)); + return Builder.CreateFCmpUNO(LHS->getOperand(0), RHS->getOperand(0)); return nullptr; } @@ -2021,8 +1892,9 @@ Value *InstCombiner::FoldOrOfFCmps(FCmpInst *LHS, FCmpInst *RHS) { /// (A & C1) | B /// /// when the XOR of the two constants is "all ones" (-1). -Instruction *InstCombiner::FoldOrWithConstants(BinaryOperator &I, Value *Op, - Value *A, Value *B, Value *C) { +static Instruction *FoldOrWithConstants(BinaryOperator &I, Value *Op, + Value *A, Value *B, Value *C, + InstCombiner::BuilderTy &Builder) { ConstantInt *CI1 = dyn_cast<ConstantInt>(C); if (!CI1) return nullptr; @@ -2034,7 +1906,7 @@ Instruction *InstCombiner::FoldOrWithConstants(BinaryOperator &I, Value *Op, if (!Xor.isAllOnesValue()) return nullptr; if (V1 == A || V1 == B) { - Value *NewOp = Builder->CreateAnd((V1 == A) ? B : A, CI1); + Value *NewOp = Builder.CreateAnd((V1 == A) ? B : A, CI1); return BinaryOperator::CreateOr(NewOp, V1); } @@ -2043,15 +1915,16 @@ Instruction *InstCombiner::FoldOrWithConstants(BinaryOperator &I, Value *Op, /// \brief This helper function folds: /// -/// ((A | B) & C1) ^ (B & C2) +/// ((A ^ B) & C1) | (B & C2) /// /// into: /// /// (A & C1) ^ B /// /// when the XOR of the two constants is "all ones" (-1). -Instruction *InstCombiner::FoldXorWithConstants(BinaryOperator &I, Value *Op, - Value *A, Value *B, Value *C) { +static Instruction *FoldXorWithConstants(BinaryOperator &I, Value *Op, + Value *A, Value *B, Value *C, + InstCombiner::BuilderTy &Builder) { ConstantInt *CI1 = dyn_cast<ConstantInt>(C); if (!CI1) return nullptr; @@ -2066,7 +1939,7 @@ Instruction *InstCombiner::FoldXorWithConstants(BinaryOperator &I, Value *Op, return nullptr; if (V1 == A || V1 == B) { - Value *NewOp = Builder->CreateAnd(V1 == A ? B : A, CI1); + Value *NewOp = Builder.CreateAnd(V1 == A ? B : A, CI1); return BinaryOperator::CreateXor(NewOp, V1); } @@ -2083,11 +1956,7 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); - if (Value *V = SimplifyOrInst(Op0, Op1, DL, &TLI, &DT, &AC)) - return replaceInstUsesWith(I, V); - - // (A&B)|(A&C) -> A&(B|C) etc - if (Value *V = SimplifyUsingDistributiveLaws(I)) + if (Value *V = SimplifyOrInst(Op0, Op1, SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); // See if we can simplify any instructions used by the instruction whose sole @@ -2095,92 +1964,58 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { if (SimplifyDemandedInstructionBits(I)) return &I; - if (Value *V = SimplifyBSwap(I)) - return replaceInstUsesWith(I, V); + // Do this before using distributive laws to catch simple and/or/not patterns. + if (Instruction *Xor = foldOrToXor(I, Builder)) + return Xor; - if (ConstantInt *RHS = dyn_cast<ConstantInt>(Op1)) { - ConstantInt *C1 = nullptr; Value *X = nullptr; - // (X & C1) | C2 --> (X | C2) & (C1|C2) - // iff (C1 & C2) == 0. - if (match(Op0, m_And(m_Value(X), m_ConstantInt(C1))) && - (RHS->getValue() & C1->getValue()) != 0 && - Op0->hasOneUse()) { - Value *Or = Builder->CreateOr(X, RHS); - Or->takeName(Op0); - return BinaryOperator::CreateAnd(Or, - Builder->getInt(RHS->getValue() | C1->getValue())); - } + // (A&B)|(A&C) -> A&(B|C) etc + if (Value *V = SimplifyUsingDistributiveLaws(I)) + return replaceInstUsesWith(I, V); - // (X ^ C1) | C2 --> (X | C2) ^ (C1&~C2) - if (match(Op0, m_Xor(m_Value(X), m_ConstantInt(C1))) && - Op0->hasOneUse()) { - Value *Or = Builder->CreateOr(X, RHS); - Or->takeName(Op0); - return BinaryOperator::CreateXor(Or, - Builder->getInt(C1->getValue() & ~RHS->getValue())); - } + if (Value *V = SimplifyBSwap(I, Builder)) + return replaceInstUsesWith(I, V); + if (isa<Constant>(Op1)) if (Instruction *FoldedLogic = foldOpWithConstantIntoOperand(I)) return FoldedLogic; - } // Given an OR instruction, check to see if this is a bswap. if (Instruction *BSwap = MatchBSwap(I)) return BSwap; - Value *A = nullptr, *B = nullptr; - ConstantInt *C1 = nullptr, *C2 = nullptr; - - // (X^C)|Y -> (X|Y)^C iff Y&C == 0 - if (Op0->hasOneUse() && - match(Op0, m_Xor(m_Value(A), m_ConstantInt(C1))) && - MaskedValueIsZero(Op1, C1->getValue(), 0, &I)) { - Value *NOr = Builder->CreateOr(A, Op1); - NOr->takeName(Op0); - return BinaryOperator::CreateXor(NOr, C1); - } + { + Value *A; + const APInt *C; + // (X^C)|Y -> (X|Y)^C iff Y&C == 0 + if (match(Op0, m_OneUse(m_Xor(m_Value(A), m_APInt(C)))) && + MaskedValueIsZero(Op1, *C, 0, &I)) { + Value *NOr = Builder.CreateOr(A, Op1); + NOr->takeName(Op0); + return BinaryOperator::CreateXor(NOr, + ConstantInt::get(NOr->getType(), *C)); + } - // Y|(X^C) -> (X|Y)^C iff Y&C == 0 - if (Op1->hasOneUse() && - match(Op1, m_Xor(m_Value(A), m_ConstantInt(C1))) && - MaskedValueIsZero(Op0, C1->getValue(), 0, &I)) { - Value *NOr = Builder->CreateOr(A, Op0); - NOr->takeName(Op0); - return BinaryOperator::CreateXor(NOr, C1); + // Y|(X^C) -> (X|Y)^C iff Y&C == 0 + if (match(Op1, m_OneUse(m_Xor(m_Value(A), m_APInt(C)))) && + MaskedValueIsZero(Op0, *C, 0, &I)) { + Value *NOr = Builder.CreateOr(A, Op0); + NOr->takeName(Op0); + return BinaryOperator::CreateXor(NOr, + ConstantInt::get(NOr->getType(), *C)); + } } - // ((~A & B) | A) -> (A | B) - if (match(Op0, m_And(m_Not(m_Value(A)), m_Value(B))) && - match(Op1, m_Specific(A))) - return BinaryOperator::CreateOr(A, B); - - // ((A & B) | ~A) -> (~A | B) - if (match(Op0, m_And(m_Value(A), m_Value(B))) && - match(Op1, m_Not(m_Specific(A)))) - return BinaryOperator::CreateOr(Builder->CreateNot(A), B); - - // (A & ~B) | (A ^ B) -> (A ^ B) - // (~B & A) | (A ^ B) -> (A ^ B) - if (match(Op0, m_c_And(m_Value(A), m_Not(m_Value(B)))) && - match(Op1, m_Xor(m_Specific(A), m_Specific(B)))) - return BinaryOperator::CreateXor(A, B); - - // Commute the 'or' operands. - // (A ^ B) | (A & ~B) -> (A ^ B) - // (A ^ B) | (~B & A) -> (A ^ B) - if (match(Op1, m_c_And(m_Value(A), m_Not(m_Value(B)))) && - match(Op0, m_Xor(m_Specific(A), m_Specific(B)))) - return BinaryOperator::CreateXor(A, B); + Value *A, *B; // (A & C)|(B & D) Value *C = nullptr, *D = nullptr; if (match(Op0, m_And(m_Value(A), m_Value(C))) && match(Op1, m_And(m_Value(B), m_Value(D)))) { Value *V1 = nullptr, *V2 = nullptr; - C1 = dyn_cast<ConstantInt>(C); - C2 = dyn_cast<ConstantInt>(D); + ConstantInt *C1 = dyn_cast<ConstantInt>(C); + ConstantInt *C2 = dyn_cast<ConstantInt>(D); if (C1 && C2) { // (A & C1)|(B & C2) - if ((C1->getValue() & C2->getValue()) == 0) { + if ((C1->getValue() & C2->getValue()).isNullValue()) { // ((V | N) & C1) | (V & C2) --> (V|N) & (C1|C2) // iff (C1&C2) == 0 and (N&~C1) == 0 if (match(A, m_Or(m_Value(V1), m_Value(V2))) && @@ -2189,7 +2024,7 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { (V2 == B && MaskedValueIsZero(V1, ~C1->getValue(), 0, &I)))) // (N|V) return BinaryOperator::CreateAnd(A, - Builder->getInt(C1->getValue()|C2->getValue())); + Builder.getInt(C1->getValue()|C2->getValue())); // Or commutes, try both ways. if (match(B, m_Or(m_Value(V1), m_Value(V2))) && ((V1 == A && @@ -2197,18 +2032,18 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { (V2 == A && MaskedValueIsZero(V1, ~C2->getValue(), 0, &I)))) // (N|V) return BinaryOperator::CreateAnd(B, - Builder->getInt(C1->getValue()|C2->getValue())); + Builder.getInt(C1->getValue()|C2->getValue())); // ((V|C3)&C1) | ((V|C4)&C2) --> (V|C3|C4)&(C1|C2) // iff (C1&C2) == 0 and (C3&~C1) == 0 and (C4&~C2) == 0. ConstantInt *C3 = nullptr, *C4 = nullptr; if (match(A, m_Or(m_Value(V1), m_ConstantInt(C3))) && - (C3->getValue() & ~C1->getValue()) == 0 && + (C3->getValue() & ~C1->getValue()).isNullValue() && match(B, m_Or(m_Specific(V1), m_ConstantInt(C4))) && - (C4->getValue() & ~C2->getValue()) == 0) { - V2 = Builder->CreateOr(V1, ConstantExpr::getOr(C3, C4), "bitfield"); + (C4->getValue() & ~C2->getValue()).isNullValue()) { + V2 = Builder.CreateOr(V1, ConstantExpr::getOr(C3, C4), "bitfield"); return BinaryOperator::CreateAnd(V2, - Builder->getInt(C1->getValue()|C2->getValue())); + Builder.getInt(C1->getValue()|C2->getValue())); } } } @@ -2218,82 +2053,59 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { // 'or' that it is replacing. if (Op0->hasOneUse() || Op1->hasOneUse()) { // (Cond & C) | (~Cond & D) -> Cond ? C : D, and commuted variants. - if (Value *V = matchSelectFromAndOr(A, C, B, D, *Builder)) + if (Value *V = matchSelectFromAndOr(A, C, B, D, Builder)) return replaceInstUsesWith(I, V); - if (Value *V = matchSelectFromAndOr(A, C, D, B, *Builder)) + if (Value *V = matchSelectFromAndOr(A, C, D, B, Builder)) return replaceInstUsesWith(I, V); - if (Value *V = matchSelectFromAndOr(C, A, B, D, *Builder)) + if (Value *V = matchSelectFromAndOr(C, A, B, D, Builder)) return replaceInstUsesWith(I, V); - if (Value *V = matchSelectFromAndOr(C, A, D, B, *Builder)) + if (Value *V = matchSelectFromAndOr(C, A, D, B, Builder)) return replaceInstUsesWith(I, V); - if (Value *V = matchSelectFromAndOr(B, D, A, C, *Builder)) + if (Value *V = matchSelectFromAndOr(B, D, A, C, Builder)) return replaceInstUsesWith(I, V); - if (Value *V = matchSelectFromAndOr(B, D, C, A, *Builder)) + if (Value *V = matchSelectFromAndOr(B, D, C, A, Builder)) return replaceInstUsesWith(I, V); - if (Value *V = matchSelectFromAndOr(D, B, A, C, *Builder)) + if (Value *V = matchSelectFromAndOr(D, B, A, C, Builder)) return replaceInstUsesWith(I, V); - if (Value *V = matchSelectFromAndOr(D, B, C, A, *Builder)) + if (Value *V = matchSelectFromAndOr(D, B, C, A, Builder)) return replaceInstUsesWith(I, V); } - // ((A&~B)|(~A&B)) -> A^B - if ((match(C, m_Not(m_Specific(D))) && - match(B, m_Not(m_Specific(A))))) - return BinaryOperator::CreateXor(A, D); - // ((~B&A)|(~A&B)) -> A^B - if ((match(A, m_Not(m_Specific(D))) && - match(B, m_Not(m_Specific(C))))) - return BinaryOperator::CreateXor(C, D); - // ((A&~B)|(B&~A)) -> A^B - if ((match(C, m_Not(m_Specific(B))) && - match(D, m_Not(m_Specific(A))))) - return BinaryOperator::CreateXor(A, B); - // ((~B&A)|(B&~A)) -> A^B - if ((match(A, m_Not(m_Specific(B))) && - match(D, m_Not(m_Specific(C))))) - return BinaryOperator::CreateXor(C, B); - // ((A|B)&1)|(B&-2) -> (A&1) | B - if (match(A, m_Or(m_Value(V1), m_Specific(B))) || - match(A, m_Or(m_Specific(B), m_Value(V1)))) { - Instruction *Ret = FoldOrWithConstants(I, Op1, V1, B, C); - if (Ret) return Ret; + if (match(A, m_c_Or(m_Value(V1), m_Specific(B)))) { + if (Instruction *Ret = FoldOrWithConstants(I, Op1, V1, B, C, Builder)) + return Ret; } // (B&-2)|((A|B)&1) -> (A&1) | B - if (match(B, m_Or(m_Specific(A), m_Value(V1))) || - match(B, m_Or(m_Value(V1), m_Specific(A)))) { - Instruction *Ret = FoldOrWithConstants(I, Op0, A, V1, D); - if (Ret) return Ret; + if (match(B, m_c_Or(m_Specific(A), m_Value(V1)))) { + if (Instruction *Ret = FoldOrWithConstants(I, Op0, A, V1, D, Builder)) + return Ret; } // ((A^B)&1)|(B&-2) -> (A&1) ^ B - if (match(A, m_Xor(m_Value(V1), m_Specific(B))) || - match(A, m_Xor(m_Specific(B), m_Value(V1)))) { - Instruction *Ret = FoldXorWithConstants(I, Op1, V1, B, C); - if (Ret) return Ret; + if (match(A, m_c_Xor(m_Value(V1), m_Specific(B)))) { + if (Instruction *Ret = FoldXorWithConstants(I, Op1, V1, B, C, Builder)) + return Ret; } // (B&-2)|((A^B)&1) -> (A&1) ^ B - if (match(B, m_Xor(m_Specific(A), m_Value(V1))) || - match(B, m_Xor(m_Value(V1), m_Specific(A)))) { - Instruction *Ret = FoldXorWithConstants(I, Op0, A, V1, D); - if (Ret) return Ret; + if (match(B, m_c_Xor(m_Specific(A), m_Value(V1)))) { + if (Instruction *Ret = FoldXorWithConstants(I, Op0, A, V1, D, Builder)) + return Ret; } } // (A ^ B) | ((B ^ C) ^ A) -> (A ^ B) | C if (match(Op0, m_Xor(m_Value(A), m_Value(B)))) if (match(Op1, m_Xor(m_Xor(m_Specific(B), m_Value(C)), m_Specific(A)))) - if (Op1->hasOneUse() || cast<BinaryOperator>(Op1)->hasOneUse()) - return BinaryOperator::CreateOr(Op0, C); + return BinaryOperator::CreateOr(Op0, C); // ((A ^ C) ^ B) | (B ^ A) -> (B ^ A) | C if (match(Op0, m_Xor(m_Xor(m_Value(A), m_Value(C)), m_Value(B)))) if (match(Op1, m_Xor(m_Specific(B), m_Specific(A)))) - if (Op0->hasOneUse() || cast<BinaryOperator>(Op0)->hasOneUse()) - return BinaryOperator::CreateOr(Op1, C); + return BinaryOperator::CreateOr(Op1, C); // ((B | C) & A) | B -> B | (A & C) if (match(Op0, m_And(m_Or(m_Specific(Op1), m_Value(C)), m_Value(A)))) - return BinaryOperator::CreateOr(Op1, Builder->CreateAnd(A, C)); + return BinaryOperator::CreateOr(Op1, Builder.CreateAnd(A, C)); if (Instruction *DeMorgan = matchDeMorgansLaws(I, Builder)) return DeMorgan; @@ -2317,11 +2129,11 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { return BinaryOperator::CreateOr(A, B); if (Op1->hasOneUse() && match(A, m_Not(m_Specific(Op0)))) { - Value *Not = Builder->CreateNot(B, B->getName()+".not"); + Value *Not = Builder.CreateNot(B, B->getName() + ".not"); return BinaryOperator::CreateOr(Not, Op0); } if (Op1->hasOneUse() && match(B, m_Not(m_Specific(Op0)))) { - Value *Not = Builder->CreateNot(A, A->getName()+".not"); + Value *Not = Builder.CreateNot(A, A->getName() + ".not"); return BinaryOperator::CreateOr(Not, Op0); } } @@ -2335,21 +2147,10 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { B->getOpcode() == Instruction::Xor)) { Value *NotOp = Op0 == B->getOperand(0) ? B->getOperand(1) : B->getOperand(0); - Value *Not = Builder->CreateNot(NotOp, NotOp->getName()+".not"); + Value *Not = Builder.CreateNot(NotOp, NotOp->getName() + ".not"); return BinaryOperator::CreateOr(Not, Op0); } - // (A & B) | (~A ^ B) -> (~A ^ B) - // (A & B) | (B ^ ~A) -> (~A ^ B) - // (B & A) | (~A ^ B) -> (~A ^ B) - // (B & A) | (B ^ ~A) -> (~A ^ B) - // The match order is important: match the xor first because the 'not' - // operation defines 'A'. We do not need to match the xor as Op0 because the - // xor was canonicalized to Op1 above. - if (match(Op1, m_c_Xor(m_Not(m_Value(A)), m_Value(B))) && - match(Op0, m_c_And(m_Specific(A), m_Specific(B)))) - return BinaryOperator::CreateXor(Builder->CreateNot(A), B); - if (SwappedForXor) std::swap(Op0, Op1); @@ -2357,7 +2158,7 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { ICmpInst *LHS = dyn_cast<ICmpInst>(Op0); ICmpInst *RHS = dyn_cast<ICmpInst>(Op1); if (LHS && RHS) - if (Value *Res = FoldOrOfICmps(LHS, RHS, &I)) + if (Value *Res = foldOrOfICmps(LHS, RHS, I)) return replaceInstUsesWith(I, Res); // TODO: Make this recursive; it's a little tricky because an arbitrary @@ -2365,26 +2166,26 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { Value *X, *Y; if (LHS && match(Op1, m_OneUse(m_Or(m_Value(X), m_Value(Y))))) { if (auto *Cmp = dyn_cast<ICmpInst>(X)) - if (Value *Res = FoldOrOfICmps(LHS, Cmp, &I)) - return replaceInstUsesWith(I, Builder->CreateOr(Res, Y)); + if (Value *Res = foldOrOfICmps(LHS, Cmp, I)) + return replaceInstUsesWith(I, Builder.CreateOr(Res, Y)); if (auto *Cmp = dyn_cast<ICmpInst>(Y)) - if (Value *Res = FoldOrOfICmps(LHS, Cmp, &I)) - return replaceInstUsesWith(I, Builder->CreateOr(Res, X)); + if (Value *Res = foldOrOfICmps(LHS, Cmp, I)) + return replaceInstUsesWith(I, Builder.CreateOr(Res, X)); } if (RHS && match(Op0, m_OneUse(m_Or(m_Value(X), m_Value(Y))))) { if (auto *Cmp = dyn_cast<ICmpInst>(X)) - if (Value *Res = FoldOrOfICmps(Cmp, RHS, &I)) - return replaceInstUsesWith(I, Builder->CreateOr(Res, Y)); + if (Value *Res = foldOrOfICmps(Cmp, RHS, I)) + return replaceInstUsesWith(I, Builder.CreateOr(Res, Y)); if (auto *Cmp = dyn_cast<ICmpInst>(Y)) - if (Value *Res = FoldOrOfICmps(Cmp, RHS, &I)) - return replaceInstUsesWith(I, Builder->CreateOr(Res, X)); + if (Value *Res = foldOrOfICmps(Cmp, RHS, I)) + return replaceInstUsesWith(I, Builder.CreateOr(Res, X)); } } // (fcmp uno x, c) | (fcmp uno y, c) -> (fcmp uno x, y) if (FCmpInst *LHS = dyn_cast<FCmpInst>(I.getOperand(0))) if (FCmpInst *RHS = dyn_cast<FCmpInst>(I.getOperand(1))) - if (Value *Res = FoldOrOfFCmps(LHS, RHS)) + if (Value *Res = foldOrOfFCmps(LHS, RHS)) return replaceInstUsesWith(I, Res); if (Instruction *CastedOr = foldCastedBitwiseLogic(I)) @@ -2392,10 +2193,10 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { // or(sext(A), B) / or(B, sext(A)) --> A ? -1 : B, where A is i1 or <N x i1>. if (match(Op0, m_OneUse(m_SExt(m_Value(A)))) && - A->getType()->getScalarType()->isIntegerTy(1)) + A->getType()->isIntOrIntVectorTy(1)) return SelectInst::Create(A, ConstantInt::getSigned(I.getType(), -1), Op1); if (match(Op1, m_OneUse(m_SExt(m_Value(A)))) && - A->getType()->getScalarType()->isIntegerTy(1)) + A->getType()->isIntOrIntVectorTy(1)) return SelectInst::Create(A, ConstantInt::getSigned(I.getType(), -1), Op0); // Note: If we've gotten to the point of visiting the outer OR, then the @@ -2403,9 +2204,10 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { // be simplified by a later pass either, so we try swapping the inner/outer // ORs in the hopes that we'll be able to simplify it this way. // (X|C) | V --> (X|V) | C + ConstantInt *C1; if (Op0->hasOneUse() && !isa<ConstantInt>(Op1) && match(Op0, m_Or(m_Value(A), m_ConstantInt(C1)))) { - Value *Inner = Builder->CreateOr(A, Op1); + Value *Inner = Builder.CreateOr(A, Op1); Inner->takeName(Op0); return BinaryOperator::CreateOr(Inner, C1); } @@ -2418,8 +2220,8 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { if (Op0->hasOneUse() && Op1->hasOneUse() && match(Op0, m_Select(m_Value(X), m_Value(A), m_Value(B))) && match(Op1, m_Select(m_Value(Y), m_Value(C), m_Value(D))) && X == Y) { - Value *orTrue = Builder->CreateOr(A, C); - Value *orFalse = Builder->CreateOr(B, D); + Value *orTrue = Builder.CreateOr(A, C); + Value *orFalse = Builder.CreateOr(B, D); return SelectInst::Create(X, orTrue, orFalse); } } @@ -2427,6 +2229,116 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { return Changed ? &I : nullptr; } +/// A ^ B can be specified using other logic ops in a variety of patterns. We +/// can fold these early and efficiently by morphing an existing instruction. +static Instruction *foldXorToXor(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + assert(I.getOpcode() == Instruction::Xor); + Value *Op0 = I.getOperand(0); + Value *Op1 = I.getOperand(1); + Value *A, *B; + + // There are 4 commuted variants for each of the basic patterns. + + // (A & B) ^ (A | B) -> A ^ B + // (A & B) ^ (B | A) -> A ^ B + // (A | B) ^ (A & B) -> A ^ B + // (A | B) ^ (B & A) -> A ^ B + if ((match(Op0, m_And(m_Value(A), m_Value(B))) && + match(Op1, m_c_Or(m_Specific(A), m_Specific(B)))) || + (match(Op0, m_Or(m_Value(A), m_Value(B))) && + match(Op1, m_c_And(m_Specific(A), m_Specific(B))))) { + I.setOperand(0, A); + I.setOperand(1, B); + return &I; + } + + // (A | ~B) ^ (~A | B) -> A ^ B + // (~B | A) ^ (~A | B) -> A ^ B + // (~A | B) ^ (A | ~B) -> A ^ B + // (B | ~A) ^ (A | ~B) -> A ^ B + if ((match(Op0, m_Or(m_Value(A), m_Not(m_Value(B)))) && + match(Op1, m_c_Or(m_Not(m_Specific(A)), m_Specific(B)))) || + (match(Op0, m_Or(m_Not(m_Value(A)), m_Value(B))) && + match(Op1, m_c_Or(m_Specific(A), m_Not(m_Specific(B)))))) { + I.setOperand(0, A); + I.setOperand(1, B); + return &I; + } + + // (A & ~B) ^ (~A & B) -> A ^ B + // (~B & A) ^ (~A & B) -> A ^ B + // (~A & B) ^ (A & ~B) -> A ^ B + // (B & ~A) ^ (A & ~B) -> A ^ B + if ((match(Op0, m_And(m_Value(A), m_Not(m_Value(B)))) && + match(Op1, m_c_And(m_Not(m_Specific(A)), m_Specific(B)))) || + (match(Op0, m_And(m_Not(m_Value(A)), m_Value(B))) && + match(Op1, m_c_And(m_Specific(A), m_Not(m_Specific(B)))))) { + I.setOperand(0, A); + I.setOperand(1, B); + return &I; + } + + // For the remaining cases we need to get rid of one of the operands. + if (!Op0->hasOneUse() && !Op1->hasOneUse()) + return nullptr; + + // (A | B) ^ ~(A & B) -> ~(A ^ B) + // (A | B) ^ ~(B & A) -> ~(A ^ B) + // (A & B) ^ ~(A | B) -> ~(A ^ B) + // (A & B) ^ ~(B | A) -> ~(A ^ B) + // Complexity sorting ensures the not will be on the right side. + if ((match(Op0, m_Or(m_Value(A), m_Value(B))) && + match(Op1, m_Not(m_c_And(m_Specific(A), m_Specific(B))))) || + (match(Op0, m_And(m_Value(A), m_Value(B))) && + match(Op1, m_Not(m_c_Or(m_Specific(A), m_Specific(B)))))) + return BinaryOperator::CreateNot(Builder.CreateXor(A, B)); + + return nullptr; +} + +Value *InstCombiner::foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS) { + if (PredicatesFoldable(LHS->getPredicate(), RHS->getPredicate())) { + if (LHS->getOperand(0) == RHS->getOperand(1) && + LHS->getOperand(1) == RHS->getOperand(0)) + LHS->swapOperands(); + if (LHS->getOperand(0) == RHS->getOperand(0) && + LHS->getOperand(1) == RHS->getOperand(1)) { + // (icmp1 A, B) ^ (icmp2 A, B) --> (icmp3 A, B) + Value *Op0 = LHS->getOperand(0), *Op1 = LHS->getOperand(1); + unsigned Code = getICmpCode(LHS) ^ getICmpCode(RHS); + bool isSigned = LHS->isSigned() || RHS->isSigned(); + return getNewICmpValue(isSigned, Code, Op0, Op1, Builder); + } + } + + // Instead of trying to imitate the folds for and/or, decompose this 'xor' + // into those logic ops. That is, try to turn this into an and-of-icmps + // because we have many folds for that pattern. + // + // This is based on a truth table definition of xor: + // X ^ Y --> (X | Y) & !(X & Y) + if (Value *OrICmp = SimplifyBinOp(Instruction::Or, LHS, RHS, SQ)) { + // TODO: If OrICmp is true, then the definition of xor simplifies to !(X&Y). + // TODO: If OrICmp is false, the whole thing is false (InstSimplify?). + if (Value *AndICmp = SimplifyBinOp(Instruction::And, LHS, RHS, SQ)) { + // TODO: Independently handle cases where the 'and' side is a constant. + if (OrICmp == LHS && AndICmp == RHS && RHS->hasOneUse()) { + // (LHS | RHS) & !(LHS & RHS) --> LHS & !RHS + RHS->setPredicate(RHS->getInversePredicate()); + return Builder.CreateAnd(LHS, RHS); + } + if (OrICmp == RHS && AndICmp == LHS && LHS->hasOneUse()) { + // !(LHS & RHS) & (LHS | RHS) --> !LHS & RHS + LHS->setPredicate(LHS->getInversePredicate()); + return Builder.CreateAnd(LHS, RHS); + } + } + } + + return nullptr; +} + // FIXME: We use commutative matchers (m_c_*) for some, but not all, matches // here. We should standardize that construct where it is needed or choose some // other way to ensure that commutated variants of patterns are not missed. @@ -2437,9 +2349,12 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); - if (Value *V = SimplifyXorInst(Op0, Op1, DL, &TLI, &DT, &AC)) + if (Value *V = SimplifyXorInst(Op0, Op1, SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); + if (Instruction *NewXor = foldXorToXor(I, Builder)) + return NewXor; + // (A&B)^(A&C) -> A&(B^C) etc if (Value *V = SimplifyUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); @@ -2449,68 +2364,85 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { if (SimplifyDemandedInstructionBits(I)) return &I; - if (Value *V = SimplifyBSwap(I)) + if (Value *V = SimplifyBSwap(I, Builder)) return replaceInstUsesWith(I, V); - // Is this a ~ operation? - if (Value *NotOp = dyn_castNotVal(&I)) { - if (BinaryOperator *Op0I = dyn_cast<BinaryOperator>(NotOp)) { - if (Op0I->getOpcode() == Instruction::And || - Op0I->getOpcode() == Instruction::Or) { - // ~(~X & Y) --> (X | ~Y) - De Morgan's Law - // ~(~X | Y) === (X & ~Y) - De Morgan's Law - if (dyn_castNotVal(Op0I->getOperand(1))) - Op0I->swapOperands(); - if (Value *Op0NotVal = dyn_castNotVal(Op0I->getOperand(0))) { - Value *NotY = - Builder->CreateNot(Op0I->getOperand(1), - Op0I->getOperand(1)->getName()+".not"); - if (Op0I->getOpcode() == Instruction::And) - return BinaryOperator::CreateOr(Op0NotVal, NotY); - return BinaryOperator::CreateAnd(Op0NotVal, NotY); - } + // Apply DeMorgan's Law for 'nand' / 'nor' logic with an inverted operand. + Value *X, *Y; + + // We must eliminate the and/or (one-use) for these transforms to not increase + // the instruction count. + // ~(~X & Y) --> (X | ~Y) + // ~(Y & ~X) --> (X | ~Y) + if (match(&I, m_Not(m_OneUse(m_c_And(m_Not(m_Value(X)), m_Value(Y)))))) { + Value *NotY = Builder.CreateNot(Y, Y->getName() + ".not"); + return BinaryOperator::CreateOr(X, NotY); + } + // ~(~X | Y) --> (X & ~Y) + // ~(Y | ~X) --> (X & ~Y) + if (match(&I, m_Not(m_OneUse(m_c_Or(m_Not(m_Value(X)), m_Value(Y)))))) { + Value *NotY = Builder.CreateNot(Y, Y->getName() + ".not"); + return BinaryOperator::CreateAnd(X, NotY); + } + + // Is this a 'not' (~) fed by a binary operator? + BinaryOperator *NotVal; + if (match(&I, m_Not(m_BinOp(NotVal)))) { + if (NotVal->getOpcode() == Instruction::And || + NotVal->getOpcode() == Instruction::Or) { + // Apply DeMorgan's Law when inverts are free: + // ~(X & Y) --> (~X | ~Y) + // ~(X | Y) --> (~X & ~Y) + if (IsFreeToInvert(NotVal->getOperand(0), + NotVal->getOperand(0)->hasOneUse()) && + IsFreeToInvert(NotVal->getOperand(1), + NotVal->getOperand(1)->hasOneUse())) { + Value *NotX = Builder.CreateNot(NotVal->getOperand(0), "notlhs"); + Value *NotY = Builder.CreateNot(NotVal->getOperand(1), "notrhs"); + if (NotVal->getOpcode() == Instruction::And) + return BinaryOperator::CreateOr(NotX, NotY); + return BinaryOperator::CreateAnd(NotX, NotY); + } + } - // ~(X & Y) --> (~X | ~Y) - De Morgan's Law - // ~(X | Y) === (~X & ~Y) - De Morgan's Law - if (IsFreeToInvert(Op0I->getOperand(0), - Op0I->getOperand(0)->hasOneUse()) && - IsFreeToInvert(Op0I->getOperand(1), - Op0I->getOperand(1)->hasOneUse())) { - Value *NotX = - Builder->CreateNot(Op0I->getOperand(0), "notlhs"); - Value *NotY = - Builder->CreateNot(Op0I->getOperand(1), "notrhs"); - if (Op0I->getOpcode() == Instruction::And) - return BinaryOperator::CreateOr(NotX, NotY); - return BinaryOperator::CreateAnd(NotX, NotY); - } + // ~(~X >>s Y) --> (X >>s Y) + if (match(NotVal, m_AShr(m_Not(m_Value(X)), m_Value(Y)))) + return BinaryOperator::CreateAShr(X, Y); - } else if (Op0I->getOpcode() == Instruction::AShr) { - // ~(~X >>s Y) --> (X >>s Y) - if (Value *Op0NotVal = dyn_castNotVal(Op0I->getOperand(0))) - return BinaryOperator::CreateAShr(Op0NotVal, Op0I->getOperand(1)); - } + // If we are inverting a right-shifted constant, we may be able to eliminate + // the 'not' by inverting the constant and using the opposite shift type. + // Canonicalization rules ensure that only a negative constant uses 'ashr', + // but we must check that in case that transform has not fired yet. + const APInt *C; + if (match(NotVal, m_AShr(m_APInt(C), m_Value(Y))) && C->isNegative()) { + // ~(C >>s Y) --> ~C >>u Y (when inverting the replicated sign bits) + Constant *NotC = ConstantInt::get(I.getType(), ~(*C)); + return BinaryOperator::CreateLShr(NotC, Y); + } + + if (match(NotVal, m_LShr(m_APInt(C), m_Value(Y))) && C->isNonNegative()) { + // ~(C >>u Y) --> ~C >>s Y (when inverting the replicated sign bits) + Constant *NotC = ConstantInt::get(I.getType(), ~(*C)); + return BinaryOperator::CreateAShr(NotC, Y); } } - if (Constant *RHS = dyn_cast<Constant>(Op1)) { - if (RHS->isAllOnesValue() && Op0->hasOneUse()) - // xor (cmp A, B), true = not (cmp A, B) = !cmp A, B - if (CmpInst *CI = dyn_cast<CmpInst>(Op0)) - return CmpInst::Create(CI->getOpcode(), - CI->getInversePredicate(), - CI->getOperand(0), CI->getOperand(1)); + // not (cmp A, B) = !cmp A, B + CmpInst::Predicate Pred; + if (match(&I, m_Not(m_OneUse(m_Cmp(Pred, m_Value(), m_Value()))))) { + cast<CmpInst>(Op0)->setPredicate(CmpInst::getInversePredicate(Pred)); + return replaceInstUsesWith(I, Op0); } - if (ConstantInt *RHS = dyn_cast<ConstantInt>(Op1)) { + if (ConstantInt *RHSC = dyn_cast<ConstantInt>(Op1)) { // fold (xor(zext(cmp)), 1) and (xor(sext(cmp)), -1) to ext(!cmp). if (CastInst *Op0C = dyn_cast<CastInst>(Op0)) { if (CmpInst *CI = dyn_cast<CmpInst>(Op0C->getOperand(0))) { if (CI->hasOneUse() && Op0C->hasOneUse()) { Instruction::CastOps Opcode = Op0C->getOpcode(); if ((Opcode == Instruction::ZExt || Opcode == Instruction::SExt) && - (RHS == ConstantExpr::getCast(Opcode, Builder->getTrue(), - Op0C->getDestTy()))) { + (RHSC == ConstantExpr::getCast(Opcode, Builder.getTrue(), + Op0C->getDestTy()))) { CI->setPredicate(CI->getInversePredicate()); return CastInst::Create(Opcode, CI, Op0C->getType()); } @@ -2520,26 +2452,23 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { if (BinaryOperator *Op0I = dyn_cast<BinaryOperator>(Op0)) { // ~(c-X) == X-c-1 == X+(-c-1) - if (Op0I->getOpcode() == Instruction::Sub && RHS->isAllOnesValue()) + if (Op0I->getOpcode() == Instruction::Sub && RHSC->isMinusOne()) if (Constant *Op0I0C = dyn_cast<Constant>(Op0I->getOperand(0))) { Constant *NegOp0I0C = ConstantExpr::getNeg(Op0I0C); - Constant *ConstantRHS = ConstantExpr::getSub(NegOp0I0C, - ConstantInt::get(I.getType(), 1)); - return BinaryOperator::CreateAdd(Op0I->getOperand(1), ConstantRHS); + return BinaryOperator::CreateAdd(Op0I->getOperand(1), + SubOne(NegOp0I0C)); } if (ConstantInt *Op0CI = dyn_cast<ConstantInt>(Op0I->getOperand(1))) { if (Op0I->getOpcode() == Instruction::Add) { // ~(X-c) --> (-c-1)-X - if (RHS->isAllOnesValue()) { + if (RHSC->isMinusOne()) { Constant *NegOp0CI = ConstantExpr::getNeg(Op0CI); - return BinaryOperator::CreateSub( - ConstantExpr::getSub(NegOp0CI, - ConstantInt::get(I.getType(), 1)), - Op0I->getOperand(0)); - } else if (RHS->getValue().isSignBit()) { - // (X + C) ^ signbit -> (X + C + signbit) - Constant *C = Builder->getInt(RHS->getValue() + Op0CI->getValue()); + return BinaryOperator::CreateSub(SubOne(NegOp0CI), + Op0I->getOperand(0)); + } else if (RHSC->getValue().isSignMask()) { + // (X + C) ^ signmask -> (X + C + signmask) + Constant *C = Builder.getInt(RHSC->getValue() + Op0CI->getValue()); return BinaryOperator::CreateAdd(Op0I->getOperand(0), C); } @@ -2547,10 +2476,10 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { // (X|C1)^C2 -> X^(C1|C2) iff X&~C1 == 0 if (MaskedValueIsZero(Op0I->getOperand(0), Op0CI->getValue(), 0, &I)) { - Constant *NewRHS = ConstantExpr::getOr(Op0CI, RHS); + Constant *NewRHS = ConstantExpr::getOr(Op0CI, RHSC); // Anything in both C1 and C2 is known to be zero, remove it from // NewRHS. - Constant *CommonBits = ConstantExpr::getAnd(Op0CI, RHS); + Constant *CommonBits = ConstantExpr::getAnd(Op0CI, RHSC); NewRHS = ConstantExpr::getAnd(NewRHS, ConstantExpr::getNot(CommonBits)); Worklist.Add(Op0I); @@ -2568,11 +2497,11 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { E1->getOpcode() == Instruction::Xor && (C1 = dyn_cast<ConstantInt>(E1->getOperand(1)))) { // fold (C1 >> C2) ^ C3 - ConstantInt *C2 = Op0CI, *C3 = RHS; + ConstantInt *C2 = Op0CI, *C3 = RHSC; APInt FoldConst = C1->getValue().lshr(C2->getValue()); FoldConst ^= C3->getValue(); // Prepare the two operands. - Value *Opnd0 = Builder->CreateLShr(E1->getOperand(0), C2); + Value *Opnd0 = Builder.CreateLShr(E1->getOperand(0), C2); Opnd0->takeName(Op0I); cast<Instruction>(Opnd0)->setDebugLoc(I.getDebugLoc()); Value *FoldVal = ConstantInt::get(Opnd0->getType(), FoldConst); @@ -2582,27 +2511,26 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { } } } + } + if (isa<Constant>(Op1)) if (Instruction *FoldedLogic = foldOpWithConstantIntoOperand(I)) return FoldedLogic; - } - BinaryOperator *Op1I = dyn_cast<BinaryOperator>(Op1); - if (Op1I) { + { Value *A, *B; - if (match(Op1I, m_Or(m_Value(A), m_Value(B)))) { - if (A == Op0) { // B^(B|A) == (A|B)^B - Op1I->swapOperands(); - I.swapOperands(); - std::swap(Op0, Op1); - } else if (B == Op0) { // B^(A|B) == (A|B)^B + if (match(Op1, m_OneUse(m_Or(m_Value(A), m_Value(B))))) { + if (A == Op0) { // A^(A|B) == A^(B|A) + cast<BinaryOperator>(Op1)->swapOperands(); + std::swap(A, B); + } + if (B == Op0) { // A^(B|A) == (B|A)^A I.swapOperands(); // Simplified below. std::swap(Op0, Op1); } - } else if (match(Op1I, m_And(m_Value(A), m_Value(B))) && - Op1I->hasOneUse()){ + } else if (match(Op1, m_OneUse(m_And(m_Value(A), m_Value(B))))) { if (A == Op0) { // A^(A&B) -> A^(B&A) - Op1I->swapOperands(); + cast<BinaryOperator>(Op1)->swapOperands(); std::swap(A, B); } if (B == Op0) { // A^(B&A) -> (B&A)^A @@ -2612,89 +2540,53 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { } } - BinaryOperator *Op0I = dyn_cast<BinaryOperator>(Op0); - if (Op0I) { + { Value *A, *B; - if (match(Op0I, m_Or(m_Value(A), m_Value(B))) && - Op0I->hasOneUse()) { + if (match(Op0, m_OneUse(m_Or(m_Value(A), m_Value(B))))) { if (A == Op1) // (B|A)^B == (A|B)^B std::swap(A, B); if (B == Op1) // (A|B)^B == A & ~B - return BinaryOperator::CreateAnd(A, Builder->CreateNot(Op1)); - } else if (match(Op0I, m_And(m_Value(A), m_Value(B))) && - Op0I->hasOneUse()){ + return BinaryOperator::CreateAnd(A, Builder.CreateNot(Op1)); + } else if (match(Op0, m_OneUse(m_And(m_Value(A), m_Value(B))))) { if (A == Op1) // (A&B)^A -> (B&A)^A std::swap(A, B); + const APInt *C; if (B == Op1 && // (B&A)^A == ~B & A - !isa<ConstantInt>(Op1)) { // Canonical form is (B&C)^C - return BinaryOperator::CreateAnd(Builder->CreateNot(A), Op1); + !match(Op1, m_APInt(C))) { // Canonical form is (B&C)^C + return BinaryOperator::CreateAnd(Builder.CreateNot(A), Op1); } } } - if (Op0I && Op1I) { + { Value *A, *B, *C, *D; - // (A & B)^(A | B) -> A ^ B - if (match(Op0I, m_And(m_Value(A), m_Value(B))) && - match(Op1I, m_Or(m_Value(C), m_Value(D)))) { - if ((A == C && B == D) || (A == D && B == C)) - return BinaryOperator::CreateXor(A, B); - } - // (A | B)^(A & B) -> A ^ B - if (match(Op0I, m_Or(m_Value(A), m_Value(B))) && - match(Op1I, m_And(m_Value(C), m_Value(D)))) { - if ((A == C && B == D) || (A == D && B == C)) - return BinaryOperator::CreateXor(A, B); - } - // (A | ~B) ^ (~A | B) -> A ^ B - // (~B | A) ^ (~A | B) -> A ^ B - if (match(Op0I, m_c_Or(m_Value(A), m_Not(m_Value(B)))) && - match(Op1I, m_Or(m_Not(m_Specific(A)), m_Specific(B)))) - return BinaryOperator::CreateXor(A, B); - - // (~A | B) ^ (A | ~B) -> A ^ B - if (match(Op0I, m_Or(m_Not(m_Value(A)), m_Value(B))) && - match(Op1I, m_Or(m_Specific(A), m_Not(m_Specific(B))))) { - return BinaryOperator::CreateXor(A, B); - } - // (A & ~B) ^ (~A & B) -> A ^ B - // (~B & A) ^ (~A & B) -> A ^ B - if (match(Op0I, m_c_And(m_Value(A), m_Not(m_Value(B)))) && - match(Op1I, m_And(m_Not(m_Specific(A)), m_Specific(B)))) - return BinaryOperator::CreateXor(A, B); - - // (~A & B) ^ (A & ~B) -> A ^ B - if (match(Op0I, m_And(m_Not(m_Value(A)), m_Value(B))) && - match(Op1I, m_And(m_Specific(A), m_Not(m_Specific(B))))) { - return BinaryOperator::CreateXor(A, B); - } // (A ^ C)^(A | B) -> ((~A) & B) ^ C - if (match(Op0I, m_Xor(m_Value(D), m_Value(C))) && - match(Op1I, m_Or(m_Value(A), m_Value(B)))) { + if (match(Op0, m_Xor(m_Value(D), m_Value(C))) && + match(Op1, m_Or(m_Value(A), m_Value(B)))) { if (D == A) return BinaryOperator::CreateXor( - Builder->CreateAnd(Builder->CreateNot(A), B), C); + Builder.CreateAnd(Builder.CreateNot(A), B), C); if (D == B) return BinaryOperator::CreateXor( - Builder->CreateAnd(Builder->CreateNot(B), A), C); + Builder.CreateAnd(Builder.CreateNot(B), A), C); } // (A | B)^(A ^ C) -> ((~A) & B) ^ C - if (match(Op0I, m_Or(m_Value(A), m_Value(B))) && - match(Op1I, m_Xor(m_Value(D), m_Value(C)))) { + if (match(Op0, m_Or(m_Value(A), m_Value(B))) && + match(Op1, m_Xor(m_Value(D), m_Value(C)))) { if (D == A) return BinaryOperator::CreateXor( - Builder->CreateAnd(Builder->CreateNot(A), B), C); + Builder.CreateAnd(Builder.CreateNot(A), B), C); if (D == B) return BinaryOperator::CreateXor( - Builder->CreateAnd(Builder->CreateNot(B), A), C); + Builder.CreateAnd(Builder.CreateNot(B), A), C); } // (A & B) ^ (A ^ B) -> (A | B) - if (match(Op0I, m_And(m_Value(A), m_Value(B))) && - match(Op1I, m_Xor(m_Specific(A), m_Specific(B)))) + if (match(Op0, m_And(m_Value(A), m_Value(B))) && + match(Op1, m_c_Xor(m_Specific(A), m_Specific(B)))) return BinaryOperator::CreateOr(A, B); // (A ^ B) ^ (A & B) -> (A | B) - if (match(Op0I, m_Xor(m_Value(A), m_Value(B))) && - match(Op1I, m_And(m_Specific(A), m_Specific(B)))) + if (match(Op0, m_Xor(m_Value(A), m_Value(B))) && + match(Op1, m_c_And(m_Specific(A), m_Specific(B)))) return BinaryOperator::CreateOr(A, B); } @@ -2703,25 +2595,12 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { Value *A, *B; if (match(Op0, m_c_And(m_Value(A), m_Not(m_Value(B)))) && match(Op1, m_Not(m_Specific(A)))) - return BinaryOperator::CreateNot(Builder->CreateAnd(A, B)); - - // (icmp1 A, B) ^ (icmp2 A, B) --> (icmp3 A, B) - if (ICmpInst *RHS = dyn_cast<ICmpInst>(I.getOperand(1))) - if (ICmpInst *LHS = dyn_cast<ICmpInst>(I.getOperand(0))) - if (PredicatesFoldable(LHS->getPredicate(), RHS->getPredicate())) { - if (LHS->getOperand(0) == RHS->getOperand(1) && - LHS->getOperand(1) == RHS->getOperand(0)) - LHS->swapOperands(); - if (LHS->getOperand(0) == RHS->getOperand(0) && - LHS->getOperand(1) == RHS->getOperand(1)) { - Value *Op0 = LHS->getOperand(0), *Op1 = LHS->getOperand(1); - unsigned Code = getICmpCode(LHS) ^ getICmpCode(RHS); - bool isSigned = LHS->isSigned() || RHS->isSigned(); - return replaceInstUsesWith(I, - getNewICmpValue(isSigned, Code, Op0, Op1, - Builder)); - } - } + return BinaryOperator::CreateNot(Builder.CreateAnd(A, B)); + + if (auto *LHS = dyn_cast<ICmpInst>(I.getOperand(0))) + if (auto *RHS = dyn_cast<ICmpInst>(I.getOperand(1))) + if (Value *V = foldXorOfICmps(LHS, RHS)) + return replaceInstUsesWith(I, V); if (Instruction *CastedXor = foldCastedBitwiseLogic(I)) return CastedXor; diff --git a/contrib/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/contrib/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index 2ef82ba..391c430 100644 --- a/contrib/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/contrib/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -16,9 +16,9 @@ #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/None.h" -#include "llvm/ADT/Statistic.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" #include "llvm/ADT/Twine.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/MemoryBuiltins.h" @@ -44,6 +44,7 @@ #include "llvm/IR/ValueHandle.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/KnownBits.h" #include "llvm/Support/MathExtras.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/SimplifyLibCalls.h" @@ -60,6 +61,12 @@ using namespace PatternMatch; STATISTIC(NumSimplified, "Number of library calls simplified"); +static cl::opt<unsigned> UnfoldElementAtomicMemcpyMaxElements( + "unfold-element-atomic-memcpy-max-elements", + cl::init(16), + cl::desc("Maximum number of elements in atomic memcpy the optimizer is " + "allowed to unfold")); + /// Return the specified type promoted as it would be to pass though a va_arg /// area. static Type *getPromotedType(Type *Ty) { @@ -70,27 +77,6 @@ static Type *getPromotedType(Type *Ty) { return Ty; } -/// Given an aggregate type which ultimately holds a single scalar element, -/// like {{{type}}} or [1 x type], return type. -static Type *reduceToSingleValueType(Type *T) { - while (!T->isSingleValueType()) { - if (StructType *STy = dyn_cast<StructType>(T)) { - if (STy->getNumElements() == 1) - T = STy->getElementType(0); - else - break; - } else if (ArrayType *ATy = dyn_cast<ArrayType>(T)) { - if (ATy->getNumElements() == 1) - T = ATy->getElementType(); - else - break; - } else - break; - } - - return T; -} - /// Return a constant boolean vector that has true elements in all positions /// where the input constant data vector has an element with the sign bit set. static Constant *getNegativeIsTrueBoolVec(ConstantDataVector *V) { @@ -108,6 +94,83 @@ static Constant *getNegativeIsTrueBoolVec(ConstantDataVector *V) { return ConstantVector::get(BoolVec); } +Instruction *InstCombiner::SimplifyElementUnorderedAtomicMemCpy( + ElementUnorderedAtomicMemCpyInst *AMI) { + // Try to unfold this intrinsic into sequence of explicit atomic loads and + // stores. + // First check that number of elements is compile time constant. + auto *LengthCI = dyn_cast<ConstantInt>(AMI->getLength()); + if (!LengthCI) + return nullptr; + + // Check that there are not too many elements. + uint64_t LengthInBytes = LengthCI->getZExtValue(); + uint32_t ElementSizeInBytes = AMI->getElementSizeInBytes(); + uint64_t NumElements = LengthInBytes / ElementSizeInBytes; + if (NumElements >= UnfoldElementAtomicMemcpyMaxElements) + return nullptr; + + // Only expand if there are elements to copy. + if (NumElements > 0) { + // Don't unfold into illegal integers + uint64_t ElementSizeInBits = ElementSizeInBytes * 8; + if (!getDataLayout().isLegalInteger(ElementSizeInBits)) + return nullptr; + + // Cast source and destination to the correct type. Intrinsic input + // arguments are usually represented as i8*. Often operands will be + // explicitly casted to i8* and we can just strip those casts instead of + // inserting new ones. However it's easier to rely on other InstCombine + // rules which will cover trivial cases anyway. + Value *Src = AMI->getRawSource(); + Value *Dst = AMI->getRawDest(); + Type *ElementPointerType = + Type::getIntNPtrTy(AMI->getContext(), ElementSizeInBits, + Src->getType()->getPointerAddressSpace()); + + Value *SrcCasted = Builder.CreatePointerCast(Src, ElementPointerType, + "memcpy_unfold.src_casted"); + Value *DstCasted = Builder.CreatePointerCast(Dst, ElementPointerType, + "memcpy_unfold.dst_casted"); + + for (uint64_t i = 0; i < NumElements; ++i) { + // Get current element addresses + ConstantInt *ElementIdxCI = + ConstantInt::get(AMI->getContext(), APInt(64, i)); + Value *SrcElementAddr = + Builder.CreateGEP(SrcCasted, ElementIdxCI, "memcpy_unfold.src_addr"); + Value *DstElementAddr = + Builder.CreateGEP(DstCasted, ElementIdxCI, "memcpy_unfold.dst_addr"); + + // Load from the source. Transfer alignment information and mark load as + // unordered atomic. + LoadInst *Load = Builder.CreateLoad(SrcElementAddr, "memcpy_unfold.val"); + Load->setOrdering(AtomicOrdering::Unordered); + // We know alignment of the first element. It is also guaranteed by the + // verifier that element size is less or equal than first element + // alignment and both of this values are powers of two. This means that + // all subsequent accesses are at least element size aligned. + // TODO: We can infer better alignment but there is no evidence that this + // will matter. + Load->setAlignment(i == 0 ? AMI->getParamAlignment(1) + : ElementSizeInBytes); + Load->setDebugLoc(AMI->getDebugLoc()); + + // Store loaded value via unordered atomic store. + StoreInst *Store = Builder.CreateStore(Load, DstElementAddr); + Store->setOrdering(AtomicOrdering::Unordered); + Store->setAlignment(i == 0 ? AMI->getParamAlignment(0) + : ElementSizeInBytes); + Store->setDebugLoc(AMI->getDebugLoc()); + } + } + + // Set the number of elements of the copy to 0, it will be deleted on the + // next iteration. + AMI->setLength(Constant::getNullValue(LengthCI->getType())); + return AMI; +} + Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) { unsigned DstAlign = getKnownAlignment(MI->getArgOperand(0), DL, MI, &AC, &DT); unsigned SrcAlign = getKnownAlignment(MI->getArgOperand(1), DL, MI, &AC, &DT); @@ -144,41 +207,19 @@ Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) { Type *NewSrcPtrTy = PointerType::get(IntType, SrcAddrSp); Type *NewDstPtrTy = PointerType::get(IntType, DstAddrSp); - // Memcpy forces the use of i8* for the source and destination. That means - // that if you're using memcpy to move one double around, you'll get a cast - // from double* to i8*. We'd much rather use a double load+store rather than - // an i64 load+store, here because this improves the odds that the source or - // dest address will be promotable. See if we can find a better type than the - // integer datatype. - Value *StrippedDest = MI->getArgOperand(0)->stripPointerCasts(); + // If the memcpy has metadata describing the members, see if we can get the + // TBAA tag describing our copy. MDNode *CopyMD = nullptr; - if (StrippedDest != MI->getArgOperand(0)) { - Type *SrcETy = cast<PointerType>(StrippedDest->getType()) - ->getElementType(); - if (SrcETy->isSized() && DL.getTypeStoreSize(SrcETy) == Size) { - // The SrcETy might be something like {{{double}}} or [1 x double]. Rip - // down through these levels if so. - SrcETy = reduceToSingleValueType(SrcETy); - - if (SrcETy->isSingleValueType()) { - NewSrcPtrTy = PointerType::get(SrcETy, SrcAddrSp); - NewDstPtrTy = PointerType::get(SrcETy, DstAddrSp); - - // If the memcpy has metadata describing the members, see if we can - // get the TBAA tag describing our copy. - if (MDNode *M = MI->getMetadata(LLVMContext::MD_tbaa_struct)) { - if (M->getNumOperands() == 3 && M->getOperand(0) && - mdconst::hasa<ConstantInt>(M->getOperand(0)) && - mdconst::extract<ConstantInt>(M->getOperand(0))->isNullValue() && - M->getOperand(1) && - mdconst::hasa<ConstantInt>(M->getOperand(1)) && - mdconst::extract<ConstantInt>(M->getOperand(1))->getValue() == - Size && - M->getOperand(2) && isa<MDNode>(M->getOperand(2))) - CopyMD = cast<MDNode>(M->getOperand(2)); - } - } - } + if (MDNode *M = MI->getMetadata(LLVMContext::MD_tbaa_struct)) { + if (M->getNumOperands() == 3 && M->getOperand(0) && + mdconst::hasa<ConstantInt>(M->getOperand(0)) && + mdconst::extract<ConstantInt>(M->getOperand(0))->isZero() && + M->getOperand(1) && + mdconst::hasa<ConstantInt>(M->getOperand(1)) && + mdconst::extract<ConstantInt>(M->getOperand(1))->getValue() == + Size && + M->getOperand(2) && isa<MDNode>(M->getOperand(2))) + CopyMD = cast<MDNode>(M->getOperand(2)); } // If the memcpy/memmove provides better alignment info than we can @@ -186,9 +227,9 @@ Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) { SrcAlign = std::max(SrcAlign, CopyAlign); DstAlign = std::max(DstAlign, CopyAlign); - Value *Src = Builder->CreateBitCast(MI->getArgOperand(1), NewSrcPtrTy); - Value *Dest = Builder->CreateBitCast(MI->getArgOperand(0), NewDstPtrTy); - LoadInst *L = Builder->CreateLoad(Src, MI->isVolatile()); + Value *Src = Builder.CreateBitCast(MI->getArgOperand(1), NewSrcPtrTy); + Value *Dest = Builder.CreateBitCast(MI->getArgOperand(0), NewDstPtrTy); + LoadInst *L = Builder.CreateLoad(Src, MI->isVolatile()); L->setAlignment(SrcAlign); if (CopyMD) L->setMetadata(LLVMContext::MD_tbaa, CopyMD); @@ -197,7 +238,7 @@ Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) { if (LoopMemParallelMD) L->setMetadata(LLVMContext::MD_mem_parallel_loop_access, LoopMemParallelMD); - StoreInst *S = Builder->CreateStore(L, Dest, MI->isVolatile()); + StoreInst *S = Builder.CreateStore(L, Dest, MI->isVolatile()); S->setAlignment(DstAlign); if (CopyMD) S->setMetadata(LLVMContext::MD_tbaa, CopyMD); @@ -233,15 +274,15 @@ Instruction *InstCombiner::SimplifyMemSet(MemSetInst *MI) { Value *Dest = MI->getDest(); unsigned DstAddrSp = cast<PointerType>(Dest->getType())->getAddressSpace(); Type *NewDstPtrTy = PointerType::get(ITy, DstAddrSp); - Dest = Builder->CreateBitCast(Dest, NewDstPtrTy); + Dest = Builder.CreateBitCast(Dest, NewDstPtrTy); // Alignment 0 is identity for alignment 1 for memset, but not store. if (Alignment == 0) Alignment = 1; // Extract the fill value and store. uint64_t Fill = FillC->getZExtValue()*0x0101010101010101ULL; - StoreInst *S = Builder->CreateStore(ConstantInt::get(ITy, Fill), Dest, - MI->isVolatile()); + StoreInst *S = Builder.CreateStore(ConstantInt::get(ITy, Fill), Dest, + MI->isVolatile()); S->setAlignment(Alignment); // Set the size of the copy to 0, it will be deleted on the next iteration. @@ -343,7 +384,7 @@ static Value *simplifyX86immShift(const IntrinsicInst &II, for (unsigned i = 0; i != NumSubElts; ++i) { unsigned SubEltIdx = (NumSubElts - 1) - i; auto SubElt = cast<ConstantInt>(CDV->getElementAsConstant(SubEltIdx)); - Count = Count.shl(BitWidth); + Count <<= BitWidth; Count |= SubElt->getValue().zextOrTrunc(64); } } @@ -357,7 +398,7 @@ static Value *simplifyX86immShift(const IntrinsicInst &II, unsigned BitWidth = SVT->getPrimitiveSizeInBits(); // If shift-by-zero then just return the original value. - if (Count == 0) + if (Count.isNullValue()) return Vec; // Handle cases when Shift >= BitWidth. @@ -510,8 +551,131 @@ static Value *simplifyX86varShift(const IntrinsicInst &II, return Builder.CreateAShr(Vec, ShiftVec); } -static Value *simplifyX86movmsk(const IntrinsicInst &II, - InstCombiner::BuilderTy &Builder) { +static Value *simplifyX86muldq(const IntrinsicInst &II, + InstCombiner::BuilderTy &Builder) { + Value *Arg0 = II.getArgOperand(0); + Value *Arg1 = II.getArgOperand(1); + Type *ResTy = II.getType(); + assert(Arg0->getType()->getScalarSizeInBits() == 32 && + Arg1->getType()->getScalarSizeInBits() == 32 && + ResTy->getScalarSizeInBits() == 64 && "Unexpected muldq/muludq types"); + + // muldq/muludq(undef, undef) -> zero (matches generic mul behavior) + if (isa<UndefValue>(Arg0) || isa<UndefValue>(Arg1)) + return ConstantAggregateZero::get(ResTy); + + // Constant folding. + // PMULDQ = (mul(vXi64 sext(shuffle<0,2,..>(Arg0)), + // vXi64 sext(shuffle<0,2,..>(Arg1)))) + // PMULUDQ = (mul(vXi64 zext(shuffle<0,2,..>(Arg0)), + // vXi64 zext(shuffle<0,2,..>(Arg1)))) + if (!isa<Constant>(Arg0) || !isa<Constant>(Arg1)) + return nullptr; + + unsigned NumElts = ResTy->getVectorNumElements(); + assert(Arg0->getType()->getVectorNumElements() == (2 * NumElts) && + Arg1->getType()->getVectorNumElements() == (2 * NumElts) && + "Unexpected muldq/muludq types"); + + unsigned IntrinsicID = II.getIntrinsicID(); + bool IsSigned = (Intrinsic::x86_sse41_pmuldq == IntrinsicID || + Intrinsic::x86_avx2_pmul_dq == IntrinsicID || + Intrinsic::x86_avx512_pmul_dq_512 == IntrinsicID); + + SmallVector<unsigned, 16> ShuffleMask; + for (unsigned i = 0; i != NumElts; ++i) + ShuffleMask.push_back(i * 2); + + auto *LHS = Builder.CreateShuffleVector(Arg0, Arg0, ShuffleMask); + auto *RHS = Builder.CreateShuffleVector(Arg1, Arg1, ShuffleMask); + + if (IsSigned) { + LHS = Builder.CreateSExt(LHS, ResTy); + RHS = Builder.CreateSExt(RHS, ResTy); + } else { + LHS = Builder.CreateZExt(LHS, ResTy); + RHS = Builder.CreateZExt(RHS, ResTy); + } + + return Builder.CreateMul(LHS, RHS); +} + +static Value *simplifyX86pack(IntrinsicInst &II, bool IsSigned) { + Value *Arg0 = II.getArgOperand(0); + Value *Arg1 = II.getArgOperand(1); + Type *ResTy = II.getType(); + + // Fast all undef handling. + if (isa<UndefValue>(Arg0) && isa<UndefValue>(Arg1)) + return UndefValue::get(ResTy); + + Type *ArgTy = Arg0->getType(); + unsigned NumLanes = ResTy->getPrimitiveSizeInBits() / 128; + unsigned NumDstElts = ResTy->getVectorNumElements(); + unsigned NumSrcElts = ArgTy->getVectorNumElements(); + assert(NumDstElts == (2 * NumSrcElts) && "Unexpected packing types"); + + unsigned NumDstEltsPerLane = NumDstElts / NumLanes; + unsigned NumSrcEltsPerLane = NumSrcElts / NumLanes; + unsigned DstScalarSizeInBits = ResTy->getScalarSizeInBits(); + assert(ArgTy->getScalarSizeInBits() == (2 * DstScalarSizeInBits) && + "Unexpected packing types"); + + // Constant folding. + auto *Cst0 = dyn_cast<Constant>(Arg0); + auto *Cst1 = dyn_cast<Constant>(Arg1); + if (!Cst0 || !Cst1) + return nullptr; + + SmallVector<Constant *, 32> Vals; + for (unsigned Lane = 0; Lane != NumLanes; ++Lane) { + for (unsigned Elt = 0; Elt != NumDstEltsPerLane; ++Elt) { + unsigned SrcIdx = Lane * NumSrcEltsPerLane + Elt % NumSrcEltsPerLane; + auto *Cst = (Elt >= NumSrcEltsPerLane) ? Cst1 : Cst0; + auto *COp = Cst->getAggregateElement(SrcIdx); + if (COp && isa<UndefValue>(COp)) { + Vals.push_back(UndefValue::get(ResTy->getScalarType())); + continue; + } + + auto *CInt = dyn_cast_or_null<ConstantInt>(COp); + if (!CInt) + return nullptr; + + APInt Val = CInt->getValue(); + assert(Val.getBitWidth() == ArgTy->getScalarSizeInBits() && + "Unexpected constant bitwidth"); + + if (IsSigned) { + // PACKSS: Truncate signed value with signed saturation. + // Source values less than dst minint are saturated to minint. + // Source values greater than dst maxint are saturated to maxint. + if (Val.isSignedIntN(DstScalarSizeInBits)) + Val = Val.trunc(DstScalarSizeInBits); + else if (Val.isNegative()) + Val = APInt::getSignedMinValue(DstScalarSizeInBits); + else + Val = APInt::getSignedMaxValue(DstScalarSizeInBits); + } else { + // PACKUS: Truncate signed value with unsigned saturation. + // Source values less than zero are saturated to zero. + // Source values greater than dst maxuint are saturated to maxuint. + if (Val.isIntN(DstScalarSizeInBits)) + Val = Val.trunc(DstScalarSizeInBits); + else if (Val.isNegative()) + Val = APInt::getNullValue(DstScalarSizeInBits); + else + Val = APInt::getAllOnesValue(DstScalarSizeInBits); + } + + Vals.push_back(ConstantInt::get(ResTy->getScalarType(), Val)); + } + } + + return ConstantVector::get(Vals); +} + +static Value *simplifyX86movmsk(const IntrinsicInst &II) { Value *Arg = II.getArgOperand(0); Type *ResTy = II.getType(); Type *ArgTy = Arg->getType(); @@ -679,7 +843,8 @@ static Value *simplifyX86extrq(IntrinsicInst &II, Value *Op0, // Length bits. if (CI0) { APInt Elt = CI0->getValue(); - Elt = Elt.lshr(Index).zextOrTrunc(Length); + Elt.lshrInPlace(Index); + Elt = Elt.zextOrTrunc(Length); return LowConstantHighUndef(Elt.getZExtValue()); } @@ -693,7 +858,7 @@ static Value *simplifyX86extrq(IntrinsicInst &II, Value *Op0, } // Constant Fold - extraction from zero is always {zero, undef}. - if (CI0 && CI0->equalsInt(0)) + if (CI0 && CI0->isZero()) return LowConstantHighUndef(0); return nullptr; @@ -876,7 +1041,7 @@ static Value *simplifyX86vpermilvar(const IntrinsicInst &II, // The PD variants uses bit 1 to select per-lane element index, so // shift down to convert to generic shuffle mask index. if (IsPD) - Index = Index.lshr(1); + Index.lshrInPlace(1); // The _256 variants are a bit trickier since the mask bits always index // into the corresponding 128 half. In order to convert to a generic @@ -1211,42 +1376,78 @@ static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombiner &IC) { II.getIntrinsicID() == Intrinsic::ctlz) && "Expected cttz or ctlz intrinsic"); Value *Op0 = II.getArgOperand(0); - // FIXME: Try to simplify vectors of integers. - auto *IT = dyn_cast<IntegerType>(Op0->getType()); - if (!IT) - return nullptr; - unsigned BitWidth = IT->getBitWidth(); - APInt KnownZero(BitWidth, 0); - APInt KnownOne(BitWidth, 0); - IC.computeKnownBits(Op0, KnownZero, KnownOne, 0, &II); + KnownBits Known = IC.computeKnownBits(Op0, 0, &II); // Create a mask for bits above (ctlz) or below (cttz) the first known one. bool IsTZ = II.getIntrinsicID() == Intrinsic::cttz; - unsigned NumMaskBits = IsTZ ? KnownOne.countTrailingZeros() - : KnownOne.countLeadingZeros(); - APInt Mask = IsTZ ? APInt::getLowBitsSet(BitWidth, NumMaskBits) - : APInt::getHighBitsSet(BitWidth, NumMaskBits); + unsigned PossibleZeros = IsTZ ? Known.countMaxTrailingZeros() + : Known.countMaxLeadingZeros(); + unsigned DefiniteZeros = IsTZ ? Known.countMinTrailingZeros() + : Known.countMinLeadingZeros(); // If all bits above (ctlz) or below (cttz) the first known one are known // zero, this value is constant. // FIXME: This should be in InstSimplify because we're replacing an // instruction with a constant. - if ((Mask & KnownZero) == Mask) { - auto *C = ConstantInt::get(IT, APInt(BitWidth, NumMaskBits)); + if (PossibleZeros == DefiniteZeros) { + auto *C = ConstantInt::get(Op0->getType(), DefiniteZeros); return IC.replaceInstUsesWith(II, C); } // If the input to cttz/ctlz is known to be non-zero, // then change the 'ZeroIsUndef' parameter to 'true' // because we know the zero behavior can't affect the result. - if (KnownOne != 0 || isKnownNonZero(Op0, IC.getDataLayout())) { + if (!Known.One.isNullValue() || + isKnownNonZero(Op0, IC.getDataLayout(), 0, &IC.getAssumptionCache(), &II, + &IC.getDominatorTree())) { if (!match(II.getArgOperand(1), m_One())) { - II.setOperand(1, IC.Builder->getTrue()); + II.setOperand(1, IC.Builder.getTrue()); return &II; } } + // Add range metadata since known bits can't completely reflect what we know. + // TODO: Handle splat vectors. + auto *IT = dyn_cast<IntegerType>(Op0->getType()); + if (IT && IT->getBitWidth() != 1 && !II.getMetadata(LLVMContext::MD_range)) { + Metadata *LowAndHigh[] = { + ConstantAsMetadata::get(ConstantInt::get(IT, DefiniteZeros)), + ConstantAsMetadata::get(ConstantInt::get(IT, PossibleZeros + 1))}; + II.setMetadata(LLVMContext::MD_range, + MDNode::get(II.getContext(), LowAndHigh)); + return &II; + } + + return nullptr; +} + +static Instruction *foldCtpop(IntrinsicInst &II, InstCombiner &IC) { + assert(II.getIntrinsicID() == Intrinsic::ctpop && + "Expected ctpop intrinsic"); + Value *Op0 = II.getArgOperand(0); + // FIXME: Try to simplify vectors of integers. + auto *IT = dyn_cast<IntegerType>(Op0->getType()); + if (!IT) + return nullptr; + + unsigned BitWidth = IT->getBitWidth(); + KnownBits Known(BitWidth); + IC.computeKnownBits(Op0, Known, 0, &II); + + unsigned MinCount = Known.countMinPopulation(); + unsigned MaxCount = Known.countMaxPopulation(); + + // Add range metadata since known bits can't completely reflect what we know. + if (IT->getBitWidth() != 1 && !II.getMetadata(LLVMContext::MD_range)) { + Metadata *LowAndHigh[] = { + ConstantAsMetadata::get(ConstantInt::get(IT, MinCount)), + ConstantAsMetadata::get(ConstantInt::get(IT, MaxCount + 1))}; + II.setMetadata(LLVMContext::MD_range, + MDNode::get(II.getContext(), LowAndHigh)); + return &II; + } + return nullptr; } @@ -1274,7 +1475,7 @@ static Instruction *simplifyX86MaskedLoad(IntrinsicInst &II, InstCombiner &IC) { // the LLVM intrinsic definition for the pointer argument. unsigned AddrSpace = cast<PointerType>(Ptr->getType())->getAddressSpace(); PointerType *VecPtrTy = PointerType::get(II.getType(), AddrSpace); - Value *PtrCast = IC.Builder->CreateBitCast(Ptr, VecPtrTy, "castvec"); + Value *PtrCast = IC.Builder.CreateBitCast(Ptr, VecPtrTy, "castvec"); // Second, convert the x86 XMM integer vector mask to a vector of bools based // on each element's most significant bit (the sign bit). @@ -1282,7 +1483,7 @@ static Instruction *simplifyX86MaskedLoad(IntrinsicInst &II, InstCombiner &IC) { // The pass-through vector for an x86 masked load is a zero vector. CallInst *NewMaskedLoad = - IC.Builder->CreateMaskedLoad(PtrCast, 1, BoolMask, ZeroVec); + IC.Builder.CreateMaskedLoad(PtrCast, 1, BoolMask, ZeroVec); return IC.replaceInstUsesWith(II, NewMaskedLoad); } @@ -1317,19 +1518,40 @@ static bool simplifyX86MaskedStore(IntrinsicInst &II, InstCombiner &IC) { // the LLVM intrinsic definition for the pointer argument. unsigned AddrSpace = cast<PointerType>(Ptr->getType())->getAddressSpace(); PointerType *VecPtrTy = PointerType::get(Vec->getType(), AddrSpace); - Value *PtrCast = IC.Builder->CreateBitCast(Ptr, VecPtrTy, "castvec"); + Value *PtrCast = IC.Builder.CreateBitCast(Ptr, VecPtrTy, "castvec"); // Second, convert the x86 XMM integer vector mask to a vector of bools based // on each element's most significant bit (the sign bit). Constant *BoolMask = getNegativeIsTrueBoolVec(ConstMask); - IC.Builder->CreateMaskedStore(Vec, PtrCast, 1, BoolMask); + IC.Builder.CreateMaskedStore(Vec, PtrCast, 1, BoolMask); // 'Replace uses' doesn't work for stores. Erase the original masked store. IC.eraseInstFromFunction(II); return true; } +// Constant fold llvm.amdgcn.fmed3 intrinsics for standard inputs. +// +// A single NaN input is folded to minnum, so we rely on that folding for +// handling NaNs. +static APFloat fmed3AMDGCN(const APFloat &Src0, const APFloat &Src1, + const APFloat &Src2) { + APFloat Max3 = maxnum(maxnum(Src0, Src1), Src2); + + APFloat::cmpResult Cmp0 = Max3.compare(Src0); + assert(Cmp0 != APFloat::cmpUnordered && "nans handled separately"); + if (Cmp0 == APFloat::cmpEqual) + return maxnum(Src1, Src2); + + APFloat::cmpResult Cmp1 = Max3.compare(Src1); + assert(Cmp1 != APFloat::cmpUnordered && "nans handled separately"); + if (Cmp1 == APFloat::cmpEqual) + return maxnum(Src0, Src2); + + return maxnum(Src0, Src1); +} + // Returns true iff the 2 intrinsics have the same operands, limiting the // comparison to the first NumOperands. static bool haveSameOperands(const IntrinsicInst &I, const IntrinsicInst &E, @@ -1373,6 +1595,254 @@ static bool removeTriviallyEmptyRange(IntrinsicInst &I, unsigned StartID, return false; } +// Convert NVVM intrinsics to target-generic LLVM code where possible. +static Instruction *SimplifyNVVMIntrinsic(IntrinsicInst *II, InstCombiner &IC) { + // Each NVVM intrinsic we can simplify can be replaced with one of: + // + // * an LLVM intrinsic, + // * an LLVM cast operation, + // * an LLVM binary operation, or + // * ad-hoc LLVM IR for the particular operation. + + // Some transformations are only valid when the module's + // flush-denormals-to-zero (ftz) setting is true/false, whereas other + // transformations are valid regardless of the module's ftz setting. + enum FtzRequirementTy { + FTZ_Any, // Any ftz setting is ok. + FTZ_MustBeOn, // Transformation is valid only if ftz is on. + FTZ_MustBeOff, // Transformation is valid only if ftz is off. + }; + // Classes of NVVM intrinsics that can't be replaced one-to-one with a + // target-generic intrinsic, cast op, or binary op but that we can nonetheless + // simplify. + enum SpecialCase { + SPC_Reciprocal, + }; + + // SimplifyAction is a poor-man's variant (plus an additional flag) that + // represents how to replace an NVVM intrinsic with target-generic LLVM IR. + struct SimplifyAction { + // Invariant: At most one of these Optionals has a value. + Optional<Intrinsic::ID> IID; + Optional<Instruction::CastOps> CastOp; + Optional<Instruction::BinaryOps> BinaryOp; + Optional<SpecialCase> Special; + + FtzRequirementTy FtzRequirement = FTZ_Any; + + SimplifyAction() = default; + + SimplifyAction(Intrinsic::ID IID, FtzRequirementTy FtzReq) + : IID(IID), FtzRequirement(FtzReq) {} + + // Cast operations don't have anything to do with FTZ, so we skip that + // argument. + SimplifyAction(Instruction::CastOps CastOp) : CastOp(CastOp) {} + + SimplifyAction(Instruction::BinaryOps BinaryOp, FtzRequirementTy FtzReq) + : BinaryOp(BinaryOp), FtzRequirement(FtzReq) {} + + SimplifyAction(SpecialCase Special, FtzRequirementTy FtzReq) + : Special(Special), FtzRequirement(FtzReq) {} + }; + + // Try to generate a SimplifyAction describing how to replace our + // IntrinsicInstr with target-generic LLVM IR. + const SimplifyAction Action = [II]() -> SimplifyAction { + switch (II->getIntrinsicID()) { + + // NVVM intrinsics that map directly to LLVM intrinsics. + case Intrinsic::nvvm_ceil_d: + return {Intrinsic::ceil, FTZ_Any}; + case Intrinsic::nvvm_ceil_f: + return {Intrinsic::ceil, FTZ_MustBeOff}; + case Intrinsic::nvvm_ceil_ftz_f: + return {Intrinsic::ceil, FTZ_MustBeOn}; + case Intrinsic::nvvm_fabs_d: + return {Intrinsic::fabs, FTZ_Any}; + case Intrinsic::nvvm_fabs_f: + return {Intrinsic::fabs, FTZ_MustBeOff}; + case Intrinsic::nvvm_fabs_ftz_f: + return {Intrinsic::fabs, FTZ_MustBeOn}; + case Intrinsic::nvvm_floor_d: + return {Intrinsic::floor, FTZ_Any}; + case Intrinsic::nvvm_floor_f: + return {Intrinsic::floor, FTZ_MustBeOff}; + case Intrinsic::nvvm_floor_ftz_f: + return {Intrinsic::floor, FTZ_MustBeOn}; + case Intrinsic::nvvm_fma_rn_d: + return {Intrinsic::fma, FTZ_Any}; + case Intrinsic::nvvm_fma_rn_f: + return {Intrinsic::fma, FTZ_MustBeOff}; + case Intrinsic::nvvm_fma_rn_ftz_f: + return {Intrinsic::fma, FTZ_MustBeOn}; + case Intrinsic::nvvm_fmax_d: + return {Intrinsic::maxnum, FTZ_Any}; + case Intrinsic::nvvm_fmax_f: + return {Intrinsic::maxnum, FTZ_MustBeOff}; + case Intrinsic::nvvm_fmax_ftz_f: + return {Intrinsic::maxnum, FTZ_MustBeOn}; + case Intrinsic::nvvm_fmin_d: + return {Intrinsic::minnum, FTZ_Any}; + case Intrinsic::nvvm_fmin_f: + return {Intrinsic::minnum, FTZ_MustBeOff}; + case Intrinsic::nvvm_fmin_ftz_f: + return {Intrinsic::minnum, FTZ_MustBeOn}; + case Intrinsic::nvvm_round_d: + return {Intrinsic::round, FTZ_Any}; + case Intrinsic::nvvm_round_f: + return {Intrinsic::round, FTZ_MustBeOff}; + case Intrinsic::nvvm_round_ftz_f: + return {Intrinsic::round, FTZ_MustBeOn}; + case Intrinsic::nvvm_sqrt_rn_d: + return {Intrinsic::sqrt, FTZ_Any}; + case Intrinsic::nvvm_sqrt_f: + // nvvm_sqrt_f is a special case. For most intrinsics, foo_ftz_f is the + // ftz version, and foo_f is the non-ftz version. But nvvm_sqrt_f adopts + // the ftz-ness of the surrounding code. sqrt_rn_f and sqrt_rn_ftz_f are + // the versions with explicit ftz-ness. + return {Intrinsic::sqrt, FTZ_Any}; + case Intrinsic::nvvm_sqrt_rn_f: + return {Intrinsic::sqrt, FTZ_MustBeOff}; + case Intrinsic::nvvm_sqrt_rn_ftz_f: + return {Intrinsic::sqrt, FTZ_MustBeOn}; + case Intrinsic::nvvm_trunc_d: + return {Intrinsic::trunc, FTZ_Any}; + case Intrinsic::nvvm_trunc_f: + return {Intrinsic::trunc, FTZ_MustBeOff}; + case Intrinsic::nvvm_trunc_ftz_f: + return {Intrinsic::trunc, FTZ_MustBeOn}; + + // NVVM intrinsics that map to LLVM cast operations. + // + // Note that llvm's target-generic conversion operators correspond to the rz + // (round to zero) versions of the nvvm conversion intrinsics, even though + // most everything else here uses the rn (round to nearest even) nvvm ops. + case Intrinsic::nvvm_d2i_rz: + case Intrinsic::nvvm_f2i_rz: + case Intrinsic::nvvm_d2ll_rz: + case Intrinsic::nvvm_f2ll_rz: + return {Instruction::FPToSI}; + case Intrinsic::nvvm_d2ui_rz: + case Intrinsic::nvvm_f2ui_rz: + case Intrinsic::nvvm_d2ull_rz: + case Intrinsic::nvvm_f2ull_rz: + return {Instruction::FPToUI}; + case Intrinsic::nvvm_i2d_rz: + case Intrinsic::nvvm_i2f_rz: + case Intrinsic::nvvm_ll2d_rz: + case Intrinsic::nvvm_ll2f_rz: + return {Instruction::SIToFP}; + case Intrinsic::nvvm_ui2d_rz: + case Intrinsic::nvvm_ui2f_rz: + case Intrinsic::nvvm_ull2d_rz: + case Intrinsic::nvvm_ull2f_rz: + return {Instruction::UIToFP}; + + // NVVM intrinsics that map to LLVM binary ops. + case Intrinsic::nvvm_add_rn_d: + return {Instruction::FAdd, FTZ_Any}; + case Intrinsic::nvvm_add_rn_f: + return {Instruction::FAdd, FTZ_MustBeOff}; + case Intrinsic::nvvm_add_rn_ftz_f: + return {Instruction::FAdd, FTZ_MustBeOn}; + case Intrinsic::nvvm_mul_rn_d: + return {Instruction::FMul, FTZ_Any}; + case Intrinsic::nvvm_mul_rn_f: + return {Instruction::FMul, FTZ_MustBeOff}; + case Intrinsic::nvvm_mul_rn_ftz_f: + return {Instruction::FMul, FTZ_MustBeOn}; + case Intrinsic::nvvm_div_rn_d: + return {Instruction::FDiv, FTZ_Any}; + case Intrinsic::nvvm_div_rn_f: + return {Instruction::FDiv, FTZ_MustBeOff}; + case Intrinsic::nvvm_div_rn_ftz_f: + return {Instruction::FDiv, FTZ_MustBeOn}; + + // The remainder of cases are NVVM intrinsics that map to LLVM idioms, but + // need special handling. + // + // We seem to be missing intrinsics for rcp.approx.{ftz.}f32, which is just + // as well. + case Intrinsic::nvvm_rcp_rn_d: + return {SPC_Reciprocal, FTZ_Any}; + case Intrinsic::nvvm_rcp_rn_f: + return {SPC_Reciprocal, FTZ_MustBeOff}; + case Intrinsic::nvvm_rcp_rn_ftz_f: + return {SPC_Reciprocal, FTZ_MustBeOn}; + + // We do not currently simplify intrinsics that give an approximate answer. + // These include: + // + // - nvvm_cos_approx_{f,ftz_f} + // - nvvm_ex2_approx_{d,f,ftz_f} + // - nvvm_lg2_approx_{d,f,ftz_f} + // - nvvm_sin_approx_{f,ftz_f} + // - nvvm_sqrt_approx_{f,ftz_f} + // - nvvm_rsqrt_approx_{d,f,ftz_f} + // - nvvm_div_approx_{ftz_d,ftz_f,f} + // - nvvm_rcp_approx_ftz_d + // + // Ideally we'd encode them as e.g. "fast call @llvm.cos", where "fast" + // means that fastmath is enabled in the intrinsic. Unfortunately only + // binary operators (currently) have a fastmath bit in SelectionDAG, so this + // information gets lost and we can't select on it. + // + // TODO: div and rcp are lowered to a binary op, so these we could in theory + // lower them to "fast fdiv". + + default: + return {}; + } + }(); + + // If Action.FtzRequirementTy is not satisfied by the module's ftz state, we + // can bail out now. (Notice that in the case that IID is not an NVVM + // intrinsic, we don't have to look up any module metadata, as + // FtzRequirementTy will be FTZ_Any.) + if (Action.FtzRequirement != FTZ_Any) { + bool FtzEnabled = + II->getFunction()->getFnAttribute("nvptx-f32ftz").getValueAsString() == + "true"; + + if (FtzEnabled != (Action.FtzRequirement == FTZ_MustBeOn)) + return nullptr; + } + + // Simplify to target-generic intrinsic. + if (Action.IID) { + SmallVector<Value *, 4> Args(II->arg_operands()); + // All the target-generic intrinsics currently of interest to us have one + // type argument, equal to that of the nvvm intrinsic's argument. + Type *Tys[] = {II->getArgOperand(0)->getType()}; + return CallInst::Create( + Intrinsic::getDeclaration(II->getModule(), *Action.IID, Tys), Args); + } + + // Simplify to target-generic binary op. + if (Action.BinaryOp) + return BinaryOperator::Create(*Action.BinaryOp, II->getArgOperand(0), + II->getArgOperand(1), II->getName()); + + // Simplify to target-generic cast op. + if (Action.CastOp) + return CastInst::Create(*Action.CastOp, II->getArgOperand(0), II->getType(), + II->getName()); + + // All that's left are the special cases. + if (!Action.Special) + return nullptr; + + switch (*Action.Special) { + case SPC_Reciprocal: + // Simplify reciprocal. + return BinaryOperator::Create( + Instruction::FDiv, ConstantFP::get(II->getArgOperand(0)->getType(), 1), + II->getArgOperand(0), II->getName()); + } + llvm_unreachable("All SpecialCase enumerators should be handled in switch."); +} + Instruction *InstCombiner::visitVAStartInst(VAStartInst &I) { removeTriviallyEmptyRange(I, Intrinsic::vastart, Intrinsic::vaend, *this); return nullptr; @@ -1388,8 +1858,8 @@ Instruction *InstCombiner::visitVACopyInst(VACopyInst &I) { /// lifting. Instruction *InstCombiner::visitCallInst(CallInst &CI) { auto Args = CI.arg_operands(); - if (Value *V = SimplifyCall(CI.getCalledValue(), Args.begin(), Args.end(), DL, - &TLI, &DT, &AC)) + if (Value *V = SimplifyCall(&CI, CI.getCalledValue(), Args.begin(), + Args.end(), SQ.getWithInstruction(&CI))) return replaceInstUsesWith(CI, V); if (isFreeCall(&CI, &TLI)) @@ -1462,6 +1932,18 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { if (Changed) return II; } + if (auto *AMI = dyn_cast<ElementUnorderedAtomicMemCpyInst>(II)) { + if (Constant *C = dyn_cast<Constant>(AMI->getLength())) + if (C->isNullValue()) + return eraseInstFromFunction(*AMI); + + if (Instruction *I = SimplifyElementUnorderedAtomicMemCpy(AMI)) + return I; + } + + if (Instruction *I = SimplifyNVVMIntrinsic(II, *this)) + return I; + auto SimplifyDemandedVectorEltsLow = [this](Value *Op, unsigned Width, unsigned DemandedWidth) { APInt UndefElts(Width, 0); @@ -1481,16 +1963,17 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { Value *IIOperand = II->getArgOperand(0); Value *X = nullptr; + // TODO should this be in InstSimplify? // bswap(bswap(x)) -> x if (match(IIOperand, m_BSwap(m_Value(X)))) - return replaceInstUsesWith(CI, X); + return replaceInstUsesWith(CI, X); // bswap(trunc(bswap(x))) -> trunc(lshr(x, c)) if (match(IIOperand, m_Trunc(m_BSwap(m_Value(X))))) { unsigned C = X->getType()->getPrimitiveSizeInBits() - IIOperand->getType()->getPrimitiveSizeInBits(); Value *CV = ConstantInt::get(X->getType(), C); - Value *V = Builder->CreateLShr(X, CV); + Value *V = Builder.CreateLShr(X, CV); return new TruncInst(V, IIOperand->getType()); } break; @@ -1500,14 +1983,15 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { Value *IIOperand = II->getArgOperand(0); Value *X = nullptr; + // TODO should this be in InstSimplify? // bitreverse(bitreverse(x)) -> x - if (match(IIOperand, m_Intrinsic<Intrinsic::bitreverse>(m_Value(X)))) + if (match(IIOperand, m_BitReverse(m_Value(X)))) return replaceInstUsesWith(CI, X); break; } case Intrinsic::masked_load: - if (Value *SimplifiedMaskedOp = simplifyMaskedLoad(*II, *Builder)) + if (Value *SimplifiedMaskedOp = simplifyMaskedLoad(*II, Builder)) return replaceInstUsesWith(CI, SimplifiedMaskedOp); break; case Intrinsic::masked_store: @@ -1526,7 +2010,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { if (Power->isOne()) return replaceInstUsesWith(CI, II->getArgOperand(0)); // powi(x, -1) -> 1/x - if (Power->isAllOnesValue()) + if (Power->isMinusOne()) return BinaryOperator::CreateFDiv(ConstantFP::get(CI.getType(), 1.0), II->getArgOperand(0)); } @@ -1538,6 +2022,11 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return I; break; + case Intrinsic::ctpop: + if (auto *I = foldCtpop(*II, *this)) + return I; + break; + case Intrinsic::uadd_with_overflow: case Intrinsic::sadd_with_overflow: case Intrinsic::umul_with_overflow: @@ -1581,8 +2070,21 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return replaceInstUsesWith(*II, V); break; } - case Intrinsic::fma: case Intrinsic::fmuladd: { + // Canonicalize fast fmuladd to the separate fmul + fadd. + if (II->hasUnsafeAlgebra()) { + BuilderTy::FastMathFlagGuard Guard(Builder); + Builder.setFastMathFlags(II->getFastMathFlags()); + Value *Mul = Builder.CreateFMul(II->getArgOperand(0), + II->getArgOperand(1)); + Value *Add = Builder.CreateFAdd(Mul, II->getArgOperand(2)); + Add->takeName(II); + return replaceInstUsesWith(*II, Add); + } + + LLVM_FALLTHROUGH; + } + case Intrinsic::fma: { Value *Src0 = II->getArgOperand(0); Value *Src1 = II->getArgOperand(1); @@ -1626,11 +2128,31 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { Constant *LHS, *RHS; if (match(II->getArgOperand(0), m_Select(m_Value(Cond), m_Constant(LHS), m_Constant(RHS)))) { - CallInst *Call0 = Builder->CreateCall(II->getCalledFunction(), {LHS}); - CallInst *Call1 = Builder->CreateCall(II->getCalledFunction(), {RHS}); + CallInst *Call0 = Builder.CreateCall(II->getCalledFunction(), {LHS}); + CallInst *Call1 = Builder.CreateCall(II->getCalledFunction(), {RHS}); return SelectInst::Create(Cond, Call0, Call1); } + LLVM_FALLTHROUGH; + } + case Intrinsic::ceil: + case Intrinsic::floor: + case Intrinsic::round: + case Intrinsic::nearbyint: + case Intrinsic::rint: + case Intrinsic::trunc: { + Value *ExtSrc; + if (match(II->getArgOperand(0), m_FPExt(m_Value(ExtSrc))) && + II->getArgOperand(0)->hasOneUse()) { + // fabs (fpext x) -> fpext (fabs x) + Value *F = Intrinsic::getDeclaration(II->getModule(), II->getIntrinsicID(), + { ExtSrc->getType() }); + CallInst *NewFabs = Builder.CreateCall(F, ExtSrc); + NewFabs->copyFastMathFlags(II); + NewFabs->takeName(II); + return new FPExtInst(NewFabs, II->getType()); + } + break; } case Intrinsic::cos: @@ -1652,7 +2174,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // Turn PPC lvx -> load if the pointer is known aligned. if (getOrEnforceKnownAlignment(II->getArgOperand(0), 16, DL, II, &AC, &DT) >= 16) { - Value *Ptr = Builder->CreateBitCast(II->getArgOperand(0), + Value *Ptr = Builder.CreateBitCast(II->getArgOperand(0), PointerType::getUnqual(II->getType())); return new LoadInst(Ptr); } @@ -1660,8 +2182,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::ppc_vsx_lxvw4x: case Intrinsic::ppc_vsx_lxvd2x: { // Turn PPC VSX loads into normal loads. - Value *Ptr = Builder->CreateBitCast(II->getArgOperand(0), - PointerType::getUnqual(II->getType())); + Value *Ptr = Builder.CreateBitCast(II->getArgOperand(0), + PointerType::getUnqual(II->getType())); return new LoadInst(Ptr, Twine(""), false, 1); } case Intrinsic::ppc_altivec_stvx: @@ -1671,7 +2193,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { &DT) >= 16) { Type *OpPtrTy = PointerType::getUnqual(II->getArgOperand(0)->getType()); - Value *Ptr = Builder->CreateBitCast(II->getArgOperand(1), OpPtrTy); + Value *Ptr = Builder.CreateBitCast(II->getArgOperand(1), OpPtrTy); return new StoreInst(II->getArgOperand(0), Ptr); } break; @@ -1679,18 +2201,18 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::ppc_vsx_stxvd2x: { // Turn PPC VSX stores into normal stores. Type *OpPtrTy = PointerType::getUnqual(II->getArgOperand(0)->getType()); - Value *Ptr = Builder->CreateBitCast(II->getArgOperand(1), OpPtrTy); + Value *Ptr = Builder.CreateBitCast(II->getArgOperand(1), OpPtrTy); return new StoreInst(II->getArgOperand(0), Ptr, false, 1); } case Intrinsic::ppc_qpx_qvlfs: // Turn PPC QPX qvlfs -> load if the pointer is known aligned. if (getOrEnforceKnownAlignment(II->getArgOperand(0), 16, DL, II, &AC, &DT) >= 16) { - Type *VTy = VectorType::get(Builder->getFloatTy(), + Type *VTy = VectorType::get(Builder.getFloatTy(), II->getType()->getVectorNumElements()); - Value *Ptr = Builder->CreateBitCast(II->getArgOperand(0), + Value *Ptr = Builder.CreateBitCast(II->getArgOperand(0), PointerType::getUnqual(VTy)); - Value *Load = Builder->CreateLoad(Ptr); + Value *Load = Builder.CreateLoad(Ptr); return new FPExtInst(Load, II->getType()); } break; @@ -1698,7 +2220,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // Turn PPC QPX qvlfd -> load if the pointer is known aligned. if (getOrEnforceKnownAlignment(II->getArgOperand(0), 32, DL, II, &AC, &DT) >= 32) { - Value *Ptr = Builder->CreateBitCast(II->getArgOperand(0), + Value *Ptr = Builder.CreateBitCast(II->getArgOperand(0), PointerType::getUnqual(II->getType())); return new LoadInst(Ptr); } @@ -1707,11 +2229,11 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // Turn PPC QPX qvstfs -> store if the pointer is known aligned. if (getOrEnforceKnownAlignment(II->getArgOperand(1), 16, DL, II, &AC, &DT) >= 16) { - Type *VTy = VectorType::get(Builder->getFloatTy(), + Type *VTy = VectorType::get(Builder.getFloatTy(), II->getArgOperand(0)->getType()->getVectorNumElements()); - Value *TOp = Builder->CreateFPTrunc(II->getArgOperand(0), VTy); + Value *TOp = Builder.CreateFPTrunc(II->getArgOperand(0), VTy); Type *OpPtrTy = PointerType::getUnqual(VTy); - Value *Ptr = Builder->CreateBitCast(II->getArgOperand(1), OpPtrTy); + Value *Ptr = Builder.CreateBitCast(II->getArgOperand(1), OpPtrTy); return new StoreInst(TOp, Ptr); } break; @@ -1721,7 +2243,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { &DT) >= 32) { Type *OpPtrTy = PointerType::getUnqual(II->getArgOperand(0)->getType()); - Value *Ptr = Builder->CreateBitCast(II->getArgOperand(1), OpPtrTy); + Value *Ptr = Builder.CreateBitCast(II->getArgOperand(1), OpPtrTy); return new StoreInst(II->getArgOperand(0), Ptr); } break; @@ -1750,15 +2272,15 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { SmallVector<uint32_t, 8> SubVecMask; for (unsigned i = 0; i != RetWidth; ++i) SubVecMask.push_back((int)i); - VectorHalfAsShorts = Builder->CreateShuffleVector( + VectorHalfAsShorts = Builder.CreateShuffleVector( Arg, UndefValue::get(ArgType), SubVecMask); } auto VectorHalfType = VectorType::get(Type::getHalfTy(II->getContext()), RetWidth); auto VectorHalfs = - Builder->CreateBitCast(VectorHalfAsShorts, VectorHalfType); - auto VectorFloats = Builder->CreateFPExt(VectorHalfs, RetType); + Builder.CreateBitCast(VectorHalfAsShorts, VectorHalfType); + auto VectorFloats = Builder.CreateFPExt(VectorHalfs, RetType); return replaceInstUsesWith(*II, VectorFloats); } @@ -1812,7 +2334,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_avx_movmsk_pd_256: case Intrinsic::x86_avx_movmsk_ps_256: case Intrinsic::x86_avx2_pmovmskb: { - if (Value *V = simplifyX86movmsk(*II, *Builder)) + if (Value *V = simplifyX86movmsk(*II)) return replaceInstUsesWith(*II, V); break; } @@ -1863,6 +2385,37 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return II; break; } + case Intrinsic::x86_avx512_mask_cmp_pd_128: + case Intrinsic::x86_avx512_mask_cmp_pd_256: + case Intrinsic::x86_avx512_mask_cmp_pd_512: + case Intrinsic::x86_avx512_mask_cmp_ps_128: + case Intrinsic::x86_avx512_mask_cmp_ps_256: + case Intrinsic::x86_avx512_mask_cmp_ps_512: { + // Folding cmp(sub(a,b),0) -> cmp(a,b) and cmp(0,sub(a,b)) -> cmp(b,a) + Value *Arg0 = II->getArgOperand(0); + Value *Arg1 = II->getArgOperand(1); + bool Arg0IsZero = match(Arg0, m_Zero()); + if (Arg0IsZero) + std::swap(Arg0, Arg1); + Value *A, *B; + // This fold requires only the NINF(not +/- inf) since inf minus + // inf is nan. + // NSZ(No Signed Zeros) is not needed because zeros of any sign are + // equal for both compares. + // NNAN is not needed because nans compare the same for both compares. + // The compare intrinsic uses the above assumptions and therefore + // doesn't require additional flags. + if ((match(Arg0, m_OneUse(m_FSub(m_Value(A), m_Value(B)))) && + match(Arg1, m_Zero()) && + cast<Instruction>(Arg0)->getFastMathFlags().noInfs())) { + if (Arg0IsZero) + std::swap(A, B); + II->setArgOperand(0, A); + II->setArgOperand(1, B); + return II; + } + break; + } case Intrinsic::x86_avx512_mask_add_ps_512: case Intrinsic::x86_avx512_mask_div_ps_512: @@ -1884,25 +2437,25 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { default: llvm_unreachable("Case stmts out of sync!"); case Intrinsic::x86_avx512_mask_add_ps_512: case Intrinsic::x86_avx512_mask_add_pd_512: - V = Builder->CreateFAdd(Arg0, Arg1); + V = Builder.CreateFAdd(Arg0, Arg1); break; case Intrinsic::x86_avx512_mask_sub_ps_512: case Intrinsic::x86_avx512_mask_sub_pd_512: - V = Builder->CreateFSub(Arg0, Arg1); + V = Builder.CreateFSub(Arg0, Arg1); break; case Intrinsic::x86_avx512_mask_mul_ps_512: case Intrinsic::x86_avx512_mask_mul_pd_512: - V = Builder->CreateFMul(Arg0, Arg1); + V = Builder.CreateFMul(Arg0, Arg1); break; case Intrinsic::x86_avx512_mask_div_ps_512: case Intrinsic::x86_avx512_mask_div_pd_512: - V = Builder->CreateFDiv(Arg0, Arg1); + V = Builder.CreateFDiv(Arg0, Arg1); break; } // Create a select for the masking. V = emitX86MaskSelect(II->getArgOperand(3), V, II->getArgOperand(2), - *Builder); + Builder); return replaceInstUsesWith(*II, V); } } @@ -1923,27 +2476,27 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // Extract the element as scalars. Value *Arg0 = II->getArgOperand(0); Value *Arg1 = II->getArgOperand(1); - Value *LHS = Builder->CreateExtractElement(Arg0, (uint64_t)0); - Value *RHS = Builder->CreateExtractElement(Arg1, (uint64_t)0); + Value *LHS = Builder.CreateExtractElement(Arg0, (uint64_t)0); + Value *RHS = Builder.CreateExtractElement(Arg1, (uint64_t)0); Value *V; switch (II->getIntrinsicID()) { default: llvm_unreachable("Case stmts out of sync!"); case Intrinsic::x86_avx512_mask_add_ss_round: case Intrinsic::x86_avx512_mask_add_sd_round: - V = Builder->CreateFAdd(LHS, RHS); + V = Builder.CreateFAdd(LHS, RHS); break; case Intrinsic::x86_avx512_mask_sub_ss_round: case Intrinsic::x86_avx512_mask_sub_sd_round: - V = Builder->CreateFSub(LHS, RHS); + V = Builder.CreateFSub(LHS, RHS); break; case Intrinsic::x86_avx512_mask_mul_ss_round: case Intrinsic::x86_avx512_mask_mul_sd_round: - V = Builder->CreateFMul(LHS, RHS); + V = Builder.CreateFMul(LHS, RHS); break; case Intrinsic::x86_avx512_mask_div_ss_round: case Intrinsic::x86_avx512_mask_div_sd_round: - V = Builder->CreateFDiv(LHS, RHS); + V = Builder.CreateFDiv(LHS, RHS); break; } @@ -1953,18 +2506,18 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // We don't need a select if we know the mask bit is a 1. if (!C || !C->getValue()[0]) { // Cast the mask to an i1 vector and then extract the lowest element. - auto *MaskTy = VectorType::get(Builder->getInt1Ty(), + auto *MaskTy = VectorType::get(Builder.getInt1Ty(), cast<IntegerType>(Mask->getType())->getBitWidth()); - Mask = Builder->CreateBitCast(Mask, MaskTy); - Mask = Builder->CreateExtractElement(Mask, (uint64_t)0); + Mask = Builder.CreateBitCast(Mask, MaskTy); + Mask = Builder.CreateExtractElement(Mask, (uint64_t)0); // Extract the lowest element from the passthru operand. - Value *Passthru = Builder->CreateExtractElement(II->getArgOperand(2), + Value *Passthru = Builder.CreateExtractElement(II->getArgOperand(2), (uint64_t)0); - V = Builder->CreateSelect(Mask, V, Passthru); + V = Builder.CreateSelect(Mask, V, Passthru); } // Insert the result back into the original argument 0. - V = Builder->CreateInsertElement(Arg0, V, (uint64_t)0); + V = Builder.CreateInsertElement(Arg0, V, (uint64_t)0); return replaceInstUsesWith(*II, V); } @@ -2045,7 +2598,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_avx512_pslli_d_512: case Intrinsic::x86_avx512_pslli_q_512: case Intrinsic::x86_avx512_pslli_w_512: - if (Value *V = simplifyX86immShift(*II, *Builder)) + if (Value *V = simplifyX86immShift(*II, Builder)) return replaceInstUsesWith(*II, V); break; @@ -2076,7 +2629,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_avx512_psll_d_512: case Intrinsic::x86_avx512_psll_q_512: case Intrinsic::x86_avx512_psll_w_512: { - if (Value *V = simplifyX86immShift(*II, *Builder)) + if (Value *V = simplifyX86immShift(*II, Builder)) return replaceInstUsesWith(*II, V); // SSE2/AVX2 uses only the first 64-bits of the 128-bit vector @@ -2120,7 +2673,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_avx512_psrlv_w_128: case Intrinsic::x86_avx512_psrlv_w_256: case Intrinsic::x86_avx512_psrlv_w_512: - if (Value *V = simplifyX86varShift(*II, *Builder)) + if (Value *V = simplifyX86varShift(*II, Builder)) return replaceInstUsesWith(*II, V); break; @@ -2130,6 +2683,9 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_avx2_pmulu_dq: case Intrinsic::x86_avx512_pmul_dq_512: case Intrinsic::x86_avx512_pmulu_dq_512: { + if (Value *V = simplifyX86muldq(*II, Builder)) + return replaceInstUsesWith(*II, V); + unsigned VWidth = II->getType()->getVectorNumElements(); APInt UndefElts(VWidth, 0); APInt DemandedElts = APInt::getAllOnesValue(VWidth); @@ -2141,8 +2697,66 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } + case Intrinsic::x86_sse2_packssdw_128: + case Intrinsic::x86_sse2_packsswb_128: + case Intrinsic::x86_avx2_packssdw: + case Intrinsic::x86_avx2_packsswb: + case Intrinsic::x86_avx512_packssdw_512: + case Intrinsic::x86_avx512_packsswb_512: + if (Value *V = simplifyX86pack(*II, true)) + return replaceInstUsesWith(*II, V); + break; + + case Intrinsic::x86_sse2_packuswb_128: + case Intrinsic::x86_sse41_packusdw: + case Intrinsic::x86_avx2_packusdw: + case Intrinsic::x86_avx2_packuswb: + case Intrinsic::x86_avx512_packusdw_512: + case Intrinsic::x86_avx512_packuswb_512: + if (Value *V = simplifyX86pack(*II, false)) + return replaceInstUsesWith(*II, V); + break; + + case Intrinsic::x86_pclmulqdq: { + if (auto *C = dyn_cast<ConstantInt>(II->getArgOperand(2))) { + unsigned Imm = C->getZExtValue(); + + bool MadeChange = false; + Value *Arg0 = II->getArgOperand(0); + Value *Arg1 = II->getArgOperand(1); + unsigned VWidth = Arg0->getType()->getVectorNumElements(); + APInt DemandedElts(VWidth, 0); + + APInt UndefElts1(VWidth, 0); + DemandedElts = (Imm & 0x01) ? 2 : 1; + if (Value *V = SimplifyDemandedVectorElts(Arg0, DemandedElts, + UndefElts1)) { + II->setArgOperand(0, V); + MadeChange = true; + } + + APInt UndefElts2(VWidth, 0); + DemandedElts = (Imm & 0x10) ? 2 : 1; + if (Value *V = SimplifyDemandedVectorElts(Arg1, DemandedElts, + UndefElts2)) { + II->setArgOperand(1, V); + MadeChange = true; + } + + // If both input elements are undef, the result is undef. + if (UndefElts1[(Imm & 0x01) ? 1 : 0] || + UndefElts2[(Imm & 0x10) ? 1 : 0]) + return replaceInstUsesWith(*II, + ConstantAggregateZero::get(II->getType())); + + if (MadeChange) + return II; + } + break; + } + case Intrinsic::x86_sse41_insertps: - if (Value *V = simplifyX86insertps(*II, *Builder)) + if (Value *V = simplifyX86insertps(*II, Builder)) return replaceInstUsesWith(*II, V); break; @@ -2165,7 +2779,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { : nullptr; // Attempt to simplify to a constant, shuffle vector or EXTRQI call. - if (Value *V = simplifyX86extrq(*II, Op0, CILength, CIIndex, *Builder)) + if (Value *V = simplifyX86extrq(*II, Op0, CILength, CIIndex, Builder)) return replaceInstUsesWith(*II, V); // EXTRQ only uses the lowest 64-bits of the first 128-bit vector @@ -2197,7 +2811,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { ConstantInt *CIIndex = dyn_cast<ConstantInt>(II->getArgOperand(2)); // Attempt to simplify to a constant or shuffle vector. - if (Value *V = simplifyX86extrq(*II, Op0, CILength, CIIndex, *Builder)) + if (Value *V = simplifyX86extrq(*II, Op0, CILength, CIIndex, Builder)) return replaceInstUsesWith(*II, V); // EXTRQI only uses the lowest 64-bits of the first 128-bit vector @@ -2229,7 +2843,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { const APInt &V11 = CI11->getValue(); APInt Len = V11.zextOrTrunc(6); APInt Idx = V11.lshr(8).zextOrTrunc(6); - if (Value *V = simplifyX86insertq(*II, Op0, Op1, Len, Idx, *Builder)) + if (Value *V = simplifyX86insertq(*II, Op0, Op1, Len, Idx, Builder)) return replaceInstUsesWith(*II, V); } @@ -2262,7 +2876,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { if (CILength && CIIndex) { APInt Len = CILength->getValue().zextOrTrunc(6); APInt Idx = CIIndex->getValue().zextOrTrunc(6); - if (Value *V = simplifyX86insertq(*II, Op0, Op1, Len, Idx, *Builder)) + if (Value *V = simplifyX86insertq(*II, Op0, Op1, Len, Idx, Builder)) return replaceInstUsesWith(*II, V); } @@ -2316,7 +2930,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_ssse3_pshuf_b_128: case Intrinsic::x86_avx2_pshuf_b: case Intrinsic::x86_avx512_pshuf_b_512: - if (Value *V = simplifyX86pshufb(*II, *Builder)) + if (Value *V = simplifyX86pshufb(*II, Builder)) return replaceInstUsesWith(*II, V); break; @@ -2326,13 +2940,13 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_avx_vpermilvar_pd: case Intrinsic::x86_avx_vpermilvar_pd_256: case Intrinsic::x86_avx512_vpermilvar_pd_512: - if (Value *V = simplifyX86vpermilvar(*II, *Builder)) + if (Value *V = simplifyX86vpermilvar(*II, Builder)) return replaceInstUsesWith(*II, V); break; case Intrinsic::x86_avx2_permd: case Intrinsic::x86_avx2_permps: - if (Value *V = simplifyX86vpermv(*II, *Builder)) + if (Value *V = simplifyX86vpermv(*II, Builder)) return replaceInstUsesWith(*II, V); break; @@ -2350,10 +2964,10 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_avx512_mask_permvar_sf_512: case Intrinsic::x86_avx512_mask_permvar_si_256: case Intrinsic::x86_avx512_mask_permvar_si_512: - if (Value *V = simplifyX86vpermv(*II, *Builder)) { + if (Value *V = simplifyX86vpermv(*II, Builder)) { // We simplified the permuting, now create a select for the masking. V = emitX86MaskSelect(II->getArgOperand(3), V, II->getArgOperand(2), - *Builder); + Builder); return replaceInstUsesWith(*II, V); } break; @@ -2362,7 +2976,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_avx_vperm2f128_ps_256: case Intrinsic::x86_avx_vperm2f128_si_256: case Intrinsic::x86_avx2_vperm2i128: - if (Value *V = simplifyX86vperm2(*II, *Builder)) + if (Value *V = simplifyX86vperm2(*II, Builder)) return replaceInstUsesWith(*II, V); break; @@ -2395,7 +3009,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_xop_vpcomd: case Intrinsic::x86_xop_vpcomq: case Intrinsic::x86_xop_vpcomw: - if (Value *V = simplifyX86vpcom(*II, *Builder, true)) + if (Value *V = simplifyX86vpcom(*II, Builder, true)) return replaceInstUsesWith(*II, V); break; @@ -2403,7 +3017,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_xop_vpcomud: case Intrinsic::x86_xop_vpcomuq: case Intrinsic::x86_xop_vpcomuw: - if (Value *V = simplifyX86vpcom(*II, *Builder, false)) + if (Value *V = simplifyX86vpcom(*II, Builder, false)) return replaceInstUsesWith(*II, V); break; @@ -2430,10 +3044,10 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { if (AllEltsOk) { // Cast the input vectors to byte vectors. - Value *Op0 = Builder->CreateBitCast(II->getArgOperand(0), - Mask->getType()); - Value *Op1 = Builder->CreateBitCast(II->getArgOperand(1), - Mask->getType()); + Value *Op0 = Builder.CreateBitCast(II->getArgOperand(0), + Mask->getType()); + Value *Op1 = Builder.CreateBitCast(II->getArgOperand(1), + Mask->getType()); Value *Result = UndefValue::get(Op0->getType()); // Only extract each element once. @@ -2453,13 +3067,13 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { Value *Op0ToUse = (DL.isLittleEndian()) ? Op1 : Op0; Value *Op1ToUse = (DL.isLittleEndian()) ? Op0 : Op1; ExtractedElts[Idx] = - Builder->CreateExtractElement(Idx < 16 ? Op0ToUse : Op1ToUse, - Builder->getInt32(Idx&15)); + Builder.CreateExtractElement(Idx < 16 ? Op0ToUse : Op1ToUse, + Builder.getInt32(Idx&15)); } // Insert this value into the result vector. - Result = Builder->CreateInsertElement(Result, ExtractedElts[Idx], - Builder->getInt32(i)); + Result = Builder.CreateInsertElement(Result, ExtractedElts[Idx], + Builder.getInt32(i)); } return CastInst::Create(Instruction::BitCast, Result, CI.getType()); } @@ -2531,9 +3145,14 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } - case Intrinsic::amdgcn_rcp: { - if (const ConstantFP *C = dyn_cast<ConstantFP>(II->getArgOperand(0))) { + Value *Src = II->getArgOperand(0); + + // TODO: Move to ConstantFolding/InstSimplify? + if (isa<UndefValue>(Src)) + return replaceInstUsesWith(CI, Src); + + if (const ConstantFP *C = dyn_cast<ConstantFP>(Src)) { const APFloat &ArgVal = C->getValueAPF(); APFloat Val(ArgVal.getSemantics(), 1.0); APFloat::opStatus Status = Val.divide(ArgVal, @@ -2546,6 +3165,14 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } + case Intrinsic::amdgcn_rsq: { + Value *Src = II->getArgOperand(0); + + // TODO: Move to ConstantFolding/InstSimplify? + if (isa<UndefValue>(Src)) + return replaceInstUsesWith(CI, Src); + break; + } case Intrinsic::amdgcn_frexp_mant: case Intrinsic::amdgcn_frexp_exp: { Value *Src = II->getArgOperand(0); @@ -2611,7 +3238,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { if (Mask == (S_NAN | Q_NAN)) { // Equivalent of isnan. Replace with standard fcmp. - Value *FCmp = Builder->CreateFCmpUNO(Src0, Src0); + Value *FCmp = Builder.CreateFCmpUNO(Src0, Src0); FCmp->takeName(II); return replaceInstUsesWith(*II, FCmp); } @@ -2623,7 +3250,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // Clamp mask to used bits if ((Mask & FullMask) != Mask) { - CallInst *NewCall = Builder->CreateCall(II->getCalledFunction(), + CallInst *NewCall = Builder.CreateCall(II->getCalledFunction(), { Src0, ConstantInt::get(Src1->getType(), Mask & FullMask) } ); @@ -2650,6 +3277,289 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return replaceInstUsesWith(*II, ConstantInt::get(II->getType(), Result)); } + case Intrinsic::amdgcn_cvt_pkrtz: { + Value *Src0 = II->getArgOperand(0); + Value *Src1 = II->getArgOperand(1); + if (const ConstantFP *C0 = dyn_cast<ConstantFP>(Src0)) { + if (const ConstantFP *C1 = dyn_cast<ConstantFP>(Src1)) { + const fltSemantics &HalfSem + = II->getType()->getScalarType()->getFltSemantics(); + bool LosesInfo; + APFloat Val0 = C0->getValueAPF(); + APFloat Val1 = C1->getValueAPF(); + Val0.convert(HalfSem, APFloat::rmTowardZero, &LosesInfo); + Val1.convert(HalfSem, APFloat::rmTowardZero, &LosesInfo); + + Constant *Folded = ConstantVector::get({ + ConstantFP::get(II->getContext(), Val0), + ConstantFP::get(II->getContext(), Val1) }); + return replaceInstUsesWith(*II, Folded); + } + } + + if (isa<UndefValue>(Src0) && isa<UndefValue>(Src1)) + return replaceInstUsesWith(*II, UndefValue::get(II->getType())); + + break; + } + case Intrinsic::amdgcn_ubfe: + case Intrinsic::amdgcn_sbfe: { + // Decompose simple cases into standard shifts. + Value *Src = II->getArgOperand(0); + if (isa<UndefValue>(Src)) + return replaceInstUsesWith(*II, Src); + + unsigned Width; + Type *Ty = II->getType(); + unsigned IntSize = Ty->getIntegerBitWidth(); + + ConstantInt *CWidth = dyn_cast<ConstantInt>(II->getArgOperand(2)); + if (CWidth) { + Width = CWidth->getZExtValue(); + if ((Width & (IntSize - 1)) == 0) + return replaceInstUsesWith(*II, ConstantInt::getNullValue(Ty)); + + if (Width >= IntSize) { + // Hardware ignores high bits, so remove those. + II->setArgOperand(2, ConstantInt::get(CWidth->getType(), + Width & (IntSize - 1))); + return II; + } + } + + unsigned Offset; + ConstantInt *COffset = dyn_cast<ConstantInt>(II->getArgOperand(1)); + if (COffset) { + Offset = COffset->getZExtValue(); + if (Offset >= IntSize) { + II->setArgOperand(1, ConstantInt::get(COffset->getType(), + Offset & (IntSize - 1))); + return II; + } + } + + bool Signed = II->getIntrinsicID() == Intrinsic::amdgcn_sbfe; + + // TODO: Also emit sub if only width is constant. + if (!CWidth && COffset && Offset == 0) { + Constant *KSize = ConstantInt::get(COffset->getType(), IntSize); + Value *ShiftVal = Builder.CreateSub(KSize, II->getArgOperand(2)); + ShiftVal = Builder.CreateZExt(ShiftVal, II->getType()); + + Value *Shl = Builder.CreateShl(Src, ShiftVal); + Value *RightShift = Signed ? Builder.CreateAShr(Shl, ShiftVal) + : Builder.CreateLShr(Shl, ShiftVal); + RightShift->takeName(II); + return replaceInstUsesWith(*II, RightShift); + } + + if (!CWidth || !COffset) + break; + + // TODO: This allows folding to undef when the hardware has specific + // behavior? + if (Offset + Width < IntSize) { + Value *Shl = Builder.CreateShl(Src, IntSize - Offset - Width); + Value *RightShift = Signed ? Builder.CreateAShr(Shl, IntSize - Width) + : Builder.CreateLShr(Shl, IntSize - Width); + RightShift->takeName(II); + return replaceInstUsesWith(*II, RightShift); + } + + Value *RightShift = Signed ? Builder.CreateAShr(Src, Offset) + : Builder.CreateLShr(Src, Offset); + + RightShift->takeName(II); + return replaceInstUsesWith(*II, RightShift); + } + case Intrinsic::amdgcn_exp: + case Intrinsic::amdgcn_exp_compr: { + ConstantInt *En = dyn_cast<ConstantInt>(II->getArgOperand(1)); + if (!En) // Illegal. + break; + + unsigned EnBits = En->getZExtValue(); + if (EnBits == 0xf) + break; // All inputs enabled. + + bool IsCompr = II->getIntrinsicID() == Intrinsic::amdgcn_exp_compr; + bool Changed = false; + for (int I = 0; I < (IsCompr ? 2 : 4); ++I) { + if ((!IsCompr && (EnBits & (1 << I)) == 0) || + (IsCompr && ((EnBits & (0x3 << (2 * I))) == 0))) { + Value *Src = II->getArgOperand(I + 2); + if (!isa<UndefValue>(Src)) { + II->setArgOperand(I + 2, UndefValue::get(Src->getType())); + Changed = true; + } + } + } + + if (Changed) + return II; + + break; + + } + case Intrinsic::amdgcn_fmed3: { + // Note this does not preserve proper sNaN behavior if IEEE-mode is enabled + // for the shader. + + Value *Src0 = II->getArgOperand(0); + Value *Src1 = II->getArgOperand(1); + Value *Src2 = II->getArgOperand(2); + + bool Swap = false; + // Canonicalize constants to RHS operands. + // + // fmed3(c0, x, c1) -> fmed3(x, c0, c1) + if (isa<Constant>(Src0) && !isa<Constant>(Src1)) { + std::swap(Src0, Src1); + Swap = true; + } + + if (isa<Constant>(Src1) && !isa<Constant>(Src2)) { + std::swap(Src1, Src2); + Swap = true; + } + + if (isa<Constant>(Src0) && !isa<Constant>(Src1)) { + std::swap(Src0, Src1); + Swap = true; + } + + if (Swap) { + II->setArgOperand(0, Src0); + II->setArgOperand(1, Src1); + II->setArgOperand(2, Src2); + return II; + } + + if (match(Src2, m_NaN()) || isa<UndefValue>(Src2)) { + CallInst *NewCall = Builder.CreateMinNum(Src0, Src1); + NewCall->copyFastMathFlags(II); + NewCall->takeName(II); + return replaceInstUsesWith(*II, NewCall); + } + + if (const ConstantFP *C0 = dyn_cast<ConstantFP>(Src0)) { + if (const ConstantFP *C1 = dyn_cast<ConstantFP>(Src1)) { + if (const ConstantFP *C2 = dyn_cast<ConstantFP>(Src2)) { + APFloat Result = fmed3AMDGCN(C0->getValueAPF(), C1->getValueAPF(), + C2->getValueAPF()); + return replaceInstUsesWith(*II, + ConstantFP::get(Builder.getContext(), Result)); + } + } + } + + break; + } + case Intrinsic::amdgcn_icmp: + case Intrinsic::amdgcn_fcmp: { + const ConstantInt *CC = dyn_cast<ConstantInt>(II->getArgOperand(2)); + if (!CC) + break; + + // Guard against invalid arguments. + int64_t CCVal = CC->getZExtValue(); + bool IsInteger = II->getIntrinsicID() == Intrinsic::amdgcn_icmp; + if ((IsInteger && (CCVal < CmpInst::FIRST_ICMP_PREDICATE || + CCVal > CmpInst::LAST_ICMP_PREDICATE)) || + (!IsInteger && (CCVal < CmpInst::FIRST_FCMP_PREDICATE || + CCVal > CmpInst::LAST_FCMP_PREDICATE))) + break; + + Value *Src0 = II->getArgOperand(0); + Value *Src1 = II->getArgOperand(1); + + if (auto *CSrc0 = dyn_cast<Constant>(Src0)) { + if (auto *CSrc1 = dyn_cast<Constant>(Src1)) { + Constant *CCmp = ConstantExpr::getCompare(CCVal, CSrc0, CSrc1); + if (CCmp->isNullValue()) { + return replaceInstUsesWith( + *II, ConstantExpr::getSExt(CCmp, II->getType())); + } + + // The result of V_ICMP/V_FCMP assembly instructions (which this + // intrinsic exposes) is one bit per thread, masked with the EXEC + // register (which contains the bitmask of live threads). So a + // comparison that always returns true is the same as a read of the + // EXEC register. + Value *NewF = Intrinsic::getDeclaration( + II->getModule(), Intrinsic::read_register, II->getType()); + Metadata *MDArgs[] = {MDString::get(II->getContext(), "exec")}; + MDNode *MD = MDNode::get(II->getContext(), MDArgs); + Value *Args[] = {MetadataAsValue::get(II->getContext(), MD)}; + CallInst *NewCall = Builder.CreateCall(NewF, Args); + NewCall->addAttribute(AttributeList::FunctionIndex, + Attribute::Convergent); + NewCall->takeName(II); + return replaceInstUsesWith(*II, NewCall); + } + + // Canonicalize constants to RHS. + CmpInst::Predicate SwapPred + = CmpInst::getSwappedPredicate(static_cast<CmpInst::Predicate>(CCVal)); + II->setArgOperand(0, Src1); + II->setArgOperand(1, Src0); + II->setArgOperand(2, ConstantInt::get(CC->getType(), + static_cast<int>(SwapPred))); + return II; + } + + if (CCVal != CmpInst::ICMP_EQ && CCVal != CmpInst::ICMP_NE) + break; + + // Canonicalize compare eq with true value to compare != 0 + // llvm.amdgcn.icmp(zext (i1 x), 1, eq) + // -> llvm.amdgcn.icmp(zext (i1 x), 0, ne) + // llvm.amdgcn.icmp(sext (i1 x), -1, eq) + // -> llvm.amdgcn.icmp(sext (i1 x), 0, ne) + Value *ExtSrc; + if (CCVal == CmpInst::ICMP_EQ && + ((match(Src1, m_One()) && match(Src0, m_ZExt(m_Value(ExtSrc)))) || + (match(Src1, m_AllOnes()) && match(Src0, m_SExt(m_Value(ExtSrc))))) && + ExtSrc->getType()->isIntegerTy(1)) { + II->setArgOperand(1, ConstantInt::getNullValue(Src1->getType())); + II->setArgOperand(2, ConstantInt::get(CC->getType(), CmpInst::ICMP_NE)); + return II; + } + + CmpInst::Predicate SrcPred; + Value *SrcLHS; + Value *SrcRHS; + + // Fold compare eq/ne with 0 from a compare result as the predicate to the + // intrinsic. The typical use is a wave vote function in the library, which + // will be fed from a user code condition compared with 0. Fold in the + // redundant compare. + + // llvm.amdgcn.icmp([sz]ext ([if]cmp pred a, b), 0, ne) + // -> llvm.amdgcn.[if]cmp(a, b, pred) + // + // llvm.amdgcn.icmp([sz]ext ([if]cmp pred a, b), 0, eq) + // -> llvm.amdgcn.[if]cmp(a, b, inv pred) + if (match(Src1, m_Zero()) && + match(Src0, + m_ZExtOrSExt(m_Cmp(SrcPred, m_Value(SrcLHS), m_Value(SrcRHS))))) { + if (CCVal == CmpInst::ICMP_EQ) + SrcPred = CmpInst::getInversePredicate(SrcPred); + + Intrinsic::ID NewIID = CmpInst::isFPPredicate(SrcPred) ? + Intrinsic::amdgcn_fcmp : Intrinsic::amdgcn_icmp; + + Value *NewF = Intrinsic::getDeclaration(II->getModule(), NewIID, + SrcLHS->getType()); + Value *Args[] = { SrcLHS, SrcRHS, + ConstantInt::get(CC->getType(), SrcPred) }; + CallInst *NewCall = Builder.CreateCall(NewF, Args); + NewCall->takeName(II); + return replaceInstUsesWith(*II, NewCall); + } + + break; + } case Intrinsic::stackrestore: { // If the save is right next to the restore, remove the restore. This can // happen when variable allocas are DCE'd. @@ -2720,16 +3630,14 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // the InstCombineIRInserter object. Value *AssumeIntrinsic = II->getCalledValue(), *A, *B; if (match(IIOperand, m_And(m_Value(A), m_Value(B)))) { - Builder->CreateCall(AssumeIntrinsic, A, II->getName()); - Builder->CreateCall(AssumeIntrinsic, B, II->getName()); + Builder.CreateCall(AssumeIntrinsic, A, II->getName()); + Builder.CreateCall(AssumeIntrinsic, B, II->getName()); return eraseInstFromFunction(*II); } // assume(!(a || b)) -> assume(!a); assume(!b); if (match(IIOperand, m_Not(m_Or(m_Value(A), m_Value(B))))) { - Builder->CreateCall(AssumeIntrinsic, Builder->CreateNot(A), - II->getName()); - Builder->CreateCall(AssumeIntrinsic, Builder->CreateNot(B), - II->getName()); + Builder.CreateCall(AssumeIntrinsic, Builder.CreateNot(A), II->getName()); + Builder.CreateCall(AssumeIntrinsic, Builder.CreateNot(B), II->getName()); return eraseInstFromFunction(*II); } @@ -2751,9 +3659,9 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // If there is a dominating assume with the same condition as this one, // then this one is redundant, and should be removed. - APInt KnownZero(1, 0), KnownOne(1, 0); - computeKnownBits(IIOperand, KnownZero, KnownOne, 0, II); - if (KnownOne.isAllOnesValue()) + KnownBits Known(1); + computeKnownBits(IIOperand, Known, 0, II); + if (Known.isAllOnes()) return eraseInstFromFunction(*II); // Update the cache of affected values for this assumption (we might be @@ -2790,7 +3698,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // isKnownNonNull -> nonnull attribute if (isKnownNonNullAt(DerivedPtr, II, &DT)) - II->addAttribute(AttributeSet::ReturnIndex, Attribute::NonNull); + II->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull); } // TODO: bitcast(relocate(p)) -> relocate(bitcast(p)) @@ -2799,11 +3707,38 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // TODO: relocate((gep p, C, C2, ...)) -> gep(relocate(p), C, C2, ...) break; } - } + case Intrinsic::experimental_guard: { + // Is this guard followed by another guard? + Instruction *NextInst = II->getNextNode(); + Value *NextCond = nullptr; + if (match(NextInst, + m_Intrinsic<Intrinsic::experimental_guard>(m_Value(NextCond)))) { + Value *CurrCond = II->getArgOperand(0); + + // Remove a guard that it is immediately preceded by an identical guard. + if (CurrCond == NextCond) + return eraseInstFromFunction(*NextInst); + + // Otherwise canonicalize guard(a); guard(b) -> guard(a & b). + II->setArgOperand(0, Builder.CreateAnd(CurrCond, NextCond)); + return eraseInstFromFunction(*NextInst); + } + break; + } + } return visitCallSite(II); } +// Fence instruction simplification +Instruction *InstCombiner::visitFenceInst(FenceInst &FI) { + // Remove identical consecutive fences. + if (auto *NFI = dyn_cast<FenceInst>(FI.getNextNode())) + if (FI.isIdenticalTo(NFI)) + return eraseInstFromFunction(FI); + return nullptr; +} + // InvokeInst simplification // Instruction *InstCombiner::visitInvokeInst(InvokeInst &II) { @@ -2945,24 +3880,24 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) { // Mark any parameters that are known to be non-null with the nonnull // attribute. This is helpful for inlining calls to functions with null // checks on their arguments. - SmallVector<unsigned, 4> Indices; + SmallVector<unsigned, 4> ArgNos; unsigned ArgNo = 0; for (Value *V : CS.args()) { if (V->getType()->isPointerTy() && - !CS.paramHasAttr(ArgNo + 1, Attribute::NonNull) && + !CS.paramHasAttr(ArgNo, Attribute::NonNull) && isKnownNonNullAt(V, CS.getInstruction(), &DT)) - Indices.push_back(ArgNo + 1); + ArgNos.push_back(ArgNo); ArgNo++; } assert(ArgNo == CS.arg_size() && "sanity check"); - if (!Indices.empty()) { - AttributeSet AS = CS.getAttributes(); + if (!ArgNos.empty()) { + AttributeList AS = CS.getAttributes(); LLVMContext &Ctx = CS.getInstruction()->getContext(); - AS = AS.addAttribute(Ctx, Indices, - Attribute::get(Ctx, Attribute::NonNull)); + AS = AS.addParamAttribute(Ctx, ArgNos, + Attribute::get(Ctx, Attribute::NonNull)); CS.setAttributes(AS); Changed = true; } @@ -3081,7 +4016,7 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { return false; Instruction *Caller = CS.getInstruction(); - const AttributeSet &CallerPAL = CS.getAttributes(); + const AttributeList &CallerPAL = CS.getAttributes(); // Okay, this is a cast from a function to a different type. Unless doing so // would cause a type conversion of one of our arguments, change this call to @@ -3108,7 +4043,7 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { } if (!CallerPAL.isEmpty() && !Caller->use_empty()) { - AttrBuilder RAttrs(CallerPAL, AttributeSet::ReturnIndex); + AttrBuilder RAttrs(CallerPAL, AttributeList::ReturnIndex); if (RAttrs.overlaps(AttributeFuncs::typeIncompatible(NewRetTy))) return false; // Attribute not compatible with transformed value. } @@ -3149,8 +4084,8 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { if (!CastInst::isBitOrNoopPointerCastable(ActTy, ParamTy, DL)) return false; // Cannot transform this parameter value. - if (AttrBuilder(CallerPAL.getParamAttributes(i + 1), i + 1). - overlaps(AttributeFuncs::typeIncompatible(ParamTy))) + if (AttrBuilder(CallerPAL.getParamAttributes(i)) + .overlaps(AttributeFuncs::typeIncompatible(ParamTy))) return false; // Attribute not compatible with transformed value. if (CS.isInAllocaArgument(i)) @@ -3158,9 +4093,7 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { // If the parameter is passed as a byval argument, then we have to have a // sized type and the sized type has to have the same size as the old type. - if (ParamTy != ActTy && - CallerPAL.getParamAttributes(i + 1).hasAttribute(i + 1, - Attribute::ByVal)) { + if (ParamTy != ActTy && CallerPAL.hasParamAttribute(i, Attribute::ByVal)) { PointerType *ParamPTy = dyn_cast<PointerType>(ParamTy); if (!ParamPTy || !ParamPTy->getElementType()->isSized()) return false; @@ -3195,62 +4128,49 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { } if (FT->getNumParams() < NumActualArgs && FT->isVarArg() && - !CallerPAL.isEmpty()) + !CallerPAL.isEmpty()) { // In this case we have more arguments than the new function type, but we // won't be dropping them. Check that these extra arguments have attributes // that are compatible with being a vararg call argument. - for (unsigned i = CallerPAL.getNumSlots(); i; --i) { - unsigned Index = CallerPAL.getSlotIndex(i - 1); - if (Index <= FT->getNumParams()) - break; - - // Check if it has an attribute that's incompatible with varargs. - AttributeSet PAttrs = CallerPAL.getSlotAttributes(i - 1); - if (PAttrs.hasAttribute(Index, Attribute::StructRet)) - return false; - } - + unsigned SRetIdx; + if (CallerPAL.hasAttrSomewhere(Attribute::StructRet, &SRetIdx) && + SRetIdx > FT->getNumParams()) + return false; + } // Okay, we decided that this is a safe thing to do: go ahead and start // inserting cast instructions as necessary. - std::vector<Value*> Args; + SmallVector<Value *, 8> Args; + SmallVector<AttributeSet, 8> ArgAttrs; Args.reserve(NumActualArgs); - SmallVector<AttributeSet, 8> attrVec; - attrVec.reserve(NumCommonArgs); + ArgAttrs.reserve(NumActualArgs); // Get any return attributes. - AttrBuilder RAttrs(CallerPAL, AttributeSet::ReturnIndex); + AttrBuilder RAttrs(CallerPAL, AttributeList::ReturnIndex); // If the return value is not being used, the type may not be compatible // with the existing attributes. Wipe out any problematic attributes. RAttrs.remove(AttributeFuncs::typeIncompatible(NewRetTy)); - // Add the new return attributes. - if (RAttrs.hasAttributes()) - attrVec.push_back(AttributeSet::get(Caller->getContext(), - AttributeSet::ReturnIndex, RAttrs)); - AI = CS.arg_begin(); for (unsigned i = 0; i != NumCommonArgs; ++i, ++AI) { Type *ParamTy = FT->getParamType(i); - if ((*AI)->getType() == ParamTy) { - Args.push_back(*AI); - } else { - Args.push_back(Builder->CreateBitOrPointerCast(*AI, ParamTy)); - } + Value *NewArg = *AI; + if ((*AI)->getType() != ParamTy) + NewArg = Builder.CreateBitOrPointerCast(*AI, ParamTy); + Args.push_back(NewArg); // Add any parameter attributes. - AttrBuilder PAttrs(CallerPAL.getParamAttributes(i + 1), i + 1); - if (PAttrs.hasAttributes()) - attrVec.push_back(AttributeSet::get(Caller->getContext(), i + 1, - PAttrs)); + ArgAttrs.push_back(CallerPAL.getParamAttributes(i)); } // If the function takes more arguments than the call was taking, add them // now. - for (unsigned i = NumCommonArgs; i != FT->getNumParams(); ++i) + for (unsigned i = NumCommonArgs; i != FT->getNumParams(); ++i) { Args.push_back(Constant::getNullValue(FT->getParamType(i))); + ArgAttrs.push_back(AttributeSet()); + } // If we are removing arguments to the function, emit an obnoxious warning. if (FT->getNumParams() < NumActualArgs) { @@ -3259,54 +4179,56 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { // Add all of the arguments in their promoted form to the arg list. for (unsigned i = FT->getNumParams(); i != NumActualArgs; ++i, ++AI) { Type *PTy = getPromotedType((*AI)->getType()); + Value *NewArg = *AI; if (PTy != (*AI)->getType()) { // Must promote to pass through va_arg area! Instruction::CastOps opcode = CastInst::getCastOpcode(*AI, false, PTy, false); - Args.push_back(Builder->CreateCast(opcode, *AI, PTy)); - } else { - Args.push_back(*AI); + NewArg = Builder.CreateCast(opcode, *AI, PTy); } + Args.push_back(NewArg); // Add any parameter attributes. - AttrBuilder PAttrs(CallerPAL.getParamAttributes(i + 1), i + 1); - if (PAttrs.hasAttributes()) - attrVec.push_back(AttributeSet::get(FT->getContext(), i + 1, - PAttrs)); + ArgAttrs.push_back(CallerPAL.getParamAttributes(i)); } } } AttributeSet FnAttrs = CallerPAL.getFnAttributes(); - if (CallerPAL.hasAttributes(AttributeSet::FunctionIndex)) - attrVec.push_back(AttributeSet::get(Callee->getContext(), FnAttrs)); if (NewRetTy->isVoidTy()) Caller->setName(""); // Void type should not have a name. - const AttributeSet &NewCallerPAL = AttributeSet::get(Callee->getContext(), - attrVec); + assert((ArgAttrs.size() == FT->getNumParams() || FT->isVarArg()) && + "missing argument attributes"); + LLVMContext &Ctx = Callee->getContext(); + AttributeList NewCallerPAL = AttributeList::get( + Ctx, FnAttrs, AttributeSet::get(Ctx, RAttrs), ArgAttrs); SmallVector<OperandBundleDef, 1> OpBundles; CS.getOperandBundlesAsDefs(OpBundles); - Instruction *NC; + CallSite NewCS; if (InvokeInst *II = dyn_cast<InvokeInst>(Caller)) { - NC = Builder->CreateInvoke(Callee, II->getNormalDest(), II->getUnwindDest(), - Args, OpBundles); - NC->takeName(II); - cast<InvokeInst>(NC)->setCallingConv(II->getCallingConv()); - cast<InvokeInst>(NC)->setAttributes(NewCallerPAL); + NewCS = Builder.CreateInvoke(Callee, II->getNormalDest(), + II->getUnwindDest(), Args, OpBundles); } else { - CallInst *CI = cast<CallInst>(Caller); - NC = Builder->CreateCall(Callee, Args, OpBundles); - NC->takeName(CI); - cast<CallInst>(NC)->setTailCallKind(CI->getTailCallKind()); - cast<CallInst>(NC)->setCallingConv(CI->getCallingConv()); - cast<CallInst>(NC)->setAttributes(NewCallerPAL); + NewCS = Builder.CreateCall(Callee, Args, OpBundles); + cast<CallInst>(NewCS.getInstruction()) + ->setTailCallKind(cast<CallInst>(Caller)->getTailCallKind()); } + NewCS->takeName(Caller); + NewCS.setCallingConv(CS.getCallingConv()); + NewCS.setAttributes(NewCallerPAL); + + // Preserve the weight metadata for the new call instruction. The metadata + // is used by SamplePGO to check callsite's hotness. + uint64_t W; + if (Caller->extractProfTotalWeight(W)) + NewCS->setProfWeight(W); // Insert a cast of the return type as necessary. + Instruction *NC = NewCS.getInstruction(); Value *NV = NC; if (OldRetTy != NV->getType() && !Caller->use_empty()) { if (!NV->getType()->isVoidTy()) { @@ -3351,7 +4273,7 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS, Value *Callee = CS.getCalledValue(); PointerType *PTy = cast<PointerType>(Callee->getType()); FunctionType *FTy = cast<FunctionType>(PTy->getElementType()); - const AttributeSet &Attrs = CS.getAttributes(); + AttributeList Attrs = CS.getAttributes(); // If the call already has the 'nest' attribute somewhere then give up - // otherwise 'nest' would occur twice after splicing in the chain. @@ -3364,50 +4286,46 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS, Function *NestF =cast<Function>(Tramp->getArgOperand(1)->stripPointerCasts()); FunctionType *NestFTy = cast<FunctionType>(NestF->getValueType()); - const AttributeSet &NestAttrs = NestF->getAttributes(); + AttributeList NestAttrs = NestF->getAttributes(); if (!NestAttrs.isEmpty()) { - unsigned NestIdx = 1; + unsigned NestArgNo = 0; Type *NestTy = nullptr; AttributeSet NestAttr; // Look for a parameter marked with the 'nest' attribute. for (FunctionType::param_iterator I = NestFTy->param_begin(), - E = NestFTy->param_end(); I != E; ++NestIdx, ++I) - if (NestAttrs.hasAttribute(NestIdx, Attribute::Nest)) { + E = NestFTy->param_end(); + I != E; ++NestArgNo, ++I) { + AttributeSet AS = NestAttrs.getParamAttributes(NestArgNo); + if (AS.hasAttribute(Attribute::Nest)) { // Record the parameter type and any other attributes. NestTy = *I; - NestAttr = NestAttrs.getParamAttributes(NestIdx); + NestAttr = AS; break; } + } if (NestTy) { Instruction *Caller = CS.getInstruction(); std::vector<Value*> NewArgs; + std::vector<AttributeSet> NewArgAttrs; NewArgs.reserve(CS.arg_size() + 1); - - SmallVector<AttributeSet, 8> NewAttrs; - NewAttrs.reserve(Attrs.getNumSlots() + 1); + NewArgAttrs.reserve(CS.arg_size()); // Insert the nest argument into the call argument list, which may // mean appending it. Likewise for attributes. - // Add any result attributes. - if (Attrs.hasAttributes(AttributeSet::ReturnIndex)) - NewAttrs.push_back(AttributeSet::get(Caller->getContext(), - Attrs.getRetAttributes())); - { - unsigned Idx = 1; + unsigned ArgNo = 0; CallSite::arg_iterator I = CS.arg_begin(), E = CS.arg_end(); do { - if (Idx == NestIdx) { + if (ArgNo == NestArgNo) { // Add the chain argument and attributes. Value *NestVal = Tramp->getArgOperand(2); if (NestVal->getType() != NestTy) - NestVal = Builder->CreateBitCast(NestVal, NestTy, "nest"); + NestVal = Builder.CreateBitCast(NestVal, NestTy, "nest"); NewArgs.push_back(NestVal); - NewAttrs.push_back(AttributeSet::get(Caller->getContext(), - NestAttr)); + NewArgAttrs.push_back(NestAttr); } if (I == E) @@ -3415,23 +4333,13 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS, // Add the original argument and attributes. NewArgs.push_back(*I); - AttributeSet Attr = Attrs.getParamAttributes(Idx); - if (Attr.hasAttributes(Idx)) { - AttrBuilder B(Attr, Idx); - NewAttrs.push_back(AttributeSet::get(Caller->getContext(), - Idx + (Idx >= NestIdx), B)); - } + NewArgAttrs.push_back(Attrs.getParamAttributes(ArgNo)); - ++Idx; + ++ArgNo; ++I; } while (true); } - // Add any function attributes. - if (Attrs.hasAttributes(AttributeSet::FunctionIndex)) - NewAttrs.push_back(AttributeSet::get(FTy->getContext(), - Attrs.getFnAttributes())); - // The trampoline may have been bitcast to a bogus type (FTy). // Handle this by synthesizing a new function type, equal to FTy // with the chain parameter inserted. @@ -3442,12 +4350,12 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS, // Insert the chain's type into the list of parameter types, which may // mean appending it. { - unsigned Idx = 1; + unsigned ArgNo = 0; FunctionType::param_iterator I = FTy->param_begin(), E = FTy->param_end(); do { - if (Idx == NestIdx) + if (ArgNo == NestArgNo) // Add the chain's type. NewTypes.push_back(NestTy); @@ -3457,7 +4365,7 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS, // Add the original type. NewTypes.push_back(*I); - ++Idx; + ++ArgNo; ++I; } while (true); } @@ -3470,8 +4378,9 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS, NestF->getType() == PointerType::getUnqual(NewFTy) ? NestF : ConstantExpr::getBitCast(NestF, PointerType::getUnqual(NewFTy)); - const AttributeSet &NewPAL = - AttributeSet::get(FTy->getContext(), NewAttrs); + AttributeList NewPAL = + AttributeList::get(FTy->getContext(), Attrs.getFnAttributes(), + Attrs.getRetAttributes(), NewArgAttrs); SmallVector<OperandBundleDef, 1> OpBundles; CS.getOperandBundlesAsDefs(OpBundles); diff --git a/contrib/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/contrib/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index e74b590..dfdfd3e 100644 --- a/contrib/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/contrib/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -14,9 +14,10 @@ #include "InstCombineInternal.h" #include "llvm/ADT/SetVector.h" #include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/PatternMatch.h" -#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Support/KnownBits.h" using namespace llvm; using namespace PatternMatch; @@ -83,7 +84,7 @@ Instruction *InstCombiner::PromoteCastOfAllocation(BitCastInst &CI, AllocaInst &AI) { PointerType *PTy = cast<PointerType>(CI.getType()); - BuilderTy AllocaBuilder(*Builder); + BuilderTy AllocaBuilder(Builder); AllocaBuilder.SetInsertPoint(&AI); // Get the type really allocated and the type casted to. @@ -274,12 +275,12 @@ Instruction *InstCombiner::commonCastTransforms(CastInst &CI) { return NV; // If we are casting a PHI, then fold the cast into the PHI. - if (isa<PHINode>(Src)) { + if (auto *PN = dyn_cast<PHINode>(Src)) { // Don't do this if it would create a PHI node with an illegal type from a // legal type. if (!Src->getType()->isIntegerTy() || !CI.getType()->isIntegerTy() || - ShouldChangeType(CI.getType(), Src->getType())) - if (Instruction *NV = FoldOpIntoPhi(CI)) + shouldChangeType(CI.getType(), Src->getType())) + if (Instruction *NV = foldOpIntoPhi(CI, PN)) return NV; } @@ -405,8 +406,7 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombiner &IC, /// trunc (lshr (bitcast <4 x i32> %X to i128), 32) to i32 /// ---> /// extractelement <4 x i32> %X, 1 -static Instruction *foldVecTruncToExtElt(TruncInst &Trunc, InstCombiner &IC, - const DataLayout &DL) { +static Instruction *foldVecTruncToExtElt(TruncInst &Trunc, InstCombiner &IC) { Value *TruncOp = Trunc.getOperand(0); Type *DestType = Trunc.getType(); if (!TruncOp->hasOneUse() || !isa<IntegerType>(DestType)) @@ -433,21 +433,21 @@ static Instruction *foldVecTruncToExtElt(TruncInst &Trunc, InstCombiner &IC, unsigned NumVecElts = VecWidth / DestWidth; if (VecType->getElementType() != DestType) { VecType = VectorType::get(DestType, NumVecElts); - VecInput = IC.Builder->CreateBitCast(VecInput, VecType, "bc"); + VecInput = IC.Builder.CreateBitCast(VecInput, VecType, "bc"); } unsigned Elt = ShiftAmount / DestWidth; - if (DL.isBigEndian()) + if (IC.getDataLayout().isBigEndian()) Elt = NumVecElts - 1 - Elt; - return ExtractElementInst::Create(VecInput, IC.Builder->getInt32(Elt)); + return ExtractElementInst::Create(VecInput, IC.Builder.getInt32(Elt)); } /// Try to narrow the width of bitwise logic instructions with constants. Instruction *InstCombiner::shrinkBitwiseLogic(TruncInst &Trunc) { Type *SrcTy = Trunc.getSrcTy(); Type *DestTy = Trunc.getType(); - if (isa<IntegerType>(SrcTy) && !ShouldChangeType(SrcTy, DestTy)) + if (isa<IntegerType>(SrcTy) && !shouldChangeType(SrcTy, DestTy)) return nullptr; BinaryOperator *LogicOp; @@ -459,10 +459,60 @@ Instruction *InstCombiner::shrinkBitwiseLogic(TruncInst &Trunc) { // trunc (logic X, C) --> logic (trunc X, C') Constant *NarrowC = ConstantExpr::getTrunc(C, DestTy); - Value *NarrowOp0 = Builder->CreateTrunc(LogicOp->getOperand(0), DestTy); + Value *NarrowOp0 = Builder.CreateTrunc(LogicOp->getOperand(0), DestTy); return BinaryOperator::Create(LogicOp->getOpcode(), NarrowOp0, NarrowC); } +/// Try to narrow the width of a splat shuffle. This could be generalized to any +/// shuffle with a constant operand, but we limit the transform to avoid +/// creating a shuffle type that targets may not be able to lower effectively. +static Instruction *shrinkSplatShuffle(TruncInst &Trunc, + InstCombiner::BuilderTy &Builder) { + auto *Shuf = dyn_cast<ShuffleVectorInst>(Trunc.getOperand(0)); + if (Shuf && Shuf->hasOneUse() && isa<UndefValue>(Shuf->getOperand(1)) && + Shuf->getMask()->getSplatValue() && + Shuf->getType() == Shuf->getOperand(0)->getType()) { + // trunc (shuf X, Undef, SplatMask) --> shuf (trunc X), Undef, SplatMask + Constant *NarrowUndef = UndefValue::get(Trunc.getType()); + Value *NarrowOp = Builder.CreateTrunc(Shuf->getOperand(0), Trunc.getType()); + return new ShuffleVectorInst(NarrowOp, NarrowUndef, Shuf->getMask()); + } + + return nullptr; +} + +/// Try to narrow the width of an insert element. This could be generalized for +/// any vector constant, but we limit the transform to insertion into undef to +/// avoid potential backend problems from unsupported insertion widths. This +/// could also be extended to handle the case of inserting a scalar constant +/// into a vector variable. +static Instruction *shrinkInsertElt(CastInst &Trunc, + InstCombiner::BuilderTy &Builder) { + Instruction::CastOps Opcode = Trunc.getOpcode(); + assert((Opcode == Instruction::Trunc || Opcode == Instruction::FPTrunc) && + "Unexpected instruction for shrinking"); + + auto *InsElt = dyn_cast<InsertElementInst>(Trunc.getOperand(0)); + if (!InsElt || !InsElt->hasOneUse()) + return nullptr; + + Type *DestTy = Trunc.getType(); + Type *DestScalarTy = DestTy->getScalarType(); + Value *VecOp = InsElt->getOperand(0); + Value *ScalarOp = InsElt->getOperand(1); + Value *Index = InsElt->getOperand(2); + + if (isa<UndefValue>(VecOp)) { + // trunc (inselt undef, X, Index) --> inselt undef, (trunc X), Index + // fptrunc (inselt undef, X, Index) --> inselt undef, (fptrunc X), Index + UndefValue *NarrowUndef = UndefValue::get(DestTy); + Value *NarrowOp = Builder.CreateCast(Opcode, ScalarOp, DestScalarTy); + return InsertElementInst::Create(NarrowUndef, NarrowOp, Index); + } + + return nullptr; +} + Instruction *InstCombiner::visitTrunc(TruncInst &CI) { if (Instruction *Result = commonCastTransforms(CI)) return Result; @@ -488,7 +538,7 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { // type. Only do this if the dest type is a simple type, don't convert the // expression tree to something weird like i93 unless the source is also // strange. - if ((DestTy->isVectorTy() || ShouldChangeType(SrcTy, DestTy)) && + if ((DestTy->isVectorTy() || shouldChangeType(SrcTy, DestTy)) && canEvaluateTruncated(Src, DestTy, *this, &CI)) { // If this cast is a truncate, evaluting in a different type always @@ -503,11 +553,14 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { // Canonicalize trunc x to i1 -> (icmp ne (and x, 1), 0), likewise for vector. if (DestTy->getScalarSizeInBits() == 1) { Constant *One = ConstantInt::get(SrcTy, 1); - Src = Builder->CreateAnd(Src, One); + Src = Builder.CreateAnd(Src, One); Value *Zero = Constant::getNullValue(Src->getType()); return new ICmpInst(ICmpInst::ICMP_NE, Src, Zero); } + // FIXME: Maybe combine the next two transforms to handle the no cast case + // more efficiently. Support vector types. Cleanup code by using m_OneUse. + // Transform trunc(lshr (zext A), Cst) to eliminate one type conversion. Value *A = nullptr; ConstantInt *Cst = nullptr; if (Src->hasOneUse() && @@ -526,36 +579,54 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { // Since we're doing an lshr and a zero extend, and know that the shift // amount is smaller than ASize, it is always safe to do the shift in A's // type, then zero extend or truncate to the result. - Value *Shift = Builder->CreateLShr(A, Cst->getZExtValue()); + Value *Shift = Builder.CreateLShr(A, Cst->getZExtValue()); Shift->takeName(Src); return CastInst::CreateIntegerCast(Shift, DestTy, false); } + // FIXME: We should canonicalize to zext/trunc and remove this transform. // Transform trunc(lshr (sext A), Cst) to ashr A, Cst to eliminate type // conversion. // It works because bits coming from sign extension have the same value as // the sign bit of the original value; performing ashr instead of lshr // generates bits of the same value as the sign bit. if (Src->hasOneUse() && - match(Src, m_LShr(m_SExt(m_Value(A)), m_ConstantInt(Cst))) && - cast<Instruction>(Src)->getOperand(0)->hasOneUse()) { + match(Src, m_LShr(m_SExt(m_Value(A)), m_ConstantInt(Cst)))) { + Value *SExt = cast<Instruction>(Src)->getOperand(0); + const unsigned SExtSize = SExt->getType()->getPrimitiveSizeInBits(); const unsigned ASize = A->getType()->getPrimitiveSizeInBits(); + const unsigned CISize = CI.getType()->getPrimitiveSizeInBits(); + const unsigned MaxAmt = SExtSize - std::max(CISize, ASize); + unsigned ShiftAmt = Cst->getZExtValue(); + // This optimization can be only performed when zero bits generated by // the original lshr aren't pulled into the value after truncation, so we - // can only shift by values smaller than the size of destination type (in - // bits). - if (Cst->getValue().ult(ASize)) { - Value *Shift = Builder->CreateAShr(A, Cst->getZExtValue()); - Shift->takeName(Src); - return CastInst::CreateIntegerCast(Shift, CI.getType(), true); + // can only shift by values no larger than the number of extension bits. + // FIXME: Instead of bailing when the shift is too large, use and to clear + // the extra bits. + if (ShiftAmt <= MaxAmt) { + if (CISize == ASize) + return BinaryOperator::CreateAShr(A, ConstantInt::get(CI.getType(), + std::min(ShiftAmt, ASize - 1))); + if (SExt->hasOneUse()) { + Value *Shift = Builder.CreateAShr(A, std::min(ShiftAmt, ASize - 1)); + Shift->takeName(Src); + return CastInst::CreateIntegerCast(Shift, CI.getType(), true); + } } } if (Instruction *I = shrinkBitwiseLogic(CI)) return I; + if (Instruction *I = shrinkSplatShuffle(CI, Builder)) + return I; + + if (Instruction *I = shrinkInsertElt(CI, Builder)) + return I; + if (Src->hasOneUse() && isa<IntegerType>(SrcTy) && - ShouldChangeType(SrcTy, DestTy)) { + shouldChangeType(SrcTy, DestTy)) { // Transform "trunc (shl X, cst)" -> "shl (trunc X), cst" so long as the // dest type is native and cst < dest size. if (match(Src, m_Shl(m_Value(A), m_ConstantInt(Cst))) && @@ -564,7 +635,7 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { // FoldShiftByConstant and is the extend in reg pattern. const unsigned DestSize = DestTy->getScalarSizeInBits(); if (Cst->getValue().ult(DestSize)) { - Value *NewTrunc = Builder->CreateTrunc(A, DestTy, A->getName() + ".tr"); + Value *NewTrunc = Builder.CreateTrunc(A, DestTy, A->getName() + ".tr"); return BinaryOperator::Create( Instruction::Shl, NewTrunc, @@ -573,7 +644,7 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { } } - if (Instruction *I = foldVecTruncToExtElt(CI, *this, DL)) + if (Instruction *I = foldVecTruncToExtElt(CI, *this)) return I; return nullptr; @@ -589,20 +660,20 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, ZExtInst &CI, // zext (x <s 0) to i32 --> x>>u31 true if signbit set. // zext (x >s -1) to i32 --> (x>>u31)^1 true if signbit clear. - if ((ICI->getPredicate() == ICmpInst::ICMP_SLT && Op1CV == 0) || + if ((ICI->getPredicate() == ICmpInst::ICMP_SLT && Op1CV.isNullValue()) || (ICI->getPredicate() == ICmpInst::ICMP_SGT && Op1CV.isAllOnesValue())) { if (!DoTransform) return ICI; Value *In = ICI->getOperand(0); Value *Sh = ConstantInt::get(In->getType(), In->getType()->getScalarSizeInBits() - 1); - In = Builder->CreateLShr(In, Sh, In->getName() + ".lobit"); + In = Builder.CreateLShr(In, Sh, In->getName() + ".lobit"); if (In->getType() != CI.getType()) - In = Builder->CreateIntCast(In, CI.getType(), false/*ZExt*/); + In = Builder.CreateIntCast(In, CI.getType(), false /*ZExt*/); if (ICI->getPredicate() == ICmpInst::ICMP_SGT) { Constant *One = ConstantInt::get(In->getType(), 1); - In = Builder->CreateXor(In, One, In->getName() + ".not"); + In = Builder.CreateXor(In, One, In->getName() + ".not"); } return replaceInstUsesWith(CI, In); @@ -616,20 +687,18 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, ZExtInst &CI, // zext (X != 0) to i32 --> X>>1 iff X has only the 2nd bit set. // zext (X != 1) to i32 --> X^1 iff X has only the low bit set. // zext (X != 2) to i32 --> (X>>1)^1 iff X has only the 2nd bit set. - if ((Op1CV == 0 || Op1CV.isPowerOf2()) && + if ((Op1CV.isNullValue() || Op1CV.isPowerOf2()) && // This only works for EQ and NE ICI->isEquality()) { // If Op1C some other power of two, convert: - uint32_t BitWidth = Op1C->getType()->getBitWidth(); - APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - computeKnownBits(ICI->getOperand(0), KnownZero, KnownOne, 0, &CI); + KnownBits Known = computeKnownBits(ICI->getOperand(0), 0, &CI); - APInt KnownZeroMask(~KnownZero); + APInt KnownZeroMask(~Known.Zero); if (KnownZeroMask.isPowerOf2()) { // Exactly 1 possible 1? if (!DoTransform) return ICI; bool isNE = ICI->getPredicate() == ICmpInst::ICMP_NE; - if (Op1CV != 0 && (Op1CV != KnownZeroMask)) { + if (!Op1CV.isNullValue() && (Op1CV != KnownZeroMask)) { // (X&4) == 2 --> false // (X&4) != 2 --> true Constant *Res = ConstantInt::get(Type::getInt1Ty(CI.getContext()), @@ -643,19 +712,19 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, ZExtInst &CI, if (ShAmt) { // Perform a logical shr by shiftamt. // Insert the shift to put the result in the low bit. - In = Builder->CreateLShr(In, ConstantInt::get(In->getType(), ShAmt), - In->getName() + ".lobit"); + In = Builder.CreateLShr(In, ConstantInt::get(In->getType(), ShAmt), + In->getName() + ".lobit"); } - if ((Op1CV != 0) == isNE) { // Toggle the low bit. + if (!Op1CV.isNullValue() == isNE) { // Toggle the low bit. Constant *One = ConstantInt::get(In->getType(), 1); - In = Builder->CreateXor(In, One); + In = Builder.CreateXor(In, One); } if (CI.getType() == In->getType()) return replaceInstUsesWith(CI, In); - Value *IntCast = Builder->CreateIntCast(In, CI.getType(), false); + Value *IntCast = Builder.CreateIntCast(In, CI.getType(), false); return replaceInstUsesWith(CI, IntCast); } } @@ -666,34 +735,31 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, ZExtInst &CI, // may lead to additional simplifications. if (ICI->isEquality() && CI.getType() == ICI->getOperand(0)->getType()) { if (IntegerType *ITy = dyn_cast<IntegerType>(CI.getType())) { - uint32_t BitWidth = ITy->getBitWidth(); Value *LHS = ICI->getOperand(0); Value *RHS = ICI->getOperand(1); - APInt KnownZeroLHS(BitWidth, 0), KnownOneLHS(BitWidth, 0); - APInt KnownZeroRHS(BitWidth, 0), KnownOneRHS(BitWidth, 0); - computeKnownBits(LHS, KnownZeroLHS, KnownOneLHS, 0, &CI); - computeKnownBits(RHS, KnownZeroRHS, KnownOneRHS, 0, &CI); + KnownBits KnownLHS = computeKnownBits(LHS, 0, &CI); + KnownBits KnownRHS = computeKnownBits(RHS, 0, &CI); - if (KnownZeroLHS == KnownZeroRHS && KnownOneLHS == KnownOneRHS) { - APInt KnownBits = KnownZeroLHS | KnownOneLHS; + if (KnownLHS.Zero == KnownRHS.Zero && KnownLHS.One == KnownRHS.One) { + APInt KnownBits = KnownLHS.Zero | KnownLHS.One; APInt UnknownBit = ~KnownBits; if (UnknownBit.countPopulation() == 1) { if (!DoTransform) return ICI; - Value *Result = Builder->CreateXor(LHS, RHS); + Value *Result = Builder.CreateXor(LHS, RHS); // Mask off any bits that are set and won't be shifted away. - if (KnownOneLHS.uge(UnknownBit)) - Result = Builder->CreateAnd(Result, + if (KnownLHS.One.uge(UnknownBit)) + Result = Builder.CreateAnd(Result, ConstantInt::get(ITy, UnknownBit)); // Shift the bit we're testing down to the lsb. - Result = Builder->CreateLShr( + Result = Builder.CreateLShr( Result, ConstantInt::get(ITy, UnknownBit.countTrailingZeros())); if (ICI->getPredicate() == ICmpInst::ICMP_EQ) - Result = Builder->CreateXor(Result, ConstantInt::get(ITy, 1)); + Result = Builder.CreateXor(Result, ConstantInt::get(ITy, 1)); Result->takeName(ICI); return replaceInstUsesWith(CI, Result); } @@ -838,11 +904,6 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) { if (Instruction *Result = commonCastTransforms(CI)) return Result; - // See if we can simplify any instructions used by the input whose sole - // purpose is to compute bits we don't care about. - if (SimplifyDemandedInstructionBits(CI)) - return &CI; - Value *Src = CI.getOperand(0); Type *SrcTy = Src->getType(), *DestTy = CI.getType(); @@ -851,10 +912,10 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) { // expression tree to something weird like i93 unless the source is also // strange. unsigned BitsToClear; - if ((DestTy->isVectorTy() || ShouldChangeType(SrcTy, DestTy)) && + if ((DestTy->isVectorTy() || shouldChangeType(SrcTy, DestTy)) && canEvaluateZExtd(Src, DestTy, BitsToClear, *this, &CI)) { - assert(BitsToClear < SrcTy->getScalarSizeInBits() && - "Unreasonable BitsToClear"); + assert(BitsToClear <= SrcTy->getScalarSizeInBits() && + "Can't clear more bits than in SrcTy"); // Okay, we can transform this! Insert the new expression now. DEBUG(dbgs() << "ICE: EvaluateInDifferentType converting expression type" @@ -898,7 +959,7 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) { if (SrcSize < DstSize) { APInt AndValue(APInt::getLowBitsSet(SrcSize, MidSize)); Constant *AndConst = ConstantInt::get(A->getType(), AndValue); - Value *And = Builder->CreateAnd(A, AndConst, CSrc->getName()+".mask"); + Value *And = Builder.CreateAnd(A, AndConst, CSrc->getName() + ".mask"); return new ZExtInst(And, CI.getType()); } @@ -908,7 +969,7 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) { AndValue)); } if (SrcSize > DstSize) { - Value *Trunc = Builder->CreateTrunc(A, CI.getType()); + Value *Trunc = Builder.CreateTrunc(A, CI.getType()); APInt AndValue(APInt::getLowBitsSet(DstSize, MidSize)); return BinaryOperator::CreateAnd(Trunc, ConstantInt::get(Trunc->getType(), @@ -930,8 +991,8 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) { (transformZExtICmp(LHS, CI, false) || transformZExtICmp(RHS, CI, false))) { // zext (or icmp, icmp) -> or (zext icmp), (zext icmp) - Value *LCast = Builder->CreateZExt(LHS, CI.getType(), LHS->getName()); - Value *RCast = Builder->CreateZExt(RHS, CI.getType(), RHS->getName()); + Value *LCast = Builder.CreateZExt(LHS, CI.getType(), LHS->getName()); + Value *RCast = Builder.CreateZExt(RHS, CI.getType(), RHS->getName()); BinaryOperator *Or = BinaryOperator::Create(Instruction::Or, LCast, RCast); // Perform the elimination. @@ -958,7 +1019,7 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) { match(And, m_OneUse(m_And(m_Trunc(m_Value(X)), m_Specific(C)))) && X->getType() == CI.getType()) { Constant *ZC = ConstantExpr::getZExt(C, CI.getType()); - return BinaryOperator::CreateXor(Builder->CreateAnd(X, ZC), ZC); + return BinaryOperator::CreateXor(Builder.CreateAnd(X, ZC), ZC); } return nullptr; @@ -981,12 +1042,12 @@ Instruction *InstCombiner::transformSExtICmp(ICmpInst *ICI, Instruction &CI) { Value *Sh = ConstantInt::get(Op0->getType(), Op0->getType()->getScalarSizeInBits()-1); - Value *In = Builder->CreateAShr(Op0, Sh, Op0->getName()+".lobit"); + Value *In = Builder.CreateAShr(Op0, Sh, Op0->getName() + ".lobit"); if (In->getType() != CI.getType()) - In = Builder->CreateIntCast(In, CI.getType(), true/*SExt*/); + In = Builder.CreateIntCast(In, CI.getType(), true /*SExt*/); if (Pred == ICmpInst::ICMP_SGT) - In = Builder->CreateNot(In, In->getName()+".not"); + In = Builder.CreateNot(In, In->getName() + ".not"); return replaceInstUsesWith(CI, In); } } @@ -997,11 +1058,9 @@ Instruction *InstCombiner::transformSExtICmp(ICmpInst *ICI, Instruction &CI) { // the icmp and sext into bitwise/integer operations. if (ICI->hasOneUse() && ICI->isEquality() && (Op1C->isZero() || Op1C->getValue().isPowerOf2())){ - unsigned BitWidth = Op1C->getType()->getBitWidth(); - APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - computeKnownBits(Op0, KnownZero, KnownOne, 0, &CI); + KnownBits Known = computeKnownBits(Op0, 0, &CI); - APInt KnownZeroMask(~KnownZero); + APInt KnownZeroMask(~Known.Zero); if (KnownZeroMask.isPowerOf2()) { Value *In = ICI->getOperand(0); @@ -1019,26 +1078,26 @@ Instruction *InstCombiner::transformSExtICmp(ICmpInst *ICI, Instruction &CI) { unsigned ShiftAmt = KnownZeroMask.countTrailingZeros(); // Perform a right shift to place the desired bit in the LSB. if (ShiftAmt) - In = Builder->CreateLShr(In, - ConstantInt::get(In->getType(), ShiftAmt)); + In = Builder.CreateLShr(In, + ConstantInt::get(In->getType(), ShiftAmt)); // At this point "In" is either 1 or 0. Subtract 1 to turn // {1, 0} -> {0, -1}. - In = Builder->CreateAdd(In, - ConstantInt::getAllOnesValue(In->getType()), - "sext"); + In = Builder.CreateAdd(In, + ConstantInt::getAllOnesValue(In->getType()), + "sext"); } else { // sext ((x & 2^n) != 0) -> (x << bitwidth-n) a>> bitwidth-1 // sext ((x & 2^n) == 2^n) -> (x << bitwidth-n) a>> bitwidth-1 unsigned ShiftAmt = KnownZeroMask.countLeadingZeros(); // Perform a left shift to place the desired bit in the MSB. if (ShiftAmt) - In = Builder->CreateShl(In, - ConstantInt::get(In->getType(), ShiftAmt)); + In = Builder.CreateShl(In, + ConstantInt::get(In->getType(), ShiftAmt)); // Distribute the bit over the whole bit width. - In = Builder->CreateAShr(In, ConstantInt::get(In->getType(), - BitWidth - 1), "sext"); + In = Builder.CreateAShr(In, ConstantInt::get(In->getType(), + KnownZeroMask.getBitWidth() - 1), "sext"); } if (CI.getType() == In->getType()) @@ -1124,20 +1183,14 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) { if (Instruction *I = commonCastTransforms(CI)) return I; - // See if we can simplify any instructions used by the input whose sole - // purpose is to compute bits we don't care about. - if (SimplifyDemandedInstructionBits(CI)) - return &CI; - Value *Src = CI.getOperand(0); Type *SrcTy = Src->getType(), *DestTy = CI.getType(); // If we know that the value being extended is positive, we can use a zext // instead. - bool KnownZero, KnownOne; - ComputeSignBit(Src, KnownZero, KnownOne, 0, &CI); - if (KnownZero) { - Value *ZExt = Builder->CreateZExt(Src, DestTy); + KnownBits Known = computeKnownBits(Src, 0, &CI); + if (Known.isNonNegative()) { + Value *ZExt = Builder.CreateZExt(Src, DestTy); return replaceInstUsesWith(CI, ZExt); } @@ -1145,7 +1198,7 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) { // type. Only do this if the dest type is a simple type, don't convert the // expression tree to something weird like i93 unless the source is also // strange. - if ((DestTy->isVectorTy() || ShouldChangeType(SrcTy, DestTy)) && + if ((DestTy->isVectorTy() || shouldChangeType(SrcTy, DestTy)) && canEvaluateSExtd(Src, DestTy)) { // Okay, we can transform this! Insert the new expression now. DEBUG(dbgs() << "ICE: EvaluateInDifferentType converting expression type" @@ -1163,22 +1216,20 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) { // We need to emit a shl + ashr to do the sign extend. Value *ShAmt = ConstantInt::get(DestTy, DestBitSize-SrcBitSize); - return BinaryOperator::CreateAShr(Builder->CreateShl(Res, ShAmt, "sext"), + return BinaryOperator::CreateAShr(Builder.CreateShl(Res, ShAmt, "sext"), ShAmt); } - // If this input is a trunc from our destination, then turn sext(trunc(x)) + // If the input is a trunc from the destination type, then turn sext(trunc(x)) // into shifts. - if (TruncInst *TI = dyn_cast<TruncInst>(Src)) - if (TI->hasOneUse() && TI->getOperand(0)->getType() == DestTy) { - uint32_t SrcBitSize = SrcTy->getScalarSizeInBits(); - uint32_t DestBitSize = DestTy->getScalarSizeInBits(); - - // We need to emit a shl + ashr to do the sign extend. - Value *ShAmt = ConstantInt::get(DestTy, DestBitSize-SrcBitSize); - Value *Res = Builder->CreateShl(TI->getOperand(0), ShAmt, "sext"); - return BinaryOperator::CreateAShr(Res, ShAmt); - } + Value *X; + if (match(Src, m_OneUse(m_Trunc(m_Value(X)))) && X->getType() == DestTy) { + // sext(trunc(X)) --> ashr(shl(X, C), C) + unsigned SrcBitSize = SrcTy->getScalarSizeInBits(); + unsigned DestBitSize = DestTy->getScalarSizeInBits(); + Constant *ShAmt = ConstantInt::get(DestTy, DestBitSize - SrcBitSize); + return BinaryOperator::CreateAShr(Builder.CreateShl(X, ShAmt), ShAmt); + } if (ICmpInst *ICI = dyn_cast<ICmpInst>(Src)) return transformSExtICmp(ICI, CI); @@ -1206,7 +1257,7 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) { unsigned SrcDstSize = CI.getType()->getScalarSizeInBits(); unsigned ShAmt = CA->getZExtValue()+SrcDstSize-MidSize; Constant *ShAmtV = ConstantInt::get(CI.getType(), ShAmt); - A = Builder->CreateShl(A, ShAmtV, CI.getName()); + A = Builder.CreateShl(A, ShAmtV, CI.getName()); return BinaryOperator::CreateAShr(A, ShAmtV); } @@ -1225,17 +1276,15 @@ static Constant *fitsInFPType(ConstantFP *CFP, const fltSemantics &Sem) { return nullptr; } -/// If this is a floating-point extension instruction, look -/// through it until we get the source value. +/// Look through floating-point extensions until we get the source value. static Value *lookThroughFPExtensions(Value *V) { - if (Instruction *I = dyn_cast<Instruction>(V)) - if (I->getOpcode() == Instruction::FPExt) - return lookThroughFPExtensions(I->getOperand(0)); + while (auto *FPExt = dyn_cast<FPExtInst>(V)) + V = FPExt->getOperand(0); // If this value is a constant, return the constant in the smallest FP type // that can accurately represent it. This allows us to turn // (float)((double)X+2.0) into x+2.0f. - if (ConstantFP *CFP = dyn_cast<ConstantFP>(V)) { + if (auto *CFP = dyn_cast<ConstantFP>(V)) { if (CFP->getType() == Type::getPPC_FP128Ty(V->getContext())) return V; // No constant folding of this. // See if the value can be truncated to half and then reextended. @@ -1297,9 +1346,9 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) { // case of interest here is (float)((double)float + float)). if (OpWidth >= 2*DstWidth+1 && DstWidth >= SrcWidth) { if (LHSOrig->getType() != CI.getType()) - LHSOrig = Builder->CreateFPExt(LHSOrig, CI.getType()); + LHSOrig = Builder.CreateFPExt(LHSOrig, CI.getType()); if (RHSOrig->getType() != CI.getType()) - RHSOrig = Builder->CreateFPExt(RHSOrig, CI.getType()); + RHSOrig = Builder.CreateFPExt(RHSOrig, CI.getType()); Instruction *RI = BinaryOperator::Create(OpI->getOpcode(), LHSOrig, RHSOrig); RI->copyFastMathFlags(OpI); @@ -1314,9 +1363,9 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) { // in the destination format if it can represent both sources. if (OpWidth >= LHSWidth + RHSWidth && DstWidth >= SrcWidth) { if (LHSOrig->getType() != CI.getType()) - LHSOrig = Builder->CreateFPExt(LHSOrig, CI.getType()); + LHSOrig = Builder.CreateFPExt(LHSOrig, CI.getType()); if (RHSOrig->getType() != CI.getType()) - RHSOrig = Builder->CreateFPExt(RHSOrig, CI.getType()); + RHSOrig = Builder.CreateFPExt(RHSOrig, CI.getType()); Instruction *RI = BinaryOperator::CreateFMul(LHSOrig, RHSOrig); RI->copyFastMathFlags(OpI); @@ -1332,9 +1381,9 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) { // TODO: Tighten bound via rigorous analysis of the unbalanced case. if (OpWidth >= 2*DstWidth && DstWidth >= SrcWidth) { if (LHSOrig->getType() != CI.getType()) - LHSOrig = Builder->CreateFPExt(LHSOrig, CI.getType()); + LHSOrig = Builder.CreateFPExt(LHSOrig, CI.getType()); if (RHSOrig->getType() != CI.getType()) - RHSOrig = Builder->CreateFPExt(RHSOrig, CI.getType()); + RHSOrig = Builder.CreateFPExt(RHSOrig, CI.getType()); Instruction *RI = BinaryOperator::CreateFDiv(LHSOrig, RHSOrig); RI->copyFastMathFlags(OpI); @@ -1349,11 +1398,11 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) { if (SrcWidth == OpWidth) break; if (LHSWidth < SrcWidth) - LHSOrig = Builder->CreateFPExt(LHSOrig, RHSOrig->getType()); + LHSOrig = Builder.CreateFPExt(LHSOrig, RHSOrig->getType()); else if (RHSWidth <= SrcWidth) - RHSOrig = Builder->CreateFPExt(RHSOrig, LHSOrig->getType()); + RHSOrig = Builder.CreateFPExt(RHSOrig, LHSOrig->getType()); if (LHSOrig != OpI->getOperand(0) || RHSOrig != OpI->getOperand(1)) { - Value *ExactResult = Builder->CreateFRem(LHSOrig, RHSOrig); + Value *ExactResult = Builder.CreateFRem(LHSOrig, RHSOrig); if (Instruction *RI = dyn_cast<Instruction>(ExactResult)) RI->copyFastMathFlags(OpI); return CastInst::CreateFPCast(ExactResult, CI.getType()); @@ -1362,8 +1411,8 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) { // (fptrunc (fneg x)) -> (fneg (fptrunc x)) if (BinaryOperator::isFNeg(OpI)) { - Value *InnerTrunc = Builder->CreateFPTrunc(OpI->getOperand(1), - CI.getType()); + Value *InnerTrunc = Builder.CreateFPTrunc(OpI->getOperand(1), + CI.getType()); Instruction *RI = BinaryOperator::CreateFNeg(InnerTrunc); RI->copyFastMathFlags(OpI); return RI; @@ -1382,34 +1431,57 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) { (isa<ConstantFP>(SI->getOperand(1)) || isa<ConstantFP>(SI->getOperand(2))) && matchSelectPattern(SI, LHS, RHS).Flavor == SPF_UNKNOWN) { - Value *LHSTrunc = Builder->CreateFPTrunc(SI->getOperand(1), - CI.getType()); - Value *RHSTrunc = Builder->CreateFPTrunc(SI->getOperand(2), - CI.getType()); + Value *LHSTrunc = Builder.CreateFPTrunc(SI->getOperand(1), CI.getType()); + Value *RHSTrunc = Builder.CreateFPTrunc(SI->getOperand(2), CI.getType()); return SelectInst::Create(SI->getOperand(0), LHSTrunc, RHSTrunc); } IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI.getOperand(0)); if (II) { switch (II->getIntrinsicID()) { - default: break; - case Intrinsic::fabs: { - // (fptrunc (fabs x)) -> (fabs (fptrunc x)) - Value *InnerTrunc = Builder->CreateFPTrunc(II->getArgOperand(0), - CI.getType()); - Type *IntrinsicType[] = { CI.getType() }; - Function *Overload = Intrinsic::getDeclaration( - CI.getModule(), II->getIntrinsicID(), IntrinsicType); - - SmallVector<OperandBundleDef, 1> OpBundles; - II->getOperandBundlesAsDefs(OpBundles); - - Value *Args[] = { InnerTrunc }; - return CallInst::Create(Overload, Args, OpBundles, II->getName()); + default: break; + case Intrinsic::fabs: + case Intrinsic::ceil: + case Intrinsic::floor: + case Intrinsic::rint: + case Intrinsic::round: + case Intrinsic::nearbyint: + case Intrinsic::trunc: { + Value *Src = II->getArgOperand(0); + if (!Src->hasOneUse()) + break; + + // Except for fabs, this transformation requires the input of the unary FP + // operation to be itself an fpext from the type to which we're + // truncating. + if (II->getIntrinsicID() != Intrinsic::fabs) { + FPExtInst *FPExtSrc = dyn_cast<FPExtInst>(Src); + if (!FPExtSrc || FPExtSrc->getOperand(0)->getType() != CI.getType()) + break; } + + // Do unary FP operation on smaller type. + // (fptrunc (fabs x)) -> (fabs (fptrunc x)) + Value *InnerTrunc = Builder.CreateFPTrunc(Src, CI.getType()); + Type *IntrinsicType[] = { CI.getType() }; + Function *Overload = Intrinsic::getDeclaration( + CI.getModule(), II->getIntrinsicID(), IntrinsicType); + + SmallVector<OperandBundleDef, 1> OpBundles; + II->getOperandBundlesAsDefs(OpBundles); + + Value *Args[] = { InnerTrunc }; + CallInst *NewCI = CallInst::Create(Overload, Args, + OpBundles, II->getName()); + NewCI->copyFastMathFlags(II); + return NewCI; + } } } + if (Instruction *I = shrinkInsertElt(CI, Builder)) + return I; + return nullptr; } @@ -1502,7 +1574,7 @@ Instruction *InstCombiner::visitIntToPtr(IntToPtrInst &CI) { if (CI.getType()->isVectorTy()) // Handle vectors of pointers. Ty = VectorType::get(Ty, CI.getType()->getVectorNumElements()); - Value *P = Builder->CreateZExtOrTrunc(CI.getOperand(0), Ty); + Value *P = Builder.CreateZExtOrTrunc(CI.getOperand(0), Ty); return new IntToPtrInst(P, CI.getType()); } @@ -1524,7 +1596,7 @@ Instruction *InstCombiner::commonPointerCastTransforms(CastInst &CI) { // GEP into CI would undo canonicalizing addrspacecast with different // pointer types, causing infinite loops. (!isa<AddrSpaceCastInst>(CI) || - GEP->getType() == GEP->getPointerOperand()->getType())) { + GEP->getType() == GEP->getPointerOperandType())) { // Changing the cast operand is usually not a good idea but it is safe // here because the pointer operand is being replaced with another // pointer operand so the opcode doesn't need to change. @@ -1552,7 +1624,7 @@ Instruction *InstCombiner::visitPtrToInt(PtrToIntInst &CI) { if (Ty->isVectorTy()) // Handle vectors of pointers. PtrTy = VectorType::get(PtrTy, Ty->getVectorNumElements()); - Value *P = Builder->CreatePtrToInt(CI.getOperand(0), PtrTy); + Value *P = Builder.CreatePtrToInt(CI.getOperand(0), PtrTy); return CastInst::CreateIntegerCast(P, Ty, /*isSigned=*/false); } @@ -1578,7 +1650,7 @@ static Instruction *optimizeVectorResize(Value *InVal, VectorType *DestTy, return nullptr; SrcTy = VectorType::get(DestTy->getElementType(), SrcTy->getNumElements()); - InVal = IC.Builder->CreateBitCast(InVal, SrcTy); + InVal = IC.Builder.CreateBitCast(InVal, SrcTy); } // Now that the element types match, get the shuffle mask and RHS of the @@ -1758,8 +1830,8 @@ static Value *optimizeIntegerToVectorInsertions(BitCastInst &CI, for (unsigned i = 0, e = Elements.size(); i != e; ++i) { if (!Elements[i]) continue; // Unset element. - Result = IC.Builder->CreateInsertElement(Result, Elements[i], - IC.Builder->getInt32(i)); + Result = IC.Builder.CreateInsertElement(Result, Elements[i], + IC.Builder.getInt32(i)); } return Result; @@ -1770,8 +1842,7 @@ static Value *optimizeIntegerToVectorInsertions(BitCastInst &CI, /// vectors better than bitcasts of scalars because vector registers are /// usually not type-specific like scalar integer or scalar floating-point. static Instruction *canonicalizeBitCastExtElt(BitCastInst &BitCast, - InstCombiner &IC, - const DataLayout &DL) { + InstCombiner &IC) { // TODO: Create and use a pattern matcher for ExtractElementInst. auto *ExtElt = dyn_cast<ExtractElementInst>(BitCast.getOperand(0)); if (!ExtElt || !ExtElt->hasOneUse()) @@ -1785,8 +1856,8 @@ static Instruction *canonicalizeBitCastExtElt(BitCastInst &BitCast, unsigned NumElts = ExtElt->getVectorOperandType()->getNumElements(); auto *NewVecType = VectorType::get(DestType, NumElts); - auto *NewBC = IC.Builder->CreateBitCast(ExtElt->getVectorOperand(), - NewVecType, "bc"); + auto *NewBC = IC.Builder.CreateBitCast(ExtElt->getVectorOperand(), + NewVecType, "bc"); return ExtractElementInst::Create(NewBC, ExtElt->getIndexOperand()); } @@ -1795,7 +1866,7 @@ static Instruction *foldBitCastBitwiseLogic(BitCastInst &BitCast, InstCombiner::BuilderTy &Builder) { Type *DestTy = BitCast.getType(); BinaryOperator *BO; - if (!DestTy->getScalarType()->isIntegerTy() || + if (!DestTy->isIntOrIntVectorTy() || !match(BitCast.getOperand(0), m_OneUse(m_BinOp(BO))) || !BO->isBitwiseLogicOp()) return nullptr; @@ -1821,6 +1892,18 @@ static Instruction *foldBitCastBitwiseLogic(BitCastInst &BitCast, return BinaryOperator::Create(BO->getOpcode(), CastedOp0, X); } + // Canonicalize vector bitcasts to come before vector bitwise logic with a + // constant. This eases recognition of special constants for later ops. + // Example: + // icmp u/s (a ^ signmask), (b ^ signmask) --> icmp s/u a, b + Constant *C; + if (match(BO->getOperand(1), m_Constant(C))) { + // bitcast (logic X, C) --> logic (bitcast X, C') + Value *CastedOp0 = Builder.CreateBitCast(BO->getOperand(0), DestTy); + Value *CastedC = ConstantExpr::getBitCast(C, DestTy); + return BinaryOperator::Create(BO->getOpcode(), CastedOp0, CastedC); + } + return nullptr; } @@ -1946,8 +2029,8 @@ Instruction *InstCombiner::optimizeBitCastFromPhi(CastInst &CI, PHINode *PN) { // For each old PHI node, create a corresponding new PHI node with a type A. SmallDenseMap<PHINode *, PHINode *> NewPNodes; for (auto *OldPN : OldPhiNodes) { - Builder->SetInsertPoint(OldPN); - PHINode *NewPN = Builder->CreatePHI(DestTy, OldPN->getNumOperands()); + Builder.SetInsertPoint(OldPN); + PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands()); NewPNodes[OldPN] = NewPN; } @@ -1960,8 +2043,8 @@ Instruction *InstCombiner::optimizeBitCastFromPhi(CastInst &CI, PHINode *PN) { if (auto *C = dyn_cast<Constant>(V)) { NewV = ConstantExpr::getBitCast(C, DestTy); } else if (auto *LI = dyn_cast<LoadInst>(V)) { - Builder->SetInsertPoint(LI->getNextNode()); - NewV = Builder->CreateBitCast(LI, DestTy); + Builder.SetInsertPoint(LI->getNextNode()); + NewV = Builder.CreateBitCast(LI, DestTy); Worklist.Add(LI); } else if (auto *BCI = dyn_cast<BitCastInst>(V)) { NewV = BCI->getOperand(0); @@ -1977,9 +2060,9 @@ Instruction *InstCombiner::optimizeBitCastFromPhi(CastInst &CI, PHINode *PN) { for (User *U : PN->users()) { auto *SI = dyn_cast<StoreInst>(U); if (SI && SI->isSimple() && SI->getOperand(0) == PN) { - Builder->SetInsertPoint(SI); + Builder.SetInsertPoint(SI); auto *NewBC = - cast<BitCastInst>(Builder->CreateBitCast(NewPNodes[PN], SrcTy)); + cast<BitCastInst>(Builder.CreateBitCast(NewPNodes[PN], SrcTy)); SI->setOperand(0, NewBC); Worklist.Add(SI); assert(hasStoreUsersOnly(*NewBC)); @@ -2034,14 +2117,14 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { // If we found a path from the src to dest, create the getelementptr now. if (SrcElTy == DstElTy) { - SmallVector<Value *, 8> Idxs(NumZeros + 1, Builder->getInt32(0)); + SmallVector<Value *, 8> Idxs(NumZeros + 1, Builder.getInt32(0)); return GetElementPtrInst::CreateInBounds(Src, Idxs); } } if (VectorType *DestVTy = dyn_cast<VectorType>(DestTy)) { if (DestVTy->getNumElements() == 1 && !SrcTy->isVectorTy()) { - Value *Elem = Builder->CreateBitCast(Src, DestVTy->getElementType()); + Value *Elem = Builder.CreateBitCast(Src, DestVTy->getElementType()); return InsertElementInst::Create(UndefValue::get(DestTy), Elem, Constant::getNullValue(Type::getInt32Ty(CI.getContext()))); // FIXME: Canonicalize bitcast(insertelement) -> insertelement(bitcast) @@ -2074,7 +2157,7 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { // scalar-scalar cast. if (!DestTy->isVectorTy()) { Value *Elem = - Builder->CreateExtractElement(Src, + Builder.CreateExtractElement(Src, Constant::getNullValue(Type::getInt32Ty(CI.getContext()))); return CastInst::Create(Instruction::BitCast, Elem, DestTy); } @@ -2103,8 +2186,8 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { Tmp->getOperand(0)->getType() == DestTy) || ((Tmp = dyn_cast<BitCastInst>(SVI->getOperand(1))) && Tmp->getOperand(0)->getType() == DestTy)) { - Value *LHS = Builder->CreateBitCast(SVI->getOperand(0), DestTy); - Value *RHS = Builder->CreateBitCast(SVI->getOperand(1), DestTy); + Value *LHS = Builder.CreateBitCast(SVI->getOperand(0), DestTy); + Value *RHS = Builder.CreateBitCast(SVI->getOperand(1), DestTy); // Return a new shuffle vector. Use the same element ID's, as we // know the vector types match #elts. return new ShuffleVectorInst(LHS, RHS, SVI->getOperand(2)); @@ -2117,13 +2200,13 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { if (Instruction *I = optimizeBitCastFromPhi(CI, PN)) return I; - if (Instruction *I = canonicalizeBitCastExtElt(CI, *this, DL)) + if (Instruction *I = canonicalizeBitCastExtElt(CI, *this)) return I; - if (Instruction *I = foldBitCastBitwiseLogic(CI, *Builder)) + if (Instruction *I = foldBitCastBitwiseLogic(CI, Builder)) return I; - if (Instruction *I = foldBitCastSelect(CI, *Builder)) + if (Instruction *I = foldBitCastSelect(CI, Builder)) return I; if (SrcTy->isPointerTy()) @@ -2147,7 +2230,7 @@ Instruction *InstCombiner::visitAddrSpaceCast(AddrSpaceCastInst &CI) { MidTy = VectorType::get(MidTy, VT->getNumElements()); } - Value *NewBitCast = Builder->CreateBitCast(Src, MidTy); + Value *NewBitCast = Builder.CreateBitCast(Src, MidTy); return new AddrSpaceCastInst(NewBitCast, CI.getType()); } diff --git a/contrib/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/contrib/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 428f94b..a8faaec 100644 --- a/contrib/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/contrib/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -26,6 +26,7 @@ #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/KnownBits.h" using namespace llvm; using namespace PatternMatch; @@ -111,10 +112,10 @@ static bool subWithOverflow(Constant *&Result, Constant *In1, /// Given an icmp instruction, return true if any use of this comparison is a /// branch on sign bit comparison. -static bool isBranchOnSignBitCheck(ICmpInst &I, bool isSignBit) { +static bool hasBranchUse(ICmpInst &I) { for (auto *U : I.users()) if (isa<BranchInst>(U)) - return isSignBit; + return true; return false; } @@ -126,7 +127,7 @@ static bool isSignBitCheck(ICmpInst::Predicate Pred, const APInt &RHS, switch (Pred) { case ICmpInst::ICMP_SLT: // True if LHS s< 0 TrueIfSigned = true; - return RHS == 0; + return RHS.isNullValue(); case ICmpInst::ICMP_SLE: // True if LHS s<= RHS and RHS == -1 TrueIfSigned = true; return RHS.isAllOnesValue(); @@ -140,7 +141,7 @@ static bool isSignBitCheck(ICmpInst::Predicate Pred, const APInt &RHS, case ICmpInst::ICMP_UGE: // True if LHS u>= RHS and RHS == high-bit-mask (2^7, 2^15, 2^31, etc) TrueIfSigned = true; - return RHS.isSignBit(); + return RHS.isSignMask(); default: return false; } @@ -154,10 +155,10 @@ static bool isSignTest(ICmpInst::Predicate &Pred, const APInt &C) { if (!ICmpInst::isSigned(Pred)) return false; - if (C == 0) + if (C.isNullValue()) return ICmpInst::isRelational(Pred); - if (C == 1) { + if (C.isOneValue()) { if (Pred == ICmpInst::ICMP_SLT) { Pred = ICmpInst::ICMP_SLE; return true; @@ -175,42 +176,40 @@ static bool isSignTest(ICmpInst::Predicate &Pred, const APInt &C) { /// Given a signed integer type and a set of known zero and one bits, compute /// the maximum and minimum values that could have the specified known zero and /// known one bits, returning them in Min/Max. -static void computeSignedMinMaxValuesFromKnownBits(const APInt &KnownZero, - const APInt &KnownOne, +/// TODO: Move to method on KnownBits struct? +static void computeSignedMinMaxValuesFromKnownBits(const KnownBits &Known, APInt &Min, APInt &Max) { - assert(KnownZero.getBitWidth() == KnownOne.getBitWidth() && - KnownZero.getBitWidth() == Min.getBitWidth() && - KnownZero.getBitWidth() == Max.getBitWidth() && + assert(Known.getBitWidth() == Min.getBitWidth() && + Known.getBitWidth() == Max.getBitWidth() && "KnownZero, KnownOne and Min, Max must have equal bitwidth."); - APInt UnknownBits = ~(KnownZero|KnownOne); + APInt UnknownBits = ~(Known.Zero|Known.One); // The minimum value is when all unknown bits are zeros, EXCEPT for the sign // bit if it is unknown. - Min = KnownOne; - Max = KnownOne|UnknownBits; + Min = Known.One; + Max = Known.One|UnknownBits; if (UnknownBits.isNegative()) { // Sign bit is unknown - Min.setBit(Min.getBitWidth()-1); - Max.clearBit(Max.getBitWidth()-1); + Min.setSignBit(); + Max.clearSignBit(); } } /// Given an unsigned integer type and a set of known zero and one bits, compute /// the maximum and minimum values that could have the specified known zero and /// known one bits, returning them in Min/Max. -static void computeUnsignedMinMaxValuesFromKnownBits(const APInt &KnownZero, - const APInt &KnownOne, +/// TODO: Move to method on KnownBits struct? +static void computeUnsignedMinMaxValuesFromKnownBits(const KnownBits &Known, APInt &Min, APInt &Max) { - assert(KnownZero.getBitWidth() == KnownOne.getBitWidth() && - KnownZero.getBitWidth() == Min.getBitWidth() && - KnownZero.getBitWidth() == Max.getBitWidth() && + assert(Known.getBitWidth() == Min.getBitWidth() && + Known.getBitWidth() == Max.getBitWidth() && "Ty, KnownZero, KnownOne and Min, Max must have equal bitwidth."); - APInt UnknownBits = ~(KnownZero|KnownOne); + APInt UnknownBits = ~(Known.Zero|Known.One); // The minimum value is when the unknown bits are all zeros. - Min = KnownOne; + Min = Known.One; // The maximum value is when the unknown bits are all ones. - Max = KnownOne|UnknownBits; + Max = Known.One|UnknownBits; } /// This is called when we see this pattern: @@ -230,7 +229,9 @@ Instruction *InstCombiner::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, return nullptr; uint64_t ArrayElementCount = Init->getType()->getArrayNumElements(); - if (ArrayElementCount > 1024) return nullptr; // Don't blow up on huge arrays. + // Don't blow up on huge arrays. + if (ArrayElementCount > MaxArraySizeForCombine) + return nullptr; // There are many forms of this optimization we can handle, for now, just do // the simple index into a single-dimensional array. @@ -391,7 +392,7 @@ Instruction *InstCombiner::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, Type *IntPtrTy = DL.getIntPtrType(GEP->getType()); unsigned PtrSize = IntPtrTy->getIntegerBitWidth(); if (Idx->getType()->getPrimitiveSizeInBits() > PtrSize) - Idx = Builder->CreateTrunc(Idx, IntPtrTy); + Idx = Builder.CreateTrunc(Idx, IntPtrTy); } // If the comparison is only true for one or two elements, emit direct @@ -399,7 +400,7 @@ Instruction *InstCombiner::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, if (SecondTrueElement != Overdefined) { // None true -> false. if (FirstTrueElement == Undefined) - return replaceInstUsesWith(ICI, Builder->getFalse()); + return replaceInstUsesWith(ICI, Builder.getFalse()); Value *FirstTrueIdx = ConstantInt::get(Idx->getType(), FirstTrueElement); @@ -408,9 +409,9 @@ Instruction *InstCombiner::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, return new ICmpInst(ICmpInst::ICMP_EQ, Idx, FirstTrueIdx); // True for two elements -> 'i == 47 | i == 72'. - Value *C1 = Builder->CreateICmpEQ(Idx, FirstTrueIdx); + Value *C1 = Builder.CreateICmpEQ(Idx, FirstTrueIdx); Value *SecondTrueIdx = ConstantInt::get(Idx->getType(), SecondTrueElement); - Value *C2 = Builder->CreateICmpEQ(Idx, SecondTrueIdx); + Value *C2 = Builder.CreateICmpEQ(Idx, SecondTrueIdx); return BinaryOperator::CreateOr(C1, C2); } @@ -419,7 +420,7 @@ Instruction *InstCombiner::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, if (SecondFalseElement != Overdefined) { // None false -> true. if (FirstFalseElement == Undefined) - return replaceInstUsesWith(ICI, Builder->getTrue()); + return replaceInstUsesWith(ICI, Builder.getTrue()); Value *FirstFalseIdx = ConstantInt::get(Idx->getType(), FirstFalseElement); @@ -428,9 +429,9 @@ Instruction *InstCombiner::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, return new ICmpInst(ICmpInst::ICMP_NE, Idx, FirstFalseIdx); // False for two elements -> 'i != 47 & i != 72'. - Value *C1 = Builder->CreateICmpNE(Idx, FirstFalseIdx); + Value *C1 = Builder.CreateICmpNE(Idx, FirstFalseIdx); Value *SecondFalseIdx = ConstantInt::get(Idx->getType(),SecondFalseElement); - Value *C2 = Builder->CreateICmpNE(Idx, SecondFalseIdx); + Value *C2 = Builder.CreateICmpNE(Idx, SecondFalseIdx); return BinaryOperator::CreateAnd(C1, C2); } @@ -442,7 +443,7 @@ Instruction *InstCombiner::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, // Generate (i-FirstTrue) <u (TrueRangeEnd-FirstTrue+1). if (FirstTrueElement) { Value *Offs = ConstantInt::get(Idx->getType(), -FirstTrueElement); - Idx = Builder->CreateAdd(Idx, Offs); + Idx = Builder.CreateAdd(Idx, Offs); } Value *End = ConstantInt::get(Idx->getType(), @@ -456,7 +457,7 @@ Instruction *InstCombiner::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, // Generate (i-FirstFalse) >u (FalseRangeEnd-FirstFalse). if (FirstFalseElement) { Value *Offs = ConstantInt::get(Idx->getType(), -FirstFalseElement); - Idx = Builder->CreateAdd(Idx, Offs); + Idx = Builder.CreateAdd(Idx, Offs); } Value *End = ConstantInt::get(Idx->getType(), @@ -480,9 +481,9 @@ Instruction *InstCombiner::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, Ty = DL.getSmallestLegalIntType(Init->getContext(), ArrayElementCount); if (Ty) { - Value *V = Builder->CreateIntCast(Idx, Ty, false); - V = Builder->CreateLShr(ConstantInt::get(Ty, MagicBitvector), V); - V = Builder->CreateAnd(ConstantInt::get(Ty, 1), V); + Value *V = Builder.CreateIntCast(Idx, Ty, false); + V = Builder.CreateLShr(ConstantInt::get(Ty, MagicBitvector), V); + V = Builder.CreateAnd(ConstantInt::get(Ty, 1), V); return new ICmpInst(ICmpInst::ICMP_NE, V, ConstantInt::get(Ty, 0)); } } @@ -565,7 +566,7 @@ static Value *evaluateGEPOffsetExpression(User *GEP, InstCombiner &IC, // we don't need to bother extending: the extension won't affect where the // computation crosses zero. if (VariableIdx->getType()->getPrimitiveSizeInBits() > IntPtrWidth) { - VariableIdx = IC.Builder->CreateTrunc(VariableIdx, IntPtrTy); + VariableIdx = IC.Builder.CreateTrunc(VariableIdx, IntPtrTy); } return VariableIdx; } @@ -587,10 +588,10 @@ static Value *evaluateGEPOffsetExpression(User *GEP, InstCombiner &IC, // Okay, we can do this evaluation. Start by converting the index to intptr. if (VariableIdx->getType() != IntPtrTy) - VariableIdx = IC.Builder->CreateIntCast(VariableIdx, IntPtrTy, + VariableIdx = IC.Builder.CreateIntCast(VariableIdx, IntPtrTy, true /*Signed*/); Constant *OffsetVal = ConstantInt::get(IntPtrTy, NewOffs); - return IC.Builder->CreateAdd(VariableIdx, OffsetVal, "offset"); + return IC.Builder.CreateAdd(VariableIdx, OffsetVal, "offset"); } /// Returns true if we can rewrite Start as a GEP with pointer Base @@ -980,13 +981,13 @@ Instruction *InstCombiner::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, if (LHSIndexTy != RHSIndexTy) { if (LHSIndexTy->getPrimitiveSizeInBits() < RHSIndexTy->getPrimitiveSizeInBits()) { - ROffset = Builder->CreateTrunc(ROffset, LHSIndexTy); + ROffset = Builder.CreateTrunc(ROffset, LHSIndexTy); } else - LOffset = Builder->CreateTrunc(LOffset, RHSIndexTy); + LOffset = Builder.CreateTrunc(LOffset, RHSIndexTy); } - Value *Cmp = Builder->CreateICmp(ICmpInst::getSignedPredicate(Cond), - LOffset, ROffset); + Value *Cmp = Builder.CreateICmp(ICmpInst::getSignedPredicate(Cond), + LOffset, ROffset); return replaceInstUsesWith(I, Cmp); } @@ -1025,7 +1026,7 @@ Instruction *InstCombiner::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, if (NumDifferences == 0) // SAME GEP? return replaceInstUsesWith(I, // No comparison is needed here. - Builder->getInt1(ICmpInst::isTrueWhenEqual(Cond))); + Builder.getInt1(ICmpInst::isTrueWhenEqual(Cond))); else if (NumDifferences == 1 && GEPsInBounds) { Value *LHSV = GEPLHS->getOperand(DiffOperand); @@ -1173,7 +1174,7 @@ Instruction *InstCombiner::foldICmpAddOpConst(Instruction &ICI, // (X+ -1) >s X --> X <s (MAXSINT-(-1-1)) --> X == -128 assert(Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE); - Constant *C = Builder->getInt(CI->getValue()-1); + Constant *C = Builder.getInt(CI->getValue() - 1); return new ICmpInst(ICmpInst::ICMP_SLT, X, ConstantExpr::getSub(SMax, C)); } @@ -1192,7 +1193,7 @@ Instruction *InstCombiner::foldICmpShrConstConst(ICmpInst &I, Value *A, }; // Don't bother doing any work for cases which InstSimplify handles. - if (AP2 == 0) + if (AP2.isNullValue()) return nullptr; bool IsAShr = isa<AShrOperator>(I.getOperand(0)); @@ -1251,7 +1252,7 @@ Instruction *InstCombiner::foldICmpShlConstConst(ICmpInst &I, Value *A, }; // Don't bother doing any work for cases which InstSimplify handles. - if (AP2 == 0) + if (AP2.isNullValue()) return nullptr; unsigned AP2TrailingZeros = AP2.countTrailingZeros(); @@ -1346,17 +1347,17 @@ static Instruction *processUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, Value *F = Intrinsic::getDeclaration(I.getModule(), Intrinsic::sadd_with_overflow, NewType); - InstCombiner::BuilderTy *Builder = IC.Builder; + InstCombiner::BuilderTy &Builder = IC.Builder; // Put the new code above the original add, in case there are any uses of the // add between the add and the compare. - Builder->SetInsertPoint(OrigAdd); + Builder.SetInsertPoint(OrigAdd); - Value *TruncA = Builder->CreateTrunc(A, NewType, A->getName() + ".trunc"); - Value *TruncB = Builder->CreateTrunc(B, NewType, B->getName() + ".trunc"); - CallInst *Call = Builder->CreateCall(F, {TruncA, TruncB}, "sadd"); - Value *Add = Builder->CreateExtractValue(Call, 0, "sadd.result"); - Value *ZExt = Builder->CreateZExt(Add, OrigAdd->getType()); + Value *TruncA = Builder.CreateTrunc(A, NewType, A->getName() + ".trunc"); + Value *TruncB = Builder.CreateTrunc(B, NewType, B->getName() + ".trunc"); + CallInst *Call = Builder.CreateCall(F, {TruncA, TruncB}, "sadd"); + Value *Add = Builder.CreateExtractValue(Call, 0, "sadd.result"); + Value *ZExt = Builder.CreateZExt(Add, OrigAdd->getType()); // The inner add was the result of the narrow add, zero extended to the // wider type. Replace it with the result computed by the intrinsic. @@ -1398,12 +1399,12 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &Cmp) { } // (icmp sgt smin(PosA, B) 0) -> (icmp sgt B 0) - if (*C == 0 && Pred == ICmpInst::ICMP_SGT) { + if (C->isNullValue() && Pred == ICmpInst::ICMP_SGT) { SelectPatternResult SPR = matchSelectPattern(X, A, B); if (SPR.Flavor == SPF_SMIN) { - if (isKnownPositive(A, DL)) + if (isKnownPositive(A, DL, 0, &AC, &Cmp, &DT)) return new ICmpInst(Pred, B, Cmp.getOperand(1)); - if (isKnownPositive(B, DL)) + if (isKnownPositive(B, DL, 0, &AC, &Cmp, &DT)) return new ICmpInst(Pred, A, Cmp.getOperand(1)); } } @@ -1433,9 +1434,9 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &Cmp) { ConstantRange Intersection = DominatingCR.intersectWith(CR); ConstantRange Difference = DominatingCR.difference(CR); if (Intersection.isEmptySet()) - return replaceInstUsesWith(Cmp, Builder->getFalse()); + return replaceInstUsesWith(Cmp, Builder.getFalse()); if (Difference.isEmptySet()) - return replaceInstUsesWith(Cmp, Builder->getTrue()); + return replaceInstUsesWith(Cmp, Builder.getTrue()); // If this is a normal comparison, it demands all bits. If it is a sign // bit comparison, it only demands the sign bit. @@ -1447,12 +1448,13 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &Cmp) { // of a test and branch. So we avoid canonicalizing in such situations // because test and branch instruction has better branch displacement // than compare and branch instruction. - if (!isBranchOnSignBitCheck(Cmp, IsSignBit) && !Cmp.isEquality()) { - if (auto *AI = Intersection.getSingleElement()) - return new ICmpInst(ICmpInst::ICMP_EQ, X, Builder->getInt(*AI)); - if (auto *AD = Difference.getSingleElement()) - return new ICmpInst(ICmpInst::ICMP_NE, X, Builder->getInt(*AD)); - } + if (Cmp.isEquality() || (IsSignBit && hasBranchUse(Cmp))) + return nullptr; + + if (auto *AI = Intersection.getSingleElement()) + return new ICmpInst(ICmpInst::ICMP_EQ, X, Builder.getInt(*AI)); + if (auto *AD = Difference.getSingleElement()) + return new ICmpInst(ICmpInst::ICMP_NE, X, Builder.getInt(*AD)); } return nullptr; @@ -1464,7 +1466,7 @@ Instruction *InstCombiner::foldICmpTruncConstant(ICmpInst &Cmp, const APInt *C) { ICmpInst::Predicate Pred = Cmp.getPredicate(); Value *X = Trunc->getOperand(0); - if (*C == 1 && C->getBitWidth() > 1) { + if (C->isOneValue() && C->getBitWidth() > 1) { // icmp slt trunc(signum(V)) 1 --> icmp slt V, 1 Value *V = nullptr; if (Pred == ICmpInst::ICMP_SLT && match(X, m_Signum(m_Value(V)))) @@ -1477,14 +1479,13 @@ Instruction *InstCombiner::foldICmpTruncConstant(ICmpInst &Cmp, // of the high bits truncated out of x are known. unsigned DstBits = Trunc->getType()->getScalarSizeInBits(), SrcBits = X->getType()->getScalarSizeInBits(); - APInt KnownZero(SrcBits, 0), KnownOne(SrcBits, 0); - computeKnownBits(X, KnownZero, KnownOne, 0, &Cmp); + KnownBits Known = computeKnownBits(X, 0, &Cmp); // If all the high bits are known, we can do this xform. - if ((KnownZero | KnownOne).countLeadingOnes() >= SrcBits - DstBits) { + if ((Known.Zero | Known.One).countLeadingOnes() >= SrcBits - DstBits) { // Pull in the high bits from known-ones set. APInt NewRHS = C->zext(SrcBits); - NewRHS |= KnownOne & APInt::getHighBitsSet(SrcBits, SrcBits - DstBits); + NewRHS |= Known.One & APInt::getHighBitsSet(SrcBits, SrcBits - DstBits); return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), NewRHS)); } } @@ -1505,7 +1506,7 @@ Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp, // If this is a comparison that tests the signbit (X < 0) or (x > -1), // fold the xor. ICmpInst::Predicate Pred = Cmp.getPredicate(); - if ((Pred == ICmpInst::ICMP_SLT && *C == 0) || + if ((Pred == ICmpInst::ICMP_SLT && C->isNullValue()) || (Pred == ICmpInst::ICMP_SGT && C->isAllOnesValue())) { // If the sign bit of the XorCst is not set, there is no change to @@ -1530,14 +1531,14 @@ Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp, } if (Xor->hasOneUse()) { - // (icmp u/s (xor X SignBit), C) -> (icmp s/u X, (xor C SignBit)) - if (!Cmp.isEquality() && XorC->isSignBit()) { + // (icmp u/s (xor X SignMask), C) -> (icmp s/u X, (xor C SignMask)) + if (!Cmp.isEquality() && XorC->isSignMask()) { Pred = Cmp.isSigned() ? Cmp.getUnsignedPredicate() : Cmp.getSignedPredicate(); return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), *C ^ *XorC)); } - // (icmp u/s (xor X ~SignBit), C) -> (icmp s/u X, (xor C ~SignBit)) + // (icmp u/s (xor X ~SignMask), C) -> (icmp s/u X, (xor C ~SignMask)) if (!Cmp.isEquality() && XorC->isMaxSignedValue()) { Pred = Cmp.isSigned() ? Cmp.getUnsignedPredicate() : Cmp.getSignedPredicate(); @@ -1623,15 +1624,15 @@ Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, // Turn ((X >> Y) & C2) == 0 into (X & (C2 << Y)) == 0. The latter is // preferable because it allows the C2 << Y expression to be hoisted out of a // loop if Y is invariant and X is not. - if (Shift->hasOneUse() && *C1 == 0 && Cmp.isEquality() && + if (Shift->hasOneUse() && C1->isNullValue() && Cmp.isEquality() && !Shift->isArithmeticShift() && !isa<Constant>(Shift->getOperand(0))) { // Compute C2 << Y. Value *NewShift = - IsShl ? Builder->CreateLShr(And->getOperand(1), Shift->getOperand(1)) - : Builder->CreateShl(And->getOperand(1), Shift->getOperand(1)); + IsShl ? Builder.CreateLShr(And->getOperand(1), Shift->getOperand(1)) + : Builder.CreateShl(And->getOperand(1), Shift->getOperand(1)); // Compute X & (C2 << Y). - Value *NewAnd = Builder->CreateAnd(Shift->getOperand(0), NewShift); + Value *NewAnd = Builder.CreateAnd(Shift->getOperand(0), NewShift); Cmp.setOperand(0, NewAnd); return &Cmp; } @@ -1663,13 +1664,13 @@ Instruction *InstCombiner::foldICmpAndConstConst(ICmpInst &Cmp, (Cmp.isEquality() || (!C1->isNegative() && !C2->isNegative()))) { // TODO: Is this a good transform for vectors? Wider types may reduce // throughput. Should this transform be limited (even for scalars) by using - // ShouldChangeType()? + // shouldChangeType()? if (!Cmp.getType()->isVectorTy()) { Type *WideType = W->getType(); unsigned WideScalarBits = WideType->getScalarSizeInBits(); Constant *ZextC1 = ConstantInt::get(WideType, C1->zext(WideScalarBits)); Constant *ZextC2 = ConstantInt::get(WideType, C2->zext(WideScalarBits)); - Value *NewAnd = Builder->CreateAnd(W, ZextC2, And->getName()); + Value *NewAnd = Builder.CreateAnd(W, ZextC2, And->getName()); return new ICmpInst(Cmp.getPredicate(), NewAnd, ZextC1); } } @@ -1681,7 +1682,8 @@ Instruction *InstCombiner::foldICmpAndConstConst(ICmpInst &Cmp, // (icmp pred (and A, (or (shl 1, B), 1), 0)) // // iff pred isn't signed - if (!Cmp.isSigned() && *C1 == 0 && match(And->getOperand(1), m_One())) { + if (!Cmp.isSigned() && C1->isNullValue() && + match(And->getOperand(1), m_One())) { Constant *One = cast<Constant>(And->getOperand(1)); Value *Or = And->getOperand(0); Value *A, *B, *LShr; @@ -1702,12 +1704,12 @@ Instruction *InstCombiner::foldICmpAndConstConst(ICmpInst &Cmp, NewOr = ConstantExpr::getOr(ConstantExpr::getNUWShl(One, C), One); } else { if (UsesRemoved >= 3) - NewOr = Builder->CreateOr(Builder->CreateShl(One, B, LShr->getName(), - /*HasNUW=*/true), - One, Or->getName()); + NewOr = Builder.CreateOr(Builder.CreateShl(One, B, LShr->getName(), + /*HasNUW=*/true), + One, Or->getName()); } if (NewOr) { - Value *NewAnd = Builder->CreateAnd(A, NewOr, And->getName()); + Value *NewAnd = Builder.CreateAnd(A, NewOr, And->getName()); Cmp.setOperand(0, NewAnd); return &Cmp; } @@ -1764,13 +1766,13 @@ Instruction *InstCombiner::foldICmpAndConstant(ICmpInst &Cmp, // (X & C2) != 0 -> (trunc X) < 0 // iff C2 is a power of 2 and it masks the sign bit of a legal integer type. const APInt *C2; - if (And->hasOneUse() && *C == 0 && match(Y, m_APInt(C2))) { + if (And->hasOneUse() && C->isNullValue() && match(Y, m_APInt(C2))) { int32_t ExactLogBase2 = C2->exactLogBase2(); if (ExactLogBase2 != -1 && DL.isLegalInteger(ExactLogBase2 + 1)) { Type *NTy = IntegerType::get(Cmp.getContext(), ExactLogBase2 + 1); if (And->getType()->isVectorTy()) NTy = VectorType::get(NTy, And->getType()->getVectorNumElements()); - Value *Trunc = Builder->CreateTrunc(X, NTy); + Value *Trunc = Builder.CreateTrunc(X, NTy); auto NewPred = Cmp.getPredicate() == CmpInst::ICMP_EQ ? CmpInst::ICMP_SGE : CmpInst::ICMP_SLT; return new ICmpInst(NewPred, Trunc, Constant::getNullValue(NTy)); @@ -1784,7 +1786,7 @@ Instruction *InstCombiner::foldICmpAndConstant(ICmpInst &Cmp, Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or, const APInt *C) { ICmpInst::Predicate Pred = Cmp.getPredicate(); - if (*C == 1) { + if (C->isOneValue()) { // icmp slt signum(V) 1 --> icmp slt V, 1 Value *V = nullptr; if (Pred == ICmpInst::ICMP_SLT && match(Or, m_Signum(m_Value(V)))) @@ -1792,7 +1794,16 @@ Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or, ConstantInt::get(V->getType(), 1)); } - if (!Cmp.isEquality() || *C != 0 || !Or->hasOneUse()) + // X | C == C --> X <=u C + // X | C != C --> X >u C + // iff C+1 is a power of 2 (C is a bitmask of the low bits) + if (Cmp.isEquality() && Cmp.getOperand(1) == Or->getOperand(1) && + (*C + 1).isPowerOf2()) { + Pred = (Pred == CmpInst::ICMP_EQ) ? CmpInst::ICMP_ULE : CmpInst::ICMP_UGT; + return new ICmpInst(Pred, Or->getOperand(0), Or->getOperand(1)); + } + + if (!Cmp.isEquality() || !C->isNullValue() || !Or->hasOneUse()) return nullptr; Value *P, *Q; @@ -1800,12 +1811,24 @@ Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or, // Simplify icmp eq (or (ptrtoint P), (ptrtoint Q)), 0 // -> and (icmp eq P, null), (icmp eq Q, null). Value *CmpP = - Builder->CreateICmp(Pred, P, ConstantInt::getNullValue(P->getType())); + Builder.CreateICmp(Pred, P, ConstantInt::getNullValue(P->getType())); Value *CmpQ = - Builder->CreateICmp(Pred, Q, ConstantInt::getNullValue(Q->getType())); - auto LogicOpc = Pred == ICmpInst::Predicate::ICMP_EQ ? Instruction::And - : Instruction::Or; - return BinaryOperator::Create(LogicOpc, CmpP, CmpQ); + Builder.CreateICmp(Pred, Q, ConstantInt::getNullValue(Q->getType())); + auto BOpc = Pred == CmpInst::ICMP_EQ ? Instruction::And : Instruction::Or; + return BinaryOperator::Create(BOpc, CmpP, CmpQ); + } + + // Are we using xors to bitwise check for a pair of (in)equalities? Convert to + // a shorter form that has more potential to be folded even further. + Value *X1, *X2, *X3, *X4; + if (match(Or->getOperand(0), m_OneUse(m_Xor(m_Value(X1), m_Value(X2)))) && + match(Or->getOperand(1), m_OneUse(m_Xor(m_Value(X3), m_Value(X4))))) { + // ((X1 ^ X2) || (X3 ^ X4)) == 0 --> (X1 == X2) && (X3 == X4) + // ((X1 ^ X2) || (X3 ^ X4)) != 0 --> (X1 != X2) || (X3 != X4) + Value *Cmp12 = Builder.CreateICmp(Pred, X1, X2); + Value *Cmp34 = Builder.CreateICmp(Pred, X3, X4); + auto BOpc = Pred == CmpInst::ICMP_EQ ? Instruction::And : Instruction::Or; + return BinaryOperator::Create(BOpc, Cmp12, Cmp34); } return nullptr; @@ -1914,61 +1937,89 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, ICmpInst::Predicate Pred = Cmp.getPredicate(); Value *X = Shl->getOperand(0); - if (Cmp.isEquality()) { - // If the shift is NUW, then it is just shifting out zeros, no need for an - // AND. - Constant *LShrC = ConstantInt::get(Shl->getType(), C->lshr(*ShiftAmt)); - if (Shl->hasNoUnsignedWrap()) - return new ICmpInst(Pred, X, LShrC); - - // If the shift is NSW and we compare to 0, then it is just shifting out - // sign bits, no need for an AND either. - if (Shl->hasNoSignedWrap() && *C == 0) - return new ICmpInst(Pred, X, LShrC); - - if (Shl->hasOneUse()) { - // Otherwise, strength reduce the shift into an and. - Constant *Mask = ConstantInt::get(Shl->getType(), - APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt->getZExtValue())); - - Value *And = Builder->CreateAnd(X, Mask, Shl->getName() + ".mask"); - return new ICmpInst(Pred, And, LShrC); + Type *ShType = Shl->getType(); + + // NSW guarantees that we are only shifting out sign bits from the high bits, + // so we can ASHR the compare constant without needing a mask and eliminate + // the shift. + if (Shl->hasNoSignedWrap()) { + if (Pred == ICmpInst::ICMP_SGT) { + // icmp Pred (shl nsw X, ShiftAmt), C --> icmp Pred X, (C >>s ShiftAmt) + APInt ShiftedC = C->ashr(*ShiftAmt); + return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); + } + if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) { + // This is the same code as the SGT case, but assert the pre-condition + // that is needed for this to work with equality predicates. + assert(C->ashr(*ShiftAmt).shl(*ShiftAmt) == *C && + "Compare known true or false was not folded"); + APInt ShiftedC = C->ashr(*ShiftAmt); + return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); + } + if (Pred == ICmpInst::ICMP_SLT) { + // SLE is the same as above, but SLE is canonicalized to SLT, so convert: + // (X << S) <=s C is equiv to X <=s (C >> S) for all C + // (X << S) <s (C + 1) is equiv to X <s (C >> S) + 1 if C <s SMAX + // (X << S) <s C is equiv to X <s ((C - 1) >> S) + 1 if C >s SMIN + assert(!C->isMinSignedValue() && "Unexpected icmp slt"); + APInt ShiftedC = (*C - 1).ashr(*ShiftAmt) + 1; + return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); + } + // If this is a signed comparison to 0 and the shift is sign preserving, + // use the shift LHS operand instead; isSignTest may change 'Pred', so only + // do that if we're sure to not continue on in this function. + if (isSignTest(Pred, *C)) + return new ICmpInst(Pred, X, Constant::getNullValue(ShType)); + } + + // NUW guarantees that we are only shifting out zero bits from the high bits, + // so we can LSHR the compare constant without needing a mask and eliminate + // the shift. + if (Shl->hasNoUnsignedWrap()) { + if (Pred == ICmpInst::ICMP_UGT) { + // icmp Pred (shl nuw X, ShiftAmt), C --> icmp Pred X, (C >>u ShiftAmt) + APInt ShiftedC = C->lshr(*ShiftAmt); + return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); + } + if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) { + // This is the same code as the UGT case, but assert the pre-condition + // that is needed for this to work with equality predicates. + assert(C->lshr(*ShiftAmt).shl(*ShiftAmt) == *C && + "Compare known true or false was not folded"); + APInt ShiftedC = C->lshr(*ShiftAmt); + return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); + } + if (Pred == ICmpInst::ICMP_ULT) { + // ULE is the same as above, but ULE is canonicalized to ULT, so convert: + // (X << S) <=u C is equiv to X <=u (C >> S) for all C + // (X << S) <u (C + 1) is equiv to X <u (C >> S) + 1 if C <u ~0u + // (X << S) <u C is equiv to X <u ((C - 1) >> S) + 1 if C >u 0 + assert(C->ugt(0) && "ult 0 should have been eliminated"); + APInt ShiftedC = (*C - 1).lshr(*ShiftAmt) + 1; + return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); } } - // If this is a signed comparison to 0 and the shift is sign preserving, - // use the shift LHS operand instead; isSignTest may change 'Pred', so only - // do that if we're sure to not continue on in this function. - if (Shl->hasNoSignedWrap() && isSignTest(Pred, *C)) - return new ICmpInst(Pred, X, Constant::getNullValue(X->getType())); + if (Cmp.isEquality() && Shl->hasOneUse()) { + // Strength-reduce the shift into an 'and'. + Constant *Mask = ConstantInt::get( + ShType, + APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt->getZExtValue())); + Value *And = Builder.CreateAnd(X, Mask, Shl->getName() + ".mask"); + Constant *LShrC = ConstantInt::get(ShType, C->lshr(*ShiftAmt)); + return new ICmpInst(Pred, And, LShrC); + } // Otherwise, if this is a comparison of the sign bit, simplify to and/test. bool TrueIfSigned = false; if (Shl->hasOneUse() && isSignBitCheck(Pred, *C, TrueIfSigned)) { // (X << 31) <s 0 --> (X & 1) != 0 Constant *Mask = ConstantInt::get( - X->getType(), + ShType, APInt::getOneBitSet(TypeBits, TypeBits - ShiftAmt->getZExtValue() - 1)); - Value *And = Builder->CreateAnd(X, Mask, Shl->getName() + ".mask"); + Value *And = Builder.CreateAnd(X, Mask, Shl->getName() + ".mask"); return new ICmpInst(TrueIfSigned ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ, - And, Constant::getNullValue(And->getType())); - } - - // When the shift is nuw and pred is >u or <=u, comparison only really happens - // in the pre-shifted bits. Since InstSimplify canonicalizes <=u into <u, the - // <=u case can be further converted to match <u (see below). - if (Shl->hasNoUnsignedWrap() && - (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULT)) { - // Derivation for the ult case: - // (X << S) <=u C is equiv to X <=u (C >> S) for all C - // (X << S) <u (C + 1) is equiv to X <u (C >> S) + 1 if C <u ~0u - // (X << S) <u C is equiv to X <u ((C - 1) >> S) + 1 if C >u 0 - assert((Pred != ICmpInst::ICMP_ULT || C->ugt(0)) && - "Encountered `ult 0` that should have been eliminated by " - "InstSimplify."); - APInt ShiftedC = Pred == ICmpInst::ICMP_ULT ? (*C - 1).lshr(*ShiftAmt) + 1 - : C->lshr(*ShiftAmt); - return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), ShiftedC)); + And, Constant::getNullValue(ShType)); } // Transform (icmp pred iM (shl iM %v, N), C) @@ -1981,11 +2032,11 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, if (Shl->hasOneUse() && Amt != 0 && C->countTrailingZeros() >= Amt && DL.isLegalInteger(TypeBits - Amt)) { Type *TruncTy = IntegerType::get(Cmp.getContext(), TypeBits - Amt); - if (X->getType()->isVectorTy()) - TruncTy = VectorType::get(TruncTy, X->getType()->getVectorNumElements()); + if (ShType->isVectorTy()) + TruncTy = VectorType::get(TruncTy, ShType->getVectorNumElements()); Constant *NewC = ConstantInt::get(TruncTy, C->ashr(*ShiftAmt).trunc(TypeBits - Amt)); - return new ICmpInst(Pred, Builder->CreateTrunc(X, TruncTy), NewC); + return new ICmpInst(Pred, Builder.CreateTrunc(X, TruncTy), NewC); } return nullptr; @@ -1999,7 +2050,8 @@ Instruction *InstCombiner::foldICmpShrConstant(ICmpInst &Cmp, // icmp eq/ne (shr X, Y), 0 --> icmp eq/ne X, 0 Value *X = Shr->getOperand(0); CmpInst::Predicate Pred = Cmp.getPredicate(); - if (Cmp.isEquality() && Shr->isExact() && Shr->hasOneUse() && *C == 0) + if (Cmp.isEquality() && Shr->isExact() && Shr->hasOneUse() && + C->isNullValue()) return new ICmpInst(Pred, X, Cmp.getOperand(1)); const APInt *ShiftVal; @@ -2036,8 +2088,8 @@ Instruction *InstCombiner::foldICmpShrConstant(ICmpInst &Cmp, Constant *DivCst = ConstantInt::get( Shr->getType(), APInt::getOneBitSet(TypeBits, ShAmtVal)); - Value *Tmp = IsAShr ? Builder->CreateSDiv(X, DivCst, "", Shr->isExact()) - : Builder->CreateUDiv(X, DivCst, "", Shr->isExact()); + Value *Tmp = IsAShr ? Builder.CreateSDiv(X, DivCst, "", Shr->isExact()) + : Builder.CreateUDiv(X, DivCst, "", Shr->isExact()); Cmp.setOperand(0, Tmp); @@ -2075,7 +2127,7 @@ Instruction *InstCombiner::foldICmpShrConstant(ICmpInst &Cmp, // Otherwise strength reduce the shift into an 'and'. APInt Val(APInt::getHighBitsSet(TypeBits, TypeBits - ShAmtVal)); Constant *Mask = ConstantInt::get(Shr->getType(), Val); - Value *And = Builder->CreateAnd(X, Mask, Shr->getName() + ".mask"); + Value *And = Builder.CreateAnd(X, Mask, Shr->getName() + ".mask"); return new ICmpInst(Pred, And, ShiftedCmpRHS); } @@ -2090,7 +2142,7 @@ Instruction *InstCombiner::foldICmpUDivConstant(ICmpInst &Cmp, if (!match(UDiv->getOperand(0), m_APInt(C2))) return nullptr; - assert(C2 != 0 && "udiv 0, X should have been simplified already."); + assert(*C2 != 0 && "udiv 0, X should have been simplified already."); // (icmp ugt (udiv C2, Y), C) -> (icmp ule Y, C2/(C+1)) Value *Y = UDiv->getOperand(1); @@ -2103,7 +2155,7 @@ Instruction *InstCombiner::foldICmpUDivConstant(ICmpInst &Cmp, // (icmp ult (udiv C2, Y), C) -> (icmp ugt Y, C2/C) if (Cmp.getPredicate() == ICmpInst::ICMP_ULT) { - assert(C != 0 && "icmp ult X, 0 should have been simplified already."); + assert(*C != 0 && "icmp ult X, 0 should have been simplified already."); return new ICmpInst(ICmpInst::ICMP_UGT, Y, ConstantInt::get(Y->getType(), C2->udiv(*C))); } @@ -2141,7 +2193,8 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp, // INT_MIN will also fail if the divisor is 1. Although folds of all these // division-by-constant cases should be present, we can not assert that they // have happened before we reach this icmp instruction. - if (*C2 == 0 || *C2 == 1 || (DivIsSigned && C2->isAllOnesValue())) + if (C2->isNullValue() || C2->isOneValue() || + (DivIsSigned && C2->isAllOnesValue())) return nullptr; // TODO: We could do all of the computations below using APInt. @@ -2187,7 +2240,7 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp, HiOverflow = addWithOverflow(HiBound, LoBound, RangeSize, false); } } else if (C2->isStrictlyPositive()) { // Divisor is > 0. - if (*C == 0) { // (X / pos) op 0 + if (C->isNullValue()) { // (X / pos) op 0 // Can't overflow. e.g. X/2 op 0 --> [-1, 2) LoBound = ConstantExpr::getNeg(SubOne(RangeSize)); HiBound = RangeSize; @@ -2208,7 +2261,7 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp, } else if (C2->isNegative()) { // Divisor is < 0. if (Div->isExact()) RangeSize = ConstantExpr::getNeg(RangeSize); - if (*C == 0) { // (X / neg) op 0 + if (C->isNullValue()) { // (X / neg) op 0 // e.g. X/-5 op 0 --> [-4, 5) LoBound = AddOne(RangeSize); HiBound = ConstantExpr::getNeg(RangeSize); @@ -2238,7 +2291,7 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp, default: llvm_unreachable("Unhandled icmp opcode!"); case ICmpInst::ICMP_EQ: if (LoOverflow && HiOverflow) - return replaceInstUsesWith(Cmp, Builder->getFalse()); + return replaceInstUsesWith(Cmp, Builder.getFalse()); if (HiOverflow) return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, X, LoBound); @@ -2250,7 +2303,7 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp, HiBound->getUniqueInteger(), DivIsSigned, true)); case ICmpInst::ICMP_NE: if (LoOverflow && HiOverflow) - return replaceInstUsesWith(Cmp, Builder->getTrue()); + return replaceInstUsesWith(Cmp, Builder.getTrue()); if (HiOverflow) return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, X, LoBound); @@ -2264,16 +2317,16 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp, case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_SLT: if (LoOverflow == +1) // Low bound is greater than input range. - return replaceInstUsesWith(Cmp, Builder->getTrue()); + return replaceInstUsesWith(Cmp, Builder.getTrue()); if (LoOverflow == -1) // Low bound is less than input range. - return replaceInstUsesWith(Cmp, Builder->getFalse()); + return replaceInstUsesWith(Cmp, Builder.getFalse()); return new ICmpInst(Pred, X, LoBound); case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_SGT: if (HiOverflow == +1) // High bound greater than input range. - return replaceInstUsesWith(Cmp, Builder->getFalse()); + return replaceInstUsesWith(Cmp, Builder.getFalse()); if (HiOverflow == -1) // High bound less than input range. - return replaceInstUsesWith(Cmp, Builder->getTrue()); + return replaceInstUsesWith(Cmp, Builder.getTrue()); if (Pred == ICmpInst::ICMP_UGT) return new ICmpInst(ICmpInst::ICMP_UGE, X, HiBound); return new ICmpInst(ICmpInst::ICMP_SGE, X, HiBound); @@ -2300,15 +2353,15 @@ Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &Cmp, return new ICmpInst(ICmpInst::ICMP_SGE, X, Y); // (icmp sgt (sub nsw X, Y), 0) -> (icmp sgt X, Y) - if (Pred == ICmpInst::ICMP_SGT && *C == 0) + if (Pred == ICmpInst::ICMP_SGT && C->isNullValue()) return new ICmpInst(ICmpInst::ICMP_SGT, X, Y); // (icmp slt (sub nsw X, Y), 0) -> (icmp slt X, Y) - if (Pred == ICmpInst::ICMP_SLT && *C == 0) + if (Pred == ICmpInst::ICMP_SLT && C->isNullValue()) return new ICmpInst(ICmpInst::ICMP_SLT, X, Y); // (icmp slt (sub nsw X, Y), 1) -> (icmp sle X, Y) - if (Pred == ICmpInst::ICMP_SLT && *C == 1) + if (Pred == ICmpInst::ICMP_SLT && C->isOneValue()) return new ICmpInst(ICmpInst::ICMP_SLE, X, Y); } @@ -2320,12 +2373,12 @@ Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &Cmp, // iff (C2 & (C - 1)) == C - 1 and C is a power of 2 if (Pred == ICmpInst::ICMP_ULT && C->isPowerOf2() && (*C2 & (*C - 1)) == (*C - 1)) - return new ICmpInst(ICmpInst::ICMP_EQ, Builder->CreateOr(Y, *C - 1), X); + return new ICmpInst(ICmpInst::ICMP_EQ, Builder.CreateOr(Y, *C - 1), X); // C2 - Y >u C -> (Y | C) != C2 // iff C2 & C == C and C + 1 is a power of 2 if (Pred == ICmpInst::ICMP_UGT && (*C + 1).isPowerOf2() && (*C2 & *C) == *C) - return new ICmpInst(ICmpInst::ICMP_NE, Builder->CreateOr(Y, *C), X); + return new ICmpInst(ICmpInst::ICMP_NE, Builder.CreateOr(Y, *C), X); return nullptr; } @@ -2342,14 +2395,30 @@ Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp, // Fold icmp pred (add X, C2), C. Value *X = Add->getOperand(0); Type *Ty = Add->getType(); - auto CR = - ConstantRange::makeExactICmpRegion(Cmp.getPredicate(), *C).subtract(*C2); + CmpInst::Predicate Pred = Cmp.getPredicate(); + + // If the add does not wrap, we can always adjust the compare by subtracting + // the constants. Equality comparisons are handled elsewhere. SGE/SLE are + // canonicalized to SGT/SLT. + if (Add->hasNoSignedWrap() && + (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLT)) { + bool Overflow; + APInt NewC = C->ssub_ov(*C2, Overflow); + // If there is overflow, the result must be true or false. + // TODO: Can we assert there is no overflow because InstSimplify always + // handles those cases? + if (!Overflow) + // icmp Pred (add nsw X, C2), C --> icmp Pred X, (C - C2) + return new ICmpInst(Pred, X, ConstantInt::get(Ty, NewC)); + } + + auto CR = ConstantRange::makeExactICmpRegion(Pred, *C).subtract(*C2); const APInt &Upper = CR.getUpper(); const APInt &Lower = CR.getLower(); if (Cmp.isSigned()) { - if (Lower.isSignBit()) + if (Lower.isSignMask()) return new ICmpInst(ICmpInst::ICMP_SLT, X, ConstantInt::get(Ty, Upper)); - if (Upper.isSignBit()) + if (Upper.isSignMask()) return new ICmpInst(ICmpInst::ICMP_SGE, X, ConstantInt::get(Ty, Lower)); } else { if (Lower.isMinValue()) @@ -2364,22 +2433,91 @@ Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp, // X+C <u C2 -> (X & -C2) == C // iff C & (C2-1) == 0 // C2 is a power of 2 - if (Cmp.getPredicate() == ICmpInst::ICMP_ULT && C->isPowerOf2() && - (*C2 & (*C - 1)) == 0) - return new ICmpInst(ICmpInst::ICMP_EQ, Builder->CreateAnd(X, -(*C)), + if (Pred == ICmpInst::ICMP_ULT && C->isPowerOf2() && (*C2 & (*C - 1)) == 0) + return new ICmpInst(ICmpInst::ICMP_EQ, Builder.CreateAnd(X, -(*C)), ConstantExpr::getNeg(cast<Constant>(Y))); // X+C >u C2 -> (X & ~C2) != C // iff C & C2 == 0 // C2+1 is a power of 2 - if (Cmp.getPredicate() == ICmpInst::ICMP_UGT && (*C + 1).isPowerOf2() && - (*C2 & *C) == 0) - return new ICmpInst(ICmpInst::ICMP_NE, Builder->CreateAnd(X, ~(*C)), + if (Pred == ICmpInst::ICMP_UGT && (*C + 1).isPowerOf2() && (*C2 & *C) == 0) + return new ICmpInst(ICmpInst::ICMP_NE, Builder.CreateAnd(X, ~(*C)), ConstantExpr::getNeg(cast<Constant>(Y))); return nullptr; } +bool InstCombiner::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS, + Value *&RHS, ConstantInt *&Less, + ConstantInt *&Equal, + ConstantInt *&Greater) { + // TODO: Generalize this to work with other comparison idioms or ensure + // they get canonicalized into this form. + + // select i1 (a == b), i32 Equal, i32 (select i1 (a < b), i32 Less, i32 + // Greater), where Equal, Less and Greater are placeholders for any three + // constants. + ICmpInst::Predicate PredA, PredB; + if (match(SI->getTrueValue(), m_ConstantInt(Equal)) && + match(SI->getCondition(), m_ICmp(PredA, m_Value(LHS), m_Value(RHS))) && + PredA == ICmpInst::ICMP_EQ && + match(SI->getFalseValue(), + m_Select(m_ICmp(PredB, m_Specific(LHS), m_Specific(RHS)), + m_ConstantInt(Less), m_ConstantInt(Greater))) && + PredB == ICmpInst::ICMP_SLT) { + return true; + } + return false; +} + +Instruction *InstCombiner::foldICmpSelectConstant(ICmpInst &Cmp, + Instruction *Select, + ConstantInt *C) { + + assert(C && "Cmp RHS should be a constant int!"); + // If we're testing a constant value against the result of a three way + // comparison, the result can be expressed directly in terms of the + // original values being compared. Note: We could possibly be more + // aggressive here and remove the hasOneUse test. The original select is + // really likely to simplify or sink when we remove a test of the result. + Value *OrigLHS, *OrigRHS; + ConstantInt *C1LessThan, *C2Equal, *C3GreaterThan; + if (Cmp.hasOneUse() && + matchThreeWayIntCompare(cast<SelectInst>(Select), OrigLHS, OrigRHS, + C1LessThan, C2Equal, C3GreaterThan)) { + assert(C1LessThan && C2Equal && C3GreaterThan); + + bool TrueWhenLessThan = + ConstantExpr::getCompare(Cmp.getPredicate(), C1LessThan, C) + ->isAllOnesValue(); + bool TrueWhenEqual = + ConstantExpr::getCompare(Cmp.getPredicate(), C2Equal, C) + ->isAllOnesValue(); + bool TrueWhenGreaterThan = + ConstantExpr::getCompare(Cmp.getPredicate(), C3GreaterThan, C) + ->isAllOnesValue(); + + // This generates the new instruction that will replace the original Cmp + // Instruction. Instead of enumerating the various combinations when + // TrueWhenLessThan, TrueWhenEqual and TrueWhenGreaterThan are true versus + // false, we rely on chaining of ORs and future passes of InstCombine to + // simplify the OR further (i.e. a s< b || a == b becomes a s<= b). + + // When none of the three constants satisfy the predicate for the RHS (C), + // the entire original Cmp can be simplified to a false. + Value *Cond = Builder.getFalse(); + if (TrueWhenLessThan) + Cond = Builder.CreateOr(Cond, Builder.CreateICmp(ICmpInst::ICMP_SLT, OrigLHS, OrigRHS)); + if (TrueWhenEqual) + Cond = Builder.CreateOr(Cond, Builder.CreateICmp(ICmpInst::ICMP_EQ, OrigLHS, OrigRHS)); + if (TrueWhenGreaterThan) + Cond = Builder.CreateOr(Cond, Builder.CreateICmp(ICmpInst::ICMP_SGT, OrigLHS, OrigRHS)); + + return replaceInstUsesWith(Cmp, Cond); + } + return nullptr; +} + /// Try to fold integer comparisons with a constant operand: icmp Pred X, C /// where X is some kind of instruction. Instruction *InstCombiner::foldICmpInstWithConstant(ICmpInst &Cmp) { @@ -2439,11 +2577,28 @@ Instruction *InstCombiner::foldICmpInstWithConstant(ICmpInst &Cmp) { return I; } + // Match against CmpInst LHS being instructions other than binary operators. Instruction *LHSI; - if (match(Cmp.getOperand(0), m_Instruction(LHSI)) && - LHSI->getOpcode() == Instruction::Trunc) - if (Instruction *I = foldICmpTruncConstant(Cmp, LHSI, C)) - return I; + if (match(Cmp.getOperand(0), m_Instruction(LHSI))) { + switch (LHSI->getOpcode()) { + case Instruction::Select: + { + // For now, we only support constant integers while folding the + // ICMP(SELECT)) pattern. We can extend this to support vector of integers + // similar to the cases handled by binary ops above. + if (ConstantInt *ConstRHS = dyn_cast<ConstantInt>(Cmp.getOperand(1))) + if (Instruction *I = foldICmpSelectConstant(Cmp, LHSI, ConstRHS)) + return I; + break; + } + case Instruction::Trunc: + if (Instruction *I = foldICmpTruncConstant(Cmp, LHSI, C)) + return I; + break; + default: + break; + } + } if (Instruction *I = foldICmpIntrinsicWithConstant(Cmp, C)) return I; @@ -2469,10 +2624,10 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, switch (BO->getOpcode()) { case Instruction::SRem: // If we have a signed (X % (2^c)) == 0, turn it into an unsigned one. - if (*C == 0 && BO->hasOneUse()) { + if (C->isNullValue() && BO->hasOneUse()) { const APInt *BOC; if (match(BOp1, m_APInt(BOC)) && BOC->sgt(1) && BOC->isPowerOf2()) { - Value *NewRem = Builder->CreateURem(BOp0, BOp1, BO->getName()); + Value *NewRem = Builder.CreateURem(BOp0, BOp1, BO->getName()); return new ICmpInst(Pred, NewRem, Constant::getNullValue(BO->getType())); } @@ -2486,7 +2641,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, Constant *SubC = ConstantExpr::getSub(RHS, cast<Constant>(BOp1)); return new ICmpInst(Pred, BOp0, SubC); } - } else if (*C == 0) { + } else if (C->isNullValue()) { // Replace ((add A, B) != 0) with (A != -B) if A or B is // efficiently invertible, or if the add has just this one use. if (Value *NegVal = dyn_castNegVal(BOp1)) @@ -2494,7 +2649,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, if (Value *NegVal = dyn_castNegVal(BOp0)) return new ICmpInst(Pred, NegVal, BOp1); if (BO->hasOneUse()) { - Value *Neg = Builder->CreateNeg(BOp1); + Value *Neg = Builder.CreateNeg(BOp1); Neg->takeName(BO); return new ICmpInst(Pred, BOp0, Neg); } @@ -2507,7 +2662,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, // For the xor case, we can xor two constants together, eliminating // the explicit xor. return new ICmpInst(Pred, BOp0, ConstantExpr::getXor(RHS, BOC)); - } else if (*C == 0) { + } else if (C->isNullValue()) { // Replace ((xor A, B) != 0) with (A != B) return new ICmpInst(Pred, BOp0, BOp1); } @@ -2520,7 +2675,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, // Replace ((sub BOC, B) != C) with (B != BOC-C). Constant *SubC = ConstantExpr::getSub(cast<Constant>(BOp0), RHS); return new ICmpInst(Pred, BOp1, SubC); - } else if (*C == 0) { + } else if (C->isNullValue()) { // Replace ((sub A, B) != 0) with (A != B). return new ICmpInst(Pred, BOp0, BOp1); } @@ -2533,7 +2688,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, // Replace (X | C) == -1 with (X & ~C) == ~C. // This removes the -1 constant. Constant *NotBOC = ConstantExpr::getNot(cast<Constant>(BOp1)); - Value *And = Builder->CreateAnd(BOp0, NotBOC); + Value *And = Builder.CreateAnd(BOp0, NotBOC); return new ICmpInst(Pred, And, NotBOC); } break; @@ -2551,14 +2706,14 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, break; // Replace (and X, (1 << size(X)-1) != 0) with x s< 0 - if (BOC->isSignBit()) { + if (BOC->isSignMask()) { Constant *Zero = Constant::getNullValue(BOp0->getType()); auto NewPred = isICMP_NE ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_SGE; return new ICmpInst(NewPred, BOp0, Zero); } // ((X & ~7) == 0) --> X < 8 - if (*C == 0 && (~(*BOC) + 1).isPowerOf2()) { + if (C->isNullValue() && (~(*BOC) + 1).isPowerOf2()) { Constant *NegBOC = ConstantExpr::getNeg(cast<Constant>(BOp1)); auto NewPred = isICMP_NE ? ICmpInst::ICMP_UGE : ICmpInst::ICMP_ULT; return new ICmpInst(NewPred, BOp0, NegBOC); @@ -2567,9 +2722,9 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, break; } case Instruction::Mul: - if (*C == 0 && BO->hasNoSignedWrap()) { + if (C->isNullValue() && BO->hasNoSignedWrap()) { const APInt *BOC; - if (match(BOp1, m_APInt(BOC)) && *BOC != 0) { + if (match(BOp1, m_APInt(BOC)) && !BOC->isNullValue()) { // The trivial case (mul X, 0) is handled by InstSimplify. // General case : (mul X, C) != 0 iff X != 0 // (mul X, C) == 0 iff X == 0 @@ -2578,7 +2733,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, } break; case Instruction::UDiv: - if (*C == 0) { + if (C->isNullValue()) { // (icmp eq/ne (udiv A, B), 0) -> (icmp ugt/ule i32 B, A) auto NewPred = isICMP_NE ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_UGT; return new ICmpInst(NewPred, BOp1, BOp0); @@ -2597,32 +2752,35 @@ Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, if (!II || !Cmp.isEquality()) return nullptr; - // Handle icmp {eq|ne} <intrinsic>, intcst. + // Handle icmp {eq|ne} <intrinsic>, Constant. + Type *Ty = II->getType(); switch (II->getIntrinsicID()) { case Intrinsic::bswap: Worklist.Add(II); Cmp.setOperand(0, II->getArgOperand(0)); - Cmp.setOperand(1, Builder->getInt(C->byteSwap())); + Cmp.setOperand(1, ConstantInt::get(Ty, C->byteSwap())); return &Cmp; + case Intrinsic::ctlz: case Intrinsic::cttz: // ctz(A) == bitwidth(A) -> A == 0 and likewise for != if (*C == C->getBitWidth()) { Worklist.Add(II); Cmp.setOperand(0, II->getArgOperand(0)); - Cmp.setOperand(1, ConstantInt::getNullValue(II->getType())); + Cmp.setOperand(1, ConstantInt::getNullValue(Ty)); return &Cmp; } break; + case Intrinsic::ctpop: { // popcount(A) == 0 -> A == 0 and likewise for != // popcount(A) == bitwidth(A) -> A == -1 and likewise for != - bool IsZero = *C == 0; + bool IsZero = C->isNullValue(); if (IsZero || *C == C->getBitWidth()) { Worklist.Add(II); Cmp.setOperand(0, II->getArgOperand(0)); - auto *NewOp = IsZero ? Constant::getNullValue(II->getType()) - : Constant::getAllOnesValue(II->getType()); + auto *NewOp = + IsZero ? Constant::getNullValue(Ty) : Constant::getAllOnesValue(Ty); Cmp.setOperand(1, NewOp); return &Cmp; } @@ -2631,6 +2789,7 @@ Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, default: break; } + return nullptr; } @@ -2656,7 +2815,7 @@ Instruction *InstCombiner::foldICmpInstWithConstantNotInt(ICmpInst &I) { // block. If in the same block, we're encouraging jump threading. If // not, we are just pessimizing the code by making an i1 phi. if (LHSI->getParent() == I.getParent()) - if (Instruction *NV = FoldOpIntoPhi(I)) + if (Instruction *NV = foldOpIntoPhi(I, cast<PHINode>(LHSI))) return NV; break; case Instruction::Select: { @@ -2698,11 +2857,11 @@ Instruction *InstCombiner::foldICmpInstWithConstantNotInt(ICmpInst &I) { } if (Transform) { if (!Op1) - Op1 = Builder->CreateICmp(I.getPredicate(), LHSI->getOperand(1), RHSC, - I.getName()); + Op1 = Builder.CreateICmp(I.getPredicate(), LHSI->getOperand(1), RHSC, + I.getName()); if (!Op2) - Op2 = Builder->CreateICmp(I.getPredicate(), LHSI->getOperand(2), RHSC, - I.getName()); + Op2 = Builder.CreateICmp(I.getPredicate(), LHSI->getOperand(2), RHSC, + I.getName()); return SelectInst::Create(LHSI->getOperand(0), Op1, Op2); } break; @@ -2733,6 +2892,9 @@ Instruction *InstCombiner::foldICmpInstWithConstantNotInt(ICmpInst &I) { } /// Try to fold icmp (binop), X or icmp X, (binop). +/// TODO: A large part of this logic is duplicated in InstSimplify's +/// simplifyICmpWithBinOp(). We should be able to share that and avoid the code +/// duplication. Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); @@ -2742,7 +2904,7 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { if (!BO0 && !BO1) return nullptr; - CmpInst::Predicate Pred = I.getPredicate(); + const CmpInst::Predicate Pred = I.getPredicate(); bool NoOp0WrapProblem = false, NoOp1WrapProblem = false; if (BO0 && isa<OverflowingBinaryOperator>(BO0)) NoOp0WrapProblem = @@ -2767,12 +2929,6 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { D = BO1->getOperand(1); } - // icmp (X+cst) < 0 --> X < -cst - if (NoOp0WrapProblem && ICmpInst::isSigned(Pred) && match(Op1, m_Zero())) - if (ConstantInt *RHSC = dyn_cast_or_null<ConstantInt>(B)) - if (!RHSC->isMinValue(/*isSigned=*/true)) - return new ICmpInst(Pred, A, ConstantExpr::getNeg(RHSC)); - // icmp (X+Y), X -> icmp Y, 0 for equalities or if there is no overflow. if ((A == Op1 || B == Op1) && NoOp0WrapProblem) return new ICmpInst(Pred, A == Op1 ? B : A, @@ -2847,6 +3003,31 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SLT && match(D, m_One())) return new ICmpInst(CmpInst::ICMP_SLE, Op0, C); + // TODO: The subtraction-related identities shown below also hold, but + // canonicalization from (X -nuw 1) to (X + -1) means that the combinations + // wouldn't happen even if they were implemented. + // + // icmp ult (X - 1), Y -> icmp ule X, Y + // icmp uge (X - 1), Y -> icmp ugt X, Y + // icmp ugt X, (Y - 1) -> icmp uge X, Y + // icmp ule X, (Y - 1) -> icmp ult X, Y + + // icmp ule (X + 1), Y -> icmp ult X, Y + if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_ULE && match(B, m_One())) + return new ICmpInst(CmpInst::ICMP_ULT, A, Op1); + + // icmp ugt (X + 1), Y -> icmp uge X, Y + if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_UGT && match(B, m_One())) + return new ICmpInst(CmpInst::ICMP_UGE, A, Op1); + + // icmp uge X, (Y + 1) -> icmp ugt X, Y + if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_UGE && match(D, m_One())) + return new ICmpInst(CmpInst::ICMP_UGT, Op0, C); + + // icmp ult X, (Y + 1) -> icmp ule X, Y + if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_ULT && match(D, m_One())) + return new ICmpInst(CmpInst::ICMP_ULE, Op0, C); + // if C1 has greater magnitude than C2: // icmp (X + C1), (Y + C2) -> icmp (X + C3), Y // s.t. C3 = C1 - C2 @@ -2864,12 +3045,12 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { APInt AP1Abs = C1->getValue().abs(); APInt AP2Abs = C2->getValue().abs(); if (AP1Abs.uge(AP2Abs)) { - ConstantInt *C3 = Builder->getInt(AP1 - AP2); - Value *NewAdd = Builder->CreateNSWAdd(A, C3); + ConstantInt *C3 = Builder.getInt(AP1 - AP2); + Value *NewAdd = Builder.CreateNSWAdd(A, C3); return new ICmpInst(Pred, NewAdd, C); } else { - ConstantInt *C3 = Builder->getInt(AP2 - AP1); - Value *NewAdd = Builder->CreateNSWAdd(C, C3); + ConstantInt *C3 = Builder.getInt(AP2 - AP1); + Value *NewAdd = Builder.CreateNSWAdd(C, C3); return new ICmpInst(Pred, A, NewAdd); } } @@ -2956,56 +3137,69 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { break; case Instruction::Add: case Instruction::Sub: - case Instruction::Xor: + case Instruction::Xor: { if (I.isEquality()) // a+x icmp eq/ne b+x --> a icmp b - return new ICmpInst(I.getPredicate(), BO0->getOperand(0), - BO1->getOperand(0)); - // icmp u/s (a ^ signbit), (b ^ signbit) --> icmp s/u a, b - if (ConstantInt *CI = dyn_cast<ConstantInt>(BO0->getOperand(1))) { - if (CI->getValue().isSignBit()) { - ICmpInst::Predicate Pred = + return new ICmpInst(Pred, BO0->getOperand(0), BO1->getOperand(0)); + + const APInt *C; + if (match(BO0->getOperand(1), m_APInt(C))) { + // icmp u/s (a ^ signmask), (b ^ signmask) --> icmp s/u a, b + if (C->isSignMask()) { + ICmpInst::Predicate NewPred = I.isSigned() ? I.getUnsignedPredicate() : I.getSignedPredicate(); - return new ICmpInst(Pred, BO0->getOperand(0), BO1->getOperand(0)); + return new ICmpInst(NewPred, BO0->getOperand(0), BO1->getOperand(0)); } - if (BO0->getOpcode() == Instruction::Xor && CI->isMaxValue(true)) { - ICmpInst::Predicate Pred = + // icmp u/s (a ^ maxsignval), (b ^ maxsignval) --> icmp s/u' a, b + if (BO0->getOpcode() == Instruction::Xor && C->isMaxSignedValue()) { + ICmpInst::Predicate NewPred = I.isSigned() ? I.getUnsignedPredicate() : I.getSignedPredicate(); - Pred = I.getSwappedPredicate(Pred); - return new ICmpInst(Pred, BO0->getOperand(0), BO1->getOperand(0)); + NewPred = I.getSwappedPredicate(NewPred); + return new ICmpInst(NewPred, BO0->getOperand(0), BO1->getOperand(0)); } } break; - case Instruction::Mul: + } + case Instruction::Mul: { if (!I.isEquality()) break; - if (ConstantInt *CI = dyn_cast<ConstantInt>(BO0->getOperand(1))) { - // a * Cst icmp eq/ne b * Cst --> a & Mask icmp b & Mask - // Mask = -1 >> count-trailing-zeros(Cst). - if (!CI->isZero() && !CI->isOne()) { - const APInt &AP = CI->getValue(); - ConstantInt *Mask = ConstantInt::get( - I.getContext(), - APInt::getLowBitsSet(AP.getBitWidth(), - AP.getBitWidth() - AP.countTrailingZeros())); - Value *And1 = Builder->CreateAnd(BO0->getOperand(0), Mask); - Value *And2 = Builder->CreateAnd(BO1->getOperand(0), Mask); - return new ICmpInst(I.getPredicate(), And1, And2); + const APInt *C; + if (match(BO0->getOperand(1), m_APInt(C)) && !C->isNullValue() && + !C->isOneValue()) { + // icmp eq/ne (X * C), (Y * C) --> icmp (X & Mask), (Y & Mask) + // Mask = -1 >> count-trailing-zeros(C). + if (unsigned TZs = C->countTrailingZeros()) { + Constant *Mask = ConstantInt::get( + BO0->getType(), + APInt::getLowBitsSet(C->getBitWidth(), C->getBitWidth() - TZs)); + Value *And1 = Builder.CreateAnd(BO0->getOperand(0), Mask); + Value *And2 = Builder.CreateAnd(BO1->getOperand(0), Mask); + return new ICmpInst(Pred, And1, And2); } + // If there are no trailing zeros in the multiplier, just eliminate + // the multiplies (no masking is needed): + // icmp eq/ne (X * C), (Y * C) --> icmp eq/ne X, Y + return new ICmpInst(Pred, BO0->getOperand(0), BO1->getOperand(0)); } break; + } case Instruction::UDiv: case Instruction::LShr: - if (I.isSigned()) + if (I.isSigned() || !BO0->isExact() || !BO1->isExact()) break; - LLVM_FALLTHROUGH; + return new ICmpInst(Pred, BO0->getOperand(0), BO1->getOperand(0)); + case Instruction::SDiv: + if (!I.isEquality() || !BO0->isExact() || !BO1->isExact()) + break; + return new ICmpInst(Pred, BO0->getOperand(0), BO1->getOperand(0)); + case Instruction::AShr: if (!BO0->isExact() || !BO1->isExact()) break; - return new ICmpInst(I.getPredicate(), BO0->getOperand(0), - BO1->getOperand(0)); + return new ICmpInst(Pred, BO0->getOperand(0), BO1->getOperand(0)); + case Instruction::Shl: { bool NUW = BO0->hasNoUnsignedWrap() && BO1->hasNoUnsignedWrap(); bool NSW = BO0->hasNoSignedWrap() && BO1->hasNoSignedWrap(); @@ -3013,8 +3207,7 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { break; if (!NSW && I.isSigned()) break; - return new ICmpInst(I.getPredicate(), BO0->getOperand(0), - BO1->getOperand(0)); + return new ICmpInst(Pred, BO0->getOperand(0), BO1->getOperand(0)); } } } @@ -3022,10 +3215,9 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { if (BO0) { // Transform A & (L - 1) `ult` L --> L != 0 auto LSubOne = m_Add(m_Specific(Op1), m_AllOnes()); - auto BitwiseAnd = - m_CombineOr(m_And(m_Value(), LSubOne), m_And(LSubOne, m_Value())); + auto BitwiseAnd = m_c_And(m_Value(), LSubOne); - if (match(BO0, BitwiseAnd) && I.getPredicate() == ICmpInst::ICMP_ULT) { + if (match(BO0, BitwiseAnd) && Pred == ICmpInst::ICMP_ULT) { auto *Zero = Constant::getNullValue(BO0->getType()); return new ICmpInst(ICmpInst::ICMP_NE, Op1, Zero); } @@ -3126,12 +3318,12 @@ Instruction *InstCombiner::foldICmpEquality(ICmpInst &I) { return nullptr; Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + const CmpInst::Predicate Pred = I.getPredicate(); Value *A, *B, *C, *D; if (match(Op0, m_Xor(m_Value(A), m_Value(B)))) { if (A == Op1 || B == Op1) { // (A^B) == A -> B == 0 Value *OtherVal = A == Op1 ? B : A; - return new ICmpInst(I.getPredicate(), OtherVal, - Constant::getNullValue(A->getType())); + return new ICmpInst(Pred, OtherVal, Constant::getNullValue(A->getType())); } if (match(Op1, m_Xor(m_Value(C), m_Value(D)))) { @@ -3139,28 +3331,27 @@ Instruction *InstCombiner::foldICmpEquality(ICmpInst &I) { ConstantInt *C1, *C2; if (match(B, m_ConstantInt(C1)) && match(D, m_ConstantInt(C2)) && Op1->hasOneUse()) { - Constant *NC = Builder->getInt(C1->getValue() ^ C2->getValue()); - Value *Xor = Builder->CreateXor(C, NC); - return new ICmpInst(I.getPredicate(), A, Xor); + Constant *NC = Builder.getInt(C1->getValue() ^ C2->getValue()); + Value *Xor = Builder.CreateXor(C, NC); + return new ICmpInst(Pred, A, Xor); } // A^B == A^D -> B == D if (A == C) - return new ICmpInst(I.getPredicate(), B, D); + return new ICmpInst(Pred, B, D); if (A == D) - return new ICmpInst(I.getPredicate(), B, C); + return new ICmpInst(Pred, B, C); if (B == C) - return new ICmpInst(I.getPredicate(), A, D); + return new ICmpInst(Pred, A, D); if (B == D) - return new ICmpInst(I.getPredicate(), A, C); + return new ICmpInst(Pred, A, C); } } if (match(Op1, m_Xor(m_Value(A), m_Value(B))) && (A == Op0 || B == Op0)) { // A == (A^B) -> B == 0 Value *OtherVal = A == Op0 ? B : A; - return new ICmpInst(I.getPredicate(), OtherVal, - Constant::getNullValue(A->getType())); + return new ICmpInst(Pred, OtherVal, Constant::getNullValue(A->getType())); } // (X&Z) == (Y&Z) -> (X^Y) & Z == 0 @@ -3187,8 +3378,8 @@ Instruction *InstCombiner::foldICmpEquality(ICmpInst &I) { } if (X) { // Build (X^Y) & Z - Op1 = Builder->CreateXor(X, Y); - Op1 = Builder->CreateAnd(Op1, Z); + Op1 = Builder.CreateXor(X, Y); + Op1 = Builder.CreateAnd(Op1, Z); I.setOperand(0, Op1); I.setOperand(1, Constant::getNullValue(Op1->getType())); return &I; @@ -3205,8 +3396,7 @@ Instruction *InstCombiner::foldICmpEquality(ICmpInst &I) { APInt Pow2 = Cst1->getValue() + 1; if (Pow2.isPowerOf2() && isa<IntegerType>(A->getType()) && Pow2.logBase2() == cast<IntegerType>(A->getType())->getBitWidth()) - return new ICmpInst(I.getPredicate(), A, - Builder->CreateTrunc(B, A->getType())); + return new ICmpInst(Pred, A, Builder.CreateTrunc(B, A->getType())); } // (A >> C) == (B >> C) --> (A^B) u< (1 << C) @@ -3218,12 +3408,11 @@ Instruction *InstCombiner::foldICmpEquality(ICmpInst &I) { unsigned TypeBits = Cst1->getBitWidth(); unsigned ShAmt = (unsigned)Cst1->getLimitedValue(TypeBits); if (ShAmt < TypeBits && ShAmt != 0) { - ICmpInst::Predicate Pred = I.getPredicate() == ICmpInst::ICMP_NE - ? ICmpInst::ICMP_UGE - : ICmpInst::ICMP_ULT; - Value *Xor = Builder->CreateXor(A, B, I.getName() + ".unshifted"); + ICmpInst::Predicate NewPred = + Pred == ICmpInst::ICMP_NE ? ICmpInst::ICMP_UGE : ICmpInst::ICMP_ULT; + Value *Xor = Builder.CreateXor(A, B, I.getName() + ".unshifted"); APInt CmpVal = APInt::getOneBitSet(TypeBits, ShAmt); - return new ICmpInst(Pred, Xor, Builder->getInt(CmpVal)); + return new ICmpInst(NewPred, Xor, Builder.getInt(CmpVal)); } } @@ -3233,12 +3422,11 @@ Instruction *InstCombiner::foldICmpEquality(ICmpInst &I) { unsigned TypeBits = Cst1->getBitWidth(); unsigned ShAmt = (unsigned)Cst1->getLimitedValue(TypeBits); if (ShAmt < TypeBits && ShAmt != 0) { - Value *Xor = Builder->CreateXor(A, B, I.getName() + ".unshifted"); + Value *Xor = Builder.CreateXor(A, B, I.getName() + ".unshifted"); APInt AndVal = APInt::getLowBitsSet(TypeBits, TypeBits - ShAmt); - Value *And = Builder->CreateAnd(Xor, Builder->getInt(AndVal), + Value *And = Builder.CreateAnd(Xor, Builder.getInt(AndVal), I.getName() + ".mask"); - return new ICmpInst(I.getPredicate(), And, - Constant::getNullValue(Cst1->getType())); + return new ICmpInst(Pred, And, Constant::getNullValue(Cst1->getType())); } } @@ -3261,11 +3449,20 @@ Instruction *InstCombiner::foldICmpEquality(ICmpInst &I) { APInt CmpV = Cst1->getValue().zext(ASize); CmpV <<= ShAmt; - Value *Mask = Builder->CreateAnd(A, Builder->getInt(MaskV)); - return new ICmpInst(I.getPredicate(), Mask, Builder->getInt(CmpV)); + Value *Mask = Builder.CreateAnd(A, Builder.getInt(MaskV)); + return new ICmpInst(Pred, Mask, Builder.getInt(CmpV)); } } + // If both operands are byte-swapped or bit-reversed, just compare the + // original values. + // TODO: Move this to a function similar to foldICmpIntrinsicWithConstant() + // and handle more intrinsics. + if ((match(Op0, m_BSwap(m_Value(A))) && match(Op1, m_BSwap(m_Value(B)))) || + (match(Op0, m_BitReverse(m_Value(A))) && + match(Op1, m_BitReverse(m_Value(B))))) + return new ICmpInst(Pred, A, B); + return nullptr; } @@ -3290,7 +3487,7 @@ Instruction *InstCombiner::foldICmpWithCastAndCast(ICmpInst &ICmp) { RHSOp = RHSC->getOperand(0); // If the pointer types don't match, insert a bitcast. if (LHSCIOp->getType() != RHSOp->getType()) - RHSOp = Builder->CreateBitCast(RHSOp, LHSCIOp->getType()); + RHSOp = Builder.CreateBitCast(RHSOp, LHSCIOp->getType()); } } else if (auto *RHSC = dyn_cast<Constant>(ICmp.getOperand(1))) { RHSOp = ConstantExpr::getIntToPtr(RHSC, SrcTy); @@ -3374,7 +3571,7 @@ Instruction *InstCombiner::foldICmpWithCastAndCast(ICmpInst &ICmp) { // We're performing an unsigned comp with a sign extended value. // This is true if the input is >= 0. [aka >s -1] Constant *NegOne = Constant::getAllOnesValue(SrcTy); - Value *Result = Builder->CreateICmpSGT(LHSCIOp, NegOne, ICmp.getName()); + Value *Result = Builder.CreateICmpSGT(LHSCIOp, NegOne, ICmp.getName()); // Finally, return the value computed. if (ICmp.getPredicate() == ICmpInst::ICMP_ULT) @@ -3402,7 +3599,7 @@ bool InstCombiner::OptimizeOverflowCheck(OverflowCheckFlavor OCF, Value *LHS, // may be pointing to the compare. We want to insert the new instructions // before the add in case there are uses of the add between the add and the // compare. - Builder->SetInsertPoint(&OrigI); + Builder.SetInsertPoint(&OrigI); switch (OCF) { case OCF_INVALID: @@ -3411,11 +3608,11 @@ bool InstCombiner::OptimizeOverflowCheck(OverflowCheckFlavor OCF, Value *LHS, case OCF_UNSIGNED_ADD: { OverflowResult OR = computeOverflowForUnsignedAdd(LHS, RHS, &OrigI); if (OR == OverflowResult::NeverOverflows) - return SetResult(Builder->CreateNUWAdd(LHS, RHS), Builder->getFalse(), + return SetResult(Builder.CreateNUWAdd(LHS, RHS), Builder.getFalse(), true); if (OR == OverflowResult::AlwaysOverflows) - return SetResult(Builder->CreateAdd(LHS, RHS), Builder->getTrue(), true); + return SetResult(Builder.CreateAdd(LHS, RHS), Builder.getTrue(), true); // Fall through uadd into sadd LLVM_FALLTHROUGH; @@ -3423,13 +3620,13 @@ bool InstCombiner::OptimizeOverflowCheck(OverflowCheckFlavor OCF, Value *LHS, case OCF_SIGNED_ADD: { // X + 0 -> {X, false} if (match(RHS, m_Zero())) - return SetResult(LHS, Builder->getFalse(), false); + return SetResult(LHS, Builder.getFalse(), false); // We can strength reduce this signed add into a regular add if we can prove // that it will never overflow. if (OCF == OCF_SIGNED_ADD) - if (WillNotOverflowSignedAdd(LHS, RHS, OrigI)) - return SetResult(Builder->CreateNSWAdd(LHS, RHS), Builder->getFalse(), + if (willNotOverflowSignedAdd(LHS, RHS, OrigI)) + return SetResult(Builder.CreateNSWAdd(LHS, RHS), Builder.getFalse(), true); break; } @@ -3438,15 +3635,15 @@ bool InstCombiner::OptimizeOverflowCheck(OverflowCheckFlavor OCF, Value *LHS, case OCF_SIGNED_SUB: { // X - 0 -> {X, false} if (match(RHS, m_Zero())) - return SetResult(LHS, Builder->getFalse(), false); + return SetResult(LHS, Builder.getFalse(), false); if (OCF == OCF_SIGNED_SUB) { - if (WillNotOverflowSignedSub(LHS, RHS, OrigI)) - return SetResult(Builder->CreateNSWSub(LHS, RHS), Builder->getFalse(), + if (willNotOverflowSignedSub(LHS, RHS, OrigI)) + return SetResult(Builder.CreateNSWSub(LHS, RHS), Builder.getFalse(), true); } else { - if (WillNotOverflowUnsignedSub(LHS, RHS, OrigI)) - return SetResult(Builder->CreateNUWSub(LHS, RHS), Builder->getFalse(), + if (willNotOverflowUnsignedSub(LHS, RHS, OrigI)) + return SetResult(Builder.CreateNUWSub(LHS, RHS), Builder.getFalse(), true); } break; @@ -3455,28 +3652,28 @@ bool InstCombiner::OptimizeOverflowCheck(OverflowCheckFlavor OCF, Value *LHS, case OCF_UNSIGNED_MUL: { OverflowResult OR = computeOverflowForUnsignedMul(LHS, RHS, &OrigI); if (OR == OverflowResult::NeverOverflows) - return SetResult(Builder->CreateNUWMul(LHS, RHS), Builder->getFalse(), + return SetResult(Builder.CreateNUWMul(LHS, RHS), Builder.getFalse(), true); if (OR == OverflowResult::AlwaysOverflows) - return SetResult(Builder->CreateMul(LHS, RHS), Builder->getTrue(), true); + return SetResult(Builder.CreateMul(LHS, RHS), Builder.getTrue(), true); LLVM_FALLTHROUGH; } case OCF_SIGNED_MUL: // X * undef -> undef if (isa<UndefValue>(RHS)) - return SetResult(RHS, UndefValue::get(Builder->getInt1Ty()), false); + return SetResult(RHS, UndefValue::get(Builder.getInt1Ty()), false); // X * 0 -> {0, false} if (match(RHS, m_Zero())) - return SetResult(RHS, Builder->getFalse(), false); + return SetResult(RHS, Builder.getFalse(), false); // X * 1 -> {X, false} if (match(RHS, m_One())) - return SetResult(LHS, Builder->getFalse(), false); + return SetResult(LHS, Builder.getFalse(), false); if (OCF == OCF_SIGNED_MUL) - if (WillNotOverflowSignedMul(LHS, RHS, OrigI)) - return SetResult(Builder->CreateNSWMul(LHS, RHS), Builder->getFalse(), + if (willNotOverflowSignedMul(LHS, RHS, OrigI)) + return SetResult(Builder.CreateNSWMul(LHS, RHS), Builder.getFalse(), true); break; } @@ -3552,6 +3749,11 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, const APInt &CVal = CI->getValue(); if (CVal.getBitWidth() - CVal.countLeadingZeros() > MulWidth) return nullptr; + } else { + // In this case we could have the operand of the binary operation + // being defined in another block, and performing the replacement + // could break the dominance relation. + return nullptr; } } else { // Other uses prohibit this transformation. @@ -3641,25 +3843,25 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, return nullptr; } - InstCombiner::BuilderTy *Builder = IC.Builder; - Builder->SetInsertPoint(MulInstr); + InstCombiner::BuilderTy &Builder = IC.Builder; + Builder.SetInsertPoint(MulInstr); // Replace: mul(zext A, zext B) --> mul.with.overflow(A, B) Value *MulA = A, *MulB = B; if (WidthA < MulWidth) - MulA = Builder->CreateZExt(A, MulType); + MulA = Builder.CreateZExt(A, MulType); if (WidthB < MulWidth) - MulB = Builder->CreateZExt(B, MulType); + MulB = Builder.CreateZExt(B, MulType); Value *F = Intrinsic::getDeclaration(I.getModule(), Intrinsic::umul_with_overflow, MulType); - CallInst *Call = Builder->CreateCall(F, {MulA, MulB}, "umul"); + CallInst *Call = Builder.CreateCall(F, {MulA, MulB}, "umul"); IC.Worklist.Add(MulInstr); // If there are uses of mul result other than the comparison, we know that // they are truncation or binary AND. Change them to use result of // mul.with.overflow and adjust properly mask/size. if (MulVal->hasNUsesOrMore(2)) { - Value *Mul = Builder->CreateExtractValue(Call, 0, "umul.value"); + Value *Mul = Builder.CreateExtractValue(Call, 0, "umul.value"); for (User *U : MulVal->users()) { if (U == &I || U == OtherVal) continue; @@ -3673,9 +3875,9 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, // Replace (mul & mask) --> zext (mul.with.overflow & short_mask) ConstantInt *CI = cast<ConstantInt>(BO->getOperand(1)); APInt ShortMask = CI->getValue().trunc(MulWidth); - Value *ShortAnd = Builder->CreateAnd(Mul, ShortMask); + Value *ShortAnd = Builder.CreateAnd(Mul, ShortMask); Instruction *Zext = - cast<Instruction>(Builder->CreateZExt(ShortAnd, BO->getType())); + cast<Instruction>(Builder.CreateZExt(ShortAnd, BO->getType())); IC.Worklist.Add(Zext); IC.replaceInstUsesWith(*BO, Zext); } else { @@ -3712,7 +3914,7 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, llvm_unreachable("Unexpected predicate"); } if (Inverse) { - Value *Res = Builder->CreateExtractValue(Call, 1); + Value *Res = Builder.CreateExtractValue(Call, 1); return BinaryOperator::CreateNot(Res); } @@ -3725,7 +3927,7 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, static APInt getDemandedBitsLHSMask(ICmpInst &I, unsigned BitWidth, bool isSignCheck) { if (isSignCheck) - return APInt::getSignBit(BitWidth); + return APInt::getSignMask(BitWidth); ConstantInt *CI = dyn_cast<ConstantInt>(I.getOperand(1)); if (!CI) return APInt::getAllOnesValue(BitWidth); @@ -3738,16 +3940,14 @@ static APInt getDemandedBitsLHSMask(ICmpInst &I, unsigned BitWidth, // greater than the RHS must differ in a bit higher than these due to carry. case ICmpInst::ICMP_UGT: { unsigned trailingOnes = RHS.countTrailingOnes(); - APInt lowBitsSet = APInt::getLowBitsSet(BitWidth, trailingOnes); - return ~lowBitsSet; + return APInt::getBitsSetFrom(BitWidth, trailingOnes); } // Similarly, for a ULT comparison, we don't care about the trailing zeros. // Any value less than the RHS must differ in a higher bit because of carries. case ICmpInst::ICMP_ULT: { unsigned trailingZeros = RHS.countTrailingZeros(); - APInt lowBitsSet = APInt::getLowBitsSet(BitWidth, trailingZeros); - return ~lowBitsSet; + return APInt::getBitsSetFrom(BitWidth, trailingZeros); } default: @@ -3887,7 +4087,7 @@ bool InstCombiner::replacedSelectWithOperand(SelectInst *SI, assert((SIOpd == 1 || SIOpd == 2) && "Invalid select operand!"); if (isChainSelectCmpBranch(SI) && Icmp->getPredicate() == ICmpInst::ICMP_EQ) { BasicBlock *Succ = SI->getParent()->getTerminator()->getSuccessor(1); - // The check for the unique predecessor is not the best that can be + // The check for the single predecessor is not the best that can be // done. But it protects efficiently against cases like when SI's // home block has two successors, Succ and Succ1, and Succ1 predecessor // of Succ. Then SI can't be replaced by SIOpd because the use that gets @@ -3895,8 +4095,10 @@ bool InstCombiner::replacedSelectWithOperand(SelectInst *SI, // guarantees that the path all uses of SI (outside SI's parent) are on // is disjoint from all other paths out of SI. But that information // is more expensive to compute, and the trade-off here is in favor - // of compile-time. - if (Succ->getUniquePredecessor() && dominatesAllUses(SI, Icmp, Succ)) { + // of compile-time. It should also be noticed that we check for a single + // predecessor and not only uniqueness. This to handle the situation when + // Succ and Succ1 points to the same basic block. + if (Succ->getSinglePredecessor() && dominatesAllUses(SI, Icmp, Succ)) { NumSel++; SI->replaceUsesOutsideBlock(SI->getOperand(SIOpd), SI->getParent()); return true; @@ -3929,16 +4131,16 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { IsSignBit = isSignBitCheck(Pred, *CmpC, UnusedBit); } - APInt Op0KnownZero(BitWidth, 0), Op0KnownOne(BitWidth, 0); - APInt Op1KnownZero(BitWidth, 0), Op1KnownOne(BitWidth, 0); + KnownBits Op0Known(BitWidth); + KnownBits Op1Known(BitWidth); - if (SimplifyDemandedBits(I.getOperandUse(0), + if (SimplifyDemandedBits(&I, 0, getDemandedBitsLHSMask(I, BitWidth, IsSignBit), - Op0KnownZero, Op0KnownOne, 0)) + Op0Known, 0)) return &I; - if (SimplifyDemandedBits(I.getOperandUse(1), APInt::getAllOnesValue(BitWidth), - Op1KnownZero, Op1KnownOne, 0)) + if (SimplifyDemandedBits(&I, 1, APInt::getAllOnesValue(BitWidth), + Op1Known, 0)) return &I; // Given the known and unknown bits, compute a range that the LHS could be @@ -3947,15 +4149,11 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { APInt Op0Min(BitWidth, 0), Op0Max(BitWidth, 0); APInt Op1Min(BitWidth, 0), Op1Max(BitWidth, 0); if (I.isSigned()) { - computeSignedMinMaxValuesFromKnownBits(Op0KnownZero, Op0KnownOne, Op0Min, - Op0Max); - computeSignedMinMaxValuesFromKnownBits(Op1KnownZero, Op1KnownOne, Op1Min, - Op1Max); + computeSignedMinMaxValuesFromKnownBits(Op0Known, Op0Min, Op0Max); + computeSignedMinMaxValuesFromKnownBits(Op1Known, Op1Min, Op1Max); } else { - computeUnsignedMinMaxValuesFromKnownBits(Op0KnownZero, Op0KnownOne, Op0Min, - Op0Max); - computeUnsignedMinMaxValuesFromKnownBits(Op1KnownZero, Op1KnownOne, Op1Min, - Op1Max); + computeUnsignedMinMaxValuesFromKnownBits(Op0Known, Op0Min, Op0Max); + computeUnsignedMinMaxValuesFromKnownBits(Op1Known, Op1Min, Op1Max); } // If Min and Max are known to be the same, then SimplifyDemandedBits @@ -3982,8 +4180,8 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { // If all bits are known zero except for one, then we know at most one bit // is set. If the comparison is against zero, then this is a check to see if // *that* bit is set. - APInt Op0KnownZeroInverted = ~Op0KnownZero; - if (~Op1KnownZero == 0) { + APInt Op0KnownZeroInverted = ~Op0Known.Zero; + if (Op1Known.isZero()) { // If the LHS is an AND with the same constant, look through it. Value *LHS = nullptr; const APInt *LHSC; @@ -4013,7 +4211,7 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { // Check if the LHS is 8 >>u x and the result is a power of 2 like 1. const APInt *CI; - if (Op0KnownZeroInverted == 1 && + if (Op0KnownZeroInverted.isOneValue() && match(LHS, m_LShr(m_Power2(CI), m_Value(X)))) { // ((8 >>u X) & 1) == 0 -> X != 3 // ((8 >>u X) & 1) != 0 -> X == 3 @@ -4071,7 +4269,7 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { if (Op1Max == Op0Min + 1) // A <s C -> A == C-1 if min(A)+1 == C return new ICmpInst(ICmpInst::ICMP_EQ, Op0, - Builder->getInt(CI->getValue() - 1)); + Builder.getInt(CI->getValue() - 1)); } break; case ICmpInst::ICMP_SGT: @@ -4085,7 +4283,7 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { if (Op1Min == Op0Max - 1) // A >s C -> A == C+1 if max(A)-1 == C return new ICmpInst(ICmpInst::ICMP_EQ, Op0, - Builder->getInt(CI->getValue() + 1)); + Builder.getInt(CI->getValue() + 1)); } break; case ICmpInst::ICMP_SGE: @@ -4121,8 +4319,8 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { // Turn a signed comparison into an unsigned one if both operands are known to // have the same sign. if (I.isSigned() && - ((Op0KnownZero.isNegative() && Op1KnownZero.isNegative()) || - (Op0KnownOne.isNegative() && Op1KnownOne.isNegative()))) + ((Op0Known.Zero.isNegative() && Op1Known.Zero.isNegative()) || + (Op0Known.One.isNegative() && Op1Known.One.isNegative()))) return new ICmpInst(I.getUnsignedPredicate(), Op0, Op1); return nullptr; @@ -4186,6 +4384,80 @@ static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) { return new ICmpInst(NewPred, Op0, ConstantExpr::getAdd(Op1C, OneOrNegOne)); } +/// Integer compare with boolean values can always be turned into bitwise ops. +static Instruction *canonicalizeICmpBool(ICmpInst &I, + InstCombiner::BuilderTy &Builder) { + Value *A = I.getOperand(0), *B = I.getOperand(1); + assert(A->getType()->isIntOrIntVectorTy(1) && "Bools only"); + + // A boolean compared to true/false can be simplified to Op0/true/false in + // 14 out of the 20 (10 predicates * 2 constants) possible combinations. + // Cases not handled by InstSimplify are always 'not' of Op0. + if (match(B, m_Zero())) { + switch (I.getPredicate()) { + case CmpInst::ICMP_EQ: // A == 0 -> !A + case CmpInst::ICMP_ULE: // A <=u 0 -> !A + case CmpInst::ICMP_SGE: // A >=s 0 -> !A + return BinaryOperator::CreateNot(A); + default: + llvm_unreachable("ICmp i1 X, C not simplified as expected."); + } + } else if (match(B, m_One())) { + switch (I.getPredicate()) { + case CmpInst::ICMP_NE: // A != 1 -> !A + case CmpInst::ICMP_ULT: // A <u 1 -> !A + case CmpInst::ICMP_SGT: // A >s -1 -> !A + return BinaryOperator::CreateNot(A); + default: + llvm_unreachable("ICmp i1 X, C not simplified as expected."); + } + } + + switch (I.getPredicate()) { + default: + llvm_unreachable("Invalid icmp instruction!"); + case ICmpInst::ICMP_EQ: + // icmp eq i1 A, B -> ~(A ^ B) + return BinaryOperator::CreateNot(Builder.CreateXor(A, B)); + + case ICmpInst::ICMP_NE: + // icmp ne i1 A, B -> A ^ B + return BinaryOperator::CreateXor(A, B); + + case ICmpInst::ICMP_UGT: + // icmp ugt -> icmp ult + std::swap(A, B); + LLVM_FALLTHROUGH; + case ICmpInst::ICMP_ULT: + // icmp ult i1 A, B -> ~A & B + return BinaryOperator::CreateAnd(Builder.CreateNot(A), B); + + case ICmpInst::ICMP_SGT: + // icmp sgt -> icmp slt + std::swap(A, B); + LLVM_FALLTHROUGH; + case ICmpInst::ICMP_SLT: + // icmp slt i1 A, B -> A & ~B + return BinaryOperator::CreateAnd(Builder.CreateNot(B), A); + + case ICmpInst::ICMP_UGE: + // icmp uge -> icmp ule + std::swap(A, B); + LLVM_FALLTHROUGH; + case ICmpInst::ICMP_ULE: + // icmp ule i1 A, B -> ~A | B + return BinaryOperator::CreateOr(Builder.CreateNot(A), B); + + case ICmpInst::ICMP_SGE: + // icmp sge -> icmp sle + std::swap(A, B); + LLVM_FALLTHROUGH; + case ICmpInst::ICMP_SLE: + // icmp sle i1 A, B -> A | ~B + return BinaryOperator::CreateOr(Builder.CreateNot(B), A); + } +} + Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { bool Changed = false; Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); @@ -4202,8 +4474,8 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { Changed = true; } - if (Value *V = - SimplifyICmpInst(I.getPredicate(), Op0, Op1, DL, &TLI, &DT, &AC, &I)) + if (Value *V = SimplifyICmpInst(I.getPredicate(), Op0, Op1, + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); // comparing -val or val with non-zero is the same as just comparing val @@ -4223,49 +4495,9 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { } } - Type *Ty = Op0->getType(); - - // icmp's with boolean values can always be turned into bitwise operations - if (Ty->getScalarType()->isIntegerTy(1)) { - switch (I.getPredicate()) { - default: llvm_unreachable("Invalid icmp instruction!"); - case ICmpInst::ICMP_EQ: { // icmp eq i1 A, B -> ~(A^B) - Value *Xor = Builder->CreateXor(Op0, Op1, I.getName() + "tmp"); - return BinaryOperator::CreateNot(Xor); - } - case ICmpInst::ICMP_NE: // icmp ne i1 A, B -> A^B - return BinaryOperator::CreateXor(Op0, Op1); - - case ICmpInst::ICMP_UGT: - std::swap(Op0, Op1); // Change icmp ugt -> icmp ult - LLVM_FALLTHROUGH; - case ICmpInst::ICMP_ULT:{ // icmp ult i1 A, B -> ~A & B - Value *Not = Builder->CreateNot(Op0, I.getName() + "tmp"); - return BinaryOperator::CreateAnd(Not, Op1); - } - case ICmpInst::ICMP_SGT: - std::swap(Op0, Op1); // Change icmp sgt -> icmp slt - LLVM_FALLTHROUGH; - case ICmpInst::ICMP_SLT: { // icmp slt i1 A, B -> A & ~B - Value *Not = Builder->CreateNot(Op1, I.getName() + "tmp"); - return BinaryOperator::CreateAnd(Not, Op0); - } - case ICmpInst::ICMP_UGE: - std::swap(Op0, Op1); // Change icmp uge -> icmp ule - LLVM_FALLTHROUGH; - case ICmpInst::ICMP_ULE: { // icmp ule i1 A, B -> ~A | B - Value *Not = Builder->CreateNot(Op0, I.getName() + "tmp"); - return BinaryOperator::CreateOr(Not, Op1); - } - case ICmpInst::ICMP_SGE: - std::swap(Op0, Op1); // Change icmp sge -> icmp sle - LLVM_FALLTHROUGH; - case ICmpInst::ICMP_SLE: { // icmp sle i1 A, B -> A | ~B - Value *Not = Builder->CreateNot(Op1, I.getName() + "tmp"); - return BinaryOperator::CreateOr(Not, Op0); - } - } - } + if (Op0->getType()->isIntOrIntVectorTy(1)) + if (Instruction *Res = canonicalizeICmpBool(I, Builder)) + return Res; if (ICmpInst *NewICmp = canonicalizeCmpWithConstant(I)) return NewICmp; @@ -4357,7 +4589,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { Op1 = ConstantExpr::getBitCast(Op1C, Op0->getType()); } else { // Otherwise, cast the RHS right before the icmp - Op1 = Builder->CreateBitCast(Op1, Op0->getType()); + Op1 = Builder.CreateBitCast(Op1, Op0->getType()); } } return new ICmpInst(I.getPredicate(), Op0, Op1); @@ -4389,18 +4621,20 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // if A is a power of 2. if (match(Op0, m_And(m_Value(A), m_Not(m_Value(B)))) && match(Op1, m_Zero()) && - isKnownToBeAPowerOfTwo(A, DL, false, 0, &AC, &I, &DT) && I.isEquality()) - return new ICmpInst(I.getInversePredicate(), - Builder->CreateAnd(A, B), + isKnownToBeAPowerOfTwo(A, false, 0, &I) && I.isEquality()) + return new ICmpInst(I.getInversePredicate(), Builder.CreateAnd(A, B), Op1); - // ~x < ~y --> y < x - // ~x < cst --> ~cst < x + // ~X < ~Y --> Y < X + // ~X < C --> X > ~C if (match(Op0, m_Not(m_Value(A)))) { if (match(Op1, m_Not(m_Value(B)))) return new ICmpInst(I.getPredicate(), B, A); - if (ConstantInt *RHSC = dyn_cast<ConstantInt>(Op1)) - return new ICmpInst(I.getPredicate(), ConstantExpr::getNot(RHSC), A); + + const APInt *C; + if (match(Op1, m_APInt(C))) + return new ICmpInst(I.getSwappedPredicate(), A, + ConstantInt::get(Op1->getType(), ~(*C))); } Instruction *AddI = nullptr; @@ -4488,10 +4722,10 @@ Instruction *InstCombiner::foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI, RHSRoundInt.roundToIntegral(APFloat::rmNearestTiesToEven); if (RHS.compare(RHSRoundInt) != APFloat::cmpEqual) { if (P == FCmpInst::FCMP_OEQ || P == FCmpInst::FCMP_UEQ) - return replaceInstUsesWith(I, Builder->getFalse()); + return replaceInstUsesWith(I, Builder.getFalse()); assert(P == FCmpInst::FCMP_ONE || P == FCmpInst::FCMP_UNE); - return replaceInstUsesWith(I, Builder->getTrue()); + return replaceInstUsesWith(I, Builder.getTrue()); } } @@ -4557,9 +4791,9 @@ Instruction *InstCombiner::foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI, Pred = ICmpInst::ICMP_NE; break; case FCmpInst::FCMP_ORD: - return replaceInstUsesWith(I, Builder->getTrue()); + return replaceInstUsesWith(I, Builder.getTrue()); case FCmpInst::FCMP_UNO: - return replaceInstUsesWith(I, Builder->getFalse()); + return replaceInstUsesWith(I, Builder.getFalse()); } // Now we know that the APFloat is a normal number, zero or inf. @@ -4577,8 +4811,8 @@ Instruction *InstCombiner::foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI, if (SMax.compare(RHS) == APFloat::cmpLessThan) { // smax < 13123.0 if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) - return replaceInstUsesWith(I, Builder->getTrue()); - return replaceInstUsesWith(I, Builder->getFalse()); + return replaceInstUsesWith(I, Builder.getTrue()); + return replaceInstUsesWith(I, Builder.getFalse()); } } else { // If the RHS value is > UnsignedMax, fold the comparison. This handles @@ -4589,8 +4823,8 @@ Instruction *InstCombiner::foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI, if (UMax.compare(RHS) == APFloat::cmpLessThan) { // umax < 13123.0 if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE) - return replaceInstUsesWith(I, Builder->getTrue()); - return replaceInstUsesWith(I, Builder->getFalse()); + return replaceInstUsesWith(I, Builder.getTrue()); + return replaceInstUsesWith(I, Builder.getFalse()); } } @@ -4602,8 +4836,8 @@ Instruction *InstCombiner::foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI, if (SMin.compare(RHS) == APFloat::cmpGreaterThan) { // smin > 12312.0 if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) - return replaceInstUsesWith(I, Builder->getTrue()); - return replaceInstUsesWith(I, Builder->getFalse()); + return replaceInstUsesWith(I, Builder.getTrue()); + return replaceInstUsesWith(I, Builder.getFalse()); } } else { // See if the RHS value is < UnsignedMin. @@ -4613,8 +4847,8 @@ Instruction *InstCombiner::foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI, if (SMin.compare(RHS) == APFloat::cmpGreaterThan) { // umin > 12312.0 if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) - return replaceInstUsesWith(I, Builder->getTrue()); - return replaceInstUsesWith(I, Builder->getFalse()); + return replaceInstUsesWith(I, Builder.getTrue()); + return replaceInstUsesWith(I, Builder.getFalse()); } } @@ -4636,14 +4870,14 @@ Instruction *InstCombiner::foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI, switch (Pred) { default: llvm_unreachable("Unexpected integer comparison!"); case ICmpInst::ICMP_NE: // (float)int != 4.4 --> true - return replaceInstUsesWith(I, Builder->getTrue()); + return replaceInstUsesWith(I, Builder.getTrue()); case ICmpInst::ICMP_EQ: // (float)int == 4.4 --> false - return replaceInstUsesWith(I, Builder->getFalse()); + return replaceInstUsesWith(I, Builder.getFalse()); case ICmpInst::ICMP_ULE: // (float)int <= 4.4 --> int <= 4 // (float)int <= -4.4 --> false if (RHS.isNegative()) - return replaceInstUsesWith(I, Builder->getFalse()); + return replaceInstUsesWith(I, Builder.getFalse()); break; case ICmpInst::ICMP_SLE: // (float)int <= 4.4 --> int <= 4 @@ -4655,7 +4889,7 @@ Instruction *InstCombiner::foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI, // (float)int < -4.4 --> false // (float)int < 4.4 --> int <= 4 if (RHS.isNegative()) - return replaceInstUsesWith(I, Builder->getFalse()); + return replaceInstUsesWith(I, Builder.getFalse()); Pred = ICmpInst::ICMP_ULE; break; case ICmpInst::ICMP_SLT: @@ -4668,7 +4902,7 @@ Instruction *InstCombiner::foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI, // (float)int > 4.4 --> int > 4 // (float)int > -4.4 --> true if (RHS.isNegative()) - return replaceInstUsesWith(I, Builder->getTrue()); + return replaceInstUsesWith(I, Builder.getTrue()); break; case ICmpInst::ICMP_SGT: // (float)int > 4.4 --> int > 4 @@ -4680,7 +4914,7 @@ Instruction *InstCombiner::foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI, // (float)int >= -4.4 --> true // (float)int >= 4.4 --> int > 4 if (RHS.isNegative()) - return replaceInstUsesWith(I, Builder->getTrue()); + return replaceInstUsesWith(I, Builder.getTrue()); Pred = ICmpInst::ICMP_UGT; break; case ICmpInst::ICMP_SGE: @@ -4711,8 +4945,9 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - if (Value *V = SimplifyFCmpInst(I.getPredicate(), Op0, Op1, - I.getFastMathFlags(), DL, &TLI, &DT, &AC, &I)) + if (Value *V = + SimplifyFCmpInst(I.getPredicate(), Op0, Op1, I.getFastMathFlags(), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); // Simplify 'fcmp pred X, X' @@ -4801,7 +5036,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { // block. If in the same block, we're encouraging jump threading. If // not, we are just pessimizing the code by making an i1 phi. if (LHSI->getParent() == I.getParent()) - if (Instruction *NV = FoldOpIntoPhi(I)) + if (Instruction *NV = foldOpIntoPhi(I, cast<PHINode>(LHSI))) return NV; break; case Instruction::SIToFP: diff --git a/contrib/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/contrib/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index 2847ce8..c38a498 100644 --- a/contrib/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/contrib/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -17,6 +17,7 @@ #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/TargetFolder.h" #include "llvm/Analysis/ValueTracking.h" @@ -27,7 +28,9 @@ #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Pass.h" +#include "llvm/Support/KnownBits.h" #include "llvm/Transforms/InstCombine/InstCombineWorklist.h" +#include "llvm/Transforms/Utils/Local.h" #define DEBUG_TYPE "instcombine" @@ -40,27 +43,68 @@ class DbgDeclareInst; class MemIntrinsic; class MemSetInst; -/// \brief Assign a complexity or rank value to LLVM Values. +/// Assign a complexity or rank value to LLVM Values. This is used to reduce +/// the amount of pattern matching needed for compares and commutative +/// instructions. For example, if we have: +/// icmp ugt X, Constant +/// or +/// xor (add X, Constant), cast Z +/// +/// We do not have to consider the commuted variants of these patterns because +/// canonicalization based on complexity guarantees the above ordering. /// /// This routine maps IR values to various complexity ranks: /// 0 -> undef /// 1 -> Constants /// 2 -> Other non-instructions /// 3 -> Arguments -/// 3 -> Unary operations -/// 4 -> Other instructions +/// 4 -> Cast and (f)neg/not instructions +/// 5 -> Other instructions static inline unsigned getComplexity(Value *V) { if (isa<Instruction>(V)) { - if (BinaryOperator::isNeg(V) || BinaryOperator::isFNeg(V) || - BinaryOperator::isNot(V)) - return 3; - return 4; + if (isa<CastInst>(V) || BinaryOperator::isNeg(V) || + BinaryOperator::isFNeg(V) || BinaryOperator::isNot(V)) + return 4; + return 5; } if (isa<Argument>(V)) return 3; return isa<Constant>(V) ? (isa<UndefValue>(V) ? 0 : 1) : 2; } +/// Predicate canonicalization reduces the number of patterns that need to be +/// matched by other transforms. For example, we may swap the operands of a +/// conditional branch or select to create a compare with a canonical (inverted) +/// predicate which is then more likely to be matched with other values. +static inline bool isCanonicalPredicate(CmpInst::Predicate Pred) { + switch (Pred) { + case CmpInst::ICMP_NE: + case CmpInst::ICMP_ULE: + case CmpInst::ICMP_SLE: + case CmpInst::ICMP_UGE: + case CmpInst::ICMP_SGE: + // TODO: There are 16 FCMP predicates. Should others be (not) canonical? + case CmpInst::FCMP_ONE: + case CmpInst::FCMP_OLE: + case CmpInst::FCMP_OGE: + return false; + default: + return true; + } +} + +/// Return the source operand of a potentially bitcasted value while optionally +/// checking if it has one use. If there is no bitcast or the one use check is +/// not met, return the input value itself. +static inline Value *peekThroughBitcast(Value *V, bool OneUseOnly = false) { + if (auto *BitCast = dyn_cast<BitCastInst>(V)) + if (!OneUseOnly || BitCast->hasOneUse()) + return BitCast->getOperand(0); + + // V is not a bitcast or V has more than one use and OneUseOnly is true. + return V; +} + /// \brief Add one to a Constant static inline Constant *AddOne(Constant *C) { return ConstantExpr::getAdd(C, ConstantInt::get(C->getType(), 1)); @@ -85,11 +129,10 @@ static inline bool IsFreeToInvert(Value *V, bool WillInvertAllUses) { return true; // A vector of constant integers can be inverted easily. - Constant *CV; - if (V->getType()->isVectorTy() && match(V, PatternMatch::m_Constant(CV))) { + if (V->getType()->isVectorTy() && isa<Constant>(V)) { unsigned NumElts = V->getType()->getVectorNumElements(); for (unsigned i = 0; i != NumElts; ++i) { - Constant *Elt = CV->getAggregateElement(i); + Constant *Elt = cast<Constant>(V)->getAggregateElement(i); if (!Elt) return false; @@ -167,7 +210,7 @@ public: /// \brief An IRBuilder that automatically inserts new instructions into the /// worklist. typedef IRBuilder<TargetFolder, IRBuilderCallbackInserter> BuilderTy; - BuilderTy *Builder; + BuilderTy &Builder; private: // Mode in which we are running the combiner. @@ -182,7 +225,7 @@ private: TargetLibraryInfo &TLI; DominatorTree &DT; const DataLayout &DL; - + const SimplifyQuery SQ; // Optional analyses. When non-null, these can both be used to do better // combining and will be updated to reflect any changes. LoopInfo *LI; @@ -190,13 +233,13 @@ private: bool MadeIRChange; public: - InstCombiner(InstCombineWorklist &Worklist, BuilderTy *Builder, + InstCombiner(InstCombineWorklist &Worklist, BuilderTy &Builder, bool MinimizeSize, bool ExpensiveCombines, AliasAnalysis *AA, - AssumptionCache &AC, TargetLibraryInfo &TLI, - DominatorTree &DT, const DataLayout &DL, LoopInfo *LI) + AssumptionCache &AC, TargetLibraryInfo &TLI, DominatorTree &DT, + const DataLayout &DL, LoopInfo *LI) : Worklist(Worklist), Builder(Builder), MinimizeSize(MinimizeSize), ExpensiveCombines(ExpensiveCombines), AA(AA), AC(AC), TLI(TLI), DT(DT), - DL(DL), LI(LI), MadeIRChange(false) {} + DL(DL), SQ(DL, &TLI, &DT, &AC), LI(LI), MadeIRChange(false) {} /// \brief Run the combiner over the entire worklist until it is empty. /// @@ -241,15 +284,7 @@ public: Instruction *visitSDiv(BinaryOperator &I); Instruction *visitFDiv(BinaryOperator &I); Value *simplifyRangeCheck(ICmpInst *Cmp0, ICmpInst *Cmp1, bool Inverted); - Value *FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS); - Value *FoldAndOfFCmps(FCmpInst *LHS, FCmpInst *RHS); Instruction *visitAnd(BinaryOperator &I); - Value *FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction *CxtI); - Value *FoldOrOfFCmps(FCmpInst *LHS, FCmpInst *RHS); - Instruction *FoldOrWithConstants(BinaryOperator &I, Value *Op, Value *A, - Value *B, Value *C); - Instruction *FoldXorWithConstants(BinaryOperator &I, Value *Op, Value *A, - Value *B, Value *C); Instruction *visitOr(BinaryOperator &I); Instruction *visitXor(BinaryOperator &I); Instruction *visitShl(BinaryOperator &I); @@ -289,6 +324,7 @@ public: Instruction *visitLoadInst(LoadInst &LI); Instruction *visitStoreInst(StoreInst &SI); Instruction *visitBranchInst(BranchInst &BI); + Instruction *visitFenceInst(FenceInst &FI); Instruction *visitSwitchInst(SwitchInst &SI); Instruction *visitReturnInst(ReturnInst &RI); Instruction *visitInsertValueInst(InsertValueInst &IV); @@ -313,9 +349,14 @@ public: bool replacedSelectWithOperand(SelectInst *SI, const ICmpInst *Icmp, const unsigned SIOpd); + /// Try to replace instruction \p I with value \p V which are pointers + /// in different address space. + /// \return true if successful. + bool replacePointer(Instruction &I, Value *V); + private: - bool ShouldChangeType(unsigned FromBitWidth, unsigned ToBitWidth) const; - bool ShouldChangeType(Type *From, Type *To) const; + bool shouldChangeType(unsigned FromBitWidth, unsigned ToBitWidth) const; + bool shouldChangeType(Type *From, Type *To) const; Value *dyn_castNegVal(Value *V) const; Value *dyn_castFNegVal(Value *V, bool NoSignedZero = false) const; Type *FindElementAtOffset(PointerType *PtrTy, int64_t Offset, @@ -370,10 +411,27 @@ private: bool DoTransform = true); Instruction *transformSExtICmp(ICmpInst *ICI, Instruction &CI); - bool WillNotOverflowSignedAdd(Value *LHS, Value *RHS, Instruction &CxtI); - bool WillNotOverflowSignedSub(Value *LHS, Value *RHS, Instruction &CxtI); - bool WillNotOverflowUnsignedSub(Value *LHS, Value *RHS, Instruction &CxtI); - bool WillNotOverflowSignedMul(Value *LHS, Value *RHS, Instruction &CxtI); + bool willNotOverflowSignedAdd(const Value *LHS, const Value *RHS, + const Instruction &CxtI) const { + return computeOverflowForSignedAdd(LHS, RHS, &CxtI) == + OverflowResult::NeverOverflows; + }; + bool willNotOverflowUnsignedAdd(const Value *LHS, const Value *RHS, + const Instruction &CxtI) const { + return computeOverflowForUnsignedAdd(LHS, RHS, &CxtI) == + OverflowResult::NeverOverflows; + }; + bool willNotOverflowSignedSub(const Value *LHS, const Value *RHS, + const Instruction &CxtI) const; + bool willNotOverflowUnsignedSub(const Value *LHS, const Value *RHS, + const Instruction &CxtI) const; + bool willNotOverflowSignedMul(const Value *LHS, const Value *RHS, + const Instruction &CxtI) const; + bool willNotOverflowUnsignedMul(const Value *LHS, const Value *RHS, + const Instruction &CxtI) const { + return computeOverflowForUnsignedMul(LHS, RHS, &CxtI) == + OverflowResult::NeverOverflows; + }; Value *EmitGEPOffset(User *GEP); Instruction *scalarizePHI(ExtractElementInst &EI, PHINode *PN); Value *EvaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask); @@ -394,6 +452,14 @@ private: Instruction::CastOps isEliminableCastPair(const CastInst *CI1, const CastInst *CI2); + Value *foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction &CxtI); + Value *foldAndOfFCmps(FCmpInst *LHS, FCmpInst *RHS); + Value *foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction &CxtI); + Value *foldOrOfFCmps(FCmpInst *LHS, FCmpInst *RHS); + Value *foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS); + + Value *foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS, + bool JoinedByAnd, Instruction &CxtI); public: /// \brief Inserts an instruction \p New before instruction \p Old /// @@ -456,8 +522,9 @@ public: /// methods should return the value returned by this function. Instruction *eraseInstFromFunction(Instruction &I) { DEBUG(dbgs() << "IC: ERASE " << I << '\n'); - assert(I.use_empty() && "Cannot erase instruction that is used!"); + salvageDebugInfo(I); + // Make sure that we reprocess all operands now that we reduced their // use counts. if (I.getNumOperands() < 8) { @@ -471,33 +538,47 @@ public: return nullptr; // Don't do anything with FI } - void computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, - unsigned Depth, Instruction *CxtI) const { - return llvm::computeKnownBits(V, KnownZero, KnownOne, DL, Depth, &AC, CxtI, - &DT); + void computeKnownBits(const Value *V, KnownBits &Known, + unsigned Depth, const Instruction *CxtI) const { + llvm::computeKnownBits(V, Known, DL, Depth, &AC, CxtI, &DT); + } + KnownBits computeKnownBits(const Value *V, unsigned Depth, + const Instruction *CxtI) const { + return llvm::computeKnownBits(V, DL, Depth, &AC, CxtI, &DT); + } + + bool isKnownToBeAPowerOfTwo(const Value *V, bool OrZero = false, + unsigned Depth = 0, + const Instruction *CxtI = nullptr) { + return llvm::isKnownToBeAPowerOfTwo(V, DL, OrZero, Depth, &AC, CxtI, &DT); } - bool MaskedValueIsZero(Value *V, const APInt &Mask, unsigned Depth = 0, - Instruction *CxtI = nullptr) const { + bool MaskedValueIsZero(const Value *V, const APInt &Mask, unsigned Depth = 0, + const Instruction *CxtI = nullptr) const { return llvm::MaskedValueIsZero(V, Mask, DL, Depth, &AC, CxtI, &DT); } - unsigned ComputeNumSignBits(Value *Op, unsigned Depth = 0, - Instruction *CxtI = nullptr) const { + unsigned ComputeNumSignBits(const Value *Op, unsigned Depth = 0, + const Instruction *CxtI = nullptr) const { return llvm::ComputeNumSignBits(Op, DL, Depth, &AC, CxtI, &DT); } - void ComputeSignBit(Value *V, bool &KnownZero, bool &KnownOne, - unsigned Depth = 0, Instruction *CxtI = nullptr) const { - return llvm::ComputeSignBit(V, KnownZero, KnownOne, DL, Depth, &AC, CxtI, - &DT); - } - OverflowResult computeOverflowForUnsignedMul(Value *LHS, Value *RHS, - const Instruction *CxtI) { + OverflowResult computeOverflowForUnsignedMul(const Value *LHS, + const Value *RHS, + const Instruction *CxtI) const { return llvm::computeOverflowForUnsignedMul(LHS, RHS, DL, &AC, CxtI, &DT); } - OverflowResult computeOverflowForUnsignedAdd(Value *LHS, Value *RHS, - const Instruction *CxtI) { + OverflowResult computeOverflowForUnsignedAdd(const Value *LHS, + const Value *RHS, + const Instruction *CxtI) const { return llvm::computeOverflowForUnsignedAdd(LHS, RHS, DL, &AC, CxtI, &DT); } + OverflowResult computeOverflowForSignedAdd(const Value *LHS, + const Value *RHS, + const Instruction *CxtI) const { + return llvm::computeOverflowForSignedAdd(LHS, RHS, DL, &AC, CxtI, &DT); + } + + /// Maximum size of array considered when transforming. + uint64_t MaxArraySizeForCombine; private: /// \brief Performs a few simplifications for operators which are associative @@ -513,18 +594,39 @@ private: /// value, or null if it didn't simplify. Value *SimplifyUsingDistributiveLaws(BinaryOperator &I); + /// This tries to simplify binary operations by factorizing out common terms + /// (e. g. "(A*B)+(A*C)" -> "A*(B+C)"). + Value *tryFactorization(BinaryOperator &, Instruction::BinaryOps, Value *, + Value *, Value *, Value *); + + /// Match a select chain which produces one of three values based on whether + /// the LHS is less than, equal to, or greater than RHS respectively. + /// Return true if we matched a three way compare idiom. The LHS, RHS, Less, + /// Equal and Greater values are saved in the matching process and returned to + /// the caller. + bool matchThreeWayIntCompare(SelectInst *SI, Value *&LHS, Value *&RHS, + ConstantInt *&Less, ConstantInt *&Equal, + ConstantInt *&Greater); + /// \brief Attempts to replace V with a simpler value based on the demanded /// bits. - Value *SimplifyDemandedUseBits(Value *V, APInt DemandedMask, APInt &KnownZero, - APInt &KnownOne, unsigned Depth, - Instruction *CxtI); - bool SimplifyDemandedBits(Use &U, const APInt &DemandedMask, APInt &KnownZero, - APInt &KnownOne, unsigned Depth = 0); + Value *SimplifyDemandedUseBits(Value *V, APInt DemandedMask, KnownBits &Known, + unsigned Depth, Instruction *CxtI); + bool SimplifyDemandedBits(Instruction *I, unsigned Op, + const APInt &DemandedMask, KnownBits &Known, + unsigned Depth = 0); + /// Helper routine of SimplifyDemandedUseBits. It computes KnownZero/KnownOne + /// bits. It also tries to handle simplifications that can be done based on + /// DemandedMask, but without modifying the Instruction. + Value *SimplifyMultipleUseDemandedBits(Instruction *I, + const APInt &DemandedMask, + KnownBits &Known, + unsigned Depth, Instruction *CxtI); /// Helper routine of SimplifyDemandedUseBits. It tries to simplify demanded /// bit for "r1 = shr x, c1; r2 = shl r1, c2" instruction sequence. - Value *SimplifyShrShlDemandedBits(Instruction *Lsr, Instruction *Sftl, - const APInt &DemandedMask, APInt &KnownZero, - APInt &KnownOne); + Value *simplifyShrShlDemandedBits( + Instruction *Shr, const APInt &ShrOp1, Instruction *Shl, + const APInt &ShlOp1, const APInt &DemandedMask, KnownBits &Known); /// \brief Tries to simplify operands to an integer instruction based on its /// demanded bits. @@ -534,13 +636,12 @@ private: APInt &UndefElts, unsigned Depth = 0); Value *SimplifyVectorOp(BinaryOperator &Inst); - Value *SimplifyBSwap(BinaryOperator &Inst); /// Given a binary operator, cast instruction, or select which has a PHI node /// as operand #0, see if we can fold the instruction into the PHI (which is /// only possible if all operands to the PHI are constants). - Instruction *FoldOpIntoPhi(Instruction &I); + Instruction *foldOpIntoPhi(Instruction &I, PHINode *PN); /// Given an instruction with a select as one operand and a constant as the /// other operand, try to fold the binary operator into the select arguments. @@ -549,7 +650,7 @@ private: Instruction *FoldOpIntoSelect(Instruction &Op, SelectInst *SI); /// This is a convenience wrapper function for the above two functions. - Instruction *foldOpWithConstantIntoOperand(Instruction &I); + Instruction *foldOpWithConstantIntoOperand(BinaryOperator &I); /// \brief Try to rotate an operation below a PHI node, using PHI nodes for /// its operands. @@ -583,6 +684,8 @@ private: Instruction *foldICmpBinOp(ICmpInst &Cmp); Instruction *foldICmpEquality(ICmpInst &Cmp); + Instruction *foldICmpSelectConstant(ICmpInst &Cmp, Instruction *Select, + ConstantInt *C); Instruction *foldICmpTruncConstant(ICmpInst &Cmp, Instruction *Trunc, const APInt *C); Instruction *foldICmpAndConstant(ICmpInst &Cmp, BinaryOperator *And, @@ -628,16 +731,17 @@ private: SelectPatternFlavor SPF2, Value *C); Instruction *foldSelectInstWithICmp(SelectInst &SI, ICmpInst *ICI); - Instruction *OptAndOp(Instruction *Op, ConstantInt *OpRHS, + Instruction *OptAndOp(BinaryOperator *Op, ConstantInt *OpRHS, ConstantInt *AndRHS, BinaryOperator &TheAnd); - Value *FoldLogicalPlusAnd(Value *LHS, Value *RHS, ConstantInt *Mask, - bool isSub, Instruction &I); Value *insertRangeTest(Value *V, const APInt &Lo, const APInt &Hi, bool isSigned, bool Inside); Instruction *PromoteCastOfAllocation(BitCastInst &CI, AllocaInst &AI); Instruction *MatchBSwap(BinaryOperator &I); bool SimplifyStoreAtEndOfBlock(StoreInst &SI); + + Instruction * + SimplifyElementUnorderedAtomicMemCpy(ElementUnorderedAtomicMemCpyInst *AMI); Instruction *SimplifyMemTransfer(MemIntrinsic *MI); Instruction *SimplifyMemSet(MemSetInst *MI); diff --git a/contrib/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/contrib/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp index 49e516e..4510365 100644 --- a/contrib/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/contrib/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -12,13 +12,15 @@ //===----------------------------------------------------------------------===// #include "InstCombineInternal.h" +#include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/Loads.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/DataLayout.h" -#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/DebugInfo.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" @@ -167,6 +169,18 @@ isOnlyCopiedFromConstantGlobal(AllocaInst *AI, return nullptr; } +/// Returns true if V is dereferenceable for size of alloca. +static bool isDereferenceableForAllocaSize(const Value *V, const AllocaInst *AI, + const DataLayout &DL) { + if (AI->isArrayAllocation()) + return false; + uint64_t AllocaSize = DL.getTypeStoreSize(AI->getAllocatedType()); + if (!AllocaSize) + return false; + return isDereferenceableAndAlignedPointer(V, AI->getAlignment(), + APInt(64, AllocaSize), DL); +} + static Instruction *simplifyAllocaArraySize(InstCombiner &IC, AllocaInst &AI) { // Check for array size of 1 (scalar allocation). if (!AI.isArrayAllocation()) { @@ -175,7 +189,7 @@ static Instruction *simplifyAllocaArraySize(InstCombiner &IC, AllocaInst &AI) { return nullptr; // Canonicalize it. - Value *V = IC.Builder->getInt32(1); + Value *V = IC.Builder.getInt32(1); AI.setOperand(0, V); return &AI; } @@ -183,7 +197,7 @@ static Instruction *simplifyAllocaArraySize(InstCombiner &IC, AllocaInst &AI) { // Convert: alloca Ty, C - where C is a constant != 1 into: alloca [C x Ty], 1 if (const ConstantInt *C = dyn_cast<ConstantInt>(AI.getArraySize())) { Type *NewTy = ArrayType::get(AI.getAllocatedType(), C->getZExtValue()); - AllocaInst *New = IC.Builder->CreateAlloca(NewTy, nullptr, AI.getName()); + AllocaInst *New = IC.Builder.CreateAlloca(NewTy, nullptr, AI.getName()); New->setAlignment(AI.getAlignment()); // Scan to the end of the allocation instructions, to skip over a block of @@ -215,7 +229,7 @@ static Instruction *simplifyAllocaArraySize(InstCombiner &IC, AllocaInst &AI) { // any casting is exposed early. Type *IntPtrTy = IC.getDataLayout().getIntPtrType(AI.getType()); if (AI.getArraySize()->getType() != IntPtrTy) { - Value *V = IC.Builder->CreateIntCast(AI.getArraySize(), IntPtrTy, false); + Value *V = IC.Builder.CreateIntCast(AI.getArraySize(), IntPtrTy, false); AI.setOperand(0, V); return &AI; } @@ -223,6 +237,107 @@ static Instruction *simplifyAllocaArraySize(InstCombiner &IC, AllocaInst &AI) { return nullptr; } +namespace { +// If I and V are pointers in different address space, it is not allowed to +// use replaceAllUsesWith since I and V have different types. A +// non-target-specific transformation should not use addrspacecast on V since +// the two address space may be disjoint depending on target. +// +// This class chases down uses of the old pointer until reaching the load +// instructions, then replaces the old pointer in the load instructions with +// the new pointer. If during the chasing it sees bitcast or GEP, it will +// create new bitcast or GEP with the new pointer and use them in the load +// instruction. +class PointerReplacer { +public: + PointerReplacer(InstCombiner &IC) : IC(IC) {} + void replacePointer(Instruction &I, Value *V); + +private: + void findLoadAndReplace(Instruction &I); + void replace(Instruction *I); + Value *getReplacement(Value *I); + + SmallVector<Instruction *, 4> Path; + MapVector<Value *, Value *> WorkMap; + InstCombiner &IC; +}; +} // end anonymous namespace + +void PointerReplacer::findLoadAndReplace(Instruction &I) { + for (auto U : I.users()) { + auto *Inst = dyn_cast<Instruction>(&*U); + if (!Inst) + return; + DEBUG(dbgs() << "Found pointer user: " << *U << '\n'); + if (isa<LoadInst>(Inst)) { + for (auto P : Path) + replace(P); + replace(Inst); + } else if (isa<GetElementPtrInst>(Inst) || isa<BitCastInst>(Inst)) { + Path.push_back(Inst); + findLoadAndReplace(*Inst); + Path.pop_back(); + } else { + return; + } + } +} + +Value *PointerReplacer::getReplacement(Value *V) { + auto Loc = WorkMap.find(V); + if (Loc != WorkMap.end()) + return Loc->second; + return nullptr; +} + +void PointerReplacer::replace(Instruction *I) { + if (getReplacement(I)) + return; + + if (auto *LT = dyn_cast<LoadInst>(I)) { + auto *V = getReplacement(LT->getPointerOperand()); + assert(V && "Operand not replaced"); + auto *NewI = new LoadInst(V); + NewI->takeName(LT); + IC.InsertNewInstWith(NewI, *LT); + IC.replaceInstUsesWith(*LT, NewI); + WorkMap[LT] = NewI; + } else if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) { + auto *V = getReplacement(GEP->getPointerOperand()); + assert(V && "Operand not replaced"); + SmallVector<Value *, 8> Indices; + Indices.append(GEP->idx_begin(), GEP->idx_end()); + auto *NewI = GetElementPtrInst::Create( + V->getType()->getPointerElementType(), V, Indices); + IC.InsertNewInstWith(NewI, *GEP); + NewI->takeName(GEP); + WorkMap[GEP] = NewI; + } else if (auto *BC = dyn_cast<BitCastInst>(I)) { + auto *V = getReplacement(BC->getOperand(0)); + assert(V && "Operand not replaced"); + auto *NewT = PointerType::get(BC->getType()->getPointerElementType(), + V->getType()->getPointerAddressSpace()); + auto *NewI = new BitCastInst(V, NewT); + IC.InsertNewInstWith(NewI, *BC); + NewI->takeName(BC); + WorkMap[BC] = NewI; + } else { + llvm_unreachable("should never reach here"); + } +} + +void PointerReplacer::replacePointer(Instruction &I, Value *V) { +#ifndef NDEBUG + auto *PT = cast<PointerType>(I.getType()); + auto *NT = cast<PointerType>(V->getType()); + assert(PT != NT && PT->getElementType() == NT->getElementType() && + "Invalid usage"); +#endif + WorkMap[&I] = V; + findLoadAndReplace(I); +} + Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) { if (auto *I = simplifyAllocaArraySize(*this, AI)) return I; @@ -287,18 +402,29 @@ Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) { if (MemTransferInst *Copy = isOnlyCopiedFromConstantGlobal(&AI, ToDelete)) { unsigned SourceAlign = getOrEnforceKnownAlignment( Copy->getSource(), AI.getAlignment(), DL, &AI, &AC, &DT); - if (AI.getAlignment() <= SourceAlign) { + if (AI.getAlignment() <= SourceAlign && + isDereferenceableForAllocaSize(Copy->getSource(), &AI, DL)) { DEBUG(dbgs() << "Found alloca equal to global: " << AI << '\n'); DEBUG(dbgs() << " memcpy = " << *Copy << '\n'); for (unsigned i = 0, e = ToDelete.size(); i != e; ++i) eraseInstFromFunction(*ToDelete[i]); Constant *TheSrc = cast<Constant>(Copy->getSource()); - Constant *Cast - = ConstantExpr::getPointerBitCastOrAddrSpaceCast(TheSrc, AI.getType()); - Instruction *NewI = replaceInstUsesWith(AI, Cast); - eraseInstFromFunction(*Copy); - ++NumGlobalCopies; - return NewI; + auto *SrcTy = TheSrc->getType(); + auto *DestTy = PointerType::get(AI.getType()->getPointerElementType(), + SrcTy->getPointerAddressSpace()); + Constant *Cast = + ConstantExpr::getPointerBitCastOrAddrSpaceCast(TheSrc, DestTy); + if (AI.getType()->getPointerAddressSpace() == + SrcTy->getPointerAddressSpace()) { + Instruction *NewI = replaceInstUsesWith(AI, Cast); + eraseInstFromFunction(*Copy); + ++NumGlobalCopies; + return NewI; + } else { + PointerReplacer PtrReplacer(*this); + PtrReplacer.replacePointer(AI, Cast); + ++NumGlobalCopies; + } } } } @@ -332,10 +458,10 @@ static LoadInst *combineLoadToNewType(InstCombiner &IC, LoadInst &LI, Type *NewT SmallVector<std::pair<unsigned, MDNode *>, 8> MD; LI.getAllMetadata(MD); - LoadInst *NewLoad = IC.Builder->CreateAlignedLoad( - IC.Builder->CreateBitCast(Ptr, NewTy->getPointerTo(AS)), + LoadInst *NewLoad = IC.Builder.CreateAlignedLoad( + IC.Builder.CreateBitCast(Ptr, NewTy->getPointerTo(AS)), LI.getAlignment(), LI.isVolatile(), LI.getName() + Suffix); - NewLoad->setAtomic(LI.getOrdering(), LI.getSynchScope()); + NewLoad->setAtomic(LI.getOrdering(), LI.getSyncScopeID()); MDBuilder MDB(NewLoad->getContext()); for (const auto &MDPair : MD) { unsigned ID = MDPair.first; @@ -363,21 +489,7 @@ static LoadInst *combineLoadToNewType(InstCombiner &IC, LoadInst &LI, Type *NewT break; case LLVMContext::MD_nonnull: - // This only directly applies if the new type is also a pointer. - if (NewTy->isPointerTy()) { - NewLoad->setMetadata(ID, N); - break; - } - // If it's integral now, translate it to !range metadata. - if (NewTy->isIntegerTy()) { - auto *ITy = cast<IntegerType>(NewTy); - auto *NullInt = ConstantExpr::getPtrToInt( - ConstantPointerNull::get(cast<PointerType>(Ptr->getType())), ITy); - auto *NonNullInt = - ConstantExpr::getAdd(NullInt, ConstantInt::get(ITy, 1)); - NewLoad->setMetadata(LLVMContext::MD_range, - MDB.createRange(NonNullInt, NullInt)); - } + copyNonnullMetadata(LI, N, *NewLoad); break; case LLVMContext::MD_align: case LLVMContext::MD_dereferenceable: @@ -387,17 +499,7 @@ static LoadInst *combineLoadToNewType(InstCombiner &IC, LoadInst &LI, Type *NewT NewLoad->setMetadata(ID, N); break; case LLVMContext::MD_range: - // FIXME: It would be nice to propagate this in some way, but the type - // conversions make it hard. - - // If it's a pointer now and the range does not contain 0, make it !nonnull. - if (NewTy->isPointerTy()) { - unsigned BitWidth = IC.getDataLayout().getTypeSizeInBits(NewTy); - if (!getConstantRangeFromMetadata(*N).contains(APInt(BitWidth, 0))) { - MDNode *NN = MDNode::get(LI.getContext(), None); - NewLoad->setMetadata(LLVMContext::MD_nonnull, NN); - } - } + copyRangeMetadata(IC.getDataLayout(), LI, N, *NewLoad); break; } } @@ -416,10 +518,10 @@ static StoreInst *combineStoreToNewValue(InstCombiner &IC, StoreInst &SI, Value SmallVector<std::pair<unsigned, MDNode *>, 8> MD; SI.getAllMetadata(MD); - StoreInst *NewStore = IC.Builder->CreateAlignedStore( - V, IC.Builder->CreateBitCast(Ptr, V->getType()->getPointerTo(AS)), + StoreInst *NewStore = IC.Builder.CreateAlignedStore( + V, IC.Builder.CreateBitCast(Ptr, V->getType()->getPointerTo(AS)), SI.getAlignment(), SI.isVolatile()); - NewStore->setAtomic(SI.getOrdering(), SI.getSynchScope()); + NewStore->setAtomic(SI.getOrdering(), SI.getSyncScopeID()); for (const auto &MDPair : MD) { unsigned ID = MDPair.first; MDNode *N = MDPair.second; @@ -511,7 +613,7 @@ static Instruction *combineLoadToOperationType(InstCombiner &IC, LoadInst &LI) { // Replace all the stores with stores of the newly loaded value. for (auto UI = LI.user_begin(), UE = LI.user_end(); UI != UE;) { auto *SI = cast<StoreInst>(*UI++); - IC.Builder->SetInsertPoint(SI); + IC.Builder.SetInsertPoint(SI); combineStoreToNewValue(IC, *SI, NewLoad); IC.eraseInstFromFunction(*SI); } @@ -559,7 +661,10 @@ static Instruction *unpackLoadToAggregate(InstCombiner &IC, LoadInst &LI) { if (NumElements == 1) { LoadInst *NewLoad = combineLoadToNewType(IC, LI, ST->getTypeAtIndex(0U), ".unpack"); - return IC.replaceInstUsesWith(LI, IC.Builder->CreateInsertValue( + AAMDNodes AAMD; + LI.getAAMetadata(AAMD); + NewLoad->setAAMetadata(AAMD); + return IC.replaceInstUsesWith(LI, IC.Builder.CreateInsertValue( UndefValue::get(T), NewLoad, 0, Name)); } @@ -584,11 +689,15 @@ static Instruction *unpackLoadToAggregate(InstCombiner &IC, LoadInst &LI) { Zero, ConstantInt::get(IdxType, i), }; - auto *Ptr = IC.Builder->CreateInBoundsGEP(ST, Addr, makeArrayRef(Indices), - Name + ".elt"); + auto *Ptr = IC.Builder.CreateInBoundsGEP(ST, Addr, makeArrayRef(Indices), + Name + ".elt"); auto EltAlign = MinAlign(Align, SL->getElementOffset(i)); - auto *L = IC.Builder->CreateAlignedLoad(Ptr, EltAlign, Name + ".unpack"); - V = IC.Builder->CreateInsertValue(V, L, i); + auto *L = IC.Builder.CreateAlignedLoad(Ptr, EltAlign, Name + ".unpack"); + // Propagate AA metadata. It'll still be valid on the narrowed load. + AAMDNodes AAMD; + LI.getAAMetadata(AAMD); + L->setAAMetadata(AAMD); + V = IC.Builder.CreateInsertValue(V, L, i); } V->setName(Name); @@ -600,7 +709,10 @@ static Instruction *unpackLoadToAggregate(InstCombiner &IC, LoadInst &LI) { auto NumElements = AT->getNumElements(); if (NumElements == 1) { LoadInst *NewLoad = combineLoadToNewType(IC, LI, ET, ".unpack"); - return IC.replaceInstUsesWith(LI, IC.Builder->CreateInsertValue( + AAMDNodes AAMD; + LI.getAAMetadata(AAMD); + NewLoad->setAAMetadata(AAMD); + return IC.replaceInstUsesWith(LI, IC.Builder.CreateInsertValue( UndefValue::get(T), NewLoad, 0, Name)); } @@ -608,7 +720,7 @@ static Instruction *unpackLoadToAggregate(InstCombiner &IC, LoadInst &LI) { // arrays of arbitrary size but this has a terrible impact on compile time. // The threshold here is chosen arbitrarily, maybe needs a little bit of // tuning. - if (NumElements > 1024) + if (NumElements > IC.MaxArraySizeForCombine) return nullptr; const DataLayout &DL = IC.getDataLayout(); @@ -628,11 +740,14 @@ static Instruction *unpackLoadToAggregate(InstCombiner &IC, LoadInst &LI) { Zero, ConstantInt::get(IdxType, i), }; - auto *Ptr = IC.Builder->CreateInBoundsGEP(AT, Addr, makeArrayRef(Indices), - Name + ".elt"); - auto *L = IC.Builder->CreateAlignedLoad(Ptr, MinAlign(Align, Offset), - Name + ".unpack"); - V = IC.Builder->CreateInsertValue(V, L, i); + auto *Ptr = IC.Builder.CreateInBoundsGEP(AT, Addr, makeArrayRef(Indices), + Name + ".elt"); + auto *L = IC.Builder.CreateAlignedLoad(Ptr, MinAlign(Align, Offset), + Name + ".unpack"); + AAMDNodes AAMD; + LI.getAAMetadata(AAMD); + L->setAAMetadata(AAMD); + V = IC.Builder.CreateInsertValue(V, L, i); Offset += EltSize; } @@ -772,10 +887,8 @@ static bool canReplaceGEPIdxWithZero(InstCombiner &IC, GetElementPtrInst *GEPI, // first non-zero index. auto IsAllNonNegative = [&]() { for (unsigned i = Idx+1, e = GEPI->getNumOperands(); i != e; ++i) { - bool KnownNonNegative, KnownNegative; - IC.ComputeSignBit(GEPI->getOperand(i), KnownNonNegative, - KnownNegative, 0, MemI); - if (KnownNonNegative) + KnownBits Known = IC.computeKnownBits(GEPI->getOperand(i), 0, MemI); + if (Known.isNonNegative()) continue; return false; } @@ -818,6 +931,18 @@ static Instruction *replaceGEPIdxWithZero(InstCombiner &IC, Value *Ptr, return nullptr; } +static bool canSimplifyNullLoadOrGEP(LoadInst &LI, Value *Op) { + if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(Op)) { + const Value *GEPI0 = GEPI->getOperand(0); + if (isa<ConstantPointerNull>(GEPI0) && GEPI->getPointerAddressSpace() == 0) + return true; + } + if (isa<UndefValue>(Op) || + (isa<ConstantPointerNull>(Op) && LI.getPointerAddressSpace() == 0)) + return true; + return false; +} + Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { Value *Op = LI.getOperand(0); @@ -857,8 +982,8 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { combineMetadataForCSE(cast<LoadInst>(AvailableVal), &LI); return replaceInstUsesWith( - LI, Builder->CreateBitOrPointerCast(AvailableVal, LI.getType(), - LI.getName() + ".cast")); + LI, Builder.CreateBitOrPointerCast(AvailableVal, LI.getType(), + LI.getName() + ".cast")); } // None of the following transforms are legal for volatile/ordered atomic @@ -866,29 +991,16 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { if (!LI.isUnordered()) return nullptr; // load(gep null, ...) -> unreachable - if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(Op)) { - const Value *GEPI0 = GEPI->getOperand(0); - // TODO: Consider a target hook for valid address spaces for this xform. - if (isa<ConstantPointerNull>(GEPI0) && GEPI->getPointerAddressSpace() == 0){ - // Insert a new store to null instruction before the load to indicate - // that this code is not reachable. We do this instead of inserting - // an unreachable instruction directly because we cannot modify the - // CFG. - new StoreInst(UndefValue::get(LI.getType()), - Constant::getNullValue(Op->getType()), &LI); - return replaceInstUsesWith(LI, UndefValue::get(LI.getType())); - } - } - // load null/undef -> unreachable - // TODO: Consider a target hook for valid address spaces for this xform. - if (isa<UndefValue>(Op) || - (isa<ConstantPointerNull>(Op) && LI.getPointerAddressSpace() == 0)) { - // Insert a new store to null instruction before the load to indicate that - // this code is not reachable. We do this instead of inserting an - // unreachable instruction directly because we cannot modify the CFG. - new StoreInst(UndefValue::get(LI.getType()), - Constant::getNullValue(Op->getType()), &LI); + // TODO: Consider a target hook for valid address spaces for this xforms. + if (canSimplifyNullLoadOrGEP(LI, Op)) { + // Insert a new store to null instruction before the load to indicate + // that this code is not reachable. We do this instead of inserting + // an unreachable instruction directly because we cannot modify the + // CFG. + StoreInst *SI = new StoreInst(UndefValue::get(LI.getType()), + Constant::getNullValue(Op->getType()), &LI); + SI->setDebugLoc(LI.getDebugLoc()); return replaceInstUsesWith(LI, UndefValue::get(LI.getType())); } @@ -908,15 +1020,15 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { unsigned Align = LI.getAlignment(); if (isSafeToLoadUnconditionally(SI->getOperand(1), Align, DL, SI) && isSafeToLoadUnconditionally(SI->getOperand(2), Align, DL, SI)) { - LoadInst *V1 = Builder->CreateLoad(SI->getOperand(1), - SI->getOperand(1)->getName()+".val"); - LoadInst *V2 = Builder->CreateLoad(SI->getOperand(2), - SI->getOperand(2)->getName()+".val"); + LoadInst *V1 = Builder.CreateLoad(SI->getOperand(1), + SI->getOperand(1)->getName()+".val"); + LoadInst *V2 = Builder.CreateLoad(SI->getOperand(2), + SI->getOperand(2)->getName()+".val"); assert(LI.isUnordered() && "implied by above"); V1->setAlignment(Align); - V1->setAtomic(LI.getOrdering(), LI.getSynchScope()); + V1->setAtomic(LI.getOrdering(), LI.getSyncScopeID()); V2->setAlignment(Align); - V2->setAtomic(LI.getOrdering(), LI.getSynchScope()); + V2->setAtomic(LI.getOrdering(), LI.getSyncScopeID()); return SelectInst::Create(SI->getCondition(), V1, V2); } @@ -1061,7 +1173,7 @@ static bool unpackStoreToAggregate(InstCombiner &IC, StoreInst &SI) { // If the struct only have one element, we unpack. unsigned Count = ST->getNumElements(); if (Count == 1) { - V = IC.Builder->CreateExtractValue(V, 0); + V = IC.Builder.CreateExtractValue(V, 0); combineStoreToNewValue(IC, SI, V); return true; } @@ -1090,11 +1202,14 @@ static bool unpackStoreToAggregate(InstCombiner &IC, StoreInst &SI) { Zero, ConstantInt::get(IdxType, i), }; - auto *Ptr = IC.Builder->CreateInBoundsGEP(ST, Addr, makeArrayRef(Indices), - AddrName); - auto *Val = IC.Builder->CreateExtractValue(V, i, EltName); + auto *Ptr = IC.Builder.CreateInBoundsGEP(ST, Addr, makeArrayRef(Indices), + AddrName); + auto *Val = IC.Builder.CreateExtractValue(V, i, EltName); auto EltAlign = MinAlign(Align, SL->getElementOffset(i)); - IC.Builder->CreateAlignedStore(Val, Ptr, EltAlign); + llvm::Instruction *NS = IC.Builder.CreateAlignedStore(Val, Ptr, EltAlign); + AAMDNodes AAMD; + SI.getAAMetadata(AAMD); + NS->setAAMetadata(AAMD); } return true; @@ -1104,7 +1219,7 @@ static bool unpackStoreToAggregate(InstCombiner &IC, StoreInst &SI) { // If the array only have one element, we unpack. auto NumElements = AT->getNumElements(); if (NumElements == 1) { - V = IC.Builder->CreateExtractValue(V, 0); + V = IC.Builder.CreateExtractValue(V, 0); combineStoreToNewValue(IC, SI, V); return true; } @@ -1113,7 +1228,7 @@ static bool unpackStoreToAggregate(InstCombiner &IC, StoreInst &SI) { // arrays of arbitrary size but this has a terrible impact on compile time. // The threshold here is chosen arbitrarily, maybe needs a little bit of // tuning. - if (NumElements > 1024) + if (NumElements > IC.MaxArraySizeForCombine) return false; const DataLayout &DL = IC.getDataLayout(); @@ -1137,11 +1252,14 @@ static bool unpackStoreToAggregate(InstCombiner &IC, StoreInst &SI) { Zero, ConstantInt::get(IdxType, i), }; - auto *Ptr = IC.Builder->CreateInBoundsGEP(AT, Addr, makeArrayRef(Indices), - AddrName); - auto *Val = IC.Builder->CreateExtractValue(V, i, EltName); + auto *Ptr = IC.Builder.CreateInBoundsGEP(AT, Addr, makeArrayRef(Indices), + AddrName); + auto *Val = IC.Builder.CreateExtractValue(V, i, EltName); auto EltAlign = MinAlign(Align, Offset); - IC.Builder->CreateAlignedStore(Val, Ptr, EltAlign); + Instruction *NS = IC.Builder.CreateAlignedStore(Val, Ptr, EltAlign); + AAMDNodes AAMD; + SI.getAAMetadata(AAMD); + NS->setAAMetadata(AAMD); Offset += EltSize; } @@ -1268,8 +1386,8 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { break; } - // Don't skip over loads or things that can modify memory. - if (BBI->mayWriteToMemory() || BBI->mayReadFromMemory()) + // Don't skip over loads, throws or things that can modify memory. + if (BBI->mayWriteToMemory() || BBI->mayReadFromMemory() || BBI->mayThrow()) break; } @@ -1392,8 +1510,8 @@ bool InstCombiner::SimplifyStoreAtEndOfBlock(StoreInst &SI) { } // If we find something that may be using or overwriting the stored // value, or if we run out of instructions, we can't do the xform. - if (BBI->mayReadFromMemory() || BBI->mayWriteToMemory() || - BBI == OtherBB->begin()) + if (BBI->mayReadFromMemory() || BBI->mayThrow() || + BBI->mayWriteToMemory() || BBI == OtherBB->begin()) return false; } @@ -1402,7 +1520,7 @@ bool InstCombiner::SimplifyStoreAtEndOfBlock(StoreInst &SI) { // StoreBB. for (BasicBlock::iterator I = StoreBB->begin(); &*I != &SI; ++I) { // FIXME: This should really be AA driven. - if (I->mayReadFromMemory() || I->mayWriteToMemory()) + if (I->mayReadFromMemory() || I->mayThrow() || I->mayWriteToMemory()) return false; } } @@ -1423,9 +1541,11 @@ bool InstCombiner::SimplifyStoreAtEndOfBlock(StoreInst &SI) { SI.isVolatile(), SI.getAlignment(), SI.getOrdering(), - SI.getSynchScope()); + SI.getSyncScopeID()); InsertNewInstBefore(NewSI, *BBI); - NewSI->setDebugLoc(OtherStore->getDebugLoc()); + // The debug locations of the original instructions might differ; merge them. + NewSI->setDebugLoc(DILocation::getMergedLocation(SI.getDebugLoc(), + OtherStore->getDebugLoc())); // If the two stores had AA tags, merge them. AAMDNodes AATags; diff --git a/contrib/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/contrib/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 45a19fb..e3a5022 100644 --- a/contrib/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/contrib/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -39,17 +39,15 @@ static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC, Value *A = nullptr, *B = nullptr, *One = nullptr; if (match(V, m_LShr(m_OneUse(m_Shl(m_Value(One), m_Value(A))), m_Value(B))) && match(One, m_One())) { - A = IC.Builder->CreateSub(A, B); - return IC.Builder->CreateShl(One, A); + A = IC.Builder.CreateSub(A, B); + return IC.Builder.CreateShl(One, A); } // (PowerOfTwo >>u B) --> isExact since shifting out the result would make it // inexact. Similarly for <<. BinaryOperator *I = dyn_cast<BinaryOperator>(V); if (I && I->isLogicalShift() && - isKnownToBeAPowerOfTwo(I->getOperand(0), IC.getDataLayout(), false, 0, - &IC.getAssumptionCache(), &CxtI, - &IC.getDominatorTree())) { + IC.isKnownToBeAPowerOfTwo(I->getOperand(0), false, 0, &CxtI)) { // We know that this is an exact/nuw shift and that the input is a // non-zero context as well. if (Value *V2 = simplifyValueKnownNonZero(I->getOperand(0), IC, CxtI)) { @@ -132,8 +130,9 @@ static Constant *getLogBase2Vector(ConstantDataVector *CV) { /// \brief Return true if we can prove that: /// (mul LHS, RHS) === (mul nsw LHS, RHS) -bool InstCombiner::WillNotOverflowSignedMul(Value *LHS, Value *RHS, - Instruction &CxtI) { +bool InstCombiner::willNotOverflowSignedMul(const Value *LHS, + const Value *RHS, + const Instruction &CxtI) const { // Multiplying n * m significant bits yields a result of n + m significant // bits. If the total number of significant bits does not exceed the // result bit width (minus 1), there is no overflow. @@ -162,11 +161,9 @@ bool InstCombiner::WillNotOverflowSignedMul(Value *LHS, Value *RHS, // product is exactly the minimum negative number. // E.g. mul i16 with 17 sign bits: 0xff00 * 0xff80 = 0x8000 // For simplicity we just check if at least one side is not negative. - bool LHSNonNegative, LHSNegative; - bool RHSNonNegative, RHSNegative; - ComputeSignBit(LHS, LHSNonNegative, LHSNegative, /*Depth=*/0, &CxtI); - ComputeSignBit(RHS, RHSNonNegative, RHSNegative, /*Depth=*/0, &CxtI); - if (LHSNonNegative || RHSNonNegative) + KnownBits LHSKnown = computeKnownBits(LHS, /*Depth=*/0, &CxtI); + KnownBits RHSKnown = computeKnownBits(RHS, /*Depth=*/0, &CxtI); + if (LHSKnown.isNonNegative() || RHSKnown.isNonNegative()) return true; } return false; @@ -179,7 +176,7 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); - if (Value *V = SimplifyMulInst(Op0, Op1, DL, &TLI, &DT, &AC)) + if (Value *V = SimplifyMulInst(Op0, Op1, SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); if (Value *V = SimplifyUsingDistributiveLaws(I)) @@ -230,8 +227,8 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { if (I.hasNoUnsignedWrap()) Shl->setHasNoUnsignedWrap(); if (I.hasNoSignedWrap()) { - uint64_t V; - if (match(NewCst, m_ConstantInt(V)) && V != Width - 1) + const APInt *V; + if (match(NewCst, m_APInt(V)) && *V != Width - 1) Shl->setHasNoSignedWrap(); } @@ -253,9 +250,9 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { ConstantInt *C1; Value *Sub = nullptr; if (match(Op0, m_Sub(m_Value(Y), m_Value(X)))) - Sub = Builder->CreateSub(X, Y, "suba"); + Sub = Builder.CreateSub(X, Y, "suba"); else if (match(Op0, m_Add(m_Value(Y), m_ConstantInt(C1)))) - Sub = Builder->CreateSub(Builder->CreateNeg(C1), Y, "subc"); + Sub = Builder.CreateSub(Builder.CreateNeg(C1), Y, "subc"); if (Sub) return BinaryOperator::CreateMul(Sub, @@ -275,11 +272,11 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { Value *X; Constant *C1; if (match(Op0, m_OneUse(m_Add(m_Value(X), m_Constant(C1))))) { - Value *Mul = Builder->CreateMul(C1, Op1); + Value *Mul = Builder.CreateMul(C1, Op1); // Only go forward with the transform if C1*CI simplifies to a tidier // constant. if (!match(Mul, m_Mul(m_Value(), m_Value()))) - return BinaryOperator::CreateAdd(Builder->CreateMul(X, Op1), Mul); + return BinaryOperator::CreateAdd(Builder.CreateMul(X, Op1), Mul); } } } @@ -298,44 +295,38 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { // (X / Y) * Y = X - (X % Y) // (X / Y) * -Y = (X % Y) - X { - Value *Op1C = Op1; - BinaryOperator *BO = dyn_cast<BinaryOperator>(Op0); - if (!BO || - (BO->getOpcode() != Instruction::UDiv && - BO->getOpcode() != Instruction::SDiv)) { - Op1C = Op0; - BO = dyn_cast<BinaryOperator>(Op1); + Value *Y = Op1; + BinaryOperator *Div = dyn_cast<BinaryOperator>(Op0); + if (!Div || (Div->getOpcode() != Instruction::UDiv && + Div->getOpcode() != Instruction::SDiv)) { + Y = Op0; + Div = dyn_cast<BinaryOperator>(Op1); } - Value *Neg = dyn_castNegVal(Op1C); - if (BO && BO->hasOneUse() && - (BO->getOperand(1) == Op1C || BO->getOperand(1) == Neg) && - (BO->getOpcode() == Instruction::UDiv || - BO->getOpcode() == Instruction::SDiv)) { - Value *Op0BO = BO->getOperand(0), *Op1BO = BO->getOperand(1); + Value *Neg = dyn_castNegVal(Y); + if (Div && Div->hasOneUse() && + (Div->getOperand(1) == Y || Div->getOperand(1) == Neg) && + (Div->getOpcode() == Instruction::UDiv || + Div->getOpcode() == Instruction::SDiv)) { + Value *X = Div->getOperand(0), *DivOp1 = Div->getOperand(1); // If the division is exact, X % Y is zero, so we end up with X or -X. - if (PossiblyExactOperator *SDiv = dyn_cast<PossiblyExactOperator>(BO)) - if (SDiv->isExact()) { - if (Op1BO == Op1C) - return replaceInstUsesWith(I, Op0BO); - return BinaryOperator::CreateNeg(Op0BO); - } - - Value *Rem; - if (BO->getOpcode() == Instruction::UDiv) - Rem = Builder->CreateURem(Op0BO, Op1BO); - else - Rem = Builder->CreateSRem(Op0BO, Op1BO); - Rem->takeName(BO); + if (Div->isExact()) { + if (DivOp1 == Y) + return replaceInstUsesWith(I, X); + return BinaryOperator::CreateNeg(X); + } - if (Op1BO == Op1C) - return BinaryOperator::CreateSub(Op0BO, Rem); - return BinaryOperator::CreateSub(Rem, Op0BO); + auto RemOpc = Div->getOpcode() == Instruction::UDiv ? Instruction::URem + : Instruction::SRem; + Value *Rem = Builder.CreateBinOp(RemOpc, X, DivOp1); + if (DivOp1 == Y) + return BinaryOperator::CreateSub(X, Rem); + return BinaryOperator::CreateSub(Rem, X); } } /// i1 mul -> i1 and. - if (I.getType()->getScalarType()->isIntegerTy(1)) + if (I.getType()->isIntOrIntVectorTy(1)) return BinaryOperator::CreateAnd(Op0, Op1); // X*(1 << Y) --> X << Y @@ -377,7 +368,7 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { } if (BoolCast) { - Value *V = Builder->CreateSub(Constant::getNullValue(I.getType()), + Value *V = Builder.CreateSub(Constant::getNullValue(I.getType()), BoolCast); return BinaryOperator::CreateAnd(V, OtherOp); } @@ -392,10 +383,10 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { Constant *CI = ConstantExpr::getTrunc(Op1C, Op0Conv->getOperand(0)->getType()); if (ConstantExpr::getSExt(CI, I.getType()) == Op1C && - WillNotOverflowSignedMul(Op0Conv->getOperand(0), CI, I)) { + willNotOverflowSignedMul(Op0Conv->getOperand(0), CI, I)) { // Insert the new, smaller mul. Value *NewMul = - Builder->CreateNSWMul(Op0Conv->getOperand(0), CI, "mulconv"); + Builder.CreateNSWMul(Op0Conv->getOperand(0), CI, "mulconv"); return new SExtInst(NewMul, I.getType()); } } @@ -409,10 +400,10 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { if (Op0Conv->getOperand(0)->getType() == Op1Conv->getOperand(0)->getType() && (Op0Conv->hasOneUse() || Op1Conv->hasOneUse()) && - WillNotOverflowSignedMul(Op0Conv->getOperand(0), + willNotOverflowSignedMul(Op0Conv->getOperand(0), Op1Conv->getOperand(0), I)) { // Insert the new integer mul. - Value *NewMul = Builder->CreateNSWMul( + Value *NewMul = Builder.CreateNSWMul( Op0Conv->getOperand(0), Op1Conv->getOperand(0), "mulconv"); return new SExtInst(NewMul, I.getType()); } @@ -428,11 +419,10 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { Constant *CI = ConstantExpr::getTrunc(Op1C, Op0Conv->getOperand(0)->getType()); if (ConstantExpr::getZExt(CI, I.getType()) == Op1C && - computeOverflowForUnsignedMul(Op0Conv->getOperand(0), CI, &I) == - OverflowResult::NeverOverflows) { + willNotOverflowUnsignedMul(Op0Conv->getOperand(0), CI, I)) { // Insert the new, smaller mul. Value *NewMul = - Builder->CreateNUWMul(Op0Conv->getOperand(0), CI, "mulconv"); + Builder.CreateNUWMul(Op0Conv->getOperand(0), CI, "mulconv"); return new ZExtInst(NewMul, I.getType()); } } @@ -446,25 +436,22 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { if (Op0Conv->getOperand(0)->getType() == Op1Conv->getOperand(0)->getType() && (Op0Conv->hasOneUse() || Op1Conv->hasOneUse()) && - computeOverflowForUnsignedMul(Op0Conv->getOperand(0), - Op1Conv->getOperand(0), - &I) == OverflowResult::NeverOverflows) { + willNotOverflowUnsignedMul(Op0Conv->getOperand(0), + Op1Conv->getOperand(0), I)) { // Insert the new integer mul. - Value *NewMul = Builder->CreateNUWMul( + Value *NewMul = Builder.CreateNUWMul( Op0Conv->getOperand(0), Op1Conv->getOperand(0), "mulconv"); return new ZExtInst(NewMul, I.getType()); } } } - if (!I.hasNoSignedWrap() && WillNotOverflowSignedMul(Op0, Op1, I)) { + if (!I.hasNoSignedWrap() && willNotOverflowSignedMul(Op0, Op1, I)) { Changed = true; I.setHasNoSignedWrap(true); } - if (!I.hasNoUnsignedWrap() && - computeOverflowForUnsignedMul(Op0, Op1, &I) == - OverflowResult::NeverOverflows) { + if (!I.hasNoUnsignedWrap() && willNotOverflowUnsignedMul(Op0, Op1, I)) { Changed = true; I.setHasNoUnsignedWrap(true); } @@ -612,8 +599,8 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { if (isa<Constant>(Op0)) std::swap(Op0, Op1); - if (Value *V = - SimplifyFMulInst(Op0, Op1, I.getFastMathFlags(), DL, &TLI, &DT, &AC)) + if (Value *V = SimplifyFMulInst(Op0, Op1, I.getFastMathFlags(), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); bool AllowReassociate = I.hasUnsafeAlgebra(); @@ -711,11 +698,11 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { } // if pattern detected emit alternate sequence if (OpX && OpY) { - BuilderTy::FastMathFlagGuard Guard(*Builder); - Builder->setFastMathFlags(Log2->getFastMathFlags()); + BuilderTy::FastMathFlagGuard Guard(Builder); + Builder.setFastMathFlags(Log2->getFastMathFlags()); Log2->setArgOperand(0, OpY); - Value *FMulVal = Builder->CreateFMul(OpX, Log2); - Value *FSub = Builder->CreateFSub(FMulVal, OpX); + Value *FMulVal = Builder.CreateFMul(OpX, Log2); + Value *FSub = Builder.CreateFSub(FMulVal, OpX); FSub->takeName(&I); return replaceInstUsesWith(I, FSub); } @@ -727,23 +714,23 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { for (int i = 0; i < 2; i++) { bool IgnoreZeroSign = I.hasNoSignedZeros(); if (BinaryOperator::isFNeg(Opnd0, IgnoreZeroSign)) { - BuilderTy::FastMathFlagGuard Guard(*Builder); - Builder->setFastMathFlags(I.getFastMathFlags()); + BuilderTy::FastMathFlagGuard Guard(Builder); + Builder.setFastMathFlags(I.getFastMathFlags()); Value *N0 = dyn_castFNegVal(Opnd0, IgnoreZeroSign); Value *N1 = dyn_castFNegVal(Opnd1, IgnoreZeroSign); // -X * -Y => X*Y if (N1) { - Value *FMul = Builder->CreateFMul(N0, N1); + Value *FMul = Builder.CreateFMul(N0, N1); FMul->takeName(&I); return replaceInstUsesWith(I, FMul); } if (Opnd0->hasOneUse()) { // -X * Y => -(X*Y) (Promote negation as high as possible) - Value *T = Builder->CreateFMul(N0, Opnd1); - Value *Neg = Builder->CreateFNeg(T); + Value *T = Builder.CreateFMul(N0, Opnd1); + Value *Neg = Builder.CreateFNeg(T); Neg->takeName(&I); return replaceInstUsesWith(I, Neg); } @@ -768,10 +755,10 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { Y = Opnd0_0; if (Y) { - BuilderTy::FastMathFlagGuard Guard(*Builder); - Builder->setFastMathFlags(I.getFastMathFlags()); - Value *T = Builder->CreateFMul(Opnd1, Opnd1); - Value *R = Builder->CreateFMul(T, Y); + BuilderTy::FastMathFlagGuard Guard(Builder); + Builder.setFastMathFlags(I.getFastMathFlags()); + Value *T = Builder.CreateFMul(Opnd1, Opnd1); + Value *R = Builder.CreateFMul(T, Y); R->takeName(&I); return replaceInstUsesWith(I, R); } @@ -837,7 +824,7 @@ bool InstCombiner::SimplifyDivRemOfSelect(BinaryOperator &I) { *I = SI->getOperand(NonNullOperand); Worklist.Add(&*BBI); } else if (*I == SelectCond) { - *I = Builder->getInt1(NonNullOperand == 1); + *I = Builder.getInt1(NonNullOperand == 1); Worklist.Add(&*BBI); } } @@ -944,28 +931,25 @@ Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) { } } - if (*C2 != 0) // avoid X udiv 0 + if (!C2->isNullValue()) // avoid X udiv 0 if (Instruction *FoldedDiv = foldOpWithConstantIntoOperand(I)) return FoldedDiv; } } - if (ConstantInt *One = dyn_cast<ConstantInt>(Op0)) { - if (One->isOne() && !I.getType()->isIntegerTy(1)) { - bool isSigned = I.getOpcode() == Instruction::SDiv; - if (isSigned) { - // If Op1 is 0 then it's undefined behaviour, if Op1 is 1 then the - // result is one, if Op1 is -1 then the result is minus one, otherwise - // it's zero. - Value *Inc = Builder->CreateAdd(Op1, One); - Value *Cmp = Builder->CreateICmpULT( - Inc, ConstantInt::get(I.getType(), 3)); - return SelectInst::Create(Cmp, Op1, ConstantInt::get(I.getType(), 0)); - } else { - // If Op1 is 0 then it's undefined behaviour. If Op1 is 1 then the - // result is one, otherwise it's zero. - return new ZExtInst(Builder->CreateICmpEQ(Op1, One), I.getType()); - } + if (match(Op0, m_One())) { + assert(!I.getType()->isIntOrIntVectorTy(1) && "i1 divide not removed?"); + if (I.getOpcode() == Instruction::SDiv) { + // If Op1 is 0 then it's undefined behaviour, if Op1 is 1 then the + // result is one, if Op1 is -1 then the result is minus one, otherwise + // it's zero. + Value *Inc = Builder.CreateAdd(Op1, Op0); + Value *Cmp = Builder.CreateICmpULT(Inc, ConstantInt::get(I.getType(), 3)); + return SelectInst::Create(Cmp, Op1, ConstantInt::get(I.getType(), 0)); + } else { + // If Op1 is 0 then it's undefined behaviour. If Op1 is 1 then the + // result is one, otherwise it's zero. + return new ZExtInst(Builder.CreateICmpEQ(Op1, Op0), I.getType()); } } @@ -1040,7 +1024,7 @@ static Instruction *foldUDivPow2Cst(Value *Op0, Value *Op1, // X udiv C, where C >= signbit static Instruction *foldUDivNegCst(Value *Op0, Value *Op1, const BinaryOperator &I, InstCombiner &IC) { - Value *ICI = IC.Builder->CreateICmpULT(Op0, cast<ConstantInt>(Op1)); + Value *ICI = IC.Builder.CreateICmpULT(Op0, cast<ConstantInt>(Op1)); return SelectInst::Create(ICI, Constant::getNullValue(I.getType()), ConstantInt::get(I.getType(), 1)); @@ -1059,10 +1043,9 @@ static Instruction *foldUDivShl(Value *Op0, Value *Op1, const BinaryOperator &I, if (!match(ShiftLeft, m_Shl(m_APInt(CI), m_Value(N)))) llvm_unreachable("match should never fail here!"); if (*CI != 1) - N = IC.Builder->CreateAdd(N, - ConstantInt::get(N->getType(), CI->logBase2())); + N = IC.Builder.CreateAdd(N, ConstantInt::get(N->getType(), CI->logBase2())); if (Op1 != ShiftLeft) - N = IC.Builder->CreateZExt(N, Op1->getType()); + N = IC.Builder.CreateZExt(N, Op1->getType()); BinaryOperator *LShr = BinaryOperator::CreateLShr(Op0, N); if (I.isExact()) LShr->setIsExact(); @@ -1118,7 +1101,7 @@ Instruction *InstCombiner::visitUDiv(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); - if (Value *V = SimplifyUDivInst(Op0, Op1, DL, &TLI, &DT, &AC)) + if (Value *V = SimplifyUDivInst(Op0, Op1, SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); // Handle the integer div common cases @@ -1148,7 +1131,7 @@ Instruction *InstCombiner::visitUDiv(BinaryOperator &I) { if (ZExtInst *ZOp0 = dyn_cast<ZExtInst>(Op0)) if (Value *ZOp1 = dyn_castZExtVal(Op1, ZOp0->getSrcTy())) return new ZExtInst( - Builder->CreateUDiv(ZOp0->getOperand(0), ZOp1, "div", I.isExact()), + Builder.CreateUDiv(ZOp0->getOperand(0), ZOp1, "div", I.isExact()), I.getType()); // (LHS udiv (select (select (...)))) -> (LHS >> (select (select (...)))) @@ -1191,7 +1174,7 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); - if (Value *V = SimplifySDivInst(Op0, Op1, DL, &TLI, &DT, &AC)) + if (Value *V = SimplifySDivInst(Op0, Op1, SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); // Handle the integer div common cases @@ -1223,7 +1206,7 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { Constant *NarrowDivisor = ConstantExpr::getTrunc(cast<Constant>(Op1), Op0Src->getType()); - Value *NarrowOp = Builder->CreateSDiv(Op0Src, NarrowDivisor); + Value *NarrowOp = Builder.CreateSDiv(Op0Src, NarrowDivisor); return new SExtInst(NarrowOp, Op0->getType()); } } @@ -1231,7 +1214,7 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { if (Constant *RHS = dyn_cast<Constant>(Op1)) { // X/INT_MIN -> X == INT_MIN if (RHS->isMinSignedValue()) - return new ZExtInst(Builder->CreateICmpEQ(Op0, Op1), I.getType()); + return new ZExtInst(Builder.CreateICmpEQ(Op0, Op1), I.getType()); // -X/C --> X/-C provided the negation doesn't overflow. Value *X; @@ -1244,25 +1227,23 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { // If the sign bits of both operands are zero (i.e. we can prove they are // unsigned inputs), turn this into a udiv. - if (I.getType()->isIntegerTy()) { - APInt Mask(APInt::getSignBit(I.getType()->getPrimitiveSizeInBits())); - if (MaskedValueIsZero(Op0, Mask, 0, &I)) { - if (MaskedValueIsZero(Op1, Mask, 0, &I)) { - // X sdiv Y -> X udiv Y, iff X and Y don't have sign bit set - auto *BO = BinaryOperator::CreateUDiv(Op0, Op1, I.getName()); - BO->setIsExact(I.isExact()); - return BO; - } + APInt Mask(APInt::getSignMask(I.getType()->getScalarSizeInBits())); + if (MaskedValueIsZero(Op0, Mask, 0, &I)) { + if (MaskedValueIsZero(Op1, Mask, 0, &I)) { + // X sdiv Y -> X udiv Y, iff X and Y don't have sign bit set + auto *BO = BinaryOperator::CreateUDiv(Op0, Op1, I.getName()); + BO->setIsExact(I.isExact()); + return BO; + } - if (isKnownToBeAPowerOfTwo(Op1, DL, /*OrZero*/ true, 0, &AC, &I, &DT)) { - // X sdiv (1 << Y) -> X udiv (1 << Y) ( -> X u>> Y) - // Safe because the only negative value (1 << Y) can take on is - // INT_MIN, and X sdiv INT_MIN == X udiv INT_MIN == 0 if X doesn't have - // the sign bit set. - auto *BO = BinaryOperator::CreateUDiv(Op0, Op1, I.getName()); - BO->setIsExact(I.isExact()); - return BO; - } + if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/ true, 0, &I)) { + // X sdiv (1 << Y) -> X udiv (1 << Y) ( -> X u>> Y) + // Safe because the only negative value (1 << Y) can take on is + // INT_MIN, and X sdiv INT_MIN == X udiv INT_MIN == 0 if X doesn't have + // the sign bit set. + auto *BO = BinaryOperator::CreateUDiv(Op0, Op1, I.getName()); + BO->setIsExact(I.isExact()); + return BO; } } @@ -1306,7 +1287,7 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { return replaceInstUsesWith(I, V); if (Value *V = SimplifyFDivInst(Op0, Op1, I.getFastMathFlags(), - DL, &TLI, &DT, &AC)) + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); if (isa<Constant>(Op0)) @@ -1396,7 +1377,7 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { // (X/Y) / Z => X / (Y*Z) // if (!isa<Constant>(Y) || !isa<Constant>(Op1)) { - NewInst = Builder->CreateFMul(Y, Op1); + NewInst = Builder.CreateFMul(Y, Op1); if (Instruction *RI = dyn_cast<Instruction>(NewInst)) { FastMathFlags Flags = I.getFastMathFlags(); Flags &= cast<Instruction>(Op0)->getFastMathFlags(); @@ -1408,7 +1389,7 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { // Z / (X/Y) => Z*Y / X // if (!isa<Constant>(Y) || !isa<Constant>(Op0)) { - NewInst = Builder->CreateFMul(Op0, Y); + NewInst = Builder.CreateFMul(Op0, Y); if (Instruction *RI = dyn_cast<Instruction>(NewInst)) { FastMathFlags Flags = I.getFastMathFlags(); Flags &= cast<Instruction>(Op1)->getFastMathFlags(); @@ -1461,16 +1442,16 @@ Instruction *InstCombiner::commonIRemTransforms(BinaryOperator &I) { if (SelectInst *SI = dyn_cast<SelectInst>(Op0I)) { if (Instruction *R = FoldOpIntoSelect(I, SI)) return R; - } else if (isa<PHINode>(Op0I)) { + } else if (auto *PN = dyn_cast<PHINode>(Op0I)) { using namespace llvm::PatternMatch; const APInt *Op1Int; if (match(Op1, m_APInt(Op1Int)) && !Op1Int->isMinValue() && (I.getOpcode() == Instruction::URem || !Op1Int->isMinSignedValue())) { - // FoldOpIntoPhi will speculate instructions to the end of the PHI's + // foldOpIntoPhi will speculate instructions to the end of the PHI's // predecessor blocks, so do this only if we know the srem or urem // will not fault. - if (Instruction *NV = FoldOpIntoPhi(I)) + if (Instruction *NV = foldOpIntoPhi(I, PN)) return NV; } } @@ -1490,7 +1471,7 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); - if (Value *V = SimplifyURemInst(Op0, Op1, DL, &TLI, &DT, &AC)) + if (Value *V = SimplifyURemInst(Op0, Op1, SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); if (Instruction *common = commonIRemTransforms(I)) @@ -1499,28 +1480,28 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) { // (zext A) urem (zext B) --> zext (A urem B) if (ZExtInst *ZOp0 = dyn_cast<ZExtInst>(Op0)) if (Value *ZOp1 = dyn_castZExtVal(Op1, ZOp0->getSrcTy())) - return new ZExtInst(Builder->CreateURem(ZOp0->getOperand(0), ZOp1), + return new ZExtInst(Builder.CreateURem(ZOp0->getOperand(0), ZOp1), I.getType()); // X urem Y -> X and Y-1, where Y is a power of 2, - if (isKnownToBeAPowerOfTwo(Op1, DL, /*OrZero*/ true, 0, &AC, &I, &DT)) { + if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/ true, 0, &I)) { Constant *N1 = Constant::getAllOnesValue(I.getType()); - Value *Add = Builder->CreateAdd(Op1, N1); + Value *Add = Builder.CreateAdd(Op1, N1); return BinaryOperator::CreateAnd(Op0, Add); } // 1 urem X -> zext(X != 1) if (match(Op0, m_One())) { - Value *Cmp = Builder->CreateICmpNE(Op1, Op0); - Value *Ext = Builder->CreateZExt(Cmp, I.getType()); + Value *Cmp = Builder.CreateICmpNE(Op1, Op0); + Value *Ext = Builder.CreateZExt(Cmp, I.getType()); return replaceInstUsesWith(I, Ext); } // X urem C -> X < C ? X : X - C, where C >= signbit. const APInt *DivisorC; if (match(Op1, m_APInt(DivisorC)) && DivisorC->isNegative()) { - Value *Cmp = Builder->CreateICmpULT(Op0, Op1); - Value *Sub = Builder->CreateSub(Op0, Op1); + Value *Cmp = Builder.CreateICmpULT(Op0, Op1); + Value *Sub = Builder.CreateSub(Op0, Op1); return SelectInst::Create(Cmp, Op0, Sub); } @@ -1533,7 +1514,7 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); - if (Value *V = SimplifySRemInst(Op0, Op1, DL, &TLI, &DT, &AC)) + if (Value *V = SimplifySRemInst(Op0, Op1, SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); // Handle the integer rem common cases @@ -1552,13 +1533,11 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) { // If the sign bits of both operands are zero (i.e. we can prove they are // unsigned inputs), turn this into a urem. - if (I.getType()->isIntegerTy()) { - APInt Mask(APInt::getSignBit(I.getType()->getPrimitiveSizeInBits())); - if (MaskedValueIsZero(Op1, Mask, 0, &I) && - MaskedValueIsZero(Op0, Mask, 0, &I)) { - // X srem Y -> X urem Y, iff X and Y don't have sign bit set - return BinaryOperator::CreateURem(Op0, Op1, I.getName()); - } + APInt Mask(APInt::getSignMask(I.getType()->getScalarSizeInBits())); + if (MaskedValueIsZero(Op1, Mask, 0, &I) && + MaskedValueIsZero(Op0, Mask, 0, &I)) { + // X srem Y -> X urem Y, iff X and Y don't have sign bit set + return BinaryOperator::CreateURem(Op0, Op1, I.getName()); } // If it's a constant vector, flip any negative values positive. @@ -1609,7 +1588,7 @@ Instruction *InstCombiner::visitFRem(BinaryOperator &I) { return replaceInstUsesWith(I, V); if (Value *V = SimplifyFRemInst(Op0, Op1, I.getFastMathFlags(), - DL, &TLI, &DT, &AC)) + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); // Handle cases involving: rem X, (select Cond, Y, Z) diff --git a/contrib/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/contrib/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp index 4cbffe9..0011412 100644 --- a/contrib/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp +++ b/contrib/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp @@ -16,9 +16,9 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/DebugInfo.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Transforms/Utils/Local.h" -#include "llvm/IR/DebugInfo.h" using namespace llvm; using namespace llvm::PatternMatch; @@ -457,8 +457,8 @@ Instruction *InstCombiner::FoldPHIArgZextsIntoPHI(PHINode &Phi) { } // The more common cases of a phi with no constant operands or just one - // variable operand are handled by FoldPHIArgOpIntoPHI() and FoldOpIntoPhi() - // respectively. FoldOpIntoPhi() wants to do the opposite transform that is + // variable operand are handled by FoldPHIArgOpIntoPHI() and foldOpIntoPhi() + // respectively. foldOpIntoPhi() wants to do the opposite transform that is // performed here. It tries to replicate a cast in the phi operand's basic // block to expose other folding opportunities. Thus, InstCombine will // infinite loop without this check. @@ -507,7 +507,7 @@ Instruction *InstCombiner::FoldPHIArgOpIntoPHI(PHINode &PN) { // Be careful about transforming integer PHIs. We don't want to pessimize // the code by turning an i32 into an i1293. if (PN.getType()->isIntegerTy() && CastSrcTy->isIntegerTy()) { - if (!ShouldChangeType(PN.getType(), CastSrcTy)) + if (!shouldChangeType(PN.getType(), CastSrcTy)) return nullptr; } } else if (isa<BinaryOperator>(FirstInst) || isa<CmpInst>(FirstInst)) { @@ -636,10 +636,10 @@ static bool PHIsEqualValue(PHINode *PN, Value *NonPhiInVal, /// Return an existing non-zero constant if this phi node has one, otherwise /// return constant 1. static ConstantInt *GetAnyNonZeroConstInt(PHINode &PN) { - assert(isa<IntegerType>(PN.getType()) && "Expect only intger type phi"); + assert(isa<IntegerType>(PN.getType()) && "Expect only integer type phi"); for (Value *V : PN.operands()) if (auto *ConstVA = dyn_cast<ConstantInt>(V)) - if (!ConstVA->isZeroValue()) + if (!ConstVA->isZero()) return ConstVA; return ConstantInt::get(cast<IntegerType>(PN.getType()), 1); } @@ -836,12 +836,12 @@ Instruction *InstCombiner::SliceUpIllegalIntegerPHI(PHINode &FirstPhi) { } // Otherwise, do an extract in the predecessor. - Builder->SetInsertPoint(Pred->getTerminator()); + Builder.SetInsertPoint(Pred->getTerminator()); Value *Res = InVal; if (Offset) - Res = Builder->CreateLShr(Res, ConstantInt::get(InVal->getType(), + Res = Builder.CreateLShr(Res, ConstantInt::get(InVal->getType(), Offset), "extract"); - Res = Builder->CreateTrunc(Res, Ty, "extract.t"); + Res = Builder.CreateTrunc(Res, Ty, "extract.t"); PredVal = Res; EltPHI->addIncoming(Res, Pred); @@ -880,7 +880,7 @@ Instruction *InstCombiner::SliceUpIllegalIntegerPHI(PHINode &FirstPhi) { // PHINode simplification // Instruction *InstCombiner::visitPHINode(PHINode &PN) { - if (Value *V = SimplifyInstruction(&PN, DL, &TLI, &DT, &AC)) + if (Value *V = SimplifyInstruction(&PN, SQ.getWithInstruction(&PN))) return replaceInstUsesWith(PN, V); if (Instruction *Result = FoldPHIArgZextsIntoPHI(PN)) diff --git a/contrib/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/contrib/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 3664484..4eebe82 100644 --- a/contrib/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/contrib/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -17,6 +17,7 @@ #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/Support/KnownBits.h" using namespace llvm; using namespace PatternMatch; @@ -60,12 +61,12 @@ static CmpInst::Predicate getCmpPredicateForMinMax(SelectPatternFlavor SPF, } } -static Value *generateMinMaxSelectPattern(InstCombiner::BuilderTy *Builder, +static Value *generateMinMaxSelectPattern(InstCombiner::BuilderTy &Builder, SelectPatternFlavor SPF, Value *A, Value *B) { CmpInst::Predicate Pred = getCmpPredicateForMinMax(SPF); assert(CmpInst::isIntPredicate(Pred)); - return Builder->CreateSelect(Builder->CreateICmp(Pred, A, B), A, B); + return Builder.CreateSelect(Builder.CreateICmp(Pred, A, B), A, B); } /// We want to turn code that looks like this: @@ -120,6 +121,16 @@ static Constant *getSelectFoldableConstant(Instruction *I) { /// We have (select c, TI, FI), and we know that TI and FI have the same opcode. Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI, Instruction *FI) { + // Don't break up min/max patterns. The hasOneUse checks below prevent that + // for most cases, but vector min/max with bitcasts can be transformed. If the + // one-use restrictions are eased for other patterns, we still don't want to + // obfuscate min/max. + if ((match(&SI, m_SMin(m_Value(), m_Value())) || + match(&SI, m_SMax(m_Value(), m_Value())) || + match(&SI, m_UMin(m_Value(), m_Value())) || + match(&SI, m_UMax(m_Value(), m_Value())))) + return nullptr; + // If this is a cast from the same type, merge. if (TI->getNumOperands() == 1 && TI->isCast()) { Type *FIOpndTy = FI->getOperand(0)->getType(); @@ -156,8 +167,8 @@ Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI, // Fold this by inserting a select from the input values. Value *NewSI = - Builder->CreateSelect(SI.getCondition(), TI->getOperand(0), - FI->getOperand(0), SI.getName() + ".v", &SI); + Builder.CreateSelect(SI.getCondition(), TI->getOperand(0), + FI->getOperand(0), SI.getName() + ".v", &SI); return CastInst::Create(Instruction::CastOps(TI->getOpcode()), NewSI, TI->getType()); } @@ -200,8 +211,8 @@ Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI, } // If we reach here, they do have operations in common. - Value *NewSI = Builder->CreateSelect(SI.getCondition(), OtherOpT, OtherOpF, - SI.getName() + ".v", &SI); + Value *NewSI = Builder.CreateSelect(SI.getCondition(), OtherOpT, OtherOpF, + SI.getName() + ".v", &SI); Value *Op0 = MatchIsOpZero ? MatchOp : NewSI; Value *Op1 = MatchIsOpZero ? NewSI : MatchOp; return BinaryOperator::Create(BO->getOpcode(), Op0, Op1); @@ -216,8 +227,8 @@ static bool isSelect01(Constant *C1, Constant *C2) { return false; if (!C1I->isZero() && !C2I->isZero()) // One side must be zero. return false; - return C1I->isOne() || C1I->isAllOnesValue() || - C2I->isOne() || C2I->isAllOnesValue(); + return C1I->isOne() || C1I->isMinusOne() || + C2I->isOne() || C2I->isMinusOne(); } /// Try to fold the select into one of the operands to allow further @@ -243,7 +254,7 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, // Avoid creating select between 2 constants unless it's selecting // between 0, 1 and -1. if (!isa<Constant>(OOp) || isSelect01(C, cast<Constant>(OOp))) { - Value *NewSel = Builder->CreateSelect(SI.getCondition(), OOp, C); + Value *NewSel = Builder.CreateSelect(SI.getCondition(), OOp, C); NewSel->takeName(TVI); BinaryOperator *TVI_BO = cast<BinaryOperator>(TVI); BinaryOperator *BO = BinaryOperator::Create(TVI_BO->getOpcode(), @@ -273,7 +284,7 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, // Avoid creating select between 2 constants unless it's selecting // between 0, 1 and -1. if (!isa<Constant>(OOp) || isSelect01(C, cast<Constant>(OOp))) { - Value *NewSel = Builder->CreateSelect(SI.getCondition(), C, OOp); + Value *NewSel = Builder.CreateSelect(SI.getCondition(), C, OOp); NewSel->takeName(FVI); BinaryOperator *FVI_BO = cast<BinaryOperator>(FVI); BinaryOperator *BO = BinaryOperator::Create(FVI_BO->getOpcode(), @@ -292,7 +303,7 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, /// We want to turn: /// (select (icmp eq (and X, C1), 0), Y, (or Y, C2)) /// into: -/// (or (shl (and X, C1), C3), y) +/// (or (shl (and X, C1), C3), Y) /// iff: /// C1 and C2 are both powers of 2 /// where: @@ -304,21 +315,46 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, /// 3. The magnitude of C2 and C1 are flipped static Value *foldSelectICmpAndOr(const SelectInst &SI, Value *TrueVal, Value *FalseVal, - InstCombiner::BuilderTy *Builder) { + InstCombiner::BuilderTy &Builder) { const ICmpInst *IC = dyn_cast<ICmpInst>(SI.getCondition()); - if (!IC || !IC->isEquality() || !SI.getType()->isIntegerTy()) + if (!IC || !SI.getType()->isIntegerTy()) return nullptr; Value *CmpLHS = IC->getOperand(0); Value *CmpRHS = IC->getOperand(1); - if (!match(CmpRHS, m_Zero())) - return nullptr; + Value *V; + unsigned C1Log; + bool IsEqualZero; + bool NeedAnd = false; + if (IC->isEquality()) { + if (!match(CmpRHS, m_Zero())) + return nullptr; + + const APInt *C1; + if (!match(CmpLHS, m_And(m_Value(), m_Power2(C1)))) + return nullptr; + + V = CmpLHS; + C1Log = C1->logBase2(); + IsEqualZero = IC->getPredicate() == ICmpInst::ICMP_EQ; + } else if (IC->getPredicate() == ICmpInst::ICMP_SLT || + IC->getPredicate() == ICmpInst::ICMP_SGT) { + // We also need to recognize (icmp slt (trunc (X)), 0) and + // (icmp sgt (trunc (X)), -1). + IsEqualZero = IC->getPredicate() == ICmpInst::ICMP_SGT; + if ((IsEqualZero && !match(CmpRHS, m_AllOnes())) || + (!IsEqualZero && !match(CmpRHS, m_Zero()))) + return nullptr; + + if (!match(CmpLHS, m_OneUse(m_Trunc(m_Value(V))))) + return nullptr; - Value *X; - const APInt *C1; - if (!match(CmpLHS, m_And(m_Value(X), m_Power2(C1)))) + C1Log = CmpLHS->getType()->getScalarSizeInBits() - 1; + NeedAnd = true; + } else { return nullptr; + } const APInt *C2; bool OrOnTrueVal = false; @@ -329,26 +365,40 @@ static Value *foldSelectICmpAndOr(const SelectInst &SI, Value *TrueVal, if (!OrOnFalseVal && !OrOnTrueVal) return nullptr; - Value *V = CmpLHS; Value *Y = OrOnFalseVal ? TrueVal : FalseVal; - unsigned C1Log = C1->logBase2(); unsigned C2Log = C2->logBase2(); + + bool NeedXor = (!IsEqualZero && OrOnFalseVal) || (IsEqualZero && OrOnTrueVal); + bool NeedShift = C1Log != C2Log; + bool NeedZExtTrunc = Y->getType()->getIntegerBitWidth() != + V->getType()->getIntegerBitWidth(); + + // Make sure we don't create more instructions than we save. + Value *Or = OrOnFalseVal ? FalseVal : TrueVal; + if ((NeedShift + NeedXor + NeedZExtTrunc) > + (IC->hasOneUse() + Or->hasOneUse())) + return nullptr; + + if (NeedAnd) { + // Insert the AND instruction on the input to the truncate. + APInt C1 = APInt::getOneBitSet(V->getType()->getScalarSizeInBits(), C1Log); + V = Builder.CreateAnd(V, ConstantInt::get(V->getType(), C1)); + } + if (C2Log > C1Log) { - V = Builder->CreateZExtOrTrunc(V, Y->getType()); - V = Builder->CreateShl(V, C2Log - C1Log); + V = Builder.CreateZExtOrTrunc(V, Y->getType()); + V = Builder.CreateShl(V, C2Log - C1Log); } else if (C1Log > C2Log) { - V = Builder->CreateLShr(V, C1Log - C2Log); - V = Builder->CreateZExtOrTrunc(V, Y->getType()); + V = Builder.CreateLShr(V, C1Log - C2Log); + V = Builder.CreateZExtOrTrunc(V, Y->getType()); } else - V = Builder->CreateZExtOrTrunc(V, Y->getType()); + V = Builder.CreateZExtOrTrunc(V, Y->getType()); - ICmpInst::Predicate Pred = IC->getPredicate(); - if ((Pred == ICmpInst::ICMP_NE && OrOnFalseVal) || - (Pred == ICmpInst::ICMP_EQ && OrOnTrueVal)) - V = Builder->CreateXor(V, *C2); + if (NeedXor) + V = Builder.CreateXor(V, *C2); - return Builder->CreateOr(V, Y); + return Builder.CreateOr(V, Y); } /// Attempt to fold a cttz/ctlz followed by a icmp plus select into a single @@ -364,7 +414,7 @@ static Value *foldSelectICmpAndOr(const SelectInst &SI, Value *TrueVal, /// into: /// %0 = tail call i32 @llvm.cttz.i32(i32 %x, i1 false) static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal, - InstCombiner::BuilderTy *Builder) { + InstCombiner::BuilderTy &Builder) { ICmpInst::Predicate Pred = ICI->getPredicate(); Value *CmpLHS = ICI->getOperand(0); Value *CmpRHS = ICI->getOperand(1); @@ -395,7 +445,6 @@ static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal, if (match(Count, m_Intrinsic<Intrinsic::cttz>(m_Specific(CmpLHS))) || match(Count, m_Intrinsic<Intrinsic::ctlz>(m_Specific(CmpLHS)))) { IntrinsicInst *II = cast<IntrinsicInst>(Count); - IRBuilder<> Builder(II); // Explicitly clear the 'undef_on_zero' flag. IntrinsicInst *NewI = cast<IntrinsicInst>(II->clone()); Type *Ty = NewI->getArgOperand(1)->getType(); @@ -500,18 +549,16 @@ static bool adjustMinMax(SelectInst &Sel, ICmpInst &Cmp) { return true; } -/// If this is an integer min/max where the select's 'true' operand is a -/// constant, canonicalize that constant to the 'false' operand: -/// select (icmp Pred X, C), C, X --> select (icmp Pred' X, C), X, C +/// If this is an integer min/max (icmp + select) with a constant operand, +/// create the canonical icmp for the min/max operation and canonicalize the +/// constant to the 'false' operand of the select: +/// select (icmp Pred X, C1), C2, X --> select (icmp Pred' X, C2), X, C2 +/// Note: if C1 != C2, this will change the icmp constant to the existing +/// constant operand of the select. static Instruction * canonicalizeMinMaxWithConstant(SelectInst &Sel, ICmpInst &Cmp, InstCombiner::BuilderTy &Builder) { - // TODO: We should also canonicalize min/max when the select has a different - // constant value than the cmp constant, but we need to fix the backend first. - if (!Cmp.hasOneUse() || !isa<Constant>(Cmp.getOperand(1)) || - !isa<Constant>(Sel.getTrueValue()) || - isa<Constant>(Sel.getFalseValue()) || - Cmp.getOperand(1) != Sel.getTrueValue()) + if (!Cmp.hasOneUse() || !isa<Constant>(Cmp.getOperand(1))) return nullptr; // Canonicalize the compare predicate based on whether we have min or max. @@ -526,22 +573,31 @@ canonicalizeMinMaxWithConstant(SelectInst &Sel, ICmpInst &Cmp, default: return nullptr; } - // Canonicalize the constant to the right side. - if (isa<Constant>(LHS)) - std::swap(LHS, RHS); + // Is this already canonical? + if (Cmp.getOperand(0) == LHS && Cmp.getOperand(1) == RHS && + Cmp.getPredicate() == NewPred) + return nullptr; + + // Create the canonical compare and plug it into the select. + Sel.setCondition(Builder.CreateICmp(NewPred, LHS, RHS)); - Value *NewCmp = Builder.CreateICmp(NewPred, LHS, RHS); - SelectInst *NewSel = SelectInst::Create(NewCmp, LHS, RHS, "", nullptr, &Sel); + // If the select operands did not change, we're done. + if (Sel.getTrueValue() == LHS && Sel.getFalseValue() == RHS) + return &Sel; - // We swapped the select operands, so swap the metadata too. - NewSel->swapProfMetadata(); - return NewSel; + // If we are swapping the select operands, swap the metadata too. + assert(Sel.getTrueValue() == RHS && Sel.getFalseValue() == LHS && + "Unexpected results from matchSelectPattern"); + Sel.setTrueValue(LHS); + Sel.setFalseValue(RHS); + Sel.swapProfMetadata(); + return &Sel; } /// Visit a SelectInst that has an ICmpInst as its first operand. Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, ICmpInst *ICI) { - if (Instruction *NewSel = canonicalizeMinMaxWithConstant(SI, *ICI, *Builder)) + if (Instruction *NewSel = canonicalizeMinMaxWithConstant(SI, *ICI, Builder)) return NewSel; bool Changed = adjustMinMax(SI, *ICI); @@ -561,23 +617,23 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, if (TrueVal->getType() == Ty) { if (ConstantInt *Cmp = dyn_cast<ConstantInt>(CmpRHS)) { ConstantInt *C1 = nullptr, *C2 = nullptr; - if (Pred == ICmpInst::ICMP_SGT && Cmp->isAllOnesValue()) { + if (Pred == ICmpInst::ICMP_SGT && Cmp->isMinusOne()) { C1 = dyn_cast<ConstantInt>(TrueVal); C2 = dyn_cast<ConstantInt>(FalseVal); - } else if (Pred == ICmpInst::ICMP_SLT && Cmp->isNullValue()) { + } else if (Pred == ICmpInst::ICMP_SLT && Cmp->isZero()) { C1 = dyn_cast<ConstantInt>(FalseVal); C2 = dyn_cast<ConstantInt>(TrueVal); } if (C1 && C2) { // This shift results in either -1 or 0. - Value *AShr = Builder->CreateAShr(CmpLHS, Ty->getBitWidth()-1); + Value *AShr = Builder.CreateAShr(CmpLHS, Ty->getBitWidth() - 1); // Check if we can express the operation with a single or. - if (C2->isAllOnesValue()) - return replaceInstUsesWith(SI, Builder->CreateOr(AShr, C1)); + if (C2->isMinusOne()) + return replaceInstUsesWith(SI, Builder.CreateOr(AShr, C1)); - Value *And = Builder->CreateAnd(AShr, C2->getValue()-C1->getValue()); - return replaceInstUsesWith(SI, Builder->CreateAdd(And, C1)); + Value *And = Builder.CreateAnd(AShr, C2->getValue() - C1->getValue()); + return replaceInstUsesWith(SI, Builder.CreateAdd(And, C1)); } } } @@ -602,7 +658,7 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, { unsigned BitWidth = DL.getTypeSizeInBits(TrueVal->getType()->getScalarType()); - APInt MinSignedValue = APInt::getSignBit(BitWidth); + APInt MinSignedValue = APInt::getSignedMinValue(BitWidth); Value *X; const APInt *Y, *C; bool TrueWhenUnset; @@ -628,19 +684,19 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, // (X & Y) == 0 ? X : X ^ Y --> X & ~Y if (TrueWhenUnset && TrueVal == X && match(FalseVal, m_Xor(m_Specific(X), m_APInt(C))) && *Y == *C) - V = Builder->CreateAnd(X, ~(*Y)); + V = Builder.CreateAnd(X, ~(*Y)); // (X & Y) != 0 ? X ^ Y : X --> X & ~Y else if (!TrueWhenUnset && FalseVal == X && match(TrueVal, m_Xor(m_Specific(X), m_APInt(C))) && *Y == *C) - V = Builder->CreateAnd(X, ~(*Y)); + V = Builder.CreateAnd(X, ~(*Y)); // (X & Y) == 0 ? X ^ Y : X --> X | Y else if (TrueWhenUnset && FalseVal == X && match(TrueVal, m_Xor(m_Specific(X), m_APInt(C))) && *Y == *C) - V = Builder->CreateOr(X, *Y); + V = Builder.CreateOr(X, *Y); // (X & Y) != 0 ? X : X ^ Y --> X | Y else if (!TrueWhenUnset && TrueVal == X && match(FalseVal, m_Xor(m_Specific(X), m_APInt(C))) && *Y == *C) - V = Builder->CreateOr(X, *Y); + V = Builder.CreateOr(X, *Y); if (V) return replaceInstUsesWith(SI, V); @@ -753,8 +809,8 @@ Instruction *InstCombiner::foldSPFofSPF(Instruction *Inner, (SPF1 == SPF_NABS && SPF2 == SPF_ABS)) { SelectInst *SI = cast<SelectInst>(Inner); Value *NewSI = - Builder->CreateSelect(SI->getCondition(), SI->getFalseValue(), - SI->getTrueValue(), SI->getName(), SI); + Builder.CreateSelect(SI->getCondition(), SI->getFalseValue(), + SI->getTrueValue(), SI->getName(), SI); return replaceInstUsesWith(Outer, NewSI); } @@ -786,19 +842,21 @@ Instruction *InstCombiner::foldSPFofSPF(Instruction *Inner, // This transform is performance neutral if we can elide at least one xor from // the set of three operands, since we'll be tacking on an xor at the very // end. - if (IsFreeOrProfitableToInvert(A, NotA, ElidesXor) && + if (SelectPatternResult::isMinOrMax(SPF1) && + SelectPatternResult::isMinOrMax(SPF2) && + IsFreeOrProfitableToInvert(A, NotA, ElidesXor) && IsFreeOrProfitableToInvert(B, NotB, ElidesXor) && IsFreeOrProfitableToInvert(C, NotC, ElidesXor) && ElidesXor) { if (!NotA) - NotA = Builder->CreateNot(A); + NotA = Builder.CreateNot(A); if (!NotB) - NotB = Builder->CreateNot(B); + NotB = Builder.CreateNot(B); if (!NotC) - NotC = Builder->CreateNot(C); + NotC = Builder.CreateNot(C); Value *NewInner = generateMinMaxSelectPattern( Builder, getInverseMinMaxSelectPattern(SPF1), NotA, NotB); - Value *NewOuter = Builder->CreateNot(generateMinMaxSelectPattern( + Value *NewOuter = Builder.CreateNot(generateMinMaxSelectPattern( Builder, getInverseMinMaxSelectPattern(SPF2), NewInner, NotC)); return replaceInstUsesWith(Outer, NewOuter); } @@ -810,9 +868,9 @@ Instruction *InstCombiner::foldSPFofSPF(Instruction *Inner, /// icmp instruction with zero, and we have an 'and' with the non-constant value /// and a power of two we can turn the select into a shift on the result of the /// 'and'. -static Value *foldSelectICmpAnd(const SelectInst &SI, ConstantInt *TrueVal, - ConstantInt *FalseVal, - InstCombiner::BuilderTy *Builder) { +static Value *foldSelectICmpAnd(const SelectInst &SI, APInt TrueVal, + APInt FalseVal, + InstCombiner::BuilderTy &Builder) { const ICmpInst *IC = dyn_cast<ICmpInst>(SI.getCondition()); if (!IC || !IC->isEquality() || !SI.getType()->isIntegerTy()) return nullptr; @@ -828,56 +886,53 @@ static Value *foldSelectICmpAnd(const SelectInst &SI, ConstantInt *TrueVal, // If both select arms are non-zero see if we have a select of the form // 'x ? 2^n + C : C'. Then we can offset both arms by C, use the logic // for 'x ? 2^n : 0' and fix the thing up at the end. - ConstantInt *Offset = nullptr; - if (!TrueVal->isZero() && !FalseVal->isZero()) { - if ((TrueVal->getValue() - FalseVal->getValue()).isPowerOf2()) + APInt Offset(TrueVal.getBitWidth(), 0); + if (!TrueVal.isNullValue() && !FalseVal.isNullValue()) { + if ((TrueVal - FalseVal).isPowerOf2()) Offset = FalseVal; - else if ((FalseVal->getValue() - TrueVal->getValue()).isPowerOf2()) + else if ((FalseVal - TrueVal).isPowerOf2()) Offset = TrueVal; else return nullptr; // Adjust TrueVal and FalseVal to the offset. - TrueVal = ConstantInt::get(Builder->getContext(), - TrueVal->getValue() - Offset->getValue()); - FalseVal = ConstantInt::get(Builder->getContext(), - FalseVal->getValue() - Offset->getValue()); + TrueVal -= Offset; + FalseVal -= Offset; } // Make sure the mask in the 'and' and one of the select arms is a power of 2. if (!AndRHS->getValue().isPowerOf2() || - (!TrueVal->getValue().isPowerOf2() && - !FalseVal->getValue().isPowerOf2())) + (!TrueVal.isPowerOf2() && !FalseVal.isPowerOf2())) return nullptr; // Determine which shift is needed to transform result of the 'and' into the // desired result. - ConstantInt *ValC = !TrueVal->isZero() ? TrueVal : FalseVal; - unsigned ValZeros = ValC->getValue().logBase2(); + const APInt &ValC = !TrueVal.isNullValue() ? TrueVal : FalseVal; + unsigned ValZeros = ValC.logBase2(); unsigned AndZeros = AndRHS->getValue().logBase2(); // If types don't match we can still convert the select by introducing a zext // or a trunc of the 'and'. The trunc case requires that all of the truncated // bits are zero, we can figure that out by looking at the 'and' mask. - if (AndZeros >= ValC->getBitWidth()) + if (AndZeros >= ValC.getBitWidth()) return nullptr; - Value *V = Builder->CreateZExtOrTrunc(LHS, SI.getType()); + Value *V = Builder.CreateZExtOrTrunc(LHS, SI.getType()); if (ValZeros > AndZeros) - V = Builder->CreateShl(V, ValZeros - AndZeros); + V = Builder.CreateShl(V, ValZeros - AndZeros); else if (ValZeros < AndZeros) - V = Builder->CreateLShr(V, AndZeros - ValZeros); + V = Builder.CreateLShr(V, AndZeros - ValZeros); // Okay, now we know that everything is set up, we just don't know whether we // have a icmp_ne or icmp_eq and whether the true or false val is the zero. - bool ShouldNotVal = !TrueVal->isZero(); + bool ShouldNotVal = !TrueVal.isNullValue(); ShouldNotVal ^= IC->getPredicate() == ICmpInst::ICMP_NE; if (ShouldNotVal) - V = Builder->CreateXor(V, ValC); + V = Builder.CreateXor(V, ValC); // Apply an offset if needed. - if (Offset) - V = Builder->CreateAdd(V, Offset); + if (!Offset.isNullValue()) + V = Builder.CreateAdd(V, ConstantInt::get(V->getType(), Offset)); return V; } @@ -966,7 +1021,7 @@ Instruction *InstCombiner::foldSelectExtConst(SelectInst &Sel) { // TODO: Handle larger types? That requires adjusting FoldOpIntoSelect too. Value *X = ExtInst->getOperand(0); Type *SmallType = X->getType(); - if (!SmallType->getScalarType()->isIntegerTy(1)) + if (!SmallType->isIntOrIntVectorTy(1)) return nullptr; Constant *C; @@ -987,7 +1042,7 @@ Instruction *InstCombiner::foldSelectExtConst(SelectInst &Sel) { // select Cond, (ext X), C --> ext(select Cond, X, C') // select Cond, C, (ext X) --> ext(select Cond, C', X) - Value *NewSel = Builder->CreateSelect(Cond, X, TruncCVal, "narrow", &Sel); + Value *NewSel = Builder.CreateSelect(Cond, X, TruncCVal, "narrow", &Sel); return CastInst::Create(Instruction::CastOps(ExtOpcode), NewSel, SelType); } @@ -1035,8 +1090,10 @@ static Instruction *canonicalizeSelectToShuffle(SelectInst &SI) { // If the select condition element is false, choose from the 2nd vector. Mask.push_back(ConstantInt::get(Int32Ty, i + NumElts)); } else if (isa<UndefValue>(Elt)) { - // If the select condition element is undef, the shuffle mask is undef. - Mask.push_back(UndefValue::get(Int32Ty)); + // Undef in a select condition (choose one of the operands) does not mean + // the same thing as undef in a shuffle mask (any value is acceptable), so + // give up. + return nullptr; } else { // Bail out on a constant expression. return nullptr; @@ -1100,14 +1157,31 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Value *FalseVal = SI.getFalseValue(); Type *SelType = SI.getType(); - if (Value *V = - SimplifySelectInst(CondVal, TrueVal, FalseVal, DL, &TLI, &DT, &AC)) + if (Value *V = SimplifySelectInst(CondVal, TrueVal, FalseVal, + SQ.getWithInstruction(&SI))) return replaceInstUsesWith(SI, V); if (Instruction *I = canonicalizeSelectToShuffle(SI)) return I; - if (SelType->getScalarType()->isIntegerTy(1) && + // Canonicalize a one-use integer compare with a non-canonical predicate by + // inverting the predicate and swapping the select operands. This matches a + // compare canonicalization for conditional branches. + // TODO: Should we do the same for FP compares? + CmpInst::Predicate Pred; + if (match(CondVal, m_OneUse(m_ICmp(Pred, m_Value(), m_Value()))) && + !isCanonicalPredicate(Pred)) { + // Swap true/false values and condition. + CmpInst *Cond = cast<CmpInst>(CondVal); + Cond->setPredicate(CmpInst::getInversePredicate(Pred)); + SI.setOperand(1, FalseVal); + SI.setOperand(2, TrueVal); + SI.swapProfMetadata(); + Worklist.Add(Cond); + return &SI; + } + + if (SelType->isIntOrIntVectorTy(1) && TrueVal->getType() == CondVal->getType()) { if (match(TrueVal, m_One())) { // Change: A = select B, true, C --> A = or B, C @@ -1115,7 +1189,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { } if (match(TrueVal, m_Zero())) { // Change: A = select B, false, C --> A = and !B, C - Value *NotCond = Builder->CreateNot(CondVal, "not." + CondVal->getName()); + Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName()); return BinaryOperator::CreateAnd(NotCond, FalseVal); } if (match(FalseVal, m_Zero())) { @@ -1124,7 +1198,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { } if (match(FalseVal, m_One())) { // Change: A = select B, C, true --> A = or !B, C - Value *NotCond = Builder->CreateNot(CondVal, "not." + CondVal->getName()); + Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName()); return BinaryOperator::CreateOr(NotCond, TrueVal); } @@ -1149,7 +1223,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { // select i1 %c, <2 x i8> <1, 1>, <2 x i8> <0, 0> // because that may need 3 instructions to splat the condition value: // extend, insertelement, shufflevector. - if (CondVal->getType()->isVectorTy() == SelType->isVectorTy()) { + if (SelType->isIntOrIntVectorTy() && + CondVal->getType()->isVectorTy() == SelType->isVectorTy()) { // select C, 1, 0 -> zext C to int if (match(TrueVal, m_One()) && match(FalseVal, m_Zero())) return new ZExtInst(CondVal, SelType); @@ -1160,20 +1235,21 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { // select C, 0, 1 -> zext !C to int if (match(TrueVal, m_Zero()) && match(FalseVal, m_One())) { - Value *NotCond = Builder->CreateNot(CondVal, "not." + CondVal->getName()); + Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName()); return new ZExtInst(NotCond, SelType); } // select C, 0, -1 -> sext !C to int if (match(TrueVal, m_Zero()) && match(FalseVal, m_AllOnes())) { - Value *NotCond = Builder->CreateNot(CondVal, "not." + CondVal->getName()); + Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName()); return new SExtInst(NotCond, SelType); } } if (ConstantInt *TrueValC = dyn_cast<ConstantInt>(TrueVal)) if (ConstantInt *FalseValC = dyn_cast<ConstantInt>(FalseVal)) - if (Value *V = foldSelectICmpAnd(SI, TrueValC, FalseValC, Builder)) + if (Value *V = foldSelectICmpAnd(SI, TrueValC->getValue(), + FalseValC->getValue(), Builder)) return replaceInstUsesWith(SI, V); // See if we are selecting two values based on a comparison of the two values. @@ -1211,10 +1287,10 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { // (X ugt Y) ? X : Y -> (X ole Y) ? Y : X if (FCI->hasOneUse() && FCmpInst::isUnordered(FCI->getPredicate())) { FCmpInst::Predicate InvPred = FCI->getInversePredicate(); - IRBuilder<>::FastMathFlagGuard FMFG(*Builder); - Builder->setFastMathFlags(FCI->getFastMathFlags()); - Value *NewCond = Builder->CreateFCmp(InvPred, TrueVal, FalseVal, - FCI->getName() + ".inv"); + IRBuilder<>::FastMathFlagGuard FMFG(Builder); + Builder.setFastMathFlags(FCI->getFastMathFlags()); + Value *NewCond = Builder.CreateFCmp(InvPred, TrueVal, FalseVal, + FCI->getName() + ".inv"); return SelectInst::Create(NewCond, FalseVal, TrueVal, SI.getName() + ".p"); @@ -1254,10 +1330,10 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { // (X ugt Y) ? X : Y -> (X ole Y) ? X : Y if (FCI->hasOneUse() && FCmpInst::isUnordered(FCI->getPredicate())) { FCmpInst::Predicate InvPred = FCI->getInversePredicate(); - IRBuilder<>::FastMathFlagGuard FMFG(*Builder); - Builder->setFastMathFlags(FCI->getFastMathFlags()); - Value *NewCond = Builder->CreateFCmp(InvPred, FalseVal, TrueVal, - FCI->getName() + ".inv"); + IRBuilder<>::FastMathFlagGuard FMFG(Builder); + Builder.setFastMathFlags(FCI->getFastMathFlags()); + Value *NewCond = Builder.CreateFCmp(InvPred, FalseVal, TrueVal, + FCI->getName() + ".inv"); return SelectInst::Create(NewCond, FalseVal, TrueVal, SI.getName() + ".p"); @@ -1273,7 +1349,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { if (Instruction *Result = foldSelectInstWithICmp(SI, ICI)) return Result; - if (Instruction *Add = foldAddSubSelect(SI, *Builder)) + if (Instruction *Add = foldAddSubSelect(SI, Builder)) return Add; // Turn (select C, (op X, Y), (op X, Z)) -> (op X, (select C, Y, Z)) @@ -1304,16 +1380,16 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Value *Cmp; if (CmpInst::isIntPredicate(Pred)) { - Cmp = Builder->CreateICmp(Pred, LHS, RHS); + Cmp = Builder.CreateICmp(Pred, LHS, RHS); } else { - IRBuilder<>::FastMathFlagGuard FMFG(*Builder); + IRBuilder<>::FastMathFlagGuard FMFG(Builder); auto FMF = cast<FPMathOperator>(SI.getCondition())->getFastMathFlags(); - Builder->setFastMathFlags(FMF); - Cmp = Builder->CreateFCmp(Pred, LHS, RHS); + Builder.setFastMathFlags(FMF); + Cmp = Builder.CreateFCmp(Pred, LHS, RHS); } - Value *NewSI = Builder->CreateCast( - CastOp, Builder->CreateSelect(Cmp, LHS, RHS, SI.getName(), &SI), + Value *NewSI = Builder.CreateCast( + CastOp, Builder.CreateSelect(Cmp, LHS, RHS, SI.getName(), &SI), SelType); return replaceInstUsesWith(SI, NewSI); } @@ -1348,13 +1424,12 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { (SI.hasOneUse() && match(*SI.user_begin(), m_Not(m_Value()))); if (NumberOfNots >= 2) { - Value *NewLHS = Builder->CreateNot(LHS); - Value *NewRHS = Builder->CreateNot(RHS); - Value *NewCmp = SPF == SPF_SMAX - ? Builder->CreateICmpSLT(NewLHS, NewRHS) - : Builder->CreateICmpULT(NewLHS, NewRHS); + Value *NewLHS = Builder.CreateNot(LHS); + Value *NewRHS = Builder.CreateNot(RHS); + Value *NewCmp = SPF == SPF_SMAX ? Builder.CreateICmpSLT(NewLHS, NewRHS) + : Builder.CreateICmpULT(NewLHS, NewRHS); Value *NewSI = - Builder->CreateNot(Builder->CreateSelect(NewCmp, NewLHS, NewRHS)); + Builder.CreateNot(Builder.CreateSelect(NewCmp, NewLHS, NewRHS)); return replaceInstUsesWith(SI, NewSI); } } @@ -1364,11 +1439,11 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { } // See if we can fold the select into a phi node if the condition is a select. - if (isa<PHINode>(SI.getCondition())) + if (auto *PN = dyn_cast<PHINode>(SI.getCondition())) // The true/false values have to be live in the PHI predecessor's blocks. if (canSelectOperandBeMappingIntoPredBlock(TrueVal, SI) && canSelectOperandBeMappingIntoPredBlock(FalseVal, SI)) - if (Instruction *NV = FoldOpIntoPhi(SI)) + if (Instruction *NV = foldOpIntoPhi(SI, PN)) return NV; if (SelectInst *TrueSI = dyn_cast<SelectInst>(TrueVal)) { @@ -1384,7 +1459,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { // We choose this as normal form to enable folding on the And and shortening // paths for the values (this helps GetUnderlyingObjects() for example). if (TrueSI->getFalseValue() == FalseVal && TrueSI->hasOneUse()) { - Value *And = Builder->CreateAnd(CondVal, TrueSI->getCondition()); + Value *And = Builder.CreateAnd(CondVal, TrueSI->getCondition()); SI.setOperand(0, And); SI.setOperand(1, TrueSI->getTrueValue()); return &SI; @@ -1402,7 +1477,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { } // select(C0, a, select(C1, a, b)) -> select(C0|C1, a, b) if (FalseSI->getTrueValue() == TrueVal && FalseSI->hasOneUse()) { - Value *Or = Builder->CreateOr(CondVal, FalseSI->getCondition()); + Value *Or = Builder.CreateOr(CondVal, FalseSI->getCondition()); SI.setOperand(0, Or); SI.setOperand(2, FalseSI->getFalseValue()); return &SI; @@ -1450,7 +1525,21 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { } } - if (Instruction *BitCastSel = foldSelectCmpBitcasts(SI, *Builder)) + // If we can compute the condition, there's no need for a select. + // Like the above fold, we are attempting to reduce compile-time cost by + // putting this fold here with limitations rather than in InstSimplify. + // The motivation for this call into value tracking is to take advantage of + // the assumption cache, so make sure that is populated. + if (!CondVal->getType()->isVectorTy() && !AC.assumptions().empty()) { + KnownBits Known(1); + computeKnownBits(CondVal, Known, 0, &SI); + if (Known.One.isOneValue()) + return replaceInstUsesWith(SI, TrueVal); + if (Known.Zero.isOneValue()) + return replaceInstUsesWith(SI, FalseVal); + } + + if (Instruction *BitCastSel = foldSelectCmpBitcasts(SI, Builder)) return BitCastSel; return nullptr; diff --git a/contrib/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/contrib/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp index 4ff9b64..7ed141c 100644 --- a/contrib/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/contrib/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -22,8 +22,8 @@ using namespace PatternMatch; #define DEBUG_TYPE "instcombine" Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { - assert(I.getOperand(1)->getType() == I.getOperand(0)->getType()); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + assert(Op0->getType() == Op1->getType()); // See if we can fold away this shift. if (SimplifyDemandedInstructionBits(I)) @@ -44,9 +44,10 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { Value *A; Constant *C; if (match(Op0, m_Constant()) && match(Op1, m_Add(m_Value(A), m_Constant(C)))) - if (isKnownNonNegative(A, DL) && isKnownNonNegative(C, DL)) + if (isKnownNonNegative(A, DL, 0, &AC, &I, &DT) && + isKnownNonNegative(C, DL, 0, &AC, &I, &DT)) return BinaryOperator::Create( - I.getOpcode(), Builder->CreateBinOp(I.getOpcode(), Op0, C), A); + I.getOpcode(), Builder.CreateBinOp(I.getOpcode(), Op0, C), A); // X shift (A srem B) -> X shift (A and B-1) iff B is a power of 2. // Because shifts by negative values (which could occur if A were negative) @@ -55,8 +56,8 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { if (Op1->hasOneUse() && match(Op1, m_SRem(m_Value(A), m_Power2(B)))) { // FIXME: Should this get moved into SimplifyDemandedBits by saying we don't // demand the sign bit (and many others) here?? - Value *Rem = Builder->CreateAnd(A, ConstantInt::get(I.getType(), *B-1), - Op1->getName()); + Value *Rem = Builder.CreateAnd(A, ConstantInt::get(I.getType(), *B - 1), + Op1->getName()); I.setOperand(1, Rem); return &I; } @@ -65,63 +66,60 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { } /// Return true if we can simplify two logical (either left or right) shifts -/// that have constant shift amounts. -static bool canEvaluateShiftedShift(unsigned FirstShiftAmt, - bool IsFirstShiftLeft, - Instruction *SecondShift, InstCombiner &IC, +/// that have constant shift amounts: OuterShift (InnerShift X, C1), C2. +static bool canEvaluateShiftedShift(unsigned OuterShAmt, bool IsOuterShl, + Instruction *InnerShift, InstCombiner &IC, Instruction *CxtI) { - assert(SecondShift->isLogicalShift() && "Unexpected instruction type"); + assert(InnerShift->isLogicalShift() && "Unexpected instruction type"); - // We need constant shifts. - auto *SecondShiftConst = dyn_cast<ConstantInt>(SecondShift->getOperand(1)); - if (!SecondShiftConst) + // We need constant scalar or constant splat shifts. + const APInt *InnerShiftConst; + if (!match(InnerShift->getOperand(1), m_APInt(InnerShiftConst))) return false; - unsigned SecondShiftAmt = SecondShiftConst->getZExtValue(); - bool IsSecondShiftLeft = SecondShift->getOpcode() == Instruction::Shl; - - // We can always fold shl(c1) + shl(c2) -> shl(c1+c2). - // We can always fold lshr(c1) + lshr(c2) -> lshr(c1+c2). - if (IsFirstShiftLeft == IsSecondShiftLeft) + // Two logical shifts in the same direction: + // shl (shl X, C1), C2 --> shl X, C1 + C2 + // lshr (lshr X, C1), C2 --> lshr X, C1 + C2 + bool IsInnerShl = InnerShift->getOpcode() == Instruction::Shl; + if (IsInnerShl == IsOuterShl) return true; - // We can always fold lshr(c) + shl(c) -> and(c2). - // We can always fold shl(c) + lshr(c) -> and(c2). - if (FirstShiftAmt == SecondShiftAmt) + // Equal shift amounts in opposite directions become bitwise 'and': + // lshr (shl X, C), C --> and X, C' + // shl (lshr X, C), C --> and X, C' + unsigned InnerShAmt = InnerShiftConst->getZExtValue(); + if (InnerShAmt == OuterShAmt) return true; - unsigned TypeWidth = SecondShift->getType()->getScalarSizeInBits(); - // If the 2nd shift is bigger than the 1st, we can fold: - // lshr(c1) + shl(c2) -> shl(c3) + and(c4) or - // shl(c1) + lshr(c2) -> lshr(c3) + and(c4), + // lshr (shl X, C1), C2 --> and (shl X, C1 - C2), C3 + // shl (lshr X, C1), C2 --> and (lshr X, C1 - C2), C3 // but it isn't profitable unless we know the and'd out bits are already zero. - // Also check that the 2nd shift is valid (less than the type width) or we'll - // crash trying to produce the bit mask for the 'and'. - if (SecondShiftAmt > FirstShiftAmt && SecondShiftAmt < TypeWidth) { - unsigned MaskShift = IsSecondShiftLeft ? TypeWidth - SecondShiftAmt - : SecondShiftAmt - FirstShiftAmt; - APInt Mask = APInt::getLowBitsSet(TypeWidth, FirstShiftAmt) << MaskShift; - if (IC.MaskedValueIsZero(SecondShift->getOperand(0), Mask, 0, CxtI)) + // Also, check that the inner shift is valid (less than the type width) or + // we'll crash trying to produce the bit mask for the 'and'. + unsigned TypeWidth = InnerShift->getType()->getScalarSizeInBits(); + if (InnerShAmt > OuterShAmt && InnerShAmt < TypeWidth) { + unsigned MaskShift = + IsInnerShl ? TypeWidth - InnerShAmt : InnerShAmt - OuterShAmt; + APInt Mask = APInt::getLowBitsSet(TypeWidth, OuterShAmt) << MaskShift; + if (IC.MaskedValueIsZero(InnerShift->getOperand(0), Mask, 0, CxtI)) return true; } return false; } -/// See if we can compute the specified value, but shifted -/// logically to the left or right by some number of bits. This should return -/// true if the expression can be computed for the same cost as the current -/// expression tree. This is used to eliminate extraneous shifting from things -/// like: +/// See if we can compute the specified value, but shifted logically to the left +/// or right by some number of bits. This should return true if the expression +/// can be computed for the same cost as the current expression tree. This is +/// used to eliminate extraneous shifting from things like: /// %C = shl i128 %A, 64 /// %D = shl i128 %B, 96 /// %E = or i128 %C, %D /// %F = lshr i128 %E, 64 -/// where the client will ask if E can be computed shifted right by 64-bits. If -/// this succeeds, the GetShiftedValue function will be called to produce the -/// value. -static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift, +/// where the client will ask if E can be computed shifted right by 64-bits. If +/// this succeeds, getShiftedValue() will be called to produce the value. +static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift, InstCombiner &IC, Instruction *CxtI) { // We can always evaluate constants shifted. if (isa<Constant>(V)) @@ -165,8 +163,8 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift, case Instruction::Or: case Instruction::Xor: // Bitwise operators can all arbitrarily be arbitrarily evaluated shifted. - return CanEvaluateShifted(I->getOperand(0), NumBits, IsLeftShift, IC, I) && - CanEvaluateShifted(I->getOperand(1), NumBits, IsLeftShift, IC, I); + return canEvaluateShifted(I->getOperand(0), NumBits, IsLeftShift, IC, I) && + canEvaluateShifted(I->getOperand(1), NumBits, IsLeftShift, IC, I); case Instruction::Shl: case Instruction::LShr: @@ -176,8 +174,8 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift, SelectInst *SI = cast<SelectInst>(I); Value *TrueVal = SI->getTrueValue(); Value *FalseVal = SI->getFalseValue(); - return CanEvaluateShifted(TrueVal, NumBits, IsLeftShift, IC, SI) && - CanEvaluateShifted(FalseVal, NumBits, IsLeftShift, IC, SI); + return canEvaluateShifted(TrueVal, NumBits, IsLeftShift, IC, SI) && + canEvaluateShifted(FalseVal, NumBits, IsLeftShift, IC, SI); } case Instruction::PHI: { // We can change a phi if we can change all operands. Note that we never @@ -185,23 +183,86 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift, // instructions with a single use. PHINode *PN = cast<PHINode>(I); for (Value *IncValue : PN->incoming_values()) - if (!CanEvaluateShifted(IncValue, NumBits, IsLeftShift, IC, PN)) + if (!canEvaluateShifted(IncValue, NumBits, IsLeftShift, IC, PN)) return false; return true; } } } -/// When CanEvaluateShifted returned true for an expression, -/// this value inserts the new computation that produces the shifted value. -static Value *GetShiftedValue(Value *V, unsigned NumBits, bool isLeftShift, +/// Fold OuterShift (InnerShift X, C1), C2. +/// See canEvaluateShiftedShift() for the constraints on these instructions. +static Value *foldShiftedShift(BinaryOperator *InnerShift, unsigned OuterShAmt, + bool IsOuterShl, + InstCombiner::BuilderTy &Builder) { + bool IsInnerShl = InnerShift->getOpcode() == Instruction::Shl; + Type *ShType = InnerShift->getType(); + unsigned TypeWidth = ShType->getScalarSizeInBits(); + + // We only accept shifts-by-a-constant in canEvaluateShifted(). + const APInt *C1; + match(InnerShift->getOperand(1), m_APInt(C1)); + unsigned InnerShAmt = C1->getZExtValue(); + + // Change the shift amount and clear the appropriate IR flags. + auto NewInnerShift = [&](unsigned ShAmt) { + InnerShift->setOperand(1, ConstantInt::get(ShType, ShAmt)); + if (IsInnerShl) { + InnerShift->setHasNoUnsignedWrap(false); + InnerShift->setHasNoSignedWrap(false); + } else { + InnerShift->setIsExact(false); + } + return InnerShift; + }; + + // Two logical shifts in the same direction: + // shl (shl X, C1), C2 --> shl X, C1 + C2 + // lshr (lshr X, C1), C2 --> lshr X, C1 + C2 + if (IsInnerShl == IsOuterShl) { + // If this is an oversized composite shift, then unsigned shifts get 0. + if (InnerShAmt + OuterShAmt >= TypeWidth) + return Constant::getNullValue(ShType); + + return NewInnerShift(InnerShAmt + OuterShAmt); + } + + // Equal shift amounts in opposite directions become bitwise 'and': + // lshr (shl X, C), C --> and X, C' + // shl (lshr X, C), C --> and X, C' + if (InnerShAmt == OuterShAmt) { + APInt Mask = IsInnerShl + ? APInt::getLowBitsSet(TypeWidth, TypeWidth - OuterShAmt) + : APInt::getHighBitsSet(TypeWidth, TypeWidth - OuterShAmt); + Value *And = Builder.CreateAnd(InnerShift->getOperand(0), + ConstantInt::get(ShType, Mask)); + if (auto *AndI = dyn_cast<Instruction>(And)) { + AndI->moveBefore(InnerShift); + AndI->takeName(InnerShift); + } + return And; + } + + assert(InnerShAmt > OuterShAmt && + "Unexpected opposite direction logical shift pair"); + + // In general, we would need an 'and' for this transform, but + // canEvaluateShiftedShift() guarantees that the masked-off bits are not used. + // lshr (shl X, C1), C2 --> shl X, C1 - C2 + // shl (lshr X, C1), C2 --> lshr X, C1 - C2 + return NewInnerShift(InnerShAmt - OuterShAmt); +} + +/// When canEvaluateShifted() returns true for an expression, this function +/// inserts the new computation that produces the shifted value. +static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift, InstCombiner &IC, const DataLayout &DL) { // We can always evaluate constants shifted. if (Constant *C = dyn_cast<Constant>(V)) { if (isLeftShift) - V = IC.Builder->CreateShl(C, NumBits); + V = IC.Builder.CreateShl(C, NumBits); else - V = IC.Builder->CreateLShr(C, NumBits); + V = IC.Builder.CreateLShr(C, NumBits); // If we got a constantexpr back, try to simplify it with TD info. if (auto *C = dyn_cast<Constant>(V)) if (auto *FoldedC = @@ -220,100 +281,21 @@ static Value *GetShiftedValue(Value *V, unsigned NumBits, bool isLeftShift, case Instruction::Xor: // Bitwise operators can all arbitrarily be arbitrarily evaluated shifted. I->setOperand( - 0, GetShiftedValue(I->getOperand(0), NumBits, isLeftShift, IC, DL)); + 0, getShiftedValue(I->getOperand(0), NumBits, isLeftShift, IC, DL)); I->setOperand( - 1, GetShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL)); + 1, getShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL)); return I; - case Instruction::Shl: { - BinaryOperator *BO = cast<BinaryOperator>(I); - unsigned TypeWidth = BO->getType()->getScalarSizeInBits(); - - // We only accept shifts-by-a-constant in CanEvaluateShifted. - ConstantInt *CI = cast<ConstantInt>(BO->getOperand(1)); - - // We can always fold shl(c1)+shl(c2) -> shl(c1+c2). - if (isLeftShift) { - // If this is oversized composite shift, then unsigned shifts get 0. - unsigned NewShAmt = NumBits+CI->getZExtValue(); - if (NewShAmt >= TypeWidth) - return Constant::getNullValue(I->getType()); - - BO->setOperand(1, ConstantInt::get(BO->getType(), NewShAmt)); - BO->setHasNoUnsignedWrap(false); - BO->setHasNoSignedWrap(false); - return I; - } - - // We turn shl(c)+lshr(c) -> and(c2) if the input doesn't already have - // zeros. - if (CI->getValue() == NumBits) { - APInt Mask(APInt::getLowBitsSet(TypeWidth, TypeWidth - NumBits)); - V = IC.Builder->CreateAnd(BO->getOperand(0), - ConstantInt::get(BO->getContext(), Mask)); - if (Instruction *VI = dyn_cast<Instruction>(V)) { - VI->moveBefore(BO); - VI->takeName(BO); - } - return V; - } - - // We turn shl(c1)+shr(c2) -> shl(c3)+and(c4), but only when we know that - // the and won't be needed. - assert(CI->getZExtValue() > NumBits); - BO->setOperand(1, ConstantInt::get(BO->getType(), - CI->getZExtValue() - NumBits)); - BO->setHasNoUnsignedWrap(false); - BO->setHasNoSignedWrap(false); - return BO; - } - // FIXME: This is almost identical to the SHL case. Refactor both cases into - // a helper function. - case Instruction::LShr: { - BinaryOperator *BO = cast<BinaryOperator>(I); - unsigned TypeWidth = BO->getType()->getScalarSizeInBits(); - // We only accept shifts-by-a-constant in CanEvaluateShifted. - ConstantInt *CI = cast<ConstantInt>(BO->getOperand(1)); - - // We can always fold lshr(c1)+lshr(c2) -> lshr(c1+c2). - if (!isLeftShift) { - // If this is oversized composite shift, then unsigned shifts get 0. - unsigned NewShAmt = NumBits+CI->getZExtValue(); - if (NewShAmt >= TypeWidth) - return Constant::getNullValue(BO->getType()); - - BO->setOperand(1, ConstantInt::get(BO->getType(), NewShAmt)); - BO->setIsExact(false); - return I; - } - - // We turn lshr(c)+shl(c) -> and(c2) if the input doesn't already have - // zeros. - if (CI->getValue() == NumBits) { - APInt Mask(APInt::getHighBitsSet(TypeWidth, TypeWidth - NumBits)); - V = IC.Builder->CreateAnd(I->getOperand(0), - ConstantInt::get(BO->getContext(), Mask)); - if (Instruction *VI = dyn_cast<Instruction>(V)) { - VI->moveBefore(I); - VI->takeName(I); - } - return V; - } - - // We turn lshr(c1)+shl(c2) -> lshr(c3)+and(c4), but only when we know that - // the and won't be needed. - assert(CI->getZExtValue() > NumBits); - BO->setOperand(1, ConstantInt::get(BO->getType(), - CI->getZExtValue() - NumBits)); - BO->setIsExact(false); - return BO; - } + case Instruction::Shl: + case Instruction::LShr: + return foldShiftedShift(cast<BinaryOperator>(I), NumBits, isLeftShift, + IC.Builder); case Instruction::Select: I->setOperand( - 1, GetShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL)); + 1, getShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL)); I->setOperand( - 2, GetShiftedValue(I->getOperand(2), NumBits, isLeftShift, IC, DL)); + 2, getShiftedValue(I->getOperand(2), NumBits, isLeftShift, IC, DL)); return I; case Instruction::PHI: { // We can change a phi if we can change all operands. Note that we never @@ -321,215 +303,39 @@ static Value *GetShiftedValue(Value *V, unsigned NumBits, bool isLeftShift, // instructions with a single use. PHINode *PN = cast<PHINode>(I); for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) - PN->setIncomingValue(i, GetShiftedValue(PN->getIncomingValue(i), NumBits, + PN->setIncomingValue(i, getShiftedValue(PN->getIncomingValue(i), NumBits, isLeftShift, IC, DL)); return PN; } } } -/// Try to fold (X << C1) << C2, where the shifts are some combination of -/// shl/ashr/lshr. -static Instruction * -foldShiftByConstOfShiftByConst(BinaryOperator &I, ConstantInt *COp1, - InstCombiner::BuilderTy *Builder) { - Value *Op0 = I.getOperand(0); - uint32_t TypeBits = Op0->getType()->getScalarSizeInBits(); - - // Find out if this is a shift of a shift by a constant. - BinaryOperator *ShiftOp = dyn_cast<BinaryOperator>(Op0); - if (ShiftOp && !ShiftOp->isShift()) - ShiftOp = nullptr; - - if (ShiftOp && isa<ConstantInt>(ShiftOp->getOperand(1))) { - - // This is a constant shift of a constant shift. Be careful about hiding - // shl instructions behind bit masks. They are used to represent multiplies - // by a constant, and it is important that simple arithmetic expressions - // are still recognizable by scalar evolution. - // - // The transforms applied to shl are very similar to the transforms applied - // to mul by constant. We can be more aggressive about optimizing right - // shifts. - // - // Combinations of right and left shifts will still be optimized in - // DAGCombine where scalar evolution no longer applies. - - ConstantInt *ShiftAmt1C = cast<ConstantInt>(ShiftOp->getOperand(1)); - uint32_t ShiftAmt1 = ShiftAmt1C->getLimitedValue(TypeBits); - uint32_t ShiftAmt2 = COp1->getLimitedValue(TypeBits); - assert(ShiftAmt2 != 0 && "Should have been simplified earlier"); - if (ShiftAmt1 == 0) - return nullptr; // Will be simplified in the future. - Value *X = ShiftOp->getOperand(0); - - IntegerType *Ty = cast<IntegerType>(I.getType()); - - // Check for (X << c1) << c2 and (X >> c1) >> c2 - if (I.getOpcode() == ShiftOp->getOpcode()) { - uint32_t AmtSum = ShiftAmt1 + ShiftAmt2; // Fold into one big shift. - // If this is an oversized composite shift, then unsigned shifts become - // zero (handled in InstSimplify) and ashr saturates. - if (AmtSum >= TypeBits) { - if (I.getOpcode() != Instruction::AShr) - return nullptr; - AmtSum = TypeBits - 1; // Saturate to 31 for i32 ashr. - } - - return BinaryOperator::Create(I.getOpcode(), X, - ConstantInt::get(Ty, AmtSum)); - } - - if (ShiftAmt1 == ShiftAmt2) { - // If we have ((X << C) >>u C), turn this into X & (-1 >>u C). - if (I.getOpcode() == Instruction::LShr && - ShiftOp->getOpcode() == Instruction::Shl) { - APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt1)); - return BinaryOperator::CreateAnd( - X, ConstantInt::get(I.getContext(), Mask)); - } - } else if (ShiftAmt1 < ShiftAmt2) { - uint32_t ShiftDiff = ShiftAmt2 - ShiftAmt1; - - // (X >>?,exact C1) << C2 --> X << (C2-C1) - // The inexact version is deferred to DAGCombine so we don't hide shl - // behind a bit mask. - if (I.getOpcode() == Instruction::Shl && - ShiftOp->getOpcode() != Instruction::Shl && ShiftOp->isExact()) { - assert(ShiftOp->getOpcode() == Instruction::LShr || - ShiftOp->getOpcode() == Instruction::AShr); - ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); - BinaryOperator *NewShl = - BinaryOperator::Create(Instruction::Shl, X, ShiftDiffCst); - NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); - NewShl->setHasNoSignedWrap(I.hasNoSignedWrap()); - return NewShl; - } - - // (X << C1) >>u C2 --> X >>u (C2-C1) & (-1 >> C2) - if (I.getOpcode() == Instruction::LShr && - ShiftOp->getOpcode() == Instruction::Shl) { - ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); - // (X <<nuw C1) >>u C2 --> X >>u (C2-C1) - if (ShiftOp->hasNoUnsignedWrap()) { - BinaryOperator *NewLShr = - BinaryOperator::Create(Instruction::LShr, X, ShiftDiffCst); - NewLShr->setIsExact(I.isExact()); - return NewLShr; - } - Value *Shift = Builder->CreateLShr(X, ShiftDiffCst); - - APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt2)); - return BinaryOperator::CreateAnd( - Shift, ConstantInt::get(I.getContext(), Mask)); - } - - // We can't handle (X << C1) >>s C2, it shifts arbitrary bits in. However, - // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits. - if (I.getOpcode() == Instruction::AShr && - ShiftOp->getOpcode() == Instruction::Shl) { - if (ShiftOp->hasNoSignedWrap()) { - // (X <<nsw C1) >>s C2 --> X >>s (C2-C1) - ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); - BinaryOperator *NewAShr = - BinaryOperator::Create(Instruction::AShr, X, ShiftDiffCst); - NewAShr->setIsExact(I.isExact()); - return NewAShr; - } - } - } else { - assert(ShiftAmt2 < ShiftAmt1); - uint32_t ShiftDiff = ShiftAmt1 - ShiftAmt2; - - // (X >>?exact C1) << C2 --> X >>?exact (C1-C2) - // The inexact version is deferred to DAGCombine so we don't hide shl - // behind a bit mask. - if (I.getOpcode() == Instruction::Shl && - ShiftOp->getOpcode() != Instruction::Shl && ShiftOp->isExact()) { - ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); - BinaryOperator *NewShr = - BinaryOperator::Create(ShiftOp->getOpcode(), X, ShiftDiffCst); - NewShr->setIsExact(true); - return NewShr; - } - - // (X << C1) >>u C2 --> X << (C1-C2) & (-1 >> C2) - if (I.getOpcode() == Instruction::LShr && - ShiftOp->getOpcode() == Instruction::Shl) { - ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); - if (ShiftOp->hasNoUnsignedWrap()) { - // (X <<nuw C1) >>u C2 --> X <<nuw (C1-C2) - BinaryOperator *NewShl = - BinaryOperator::Create(Instruction::Shl, X, ShiftDiffCst); - NewShl->setHasNoUnsignedWrap(true); - return NewShl; - } - Value *Shift = Builder->CreateShl(X, ShiftDiffCst); - - APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt2)); - return BinaryOperator::CreateAnd( - Shift, ConstantInt::get(I.getContext(), Mask)); - } - - // We can't handle (X << C1) >>s C2, it shifts arbitrary bits in. However, - // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits. - if (I.getOpcode() == Instruction::AShr && - ShiftOp->getOpcode() == Instruction::Shl) { - if (ShiftOp->hasNoSignedWrap()) { - // (X <<nsw C1) >>s C2 --> X <<nsw (C1-C2) - ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); - BinaryOperator *NewShl = - BinaryOperator::Create(Instruction::Shl, X, ShiftDiffCst); - NewShl->setHasNoSignedWrap(true); - return NewShl; - } - } - } - } - - return nullptr; -} - Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, BinaryOperator &I) { bool isLeftShift = I.getOpcode() == Instruction::Shl; - ConstantInt *COp1 = nullptr; - if (ConstantDataVector *CV = dyn_cast<ConstantDataVector>(Op1)) - COp1 = dyn_cast_or_null<ConstantInt>(CV->getSplatValue()); - else if (ConstantVector *CV = dyn_cast<ConstantVector>(Op1)) - COp1 = dyn_cast_or_null<ConstantInt>(CV->getSplatValue()); - else - COp1 = dyn_cast<ConstantInt>(Op1); - - if (!COp1) + const APInt *Op1C; + if (!match(Op1, m_APInt(Op1C))) return nullptr; // See if we can propagate this shift into the input, this covers the trivial // cast of lshr(shl(x,c1),c2) as well as other more complex cases. if (I.getOpcode() != Instruction::AShr && - CanEvaluateShifted(Op0, COp1->getZExtValue(), isLeftShift, *this, &I)) { + canEvaluateShifted(Op0, Op1C->getZExtValue(), isLeftShift, *this, &I)) { DEBUG(dbgs() << "ICE: GetShiftedValue propagating shift through expression" " to eliminate shift:\n IN: " << *Op0 << "\n SH: " << I <<"\n"); return replaceInstUsesWith( - I, GetShiftedValue(Op0, COp1->getZExtValue(), isLeftShift, *this, DL)); + I, getShiftedValue(Op0, Op1C->getZExtValue(), isLeftShift, *this, DL)); } // See if we can simplify any instructions used by the instruction whose sole // purpose is to compute bits we don't care about. - uint32_t TypeBits = Op0->getType()->getScalarSizeInBits(); + unsigned TypeBits = Op0->getType()->getScalarSizeInBits(); - assert(!COp1->uge(TypeBits) && + assert(!Op1C->uge(TypeBits) && "Shift over the type width should have been removed already"); - // ((X*C1) << C2) == (X * (C1 << C2)) - if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Op0)) - if (BO->getOpcode() == Instruction::Mul && isLeftShift) - if (Constant *BOOp = dyn_cast<Constant>(BO->getOperand(1))) - return BinaryOperator::CreateMul(BO->getOperand(0), - ConstantExpr::getShl(BOOp, Op1)); - if (Instruction *FoldedShift = foldOpWithConstantIntoOperand(I)) return FoldedShift; @@ -544,9 +350,10 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, if (TrOp && I.isLogicalShift() && TrOp->isShift() && isa<ConstantInt>(TrOp->getOperand(1))) { // Okay, we'll do this xform. Make the shift of shift. - Constant *ShAmt = ConstantExpr::getZExt(COp1, TrOp->getType()); + Constant *ShAmt = + ConstantExpr::getZExt(cast<Constant>(Op1), TrOp->getType()); // (shift2 (shift1 & 0x00FF), c2) - Value *NSh = Builder->CreateBinOp(I.getOpcode(), TrOp, ShAmt,I.getName()); + Value *NSh = Builder.CreateBinOp(I.getOpcode(), TrOp, ShAmt, I.getName()); // For logical shifts, the truncation has the effect of making the high // part of the register be zeros. Emulate this by inserting an AND to @@ -561,16 +368,16 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, // shift. We know that it is a logical shift by a constant, so adjust the // mask as appropriate. if (I.getOpcode() == Instruction::Shl) - MaskV <<= COp1->getZExtValue(); + MaskV <<= Op1C->getZExtValue(); else { assert(I.getOpcode() == Instruction::LShr && "Unknown logical shift"); - MaskV = MaskV.lshr(COp1->getZExtValue()); + MaskV.lshrInPlace(Op1C->getZExtValue()); } // shift1 & 0x00FF - Value *And = Builder->CreateAnd(NSh, - ConstantInt::get(I.getContext(), MaskV), - TI->getName()); + Value *And = Builder.CreateAnd(NSh, + ConstantInt::get(I.getContext(), MaskV), + TI->getName()); // Return the value truncated to the interesting size. return new TruncInst(And, I.getType()); @@ -594,11 +401,11 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, match(Op0BO->getOperand(1), m_Shr(m_Value(V1), m_Specific(Op1)))) { Value *YS = // (Y << C) - Builder->CreateShl(Op0BO->getOperand(0), Op1, Op0BO->getName()); + Builder.CreateShl(Op0BO->getOperand(0), Op1, Op0BO->getName()); // (X + (Y << C)) - Value *X = Builder->CreateBinOp(Op0BO->getOpcode(), YS, V1, - Op0BO->getOperand(1)->getName()); - uint32_t Op1Val = COp1->getLimitedValue(TypeBits); + Value *X = Builder.CreateBinOp(Op0BO->getOpcode(), YS, V1, + Op0BO->getOperand(1)->getName()); + unsigned Op1Val = Op1C->getLimitedValue(TypeBits); APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val); Constant *Mask = ConstantInt::get(I.getContext(), Bits); @@ -614,11 +421,10 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, m_And(m_OneUse(m_Shr(m_Value(V1), m_Specific(Op1))), m_ConstantInt(CC)))) { Value *YS = // (Y << C) - Builder->CreateShl(Op0BO->getOperand(0), Op1, - Op0BO->getName()); + Builder.CreateShl(Op0BO->getOperand(0), Op1, Op0BO->getName()); // X & (CC << C) - Value *XM = Builder->CreateAnd(V1, ConstantExpr::getShl(CC, Op1), - V1->getName()+".mask"); + Value *XM = Builder.CreateAnd(V1, ConstantExpr::getShl(CC, Op1), + V1->getName()+".mask"); return BinaryOperator::Create(Op0BO->getOpcode(), YS, XM); } LLVM_FALLTHROUGH; @@ -630,11 +436,11 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, match(Op0BO->getOperand(0), m_Shr(m_Value(V1), m_Specific(Op1)))) { Value *YS = // (Y << C) - Builder->CreateShl(Op0BO->getOperand(1), Op1, Op0BO->getName()); + Builder.CreateShl(Op0BO->getOperand(1), Op1, Op0BO->getName()); // (X + (Y << C)) - Value *X = Builder->CreateBinOp(Op0BO->getOpcode(), V1, YS, - Op0BO->getOperand(0)->getName()); - uint32_t Op1Val = COp1->getLimitedValue(TypeBits); + Value *X = Builder.CreateBinOp(Op0BO->getOpcode(), V1, YS, + Op0BO->getOperand(0)->getName()); + unsigned Op1Val = Op1C->getLimitedValue(TypeBits); APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val); Constant *Mask = ConstantInt::get(I.getContext(), Bits); @@ -649,10 +455,10 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, m_And(m_OneUse(m_Shr(m_Value(V1), m_Value(V2))), m_ConstantInt(CC))) && V2 == Op1) { Value *YS = // (Y << C) - Builder->CreateShl(Op0BO->getOperand(1), Op1, Op0BO->getName()); + Builder.CreateShl(Op0BO->getOperand(1), Op1, Op0BO->getName()); // X & (CC << C) - Value *XM = Builder->CreateAnd(V1, ConstantExpr::getShl(CC, Op1), - V1->getName()+".mask"); + Value *XM = Builder.CreateAnd(V1, ConstantExpr::getShl(CC, Op1), + V1->getName()+".mask"); return BinaryOperator::Create(Op0BO->getOpcode(), XM, YS); } @@ -695,7 +501,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, Constant *NewRHS = ConstantExpr::get(I.getOpcode(), Op0C, Op1); Value *NewShift = - Builder->CreateBinOp(I.getOpcode(), Op0BO->getOperand(0), Op1); + Builder.CreateBinOp(I.getOpcode(), Op0BO->getOperand(0), Op1); NewShift->takeName(Op0BO); return BinaryOperator::Create(Op0BO->getOpcode(), NewShift, @@ -705,9 +511,6 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, } } - if (Instruction *Folded = foldShiftByConstOfShiftByConst(I, COp1, Builder)) - return Folded; - return nullptr; } @@ -715,59 +518,97 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = - SimplifyShlInst(I.getOperand(0), I.getOperand(1), I.hasNoSignedWrap(), - I.hasNoUnsignedWrap(), DL, &TLI, &DT, &AC)) + SimplifyShlInst(Op0, Op1, I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); if (Instruction *V = commonShiftTransforms(I)) return V; - if (ConstantInt *Op1C = dyn_cast<ConstantInt>(I.getOperand(1))) { - unsigned ShAmt = Op1C->getZExtValue(); - - // Turn: - // %zext = zext i32 %V to i64 - // %res = shl i64 %V, 8 - // - // Into: - // %shl = shl i32 %V, 8 - // %res = zext i32 %shl to i64 - // - // This is only valid if %V would have zeros shifted out. - if (auto *ZI = dyn_cast<ZExtInst>(I.getOperand(0))) { - unsigned SrcBitWidth = ZI->getSrcTy()->getScalarSizeInBits(); - if (ShAmt < SrcBitWidth && - MaskedValueIsZero(ZI->getOperand(0), - APInt::getHighBitsSet(SrcBitWidth, ShAmt), 0, &I)) { - auto *Shl = Builder->CreateShl(ZI->getOperand(0), ShAmt); - return new ZExtInst(Shl, I.getType()); + const APInt *ShAmtAPInt; + if (match(Op1, m_APInt(ShAmtAPInt))) { + unsigned ShAmt = ShAmtAPInt->getZExtValue(); + unsigned BitWidth = I.getType()->getScalarSizeInBits(); + Type *Ty = I.getType(); + + // shl (zext X), ShAmt --> zext (shl X, ShAmt) + // This is only valid if X would have zeros shifted out. + Value *X; + if (match(Op0, m_ZExt(m_Value(X)))) { + unsigned SrcWidth = X->getType()->getScalarSizeInBits(); + if (ShAmt < SrcWidth && + MaskedValueIsZero(X, APInt::getHighBitsSet(SrcWidth, ShAmt), 0, &I)) + return new ZExtInst(Builder.CreateShl(X, ShAmt), Ty); + } + + // (X >>u C) << C --> X & (-1 << C) + if (match(Op0, m_LShr(m_Value(X), m_Specific(Op1)))) { + APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmt)); + return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask)); + } + + // Be careful about hiding shl instructions behind bit masks. They are used + // to represent multiplies by a constant, and it is important that simple + // arithmetic expressions are still recognizable by scalar evolution. + // The inexact versions are deferred to DAGCombine, so we don't hide shl + // behind a bit mask. + const APInt *ShOp1; + if (match(Op0, m_Exact(m_Shr(m_Value(X), m_APInt(ShOp1))))) { + unsigned ShrAmt = ShOp1->getZExtValue(); + if (ShrAmt < ShAmt) { + // If C1 < C2: (X >>?,exact C1) << C2 --> X << (C2 - C1) + Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShrAmt); + auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff); + NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); + NewShl->setHasNoSignedWrap(I.hasNoSignedWrap()); + return NewShl; + } + if (ShrAmt > ShAmt) { + // If C1 > C2: (X >>?exact C1) << C2 --> X >>?exact (C1 - C2) + Constant *ShiftDiff = ConstantInt::get(Ty, ShrAmt - ShAmt); + auto *NewShr = BinaryOperator::Create( + cast<BinaryOperator>(Op0)->getOpcode(), X, ShiftDiff); + NewShr->setIsExact(true); + return NewShr; } } + if (match(Op0, m_Shl(m_Value(X), m_APInt(ShOp1)))) { + unsigned AmtSum = ShAmt + ShOp1->getZExtValue(); + // Oversized shifts are simplified to zero in InstSimplify. + if (AmtSum < BitWidth) + // (X << C1) << C2 --> X << (C1 + C2) + return BinaryOperator::CreateShl(X, ConstantInt::get(Ty, AmtSum)); + } + // If the shifted-out value is known-zero, then this is a NUW shift. if (!I.hasNoUnsignedWrap() && - MaskedValueIsZero(I.getOperand(0), - APInt::getHighBitsSet(Op1C->getBitWidth(), ShAmt), 0, - &I)) { + MaskedValueIsZero(Op0, APInt::getHighBitsSet(BitWidth, ShAmt), 0, &I)) { I.setHasNoUnsignedWrap(); return &I; } - // If the shifted out value is all signbits, this is a NSW shift. - if (!I.hasNoSignedWrap() && - ComputeNumSignBits(I.getOperand(0), 0, &I) > ShAmt) { + // If the shifted-out value is all signbits, then this is a NSW shift. + if (!I.hasNoSignedWrap() && ComputeNumSignBits(Op0, 0, &I) > ShAmt) { I.setHasNoSignedWrap(); return &I; } } - // (C1 << A) << C2 -> (C1 << C2) << A - Constant *C1, *C2; - Value *A; - if (match(I.getOperand(0), m_OneUse(m_Shl(m_Constant(C1), m_Value(A)))) && - match(I.getOperand(1), m_Constant(C2))) - return BinaryOperator::CreateShl(ConstantExpr::getShl(C1, C2), A); + Constant *C1; + if (match(Op1, m_Constant(C1))) { + Constant *C2; + Value *X; + // (C2 << X) << C1 --> (C2 << C1) << X + if (match(Op0, m_OneUse(m_Shl(m_Constant(C2), m_Value(X))))) + return BinaryOperator::CreateShl(ConstantExpr::getShl(C2, C1), X); + + // (X * C2) << C1 --> X * (C2 << C1) + if (match(Op0, m_Mul(m_Value(X), m_Constant(C2)))) + return BinaryOperator::CreateMul(X, ConstantExpr::getShl(C2, C1)); + } return nullptr; } @@ -776,43 +617,109 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); - if (Value *V = SimplifyLShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), - DL, &TLI, &DT, &AC)) + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + if (Value *V = + SimplifyLShrInst(Op0, Op1, I.isExact(), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); if (Instruction *R = commonShiftTransforms(I)) return R; - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - - if (ConstantInt *Op1C = dyn_cast<ConstantInt>(Op1)) { - unsigned ShAmt = Op1C->getZExtValue(); - - if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op0)) { - unsigned BitWidth = Op0->getType()->getScalarSizeInBits(); + Type *Ty = I.getType(); + const APInt *ShAmtAPInt; + if (match(Op1, m_APInt(ShAmtAPInt))) { + unsigned ShAmt = ShAmtAPInt->getZExtValue(); + unsigned BitWidth = Ty->getScalarSizeInBits(); + auto *II = dyn_cast<IntrinsicInst>(Op0); + if (II && isPowerOf2_32(BitWidth) && Log2_32(BitWidth) == ShAmt && + (II->getIntrinsicID() == Intrinsic::ctlz || + II->getIntrinsicID() == Intrinsic::cttz || + II->getIntrinsicID() == Intrinsic::ctpop)) { // ctlz.i32(x)>>5 --> zext(x == 0) // cttz.i32(x)>>5 --> zext(x == 0) // ctpop.i32(x)>>5 --> zext(x == -1) - if ((II->getIntrinsicID() == Intrinsic::ctlz || - II->getIntrinsicID() == Intrinsic::cttz || - II->getIntrinsicID() == Intrinsic::ctpop) && - isPowerOf2_32(BitWidth) && Log2_32(BitWidth) == ShAmt) { - bool isCtPop = II->getIntrinsicID() == Intrinsic::ctpop; - Constant *RHS = ConstantInt::getSigned(Op0->getType(), isCtPop ? -1:0); - Value *Cmp = Builder->CreateICmpEQ(II->getArgOperand(0), RHS); - return new ZExtInst(Cmp, II->getType()); + bool IsPop = II->getIntrinsicID() == Intrinsic::ctpop; + Constant *RHS = ConstantInt::getSigned(Ty, IsPop ? -1 : 0); + Value *Cmp = Builder.CreateICmpEQ(II->getArgOperand(0), RHS); + return new ZExtInst(Cmp, Ty); + } + + Value *X; + const APInt *ShOp1; + if (match(Op0, m_Shl(m_Value(X), m_APInt(ShOp1)))) { + unsigned ShlAmt = ShOp1->getZExtValue(); + if (ShlAmt < ShAmt) { + Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShlAmt); + if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) { + // (X <<nuw C1) >>u C2 --> X >>u (C2 - C1) + auto *NewLShr = BinaryOperator::CreateLShr(X, ShiftDiff); + NewLShr->setIsExact(I.isExact()); + return NewLShr; + } + // (X << C1) >>u C2 --> (X >>u (C2 - C1)) & (-1 >> C2) + Value *NewLShr = Builder.CreateLShr(X, ShiftDiff, "", I.isExact()); + APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt)); + return BinaryOperator::CreateAnd(NewLShr, ConstantInt::get(Ty, Mask)); + } + if (ShlAmt > ShAmt) { + Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmt - ShAmt); + if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) { + // (X <<nuw C1) >>u C2 --> X <<nuw (C1 - C2) + auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff); + NewShl->setHasNoUnsignedWrap(true); + return NewShl; + } + // (X << C1) >>u C2 --> X << (C1 - C2) & (-1 >> C2) + Value *NewShl = Builder.CreateShl(X, ShiftDiff); + APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt)); + return BinaryOperator::CreateAnd(NewShl, ConstantInt::get(Ty, Mask)); + } + assert(ShlAmt == ShAmt); + // (X << C) >>u C --> X & (-1 >>u C) + APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt)); + return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask)); + } + + if (match(Op0, m_SExt(m_Value(X))) && + (!Ty->isIntegerTy() || shouldChangeType(Ty, X->getType()))) { + // Are we moving the sign bit to the low bit and widening with high zeros? + unsigned SrcTyBitWidth = X->getType()->getScalarSizeInBits(); + if (ShAmt == BitWidth - 1) { + // lshr (sext i1 X to iN), N-1 --> zext X to iN + if (SrcTyBitWidth == 1) + return new ZExtInst(X, Ty); + + // lshr (sext iM X to iN), N-1 --> zext (lshr X, M-1) to iN + if (Op0->hasOneUse()) { + Value *NewLShr = Builder.CreateLShr(X, SrcTyBitWidth - 1); + return new ZExtInst(NewLShr, Ty); + } + } + + // lshr (sext iM X to iN), N-M --> zext (ashr X, min(N-M, M-1)) to iN + if (ShAmt == BitWidth - SrcTyBitWidth && Op0->hasOneUse()) { + // The new shift amount can't be more than the narrow source type. + unsigned NewShAmt = std::min(ShAmt, SrcTyBitWidth - 1); + Value *AShr = Builder.CreateAShr(X, NewShAmt); + return new ZExtInst(AShr, Ty); } } + if (match(Op0, m_LShr(m_Value(X), m_APInt(ShOp1)))) { + unsigned AmtSum = ShAmt + ShOp1->getZExtValue(); + // Oversized shifts are simplified to zero in InstSimplify. + if (AmtSum < BitWidth) + // (X >>u C1) >>u C2 --> X >>u (C1 + C2) + return BinaryOperator::CreateLShr(X, ConstantInt::get(Ty, AmtSum)); + } + // If the shifted-out value is known-zero, then this is an exact shift. if (!I.isExact() && - MaskedValueIsZero(Op0, APInt::getLowBitsSet(Op1C->getBitWidth(), ShAmt), - 0, &I)){ + MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) { I.setIsExact(); return &I; } } - return nullptr; } @@ -820,48 +727,67 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); - if (Value *V = SimplifyAShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), - DL, &TLI, &DT, &AC)) + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + if (Value *V = + SimplifyAShrInst(Op0, Op1, I.isExact(), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); if (Instruction *R = commonShiftTransforms(I)) return R; - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - - if (ConstantInt *Op1C = dyn_cast<ConstantInt>(Op1)) { - unsigned ShAmt = Op1C->getZExtValue(); + Type *Ty = I.getType(); + unsigned BitWidth = Ty->getScalarSizeInBits(); + const APInt *ShAmtAPInt; + if (match(Op1, m_APInt(ShAmtAPInt))) { + unsigned ShAmt = ShAmtAPInt->getZExtValue(); - // If the input is a SHL by the same constant (ashr (shl X, C), C), then we - // have a sign-extend idiom. + // If the shift amount equals the difference in width of the destination + // and source scalar types: + // ashr (shl (zext X), C), C --> sext X Value *X; - if (match(Op0, m_Shl(m_Value(X), m_Specific(Op1)))) { - // If the input is an extension from the shifted amount value, e.g. - // %x = zext i8 %A to i32 - // %y = shl i32 %x, 24 - // %z = ashr %y, 24 - // then turn this into "z = sext i8 A to i32". - if (ZExtInst *ZI = dyn_cast<ZExtInst>(X)) { - uint32_t SrcBits = ZI->getOperand(0)->getType()->getScalarSizeInBits(); - uint32_t DestBits = ZI->getType()->getScalarSizeInBits(); - if (Op1C->getZExtValue() == DestBits-SrcBits) - return new SExtInst(ZI->getOperand(0), ZI->getType()); + if (match(Op0, m_Shl(m_ZExt(m_Value(X)), m_Specific(Op1))) && + ShAmt == BitWidth - X->getType()->getScalarSizeInBits()) + return new SExtInst(X, Ty); + + // We can't handle (X << C1) >>s C2. It shifts arbitrary bits in. However, + // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits. + const APInt *ShOp1; + if (match(Op0, m_NSWShl(m_Value(X), m_APInt(ShOp1)))) { + unsigned ShlAmt = ShOp1->getZExtValue(); + if (ShlAmt < ShAmt) { + // (X <<nsw C1) >>s C2 --> X >>s (C2 - C1) + Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShlAmt); + auto *NewAShr = BinaryOperator::CreateAShr(X, ShiftDiff); + NewAShr->setIsExact(I.isExact()); + return NewAShr; } + if (ShlAmt > ShAmt) { + // (X <<nsw C1) >>s C2 --> X <<nsw (C1 - C2) + Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmt - ShAmt); + auto *NewShl = BinaryOperator::Create(Instruction::Shl, X, ShiftDiff); + NewShl->setHasNoSignedWrap(true); + return NewShl; + } + } + + if (match(Op0, m_AShr(m_Value(X), m_APInt(ShOp1)))) { + unsigned AmtSum = ShAmt + ShOp1->getZExtValue(); + // Oversized arithmetic shifts replicate the sign bit. + AmtSum = std::min(AmtSum, BitWidth - 1); + // (X >>s C1) >>s C2 --> X >>s (C1 + C2) + return BinaryOperator::CreateAShr(X, ConstantInt::get(Ty, AmtSum)); } // If the shifted-out value is known-zero, then this is an exact shift. if (!I.isExact() && - MaskedValueIsZero(Op0, APInt::getLowBitsSet(Op1C->getBitWidth(), ShAmt), - 0, &I)) { + MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) { I.setIsExact(); return &I; } } // See if we can turn a signed shr into an unsigned shr. - if (MaskedValueIsZero(Op0, - APInt::getSignBit(I.getType()->getScalarSizeInBits()), - 0, &I)) + if (MaskedValueIsZero(Op0, APInt::getSignMask(BitWidth), 0, &I)) return BinaryOperator::CreateLShr(Op0, Op1); return nullptr; diff --git a/contrib/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/contrib/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index 8b930bd..a20f474 100644 --- a/contrib/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/contrib/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -16,6 +16,7 @@ #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/Support/KnownBits.h" using namespace llvm; using namespace llvm::PatternMatch; @@ -26,22 +27,22 @@ using namespace llvm::PatternMatch; /// constant integer. If so, check to see if there are any bits set in the /// constant that are not demanded. If so, shrink the constant and return true. static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo, - APInt Demanded) { + const APInt &Demanded) { assert(I && "No instruction?"); assert(OpNo < I->getNumOperands() && "Operand index too large"); - // If the operand is not a constant integer, nothing to do. - ConstantInt *OpC = dyn_cast<ConstantInt>(I->getOperand(OpNo)); - if (!OpC) return false; + // The operand must be a constant integer or splat integer. + Value *Op = I->getOperand(OpNo); + const APInt *C; + if (!match(Op, m_APInt(C))) + return false; // If there are no bits set that aren't demanded, nothing to do. - Demanded = Demanded.zextOrTrunc(OpC->getValue().getBitWidth()); - if ((~Demanded & OpC->getValue()) == 0) + if (C->isSubsetOf(Demanded)) return false; // This instruction is producing bits that are not demanded. Shrink the RHS. - Demanded &= OpC->getValue(); - I->setOperand(OpNo, ConstantInt::get(OpC->getType(), Demanded)); + I->setOperand(OpNo, ConstantInt::get(Op->getType(), *C & Demanded)); return true; } @@ -52,10 +53,10 @@ static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo, /// the instruction has any properties that allow us to simplify its operands. bool InstCombiner::SimplifyDemandedInstructionBits(Instruction &Inst) { unsigned BitWidth = Inst.getType()->getScalarSizeInBits(); - APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); + KnownBits Known(BitWidth); APInt DemandedMask(APInt::getAllOnesValue(BitWidth)); - Value *V = SimplifyDemandedUseBits(&Inst, DemandedMask, KnownZero, KnownOne, + Value *V = SimplifyDemandedUseBits(&Inst, DemandedMask, Known, 0, &Inst); if (!V) return false; if (V == &Inst) return true; @@ -66,12 +67,13 @@ bool InstCombiner::SimplifyDemandedInstructionBits(Instruction &Inst) { /// This form of SimplifyDemandedBits simplifies the specified instruction /// operand if possible, updating it in place. It returns true if it made any /// change and false otherwise. -bool InstCombiner::SimplifyDemandedBits(Use &U, const APInt &DemandedMask, - APInt &KnownZero, APInt &KnownOne, +bool InstCombiner::SimplifyDemandedBits(Instruction *I, unsigned OpNo, + const APInt &DemandedMask, + KnownBits &Known, unsigned Depth) { - auto *UserI = dyn_cast<Instruction>(U.getUser()); - Value *NewVal = SimplifyDemandedUseBits(U.get(), DemandedMask, KnownZero, - KnownOne, Depth, UserI); + Use &U = I->getOperandUse(OpNo); + Value *NewVal = SimplifyDemandedUseBits(U.get(), DemandedMask, Known, + Depth, I); if (!NewVal) return false; U = NewVal; return true; @@ -85,15 +87,16 @@ bool InstCombiner::SimplifyDemandedBits(Use &U, const APInt &DemandedMask, /// with a constant or one of its operands. In such cases, this function does /// the replacement and returns true. In all other cases, it returns false after /// analyzing the expression and setting KnownOne and known to be one in the -/// expression. KnownZero contains all the bits that are known to be zero in the -/// expression. These are provided to potentially allow the caller (which might -/// recursively be SimplifyDemandedBits itself) to simplify the expression. -/// KnownOne and KnownZero always follow the invariant that: -/// KnownOne & KnownZero == 0. -/// That is, a bit can't be both 1 and 0. Note that the bits in KnownOne and -/// KnownZero may only be accurate for those bits set in DemandedMask. Note also -/// that the bitwidth of V, DemandedMask, KnownZero and KnownOne must all be the -/// same. +/// expression. Known.Zero contains all the bits that are known to be zero in +/// the expression. These are provided to potentially allow the caller (which +/// might recursively be SimplifyDemandedBits itself) to simplify the +/// expression. +/// Known.One and Known.Zero always follow the invariant that: +/// Known.One & Known.Zero == 0. +/// That is, a bit can't be both 1 and 0. Note that the bits in Known.One and +/// Known.Zero may only be accurate for those bits set in DemandedMask. Note +/// also that the bitwidth of V, DemandedMask, Known.Zero and Known.One must all +/// be the same. /// /// This returns null if it did not change anything and it permits no /// simplification. This returns V itself if it did some simplification of V's @@ -101,8 +104,7 @@ bool InstCombiner::SimplifyDemandedBits(Use &U, const APInt &DemandedMask, /// some other non-null value if it found out that V is equal to another value /// in the context where the specified bits are demanded, but not for all users. Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, - APInt &KnownZero, APInt &KnownOne, - unsigned Depth, + KnownBits &Known, unsigned Depth, Instruction *CxtI) { assert(V != nullptr && "Null pointer of Value???"); assert(Depth <= 6 && "Limit Search Depth"); @@ -110,246 +112,144 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, Type *VTy = V->getType(); assert( (!VTy->isIntOrIntVectorTy() || VTy->getScalarSizeInBits() == BitWidth) && - KnownZero.getBitWidth() == BitWidth && - KnownOne.getBitWidth() == BitWidth && - "Value *V, DemandedMask, KnownZero and KnownOne " - "must have same BitWidth"); - if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) { - // We know all of the bits for a constant! - KnownOne = CI->getValue() & DemandedMask; - KnownZero = ~KnownOne & DemandedMask; - return nullptr; - } - if (isa<ConstantPointerNull>(V)) { - // We know all of the bits for a constant! - KnownOne.clearAllBits(); - KnownZero = DemandedMask; + Known.getBitWidth() == BitWidth && + "Value *V, DemandedMask and Known must have same BitWidth"); + + if (isa<Constant>(V)) { + computeKnownBits(V, Known, Depth, CxtI); return nullptr; } - KnownZero.clearAllBits(); - KnownOne.clearAllBits(); - if (DemandedMask == 0) { // Not demanding any bits from V. - if (isa<UndefValue>(V)) - return nullptr; + Known.resetAll(); + if (DemandedMask.isNullValue()) // Not demanding any bits from V. return UndefValue::get(VTy); - } if (Depth == 6) // Limit search depth. return nullptr; - APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0); - APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); - Instruction *I = dyn_cast<Instruction>(V); if (!I) { - computeKnownBits(V, KnownZero, KnownOne, Depth, CxtI); + computeKnownBits(V, Known, Depth, CxtI); return nullptr; // Only analyze instructions. } // If there are multiple uses of this value and we aren't at the root, then // we can't do any simplifications of the operands, because DemandedMask // only reflects the bits demanded by *one* of the users. - if (Depth != 0 && !I->hasOneUse()) { - // Despite the fact that we can't simplify this instruction in all User's - // context, we can at least compute the knownzero/knownone bits, and we can - // do simplifications that apply to *just* the one user if we know that - // this instruction has a simpler value in that context. - if (I->getOpcode() == Instruction::And) { - // If either the LHS or the RHS are Zero, the result is zero. - computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth + 1, - CxtI); - computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth + 1, - CxtI); - - // If all of the demanded bits are known 1 on one side, return the other. - // These bits cannot contribute to the result of the 'and' in this - // context. - if ((DemandedMask & ~LHSKnownZero & RHSKnownOne) == - (DemandedMask & ~LHSKnownZero)) - return I->getOperand(0); - if ((DemandedMask & ~RHSKnownZero & LHSKnownOne) == - (DemandedMask & ~RHSKnownZero)) - return I->getOperand(1); - - // If all of the demanded bits in the inputs are known zeros, return zero. - if ((DemandedMask & (RHSKnownZero|LHSKnownZero)) == DemandedMask) - return Constant::getNullValue(VTy); - - } else if (I->getOpcode() == Instruction::Or) { - // We can simplify (X|Y) -> X or Y in the user's context if we know that - // only bits from X or Y are demanded. - - // If either the LHS or the RHS are One, the result is One. - computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth + 1, - CxtI); - computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth + 1, - CxtI); - - // If all of the demanded bits are known zero on one side, return the - // other. These bits cannot contribute to the result of the 'or' in this - // context. - if ((DemandedMask & ~LHSKnownOne & RHSKnownZero) == - (DemandedMask & ~LHSKnownOne)) - return I->getOperand(0); - if ((DemandedMask & ~RHSKnownOne & LHSKnownZero) == - (DemandedMask & ~RHSKnownOne)) - return I->getOperand(1); - - // If all of the potentially set bits on one side are known to be set on - // the other side, just use the 'other' side. - if ((DemandedMask & (~RHSKnownZero) & LHSKnownOne) == - (DemandedMask & (~RHSKnownZero))) - return I->getOperand(0); - if ((DemandedMask & (~LHSKnownZero) & RHSKnownOne) == - (DemandedMask & (~LHSKnownZero))) - return I->getOperand(1); - } else if (I->getOpcode() == Instruction::Xor) { - // We can simplify (X^Y) -> X or Y in the user's context if we know that - // only bits from X or Y are demanded. - - computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth + 1, - CxtI); - computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth + 1, - CxtI); - - // If all of the demanded bits are known zero on one side, return the - // other. - if ((DemandedMask & RHSKnownZero) == DemandedMask) - return I->getOperand(0); - if ((DemandedMask & LHSKnownZero) == DemandedMask) - return I->getOperand(1); - } + if (Depth != 0 && !I->hasOneUse()) + return SimplifyMultipleUseDemandedBits(I, DemandedMask, Known, Depth, CxtI); - // Compute the KnownZero/KnownOne bits to simplify things downstream. - computeKnownBits(I, KnownZero, KnownOne, Depth, CxtI); - return nullptr; - } + KnownBits LHSKnown(BitWidth), RHSKnown(BitWidth); // If this is the root being simplified, allow it to have multiple uses, // just set the DemandedMask to all bits so that we can try to simplify the // operands. This allows visitTruncInst (for example) to simplify the // operand of a trunc without duplicating all the logic below. if (Depth == 0 && !V->hasOneUse()) - DemandedMask = APInt::getAllOnesValue(BitWidth); + DemandedMask.setAllBits(); switch (I->getOpcode()) { default: - computeKnownBits(I, KnownZero, KnownOne, Depth, CxtI); + computeKnownBits(I, Known, Depth, CxtI); break; - case Instruction::And: + case Instruction::And: { // If either the LHS or the RHS are Zero, the result is zero. - if (SimplifyDemandedBits(I->getOperandUse(1), DemandedMask, RHSKnownZero, - RHSKnownOne, Depth + 1) || - SimplifyDemandedBits(I->getOperandUse(0), DemandedMask & ~RHSKnownZero, - LHSKnownZero, LHSKnownOne, Depth + 1)) + if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Depth + 1) || + SimplifyDemandedBits(I, 0, DemandedMask & ~RHSKnown.Zero, LHSKnown, + Depth + 1)) return I; - assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?"); - assert(!(LHSKnownZero & LHSKnownOne) && "Bits known to be one AND zero?"); + assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?"); + assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?"); + + // Output known-0 are known to be clear if zero in either the LHS | RHS. + APInt IKnownZero = RHSKnown.Zero | LHSKnown.Zero; + // Output known-1 bits are only known if set in both the LHS & RHS. + APInt IKnownOne = RHSKnown.One & LHSKnown.One; // If the client is only demanding bits that we know, return the known // constant. - if ((DemandedMask & ((RHSKnownZero | LHSKnownZero)| - (RHSKnownOne & LHSKnownOne))) == DemandedMask) - return Constant::getIntegerValue(VTy, RHSKnownOne & LHSKnownOne); + if (DemandedMask.isSubsetOf(IKnownZero|IKnownOne)) + return Constant::getIntegerValue(VTy, IKnownOne); // If all of the demanded bits are known 1 on one side, return the other. // These bits cannot contribute to the result of the 'and'. - if ((DemandedMask & ~LHSKnownZero & RHSKnownOne) == - (DemandedMask & ~LHSKnownZero)) + if (DemandedMask.isSubsetOf(LHSKnown.Zero | RHSKnown.One)) return I->getOperand(0); - if ((DemandedMask & ~RHSKnownZero & LHSKnownOne) == - (DemandedMask & ~RHSKnownZero)) + if (DemandedMask.isSubsetOf(RHSKnown.Zero | LHSKnown.One)) return I->getOperand(1); - // If all of the demanded bits in the inputs are known zeros, return zero. - if ((DemandedMask & (RHSKnownZero|LHSKnownZero)) == DemandedMask) - return Constant::getNullValue(VTy); - // If the RHS is a constant, see if we can simplify it. - if (ShrinkDemandedConstant(I, 1, DemandedMask & ~LHSKnownZero)) + if (ShrinkDemandedConstant(I, 1, DemandedMask & ~LHSKnown.Zero)) return I; - // Output known-1 bits are only known if set in both the LHS & RHS. - KnownOne = RHSKnownOne & LHSKnownOne; - // Output known-0 are known to be clear if zero in either the LHS | RHS. - KnownZero = RHSKnownZero | LHSKnownZero; + Known.Zero = std::move(IKnownZero); + Known.One = std::move(IKnownOne); break; - case Instruction::Or: + } + case Instruction::Or: { // If either the LHS or the RHS are One, the result is One. - if (SimplifyDemandedBits(I->getOperandUse(1), DemandedMask, RHSKnownZero, - RHSKnownOne, Depth + 1) || - SimplifyDemandedBits(I->getOperandUse(0), DemandedMask & ~RHSKnownOne, - LHSKnownZero, LHSKnownOne, Depth + 1)) + if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Depth + 1) || + SimplifyDemandedBits(I, 0, DemandedMask & ~RHSKnown.One, LHSKnown, + Depth + 1)) return I; - assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?"); - assert(!(LHSKnownZero & LHSKnownOne) && "Bits known to be one AND zero?"); + assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?"); + assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?"); + + // Output known-0 bits are only known if clear in both the LHS & RHS. + APInt IKnownZero = RHSKnown.Zero & LHSKnown.Zero; + // Output known-1 are known. to be set if s.et in either the LHS | RHS. + APInt IKnownOne = RHSKnown.One | LHSKnown.One; // If the client is only demanding bits that we know, return the known // constant. - if ((DemandedMask & ((RHSKnownZero & LHSKnownZero)| - (RHSKnownOne | LHSKnownOne))) == DemandedMask) - return Constant::getIntegerValue(VTy, RHSKnownOne | LHSKnownOne); + if (DemandedMask.isSubsetOf(IKnownZero|IKnownOne)) + return Constant::getIntegerValue(VTy, IKnownOne); // If all of the demanded bits are known zero on one side, return the other. // These bits cannot contribute to the result of the 'or'. - if ((DemandedMask & ~LHSKnownOne & RHSKnownZero) == - (DemandedMask & ~LHSKnownOne)) - return I->getOperand(0); - if ((DemandedMask & ~RHSKnownOne & LHSKnownZero) == - (DemandedMask & ~RHSKnownOne)) - return I->getOperand(1); - - // If all of the potentially set bits on one side are known to be set on - // the other side, just use the 'other' side. - if ((DemandedMask & (~RHSKnownZero) & LHSKnownOne) == - (DemandedMask & (~RHSKnownZero))) + if (DemandedMask.isSubsetOf(LHSKnown.One | RHSKnown.Zero)) return I->getOperand(0); - if ((DemandedMask & (~LHSKnownZero) & RHSKnownOne) == - (DemandedMask & (~LHSKnownZero))) + if (DemandedMask.isSubsetOf(RHSKnown.One | LHSKnown.Zero)) return I->getOperand(1); // If the RHS is a constant, see if we can simplify it. if (ShrinkDemandedConstant(I, 1, DemandedMask)) return I; - // Output known-0 bits are only known if clear in both the LHS & RHS. - KnownZero = RHSKnownZero & LHSKnownZero; - // Output known-1 are known to be set if set in either the LHS | RHS. - KnownOne = RHSKnownOne | LHSKnownOne; + Known.Zero = std::move(IKnownZero); + Known.One = std::move(IKnownOne); break; + } case Instruction::Xor: { - if (SimplifyDemandedBits(I->getOperandUse(1), DemandedMask, RHSKnownZero, - RHSKnownOne, Depth + 1) || - SimplifyDemandedBits(I->getOperandUse(0), DemandedMask, LHSKnownZero, - LHSKnownOne, Depth + 1)) + if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Depth + 1) || + SimplifyDemandedBits(I, 0, DemandedMask, LHSKnown, Depth + 1)) return I; - assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?"); - assert(!(LHSKnownZero & LHSKnownOne) && "Bits known to be one AND zero?"); + assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?"); + assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?"); // Output known-0 bits are known if clear or set in both the LHS & RHS. - APInt IKnownZero = (RHSKnownZero & LHSKnownZero) | - (RHSKnownOne & LHSKnownOne); + APInt IKnownZero = (RHSKnown.Zero & LHSKnown.Zero) | + (RHSKnown.One & LHSKnown.One); // Output known-1 are known to be set if set in only one of the LHS, RHS. - APInt IKnownOne = (RHSKnownZero & LHSKnownOne) | - (RHSKnownOne & LHSKnownZero); + APInt IKnownOne = (RHSKnown.Zero & LHSKnown.One) | + (RHSKnown.One & LHSKnown.Zero); // If the client is only demanding bits that we know, return the known // constant. - if ((DemandedMask & (IKnownZero|IKnownOne)) == DemandedMask) + if (DemandedMask.isSubsetOf(IKnownZero|IKnownOne)) return Constant::getIntegerValue(VTy, IKnownOne); // If all of the demanded bits are known zero on one side, return the other. // These bits cannot contribute to the result of the 'xor'. - if ((DemandedMask & RHSKnownZero) == DemandedMask) + if (DemandedMask.isSubsetOf(RHSKnown.Zero)) return I->getOperand(0); - if ((DemandedMask & LHSKnownZero) == DemandedMask) + if (DemandedMask.isSubsetOf(LHSKnown.Zero)) return I->getOperand(1); // If all of the demanded bits are known to be zero on one side or the // other, turn this into an *inclusive* or. // e.g. (A & C1)^(B & C2) -> (A & C1)|(B & C2) iff C1&C2 == 0 - if ((DemandedMask & ~RHSKnownZero & ~LHSKnownZero) == 0) { + if (DemandedMask.isSubsetOf(RHSKnown.Zero | LHSKnown.Zero)) { Instruction *Or = BinaryOperator::CreateOr(I->getOperand(0), I->getOperand(1), I->getName()); @@ -360,14 +260,12 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // bits on that side are also known to be set on the other side, turn this // into an AND, as we know the bits will be cleared. // e.g. (X | C1) ^ C2 --> (X | C1) & ~C2 iff (C1&C2) == C2 - if ((DemandedMask & (RHSKnownZero|RHSKnownOne)) == DemandedMask) { - // all known - if ((RHSKnownOne & LHSKnownOne) == RHSKnownOne) { - Constant *AndC = Constant::getIntegerValue(VTy, - ~RHSKnownOne & DemandedMask); - Instruction *And = BinaryOperator::CreateAnd(I->getOperand(0), AndC); - return InsertNewInstWith(And, *I); - } + if (DemandedMask.isSubsetOf(RHSKnown.Zero|RHSKnown.One) && + RHSKnown.One.isSubsetOf(LHSKnown.One)) { + Constant *AndC = Constant::getIntegerValue(VTy, + ~RHSKnown.One & DemandedMask); + Instruction *And = BinaryOperator::CreateAnd(I->getOperand(0), AndC); + return InsertNewInstWith(And, *I); } // If the RHS is a constant, see if we can simplify it. @@ -383,10 +281,10 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (LHSInst->getOpcode() == Instruction::And && LHSInst->hasOneUse() && isa<ConstantInt>(I->getOperand(1)) && isa<ConstantInt>(LHSInst->getOperand(1)) && - (LHSKnownOne & RHSKnownOne & DemandedMask) != 0) { + (LHSKnown.One & RHSKnown.One & DemandedMask) != 0) { ConstantInt *AndRHS = cast<ConstantInt>(LHSInst->getOperand(1)); ConstantInt *XorRHS = cast<ConstantInt>(I->getOperand(1)); - APInt NewMask = ~(LHSKnownOne & RHSKnownOne & DemandedMask); + APInt NewMask = ~(LHSKnown.One & RHSKnown.One & DemandedMask); Constant *AndC = ConstantInt::get(I->getType(), NewMask & AndRHS->getValue()); @@ -400,9 +298,9 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, } // Output known-0 bits are known if clear or set in both the LHS & RHS. - KnownZero= (RHSKnownZero & LHSKnownZero) | (RHSKnownOne & LHSKnownOne); + Known.Zero = std::move(IKnownZero); // Output known-1 are known to be set if set in only one of the LHS, RHS. - KnownOne = (RHSKnownZero & LHSKnownOne) | (RHSKnownOne & LHSKnownZero); + Known.One = std::move(IKnownOne); break; } case Instruction::Select: @@ -412,13 +310,11 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (matchSelectPattern(I, LHS, RHS).Flavor != SPF_UNKNOWN) return nullptr; - if (SimplifyDemandedBits(I->getOperandUse(2), DemandedMask, RHSKnownZero, - RHSKnownOne, Depth + 1) || - SimplifyDemandedBits(I->getOperandUse(1), DemandedMask, LHSKnownZero, - LHSKnownOne, Depth + 1)) + if (SimplifyDemandedBits(I, 2, DemandedMask, RHSKnown, Depth + 1) || + SimplifyDemandedBits(I, 1, DemandedMask, LHSKnown, Depth + 1)) return I; - assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?"); - assert(!(LHSKnownZero & LHSKnownOne) && "Bits known to be one AND zero?"); + assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?"); + assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?"); // If the operands are constants, see if we can simplify them. if (ShrinkDemandedConstant(I, 1, DemandedMask) || @@ -426,21 +322,22 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, return I; // Only known if known in both the LHS and RHS. - KnownOne = RHSKnownOne & LHSKnownOne; - KnownZero = RHSKnownZero & LHSKnownZero; + Known.One = RHSKnown.One & LHSKnown.One; + Known.Zero = RHSKnown.Zero & LHSKnown.Zero; break; + case Instruction::ZExt: case Instruction::Trunc: { - unsigned truncBf = I->getOperand(0)->getType()->getScalarSizeInBits(); - DemandedMask = DemandedMask.zext(truncBf); - KnownZero = KnownZero.zext(truncBf); - KnownOne = KnownOne.zext(truncBf); - if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMask, KnownZero, - KnownOne, Depth + 1)) + unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits(); + + APInt InputDemandedMask = DemandedMask.zextOrTrunc(SrcBitWidth); + KnownBits InputKnown(SrcBitWidth); + if (SimplifyDemandedBits(I, 0, InputDemandedMask, InputKnown, Depth + 1)) return I; - DemandedMask = DemandedMask.trunc(BitWidth); - KnownZero = KnownZero.trunc(BitWidth); - KnownOne = KnownOne.trunc(BitWidth); - assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?"); + Known = Known.zextOrTrunc(BitWidth); + // Any top bits are known to be zero. + if (BitWidth > SrcBitWidth) + Known.Zero.setBitsFrom(SrcBitWidth); + assert(!Known.hasConflict() && "Bits known to be one AND zero?"); break; } case Instruction::BitCast: @@ -460,65 +357,38 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // Don't touch a vector-to-scalar bitcast. return nullptr; - if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMask, KnownZero, - KnownOne, Depth + 1)) - return I; - assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?"); - break; - case Instruction::ZExt: { - // Compute the bits in the result that are not present in the input. - unsigned SrcBitWidth =I->getOperand(0)->getType()->getScalarSizeInBits(); - - DemandedMask = DemandedMask.trunc(SrcBitWidth); - KnownZero = KnownZero.trunc(SrcBitWidth); - KnownOne = KnownOne.trunc(SrcBitWidth); - if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMask, KnownZero, - KnownOne, Depth + 1)) + if (SimplifyDemandedBits(I, 0, DemandedMask, Known, Depth + 1)) return I; - DemandedMask = DemandedMask.zext(BitWidth); - KnownZero = KnownZero.zext(BitWidth); - KnownOne = KnownOne.zext(BitWidth); - assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?"); - // The top bits are known to be zero. - KnownZero |= APInt::getHighBitsSet(BitWidth, BitWidth - SrcBitWidth); + assert(!Known.hasConflict() && "Bits known to be one AND zero?"); break; - } case Instruction::SExt: { // Compute the bits in the result that are not present in the input. - unsigned SrcBitWidth =I->getOperand(0)->getType()->getScalarSizeInBits(); + unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits(); - APInt InputDemandedBits = DemandedMask & - APInt::getLowBitsSet(BitWidth, SrcBitWidth); + APInt InputDemandedBits = DemandedMask.trunc(SrcBitWidth); - APInt NewBits(APInt::getHighBitsSet(BitWidth, BitWidth - SrcBitWidth)); // If any of the sign extended bits are demanded, we know that the sign // bit is demanded. - if ((NewBits & DemandedMask) != 0) + if (DemandedMask.getActiveBits() > SrcBitWidth) InputDemandedBits.setBit(SrcBitWidth-1); - InputDemandedBits = InputDemandedBits.trunc(SrcBitWidth); - KnownZero = KnownZero.trunc(SrcBitWidth); - KnownOne = KnownOne.trunc(SrcBitWidth); - if (SimplifyDemandedBits(I->getOperandUse(0), InputDemandedBits, KnownZero, - KnownOne, Depth + 1)) + KnownBits InputKnown(SrcBitWidth); + if (SimplifyDemandedBits(I, 0, InputDemandedBits, InputKnown, Depth + 1)) return I; - InputDemandedBits = InputDemandedBits.zext(BitWidth); - KnownZero = KnownZero.zext(BitWidth); - KnownOne = KnownOne.zext(BitWidth); - assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?"); - - // If the sign bit of the input is known set or clear, then we know the - // top bits of the result. // If the input sign bit is known zero, or if the NewBits are not demanded // convert this into a zero extension. - if (KnownZero[SrcBitWidth-1] || (NewBits & ~DemandedMask) == NewBits) { - // Convert to ZExt cast + if (InputKnown.isNonNegative() || + DemandedMask.getActiveBits() <= SrcBitWidth) { + // Convert to ZExt cast. CastInst *NewCast = new ZExtInst(I->getOperand(0), VTy, I->getName()); return InsertNewInstWith(NewCast, *I); - } else if (KnownOne[SrcBitWidth-1]) { // Input sign bit known set - KnownOne |= NewBits; - } + } + + // If the sign bit of the input is known set or clear, then we know the + // top bits of the result. + Known = InputKnown.sext(BitWidth); + assert(!Known.hasConflict() && "Bits known to be one AND zero?"); break; } case Instruction::Add: @@ -530,11 +400,10 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // Right fill the mask of bits for this ADD/SUB to demand the most // significant bit and all those below it. APInt DemandedFromOps(APInt::getLowBitsSet(BitWidth, BitWidth-NLZ)); - if (SimplifyDemandedBits(I->getOperandUse(0), DemandedFromOps, - LHSKnownZero, LHSKnownOne, Depth + 1) || + if (ShrinkDemandedConstant(I, 0, DemandedFromOps) || + SimplifyDemandedBits(I, 0, DemandedFromOps, LHSKnown, Depth + 1) || ShrinkDemandedConstant(I, 1, DemandedFromOps) || - SimplifyDemandedBits(I->getOperandUse(1), DemandedFromOps, - LHSKnownZero, LHSKnownOne, Depth + 1)) { + SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1)) { // Disable the nsw and nuw flags here: We can no longer guarantee that // we won't wrap after simplification. Removing the nsw/nuw flags is // legal here because the top bit is not demanded. @@ -543,24 +412,33 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, BinOP.setHasNoUnsignedWrap(false); return I; } + + // If we are known to be adding/subtracting zeros to every bit below + // the highest demanded bit, we just return the other side. + if (DemandedFromOps.isSubsetOf(RHSKnown.Zero)) + return I->getOperand(0); + // We can't do this with the LHS for subtraction, unless we are only + // demanding the LSB. + if ((I->getOpcode() == Instruction::Add || + DemandedFromOps.isOneValue()) && + DemandedFromOps.isSubsetOf(LHSKnown.Zero)) + return I->getOperand(1); } // Otherwise just hand the add/sub off to computeKnownBits to fill in // the known zeros and ones. - computeKnownBits(V, KnownZero, KnownOne, Depth, CxtI); + computeKnownBits(V, Known, Depth, CxtI); break; } - case Instruction::Shl: - if (ConstantInt *SA = dyn_cast<ConstantInt>(I->getOperand(1))) { - { - Value *VarX; ConstantInt *C1; - if (match(I->getOperand(0), m_Shr(m_Value(VarX), m_ConstantInt(C1)))) { - Instruction *Shr = cast<Instruction>(I->getOperand(0)); - Value *R = SimplifyShrShlDemandedBits(Shr, I, DemandedMask, - KnownZero, KnownOne); - if (R) - return R; - } + case Instruction::Shl: { + const APInt *SA; + if (match(I->getOperand(1), m_APInt(SA))) { + const APInt *ShrAmt; + if (match(I->getOperand(0), m_Shr(m_Value(), m_APInt(ShrAmt)))) { + Instruction *Shr = cast<Instruction>(I->getOperand(0)); + if (Value *R = simplifyShrShlDemandedBits( + Shr, *ShrAmt, I, *SA, DemandedMask, Known)) + return R; } uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1); @@ -569,24 +447,24 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // If the shift is NUW/NSW, then it does demand the high bits. ShlOperator *IOp = cast<ShlOperator>(I); if (IOp->hasNoSignedWrap()) - DemandedMaskIn |= APInt::getHighBitsSet(BitWidth, ShiftAmt+1); + DemandedMaskIn.setHighBits(ShiftAmt+1); else if (IOp->hasNoUnsignedWrap()) - DemandedMaskIn |= APInt::getHighBitsSet(BitWidth, ShiftAmt); + DemandedMaskIn.setHighBits(ShiftAmt); - if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMaskIn, KnownZero, - KnownOne, Depth + 1)) + if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1)) return I; - assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?"); - KnownZero <<= ShiftAmt; - KnownOne <<= ShiftAmt; + assert(!Known.hasConflict() && "Bits known to be one AND zero?"); + Known.Zero <<= ShiftAmt; + Known.One <<= ShiftAmt; // low bits known zero. if (ShiftAmt) - KnownZero |= APInt::getLowBitsSet(BitWidth, ShiftAmt); + Known.Zero.setLowBits(ShiftAmt); } break; - case Instruction::LShr: - // For a logical shift right - if (ConstantInt *SA = dyn_cast<ConstantInt>(I->getOperand(1))) { + } + case Instruction::LShr: { + const APInt *SA; + if (match(I->getOperand(1), m_APInt(SA))) { uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1); // Unsigned shift right. @@ -595,27 +473,24 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // If the shift is exact, then it does demand the low bits (and knows that // they are zero). if (cast<LShrOperator>(I)->isExact()) - DemandedMaskIn |= APInt::getLowBitsSet(BitWidth, ShiftAmt); + DemandedMaskIn.setLowBits(ShiftAmt); - if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMaskIn, KnownZero, - KnownOne, Depth + 1)) + if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1)) return I; - assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?"); - KnownZero = APIntOps::lshr(KnownZero, ShiftAmt); - KnownOne = APIntOps::lshr(KnownOne, ShiftAmt); - if (ShiftAmt) { - // Compute the new bits that are at the top now. - APInt HighBits(APInt::getHighBitsSet(BitWidth, ShiftAmt)); - KnownZero |= HighBits; // high bits known zero. - } + assert(!Known.hasConflict() && "Bits known to be one AND zero?"); + Known.Zero.lshrInPlace(ShiftAmt); + Known.One.lshrInPlace(ShiftAmt); + if (ShiftAmt) + Known.Zero.setHighBits(ShiftAmt); // high bits known zero. } break; - case Instruction::AShr: + } + case Instruction::AShr: { // If this is an arithmetic shift right and only the low-bit is set, we can // always convert this into a logical shr, even if the shift amount is // variable. The low bit of the shift cannot be an input sign bit unless // the shift amount is >= the size of the datatype, which is undefined. - if (DemandedMask == 1) { + if (DemandedMask.isOneValue()) { // Perform the logical shift right. Instruction *NewVal = BinaryOperator::CreateLShr( I->getOperand(0), I->getOperand(1), I->getName()); @@ -624,57 +499,58 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // If the sign bit is the only bit demanded by this ashr, then there is no // need to do it, the shift doesn't change the high bit. - if (DemandedMask.isSignBit()) + if (DemandedMask.isSignMask()) return I->getOperand(0); - if (ConstantInt *SA = dyn_cast<ConstantInt>(I->getOperand(1))) { + const APInt *SA; + if (match(I->getOperand(1), m_APInt(SA))) { uint32_t ShiftAmt = SA->getLimitedValue(BitWidth-1); // Signed shift right. APInt DemandedMaskIn(DemandedMask.shl(ShiftAmt)); - // If any of the "high bits" are demanded, we should set the sign bit as + // If any of the high bits are demanded, we should set the sign bit as // demanded. if (DemandedMask.countLeadingZeros() <= ShiftAmt) - DemandedMaskIn.setBit(BitWidth-1); + DemandedMaskIn.setSignBit(); // If the shift is exact, then it does demand the low bits (and knows that // they are zero). if (cast<AShrOperator>(I)->isExact()) - DemandedMaskIn |= APInt::getLowBitsSet(BitWidth, ShiftAmt); + DemandedMaskIn.setLowBits(ShiftAmt); - if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMaskIn, KnownZero, - KnownOne, Depth + 1)) + if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1)) return I; - assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?"); + + assert(!Known.hasConflict() && "Bits known to be one AND zero?"); // Compute the new bits that are at the top now. APInt HighBits(APInt::getHighBitsSet(BitWidth, ShiftAmt)); - KnownZero = APIntOps::lshr(KnownZero, ShiftAmt); - KnownOne = APIntOps::lshr(KnownOne, ShiftAmt); + Known.Zero.lshrInPlace(ShiftAmt); + Known.One.lshrInPlace(ShiftAmt); // Handle the sign bits. - APInt SignBit(APInt::getSignBit(BitWidth)); + APInt SignMask(APInt::getSignMask(BitWidth)); // Adjust to where it is now in the mask. - SignBit = APIntOps::lshr(SignBit, ShiftAmt); + SignMask.lshrInPlace(ShiftAmt); // If the input sign bit is known to be zero, or if none of the top bits // are demanded, turn this into an unsigned shift right. - if (BitWidth <= ShiftAmt || KnownZero[BitWidth-ShiftAmt-1] || - (HighBits & ~DemandedMask) == HighBits) { - // Perform the logical shift right. - BinaryOperator *NewVal = BinaryOperator::CreateLShr(I->getOperand(0), - SA, I->getName()); - NewVal->setIsExact(cast<BinaryOperator>(I)->isExact()); - return InsertNewInstWith(NewVal, *I); - } else if ((KnownOne & SignBit) != 0) { // New bits are known one. - KnownOne |= HighBits; + if (BitWidth <= ShiftAmt || Known.Zero[BitWidth-ShiftAmt-1] || + !DemandedMask.intersects(HighBits)) { + BinaryOperator *LShr = BinaryOperator::CreateLShr(I->getOperand(0), + I->getOperand(1)); + LShr->setIsExact(cast<BinaryOperator>(I)->isExact()); + return InsertNewInstWith(LShr, *I); + } else if (Known.One.intersects(SignMask)) { // New bits are known one. + Known.One |= HighBits; } } break; + } case Instruction::SRem: if (ConstantInt *Rem = dyn_cast<ConstantInt>(I->getOperand(1))) { // X % -1 demands all the bits because we don't want to introduce // INT_MIN % -1 (== undef) by accident. - if (Rem->isAllOnesValue()) + if (Rem->isMinusOne()) break; APInt RA = Rem->getValue().abs(); if (RA.isPowerOf2()) { @@ -682,53 +558,47 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, return I->getOperand(0); APInt LowBits = RA - 1; - APInt Mask2 = LowBits | APInt::getSignBit(BitWidth); - if (SimplifyDemandedBits(I->getOperandUse(0), Mask2, LHSKnownZero, - LHSKnownOne, Depth + 1)) + APInt Mask2 = LowBits | APInt::getSignMask(BitWidth); + if (SimplifyDemandedBits(I, 0, Mask2, LHSKnown, Depth + 1)) return I; // The low bits of LHS are unchanged by the srem. - KnownZero = LHSKnownZero & LowBits; - KnownOne = LHSKnownOne & LowBits; + Known.Zero = LHSKnown.Zero & LowBits; + Known.One = LHSKnown.One & LowBits; // If LHS is non-negative or has all low bits zero, then the upper bits // are all zero. - if (LHSKnownZero[BitWidth-1] || ((LHSKnownZero & LowBits) == LowBits)) - KnownZero |= ~LowBits; + if (LHSKnown.isNonNegative() || LowBits.isSubsetOf(LHSKnown.Zero)) + Known.Zero |= ~LowBits; // If LHS is negative and not all low bits are zero, then the upper bits // are all one. - if (LHSKnownOne[BitWidth-1] && ((LHSKnownOne & LowBits) != 0)) - KnownOne |= ~LowBits; + if (LHSKnown.isNegative() && LowBits.intersects(LHSKnown.One)) + Known.One |= ~LowBits; - assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?"); + assert(!Known.hasConflict() && "Bits known to be one AND zero?"); + break; } } // The sign bit is the LHS's sign bit, except when the result of the // remainder is zero. - if (DemandedMask.isNegative() && KnownZero.isNonNegative()) { - APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0); - computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth + 1, - CxtI); + if (DemandedMask.isSignBitSet()) { + computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI); // If it's known zero, our sign bit is also zero. - if (LHSKnownZero.isNegative()) - KnownZero.setBit(KnownZero.getBitWidth() - 1); + if (LHSKnown.isNonNegative()) + Known.makeNonNegative(); } break; case Instruction::URem: { - APInt KnownZero2(BitWidth, 0), KnownOne2(BitWidth, 0); + KnownBits Known2(BitWidth); APInt AllOnes = APInt::getAllOnesValue(BitWidth); - if (SimplifyDemandedBits(I->getOperandUse(0), AllOnes, KnownZero2, - KnownOne2, Depth + 1) || - SimplifyDemandedBits(I->getOperandUse(1), AllOnes, KnownZero2, - KnownOne2, Depth + 1)) + if (SimplifyDemandedBits(I, 0, AllOnes, Known2, Depth + 1) || + SimplifyDemandedBits(I, 1, AllOnes, Known2, Depth + 1)) return I; - unsigned Leaders = KnownZero2.countLeadingOnes(); - Leaders = std::max(Leaders, - KnownZero2.countLeadingOnes()); - KnownZero = APInt::getHighBitsSet(BitWidth, Leaders) & DemandedMask; + unsigned Leaders = Known2.countMinLeadingZeros(); + Known.Zero = APInt::getHighBitsSet(BitWidth, Leaders) & DemandedMask; break; } case Instruction::Call: @@ -788,29 +658,156 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // If we don't need any of low bits then return zero, // we know that DemandedMask is non-zero already. APInt DemandedElts = DemandedMask.zextOrTrunc(ArgWidth); - if (DemandedElts == 0) + if (DemandedElts.isNullValue()) return ConstantInt::getNullValue(VTy); // We know that the upper bits are set to zero. - KnownZero = APInt::getHighBitsSet(BitWidth, BitWidth - ArgWidth); + Known.Zero.setBitsFrom(ArgWidth); return nullptr; } case Intrinsic::x86_sse42_crc32_64_64: - KnownZero = APInt::getHighBitsSet(64, 32); + Known.Zero.setBitsFrom(32); return nullptr; } } - computeKnownBits(V, KnownZero, KnownOne, Depth, CxtI); + computeKnownBits(V, Known, Depth, CxtI); break; } // If the client is only demanding bits that we know, return the known // constant. - if ((DemandedMask & (KnownZero|KnownOne)) == DemandedMask) - return Constant::getIntegerValue(VTy, KnownOne); + if (DemandedMask.isSubsetOf(Known.Zero|Known.One)) + return Constant::getIntegerValue(VTy, Known.One); + return nullptr; +} + +/// Helper routine of SimplifyDemandedUseBits. It computes Known +/// bits. It also tries to handle simplifications that can be done based on +/// DemandedMask, but without modifying the Instruction. +Value *InstCombiner::SimplifyMultipleUseDemandedBits(Instruction *I, + const APInt &DemandedMask, + KnownBits &Known, + unsigned Depth, + Instruction *CxtI) { + unsigned BitWidth = DemandedMask.getBitWidth(); + Type *ITy = I->getType(); + + KnownBits LHSKnown(BitWidth); + KnownBits RHSKnown(BitWidth); + + // Despite the fact that we can't simplify this instruction in all User's + // context, we can at least compute the known bits, and we can + // do simplifications that apply to *just* the one user if we know that + // this instruction has a simpler value in that context. + switch (I->getOpcode()) { + case Instruction::And: { + // If either the LHS or the RHS are Zero, the result is zero. + computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI); + computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, + CxtI); + + // Output known-0 are known to be clear if zero in either the LHS | RHS. + APInt IKnownZero = RHSKnown.Zero | LHSKnown.Zero; + // Output known-1 bits are only known if set in both the LHS & RHS. + APInt IKnownOne = RHSKnown.One & LHSKnown.One; + + // If the client is only demanding bits that we know, return the known + // constant. + if (DemandedMask.isSubsetOf(IKnownZero|IKnownOne)) + return Constant::getIntegerValue(ITy, IKnownOne); + + // If all of the demanded bits are known 1 on one side, return the other. + // These bits cannot contribute to the result of the 'and' in this + // context. + if (DemandedMask.isSubsetOf(LHSKnown.Zero | RHSKnown.One)) + return I->getOperand(0); + if (DemandedMask.isSubsetOf(RHSKnown.Zero | LHSKnown.One)) + return I->getOperand(1); + + Known.Zero = std::move(IKnownZero); + Known.One = std::move(IKnownOne); + break; + } + case Instruction::Or: { + // We can simplify (X|Y) -> X or Y in the user's context if we know that + // only bits from X or Y are demanded. + + // If either the LHS or the RHS are One, the result is One. + computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI); + computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, + CxtI); + + // Output known-0 bits are only known if clear in both the LHS & RHS. + APInt IKnownZero = RHSKnown.Zero & LHSKnown.Zero; + // Output known-1 are known to be set if set in either the LHS | RHS. + APInt IKnownOne = RHSKnown.One | LHSKnown.One; + + // If the client is only demanding bits that we know, return the known + // constant. + if (DemandedMask.isSubsetOf(IKnownZero|IKnownOne)) + return Constant::getIntegerValue(ITy, IKnownOne); + + // If all of the demanded bits are known zero on one side, return the + // other. These bits cannot contribute to the result of the 'or' in this + // context. + if (DemandedMask.isSubsetOf(LHSKnown.One | RHSKnown.Zero)) + return I->getOperand(0); + if (DemandedMask.isSubsetOf(RHSKnown.One | LHSKnown.Zero)) + return I->getOperand(1); + + Known.Zero = std::move(IKnownZero); + Known.One = std::move(IKnownOne); + break; + } + case Instruction::Xor: { + // We can simplify (X^Y) -> X or Y in the user's context if we know that + // only bits from X or Y are demanded. + + computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI); + computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, + CxtI); + + // Output known-0 bits are known if clear or set in both the LHS & RHS. + APInt IKnownZero = (RHSKnown.Zero & LHSKnown.Zero) | + (RHSKnown.One & LHSKnown.One); + // Output known-1 are known to be set if set in only one of the LHS, RHS. + APInt IKnownOne = (RHSKnown.Zero & LHSKnown.One) | + (RHSKnown.One & LHSKnown.Zero); + + // If the client is only demanding bits that we know, return the known + // constant. + if (DemandedMask.isSubsetOf(IKnownZero|IKnownOne)) + return Constant::getIntegerValue(ITy, IKnownOne); + + // If all of the demanded bits are known zero on one side, return the + // other. + if (DemandedMask.isSubsetOf(RHSKnown.Zero)) + return I->getOperand(0); + if (DemandedMask.isSubsetOf(LHSKnown.Zero)) + return I->getOperand(1); + + // Output known-0 bits are known if clear or set in both the LHS & RHS. + Known.Zero = std::move(IKnownZero); + // Output known-1 are known to be set if set in only one of the LHS, RHS. + Known.One = std::move(IKnownOne); + break; + } + default: + // Compute the Known bits to simplify things downstream. + computeKnownBits(I, Known, Depth, CxtI); + + // If this user is only demanding bits that we know, return the known + // constant. + if (DemandedMask.isSubsetOf(Known.Zero|Known.One)) + return Constant::getIntegerValue(ITy, Known.One); + + break; + } + return nullptr; } + /// Helper routine of SimplifyDemandedUseBits. It tries to simplify /// "E1 = (X lsr C1) << C2", where the C1 and C2 are constant, into /// "E2 = X << (C2 - C1)" or "E2 = X >> (C1 - C2)", depending on the sign @@ -828,29 +825,26 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, /// /// As with SimplifyDemandedUseBits, it returns NULL if the simplification was /// not successful. -Value *InstCombiner::SimplifyShrShlDemandedBits(Instruction *Shr, - Instruction *Shl, - const APInt &DemandedMask, - APInt &KnownZero, - APInt &KnownOne) { - - const APInt &ShlOp1 = cast<ConstantInt>(Shl->getOperand(1))->getValue(); - const APInt &ShrOp1 = cast<ConstantInt>(Shr->getOperand(1))->getValue(); +Value * +InstCombiner::simplifyShrShlDemandedBits(Instruction *Shr, const APInt &ShrOp1, + Instruction *Shl, const APInt &ShlOp1, + const APInt &DemandedMask, + KnownBits &Known) { if (!ShlOp1 || !ShrOp1) - return nullptr; // Noop. + return nullptr; // No-op. Value *VarX = Shr->getOperand(0); Type *Ty = VarX->getType(); - unsigned BitWidth = Ty->getIntegerBitWidth(); + unsigned BitWidth = Ty->getScalarSizeInBits(); if (ShlOp1.uge(BitWidth) || ShrOp1.uge(BitWidth)) return nullptr; // Undef. unsigned ShlAmt = ShlOp1.getZExtValue(); unsigned ShrAmt = ShrOp1.getZExtValue(); - KnownOne.clearAllBits(); - KnownZero = APInt::getBitsSet(KnownZero.getBitWidth(), 0, ShlAmt-1); - KnownZero &= DemandedMask; + Known.One.clearAllBits(); + Known.Zero.setLowBits(ShlAmt - 1); + Known.Zero &= DemandedMask; APInt BitMask1(APInt::getAllOnesValue(BitWidth)); APInt BitMask2(APInt::getAllOnesValue(BitWidth)); @@ -916,7 +910,7 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, return nullptr; } - if (DemandedElts == 0) { // If nothing is demanded, provide undef. + if (DemandedElts.isNullValue()) { // If nothing is demanded, provide undef. UndefElts = EltMask; return UndefValue::get(V->getType()); } @@ -1472,14 +1466,213 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, break; } + case Intrinsic::x86_sse2_packssdw_128: + case Intrinsic::x86_sse2_packsswb_128: + case Intrinsic::x86_sse2_packuswb_128: + case Intrinsic::x86_sse41_packusdw: + case Intrinsic::x86_avx2_packssdw: + case Intrinsic::x86_avx2_packsswb: + case Intrinsic::x86_avx2_packusdw: + case Intrinsic::x86_avx2_packuswb: + case Intrinsic::x86_avx512_packssdw_512: + case Intrinsic::x86_avx512_packsswb_512: + case Intrinsic::x86_avx512_packusdw_512: + case Intrinsic::x86_avx512_packuswb_512: { + auto *Ty0 = II->getArgOperand(0)->getType(); + unsigned InnerVWidth = Ty0->getVectorNumElements(); + assert(VWidth == (InnerVWidth * 2) && "Unexpected input size"); + + unsigned NumLanes = Ty0->getPrimitiveSizeInBits() / 128; + unsigned VWidthPerLane = VWidth / NumLanes; + unsigned InnerVWidthPerLane = InnerVWidth / NumLanes; + + // Per lane, pack the elements of the first input and then the second. + // e.g. + // v8i16 PACK(v4i32 X, v4i32 Y) - (X[0..3],Y[0..3]) + // v32i8 PACK(v16i16 X, v16i16 Y) - (X[0..7],Y[0..7]),(X[8..15],Y[8..15]) + for (int OpNum = 0; OpNum != 2; ++OpNum) { + APInt OpDemandedElts(InnerVWidth, 0); + for (unsigned Lane = 0; Lane != NumLanes; ++Lane) { + unsigned LaneIdx = Lane * VWidthPerLane; + for (unsigned Elt = 0; Elt != InnerVWidthPerLane; ++Elt) { + unsigned Idx = LaneIdx + Elt + InnerVWidthPerLane * OpNum; + if (DemandedElts[Idx]) + OpDemandedElts.setBit((Lane * InnerVWidthPerLane) + Elt); + } + } + + // Demand elements from the operand. + auto *Op = II->getArgOperand(OpNum); + APInt OpUndefElts(InnerVWidth, 0); + TmpV = SimplifyDemandedVectorElts(Op, OpDemandedElts, OpUndefElts, + Depth + 1); + if (TmpV) { + II->setArgOperand(OpNum, TmpV); + MadeChange = true; + } + + // Pack the operand's UNDEF elements, one lane at a time. + OpUndefElts = OpUndefElts.zext(VWidth); + for (unsigned Lane = 0; Lane != NumLanes; ++Lane) { + APInt LaneElts = OpUndefElts.lshr(InnerVWidthPerLane * Lane); + LaneElts = LaneElts.getLoBits(InnerVWidthPerLane); + LaneElts <<= InnerVWidthPerLane * (2 * Lane + OpNum); + UndefElts |= LaneElts; + } + } + break; + } + + // PSHUFB + case Intrinsic::x86_ssse3_pshuf_b_128: + case Intrinsic::x86_avx2_pshuf_b: + case Intrinsic::x86_avx512_pshuf_b_512: + // PERMILVAR + case Intrinsic::x86_avx_vpermilvar_ps: + case Intrinsic::x86_avx_vpermilvar_ps_256: + case Intrinsic::x86_avx512_vpermilvar_ps_512: + case Intrinsic::x86_avx_vpermilvar_pd: + case Intrinsic::x86_avx_vpermilvar_pd_256: + case Intrinsic::x86_avx512_vpermilvar_pd_512: + // PERMV + case Intrinsic::x86_avx2_permd: + case Intrinsic::x86_avx2_permps: { + Value *Op1 = II->getArgOperand(1); + TmpV = SimplifyDemandedVectorElts(Op1, DemandedElts, UndefElts, + Depth + 1); + if (TmpV) { II->setArgOperand(1, TmpV); MadeChange = true; } + break; + } + // SSE4A instructions leave the upper 64-bits of the 128-bit result // in an undefined state. case Intrinsic::x86_sse4a_extrq: case Intrinsic::x86_sse4a_extrqi: case Intrinsic::x86_sse4a_insertq: case Intrinsic::x86_sse4a_insertqi: - UndefElts |= APInt::getHighBitsSet(VWidth, VWidth / 2); + UndefElts.setHighBits(VWidth / 2); break; + case Intrinsic::amdgcn_buffer_load: + case Intrinsic::amdgcn_buffer_load_format: + case Intrinsic::amdgcn_image_sample: + case Intrinsic::amdgcn_image_sample_cl: + case Intrinsic::amdgcn_image_sample_d: + case Intrinsic::amdgcn_image_sample_d_cl: + case Intrinsic::amdgcn_image_sample_l: + case Intrinsic::amdgcn_image_sample_b: + case Intrinsic::amdgcn_image_sample_b_cl: + case Intrinsic::amdgcn_image_sample_lz: + case Intrinsic::amdgcn_image_sample_cd: + case Intrinsic::amdgcn_image_sample_cd_cl: + + case Intrinsic::amdgcn_image_sample_c: + case Intrinsic::amdgcn_image_sample_c_cl: + case Intrinsic::amdgcn_image_sample_c_d: + case Intrinsic::amdgcn_image_sample_c_d_cl: + case Intrinsic::amdgcn_image_sample_c_l: + case Intrinsic::amdgcn_image_sample_c_b: + case Intrinsic::amdgcn_image_sample_c_b_cl: + case Intrinsic::amdgcn_image_sample_c_lz: + case Intrinsic::amdgcn_image_sample_c_cd: + case Intrinsic::amdgcn_image_sample_c_cd_cl: + + case Intrinsic::amdgcn_image_sample_o: + case Intrinsic::amdgcn_image_sample_cl_o: + case Intrinsic::amdgcn_image_sample_d_o: + case Intrinsic::amdgcn_image_sample_d_cl_o: + case Intrinsic::amdgcn_image_sample_l_o: + case Intrinsic::amdgcn_image_sample_b_o: + case Intrinsic::amdgcn_image_sample_b_cl_o: + case Intrinsic::amdgcn_image_sample_lz_o: + case Intrinsic::amdgcn_image_sample_cd_o: + case Intrinsic::amdgcn_image_sample_cd_cl_o: + + case Intrinsic::amdgcn_image_sample_c_o: + case Intrinsic::amdgcn_image_sample_c_cl_o: + case Intrinsic::amdgcn_image_sample_c_d_o: + case Intrinsic::amdgcn_image_sample_c_d_cl_o: + case Intrinsic::amdgcn_image_sample_c_l_o: + case Intrinsic::amdgcn_image_sample_c_b_o: + case Intrinsic::amdgcn_image_sample_c_b_cl_o: + case Intrinsic::amdgcn_image_sample_c_lz_o: + case Intrinsic::amdgcn_image_sample_c_cd_o: + case Intrinsic::amdgcn_image_sample_c_cd_cl_o: + + case Intrinsic::amdgcn_image_getlod: { + if (VWidth == 1 || !DemandedElts.isMask()) + return nullptr; + + // TODO: Handle 3 vectors when supported in code gen. + unsigned NewNumElts = PowerOf2Ceil(DemandedElts.countTrailingOnes()); + if (NewNumElts == VWidth) + return nullptr; + + Module *M = II->getParent()->getParent()->getParent(); + Type *EltTy = V->getType()->getVectorElementType(); + + Type *NewTy = (NewNumElts == 1) ? EltTy : + VectorType::get(EltTy, NewNumElts); + + auto IID = II->getIntrinsicID(); + + bool IsBuffer = IID == Intrinsic::amdgcn_buffer_load || + IID == Intrinsic::amdgcn_buffer_load_format; + + Function *NewIntrin = IsBuffer ? + Intrinsic::getDeclaration(M, IID, NewTy) : + // Samplers have 3 mangled types. + Intrinsic::getDeclaration(M, IID, + { NewTy, II->getArgOperand(0)->getType(), + II->getArgOperand(1)->getType()}); + + SmallVector<Value *, 5> Args; + for (unsigned I = 0, E = II->getNumArgOperands(); I != E; ++I) + Args.push_back(II->getArgOperand(I)); + + IRBuilderBase::InsertPointGuard Guard(Builder); + Builder.SetInsertPoint(II); + + CallInst *NewCall = Builder.CreateCall(NewIntrin, Args); + NewCall->takeName(II); + NewCall->copyMetadata(*II); + + if (!IsBuffer) { + ConstantInt *DMask = dyn_cast<ConstantInt>(NewCall->getArgOperand(3)); + if (DMask) { + unsigned DMaskVal = DMask->getZExtValue() & 0xf; + + unsigned PopCnt = 0; + unsigned NewDMask = 0; + for (unsigned I = 0; I < 4; ++I) { + const unsigned Bit = 1 << I; + if (!!(DMaskVal & Bit)) { + if (++PopCnt > NewNumElts) + break; + + NewDMask |= Bit; + } + } + + NewCall->setArgOperand(3, ConstantInt::get(DMask->getType(), NewDMask)); + } + } + + + if (NewNumElts == 1) { + return Builder.CreateInsertElement(UndefValue::get(V->getType()), + NewCall, static_cast<uint64_t>(0)); + } + + SmallVector<uint32_t, 8> EltMask; + for (unsigned I = 0; I < VWidth; ++I) + EltMask.push_back(I); + + Value *Shuffle = Builder.CreateShuffleVector( + NewCall, UndefValue::get(NewTy), EltMask); + + MadeChange = true; + return Shuffle; + } } break; } diff --git a/contrib/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/contrib/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index b2477f6..dd71a31 100644 --- a/contrib/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/contrib/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -144,8 +144,9 @@ Instruction *InstCombiner::scalarizePHI(ExtractElementInst &EI, PHINode *PN) { } Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) { - if (Value *V = SimplifyExtractElementInst( - EI.getVectorOperand(), EI.getIndexOperand(), DL, &TLI, &DT, &AC)) + if (Value *V = SimplifyExtractElementInst(EI.getVectorOperand(), + EI.getIndexOperand(), + SQ.getWithInstruction(&EI))) return replaceInstUsesWith(EI, V); // If vector val is constant with all elements the same, replace EI with @@ -203,11 +204,11 @@ Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) { if (I->hasOneUse() && cheapToScalarize(BO, isa<ConstantInt>(EI.getOperand(1)))) { Value *newEI0 = - Builder->CreateExtractElement(BO->getOperand(0), EI.getOperand(1), - EI.getName()+".lhs"); + Builder.CreateExtractElement(BO->getOperand(0), EI.getOperand(1), + EI.getName()+".lhs"); Value *newEI1 = - Builder->CreateExtractElement(BO->getOperand(1), EI.getOperand(1), - EI.getName()+".rhs"); + Builder.CreateExtractElement(BO->getOperand(1), EI.getOperand(1), + EI.getName()+".rhs"); return BinaryOperator::CreateWithCopiedFlags(BO->getOpcode(), newEI0, newEI1, BO); } @@ -249,8 +250,8 @@ Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) { // Bitcasts can change the number of vector elements, and they cost // nothing. if (CI->hasOneUse() && (CI->getOpcode() != Instruction::BitCast)) { - Value *EE = Builder->CreateExtractElement(CI->getOperand(0), - EI.getIndexOperand()); + Value *EE = Builder.CreateExtractElement(CI->getOperand(0), + EI.getIndexOperand()); Worklist.AddValue(EE); return CastInst::Create(CI->getOpcode(), EE, EI.getType()); } @@ -268,20 +269,20 @@ Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) { Value *Cond = SI->getCondition(); if (Cond->getType()->isVectorTy()) { - Cond = Builder->CreateExtractElement(Cond, - EI.getIndexOperand(), - Cond->getName() + ".elt"); + Cond = Builder.CreateExtractElement(Cond, + EI.getIndexOperand(), + Cond->getName() + ".elt"); } Value *V1Elem - = Builder->CreateExtractElement(TrueVal, - EI.getIndexOperand(), - TrueVal->getName() + ".elt"); + = Builder.CreateExtractElement(TrueVal, + EI.getIndexOperand(), + TrueVal->getName() + ".elt"); Value *V2Elem - = Builder->CreateExtractElement(FalseVal, - EI.getIndexOperand(), - FalseVal->getName() + ".elt"); + = Builder.CreateExtractElement(FalseVal, + EI.getIndexOperand(), + FalseVal->getName() + ".elt"); return SelectInst::Create(Cond, V1Elem, V2Elem, @@ -440,7 +441,7 @@ static void replaceExtractElements(InsertElementInst *InsElt, if (!OldExt || OldExt->getParent() != WideVec->getParent()) continue; auto *NewExt = ExtractElementInst::Create(WideVec, OldExt->getOperand(1)); - NewExt->insertAfter(WideVec); + NewExt->insertAfter(OldExt); IC.replaceInstUsesWith(*OldExt, NewExt); } } @@ -645,6 +646,36 @@ static Instruction *foldInsSequenceIntoBroadcast(InsertElementInst &InsElt) { return new ShuffleVectorInst(InsertFirst, UndefValue::get(VT), ZeroMask); } +/// If we have an insertelement instruction feeding into another insertelement +/// and the 2nd is inserting a constant into the vector, canonicalize that +/// constant insertion before the insertion of a variable: +/// +/// insertelement (insertelement X, Y, IdxC1), ScalarC, IdxC2 --> +/// insertelement (insertelement X, ScalarC, IdxC2), Y, IdxC1 +/// +/// This has the potential of eliminating the 2nd insertelement instruction +/// via constant folding of the scalar constant into a vector constant. +static Instruction *hoistInsEltConst(InsertElementInst &InsElt2, + InstCombiner::BuilderTy &Builder) { + auto *InsElt1 = dyn_cast<InsertElementInst>(InsElt2.getOperand(0)); + if (!InsElt1 || !InsElt1->hasOneUse()) + return nullptr; + + Value *X, *Y; + Constant *ScalarC; + ConstantInt *IdxC1, *IdxC2; + if (match(InsElt1->getOperand(0), m_Value(X)) && + match(InsElt1->getOperand(1), m_Value(Y)) && !isa<Constant>(Y) && + match(InsElt1->getOperand(2), m_ConstantInt(IdxC1)) && + match(InsElt2.getOperand(1), m_Constant(ScalarC)) && + match(InsElt2.getOperand(2), m_ConstantInt(IdxC2)) && IdxC1 != IdxC2) { + Value *NewInsElt1 = Builder.CreateInsertElement(X, ScalarC, IdxC2); + return InsertElementInst::Create(NewInsElt1, Y, IdxC1); + } + + return nullptr; +} + /// insertelt (shufflevector X, CVec, Mask|insertelt X, C1, CIndex1), C, CIndex /// --> shufflevector X, CVec', Mask' static Instruction *foldConstantInsEltIntoShuffle(InsertElementInst &InsElt) { @@ -806,6 +837,9 @@ Instruction *InstCombiner::visitInsertElementInst(InsertElementInst &IE) { if (Instruction *Shuf = foldConstantInsEltIntoShuffle(IE)) return Shuf; + if (Instruction *NewInsElt = hoistInsEltConst(IE, Builder)) + return NewInsElt; + // Turn a sequence of inserts that broadcasts a scalar into a single // insert + shufflevector. if (Instruction *Broadcast = foldInsSequenceIntoBroadcast(IE)) @@ -986,9 +1020,9 @@ InstCombiner::EvaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask) { SmallVector<Constant *, 16> MaskValues; for (int i = 0, e = Mask.size(); i != e; ++i) { if (Mask[i] == -1) - MaskValues.push_back(UndefValue::get(Builder->getInt32Ty())); + MaskValues.push_back(UndefValue::get(Builder.getInt32Ty())); else - MaskValues.push_back(Builder->getInt32(Mask[i])); + MaskValues.push_back(Builder.getInt32(Mask[i])); } return ConstantExpr::getShuffleVector(C, UndefValue::get(C->getType()), ConstantVector::get(MaskValues)); @@ -1061,7 +1095,7 @@ InstCombiner::EvaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask) { Value *V = EvaluateInDifferentElementOrder(I->getOperand(0), Mask); return InsertElementInst::Create(V, I->getOperand(1), - Builder->getInt32(Index), "", I); + Builder.getInt32(Index), "", I); } } llvm_unreachable("failed to reorder elements of vector instruction!"); @@ -1107,12 +1141,11 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { SmallVector<int, 16> Mask = SVI.getShuffleMask(); Type *Int32Ty = Type::getInt32Ty(SVI.getContext()); - bool MadeChange = false; - - // Undefined shuffle mask -> undefined value. - if (isa<UndefValue>(SVI.getOperand(2))) - return replaceInstUsesWith(SVI, UndefValue::get(SVI.getType())); + if (auto *V = SimplifyShuffleVectorInst( + LHS, RHS, SVI.getMask(), SVI.getType(), SQ.getWithInstruction(&SVI))) + return replaceInstUsesWith(SVI, V); + bool MadeChange = false; unsigned VWidth = SVI.getType()->getVectorNumElements(); APInt UndefElts(VWidth, 0); @@ -1209,7 +1242,6 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { if (isShuffleExtractingFromLHS(SVI, Mask)) { Value *V = LHS; unsigned MaskElems = Mask.size(); - unsigned BegIdx = Mask.front(); VectorType *SrcTy = cast<VectorType>(V->getType()); unsigned VecBitWidth = SrcTy->getBitWidth(); unsigned SrcElemBitWidth = DL.getTypeSizeInBits(SrcTy->getElementType()); @@ -1223,6 +1255,7 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { // Only visit bitcasts that weren't previously handled. BCs.push_back(BC); for (BitCastInst *BC : BCs) { + unsigned BegIdx = Mask.front(); Type *TgtTy = BC->getDestTy(); unsigned TgtElemBitWidth = DL.getTypeSizeInBits(TgtTy); if (!TgtElemBitWidth) @@ -1242,9 +1275,9 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { UndefValue::get(Int32Ty)); for (unsigned I = 0, E = MaskElems, Idx = BegIdx; I != E; ++Idx, ++I) ShuffleMask[I] = ConstantInt::get(Int32Ty, Idx); - V = Builder->CreateShuffleVector(V, UndefValue::get(V->getType()), - ConstantVector::get(ShuffleMask), - SVI.getName() + ".extract"); + V = Builder.CreateShuffleVector(V, UndefValue::get(V->getType()), + ConstantVector::get(ShuffleMask), + SVI.getName() + ".extract"); BegIdx = 0; } unsigned SrcElemsPerTgtElem = TgtElemBitWidth / SrcElemBitWidth; @@ -1254,10 +1287,10 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { auto *NewBC = BCAlreadyExists ? NewBCs[CastSrcTy] - : Builder->CreateBitCast(V, CastSrcTy, SVI.getName() + ".bc"); + : Builder.CreateBitCast(V, CastSrcTy, SVI.getName() + ".bc"); if (!BCAlreadyExists) NewBCs[CastSrcTy] = NewBC; - auto *Ext = Builder->CreateExtractElement( + auto *Ext = Builder.CreateExtractElement( NewBC, ConstantInt::get(Int32Ty, BegIdx), SVI.getName() + ".extract"); // The shufflevector isn't being replaced: the bitcast that used it // is. InstCombine will visit the newly-created instructions. diff --git a/contrib/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/contrib/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index 27fc34d..c776656 100644 --- a/contrib/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/contrib/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -33,7 +33,6 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/InstCombine/InstCombine.h" #include "InstCombineInternal.h" #include "llvm-c/Initialization.h" #include "llvm/ADT/SmallPtrSet.h" @@ -60,7 +59,9 @@ #include "llvm/IR/ValueHandle.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/KnownBits.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/InstCombine/InstCombine.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" #include <algorithm> @@ -82,18 +83,24 @@ static cl::opt<bool> EnableExpensiveCombines("expensive-combines", cl::desc("Enable expensive instruction combines")); +static cl::opt<unsigned> +MaxArraySize("instcombine-maxarray-size", cl::init(1024), + cl::desc("Maximum array size considered when doing a combine")); + Value *InstCombiner::EmitGEPOffset(User *GEP) { - return llvm::EmitGEPOffset(Builder, DL, GEP); + return llvm::EmitGEPOffset(&Builder, DL, GEP); } /// Return true if it is desirable to convert an integer computation from a /// given bit width to a new bit width. -/// We don't want to convert from a legal to an illegal type for example or from -/// a smaller to a larger illegal type. -bool InstCombiner::ShouldChangeType(unsigned FromWidth, +/// We don't want to convert from a legal to an illegal type or from a smaller +/// to a larger illegal type. A width of '1' is always treated as a legal type +/// because i1 is a fundamental type in IR, and there are many specialized +/// optimizations for i1 types. +bool InstCombiner::shouldChangeType(unsigned FromWidth, unsigned ToWidth) const { - bool FromLegal = DL.isLegalInteger(FromWidth); - bool ToLegal = DL.isLegalInteger(ToWidth); + bool FromLegal = FromWidth == 1 || DL.isLegalInteger(FromWidth); + bool ToLegal = ToWidth == 1 || DL.isLegalInteger(ToWidth); // If this is a legal integer from type, and the result would be an illegal // type, don't do the transformation. @@ -109,14 +116,16 @@ bool InstCombiner::ShouldChangeType(unsigned FromWidth, } /// Return true if it is desirable to convert a computation from 'From' to 'To'. -/// We don't want to convert from a legal to an illegal type for example or from -/// a smaller to a larger illegal type. -bool InstCombiner::ShouldChangeType(Type *From, Type *To) const { +/// We don't want to convert from a legal to an illegal type or from a smaller +/// to a larger illegal type. i1 is always treated as a legal type because it is +/// a fundamental type in IR, and there are many specialized optimizations for +/// i1 types. +bool InstCombiner::shouldChangeType(Type *From, Type *To) const { assert(From->isIntegerTy() && To->isIntegerTy()); unsigned FromWidth = From->getPrimitiveSizeInBits(); unsigned ToWidth = To->getPrimitiveSizeInBits(); - return ShouldChangeType(FromWidth, ToWidth); + return shouldChangeType(FromWidth, ToWidth); } // Return true, if No Signed Wrap should be maintained for I. @@ -140,9 +149,9 @@ static bool MaintainNoSignedWrap(BinaryOperator &I, Value *B, Value *C) { bool Overflow = false; if (Opcode == Instruction::Add) - BVal->sadd_ov(*CVal, Overflow); + (void)BVal->sadd_ov(*CVal, Overflow); else - BVal->ssub_ov(*CVal, Overflow); + (void)BVal->ssub_ov(*CVal, Overflow); return !Overflow; } @@ -247,7 +256,7 @@ bool InstCombiner::SimplifyAssociativeOrCommutative(BinaryOperator &I) { Value *C = I.getOperand(1); // Does "B op C" simplify? - if (Value *V = SimplifyBinOp(Opcode, B, C, DL)) { + if (Value *V = SimplifyBinOp(Opcode, B, C, SQ.getWithInstruction(&I))) { // It simplifies to V. Form "A op V". I.setOperand(0, A); I.setOperand(1, V); @@ -276,7 +285,7 @@ bool InstCombiner::SimplifyAssociativeOrCommutative(BinaryOperator &I) { Value *C = Op1->getOperand(1); // Does "A op B" simplify? - if (Value *V = SimplifyBinOp(Opcode, A, B, DL)) { + if (Value *V = SimplifyBinOp(Opcode, A, B, SQ.getWithInstruction(&I))) { // It simplifies to V. Form "V op C". I.setOperand(0, V); I.setOperand(1, C); @@ -304,7 +313,7 @@ bool InstCombiner::SimplifyAssociativeOrCommutative(BinaryOperator &I) { Value *C = I.getOperand(1); // Does "C op A" simplify? - if (Value *V = SimplifyBinOp(Opcode, C, A, DL)) { + if (Value *V = SimplifyBinOp(Opcode, C, A, SQ.getWithInstruction(&I))) { // It simplifies to V. Form "V op B". I.setOperand(0, V); I.setOperand(1, B); @@ -324,7 +333,7 @@ bool InstCombiner::SimplifyAssociativeOrCommutative(BinaryOperator &I) { Value *C = Op1->getOperand(1); // Does "C op A" simplify? - if (Value *V = SimplifyBinOp(Opcode, C, A, DL)) { + if (Value *V = SimplifyBinOp(Opcode, C, A, SQ.getWithInstruction(&I))) { // It simplifies to V. Form "B op V". I.setOperand(0, B); I.setOperand(1, V); @@ -447,16 +456,11 @@ static bool RightDistributesOverLeft(Instruction::BinaryOps LOp, /// This function returns identity value for given opcode, which can be used to /// factor patterns like (X * 2) + X ==> (X * 2) + (X * 1) ==> X * (2 + 1). -static Value *getIdentityValue(Instruction::BinaryOps OpCode, Value *V) { +static Value *getIdentityValue(Instruction::BinaryOps Opcode, Value *V) { if (isa<Constant>(V)) return nullptr; - if (OpCode == Instruction::Mul) - return ConstantInt::get(V->getType(), 1); - - // TODO: We can handle other cases e.g. Instruction::And, Instruction::Or etc. - - return nullptr; + return ConstantExpr::getBinOpIdentity(Opcode, V->getType()); } /// This function factors binary ops which can be combined using distributive @@ -468,8 +472,7 @@ static Value *getIdentityValue(Instruction::BinaryOps OpCode, Value *V) { static Instruction::BinaryOps getBinOpsForFactorization(Instruction::BinaryOps TopLevelOpcode, BinaryOperator *Op, Value *&LHS, Value *&RHS) { - if (!Op) - return Instruction::BinaryOpsEnd; + assert(Op && "Expected a binary operator"); LHS = Op->getOperand(0); RHS = Op->getOperand(1); @@ -495,15 +498,10 @@ getBinOpsForFactorization(Instruction::BinaryOps TopLevelOpcode, /// This tries to simplify binary operations by factorizing out common terms /// (e. g. "(A*B)+(A*C)" -> "A*(B+C)"). -static Value *tryFactorization(InstCombiner::BuilderTy *Builder, - const DataLayout &DL, BinaryOperator &I, - Instruction::BinaryOps InnerOpcode, Value *A, - Value *B, Value *C, Value *D) { - - // If any of A, B, C, D are null, we can not factor I, return early. - // Checking A and C should be enough. - if (!A || !C || !B || !D) - return nullptr; +Value *InstCombiner::tryFactorization(BinaryOperator &I, + Instruction::BinaryOps InnerOpcode, + Value *A, Value *B, Value *C, Value *D) { + assert(A && B && C && D && "All values must be provided"); Value *V = nullptr; Value *SimplifiedInst = nullptr; @@ -522,13 +520,13 @@ static Value *tryFactorization(InstCombiner::BuilderTy *Builder, std::swap(C, D); // Consider forming "A op' (B op D)". // If "B op D" simplifies then it can be formed with no cost. - V = SimplifyBinOp(TopLevelOpcode, B, D, DL); + V = SimplifyBinOp(TopLevelOpcode, B, D, SQ.getWithInstruction(&I)); // If "B op D" doesn't simplify then only go on if both of the existing // operations "A op' B" and "C op' D" will be zapped as no longer used. if (!V && LHS->hasOneUse() && RHS->hasOneUse()) - V = Builder->CreateBinOp(TopLevelOpcode, B, D, RHS->getName()); + V = Builder.CreateBinOp(TopLevelOpcode, B, D, RHS->getName()); if (V) { - SimplifiedInst = Builder->CreateBinOp(InnerOpcode, A, V); + SimplifiedInst = Builder.CreateBinOp(InnerOpcode, A, V); } } @@ -541,14 +539,14 @@ static Value *tryFactorization(InstCombiner::BuilderTy *Builder, std::swap(C, D); // Consider forming "(A op C) op' B". // If "A op C" simplifies then it can be formed with no cost. - V = SimplifyBinOp(TopLevelOpcode, A, C, DL); + V = SimplifyBinOp(TopLevelOpcode, A, C, SQ.getWithInstruction(&I)); // If "A op C" doesn't simplify then only go on if both of the existing // operations "A op' B" and "C op' D" will be zapped as no longer used. if (!V && LHS->hasOneUse() && RHS->hasOneUse()) - V = Builder->CreateBinOp(TopLevelOpcode, A, C, LHS->getName()); + V = Builder.CreateBinOp(TopLevelOpcode, A, C, LHS->getName()); if (V) { - SimplifiedInst = Builder->CreateBinOp(InnerOpcode, V, B); + SimplifiedInst = Builder.CreateBinOp(InnerOpcode, V, B); } } @@ -564,13 +562,11 @@ static Value *tryFactorization(InstCombiner::BuilderTy *Builder, if (isa<OverflowingBinaryOperator>(&I)) HasNSW = I.hasNoSignedWrap(); - if (BinaryOperator *Op0 = dyn_cast<BinaryOperator>(LHS)) - if (isa<OverflowingBinaryOperator>(Op0)) - HasNSW &= Op0->hasNoSignedWrap(); + if (auto *LOBO = dyn_cast<OverflowingBinaryOperator>(LHS)) + HasNSW &= LOBO->hasNoSignedWrap(); - if (BinaryOperator *Op1 = dyn_cast<BinaryOperator>(RHS)) - if (isa<OverflowingBinaryOperator>(Op1)) - HasNSW &= Op1->hasNoSignedWrap(); + if (auto *ROBO = dyn_cast<OverflowingBinaryOperator>(RHS)) + HasNSW &= ROBO->hasNoSignedWrap(); // We can propagate 'nsw' if we know that // %Y = mul nsw i16 %X, C @@ -599,31 +595,39 @@ Value *InstCombiner::SimplifyUsingDistributiveLaws(BinaryOperator &I) { Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); BinaryOperator *Op0 = dyn_cast<BinaryOperator>(LHS); BinaryOperator *Op1 = dyn_cast<BinaryOperator>(RHS); + Instruction::BinaryOps TopLevelOpcode = I.getOpcode(); - // Factorization. - Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr; - auto TopLevelOpcode = I.getOpcode(); - auto LHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op0, A, B); - auto RHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op1, C, D); - - // The instruction has the form "(A op' B) op (C op' D)". Try to factorize - // a common term. - if (LHSOpcode == RHSOpcode) { - if (Value *V = tryFactorization(Builder, DL, I, LHSOpcode, A, B, C, D)) - return V; - } - - // The instruction has the form "(A op' B) op (C)". Try to factorize common - // term. - if (Value *V = tryFactorization(Builder, DL, I, LHSOpcode, A, B, RHS, - getIdentityValue(LHSOpcode, RHS))) - return V; + { + // Factorization. + Value *A, *B, *C, *D; + Instruction::BinaryOps LHSOpcode, RHSOpcode; + if (Op0) + LHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op0, A, B); + if (Op1) + RHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op1, C, D); + + // The instruction has the form "(A op' B) op (C op' D)". Try to factorize + // a common term. + if (Op0 && Op1 && LHSOpcode == RHSOpcode) + if (Value *V = tryFactorization(I, LHSOpcode, A, B, C, D)) + return V; + + // The instruction has the form "(A op' B) op (C)". Try to factorize common + // term. + if (Op0) + if (Value *Ident = getIdentityValue(LHSOpcode, RHS)) + if (Value *V = + tryFactorization(I, LHSOpcode, A, B, RHS, Ident)) + return V; - // The instruction has the form "(B) op (C op' D)". Try to factorize common - // term. - if (Value *V = tryFactorization(Builder, DL, I, RHSOpcode, LHS, - getIdentityValue(RHSOpcode, LHS), C, D)) - return V; + // The instruction has the form "(B) op (C op' D)". Try to factorize common + // term. + if (Op1) + if (Value *Ident = getIdentityValue(RHSOpcode, LHS)) + if (Value *V = + tryFactorization(I, RHSOpcode, LHS, Ident, C, D)) + return V; + } // Expansion. if (Op0 && RightDistributesOverLeft(Op0->getOpcode(), TopLevelOpcode)) { @@ -632,23 +636,35 @@ Value *InstCombiner::SimplifyUsingDistributiveLaws(BinaryOperator &I) { Value *A = Op0->getOperand(0), *B = Op0->getOperand(1), *C = RHS; Instruction::BinaryOps InnerOpcode = Op0->getOpcode(); // op' + Value *L = SimplifyBinOp(TopLevelOpcode, A, C, SQ.getWithInstruction(&I)); + Value *R = SimplifyBinOp(TopLevelOpcode, B, C, SQ.getWithInstruction(&I)); + // Do "A op C" and "B op C" both simplify? - if (Value *L = SimplifyBinOp(TopLevelOpcode, A, C, DL)) - if (Value *R = SimplifyBinOp(TopLevelOpcode, B, C, DL)) { - // They do! Return "L op' R". - ++NumExpand; - // If "L op' R" equals "A op' B" then "L op' R" is just the LHS. - if ((L == A && R == B) || - (Instruction::isCommutative(InnerOpcode) && L == B && R == A)) - return Op0; - // Otherwise return "L op' R" if it simplifies. - if (Value *V = SimplifyBinOp(InnerOpcode, L, R, DL)) - return V; - // Otherwise, create a new instruction. - C = Builder->CreateBinOp(InnerOpcode, L, R); - C->takeName(&I); - return C; - } + if (L && R) { + // They do! Return "L op' R". + ++NumExpand; + C = Builder.CreateBinOp(InnerOpcode, L, R); + C->takeName(&I); + return C; + } + + // Does "A op C" simplify to the identity value for the inner opcode? + if (L && L == ConstantExpr::getBinOpIdentity(InnerOpcode, L->getType())) { + // They do! Return "B op C". + ++NumExpand; + C = Builder.CreateBinOp(TopLevelOpcode, B, C); + C->takeName(&I); + return C; + } + + // Does "B op C" simplify to the identity value for the inner opcode? + if (R && R == ConstantExpr::getBinOpIdentity(InnerOpcode, R->getType())) { + // They do! Return "A op C". + ++NumExpand; + C = Builder.CreateBinOp(TopLevelOpcode, A, C); + C->takeName(&I); + return C; + } } if (Op1 && LeftDistributesOverRight(TopLevelOpcode, Op1->getOpcode())) { @@ -657,23 +673,35 @@ Value *InstCombiner::SimplifyUsingDistributiveLaws(BinaryOperator &I) { Value *A = LHS, *B = Op1->getOperand(0), *C = Op1->getOperand(1); Instruction::BinaryOps InnerOpcode = Op1->getOpcode(); // op' + Value *L = SimplifyBinOp(TopLevelOpcode, A, B, SQ.getWithInstruction(&I)); + Value *R = SimplifyBinOp(TopLevelOpcode, A, C, SQ.getWithInstruction(&I)); + // Do "A op B" and "A op C" both simplify? - if (Value *L = SimplifyBinOp(TopLevelOpcode, A, B, DL)) - if (Value *R = SimplifyBinOp(TopLevelOpcode, A, C, DL)) { - // They do! Return "L op' R". - ++NumExpand; - // If "L op' R" equals "B op' C" then "L op' R" is just the RHS. - if ((L == B && R == C) || - (Instruction::isCommutative(InnerOpcode) && L == C && R == B)) - return Op1; - // Otherwise return "L op' R" if it simplifies. - if (Value *V = SimplifyBinOp(InnerOpcode, L, R, DL)) - return V; - // Otherwise, create a new instruction. - A = Builder->CreateBinOp(InnerOpcode, L, R); - A->takeName(&I); - return A; - } + if (L && R) { + // They do! Return "L op' R". + ++NumExpand; + A = Builder.CreateBinOp(InnerOpcode, L, R); + A->takeName(&I); + return A; + } + + // Does "A op B" simplify to the identity value for the inner opcode? + if (L && L == ConstantExpr::getBinOpIdentity(InnerOpcode, L->getType())) { + // They do! Return "A op C". + ++NumExpand; + A = Builder.CreateBinOp(TopLevelOpcode, A, C); + A->takeName(&I); + return A; + } + + // Does "A op C" simplify to the identity value for the inner opcode? + if (R && R == ConstantExpr::getBinOpIdentity(InnerOpcode, R->getType())) { + // They do! Return "A op B". + ++NumExpand; + A = Builder.CreateBinOp(TopLevelOpcode, A, B); + A->takeName(&I); + return A; + } } // (op (select (a, c, b)), (select (a, d, b))) -> (select (a, (op c, d), 0)) @@ -682,19 +710,21 @@ Value *InstCombiner::SimplifyUsingDistributiveLaws(BinaryOperator &I) { if (auto *SI1 = dyn_cast<SelectInst>(RHS)) { if (SI0->getCondition() == SI1->getCondition()) { Value *SI = nullptr; - if (Value *V = SimplifyBinOp(TopLevelOpcode, SI0->getFalseValue(), - SI1->getFalseValue(), DL, &TLI, &DT, &AC)) - SI = Builder->CreateSelect(SI0->getCondition(), - Builder->CreateBinOp(TopLevelOpcode, - SI0->getTrueValue(), - SI1->getTrueValue()), - V); - if (Value *V = SimplifyBinOp(TopLevelOpcode, SI0->getTrueValue(), - SI1->getTrueValue(), DL, &TLI, &DT, &AC)) - SI = Builder->CreateSelect( + if (Value *V = + SimplifyBinOp(TopLevelOpcode, SI0->getFalseValue(), + SI1->getFalseValue(), SQ.getWithInstruction(&I))) + SI = Builder.CreateSelect(SI0->getCondition(), + Builder.CreateBinOp(TopLevelOpcode, + SI0->getTrueValue(), + SI1->getTrueValue()), + V); + if (Value *V = + SimplifyBinOp(TopLevelOpcode, SI0->getTrueValue(), + SI1->getTrueValue(), SQ.getWithInstruction(&I))) + SI = Builder.CreateSelect( SI0->getCondition(), V, - Builder->CreateBinOp(TopLevelOpcode, SI0->getFalseValue(), - SI1->getFalseValue())); + Builder.CreateBinOp(TopLevelOpcode, SI0->getFalseValue(), + SI1->getFalseValue())); if (SI) { SI->takeName(&I); return SI; @@ -720,6 +750,21 @@ Value *InstCombiner::dyn_castNegVal(Value *V) const { if (C->getType()->getElementType()->isIntegerTy()) return ConstantExpr::getNeg(C); + if (ConstantVector *CV = dyn_cast<ConstantVector>(V)) { + for (unsigned i = 0, e = CV->getNumOperands(); i != e; ++i) { + Constant *Elt = CV->getAggregateElement(i); + if (!Elt) + return nullptr; + + if (isa<UndefValue>(Elt)) + continue; + + if (!isa<ConstantInt>(Elt)) + return nullptr; + } + return ConstantExpr::getNeg(CV); + } + return nullptr; } @@ -741,9 +786,9 @@ Value *InstCombiner::dyn_castFNegVal(Value *V, bool IgnoreZeroSign) const { } static Value *foldOperationIntoSelectOperand(Instruction &I, Value *SO, - InstCombiner *IC) { + InstCombiner::BuilderTy &Builder) { if (auto *Cast = dyn_cast<CastInst>(&I)) - return IC->Builder->CreateCast(Cast->getOpcode(), SO, I.getType()); + return Builder.CreateCast(Cast->getOpcode(), SO, I.getType()); assert(I.isBinaryOp() && "Unexpected opcode for select folding"); @@ -762,8 +807,8 @@ static Value *foldOperationIntoSelectOperand(Instruction &I, Value *SO, std::swap(Op0, Op1); auto *BO = cast<BinaryOperator>(&I); - Value *RI = IC->Builder->CreateBinOp(BO->getOpcode(), Op0, Op1, - SO->getName() + ".op"); + Value *RI = Builder.CreateBinOp(BO->getOpcode(), Op0, Op1, + SO->getName() + ".op"); auto *FPInst = dyn_cast<Instruction>(RI); if (FPInst && isa<FPMathOperator>(FPInst)) FPInst->copyFastMathFlags(BO); @@ -781,7 +826,7 @@ Instruction *InstCombiner::FoldOpIntoSelect(Instruction &Op, SelectInst *SI) { return nullptr; // Bool selects with constant operands can be folded to logical ops. - if (SI->getType()->getScalarType()->isIntegerTy(1)) + if (SI->getType()->isIntOrIntVectorTy(1)) return nullptr; // If it's a bitcast involving vectors, make sure it has the same number of @@ -815,13 +860,34 @@ Instruction *InstCombiner::FoldOpIntoSelect(Instruction &Op, SelectInst *SI) { } } - Value *NewTV = foldOperationIntoSelectOperand(Op, TV, this); - Value *NewFV = foldOperationIntoSelectOperand(Op, FV, this); + Value *NewTV = foldOperationIntoSelectOperand(Op, TV, Builder); + Value *NewFV = foldOperationIntoSelectOperand(Op, FV, Builder); return SelectInst::Create(SI->getCondition(), NewTV, NewFV, "", nullptr, SI); } -Instruction *InstCombiner::FoldOpIntoPhi(Instruction &I) { - PHINode *PN = cast<PHINode>(I.getOperand(0)); +static Value *foldOperationIntoPhiValue(BinaryOperator *I, Value *InV, + InstCombiner::BuilderTy &Builder) { + bool ConstIsRHS = isa<Constant>(I->getOperand(1)); + Constant *C = cast<Constant>(I->getOperand(ConstIsRHS)); + + if (auto *InC = dyn_cast<Constant>(InV)) { + if (ConstIsRHS) + return ConstantExpr::get(I->getOpcode(), InC, C); + return ConstantExpr::get(I->getOpcode(), C, InC); + } + + Value *Op0 = InV, *Op1 = C; + if (!ConstIsRHS) + std::swap(Op0, Op1); + + Value *RI = Builder.CreateBinOp(I->getOpcode(), Op0, Op1, "phitmp"); + auto *FPInst = dyn_cast<Instruction>(RI); + if (FPInst && isa<FPMathOperator>(FPInst)) + FPInst->copyFastMathFlags(I); + return RI; +} + +Instruction *InstCombiner::foldOpIntoPhi(Instruction &I, PHINode *PN) { unsigned NumPHIValues = PN->getNumIncomingValues(); if (NumPHIValues == 0) return nullptr; @@ -885,7 +951,7 @@ Instruction *InstCombiner::FoldOpIntoPhi(Instruction &I) { // If we are going to have to insert a new computation, do so right before the // predecessor's terminator. if (NonConstBB) - Builder->SetInsertPoint(NonConstBB->getTerminator()); + Builder.SetInsertPoint(NonConstBB->getTerminator()); // Next, add all of the operands to the PHI. if (SelectInst *SI = dyn_cast<SelectInst>(&I)) { @@ -902,11 +968,25 @@ Instruction *InstCombiner::FoldOpIntoPhi(Instruction &I) { // Beware of ConstantExpr: it may eventually evaluate to getNullValue, // even if currently isNullValue gives false. Constant *InC = dyn_cast<Constant>(PN->getIncomingValue(i)); - if (InC && !isa<ConstantExpr>(InC)) + // For vector constants, we cannot use isNullValue to fold into + // FalseVInPred versus TrueVInPred. When we have individual nonzero + // elements in the vector, we will incorrectly fold InC to + // `TrueVInPred`. + if (InC && !isa<ConstantExpr>(InC) && isa<ConstantInt>(InC)) InV = InC->isNullValue() ? FalseVInPred : TrueVInPred; - else - InV = Builder->CreateSelect(PN->getIncomingValue(i), - TrueVInPred, FalseVInPred, "phitmp"); + else { + // Generate the select in the same block as PN's current incoming block. + // Note: ThisBB need not be the NonConstBB because vector constants + // which are constants by definition are handled here. + // FIXME: This can lead to an increase in IR generation because we might + // generate selects for vector constant phi operand, that could not be + // folded to TrueVInPred or FalseVInPred as done for ConstantInt. For + // non-vector phis, this transformation was always profitable because + // the select would be generated exactly once in the NonConstBB. + Builder.SetInsertPoint(ThisBB->getTerminator()); + InV = Builder.CreateSelect(PN->getIncomingValue(i), TrueVInPred, + FalseVInPred, "phitmp"); + } NewPN->addIncoming(InV, ThisBB); } } else if (CmpInst *CI = dyn_cast<CmpInst>(&I)) { @@ -916,22 +996,17 @@ Instruction *InstCombiner::FoldOpIntoPhi(Instruction &I) { if (Constant *InC = dyn_cast<Constant>(PN->getIncomingValue(i))) InV = ConstantExpr::getCompare(CI->getPredicate(), InC, C); else if (isa<ICmpInst>(CI)) - InV = Builder->CreateICmp(CI->getPredicate(), PN->getIncomingValue(i), - C, "phitmp"); + InV = Builder.CreateICmp(CI->getPredicate(), PN->getIncomingValue(i), + C, "phitmp"); else - InV = Builder->CreateFCmp(CI->getPredicate(), PN->getIncomingValue(i), - C, "phitmp"); + InV = Builder.CreateFCmp(CI->getPredicate(), PN->getIncomingValue(i), + C, "phitmp"); NewPN->addIncoming(InV, PN->getIncomingBlock(i)); } - } else if (I.getNumOperands() == 2) { - Constant *C = cast<Constant>(I.getOperand(1)); + } else if (auto *BO = dyn_cast<BinaryOperator>(&I)) { for (unsigned i = 0; i != NumPHIValues; ++i) { - Value *InV = nullptr; - if (Constant *InC = dyn_cast<Constant>(PN->getIncomingValue(i))) - InV = ConstantExpr::get(I.getOpcode(), InC, C); - else - InV = Builder->CreateBinOp(cast<BinaryOperator>(I).getOpcode(), - PN->getIncomingValue(i), C, "phitmp"); + Value *InV = foldOperationIntoPhiValue(BO, PN->getIncomingValue(i), + Builder); NewPN->addIncoming(InV, PN->getIncomingBlock(i)); } } else { @@ -942,8 +1017,8 @@ Instruction *InstCombiner::FoldOpIntoPhi(Instruction &I) { if (Constant *InC = dyn_cast<Constant>(PN->getIncomingValue(i))) InV = ConstantExpr::getCast(CI->getOpcode(), InC, RetTy); else - InV = Builder->CreateCast(CI->getOpcode(), - PN->getIncomingValue(i), I.getType(), "phitmp"); + InV = Builder.CreateCast(CI->getOpcode(), PN->getIncomingValue(i), + I.getType(), "phitmp"); NewPN->addIncoming(InV, PN->getIncomingBlock(i)); } } @@ -957,14 +1032,14 @@ Instruction *InstCombiner::FoldOpIntoPhi(Instruction &I) { return replaceInstUsesWith(I, NewPN); } -Instruction *InstCombiner::foldOpWithConstantIntoOperand(Instruction &I) { +Instruction *InstCombiner::foldOpWithConstantIntoOperand(BinaryOperator &I) { assert(isa<Constant>(I.getOperand(1)) && "Unexpected operand type"); if (auto *Sel = dyn_cast<SelectInst>(I.getOperand(0))) { if (Instruction *NewSel = FoldOpIntoSelect(I, Sel)) return NewSel; - } else if (isa<PHINode>(I.getOperand(0))) { - if (Instruction *NewPhi = FoldOpIntoPhi(I)) + } else if (auto *PN = dyn_cast<PHINode>(I.getOperand(0))) { + if (Instruction *NewPhi = foldOpIntoPhi(I, PN)) return NewPhi; } return nullptr; @@ -1289,8 +1364,8 @@ Value *InstCombiner::Descale(Value *Val, APInt Scale, bool &NoSignedWrap) { /// \brief Creates node of binary operation with the same attributes as the /// specified one but with other operands. static Value *CreateBinOpAsGiven(BinaryOperator &Inst, Value *LHS, Value *RHS, - InstCombiner::BuilderTy *B) { - Value *BO = B->CreateBinOp(Inst.getOpcode(), LHS, RHS); + InstCombiner::BuilderTy &B) { + Value *BO = B.CreateBinOp(Inst.getOpcode(), LHS, RHS); // If LHS and RHS are constant, BO won't be a binary operator. if (BinaryOperator *NewBO = dyn_cast<BinaryOperator>(BO)) NewBO->copyIRFlags(&Inst); @@ -1315,22 +1390,19 @@ Value *InstCombiner::SimplifyVectorOp(BinaryOperator &Inst) { assert(cast<VectorType>(LHS->getType())->getNumElements() == VWidth); assert(cast<VectorType>(RHS->getType())->getNumElements() == VWidth); - // If both arguments of binary operation are shuffles, which use the same - // mask and shuffle within a single vector, it is worthwhile to move the - // shuffle after binary operation: + // If both arguments of the binary operation are shuffles that use the same + // mask and shuffle within a single vector, move the shuffle after the binop: // Op(shuffle(v1, m), shuffle(v2, m)) -> shuffle(Op(v1, v2), m) - if (isa<ShuffleVectorInst>(LHS) && isa<ShuffleVectorInst>(RHS)) { - ShuffleVectorInst *LShuf = cast<ShuffleVectorInst>(LHS); - ShuffleVectorInst *RShuf = cast<ShuffleVectorInst>(RHS); - if (isa<UndefValue>(LShuf->getOperand(1)) && - isa<UndefValue>(RShuf->getOperand(1)) && - LShuf->getOperand(0)->getType() == RShuf->getOperand(0)->getType() && - LShuf->getMask() == RShuf->getMask()) { - Value *NewBO = CreateBinOpAsGiven(Inst, LShuf->getOperand(0), - RShuf->getOperand(0), Builder); - return Builder->CreateShuffleVector(NewBO, - UndefValue::get(NewBO->getType()), LShuf->getMask()); - } + auto *LShuf = dyn_cast<ShuffleVectorInst>(LHS); + auto *RShuf = dyn_cast<ShuffleVectorInst>(RHS); + if (LShuf && RShuf && LShuf->getMask() == RShuf->getMask() && + isa<UndefValue>(LShuf->getOperand(1)) && + isa<UndefValue>(RShuf->getOperand(1)) && + LShuf->getOperand(0)->getType() == RShuf->getOperand(0)->getType()) { + Value *NewBO = CreateBinOpAsGiven(Inst, LShuf->getOperand(0), + RShuf->getOperand(0), Builder); + return Builder.CreateShuffleVector( + NewBO, UndefValue::get(NewBO->getType()), LShuf->getMask()); } // If one argument is a shuffle within one vector, the other is a constant, @@ -1368,7 +1440,7 @@ Value *InstCombiner::SimplifyVectorOp(BinaryOperator &Inst) { Value *NewLHS = isa<Constant>(LHS) ? C2 : Shuffle->getOperand(0); Value *NewRHS = isa<Constant>(LHS) ? Shuffle->getOperand(0) : C2; Value *NewBO = CreateBinOpAsGiven(Inst, NewLHS, NewRHS, Builder); - return Builder->CreateShuffleVector(NewBO, + return Builder.CreateShuffleVector(NewBO, UndefValue::get(Inst.getType()), Shuffle->getMask()); } } @@ -1379,8 +1451,8 @@ Value *InstCombiner::SimplifyVectorOp(BinaryOperator &Inst) { Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { SmallVector<Value*, 8> Ops(GEP.op_begin(), GEP.op_end()); - if (Value *V = - SimplifyGEPInst(GEP.getSourceElementType(), Ops, DL, &TLI, &DT, &AC)) + if (Value *V = SimplifyGEPInst(GEP.getSourceElementType(), Ops, + SQ.getWithInstruction(&GEP))) return replaceInstUsesWith(GEP, V); Value *PtrOp = GEP.getOperand(0); @@ -1416,7 +1488,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // If we are using a wider index than needed for this platform, shrink // it to what we need. If narrower, sign-extend it to what we need. // This explicit cast can make subsequent optimizations more obvious. - *I = Builder->CreateIntCast(*I, NewIndexType, true); + *I = Builder.CreateIntCast(*I, NewIndexType, true); MadeChange = true; } } @@ -1510,10 +1582,10 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // set that index. PHINode *NewPN; { - IRBuilderBase::InsertPointGuard Guard(*Builder); - Builder->SetInsertPoint(PN); - NewPN = Builder->CreatePHI(Op1->getOperand(DI)->getType(), - PN->getNumOperands()); + IRBuilderBase::InsertPointGuard Guard(Builder); + Builder.SetInsertPoint(PN); + NewPN = Builder.CreatePHI(Op1->getOperand(DI)->getType(), + PN->getNumOperands()); } for (auto &I : PN->operands()) @@ -1559,27 +1631,22 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // Replace: gep (gep %P, long B), long A, ... // With: T = long A+B; gep %P, T, ... // - Value *Sum; Value *SO1 = Src->getOperand(Src->getNumOperands()-1); Value *GO1 = GEP.getOperand(1); - if (SO1 == Constant::getNullValue(SO1->getType())) { - Sum = GO1; - } else if (GO1 == Constant::getNullValue(GO1->getType())) { - Sum = SO1; - } else { - // If they aren't the same type, then the input hasn't been processed - // by the loop above yet (which canonicalizes sequential index types to - // intptr_t). Just avoid transforming this until the input has been - // normalized. - if (SO1->getType() != GO1->getType()) - return nullptr; - // Only do the combine when GO1 and SO1 are both constants. Only in - // this case, we are sure the cost after the merge is never more than - // that before the merge. - if (!isa<Constant>(GO1) || !isa<Constant>(SO1)) - return nullptr; - Sum = Builder->CreateAdd(SO1, GO1, PtrOp->getName()+".sum"); - } + + // If they aren't the same type, then the input hasn't been processed + // by the loop above yet (which canonicalizes sequential index types to + // intptr_t). Just avoid transforming this until the input has been + // normalized. + if (SO1->getType() != GO1->getType()) + return nullptr; + + Value *Sum = + SimplifyAddInst(GO1, SO1, false, false, SQ.getWithInstruction(&GEP)); + // Only do the combine when we are sure the cost after the + // merge is never more than that before the merge. + if (Sum == nullptr) + return nullptr; // Update the GEP in place if possible. if (Src->getNumOperands() == 2) { @@ -1638,8 +1705,8 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // pointer arithmetic. if (match(V, m_Neg(m_PtrToInt(m_Value())))) { Operator *Index = cast<Operator>(V); - Value *PtrToInt = Builder->CreatePtrToInt(PtrOp, Index->getType()); - Value *NewSub = Builder->CreateSub(PtrToInt, Index->getOperand(1)); + Value *PtrToInt = Builder.CreatePtrToInt(PtrOp, Index->getType()); + Value *NewSub = Builder.CreateSub(PtrToInt, Index->getOperand(1)); return CastInst::Create(Instruction::IntToPtr, NewSub, GEP.getType()); } // Canonicalize (gep i8* X, (ptrtoint Y)-(ptrtoint X)) @@ -1654,14 +1721,14 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { } } - // Handle gep(bitcast x) and gep(gep x, 0, 0, 0). - Value *StrippedPtr = PtrOp->stripPointerCasts(); - PointerType *StrippedPtrTy = dyn_cast<PointerType>(StrippedPtr->getType()); - // We do not handle pointer-vector geps here. - if (!StrippedPtrTy) + if (GEP.getType()->isVectorTy()) return nullptr; + // Handle gep(bitcast x) and gep(gep x, 0, 0, 0). + Value *StrippedPtr = PtrOp->stripPointerCasts(); + PointerType *StrippedPtrTy = cast<PointerType>(StrippedPtr->getType()); + if (StrippedPtr != PtrOp) { bool HasZeroPointerIndex = false; if (ConstantInt *C = dyn_cast<ConstantInt>(GEP.getOperand(1))) @@ -1692,7 +1759,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // -> // %0 = GEP i8 addrspace(1)* X, ... // addrspacecast i8 addrspace(1)* %0 to i8* - return new AddrSpaceCastInst(Builder->Insert(Res), GEP.getType()); + return new AddrSpaceCastInst(Builder.Insert(Res), GEP.getType()); } if (ArrayType *XATy = @@ -1720,10 +1787,10 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // addrspacecast i8 addrspace(1)* %0 to i8* SmallVector<Value*, 8> Idx(GEP.idx_begin(), GEP.idx_end()); Value *NewGEP = GEP.isInBounds() - ? Builder->CreateInBoundsGEP( + ? Builder.CreateInBoundsGEP( nullptr, StrippedPtr, Idx, GEP.getName()) - : Builder->CreateGEP(nullptr, StrippedPtr, Idx, - GEP.getName()); + : Builder.CreateGEP(nullptr, StrippedPtr, Idx, + GEP.getName()); return new AddrSpaceCastInst(NewGEP, GEP.getType()); } } @@ -1741,9 +1808,9 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { Value *Idx[2] = { Constant::getNullValue(IdxType), GEP.getOperand(1) }; Value *NewGEP = GEP.isInBounds() - ? Builder->CreateInBoundsGEP(nullptr, StrippedPtr, Idx, - GEP.getName()) - : Builder->CreateGEP(nullptr, StrippedPtr, Idx, GEP.getName()); + ? Builder.CreateInBoundsGEP(nullptr, StrippedPtr, Idx, + GEP.getName()) + : Builder.CreateGEP(nullptr, StrippedPtr, Idx, GEP.getName()); // V and GEP are both pointer types --> BitCast return CastInst::CreatePointerBitCastOrAddrSpaceCast(NewGEP, @@ -1776,10 +1843,10 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // GEP may not be "inbounds". Value *NewGEP = GEP.isInBounds() && NSW - ? Builder->CreateInBoundsGEP(nullptr, StrippedPtr, NewIdx, - GEP.getName()) - : Builder->CreateGEP(nullptr, StrippedPtr, NewIdx, - GEP.getName()); + ? Builder.CreateInBoundsGEP(nullptr, StrippedPtr, NewIdx, + GEP.getName()) + : Builder.CreateGEP(nullptr, StrippedPtr, NewIdx, + GEP.getName()); // The NewGEP must be pointer typed, so must the old one -> BitCast return CastInst::CreatePointerBitCastOrAddrSpaceCast(NewGEP, @@ -1818,10 +1885,10 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { NewIdx}; Value *NewGEP = GEP.isInBounds() && NSW - ? Builder->CreateInBoundsGEP( + ? Builder.CreateInBoundsGEP( SrcElTy, StrippedPtr, Off, GEP.getName()) - : Builder->CreateGEP(SrcElTy, StrippedPtr, Off, - GEP.getName()); + : Builder.CreateGEP(SrcElTy, StrippedPtr, Off, + GEP.getName()); // The NewGEP must be pointer typed, so must the old one -> BitCast return CastInst::CreatePointerBitCastOrAddrSpaceCast(NewGEP, GEP.getType()); @@ -1885,8 +1952,8 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (FindElementAtOffset(OpType, Offset.getSExtValue(), NewIndices)) { Value *NGEP = GEP.isInBounds() - ? Builder->CreateInBoundsGEP(nullptr, Operand, NewIndices) - : Builder->CreateGEP(nullptr, Operand, NewIndices); + ? Builder.CreateInBoundsGEP(nullptr, Operand, NewIndices) + : Builder.CreateGEP(nullptr, Operand, NewIndices); if (NGEP->getType() == GEP.getType()) return replaceInstUsesWith(GEP, NGEP); @@ -1935,9 +2002,9 @@ static bool isNeverEqualToUnescapedAlloc(Value *V, const TargetLibraryInfo *TLI, return isAllocLikeFn(V, TLI) && V != AI; } -static bool -isAllocSiteRemovable(Instruction *AI, SmallVectorImpl<WeakVH> &Users, - const TargetLibraryInfo *TLI) { +static bool isAllocSiteRemovable(Instruction *AI, + SmallVectorImpl<WeakTrackingVH> &Users, + const TargetLibraryInfo *TLI) { SmallVector<Instruction*, 4> Worklist; Worklist.push_back(AI); @@ -1950,6 +2017,7 @@ isAllocSiteRemovable(Instruction *AI, SmallVectorImpl<WeakVH> &Users, // Give up the moment we see something we can't handle. return false; + case Instruction::AddrSpaceCast: case Instruction::BitCast: case Instruction::GetElementPtr: Users.emplace_back(I); @@ -2021,7 +2089,7 @@ Instruction *InstCombiner::visitAllocSite(Instruction &MI) { // If we have a malloc call which is only used in any amount of comparisons // to null and free calls, delete the calls and replace the comparisons with // true or false as appropriate. - SmallVector<WeakVH, 64> Users; + SmallVector<WeakTrackingVH, 64> Users; if (isAllocSiteRemovable(&MI, Users, &TLI)) { for (unsigned i = 0, e = Users.size(); i != e; ++i) { // Lowering all @llvm.objectsize calls first because they may @@ -2051,7 +2119,8 @@ Instruction *InstCombiner::visitAllocSite(Instruction &MI) { replaceInstUsesWith(*C, ConstantInt::get(Type::getInt1Ty(C->getContext()), C->isFalseWhenEqual())); - } else if (isa<BitCastInst>(I) || isa<GetElementPtrInst>(I)) { + } else if (isa<BitCastInst>(I) || isa<GetElementPtrInst>(I) || + isa<AddrSpaceCastInst>(I)) { replaceInstUsesWith(*I, UndefValue::get(I->getType())); } eraseInstFromFunction(*I); @@ -2133,8 +2202,8 @@ Instruction *InstCombiner::visitFree(CallInst &FI) { // free undef -> unreachable. if (isa<UndefValue>(Op)) { // Insert a new store to null because we cannot modify the CFG here. - Builder->CreateStore(ConstantInt::getTrue(FI.getContext()), - UndefValue::get(Type::getInt1PtrTy(FI.getContext()))); + Builder.CreateStore(ConstantInt::getTrue(FI.getContext()), + UndefValue::get(Type::getInt1PtrTy(FI.getContext()))); return eraseInstFromFunction(FI); } @@ -2167,11 +2236,9 @@ Instruction *InstCombiner::visitReturnInst(ReturnInst &RI) { // There might be assume intrinsics dominating this return that completely // determine the value. If so, constant fold it. - unsigned BitWidth = VTy->getPrimitiveSizeInBits(); - APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - computeKnownBits(ResultOp, KnownZero, KnownOne, 0, &RI); - if ((KnownZero|KnownOne).isAllOnesValue()) - RI.setOperand(0, Constant::getIntegerValue(VTy, KnownOne)); + KnownBits Known = computeKnownBits(ResultOp, 0, &RI); + if (Known.isConstant()) + RI.setOperand(0, Constant::getIntegerValue(VTy, Known.getConstant())); return nullptr; } @@ -2198,37 +2265,18 @@ Instruction *InstCombiner::visitBranchInst(BranchInst &BI) { return &BI; } - // Canonicalize fcmp_one -> fcmp_oeq - FCmpInst::Predicate FPred; Value *Y; - if (match(&BI, m_Br(m_FCmp(FPred, m_Value(X), m_Value(Y)), - TrueDest, FalseDest)) && - BI.getCondition()->hasOneUse()) - if (FPred == FCmpInst::FCMP_ONE || FPred == FCmpInst::FCMP_OLE || - FPred == FCmpInst::FCMP_OGE) { - FCmpInst *Cond = cast<FCmpInst>(BI.getCondition()); - Cond->setPredicate(FCmpInst::getInversePredicate(FPred)); - - // Swap Destinations and condition. - BI.swapSuccessors(); - Worklist.Add(Cond); - return &BI; - } - - // Canonicalize icmp_ne -> icmp_eq - ICmpInst::Predicate IPred; - if (match(&BI, m_Br(m_ICmp(IPred, m_Value(X), m_Value(Y)), - TrueDest, FalseDest)) && - BI.getCondition()->hasOneUse()) - if (IPred == ICmpInst::ICMP_NE || IPred == ICmpInst::ICMP_ULE || - IPred == ICmpInst::ICMP_SLE || IPred == ICmpInst::ICMP_UGE || - IPred == ICmpInst::ICMP_SGE) { - ICmpInst *Cond = cast<ICmpInst>(BI.getCondition()); - Cond->setPredicate(ICmpInst::getInversePredicate(IPred)); - // Swap Destinations and condition. - BI.swapSuccessors(); - Worklist.Add(Cond); - return &BI; - } + // Canonicalize, for example, icmp_ne -> icmp_eq or fcmp_one -> fcmp_oeq. + CmpInst::Predicate Pred; + if (match(&BI, m_Br(m_OneUse(m_Cmp(Pred, m_Value(), m_Value())), TrueDest, + FalseDest)) && + !isCanonicalPredicate(Pred)) { + // Swap destinations and condition. + CmpInst *Cond = cast<CmpInst>(BI.getCondition()); + Cond->setPredicate(CmpInst::getInversePredicate(Pred)); + BI.swapSuccessors(); + Worklist.Add(Cond); + return &BI; + } return nullptr; } @@ -2239,21 +2287,19 @@ Instruction *InstCombiner::visitSwitchInst(SwitchInst &SI) { ConstantInt *AddRHS; if (match(Cond, m_Add(m_Value(Op0), m_ConstantInt(AddRHS)))) { // Change 'switch (X+4) case 1:' into 'switch (X) case -3'. - for (SwitchInst::CaseIt CaseIter : SI.cases()) { - Constant *NewCase = ConstantExpr::getSub(CaseIter.getCaseValue(), AddRHS); + for (auto Case : SI.cases()) { + Constant *NewCase = ConstantExpr::getSub(Case.getCaseValue(), AddRHS); assert(isa<ConstantInt>(NewCase) && "Result of expression should be constant"); - CaseIter.setValue(cast<ConstantInt>(NewCase)); + Case.setValue(cast<ConstantInt>(NewCase)); } SI.setCondition(Op0); return &SI; } - unsigned BitWidth = cast<IntegerType>(Cond->getType())->getBitWidth(); - APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - computeKnownBits(Cond, KnownZero, KnownOne, 0, &SI); - unsigned LeadingKnownZeros = KnownZero.countLeadingOnes(); - unsigned LeadingKnownOnes = KnownOne.countLeadingOnes(); + KnownBits Known = computeKnownBits(Cond, 0, &SI); + unsigned LeadingKnownZeros = Known.countMinLeadingZeros(); + unsigned LeadingKnownOnes = Known.countMinLeadingOnes(); // Compute the number of leading bits we can ignore. // TODO: A better way to determine this would use ComputeNumSignBits(). @@ -2264,20 +2310,20 @@ Instruction *InstCombiner::visitSwitchInst(SwitchInst &SI) { LeadingKnownOnes, C.getCaseValue()->getValue().countLeadingOnes()); } - unsigned NewWidth = BitWidth - std::max(LeadingKnownZeros, LeadingKnownOnes); + unsigned NewWidth = Known.getBitWidth() - std::max(LeadingKnownZeros, LeadingKnownOnes); // Shrink the condition operand if the new type is smaller than the old type. // This may produce a non-standard type for the switch, but that's ok because // the backend should extend back to a legal type for the target. - if (NewWidth > 0 && NewWidth < BitWidth) { + if (NewWidth > 0 && NewWidth < Known.getBitWidth()) { IntegerType *Ty = IntegerType::get(SI.getContext(), NewWidth); - Builder->SetInsertPoint(&SI); - Value *NewCond = Builder->CreateTrunc(Cond, Ty, "trunc"); + Builder.SetInsertPoint(&SI); + Value *NewCond = Builder.CreateTrunc(Cond, Ty, "trunc"); SI.setCondition(NewCond); - for (SwitchInst::CaseIt CaseIter : SI.cases()) { - APInt TruncatedCase = CaseIter.getCaseValue()->getValue().trunc(NewWidth); - CaseIter.setValue(ConstantInt::get(SI.getContext(), TruncatedCase)); + for (auto Case : SI.cases()) { + APInt TruncatedCase = Case.getCaseValue()->getValue().trunc(NewWidth); + Case.setValue(ConstantInt::get(SI.getContext(), TruncatedCase)); } return &SI; } @@ -2291,8 +2337,8 @@ Instruction *InstCombiner::visitExtractValueInst(ExtractValueInst &EV) { if (!EV.hasIndices()) return replaceInstUsesWith(EV, Agg); - if (Value *V = - SimplifyExtractValueInst(Agg, EV.getIndices(), DL, &TLI, &DT, &AC)) + if (Value *V = SimplifyExtractValueInst(Agg, EV.getIndices(), + SQ.getWithInstruction(&EV))) return replaceInstUsesWith(EV, V); if (InsertValueInst *IV = dyn_cast<InsertValueInst>(Agg)) { @@ -2329,8 +2375,8 @@ Instruction *InstCombiner::visitExtractValueInst(ExtractValueInst &EV) { // %E = insertvalue { i32 } %X, i32 42, 0 // by switching the order of the insert and extract (though the // insertvalue should be left in, since it may have other uses). - Value *NewEV = Builder->CreateExtractValue(IV->getAggregateOperand(), - EV.getIndices()); + Value *NewEV = Builder.CreateExtractValue(IV->getAggregateOperand(), + EV.getIndices()); return InsertValueInst::Create(NewEV, IV->getInsertedValueOperand(), makeArrayRef(insi, inse)); } @@ -2405,19 +2451,25 @@ Instruction *InstCombiner::visitExtractValueInst(ExtractValueInst &EV) { // extractvalue has integer indices, getelementptr has Value*s. Convert. SmallVector<Value*, 4> Indices; // Prefix an i32 0 since we need the first element. - Indices.push_back(Builder->getInt32(0)); + Indices.push_back(Builder.getInt32(0)); for (ExtractValueInst::idx_iterator I = EV.idx_begin(), E = EV.idx_end(); I != E; ++I) - Indices.push_back(Builder->getInt32(*I)); + Indices.push_back(Builder.getInt32(*I)); // We need to insert these at the location of the old load, not at that of // the extractvalue. - Builder->SetInsertPoint(L); - Value *GEP = Builder->CreateInBoundsGEP(L->getType(), - L->getPointerOperand(), Indices); + Builder.SetInsertPoint(L); + Value *GEP = Builder.CreateInBoundsGEP(L->getType(), + L->getPointerOperand(), Indices); + Instruction *NL = Builder.CreateLoad(GEP); + // Whatever aliasing information we had for the orignal load must also + // hold for the smaller load, so propagate the annotations. + AAMDNodes Nodes; + L->getAAMetadata(Nodes); + NL->setAAMetadata(Nodes); // Returning the load directly will cause the main loop to insert it in // the wrong spot, so use replaceInstUsesWith(). - return replaceInstUsesWith(EV, Builder->CreateLoad(GEP)); + return replaceInstUsesWith(EV, NL); } // We could simplify extracts from other values. Note that nested extracts may // already be simplified implicitly by the above: extract (extract (insert) ) @@ -2849,12 +2901,9 @@ bool InstCombiner::run() { // a value even when the operands are not all constants. Type *Ty = I->getType(); if (ExpensiveCombines && !I->use_empty() && Ty->isIntOrIntVectorTy()) { - unsigned BitWidth = Ty->getScalarSizeInBits(); - APInt KnownZero(BitWidth, 0); - APInt KnownOne(BitWidth, 0); - computeKnownBits(I, KnownZero, KnownOne, /*Depth*/0, I); - if ((KnownZero | KnownOne).isAllOnesValue()) { - Constant *C = ConstantInt::get(Ty, KnownOne); + KnownBits Known = computeKnownBits(I, /*Depth*/0, I); + if (Known.isConstant()) { + Constant *C = ConstantInt::get(Ty, Known.getConstant()); DEBUG(dbgs() << "IC: ConstFold (all bits known) to: " << *C << " from: " << *I << '\n'); @@ -2909,8 +2958,8 @@ bool InstCombiner::run() { } // Now that we have an instruction, try combining it to simplify it. - Builder->SetInsertPoint(I); - Builder->SetCurrentDebugLocation(I->getDebugLoc()); + Builder.SetInsertPoint(I); + Builder.SetCurrentDebugLocation(I->getDebugLoc()); #ifndef NDEBUG std::string OrigI; @@ -2934,8 +2983,8 @@ bool InstCombiner::run() { Result->takeName(I); // Push the new instruction and any users onto the worklist. - Worklist.Add(Result); Worklist.AddUsersToWorkList(*Result); + Worklist.Add(Result); // Insert the new instruction into the basic block... BasicBlock *InstParent = I->getParent(); @@ -2958,8 +3007,8 @@ bool InstCombiner::run() { if (isInstructionTriviallyDead(I, &TLI)) { eraseInstFromFunction(*I); } else { - Worklist.Add(I); Worklist.AddUsersToWorkList(*I); + Worklist.Add(I); } } MadeIRChange = true; @@ -3005,6 +3054,7 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL, ++NumDeadInst; DEBUG(dbgs() << "IC: DCE: " << *Inst << '\n'); Inst->eraseFromParent(); + MadeIRChange = true; continue; } @@ -3018,16 +3068,16 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL, ++NumConstProp; if (isInstructionTriviallyDead(Inst, TLI)) Inst->eraseFromParent(); + MadeIRChange = true; continue; } // See if we can constant fold its operands. - for (User::op_iterator i = Inst->op_begin(), e = Inst->op_end(); i != e; - ++i) { - if (!isa<ConstantVector>(i) && !isa<ConstantExpr>(i)) + for (Use &U : Inst->operands()) { + if (!isa<ConstantVector>(U) && !isa<ConstantExpr>(U)) continue; - auto *C = cast<Constant>(i); + auto *C = cast<Constant>(U); Constant *&FoldRes = FoldedConstants[C]; if (!FoldRes) FoldRes = ConstantFoldConstant(C, DL, TLI); @@ -3035,12 +3085,18 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL, FoldRes = C; if (FoldRes != C) { - *i = FoldRes; + DEBUG(dbgs() << "IC: ConstFold operand of: " << *Inst + << "\n Old = " << *C + << "\n New = " << *FoldRes << '\n'); + U = FoldRes; MadeIRChange = true; } } - InstrsForInstCombineWorklist.push_back(Inst); + // Skip processing debug intrinsics in InstCombine. Processing these call instructions + // consumes non-trivial amount of time and provides no value for the optimization. + if (!isa<DbgInfoIntrinsic>(Inst)) + InstrsForInstCombineWorklist.push_back(Inst); } // Recursively visit successors. If this is a branch or switch on a @@ -3055,17 +3111,7 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL, } } else if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) { if (ConstantInt *Cond = dyn_cast<ConstantInt>(SI->getCondition())) { - // See if this is an explicit destination. - for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); - i != e; ++i) - if (i.getCaseValue() == Cond) { - BasicBlock *ReachableBB = i.getCaseSuccessor(); - Worklist.push_back(ReachableBB); - continue; - } - - // Otherwise it is the default destination. - Worklist.push_back(SI->getDefaultDest()); + Worklist.push_back(SI->findCaseValue(Cond)->getCaseSuccessor()); continue; } } @@ -3139,7 +3185,7 @@ combineInstructionsOverFunction(Function &F, InstCombineWorklist &Worklist, // Lower dbg.declare intrinsics otherwise their value may be clobbered // by instcombiner. - bool DbgDeclaresChanged = LowerDbgDeclare(F); + bool MadeIRChange = LowerDbgDeclare(F); // Iterate while there is work to do. int Iteration = 0; @@ -3148,17 +3194,17 @@ combineInstructionsOverFunction(Function &F, InstCombineWorklist &Worklist, DEBUG(dbgs() << "\n\nINSTCOMBINE ITERATION #" << Iteration << " on " << F.getName() << "\n"); - bool Changed = prepareICWorklistFromFunction(F, DL, &TLI, Worklist); + MadeIRChange |= prepareICWorklistFromFunction(F, DL, &TLI, Worklist); - InstCombiner IC(Worklist, &Builder, F.optForMinSize(), ExpensiveCombines, + InstCombiner IC(Worklist, Builder, F.optForMinSize(), ExpensiveCombines, AA, AC, TLI, DT, DL, LI); - Changed |= IC.run(); + IC.MaxArraySizeForCombine = MaxArraySize; - if (!Changed) + if (!IC.run()) break; } - return DbgDeclaresChanged || Iteration > 1; + return MadeIRChange || Iteration > 1; } PreservedAnalyses InstCombinePass::run(Function &F, @@ -3176,9 +3222,10 @@ PreservedAnalyses InstCombinePass::run(Function &F, return PreservedAnalyses::all(); // Mark all the analyses that instcombine updates as preserved. - // FIXME: This should also 'preserve the CFG'. PreservedAnalyses PA; - PA.preserve<DominatorTreeAnalysis>(); + PA.preserveSet<CFGAnalyses>(); + PA.preserve<AAManager>(); + PA.preserve<GlobalsAA>(); return PA; } diff --git a/contrib/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp b/contrib/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp index f5e9e7d..f8d2552 100644 --- a/contrib/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp +++ b/contrib/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp @@ -22,9 +22,11 @@ #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/Triple.h" +#include "llvm/ADT/Twine.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Argument.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/DIBuilder.h" #include "llvm/IR/DataLayout.h" @@ -43,6 +45,7 @@ #include "llvm/Support/DataTypes.h" #include "llvm/Support/Debug.h" #include "llvm/Support/Endian.h" +#include "llvm/Support/ScopedPrinter.h" #include "llvm/Support/SwapByteOrder.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Instrumentation.h" @@ -80,6 +83,7 @@ static const uint64_t kMIPS64_ShadowOffset64 = 1ULL << 37; static const uint64_t kAArch64_ShadowOffset64 = 1ULL << 36; static const uint64_t kFreeBSD_ShadowOffset32 = 1ULL << 30; static const uint64_t kFreeBSD_ShadowOffset64 = 1ULL << 46; +static const uint64_t kPS4CPU_ShadowOffset64 = 1ULL << 40; static const uint64_t kWindowsShadowOffset32 = 3ULL << 28; // The shadow memory space is dynamically allocated. static const uint64_t kWindowsShadowOffset64 = kDynamicShadowSentinel; @@ -100,6 +104,10 @@ static const char *const kAsanRegisterImageGlobalsName = "__asan_register_image_globals"; static const char *const kAsanUnregisterImageGlobalsName = "__asan_unregister_image_globals"; +static const char *const kAsanRegisterElfGlobalsName = + "__asan_register_elf_globals"; +static const char *const kAsanUnregisterElfGlobalsName = + "__asan_unregister_elf_globals"; static const char *const kAsanPoisonGlobalsName = "__asan_before_dynamic_init"; static const char *const kAsanUnpoisonGlobalsName = "__asan_after_dynamic_init"; static const char *const kAsanInitName = "__asan_init"; @@ -119,8 +127,11 @@ static const char *const kAsanPoisonStackMemoryName = "__asan_poison_stack_memory"; static const char *const kAsanUnpoisonStackMemoryName = "__asan_unpoison_stack_memory"; + +// ASan version script has __asan_* wildcard. Triple underscore prevents a +// linker (gold) warning about attempting to export a local symbol. static const char *const kAsanGlobalsRegisteredFlagName = - "__asan_globals_registered"; + "___asan_globals_registered"; static const char *const kAsanOptionDetectUseAfterReturn = "__asan_option_detect_stack_use_after_return"; @@ -184,6 +195,11 @@ static cl::opt<uint32_t> ClMaxInlinePoisoningSize( static cl::opt<bool> ClUseAfterReturn("asan-use-after-return", cl::desc("Check stack-use-after-return"), cl::Hidden, cl::init(true)); +static cl::opt<bool> ClRedzoneByvalArgs("asan-redzone-byval-args", + cl::desc("Create redzones for byval " + "arguments (extra copy " + "required)"), cl::Hidden, + cl::init(true)); static cl::opt<bool> ClUseAfterScope("asan-use-after-scope", cl::desc("Check stack-use-after-scope"), cl::Hidden, cl::init(false)); @@ -264,11 +280,17 @@ static cl::opt<bool> cl::Hidden, cl::init(false)); static cl::opt<bool> - ClUseMachOGlobalsSection("asan-globals-live-support", - cl::desc("Use linker features to support dead " - "code stripping of globals " - "(Mach-O only)"), - cl::Hidden, cl::init(true)); + ClUseGlobalsGC("asan-globals-live-support", + cl::desc("Use linker features to support dead " + "code stripping of globals"), + cl::Hidden, cl::init(true)); + +// This is on by default even though there is a bug in gold: +// https://sourceware.org/bugzilla/show_bug.cgi?id=19002 +static cl::opt<bool> + ClWithComdat("asan-with-comdat", + cl::desc("Place ASan constructors in comdat sections"), + cl::Hidden, cl::init(true)); // Debug flags. static cl::opt<int> ClDebug("asan-debug", cl::desc("debug"), cl::Hidden, @@ -380,6 +402,7 @@ static ShadowMapping getShadowMapping(Triple &TargetTriple, int LongSize, bool IsAndroid = TargetTriple.isAndroid(); bool IsIOS = TargetTriple.isiOS() || TargetTriple.isWatchOS(); bool IsFreeBSD = TargetTriple.isOSFreeBSD(); + bool IsPS4CPU = TargetTriple.isPS4CPU(); bool IsLinux = TargetTriple.isOSLinux(); bool IsPPC64 = TargetTriple.getArch() == llvm::Triple::ppc64 || TargetTriple.getArch() == llvm::Triple::ppc64le; @@ -392,6 +415,7 @@ static ShadowMapping getShadowMapping(Triple &TargetTriple, int LongSize, TargetTriple.getArch() == llvm::Triple::mips64el; bool IsAArch64 = TargetTriple.getArch() == llvm::Triple::aarch64; bool IsWindows = TargetTriple.isOSWindows(); + bool IsFuchsia = TargetTriple.isOSFuchsia(); ShadowMapping Mapping; @@ -412,12 +436,18 @@ static ShadowMapping getShadowMapping(Triple &TargetTriple, int LongSize, else Mapping.Offset = kDefaultShadowOffset32; } else { // LongSize == 64 - if (IsPPC64) + // Fuchsia is always PIE, which means that the beginning of the address + // space is always available. + if (IsFuchsia) + Mapping.Offset = 0; + else if (IsPPC64) Mapping.Offset = kPPC64_ShadowOffset64; else if (IsSystemZ) Mapping.Offset = kSystemZ_ShadowOffset64; else if (IsFreeBSD) Mapping.Offset = kFreeBSD_ShadowOffset64; + else if (IsPS4CPU) + Mapping.Offset = kPS4CPU_ShadowOffset64; else if (IsLinux && IsX86_64) { if (IsKasan) Mapping.Offset = kLinuxKasan_ShadowOffset64; @@ -456,9 +486,9 @@ static ShadowMapping getShadowMapping(Triple &TargetTriple, int LongSize, // offset is not necessary 1/8-th of the address space. On SystemZ, // we could OR the constant in a single instruction, but it's more // efficient to load it once and use indexed addressing. - Mapping.OrShadowOffset = !IsAArch64 && !IsPPC64 && !IsSystemZ - && !(Mapping.Offset & (Mapping.Offset - 1)) - && Mapping.Offset != kDynamicShadowSentinel; + Mapping.OrShadowOffset = !IsAArch64 && !IsPPC64 && !IsSystemZ && !IsPS4CPU && + !(Mapping.Offset & (Mapping.Offset - 1)) && + Mapping.Offset != kDynamicShadowSentinel; return Mapping; } @@ -567,8 +597,6 @@ struct AddressSanitizer : public FunctionPass { Type *IntptrTy; ShadowMapping Mapping; DominatorTree *DT; - Function *AsanCtorFunction = nullptr; - Function *AsanInitFunction = nullptr; Function *AsanHandleNoReturnFunc; Function *AsanPtrCmpFunction, *AsanPtrSubFunction; // This array is indexed by AccessIsWrite, Experiment and log2(AccessSize). @@ -587,22 +615,36 @@ struct AddressSanitizer : public FunctionPass { }; class AddressSanitizerModule : public ModulePass { - public: +public: explicit AddressSanitizerModule(bool CompileKernel = false, - bool Recover = false) + bool Recover = false, + bool UseGlobalsGC = true) : ModulePass(ID), CompileKernel(CompileKernel || ClEnableKasan), - Recover(Recover || ClRecover) {} + Recover(Recover || ClRecover), + UseGlobalsGC(UseGlobalsGC && ClUseGlobalsGC), + // Not a typo: ClWithComdat is almost completely pointless without + // ClUseGlobalsGC (because then it only works on modules without + // globals, which are rare); it is a prerequisite for ClUseGlobalsGC; + // and both suffer from gold PR19002 for which UseGlobalsGC constructor + // argument is designed as workaround. Therefore, disable both + // ClWithComdat and ClUseGlobalsGC unless the frontend says it's ok to + // do globals-gc. + UseCtorComdat(UseGlobalsGC && ClWithComdat) {} bool runOnModule(Module &M) override; - static char ID; // Pass identification, replacement for typeid + static char ID; // Pass identification, replacement for typeid StringRef getPassName() const override { return "AddressSanitizerModule"; } private: void initializeCallbacks(Module &M); - bool InstrumentGlobals(IRBuilder<> &IRB, Module &M); + bool InstrumentGlobals(IRBuilder<> &IRB, Module &M, bool *CtorComdat); void InstrumentGlobalsCOFF(IRBuilder<> &IRB, Module &M, ArrayRef<GlobalVariable *> ExtendedGlobals, ArrayRef<Constant *> MetadataInitializers); + void InstrumentGlobalsELF(IRBuilder<> &IRB, Module &M, + ArrayRef<GlobalVariable *> ExtendedGlobals, + ArrayRef<Constant *> MetadataInitializers, + const std::string &UniqueModuleId); void InstrumentGlobalsMachO(IRBuilder<> &IRB, Module &M, ArrayRef<GlobalVariable *> ExtendedGlobals, ArrayRef<Constant *> MetadataInitializers); @@ -613,7 +655,8 @@ private: GlobalVariable *CreateMetadataGlobal(Module &M, Constant *Initializer, StringRef OriginalName); - void SetComdatForGlobalMetadata(GlobalVariable *G, GlobalVariable *Metadata); + void SetComdatForGlobalMetadata(GlobalVariable *G, GlobalVariable *Metadata, + StringRef InternalSuffix); IRBuilder<> CreateAsanModuleDtor(Module &M); bool ShouldInstrumentGlobal(GlobalVariable *G); @@ -628,6 +671,8 @@ private: GlobalsMetadata GlobalsMD; bool CompileKernel; bool Recover; + bool UseGlobalsGC; + bool UseCtorComdat; Type *IntptrTy; LLVMContext *C; Triple TargetTriple; @@ -638,6 +683,11 @@ private: Function *AsanUnregisterGlobals; Function *AsanRegisterImageGlobals; Function *AsanUnregisterImageGlobals; + Function *AsanRegisterElfGlobals; + Function *AsanUnregisterElfGlobals; + + Function *AsanCtorFunction = nullptr; + Function *AsanDtorFunction = nullptr; }; // Stack poisoning does not play well with exception handling. @@ -705,6 +755,10 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { bool runOnFunction() { if (!ClStack) return false; + + if (ClRedzoneByvalArgs && Mapping.Offset != kDynamicShadowSentinel) + copyArgsPassedByValToAllocas(); + // Collect alloca, ret, lifetime instructions etc. for (BasicBlock *BB : depth_first(&F.getEntryBlock())) visit(*BB); @@ -721,6 +775,11 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { return true; } + // Arguments marked with the "byval" attribute are implicitly copied without + // using an alloca instruction. To produce redzones for those arguments, we + // copy them a second time into memory allocated with an alloca instruction. + void copyArgsPassedByValToAllocas(); + // Finds all Alloca instructions and puts // poisoned red zones around all of them. // Then unpoison everything back before the function returns. @@ -906,9 +965,10 @@ INITIALIZE_PASS( "ModulePass", false, false) ModulePass *llvm::createAddressSanitizerModulePass(bool CompileKernel, - bool Recover) { + bool Recover, + bool UseGlobalsGC) { assert(!CompileKernel || Recover); - return new AddressSanitizerModule(CompileKernel, Recover); + return new AddressSanitizerModule(CompileKernel, Recover, UseGlobalsGC); } static size_t TypeSizeToSizeIndex(uint32_t TypeSize) { @@ -1187,7 +1247,7 @@ static void instrumentMaskedLoadOrStore(AddressSanitizer *Pass, if (auto *Vector = dyn_cast<ConstantVector>(Mask)) { // dyn_cast as we might get UndefValue if (auto *Masked = dyn_cast<ConstantInt>(Vector->getOperand(Idx))) { - if (Masked->isNullValue()) + if (Masked->isZero()) // Mask is constant false, so no instrumentation needed. continue; // If we have a true or undef value, fall through to doInstrumentAddress @@ -1421,8 +1481,13 @@ void AddressSanitizerModule::poisonOneInitializer(Function &GlobalInit, void AddressSanitizerModule::createInitializerPoisonCalls( Module &M, GlobalValue *ModuleName) { GlobalVariable *GV = M.getGlobalVariable("llvm.global_ctors"); + if (!GV) + return; + + ConstantArray *CA = dyn_cast<ConstantArray>(GV->getInitializer()); + if (!CA) + return; - ConstantArray *CA = cast<ConstantArray>(GV->getInitializer()); for (Use &OP : CA->operands()) { if (isa<ConstantAggregateZero>(OP)) continue; ConstantStruct *CS = cast<ConstantStruct>(OP); @@ -1530,9 +1595,6 @@ bool AddressSanitizerModule::ShouldInstrumentGlobal(GlobalVariable *G) { // binary in order to allow the linker to properly dead strip. This is only // supported on recent versions of ld64. bool AddressSanitizerModule::ShouldUseMachOGlobalsSection() const { - if (!ClUseMachOGlobalsSection) - return false; - if (!TargetTriple.isOSBinFormatMachO()) return false; @@ -1561,38 +1623,48 @@ void AddressSanitizerModule::initializeCallbacks(Module &M) { // Declare our poisoning and unpoisoning functions. AsanPoisonGlobals = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - kAsanPoisonGlobalsName, IRB.getVoidTy(), IntptrTy, nullptr)); + kAsanPoisonGlobalsName, IRB.getVoidTy(), IntptrTy)); AsanPoisonGlobals->setLinkage(Function::ExternalLinkage); AsanUnpoisonGlobals = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - kAsanUnpoisonGlobalsName, IRB.getVoidTy(), nullptr)); + kAsanUnpoisonGlobalsName, IRB.getVoidTy())); AsanUnpoisonGlobals->setLinkage(Function::ExternalLinkage); // Declare functions that register/unregister globals. AsanRegisterGlobals = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - kAsanRegisterGlobalsName, IRB.getVoidTy(), IntptrTy, IntptrTy, nullptr)); + kAsanRegisterGlobalsName, IRB.getVoidTy(), IntptrTy, IntptrTy)); AsanRegisterGlobals->setLinkage(Function::ExternalLinkage); AsanUnregisterGlobals = checkSanitizerInterfaceFunction( M.getOrInsertFunction(kAsanUnregisterGlobalsName, IRB.getVoidTy(), - IntptrTy, IntptrTy, nullptr)); + IntptrTy, IntptrTy)); AsanUnregisterGlobals->setLinkage(Function::ExternalLinkage); // Declare the functions that find globals in a shared object and then invoke // the (un)register function on them. AsanRegisterImageGlobals = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - kAsanRegisterImageGlobalsName, IRB.getVoidTy(), IntptrTy, nullptr)); + kAsanRegisterImageGlobalsName, IRB.getVoidTy(), IntptrTy)); AsanRegisterImageGlobals->setLinkage(Function::ExternalLinkage); AsanUnregisterImageGlobals = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - kAsanUnregisterImageGlobalsName, IRB.getVoidTy(), IntptrTy, nullptr)); + kAsanUnregisterImageGlobalsName, IRB.getVoidTy(), IntptrTy)); AsanUnregisterImageGlobals->setLinkage(Function::ExternalLinkage); + + AsanRegisterElfGlobals = checkSanitizerInterfaceFunction( + M.getOrInsertFunction(kAsanRegisterElfGlobalsName, IRB.getVoidTy(), + IntptrTy, IntptrTy, IntptrTy)); + AsanRegisterElfGlobals->setLinkage(Function::ExternalLinkage); + + AsanUnregisterElfGlobals = checkSanitizerInterfaceFunction( + M.getOrInsertFunction(kAsanUnregisterElfGlobalsName, IRB.getVoidTy(), + IntptrTy, IntptrTy, IntptrTy)); + AsanUnregisterElfGlobals->setLinkage(Function::ExternalLinkage); } // Put the metadata and the instrumented global in the same group. This ensures // that the metadata is discarded if the instrumented global is discarded. void AddressSanitizerModule::SetComdatForGlobalMetadata( - GlobalVariable *G, GlobalVariable *Metadata) { + GlobalVariable *G, GlobalVariable *Metadata, StringRef InternalSuffix) { Module &M = *G->getParent(); Comdat *C = G->getComdat(); if (!C) { @@ -1602,7 +1674,15 @@ void AddressSanitizerModule::SetComdatForGlobalMetadata( assert(G->hasLocalLinkage()); G->setName(Twine(kAsanGenPrefix) + "_anon_global"); } - C = M.getOrInsertComdat(G->getName()); + + if (!InternalSuffix.empty() && G->hasLocalLinkage()) { + std::string Name = G->getName(); + Name += InternalSuffix; + C = M.getOrInsertComdat(Name); + } else { + C = M.getOrInsertComdat(G->getName()); + } + // Make this IMAGE_COMDAT_SELECT_NODUPLICATES on COFF. if (TargetTriple.isOSBinFormatCOFF()) C->setSelectionKind(Comdat::NoDuplicates); @@ -1618,21 +1698,21 @@ void AddressSanitizerModule::SetComdatForGlobalMetadata( GlobalVariable * AddressSanitizerModule::CreateMetadataGlobal(Module &M, Constant *Initializer, StringRef OriginalName) { - GlobalVariable *Metadata = - new GlobalVariable(M, Initializer->getType(), false, - GlobalVariable::InternalLinkage, Initializer, - Twine("__asan_global_") + - GlobalValue::getRealLinkageName(OriginalName)); + auto Linkage = TargetTriple.isOSBinFormatMachO() + ? GlobalVariable::InternalLinkage + : GlobalVariable::PrivateLinkage; + GlobalVariable *Metadata = new GlobalVariable( + M, Initializer->getType(), false, Linkage, Initializer, + Twine("__asan_global_") + GlobalValue::dropLLVMManglingEscape(OriginalName)); Metadata->setSection(getGlobalMetadataSection()); return Metadata; } IRBuilder<> AddressSanitizerModule::CreateAsanModuleDtor(Module &M) { - Function *AsanDtorFunction = + AsanDtorFunction = Function::Create(FunctionType::get(Type::getVoidTy(*C), false), GlobalValue::InternalLinkage, kAsanModuleDtorName, &M); BasicBlock *AsanDtorBB = BasicBlock::Create(*C, "", AsanDtorFunction); - appendToGlobalDtors(M, AsanDtorFunction, kAsanCtorAndDtorPriority); return IRBuilder<>(ReturnInst::Create(*C, AsanDtorBB)); } @@ -1657,10 +1737,69 @@ void AddressSanitizerModule::InstrumentGlobalsCOFF( "global metadata will not be padded appropriately"); Metadata->setAlignment(SizeOfGlobalStruct); - SetComdatForGlobalMetadata(G, Metadata); + SetComdatForGlobalMetadata(G, Metadata, ""); } } +void AddressSanitizerModule::InstrumentGlobalsELF( + IRBuilder<> &IRB, Module &M, ArrayRef<GlobalVariable *> ExtendedGlobals, + ArrayRef<Constant *> MetadataInitializers, + const std::string &UniqueModuleId) { + assert(ExtendedGlobals.size() == MetadataInitializers.size()); + + SmallVector<GlobalValue *, 16> MetadataGlobals(ExtendedGlobals.size()); + for (size_t i = 0; i < ExtendedGlobals.size(); i++) { + GlobalVariable *G = ExtendedGlobals[i]; + GlobalVariable *Metadata = + CreateMetadataGlobal(M, MetadataInitializers[i], G->getName()); + MDNode *MD = MDNode::get(M.getContext(), ValueAsMetadata::get(G)); + Metadata->setMetadata(LLVMContext::MD_associated, MD); + MetadataGlobals[i] = Metadata; + + SetComdatForGlobalMetadata(G, Metadata, UniqueModuleId); + } + + // Update llvm.compiler.used, adding the new metadata globals. This is + // needed so that during LTO these variables stay alive. + if (!MetadataGlobals.empty()) + appendToCompilerUsed(M, MetadataGlobals); + + // RegisteredFlag serves two purposes. First, we can pass it to dladdr() + // to look up the loaded image that contains it. Second, we can store in it + // whether registration has already occurred, to prevent duplicate + // registration. + // + // Common linkage ensures that there is only one global per shared library. + GlobalVariable *RegisteredFlag = new GlobalVariable( + M, IntptrTy, false, GlobalVariable::CommonLinkage, + ConstantInt::get(IntptrTy, 0), kAsanGlobalsRegisteredFlagName); + RegisteredFlag->setVisibility(GlobalVariable::HiddenVisibility); + + // Create start and stop symbols. + GlobalVariable *StartELFMetadata = new GlobalVariable( + M, IntptrTy, false, GlobalVariable::ExternalWeakLinkage, nullptr, + "__start_" + getGlobalMetadataSection()); + StartELFMetadata->setVisibility(GlobalVariable::HiddenVisibility); + GlobalVariable *StopELFMetadata = new GlobalVariable( + M, IntptrTy, false, GlobalVariable::ExternalWeakLinkage, nullptr, + "__stop_" + getGlobalMetadataSection()); + StopELFMetadata->setVisibility(GlobalVariable::HiddenVisibility); + + // Create a call to register the globals with the runtime. + IRB.CreateCall(AsanRegisterElfGlobals, + {IRB.CreatePointerCast(RegisteredFlag, IntptrTy), + IRB.CreatePointerCast(StartELFMetadata, IntptrTy), + IRB.CreatePointerCast(StopELFMetadata, IntptrTy)}); + + // We also need to unregister globals at the end, e.g., when a shared library + // gets closed. + IRBuilder<> IRB_Dtor = CreateAsanModuleDtor(M); + IRB_Dtor.CreateCall(AsanUnregisterElfGlobals, + {IRB.CreatePointerCast(RegisteredFlag, IntptrTy), + IRB.CreatePointerCast(StartELFMetadata, IntptrTy), + IRB.CreatePointerCast(StopELFMetadata, IntptrTy)}); +} + void AddressSanitizerModule::InstrumentGlobalsMachO( IRBuilder<> &IRB, Module &M, ArrayRef<GlobalVariable *> ExtendedGlobals, ArrayRef<Constant *> MetadataInitializers) { @@ -1669,7 +1808,7 @@ void AddressSanitizerModule::InstrumentGlobalsMachO( // On recent Mach-O platforms, use a structure which binds the liveness of // the global variable to the metadata struct. Keep the list of "Liveness" GV // created to be added to llvm.compiler.used - StructType *LivenessTy = StructType::get(IntptrTy, IntptrTy, nullptr); + StructType *LivenessTy = StructType::get(IntptrTy, IntptrTy); SmallVector<GlobalValue *, 16> LivenessGlobals(ExtendedGlobals.size()); for (size_t i = 0; i < ExtendedGlobals.size(); i++) { @@ -1680,9 +1819,9 @@ void AddressSanitizerModule::InstrumentGlobalsMachO( // On recent Mach-O platforms, we emit the global metadata in a way that // allows the linker to properly strip dead globals. - auto LivenessBinder = ConstantStruct::get( - LivenessTy, Initializer->getAggregateElement(0u), - ConstantExpr::getPointerCast(Metadata, IntptrTy), nullptr); + auto LivenessBinder = + ConstantStruct::get(LivenessTy, Initializer->getAggregateElement(0u), + ConstantExpr::getPointerCast(Metadata, IntptrTy)); GlobalVariable *Liveness = new GlobalVariable( M, LivenessTy, false, GlobalVariable::InternalLinkage, LivenessBinder, Twine("__asan_binder_") + G->getName()); @@ -1748,7 +1887,10 @@ void AddressSanitizerModule::InstrumentGlobalsWithMetadataArray( // This function replaces all global variables with new variables that have // trailing redzones. It also creates a function that poisons // redzones and inserts this function into llvm.global_ctors. -bool AddressSanitizerModule::InstrumentGlobals(IRBuilder<> &IRB, Module &M) { +// Sets *CtorComdat to true if the global registration code emitted into the +// asan constructor is comdat-compatible. +bool AddressSanitizerModule::InstrumentGlobals(IRBuilder<> &IRB, Module &M, bool *CtorComdat) { + *CtorComdat = false; GlobalsMD.init(M); SmallVector<GlobalVariable *, 16> GlobalsToChange; @@ -1758,7 +1900,10 @@ bool AddressSanitizerModule::InstrumentGlobals(IRBuilder<> &IRB, Module &M) { } size_t n = GlobalsToChange.size(); - if (n == 0) return false; + if (n == 0) { + *CtorComdat = true; + return false; + } auto &DL = M.getDataLayout(); @@ -1774,7 +1919,7 @@ bool AddressSanitizerModule::InstrumentGlobals(IRBuilder<> &IRB, Module &M) { // We initialize an array of such structures and pass it to a run-time call. StructType *GlobalStructTy = StructType::get(IntptrTy, IntptrTy, IntptrTy, IntptrTy, IntptrTy, - IntptrTy, IntptrTy, IntptrTy, nullptr); + IntptrTy, IntptrTy, IntptrTy); SmallVector<GlobalVariable *, 16> NewGlobals(n); SmallVector<Constant *, 16> Initializers(n); @@ -1810,10 +1955,9 @@ bool AddressSanitizerModule::InstrumentGlobals(IRBuilder<> &IRB, Module &M) { assert(((RightRedzoneSize + SizeInBytes) % MinRZ) == 0); Type *RightRedZoneTy = ArrayType::get(IRB.getInt8Ty(), RightRedzoneSize); - StructType *NewTy = StructType::get(Ty, RightRedZoneTy, nullptr); - Constant *NewInitializer = - ConstantStruct::get(NewTy, G->getInitializer(), - Constant::getNullValue(RightRedZoneTy), nullptr); + StructType *NewTy = StructType::get(Ty, RightRedZoneTy); + Constant *NewInitializer = ConstantStruct::get( + NewTy, G->getInitializer(), Constant::getNullValue(RightRedZoneTy)); // Create a new global variable with enough space for a redzone. GlobalValue::LinkageTypes Linkage = G->getLinkage(); @@ -1862,7 +2006,8 @@ bool AddressSanitizerModule::InstrumentGlobals(IRBuilder<> &IRB, Module &M) { GlobalValue *InstrumentedGlobal = NewGlobal; bool CanUsePrivateAliases = - TargetTriple.isOSBinFormatELF() || TargetTriple.isOSBinFormatMachO(); + TargetTriple.isOSBinFormatELF() || TargetTriple.isOSBinFormatMachO() || + TargetTriple.isOSBinFormatWasm(); if (CanUsePrivateAliases && ClUsePrivateAliasForGlobals) { // Create local alias for NewGlobal to avoid crash on ODR between // instrumented and non-instrumented libraries. @@ -1893,7 +2038,7 @@ bool AddressSanitizerModule::InstrumentGlobals(IRBuilder<> &IRB, Module &M) { ConstantExpr::getPointerCast(Name, IntptrTy), ConstantExpr::getPointerCast(ModuleName, IntptrTy), ConstantInt::get(IntptrTy, MD.IsDynInit), SourceLoc, - ConstantExpr::getPointerCast(ODRIndicator, IntptrTy), nullptr); + ConstantExpr::getPointerCast(ODRIndicator, IntptrTy)); if (ClInitializers && MD.IsDynInit) HasDynamicallyInitializedGlobals = true; @@ -1902,9 +2047,16 @@ bool AddressSanitizerModule::InstrumentGlobals(IRBuilder<> &IRB, Module &M) { Initializers[i] = Initializer; } - if (TargetTriple.isOSBinFormatCOFF()) { + std::string ELFUniqueModuleId = + (UseGlobalsGC && TargetTriple.isOSBinFormatELF()) ? getUniqueModuleId(&M) + : ""; + + if (!ELFUniqueModuleId.empty()) { + InstrumentGlobalsELF(IRB, M, NewGlobals, Initializers, ELFUniqueModuleId); + *CtorComdat = true; + } else if (UseGlobalsGC && TargetTriple.isOSBinFormatCOFF()) { InstrumentGlobalsCOFF(IRB, M, NewGlobals, Initializers); - } else if (ShouldUseMachOGlobalsSection()) { + } else if (UseGlobalsGC && ShouldUseMachOGlobalsSection()) { InstrumentGlobalsMachO(IRB, M, NewGlobals, Initializers); } else { InstrumentGlobalsWithMetadataArray(IRB, M, NewGlobals, Initializers); @@ -1926,14 +2078,39 @@ bool AddressSanitizerModule::runOnModule(Module &M) { Mapping = getShadowMapping(TargetTriple, LongSize, CompileKernel); initializeCallbacks(M); - bool Changed = false; + if (CompileKernel) + return false; + + // Create a module constructor. A destructor is created lazily because not all + // platforms, and not all modules need it. + std::tie(AsanCtorFunction, std::ignore) = createSanitizerCtorAndInitFunctions( + M, kAsanModuleCtorName, kAsanInitName, /*InitArgTypes=*/{}, + /*InitArgs=*/{}, kAsanVersionCheckName); + bool CtorComdat = true; + bool Changed = false; // TODO(glider): temporarily disabled globals instrumentation for KASan. - if (ClGlobals && !CompileKernel) { - Function *CtorFunc = M.getFunction(kAsanModuleCtorName); - assert(CtorFunc); - IRBuilder<> IRB(CtorFunc->getEntryBlock().getTerminator()); - Changed |= InstrumentGlobals(IRB, M); + if (ClGlobals) { + IRBuilder<> IRB(AsanCtorFunction->getEntryBlock().getTerminator()); + Changed |= InstrumentGlobals(IRB, M, &CtorComdat); + } + + // Put the constructor and destructor in comdat if both + // (1) global instrumentation is not TU-specific + // (2) target is ELF. + if (UseCtorComdat && TargetTriple.isOSBinFormatELF() && CtorComdat) { + AsanCtorFunction->setComdat(M.getOrInsertComdat(kAsanModuleCtorName)); + appendToGlobalCtors(M, AsanCtorFunction, kAsanCtorAndDtorPriority, + AsanCtorFunction); + if (AsanDtorFunction) { + AsanDtorFunction->setComdat(M.getOrInsertComdat(kAsanModuleDtorName)); + appendToGlobalDtors(M, AsanDtorFunction, kAsanCtorAndDtorPriority, + AsanDtorFunction); + } + } else { + appendToGlobalCtors(M, AsanCtorFunction, kAsanCtorAndDtorPriority); + if (AsanDtorFunction) + appendToGlobalDtors(M, AsanDtorFunction, kAsanCtorAndDtorPriority); } return Changed; @@ -1949,49 +2126,60 @@ void AddressSanitizer::initializeCallbacks(Module &M) { const std::string ExpStr = Exp ? "exp_" : ""; const std::string SuffixStr = CompileKernel ? "N" : "_n"; const std::string EndingStr = Recover ? "_noabort" : ""; - Type *ExpType = Exp ? Type::getInt32Ty(*C) : nullptr; - AsanErrorCallbackSized[AccessIsWrite][Exp] = - checkSanitizerInterfaceFunction(M.getOrInsertFunction( - kAsanReportErrorTemplate + ExpStr + TypeStr + SuffixStr + EndingStr, - IRB.getVoidTy(), IntptrTy, IntptrTy, ExpType, nullptr)); - AsanMemoryAccessCallbackSized[AccessIsWrite][Exp] = - checkSanitizerInterfaceFunction(M.getOrInsertFunction( - ClMemoryAccessCallbackPrefix + ExpStr + TypeStr + "N" + EndingStr, - IRB.getVoidTy(), IntptrTy, IntptrTy, ExpType, nullptr)); - for (size_t AccessSizeIndex = 0; AccessSizeIndex < kNumberOfAccessSizes; - AccessSizeIndex++) { - const std::string Suffix = TypeStr + itostr(1ULL << AccessSizeIndex); - AsanErrorCallback[AccessIsWrite][Exp][AccessSizeIndex] = - checkSanitizerInterfaceFunction(M.getOrInsertFunction( - kAsanReportErrorTemplate + ExpStr + Suffix + EndingStr, - IRB.getVoidTy(), IntptrTy, ExpType, nullptr)); - AsanMemoryAccessCallback[AccessIsWrite][Exp][AccessSizeIndex] = - checkSanitizerInterfaceFunction(M.getOrInsertFunction( - ClMemoryAccessCallbackPrefix + ExpStr + Suffix + EndingStr, - IRB.getVoidTy(), IntptrTy, ExpType, nullptr)); + + SmallVector<Type *, 3> Args2 = {IntptrTy, IntptrTy}; + SmallVector<Type *, 2> Args1{1, IntptrTy}; + if (Exp) { + Type *ExpType = Type::getInt32Ty(*C); + Args2.push_back(ExpType); + Args1.push_back(ExpType); } - } + AsanErrorCallbackSized[AccessIsWrite][Exp] = + checkSanitizerInterfaceFunction(M.getOrInsertFunction( + kAsanReportErrorTemplate + ExpStr + TypeStr + SuffixStr + + EndingStr, + FunctionType::get(IRB.getVoidTy(), Args2, false))); + + AsanMemoryAccessCallbackSized[AccessIsWrite][Exp] = + checkSanitizerInterfaceFunction(M.getOrInsertFunction( + ClMemoryAccessCallbackPrefix + ExpStr + TypeStr + "N" + EndingStr, + FunctionType::get(IRB.getVoidTy(), Args2, false))); + + for (size_t AccessSizeIndex = 0; AccessSizeIndex < kNumberOfAccessSizes; + AccessSizeIndex++) { + const std::string Suffix = TypeStr + itostr(1ULL << AccessSizeIndex); + AsanErrorCallback[AccessIsWrite][Exp][AccessSizeIndex] = + checkSanitizerInterfaceFunction(M.getOrInsertFunction( + kAsanReportErrorTemplate + ExpStr + Suffix + EndingStr, + FunctionType::get(IRB.getVoidTy(), Args1, false))); + + AsanMemoryAccessCallback[AccessIsWrite][Exp][AccessSizeIndex] = + checkSanitizerInterfaceFunction(M.getOrInsertFunction( + ClMemoryAccessCallbackPrefix + ExpStr + Suffix + EndingStr, + FunctionType::get(IRB.getVoidTy(), Args1, false))); + } + } } const std::string MemIntrinCallbackPrefix = CompileKernel ? std::string("") : ClMemoryAccessCallbackPrefix; AsanMemmove = checkSanitizerInterfaceFunction(M.getOrInsertFunction( MemIntrinCallbackPrefix + "memmove", IRB.getInt8PtrTy(), - IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy, nullptr)); + IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy)); AsanMemcpy = checkSanitizerInterfaceFunction(M.getOrInsertFunction( MemIntrinCallbackPrefix + "memcpy", IRB.getInt8PtrTy(), - IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy, nullptr)); + IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy)); AsanMemset = checkSanitizerInterfaceFunction(M.getOrInsertFunction( MemIntrinCallbackPrefix + "memset", IRB.getInt8PtrTy(), - IRB.getInt8PtrTy(), IRB.getInt32Ty(), IntptrTy, nullptr)); + IRB.getInt8PtrTy(), IRB.getInt32Ty(), IntptrTy)); AsanHandleNoReturnFunc = checkSanitizerInterfaceFunction( - M.getOrInsertFunction(kAsanHandleNoReturnName, IRB.getVoidTy(), nullptr)); + M.getOrInsertFunction(kAsanHandleNoReturnName, IRB.getVoidTy())); AsanPtrCmpFunction = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - kAsanPtrCmp, IRB.getVoidTy(), IntptrTy, IntptrTy, nullptr)); + kAsanPtrCmp, IRB.getVoidTy(), IntptrTy, IntptrTy)); AsanPtrSubFunction = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - kAsanPtrSub, IRB.getVoidTy(), IntptrTy, IntptrTy, nullptr)); + kAsanPtrSub, IRB.getVoidTy(), IntptrTy, IntptrTy)); // We insert an empty inline asm after __asan_report* to avoid callback merge. EmptyAsm = InlineAsm::get(FunctionType::get(IRB.getVoidTy(), false), StringRef(""), StringRef(""), @@ -2001,7 +2189,6 @@ void AddressSanitizer::initializeCallbacks(Module &M) { // virtual bool AddressSanitizer::doInitialization(Module &M) { // Initialize the private fields. No one has accessed them before. - GlobalsMD.init(M); C = &(M.getContext()); @@ -2009,13 +2196,6 @@ bool AddressSanitizer::doInitialization(Module &M) { IntptrTy = Type::getIntNTy(*C, LongSize); TargetTriple = Triple(M.getTargetTriple()); - if (!CompileKernel) { - std::tie(AsanCtorFunction, AsanInitFunction) = - createSanitizerCtorAndInitFunctions( - M, kAsanModuleCtorName, kAsanInitName, - /*InitArgTypes=*/{}, /*InitArgs=*/{}, kAsanVersionCheckName); - appendToGlobalCtors(M, AsanCtorFunction, kAsanCtorAndDtorPriority); - } Mapping = getShadowMapping(TargetTriple, LongSize, CompileKernel); return true; } @@ -2034,6 +2214,8 @@ bool AddressSanitizer::maybeInsertAsanInitAtFunctionEntry(Function &F) { // We cannot just ignore these methods, because they may call other // instrumented functions. if (F.getName().find(" load]") != std::string::npos) { + Function *AsanInitFunction = + declareSanitizerInitFunction(*F.getParent(), kAsanInitName, {}); IRBuilder<> IRB(&F.front(), F.front().begin()); IRB.CreateCall(AsanInitFunction, {}); return true; @@ -2081,7 +2263,6 @@ void AddressSanitizer::markEscapedLocalAllocas(Function &F) { } bool AddressSanitizer::runOnFunction(Function &F) { - if (&F == AsanCtorFunction) return false; if (F.getLinkage() == GlobalValue::AvailableExternallyLinkage) return false; if (!ClDebugFunc.empty() && ClDebugFunc == F.getName()) return false; if (F.getName().startswith("__asan_")) return false; @@ -2175,8 +2356,9 @@ bool AddressSanitizer::runOnFunction(Function &F) { (ClInstrumentationWithCallsThreshold >= 0 && ToInstrument.size() > (unsigned)ClInstrumentationWithCallsThreshold); const DataLayout &DL = F.getParent()->getDataLayout(); - ObjectSizeOffsetVisitor ObjSizeVis(DL, TLI, F.getContext(), - /*RoundToAlign=*/true); + ObjectSizeOpts ObjSizeOpts; + ObjSizeOpts.RoundToAlign = true; + ObjectSizeOffsetVisitor ObjSizeVis(DL, TLI, F.getContext(), ObjSizeOpts); // Instrument. int NumInstrumented = 0; @@ -2234,18 +2416,18 @@ void FunctionStackPoisoner::initializeCallbacks(Module &M) { std::string Suffix = itostr(i); AsanStackMallocFunc[i] = checkSanitizerInterfaceFunction( M.getOrInsertFunction(kAsanStackMallocNameTemplate + Suffix, IntptrTy, - IntptrTy, nullptr)); + IntptrTy)); AsanStackFreeFunc[i] = checkSanitizerInterfaceFunction( M.getOrInsertFunction(kAsanStackFreeNameTemplate + Suffix, - IRB.getVoidTy(), IntptrTy, IntptrTy, nullptr)); + IRB.getVoidTy(), IntptrTy, IntptrTy)); } if (ASan.UseAfterScope) { AsanPoisonStackMemoryFunc = checkSanitizerInterfaceFunction( M.getOrInsertFunction(kAsanPoisonStackMemoryName, IRB.getVoidTy(), - IntptrTy, IntptrTy, nullptr)); + IntptrTy, IntptrTy)); AsanUnpoisonStackMemoryFunc = checkSanitizerInterfaceFunction( M.getOrInsertFunction(kAsanUnpoisonStackMemoryName, IRB.getVoidTy(), - IntptrTy, IntptrTy, nullptr)); + IntptrTy, IntptrTy)); } for (size_t Val : {0x00, 0xf1, 0xf2, 0xf3, 0xf5, 0xf8}) { @@ -2254,14 +2436,14 @@ void FunctionStackPoisoner::initializeCallbacks(Module &M) { Name << std::setw(2) << std::setfill('0') << std::hex << Val; AsanSetShadowFunc[Val] = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - Name.str(), IRB.getVoidTy(), IntptrTy, IntptrTy, nullptr)); + Name.str(), IRB.getVoidTy(), IntptrTy, IntptrTy)); } AsanAllocaPoisonFunc = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - kAsanAllocaPoison, IRB.getVoidTy(), IntptrTy, IntptrTy, nullptr)); + kAsanAllocaPoison, IRB.getVoidTy(), IntptrTy, IntptrTy)); AsanAllocasUnpoisonFunc = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - kAsanAllocasUnpoison, IRB.getVoidTy(), IntptrTy, IntptrTy, nullptr)); + kAsanAllocasUnpoison, IRB.getVoidTy(), IntptrTy, IntptrTy)); } void FunctionStackPoisoner::copyToShadowInline(ArrayRef<uint8_t> ShadowMask, @@ -2363,6 +2545,28 @@ static int StackMallocSizeClass(uint64_t LocalStackSize) { llvm_unreachable("impossible LocalStackSize"); } +void FunctionStackPoisoner::copyArgsPassedByValToAllocas() { + BasicBlock &FirstBB = *F.begin(); + IRBuilder<> IRB(&FirstBB, FirstBB.getFirstInsertionPt()); + const DataLayout &DL = F.getParent()->getDataLayout(); + for (Argument &Arg : F.args()) { + if (Arg.hasByValAttr()) { + Type *Ty = Arg.getType()->getPointerElementType(); + unsigned Align = Arg.getParamAlignment(); + if (Align == 0) Align = DL.getABITypeAlignment(Ty); + + const std::string &Name = Arg.hasName() ? Arg.getName().str() : + "Arg" + llvm::to_string(Arg.getArgNo()); + AllocaInst *AI = IRB.CreateAlloca(Ty, nullptr, Twine(Name) + ".byval"); + AI->setAlignment(Align); + Arg.replaceAllUsesWith(AI); + + uint64_t AllocSize = DL.getTypeAllocSize(Ty); + IRB.CreateMemCpy(AI, &Arg, AllocSize, Align); + } + } +} + PHINode *FunctionStackPoisoner::createPHI(IRBuilder<> &IRB, Value *Cond, Value *ValueIfTrue, Instruction *ThenTerm, @@ -2566,7 +2770,7 @@ void FunctionStackPoisoner::processStaticAllocas() { Value *NewAllocaPtr = IRB.CreateIntToPtr( IRB.CreateAdd(LocalStackBase, ConstantInt::get(IntptrTy, Desc.Offset)), AI->getType()); - replaceDbgDeclareForAlloca(AI, NewAllocaPtr, DIB, /*Deref=*/true); + replaceDbgDeclareForAlloca(AI, NewAllocaPtr, DIB, DIExpression::NoDeref); AI->replaceAllUsesWith(NewAllocaPtr); } diff --git a/contrib/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp b/contrib/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp index d4c8369..a193efe 100644 --- a/contrib/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp +++ b/contrib/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp @@ -12,7 +12,6 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Instrumentation.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/TargetFolder.h" @@ -25,6 +24,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Instrumentation.h" using namespace llvm; #define DEBUG_TYPE "bounds-checking" diff --git a/contrib/llvm/lib/Transforms/Instrumentation/CFGMST.h b/contrib/llvm/lib/Transforms/Instrumentation/CFGMST.h index 3802f9f..16e2e6b 100644 --- a/contrib/llvm/lib/Transforms/Instrumentation/CFGMST.h +++ b/contrib/llvm/lib/Transforms/Instrumentation/CFGMST.h @@ -12,6 +12,9 @@ // //===----------------------------------------------------------------------===// +#ifndef LLVM_LIB_TRANSFORMS_INSTRUMENTATION_CFGMST_H +#define LLVM_LIB_TRANSFORMS_INSTRUMENTATION_CFGMST_H + #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Analysis/BlockFrequencyInfo.h" @@ -24,10 +27,10 @@ #include <utility> #include <vector> -namespace llvm { - #define DEBUG_TYPE "cfgmst" +namespace llvm { + /// \brief An union-find based Minimum Spanning Tree for CFG /// /// Implements a Union-find algorithm to compute Minimum Spanning Tree @@ -220,5 +223,8 @@ public: } }; -#undef DEBUG_TYPE // "cfgmst" } // end namespace llvm + +#undef DEBUG_TYPE // "cfgmst" + +#endif // LLVM_LIB_TRANSFORMS_INSTRUMENTATION_CFGMST_H diff --git a/contrib/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp b/contrib/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp index b34d5b8..ddc975c 100644 --- a/contrib/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp +++ b/contrib/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp @@ -44,15 +44,14 @@ /// For more information, please refer to the design document: /// http://clang.llvm.org/docs/DataFlowSanitizerDesign.html -#include "llvm/Transforms/Instrumentation.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/Triple.h" #include "llvm/Analysis/ValueTracking.h" -#include "llvm/IR/Dominators.h" #include "llvm/IR/DebugInfo.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InlineAsm.h" #include "llvm/IR/InstVisitor.h" @@ -63,6 +62,7 @@ #include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/SpecialCaseList.h" +#include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include <algorithm> @@ -254,7 +254,7 @@ class DataFlowSanitizer : public ModulePass { MDNode *ColdCallWeights; DFSanABIList ABIList; DenseMap<Value *, Function *> UnwrappedFnMap; - AttributeSet ReadOnlyNoneAttrs; + AttrBuilder ReadOnlyNoneAttrs; bool DFSanRuntimeShadowMask; Value *getShadowAddress(Value *Addr, Instruction *Pos); @@ -331,6 +331,10 @@ class DFSanVisitor : public InstVisitor<DFSanVisitor> { DFSanFunction &DFSF; DFSanVisitor(DFSanFunction &DFSF) : DFSF(DFSF) {} + const DataLayout &getDataLayout() const { + return DFSF.F->getParent()->getDataLayout(); + } + void visitOperandShadowInst(Instruction &I); void visitBinaryOperator(BinaryOperator &BO); @@ -384,7 +388,7 @@ FunctionType *DataFlowSanitizer::getArgsFunctionType(FunctionType *T) { ArgTypes.push_back(ShadowPtrTy); Type *RetType = T->getReturnType(); if (!RetType->isVoidTy()) - RetType = StructType::get(RetType, ShadowTy, (Type *)nullptr); + RetType = StructType::get(RetType, ShadowTy); return FunctionType::get(RetType, ArgTypes, T->isVarArg()); } @@ -472,16 +476,14 @@ bool DataFlowSanitizer::doInitialization(Module &M) { GetArgTLS = ConstantExpr::getIntToPtr( ConstantInt::get(IntptrTy, uintptr_t(GetArgTLSPtr)), PointerType::getUnqual( - FunctionType::get(PointerType::getUnqual(ArgTLSTy), - (Type *)nullptr))); + FunctionType::get(PointerType::getUnqual(ArgTLSTy), false))); } if (GetRetvalTLSPtr) { RetvalTLS = nullptr; GetRetvalTLS = ConstantExpr::getIntToPtr( ConstantInt::get(IntptrTy, uintptr_t(GetRetvalTLSPtr)), PointerType::getUnqual( - FunctionType::get(PointerType::getUnqual(ShadowTy), - (Type *)nullptr))); + FunctionType::get(PointerType::getUnqual(ShadowTy), false))); } ColdCallWeights = MDBuilder(*Ctx).createBranchWeights(1, 1000); @@ -539,16 +541,13 @@ DataFlowSanitizer::buildWrapperFunction(Function *F, StringRef NewFName, F->getParent()); NewF->copyAttributesFrom(F); NewF->removeAttributes( - AttributeSet::ReturnIndex, - AttributeSet::get(F->getContext(), AttributeSet::ReturnIndex, - AttributeFuncs::typeIncompatible(NewFT->getReturnType()))); + AttributeList::ReturnIndex, + AttributeFuncs::typeIncompatible(NewFT->getReturnType())); BasicBlock *BB = BasicBlock::Create(*Ctx, "entry", NewF); if (F->isVarArg()) { - NewF->removeAttributes( - AttributeSet::FunctionIndex, - AttributeSet().addAttribute(*Ctx, AttributeSet::FunctionIndex, - "split-stack")); + NewF->removeAttributes(AttributeList::FunctionIndex, + AttrBuilder().addAttribute("split-stack")); CallInst::Create(DFSanVarargWrapperFn, IRBuilder<>(BB).CreateGlobalStringPtr(F->getName()), "", BB); @@ -580,8 +579,7 @@ Constant *DataFlowSanitizer::getOrBuildTrampolineFunction(FunctionType *FT, Function::arg_iterator AI = F->arg_begin(); ++AI; for (unsigned N = FT->getNumParams(); N != 0; ++AI, --N) Args.push_back(&*AI); - CallInst *CI = - CallInst::Create(&F->getArgumentList().front(), Args, "", BB); + CallInst *CI = CallInst::Create(&*F->arg_begin(), Args, "", BB); ReturnInst *RI; if (FT->getReturnType()->isVoidTy()) RI = ReturnInst::Create(*Ctx, BB); @@ -595,7 +593,7 @@ Constant *DataFlowSanitizer::getOrBuildTrampolineFunction(FunctionType *FT, DFSanVisitor(DFSF).visitCallInst(*CI); if (!FT->getReturnType()->isVoidTy()) new StoreInst(DFSF.getShadow(RI->getReturnValue()), - &F->getArgumentList().back(), RI); + &*std::prev(F->arg_end()), RI); } return C; @@ -622,33 +620,33 @@ bool DataFlowSanitizer::runOnModule(Module &M) { DFSanUnionFn = Mod->getOrInsertFunction("__dfsan_union", DFSanUnionFnTy); if (Function *F = dyn_cast<Function>(DFSanUnionFn)) { - F->addAttribute(AttributeSet::FunctionIndex, Attribute::NoUnwind); - F->addAttribute(AttributeSet::FunctionIndex, Attribute::ReadNone); - F->addAttribute(AttributeSet::ReturnIndex, Attribute::ZExt); - F->addAttribute(1, Attribute::ZExt); - F->addAttribute(2, Attribute::ZExt); + F->addAttribute(AttributeList::FunctionIndex, Attribute::NoUnwind); + F->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone); + F->addAttribute(AttributeList::ReturnIndex, Attribute::ZExt); + F->addParamAttr(0, Attribute::ZExt); + F->addParamAttr(1, Attribute::ZExt); } DFSanCheckedUnionFn = Mod->getOrInsertFunction("dfsan_union", DFSanUnionFnTy); if (Function *F = dyn_cast<Function>(DFSanCheckedUnionFn)) { - F->addAttribute(AttributeSet::FunctionIndex, Attribute::NoUnwind); - F->addAttribute(AttributeSet::FunctionIndex, Attribute::ReadNone); - F->addAttribute(AttributeSet::ReturnIndex, Attribute::ZExt); - F->addAttribute(1, Attribute::ZExt); - F->addAttribute(2, Attribute::ZExt); + F->addAttribute(AttributeList::FunctionIndex, Attribute::NoUnwind); + F->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone); + F->addAttribute(AttributeList::ReturnIndex, Attribute::ZExt); + F->addParamAttr(0, Attribute::ZExt); + F->addParamAttr(1, Attribute::ZExt); } DFSanUnionLoadFn = Mod->getOrInsertFunction("__dfsan_union_load", DFSanUnionLoadFnTy); if (Function *F = dyn_cast<Function>(DFSanUnionLoadFn)) { - F->addAttribute(AttributeSet::FunctionIndex, Attribute::NoUnwind); - F->addAttribute(AttributeSet::FunctionIndex, Attribute::ReadOnly); - F->addAttribute(AttributeSet::ReturnIndex, Attribute::ZExt); + F->addAttribute(AttributeList::FunctionIndex, Attribute::NoUnwind); + F->addAttribute(AttributeList::FunctionIndex, Attribute::ReadOnly); + F->addAttribute(AttributeList::ReturnIndex, Attribute::ZExt); } DFSanUnimplementedFn = Mod->getOrInsertFunction("__dfsan_unimplemented", DFSanUnimplementedFnTy); DFSanSetLabelFn = Mod->getOrInsertFunction("__dfsan_set_label", DFSanSetLabelFnTy); if (Function *F = dyn_cast<Function>(DFSanSetLabelFn)) { - F->addAttribute(1, Attribute::ZExt); + F->addParamAttr(0, Attribute::ZExt); } DFSanNonzeroLabelFn = Mod->getOrInsertFunction("__dfsan_nonzero_label", DFSanNonzeroLabelFnTy); @@ -694,9 +692,8 @@ bool DataFlowSanitizer::runOnModule(Module &M) { } } - AttrBuilder B; - B.addAttribute(Attribute::ReadOnly).addAttribute(Attribute::ReadNone); - ReadOnlyNoneAttrs = AttributeSet::get(*Ctx, AttributeSet::FunctionIndex, B); + ReadOnlyNoneAttrs.addAttribute(Attribute::ReadOnly) + .addAttribute(Attribute::ReadNone); // First, change the ABI of every function in the module. ABI-listed // functions keep their original ABI and get a wrapper function. @@ -717,9 +714,8 @@ bool DataFlowSanitizer::runOnModule(Module &M) { Function *NewF = Function::Create(NewFT, F.getLinkage(), "", &M); NewF->copyAttributesFrom(&F); NewF->removeAttributes( - AttributeSet::ReturnIndex, - AttributeSet::get(NewF->getContext(), AttributeSet::ReturnIndex, - AttributeFuncs::typeIncompatible(NewFT->getReturnType()))); + AttributeList::ReturnIndex, + AttributeFuncs::typeIncompatible(NewFT->getReturnType())); for (Function::arg_iterator FArg = F.arg_begin(), NewFArg = NewF->arg_begin(), FArgEnd = F.arg_end(); @@ -758,7 +754,7 @@ bool DataFlowSanitizer::runOnModule(Module &M) { &F, std::string("dfsw$") + std::string(F.getName()), GlobalValue::LinkOnceODRLinkage, NewFT); if (getInstrumentedABI() == IA_TLS) - NewF->removeAttributes(AttributeSet::FunctionIndex, ReadOnlyNoneAttrs); + NewF->removeAttributes(AttributeList::FunctionIndex, ReadOnlyNoneAttrs); Value *WrappedFnCst = ConstantExpr::getBitCast(NewF, PointerType::getUnqual(FT)); @@ -906,7 +902,7 @@ Value *DFSanFunction::getShadow(Value *V) { break; } case DataFlowSanitizer::IA_Args: { - unsigned ArgIdx = A->getArgNo() + F->getArgumentList().size() / 2; + unsigned ArgIdx = A->getArgNo() + F->arg_size() / 2; Function::arg_iterator i = F->arg_begin(); while (ArgIdx--) ++i; @@ -983,9 +979,9 @@ Value *DFSanFunction::combineShadows(Value *V1, Value *V2, Instruction *Pos) { IRBuilder<> IRB(Pos); if (AvoidNewBlocks) { CallInst *Call = IRB.CreateCall(DFS.DFSanCheckedUnionFn, {V1, V2}); - Call->addAttribute(AttributeSet::ReturnIndex, Attribute::ZExt); - Call->addAttribute(1, Attribute::ZExt); - Call->addAttribute(2, Attribute::ZExt); + Call->addAttribute(AttributeList::ReturnIndex, Attribute::ZExt); + Call->addParamAttr(0, Attribute::ZExt); + Call->addParamAttr(1, Attribute::ZExt); CCS.Block = Pos->getParent(); CCS.Shadow = Call; @@ -996,9 +992,9 @@ Value *DFSanFunction::combineShadows(Value *V1, Value *V2, Instruction *Pos) { Ne, Pos, /*Unreachable=*/false, DFS.ColdCallWeights, &DT)); IRBuilder<> ThenIRB(BI); CallInst *Call = ThenIRB.CreateCall(DFS.DFSanUnionFn, {V1, V2}); - Call->addAttribute(AttributeSet::ReturnIndex, Attribute::ZExt); - Call->addAttribute(1, Attribute::ZExt); - Call->addAttribute(2, Attribute::ZExt); + Call->addAttribute(AttributeList::ReturnIndex, Attribute::ZExt); + Call->addParamAttr(0, Attribute::ZExt); + Call->addParamAttr(1, Attribute::ZExt); BasicBlock *Tail = BI->getSuccessor(0); PHINode *Phi = PHINode::Create(DFS.ShadowTy, 2, "", &Tail->front()); @@ -1099,7 +1095,7 @@ Value *DFSanFunction::loadShadow(Value *Addr, uint64_t Size, uint64_t Align, CallInst *FallbackCall = FallbackIRB.CreateCall( DFS.DFSanUnionLoadFn, {ShadowAddr, ConstantInt::get(DFS.IntptrTy, Size)}); - FallbackCall->addAttribute(AttributeSet::ReturnIndex, Attribute::ZExt); + FallbackCall->addAttribute(AttributeList::ReturnIndex, Attribute::ZExt); // Compare each of the shadows stored in the loaded 64 bits to each other, // by computing (WideShadow rotl ShadowWidth) == WideShadow. @@ -1156,7 +1152,7 @@ Value *DFSanFunction::loadShadow(Value *Addr, uint64_t Size, uint64_t Align, IRBuilder<> IRB(Pos); CallInst *FallbackCall = IRB.CreateCall( DFS.DFSanUnionLoadFn, {ShadowAddr, ConstantInt::get(DFS.IntptrTy, Size)}); - FallbackCall->addAttribute(AttributeSet::ReturnIndex, Attribute::ZExt); + FallbackCall->addAttribute(AttributeList::ReturnIndex, Attribute::ZExt); return FallbackCall; } @@ -1446,7 +1442,7 @@ void DFSanVisitor::visitCallSite(CallSite CS) { // Custom functions returning non-void will write to the return label. if (!FT->getReturnType()->isVoidTy()) { - CustomFn->removeAttributes(AttributeSet::FunctionIndex, + CustomFn->removeAttributes(AttributeList::FunctionIndex, DFSF.DFS.ReadOnlyNoneAttrs); } } @@ -1474,6 +1470,7 @@ void DFSanVisitor::visitCallSite(CallSite CS) { } i = CS.arg_begin(); + const unsigned ShadowArgStart = Args.size(); for (unsigned n = FT->getNumParams(); n != 0; ++i, --n) Args.push_back(DFSF.getShadow(*i)); @@ -1481,7 +1478,8 @@ void DFSanVisitor::visitCallSite(CallSite CS) { auto *LabelVATy = ArrayType::get(DFSF.DFS.ShadowTy, CS.arg_size() - FT->getNumParams()); auto *LabelVAAlloca = new AllocaInst( - LabelVATy, "labelva", &DFSF.F->getEntryBlock().front()); + LabelVATy, getDataLayout().getAllocaAddrSpace(), + "labelva", &DFSF.F->getEntryBlock().front()); for (unsigned n = 0; i != CS.arg_end(); ++i, ++n) { auto LabelVAPtr = IRB.CreateStructGEP(LabelVATy, LabelVAAlloca, n); @@ -1494,8 +1492,9 @@ void DFSanVisitor::visitCallSite(CallSite CS) { if (!FT->getReturnType()->isVoidTy()) { if (!DFSF.LabelReturnAlloca) { DFSF.LabelReturnAlloca = - new AllocaInst(DFSF.DFS.ShadowTy, "labelreturn", - &DFSF.F->getEntryBlock().front()); + new AllocaInst(DFSF.DFS.ShadowTy, + getDataLayout().getAllocaAddrSpace(), + "labelreturn", &DFSF.F->getEntryBlock().front()); } Args.push_back(DFSF.LabelReturnAlloca); } @@ -1507,6 +1506,15 @@ void DFSanVisitor::visitCallSite(CallSite CS) { CustomCI->setCallingConv(CI->getCallingConv()); CustomCI->setAttributes(CI->getAttributes()); + // Update the parameter attributes of the custom call instruction to + // zero extend the shadow parameters. This is required for targets + // which consider ShadowTy an illegal type. + for (unsigned n = 0; n < FT->getNumParams(); n++) { + const unsigned ArgNo = ShadowArgStart + n; + if (CustomCI->getArgOperand(ArgNo)->getType() == DFSF.DFS.ShadowTy) + CustomCI->addParamAttr(ArgNo, Attribute::ZExt); + } + if (!FT->getReturnType()->isVoidTy()) { LoadInst *LabelLoad = IRB.CreateLoad(DFSF.LabelReturnAlloca); DFSF.setShadow(CustomCI, LabelLoad); @@ -1574,7 +1582,8 @@ void DFSanVisitor::visitCallSite(CallSite CS) { unsigned VarArgSize = CS.arg_size() - FT->getNumParams(); ArrayType *VarArgArrayTy = ArrayType::get(DFSF.DFS.ShadowTy, VarArgSize); AllocaInst *VarArgShadow = - new AllocaInst(VarArgArrayTy, "", &DFSF.F->getEntryBlock().front()); + new AllocaInst(VarArgArrayTy, getDataLayout().getAllocaAddrSpace(), + "", &DFSF.F->getEntryBlock().front()); Args.push_back(IRB.CreateConstGEP2_32(VarArgArrayTy, VarArgShadow, 0, 0)); for (unsigned n = 0; i != e; ++i, ++n) { IRB.CreateStore( @@ -1593,7 +1602,7 @@ void DFSanVisitor::visitCallSite(CallSite CS) { } NewCS.setCallingConv(CS.getCallingConv()); NewCS.setAttributes(CS.getAttributes().removeAttributes( - *DFSF.DFS.Ctx, AttributeSet::ReturnIndex, + *DFSF.DFS.Ctx, AttributeList::ReturnIndex, AttributeFuncs::typeIncompatible(NewCS.getInstruction()->getType()))); if (Next) { diff --git a/contrib/llvm/lib/Transforms/Instrumentation/EfficiencySanitizer.cpp b/contrib/llvm/lib/Transforms/Instrumentation/EfficiencySanitizer.cpp index 05eba6c..6864d29 100644 --- a/contrib/llvm/lib/Transforms/Instrumentation/EfficiencySanitizer.cpp +++ b/contrib/llvm/lib/Transforms/Instrumentation/EfficiencySanitizer.cpp @@ -18,7 +18,6 @@ // The rest is handled by the run-time library. //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Instrumentation.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" @@ -32,6 +31,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ModuleUtils.h" @@ -267,35 +267,35 @@ void EfficiencySanitizer::initializeCallbacks(Module &M) { SmallString<32> AlignedLoadName("__esan_aligned_load" + ByteSizeStr); EsanAlignedLoad[Idx] = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - AlignedLoadName, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); + AlignedLoadName, IRB.getVoidTy(), IRB.getInt8PtrTy())); SmallString<32> AlignedStoreName("__esan_aligned_store" + ByteSizeStr); EsanAlignedStore[Idx] = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - AlignedStoreName, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); + AlignedStoreName, IRB.getVoidTy(), IRB.getInt8PtrTy())); SmallString<32> UnalignedLoadName("__esan_unaligned_load" + ByteSizeStr); EsanUnalignedLoad[Idx] = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - UnalignedLoadName, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); + UnalignedLoadName, IRB.getVoidTy(), IRB.getInt8PtrTy())); SmallString<32> UnalignedStoreName("__esan_unaligned_store" + ByteSizeStr); EsanUnalignedStore[Idx] = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - UnalignedStoreName, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); + UnalignedStoreName, IRB.getVoidTy(), IRB.getInt8PtrTy())); } EsanUnalignedLoadN = checkSanitizerInterfaceFunction( M.getOrInsertFunction("__esan_unaligned_loadN", IRB.getVoidTy(), - IRB.getInt8PtrTy(), IntptrTy, nullptr)); + IRB.getInt8PtrTy(), IntptrTy)); EsanUnalignedStoreN = checkSanitizerInterfaceFunction( M.getOrInsertFunction("__esan_unaligned_storeN", IRB.getVoidTy(), - IRB.getInt8PtrTy(), IntptrTy, nullptr)); + IRB.getInt8PtrTy(), IntptrTy)); MemmoveFn = checkSanitizerInterfaceFunction( M.getOrInsertFunction("memmove", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), - IRB.getInt8PtrTy(), IntptrTy, nullptr)); + IRB.getInt8PtrTy(), IntptrTy)); MemcpyFn = checkSanitizerInterfaceFunction( M.getOrInsertFunction("memcpy", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), - IRB.getInt8PtrTy(), IntptrTy, nullptr)); + IRB.getInt8PtrTy(), IntptrTy)); MemsetFn = checkSanitizerInterfaceFunction( M.getOrInsertFunction("memset", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), - IRB.getInt32Ty(), IntptrTy, nullptr)); + IRB.getInt32Ty(), IntptrTy)); } bool EfficiencySanitizer::shouldIgnoreStructType(StructType *StructTy) { @@ -398,8 +398,8 @@ GlobalVariable *EfficiencySanitizer::createCacheFragInfoGV( // u64 *ArrayCounter; // }; auto *StructInfoTy = - StructType::get(Int8PtrTy, Int32Ty, Int32Ty, Int32PtrTy, Int32PtrTy, - Int8PtrPtrTy, Int64PtrTy, Int64PtrTy, nullptr); + StructType::get(Int8PtrTy, Int32Ty, Int32Ty, Int32PtrTy, Int32PtrTy, + Int8PtrPtrTy, Int64PtrTy, Int64PtrTy); auto *StructInfoPtrTy = StructInfoTy->getPointerTo(); // This structure should be kept consistent with the CacheFragInfo struct // in the runtime library. @@ -408,8 +408,7 @@ GlobalVariable *EfficiencySanitizer::createCacheFragInfoGV( // u32 NumStructs; // StructInfo *Structs; // }; - auto *CacheFragInfoTy = - StructType::get(Int8PtrTy, Int32Ty, StructInfoPtrTy, nullptr); + auto *CacheFragInfoTy = StructType::get(Int8PtrTy, Int32Ty, StructInfoPtrTy); std::vector<StructType *> Vec = M.getIdentifiedStructTypes(); unsigned NumStructs = 0; @@ -457,24 +456,23 @@ GlobalVariable *EfficiencySanitizer::createCacheFragInfoGV( ArrayCounterIdx[0] = ConstantInt::get(Int32Ty, 0); ArrayCounterIdx[1] = ConstantInt::get(Int32Ty, getArrayCounterIdx(StructTy)); - Initializers.push_back( - ConstantStruct::get( - StructInfoTy, - ConstantExpr::getPointerCast(StructCounterName, Int8PtrTy), - ConstantInt::get(Int32Ty, - DL.getStructLayout(StructTy)->getSizeInBytes()), - ConstantInt::get(Int32Ty, StructTy->getNumElements()), - Offset == nullptr ? ConstantPointerNull::get(Int32PtrTy) : - ConstantExpr::getPointerCast(Offset, Int32PtrTy), - Size == nullptr ? ConstantPointerNull::get(Int32PtrTy) : - ConstantExpr::getPointerCast(Size, Int32PtrTy), - TypeName == nullptr ? ConstantPointerNull::get(Int8PtrPtrTy) : - ConstantExpr::getPointerCast(TypeName, Int8PtrPtrTy), - ConstantExpr::getGetElementPtr(CounterArrayTy, Counters, - FieldCounterIdx), - ConstantExpr::getGetElementPtr(CounterArrayTy, Counters, - ArrayCounterIdx), - nullptr)); + Initializers.push_back(ConstantStruct::get( + StructInfoTy, + ConstantExpr::getPointerCast(StructCounterName, Int8PtrTy), + ConstantInt::get(Int32Ty, + DL.getStructLayout(StructTy)->getSizeInBytes()), + ConstantInt::get(Int32Ty, StructTy->getNumElements()), + Offset == nullptr ? ConstantPointerNull::get(Int32PtrTy) + : ConstantExpr::getPointerCast(Offset, Int32PtrTy), + Size == nullptr ? ConstantPointerNull::get(Int32PtrTy) + : ConstantExpr::getPointerCast(Size, Int32PtrTy), + TypeName == nullptr + ? ConstantPointerNull::get(Int8PtrPtrTy) + : ConstantExpr::getPointerCast(TypeName, Int8PtrPtrTy), + ConstantExpr::getGetElementPtr(CounterArrayTy, Counters, + FieldCounterIdx), + ConstantExpr::getGetElementPtr(CounterArrayTy, Counters, + ArrayCounterIdx))); } // Structs. Constant *StructInfo; @@ -491,11 +489,8 @@ GlobalVariable *EfficiencySanitizer::createCacheFragInfoGV( auto *CacheFragInfoGV = new GlobalVariable( M, CacheFragInfoTy, true, GlobalVariable::InternalLinkage, - ConstantStruct::get(CacheFragInfoTy, - UnitName, - ConstantInt::get(Int32Ty, NumStructs), - StructInfo, - nullptr)); + ConstantStruct::get(CacheFragInfoTy, UnitName, + ConstantInt::get(Int32Ty, NumStructs), StructInfo)); return CacheFragInfoGV; } @@ -533,7 +528,7 @@ void EfficiencySanitizer::createDestructor(Module &M, Constant *ToolInfoArg) { IRBuilder<> IRB_Dtor(EsanDtorFunction->getEntryBlock().getTerminator()); Function *EsanExit = checkSanitizerInterfaceFunction( M.getOrInsertFunction(EsanExitName, IRB_Dtor.getVoidTy(), - Int8PtrTy, nullptr)); + Int8PtrTy)); EsanExit->setLinkage(Function::ExternalLinkage); IRB_Dtor.CreateCall(EsanExit, {ToolInfoArg}); appendToGlobalDtors(M, EsanDtorFunction, EsanCtorAndDtorPriority); @@ -757,7 +752,7 @@ bool EfficiencySanitizer::instrumentGetElementPtr(Instruction *I, Module &M) { return false; } Type *SourceTy = GepInst->getSourceElementType(); - StructType *StructTy; + StructType *StructTy = nullptr; ConstantInt *Idx; // Check if GEP calculates address from a struct array. if (isa<StructType>(SourceTy)) { diff --git a/contrib/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp b/contrib/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp index 1ba13bd..4089d81 100644 --- a/contrib/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp +++ b/contrib/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp @@ -1,4 +1,4 @@ -//===-- IndirectCallPromotion.cpp - Promote indirect calls to direct calls ===// +//===-- IndirectCallPromotion.cpp - Optimizations based on value profiling ===// // // The LLVM Compiler Infrastructure // @@ -17,6 +17,8 @@ #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/IndirectCallPromotionAnalysis.h" #include "llvm/Analysis/IndirectCallSiteVisitor.h" #include "llvm/IR/BasicBlock.h" @@ -40,6 +42,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/MathExtras.h" #include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/PGOInstrumentation.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" @@ -65,13 +68,13 @@ static cl::opt<bool> DisableICP("disable-icp", cl::init(false), cl::Hidden, // For debug use only. static cl::opt<unsigned> ICPCutOff("icp-cutoff", cl::init(0), cl::Hidden, cl::ZeroOrMore, - cl::desc("Max number of promotions for this compilaiton")); + cl::desc("Max number of promotions for this compilation")); // If ICPCSSkip is non zero, the first ICPCSSkip callsites will be skipped. // For debug use only. static cl::opt<unsigned> ICPCSSkip("icp-csskip", cl::init(0), cl::Hidden, cl::ZeroOrMore, - cl::desc("Skip Callsite up to this number for this compilaiton")); + cl::desc("Skip Callsite up to this number for this compilation")); // Set if the pass is called in LTO optimization. The difference for LTO mode // is the pass won't prefix the source module name to the internal linkage @@ -80,6 +83,12 @@ static cl::opt<bool> ICPLTOMode("icp-lto", cl::init(false), cl::Hidden, cl::desc("Run indirect-call promotion in LTO " "mode")); +// Set if the pass is called in SamplePGO mode. The difference for SamplePGO +// mode is it will add prof metadatato the created direct call. +static cl::opt<bool> + ICPSamplePGOMode("icp-samplepgo", cl::init(false), cl::Hidden, + cl::desc("Run indirect-call promotion in SamplePGO mode")); + // If the option is set to true, only call instructions will be considered for // transformation -- invoke instructions will be ignored. static cl::opt<bool> @@ -105,8 +114,8 @@ class PGOIndirectCallPromotionLegacyPass : public ModulePass { public: static char ID; - PGOIndirectCallPromotionLegacyPass(bool InLTO = false) - : ModulePass(ID), InLTO(InLTO) { + PGOIndirectCallPromotionLegacyPass(bool InLTO = false, bool SamplePGO = false) + : ModulePass(ID), InLTO(InLTO), SamplePGO(SamplePGO) { initializePGOIndirectCallPromotionLegacyPassPass( *PassRegistry::getPassRegistry()); } @@ -119,6 +128,10 @@ private: // If this pass is called in LTO. We need to special handling the PGOFuncName // for the static variables due to LTO's internalization. bool InLTO; + + // If this pass is called in SamplePGO. We need to add the prof metadata to + // the promoted direct call. + bool SamplePGO; }; } // end anonymous namespace @@ -128,8 +141,9 @@ INITIALIZE_PASS(PGOIndirectCallPromotionLegacyPass, "pgo-icall-prom", "direct calls.", false, false) -ModulePass *llvm::createPGOIndirectCallPromotionLegacyPass(bool InLTO) { - return new PGOIndirectCallPromotionLegacyPass(InLTO); +ModulePass *llvm::createPGOIndirectCallPromotionLegacyPass(bool InLTO, + bool SamplePGO) { + return new PGOIndirectCallPromotionLegacyPass(InLTO, SamplePGO); } namespace { @@ -144,17 +158,11 @@ private: // defines. InstrProfSymtab *Symtab; - enum TargetStatus { - OK, // Should be able to promote. - NotAvailableInModule, // Cannot find the target in current module. - ReturnTypeMismatch, // Return type mismatch b/w target and indirect-call. - NumArgsMismatch, // Number of arguments does not match. - ArgTypeMismatch // Type mismatch in the arguments (cannot bitcast). - }; + bool SamplePGO; // Test if we can legally promote this direct-call of Target. - TargetStatus isPromotionLegal(Instruction *Inst, uint64_t Target, - Function *&F); + bool isPromotionLegal(Instruction *Inst, uint64_t Target, Function *&F, + const char **Reason = nullptr); // A struct that records the direct target and it's call count. struct PromotionCandidate { @@ -172,91 +180,77 @@ private: Instruction *Inst, const ArrayRef<InstrProfValueData> &ValueDataRef, uint64_t TotalCount, uint32_t NumCandidates); - // Main function that transforms Inst (either a indirect-call instruction, or - // an invoke instruction , to a conditional call to F. This is like: - // if (Inst.CalledValue == F) - // F(...); - // else - // Inst(...); - // end - // TotalCount is the profile count value that the instruction executes. - // Count is the profile count value that F is the target function. - // These two values are being used to update the branch weight. - void promote(Instruction *Inst, Function *F, uint64_t Count, - uint64_t TotalCount); - // Promote a list of targets for one indirect-call callsite. Return // the number of promotions. uint32_t tryToPromote(Instruction *Inst, const std::vector<PromotionCandidate> &Candidates, uint64_t &TotalCount); - static const char *StatusToString(const TargetStatus S) { - switch (S) { - case OK: - return "OK to promote"; - case NotAvailableInModule: - return "Cannot find the target"; - case ReturnTypeMismatch: - return "Return type mismatch"; - case NumArgsMismatch: - return "The number of arguments mismatch"; - case ArgTypeMismatch: - return "Argument Type mismatch"; - } - llvm_unreachable("Should not reach here"); - } - // Noncopyable ICallPromotionFunc(const ICallPromotionFunc &other) = delete; ICallPromotionFunc &operator=(const ICallPromotionFunc &other) = delete; public: - ICallPromotionFunc(Function &Func, Module *Modu, InstrProfSymtab *Symtab) - : F(Func), M(Modu), Symtab(Symtab) { - } + ICallPromotionFunc(Function &Func, Module *Modu, InstrProfSymtab *Symtab, + bool SamplePGO) + : F(Func), M(Modu), Symtab(Symtab), SamplePGO(SamplePGO) {} bool processFunction(); }; } // end anonymous namespace -ICallPromotionFunc::TargetStatus -ICallPromotionFunc::isPromotionLegal(Instruction *Inst, uint64_t Target, - Function *&TargetFunction) { - Function *DirectCallee = Symtab->getFunction(Target); - if (DirectCallee == nullptr) - return NotAvailableInModule; +bool llvm::isLegalToPromote(Instruction *Inst, Function *F, + const char **Reason) { // Check the return type. Type *CallRetType = Inst->getType(); if (!CallRetType->isVoidTy()) { - Type *FuncRetType = DirectCallee->getReturnType(); + Type *FuncRetType = F->getReturnType(); if (FuncRetType != CallRetType && - !CastInst::isBitCastable(FuncRetType, CallRetType)) - return ReturnTypeMismatch; + !CastInst::isBitCastable(FuncRetType, CallRetType)) { + if (Reason) + *Reason = "Return type mismatch"; + return false; + } } // Check if the arguments are compatible with the parameters - FunctionType *DirectCalleeType = DirectCallee->getFunctionType(); + FunctionType *DirectCalleeType = F->getFunctionType(); unsigned ParamNum = DirectCalleeType->getFunctionNumParams(); CallSite CS(Inst); unsigned ArgNum = CS.arg_size(); - if (ParamNum != ArgNum && !DirectCalleeType->isVarArg()) - return NumArgsMismatch; + if (ParamNum != ArgNum && !DirectCalleeType->isVarArg()) { + if (Reason) + *Reason = "The number of arguments mismatch"; + return false; + } for (unsigned I = 0; I < ParamNum; ++I) { Type *PTy = DirectCalleeType->getFunctionParamType(I); Type *ATy = CS.getArgument(I)->getType(); if (PTy == ATy) continue; - if (!CastInst::castIsValid(Instruction::BitCast, CS.getArgument(I), PTy)) - return ArgTypeMismatch; + if (!CastInst::castIsValid(Instruction::BitCast, CS.getArgument(I), PTy)) { + if (Reason) + *Reason = "Argument type mismatch"; + return false; + } } DEBUG(dbgs() << " #" << NumOfPGOICallPromotion << " Promote the icall to " - << Symtab->getFuncName(Target) << "\n"); - TargetFunction = DirectCallee; - return OK; + << F->getName() << "\n"); + return true; +} + +bool ICallPromotionFunc::isPromotionLegal(Instruction *Inst, uint64_t Target, + Function *&TargetFunction, + const char **Reason) { + TargetFunction = Symtab->getFunction(Target); + if (TargetFunction == nullptr) { + *Reason = "Cannot find the target"; + return false; + } + return isLegalToPromote(Inst, TargetFunction, Reason); } // Indirect-call promotion heuristic. The direct targets are sorted based on @@ -296,10 +290,9 @@ ICallPromotionFunc::getPromotionCandidatesForCallSite( break; } Function *TargetFunction = nullptr; - TargetStatus Status = isPromotionLegal(Inst, Target, TargetFunction); - if (Status != OK) { + const char *Reason = nullptr; + if (!isPromotionLegal(Inst, Target, TargetFunction, &Reason)) { StringRef TargetFuncName = Symtab->getFuncName(Target); - const char *Reason = StatusToString(Status); DEBUG(dbgs() << " Not promote: " << Reason << "\n"); emitOptimizationRemarkMissed( F.getContext(), "pgo-icall-prom", F, Inst->getDebugLoc(), @@ -532,8 +525,14 @@ static void insertCallRetPHI(Instruction *Inst, Instruction *CallResult, // Ret = phi(Ret1, Ret2); // It adds type casts for the args do not match the parameters and the return // value. Branch weights metadata also updated. -void ICallPromotionFunc::promote(Instruction *Inst, Function *DirectCallee, - uint64_t Count, uint64_t TotalCount) { +// If \p AttachProfToDirectCall is true, a prof metadata is attached to the +// new direct call to contain \p Count. This is used by SamplePGO inliner to +// check callsite hotness. +// Returns the promoted direct call instruction. +Instruction *llvm::promoteIndirectCall(Instruction *Inst, + Function *DirectCallee, uint64_t Count, + uint64_t TotalCount, + bool AttachProfToDirectCall) { assert(DirectCallee != nullptr); BasicBlock *BB = Inst->getParent(); // Just to suppress the non-debug build warning. @@ -548,6 +547,14 @@ void ICallPromotionFunc::promote(Instruction *Inst, Function *DirectCallee, Instruction *NewInst = createDirectCallInst(Inst, DirectCallee, DirectCallBB, MergeBB); + if (AttachProfToDirectCall) { + SmallVector<uint32_t, 1> Weights; + Weights.push_back(Count); + MDBuilder MDB(NewInst->getContext()); + dyn_cast<Instruction>(NewInst->stripPointerCasts()) + ->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights)); + } + // Move Inst from MergeBB to IndirectCallBB. Inst->removeFromParent(); IndirectCallBB->getInstList().insert(IndirectCallBB->getFirstInsertionPt(), @@ -576,9 +583,10 @@ void ICallPromotionFunc::promote(Instruction *Inst, Function *DirectCallee, DEBUG(dbgs() << *BB << *DirectCallBB << *IndirectCallBB << *MergeBB << "\n"); emitOptimizationRemark( - F.getContext(), "pgo-icall-prom", F, Inst->getDebugLoc(), + BB->getContext(), "pgo-icall-prom", *BB->getParent(), Inst->getDebugLoc(), Twine("Promote indirect call to ") + DirectCallee->getName() + " with count " + Twine(Count) + " out of " + Twine(TotalCount)); + return NewInst; } // Promote indirect-call to conditional direct-call for one callsite. @@ -589,7 +597,7 @@ uint32_t ICallPromotionFunc::tryToPromote( for (auto &C : Candidates) { uint64_t Count = C.Count; - promote(Inst, C.TargetFunction, Count, TotalCount); + promoteIndirectCall(Inst, C.TargetFunction, Count, TotalCount, SamplePGO); assert(TotalCount >= Count); TotalCount -= Count; NumOfPGOICallPromotion++; @@ -630,18 +638,23 @@ bool ICallPromotionFunc::processFunction() { } // A wrapper function that does the actual work. -static bool promoteIndirectCalls(Module &M, bool InLTO) { +static bool promoteIndirectCalls(Module &M, bool InLTO, bool SamplePGO) { if (DisableICP) return false; InstrProfSymtab Symtab; - Symtab.create(M, InLTO); + if (Error E = Symtab.create(M, InLTO)) { + std::string SymtabFailure = toString(std::move(E)); + DEBUG(dbgs() << "Failed to create symtab: " << SymtabFailure << "\n"); + (void)SymtabFailure; + return false; + } bool Changed = false; for (auto &F : M) { if (F.isDeclaration()) continue; if (F.hasFnAttribute(Attribute::OptimizeNone)) continue; - ICallPromotionFunc ICallPromotion(F, &M, &Symtab); + ICallPromotionFunc ICallPromotion(F, &M, &Symtab, SamplePGO); bool FuncChanged = ICallPromotion.processFunction(); if (ICPDUMPAFTER && FuncChanged) { DEBUG(dbgs() << "\n== IR Dump After =="; F.print(dbgs())); @@ -658,11 +671,14 @@ static bool promoteIndirectCalls(Module &M, bool InLTO) { bool PGOIndirectCallPromotionLegacyPass::runOnModule(Module &M) { // Command-line option has the priority for InLTO. - return promoteIndirectCalls(M, InLTO | ICPLTOMode); + return promoteIndirectCalls(M, InLTO | ICPLTOMode, + SamplePGO | ICPSamplePGOMode); } -PreservedAnalyses PGOIndirectCallPromotion::run(Module &M, ModuleAnalysisManager &AM) { - if (!promoteIndirectCalls(M, InLTO | ICPLTOMode)) +PreservedAnalyses PGOIndirectCallPromotion::run(Module &M, + ModuleAnalysisManager &AM) { + if (!promoteIndirectCalls(M, InLTO | ICPLTOMode, + SamplePGO | ICPSamplePGOMode)) return PreservedAnalyses::all(); return PreservedAnalyses::none(); diff --git a/contrib/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp b/contrib/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp index adea7e7..db8fa89 100644 --- a/contrib/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp +++ b/contrib/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp @@ -14,18 +14,63 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/InstrProfiling.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" #include "llvm/ADT/Triple.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/Pass.h" #include "llvm/ProfileData/InstrProf.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/LoopSimplify.h" #include "llvm/Transforms/Utils/ModuleUtils.h" +#include "llvm/Transforms/Utils/SSAUpdater.h" +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <string> using namespace llvm; #define DEBUG_TYPE "instrprof" +// The start and end values of precise value profile range for memory +// intrinsic sizes +cl::opt<std::string> MemOPSizeRange( + "memop-size-range", + cl::desc("Set the range of size in memory intrinsic calls to be profiled " + "precisely, in a format of <start_val>:<end_val>"), + cl::init("")); + +// The value that considered to be large value in memory intrinsic. +cl::opt<unsigned> MemOPSizeLarge( + "memop-size-large", + cl::desc("Set large value thresthold in memory intrinsic size profiling. " + "Value of 0 disables the large value profiling."), + cl::init(8192)); + namespace { cl::opt<bool> DoNameCompression("enable-name-compression", @@ -41,6 +86,7 @@ cl::opt<bool> ValueProfileStaticAlloc( "vp-static-alloc", cl::desc("Do static counter allocation for value profiler"), cl::init(true)); + cl::opt<double> NumCountersPerValueSite( "vp-counters-per-site", cl::desc("The average number of profile counters allocated " @@ -51,14 +97,56 @@ cl::opt<double> NumCountersPerValueSite( // is usually smaller than 2. cl::init(1.0)); +cl::opt<bool> AtomicCounterUpdatePromoted( + "atomic-counter-update-promoted", cl::ZeroOrMore, + cl::desc("Do counter update using atomic fetch add " + " for promoted counters only"), + cl::init(false)); + +// If the option is not specified, the default behavior about whether +// counter promotion is done depends on how instrumentaiton lowering +// pipeline is setup, i.e., the default value of true of this option +// does not mean the promotion will be done by default. Explicitly +// setting this option can override the default behavior. +cl::opt<bool> DoCounterPromotion("do-counter-promotion", cl::ZeroOrMore, + cl::desc("Do counter register promotion"), + cl::init(false)); +cl::opt<unsigned> MaxNumOfPromotionsPerLoop( + cl::ZeroOrMore, "max-counter-promotions-per-loop", cl::init(20), + cl::desc("Max number counter promotions per loop to avoid" + " increasing register pressure too much")); + +// A debug option +cl::opt<int> + MaxNumOfPromotions(cl::ZeroOrMore, "max-counter-promotions", cl::init(-1), + cl::desc("Max number of allowed counter promotions")); + +cl::opt<unsigned> SpeculativeCounterPromotionMaxExiting( + cl::ZeroOrMore, "speculative-counter-promotion-max-exiting", cl::init(3), + cl::desc("The max number of exiting blocks of a loop to allow " + " speculative counter promotion")); + +cl::opt<bool> SpeculativeCounterPromotionToLoop( + cl::ZeroOrMore, "speculative-counter-promotion-to-loop", cl::init(false), + cl::desc("When the option is false, if the target block is in a loop, " + "the promotion will be disallowed unless the promoted counter " + " update can be further/iteratively promoted into an acyclic " + " region.")); + +cl::opt<bool> IterativeCounterPromotion( + cl::ZeroOrMore, "iterative-counter-promotion", cl::init(true), + cl::desc("Allow counter promotion across the whole loop nest.")); + class InstrProfilingLegacyPass : public ModulePass { InstrProfiling InstrProf; public: static char ID; - InstrProfilingLegacyPass() : ModulePass(ID), InstrProf() {} + + InstrProfilingLegacyPass() : ModulePass(ID) {} InstrProfilingLegacyPass(const InstrProfOptions &Options) : ModulePass(ID), InstrProf(Options) {} + StringRef getPassName() const override { return "Frontend instrumentation-based coverage lowering"; } @@ -73,7 +161,184 @@ public: } }; -} // anonymous namespace +/// +/// A helper class to promote one counter RMW operation in the loop +/// into register update. +/// +/// RWM update for the counter will be sinked out of the loop after +/// the transformation. +/// +class PGOCounterPromoterHelper : public LoadAndStorePromoter { +public: + PGOCounterPromoterHelper( + Instruction *L, Instruction *S, SSAUpdater &SSA, Value *Init, + BasicBlock *PH, ArrayRef<BasicBlock *> ExitBlocks, + ArrayRef<Instruction *> InsertPts, + DenseMap<Loop *, SmallVector<LoadStorePair, 8>> &LoopToCands, + LoopInfo &LI) + : LoadAndStorePromoter({L, S}, SSA), Store(S), ExitBlocks(ExitBlocks), + InsertPts(InsertPts), LoopToCandidates(LoopToCands), LI(LI) { + assert(isa<LoadInst>(L)); + assert(isa<StoreInst>(S)); + SSA.AddAvailableValue(PH, Init); + } + + void doExtraRewritesBeforeFinalDeletion() const override { + for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i) { + BasicBlock *ExitBlock = ExitBlocks[i]; + Instruction *InsertPos = InsertPts[i]; + // Get LiveIn value into the ExitBlock. If there are multiple + // predecessors, the value is defined by a PHI node in this + // block. + Value *LiveInValue = SSA.GetValueInMiddleOfBlock(ExitBlock); + Value *Addr = cast<StoreInst>(Store)->getPointerOperand(); + IRBuilder<> Builder(InsertPos); + if (AtomicCounterUpdatePromoted) + // automic update currently can only be promoted across the current + // loop, not the whole loop nest. + Builder.CreateAtomicRMW(AtomicRMWInst::Add, Addr, LiveInValue, + AtomicOrdering::SequentiallyConsistent); + else { + LoadInst *OldVal = Builder.CreateLoad(Addr, "pgocount.promoted"); + auto *NewVal = Builder.CreateAdd(OldVal, LiveInValue); + auto *NewStore = Builder.CreateStore(NewVal, Addr); + + // Now update the parent loop's candidate list: + if (IterativeCounterPromotion) { + auto *TargetLoop = LI.getLoopFor(ExitBlock); + if (TargetLoop) + LoopToCandidates[TargetLoop].emplace_back(OldVal, NewStore); + } + } + } + } + +private: + Instruction *Store; + ArrayRef<BasicBlock *> ExitBlocks; + ArrayRef<Instruction *> InsertPts; + DenseMap<Loop *, SmallVector<LoadStorePair, 8>> &LoopToCandidates; + LoopInfo &LI; +}; + +/// A helper class to do register promotion for all profile counter +/// updates in a loop. +/// +class PGOCounterPromoter { +public: + PGOCounterPromoter( + DenseMap<Loop *, SmallVector<LoadStorePair, 8>> &LoopToCands, + Loop &CurLoop, LoopInfo &LI) + : LoopToCandidates(LoopToCands), ExitBlocks(), InsertPts(), L(CurLoop), + LI(LI) { + + SmallVector<BasicBlock *, 8> LoopExitBlocks; + SmallPtrSet<BasicBlock *, 8> BlockSet; + L.getExitBlocks(LoopExitBlocks); + + for (BasicBlock *ExitBlock : LoopExitBlocks) { + if (BlockSet.insert(ExitBlock).second) { + ExitBlocks.push_back(ExitBlock); + InsertPts.push_back(&*ExitBlock->getFirstInsertionPt()); + } + } + } + + bool run(int64_t *NumPromoted) { + unsigned MaxProm = getMaxNumOfPromotionsInLoop(&L); + if (MaxProm == 0) + return false; + + unsigned Promoted = 0; + for (auto &Cand : LoopToCandidates[&L]) { + + SmallVector<PHINode *, 4> NewPHIs; + SSAUpdater SSA(&NewPHIs); + Value *InitVal = ConstantInt::get(Cand.first->getType(), 0); + + PGOCounterPromoterHelper Promoter(Cand.first, Cand.second, SSA, InitVal, + L.getLoopPreheader(), ExitBlocks, + InsertPts, LoopToCandidates, LI); + Promoter.run(SmallVector<Instruction *, 2>({Cand.first, Cand.second})); + Promoted++; + if (Promoted >= MaxProm) + break; + + (*NumPromoted)++; + if (MaxNumOfPromotions != -1 && *NumPromoted >= MaxNumOfPromotions) + break; + } + + DEBUG(dbgs() << Promoted << " counters promoted for loop (depth=" + << L.getLoopDepth() << ")\n"); + return Promoted != 0; + } + +private: + bool allowSpeculativeCounterPromotion(Loop *LP) { + SmallVector<BasicBlock *, 8> ExitingBlocks; + L.getExitingBlocks(ExitingBlocks); + // Not considierered speculative. + if (ExitingBlocks.size() == 1) + return true; + if (ExitingBlocks.size() > SpeculativeCounterPromotionMaxExiting) + return false; + return true; + } + + // Returns the max number of Counter Promotions for LP. + unsigned getMaxNumOfPromotionsInLoop(Loop *LP) { + // We can't insert into a catchswitch. + SmallVector<BasicBlock *, 8> LoopExitBlocks; + LP->getExitBlocks(LoopExitBlocks); + if (llvm::any_of(LoopExitBlocks, [](BasicBlock *Exit) { + return isa<CatchSwitchInst>(Exit->getTerminator()); + })) + return 0; + + if (!LP->hasDedicatedExits()) + return 0; + + BasicBlock *PH = LP->getLoopPreheader(); + if (!PH) + return 0; + + SmallVector<BasicBlock *, 8> ExitingBlocks; + LP->getExitingBlocks(ExitingBlocks); + // Not considierered speculative. + if (ExitingBlocks.size() == 1) + return MaxNumOfPromotionsPerLoop; + + if (ExitingBlocks.size() > SpeculativeCounterPromotionMaxExiting) + return 0; + + // Whether the target block is in a loop does not matter: + if (SpeculativeCounterPromotionToLoop) + return MaxNumOfPromotionsPerLoop; + + // Now check the target block: + unsigned MaxProm = MaxNumOfPromotionsPerLoop; + for (auto *TargetBlock : LoopExitBlocks) { + auto *TargetLoop = LI.getLoopFor(TargetBlock); + if (!TargetLoop) + continue; + unsigned MaxPromForTarget = getMaxNumOfPromotionsInLoop(TargetLoop); + unsigned PendingCandsInTarget = LoopToCandidates[TargetLoop].size(); + MaxProm = + std::min(MaxProm, std::max(MaxPromForTarget, PendingCandsInTarget) - + PendingCandsInTarget); + } + return MaxProm; + } + + DenseMap<Loop *, SmallVector<LoadStorePair, 8>> &LoopToCandidates; + SmallVector<BasicBlock *, 8> ExitBlocks; + SmallVector<Instruction *, 8> InsertPts; + Loop &L; + LoopInfo &LI; +}; + +} // end anonymous namespace PreservedAnalyses InstrProfiling::run(Module &M, ModuleAnalysisManager &AM) { auto &TLI = AM.getResult<TargetLibraryAnalysis>(M); @@ -97,35 +362,70 @@ llvm::createInstrProfilingLegacyPass(const InstrProfOptions &Options) { return new InstrProfilingLegacyPass(Options); } -bool InstrProfiling::isMachO() const { - return Triple(M->getTargetTriple()).isOSBinFormatMachO(); +static InstrProfIncrementInst *castToIncrementInst(Instruction *Instr) { + InstrProfIncrementInst *Inc = dyn_cast<InstrProfIncrementInstStep>(Instr); + if (Inc) + return Inc; + return dyn_cast<InstrProfIncrementInst>(Instr); } -/// Get the section name for the counter variables. -StringRef InstrProfiling::getCountersSection() const { - return getInstrProfCountersSectionName(isMachO()); -} +bool InstrProfiling::lowerIntrinsics(Function *F) { + bool MadeChange = false; + PromotionCandidates.clear(); + for (BasicBlock &BB : *F) { + for (auto I = BB.begin(), E = BB.end(); I != E;) { + auto Instr = I++; + InstrProfIncrementInst *Inc = castToIncrementInst(&*Instr); + if (Inc) { + lowerIncrement(Inc); + MadeChange = true; + } else if (auto *Ind = dyn_cast<InstrProfValueProfileInst>(Instr)) { + lowerValueProfileInst(Ind); + MadeChange = true; + } + } + } -/// Get the section name for the name variables. -StringRef InstrProfiling::getNameSection() const { - return getInstrProfNameSectionName(isMachO()); -} + if (!MadeChange) + return false; -/// Get the section name for the profile data variables. -StringRef InstrProfiling::getDataSection() const { - return getInstrProfDataSectionName(isMachO()); + promoteCounterLoadStores(F); + return true; } -/// Get the section name for the coverage mapping data. -StringRef InstrProfiling::getCoverageSection() const { - return getInstrProfCoverageSectionName(isMachO()); +bool InstrProfiling::isCounterPromotionEnabled() const { + if (DoCounterPromotion.getNumOccurrences() > 0) + return DoCounterPromotion; + + return Options.DoCounterPromotion; } -static InstrProfIncrementInst *castToIncrementInst(Instruction *Instr) { - InstrProfIncrementInst *Inc = dyn_cast<InstrProfIncrementInstStep>(Instr); - if (Inc) - return Inc; - return dyn_cast<InstrProfIncrementInst>(Instr); +void InstrProfiling::promoteCounterLoadStores(Function *F) { + if (!isCounterPromotionEnabled()) + return; + + DominatorTree DT(*F); + LoopInfo LI(DT); + DenseMap<Loop *, SmallVector<LoadStorePair, 8>> LoopPromotionCandidates; + + for (const auto &LoadStore : PromotionCandidates) { + auto *CounterLoad = LoadStore.first; + auto *CounterStore = LoadStore.second; + BasicBlock *BB = CounterLoad->getParent(); + Loop *ParentLoop = LI.getLoopFor(BB); + if (!ParentLoop) + continue; + LoopPromotionCandidates[ParentLoop].emplace_back(CounterLoad, CounterStore); + } + + SmallVector<Loop *, 4> Loops = LI.getLoopsInPreorder(); + + // Do a post-order traversal of the loops so that counter updates can be + // iteratively hoisted outside the loop nest. + for (auto *Loop : llvm::reverse(Loops)) { + PGOCounterPromoter Promoter(LoopPromotionCandidates, *Loop, LI); + Promoter.run(&TotalCountersPromoted); + } } bool InstrProfiling::run(Module &M, const TargetLibraryInfo &TLI) { @@ -137,6 +437,9 @@ bool InstrProfiling::run(Module &M, const TargetLibraryInfo &TLI) { NamesSize = 0; ProfileDataMap.clear(); UsedVars.clear(); + getMemOPSizeRangeFromOption(MemOPSizeRange, MemOPSizeRangeStart, + MemOPSizeRangeLast); + TT = Triple(M.getTargetTriple()); // We did not know how many value sites there would be inside // the instrumented function. This is counting the number of instrumented @@ -157,18 +460,7 @@ bool InstrProfiling::run(Module &M, const TargetLibraryInfo &TLI) { } for (Function &F : M) - for (BasicBlock &BB : F) - for (auto I = BB.begin(), E = BB.end(); I != E;) { - auto Instr = I++; - InstrProfIncrementInst *Inc = castToIncrementInst(&*Instr); - if (Inc) { - lowerIncrement(Inc); - MadeChange = true; - } else if (auto *Ind = dyn_cast<InstrProfValueProfileInst>(Instr)) { - lowerValueProfileInst(Ind); - MadeChange = true; - } - } + MadeChange |= lowerIntrinsics(&F); if (GlobalVariable *CoverageNamesVar = M.getNamedGlobal(getCoverageUnusedNamesVarName())) { @@ -189,26 +481,42 @@ bool InstrProfiling::run(Module &M, const TargetLibraryInfo &TLI) { } static Constant *getOrInsertValueProfilingCall(Module &M, - const TargetLibraryInfo &TLI) { + const TargetLibraryInfo &TLI, + bool IsRange = false) { LLVMContext &Ctx = M.getContext(); auto *ReturnTy = Type::getVoidTy(M.getContext()); - Type *ParamTypes[] = { + + Constant *Res; + if (!IsRange) { + Type *ParamTypes[] = { #define VALUE_PROF_FUNC_PARAM(ParamType, ParamName, ParamLLVMType) ParamLLVMType #include "llvm/ProfileData/InstrProfData.inc" - }; - auto *ValueProfilingCallTy = - FunctionType::get(ReturnTy, makeArrayRef(ParamTypes), false); - Constant *Res = M.getOrInsertFunction(getInstrProfValueProfFuncName(), - ValueProfilingCallTy); + }; + auto *ValueProfilingCallTy = + FunctionType::get(ReturnTy, makeArrayRef(ParamTypes), false); + Res = M.getOrInsertFunction(getInstrProfValueProfFuncName(), + ValueProfilingCallTy); + } else { + Type *RangeParamTypes[] = { +#define VALUE_RANGE_PROF 1 +#define VALUE_PROF_FUNC_PARAM(ParamType, ParamName, ParamLLVMType) ParamLLVMType +#include "llvm/ProfileData/InstrProfData.inc" +#undef VALUE_RANGE_PROF + }; + auto *ValueRangeProfilingCallTy = + FunctionType::get(ReturnTy, makeArrayRef(RangeParamTypes), false); + Res = M.getOrInsertFunction(getInstrProfValueRangeProfFuncName(), + ValueRangeProfilingCallTy); + } + if (Function *FunRes = dyn_cast<Function>(Res)) { if (auto AK = TLI.getExtAttrForI32Param(false)) - FunRes->addAttribute(3, AK); + FunRes->addParamAttr(2, AK); } return Res; } void InstrProfiling::computeNumValueSiteCounts(InstrProfValueProfileInst *Ind) { - GlobalVariable *Name = Ind->getName(); uint64_t ValueKind = Ind->getValueKind()->getZExtValue(); uint64_t Index = Ind->getIndex()->getZExtValue(); @@ -222,7 +530,6 @@ void InstrProfiling::computeNumValueSiteCounts(InstrProfValueProfileInst *Ind) { } void InstrProfiling::lowerValueProfileInst(InstrProfValueProfileInst *Ind) { - GlobalVariable *Name = Ind->getName(); auto It = ProfileDataMap.find(Name); assert(It != ProfileDataMap.end() && It->second.DataVar && @@ -235,13 +542,27 @@ void InstrProfiling::lowerValueProfileInst(InstrProfValueProfileInst *Ind) { Index += It->second.NumValueSites[Kind]; IRBuilder<> Builder(Ind); - Value *Args[3] = {Ind->getTargetValue(), - Builder.CreateBitCast(DataVar, Builder.getInt8PtrTy()), - Builder.getInt32(Index)}; - CallInst *Call = Builder.CreateCall(getOrInsertValueProfilingCall(*M, *TLI), - Args); + bool IsRange = (Ind->getValueKind()->getZExtValue() == + llvm::InstrProfValueKind::IPVK_MemOPSize); + CallInst *Call = nullptr; + if (!IsRange) { + Value *Args[3] = {Ind->getTargetValue(), + Builder.CreateBitCast(DataVar, Builder.getInt8PtrTy()), + Builder.getInt32(Index)}; + Call = Builder.CreateCall(getOrInsertValueProfilingCall(*M, *TLI), Args); + } else { + Value *Args[6] = { + Ind->getTargetValue(), + Builder.CreateBitCast(DataVar, Builder.getInt8PtrTy()), + Builder.getInt32(Index), + Builder.getInt64(MemOPSizeRangeStart), + Builder.getInt64(MemOPSizeRangeLast), + Builder.getInt64(MemOPSizeLarge == 0 ? INT64_MIN : MemOPSizeLarge)}; + Call = + Builder.CreateCall(getOrInsertValueProfilingCall(*M, *TLI, true), Args); + } if (auto AK = TLI->getExtAttrForI32Param(false)) - Call->addAttribute(3, AK); + Call->addParamAttr(2, AK); Ind->replaceAllUsesWith(Call); Ind->eraseFromParent(); } @@ -252,14 +573,16 @@ void InstrProfiling::lowerIncrement(InstrProfIncrementInst *Inc) { IRBuilder<> Builder(Inc); uint64_t Index = Inc->getIndex()->getZExtValue(); Value *Addr = Builder.CreateConstInBoundsGEP2_64(Counters, 0, Index); - Value *Count = Builder.CreateLoad(Addr, "pgocount"); - Count = Builder.CreateAdd(Count, Inc->getStep()); - Inc->replaceAllUsesWith(Builder.CreateStore(Count, Addr)); + Value *Load = Builder.CreateLoad(Addr, "pgocount"); + auto *Count = Builder.CreateAdd(Load, Inc->getStep()); + auto *Store = Builder.CreateStore(Count, Addr); + Inc->replaceAllUsesWith(Store); + if (isCounterPromotionEnabled()) + PromotionCandidates.emplace_back(cast<Instruction>(Load), Store); Inc->eraseFromParent(); } void InstrProfiling::lowerCoverageData(GlobalVariable *CoverageNamesVar) { - ConstantArray *Names = cast<ConstantArray>(CoverageNamesVar->getInitializer()); for (unsigned I = 0, E = Names->getNumOperands(); I < E; ++I) { @@ -270,7 +593,9 @@ void InstrProfiling::lowerCoverageData(GlobalVariable *CoverageNamesVar) { Name->setLinkage(GlobalValue::PrivateLinkage); ReferencedNames.push_back(Name); + NC->dropAllReferences(); } + CoverageNamesVar->eraseFromParent(); } /// Get the name of a profiling variable for a particular function. @@ -291,14 +616,24 @@ static std::string getVarName(InstrProfIncrementInst *Inc, StringRef Prefix) { static inline bool shouldRecordFunctionAddr(Function *F) { // Check the linkage + bool HasAvailableExternallyLinkage = F->hasAvailableExternallyLinkage(); if (!F->hasLinkOnceLinkage() && !F->hasLocalLinkage() && - !F->hasAvailableExternallyLinkage()) + !HasAvailableExternallyLinkage) return true; + + // A function marked 'alwaysinline' with available_externally linkage can't + // have its address taken. Doing so would create an undefined external ref to + // the function, which would fail to link. + if (HasAvailableExternallyLinkage && + F->hasFnAttribute(Attribute::AlwaysInline)) + return false; + // Prohibit function address recording if the function is both internal and // COMDAT. This avoids the profile data variable referencing internal symbols // in COMDAT. if (F->hasLocalLinkage() && F->hasComdat()) return false; + // Check uses of this function for other than direct calls or invokes to it. // Inline virtual functions have linkeOnceODR linkage. When a key method // exists, the vtable will only be emitted in the TU where the key method @@ -367,7 +702,8 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfIncrementInst *Inc) { Constant::getNullValue(CounterTy), getVarName(Inc, getInstrProfCountersVarPrefix())); CounterPtr->setVisibility(NamePtr->getVisibility()); - CounterPtr->setSection(getCountersSection()); + CounterPtr->setSection( + getInstrProfSectionName(IPSK_cnts, TT.getObjectFormat())); CounterPtr->setAlignment(8); CounterPtr->setComdat(ProfileVarsComdat); @@ -376,7 +712,6 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfIncrementInst *Inc) { // the current function. Constant *ValuesPtrExpr = ConstantPointerNull::get(Int8PtrTy); if (ValueProfileStaticAlloc && !needsRuntimeRegistrationOfSectionRange(*M)) { - uint64_t NS = 0; for (uint32_t Kind = IPVK_First; Kind <= IPVK_Last; ++Kind) NS += PD.NumValueSites[Kind]; @@ -388,11 +723,12 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfIncrementInst *Inc) { Constant::getNullValue(ValuesTy), getVarName(Inc, getInstrProfValuesVarPrefix())); ValuesVar->setVisibility(NamePtr->getVisibility()); - ValuesVar->setSection(getInstrProfValuesSectionName(isMachO())); + ValuesVar->setSection( + getInstrProfSectionName(IPSK_vals, TT.getObjectFormat())); ValuesVar->setAlignment(8); ValuesVar->setComdat(ProfileVarsComdat); ValuesPtrExpr = - ConstantExpr::getBitCast(ValuesVar, llvm::Type::getInt8PtrTy(Ctx)); + ConstantExpr::getBitCast(ValuesVar, Type::getInt8PtrTy(Ctx)); } } @@ -421,7 +757,7 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfIncrementInst *Inc) { ConstantStruct::get(DataTy, DataVals), getVarName(Inc, getInstrProfDataVarPrefix())); Data->setVisibility(NamePtr->getVisibility()); - Data->setSection(getDataSection()); + Data->setSection(getInstrProfSectionName(IPSK_data, TT.getObjectFormat())); Data->setAlignment(INSTR_PROF_DATA_ALIGNMENT); Data->setComdat(ProfileVarsComdat); @@ -481,9 +817,10 @@ void InstrProfiling::emitVNodes() { ArrayType *VNodesTy = ArrayType::get(VNodeTy, NumCounters); auto *VNodesVar = new GlobalVariable( - *M, VNodesTy, false, llvm::GlobalValue::PrivateLinkage, + *M, VNodesTy, false, GlobalValue::PrivateLinkage, Constant::getNullValue(VNodesTy), getInstrProfVNodesVarName()); - VNodesVar->setSection(getInstrProfVNodesSectionName(isMachO())); + VNodesVar->setSection( + getInstrProfSectionName(IPSK_vnodes, TT.getObjectFormat())); UsedVars.push_back(VNodesVar); } @@ -496,18 +833,22 @@ void InstrProfiling::emitNameData() { std::string CompressedNameStr; if (Error E = collectPGOFuncNameStrings(ReferencedNames, CompressedNameStr, DoNameCompression)) { - llvm::report_fatal_error(toString(std::move(E)), false); + report_fatal_error(toString(std::move(E)), false); } auto &Ctx = M->getContext(); - auto *NamesVal = llvm::ConstantDataArray::getString( + auto *NamesVal = ConstantDataArray::getString( Ctx, StringRef(CompressedNameStr), false); - NamesVar = new llvm::GlobalVariable(*M, NamesVal->getType(), true, - llvm::GlobalValue::PrivateLinkage, - NamesVal, getInstrProfNamesVarName()); + NamesVar = new GlobalVariable(*M, NamesVal->getType(), true, + GlobalValue::PrivateLinkage, NamesVal, + getInstrProfNamesVarName()); NamesSize = CompressedNameStr.size(); - NamesVar->setSection(getNameSection()); + NamesVar->setSection( + getInstrProfSectionName(IPSK_name, TT.getObjectFormat())); UsedVars.push_back(NamesVar); + + for (auto *NamePtr : ReferencedNames) + NamePtr->eraseFromParent(); } void InstrProfiling::emitRegistration() { @@ -550,7 +891,6 @@ void InstrProfiling::emitRegistration() { } void InstrProfiling::emitRuntimeHook() { - // We expect the linker to be invoked with -u<hook_var> flag for linux, // for which case there is no need to emit the user function. if (Triple(M->getTargetTriple()).isOSLinux()) @@ -600,7 +940,6 @@ void InstrProfiling::emitInitialization() { GlobalVariable *ProfileNameVar = new GlobalVariable( *M, ProfileNameConst->getType(), true, GlobalValue::WeakAnyLinkage, ProfileNameConst, INSTR_PROF_QUOTE(INSTR_PROF_PROFILE_NAME_VAR)); - Triple TT(M->getTargetTriple()); if (TT.supportsCOMDAT()) { ProfileNameVar->setLinkage(GlobalValue::ExternalLinkage); ProfileNameVar->setComdat(M->getOrInsertComdat( diff --git a/contrib/llvm/lib/Transforms/Instrumentation/Instrumentation.cpp b/contrib/llvm/lib/Transforms/Instrumentation/Instrumentation.cpp index 2963d08..7bb62d2 100644 --- a/contrib/llvm/lib/Transforms/Instrumentation/Instrumentation.cpp +++ b/contrib/llvm/lib/Transforms/Instrumentation/Instrumentation.cpp @@ -63,6 +63,7 @@ void llvm::initializeInstrumentation(PassRegistry &Registry) { initializePGOInstrumentationGenLegacyPassPass(Registry); initializePGOInstrumentationUseLegacyPassPass(Registry); initializePGOIndirectCallPromotionLegacyPassPass(Registry); + initializePGOMemOPSizeOptLegacyPassPass(Registry); initializeInstrProfilingLegacyPassPass(Registry); initializeMemorySanitizerPass(Registry); initializeThreadSanitizerPass(Registry); diff --git a/contrib/llvm/lib/Transforms/Instrumentation/MaximumSpanningTree.h b/contrib/llvm/lib/Transforms/Instrumentation/MaximumSpanningTree.h index 363539b..4eb758c 100644 --- a/contrib/llvm/lib/Transforms/Instrumentation/MaximumSpanningTree.h +++ b/contrib/llvm/lib/Transforms/Instrumentation/MaximumSpanningTree.h @@ -12,8 +12,8 @@ // //===----------------------------------------------------------------------===// -#ifndef LLVM_ANALYSIS_MAXIMUMSPANNINGTREE_H -#define LLVM_ANALYSIS_MAXIMUMSPANNINGTREE_H +#ifndef LLVM_LIB_TRANSFORMS_INSTRUMENTATION_MAXIMUMSPANNINGTREE_H +#define LLVM_LIB_TRANSFORMS_INSTRUMENTATION_MAXIMUMSPANNINGTREE_H #include "llvm/ADT/EquivalenceClasses.h" #include "llvm/IR/BasicBlock.h" @@ -108,4 +108,4 @@ namespace llvm { } // End llvm namespace -#endif +#endif // LLVM_LIB_TRANSFORMS_INSTRUMENTATION_MAXIMUMSPANNINGTREE_H diff --git a/contrib/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/contrib/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp index fafb0fc..b7c6271 100644 --- a/contrib/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp +++ b/contrib/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp @@ -425,7 +425,7 @@ void MemorySanitizer::initializeCallbacks(Module &M) { // which is not yet implemented. StringRef WarningFnName = Recover ? "__msan_warning" : "__msan_warning_noreturn"; - WarningFn = M.getOrInsertFunction(WarningFnName, IRB.getVoidTy(), nullptr); + WarningFn = M.getOrInsertFunction(WarningFnName, IRB.getVoidTy()); for (size_t AccessSizeIndex = 0; AccessSizeIndex < kNumberOfAccessSizes; AccessSizeIndex++) { @@ -433,31 +433,31 @@ void MemorySanitizer::initializeCallbacks(Module &M) { std::string FunctionName = "__msan_maybe_warning_" + itostr(AccessSize); MaybeWarningFn[AccessSizeIndex] = M.getOrInsertFunction( FunctionName, IRB.getVoidTy(), IRB.getIntNTy(AccessSize * 8), - IRB.getInt32Ty(), nullptr); + IRB.getInt32Ty()); FunctionName = "__msan_maybe_store_origin_" + itostr(AccessSize); MaybeStoreOriginFn[AccessSizeIndex] = M.getOrInsertFunction( FunctionName, IRB.getVoidTy(), IRB.getIntNTy(AccessSize * 8), - IRB.getInt8PtrTy(), IRB.getInt32Ty(), nullptr); + IRB.getInt8PtrTy(), IRB.getInt32Ty()); } MsanSetAllocaOrigin4Fn = M.getOrInsertFunction( "__msan_set_alloca_origin4", IRB.getVoidTy(), IRB.getInt8PtrTy(), IntptrTy, - IRB.getInt8PtrTy(), IntptrTy, nullptr); + IRB.getInt8PtrTy(), IntptrTy); MsanPoisonStackFn = M.getOrInsertFunction("__msan_poison_stack", IRB.getVoidTy(), - IRB.getInt8PtrTy(), IntptrTy, nullptr); + IRB.getInt8PtrTy(), IntptrTy); MsanChainOriginFn = M.getOrInsertFunction( - "__msan_chain_origin", IRB.getInt32Ty(), IRB.getInt32Ty(), nullptr); + "__msan_chain_origin", IRB.getInt32Ty(), IRB.getInt32Ty()); MemmoveFn = M.getOrInsertFunction( "__msan_memmove", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), - IRB.getInt8PtrTy(), IntptrTy, nullptr); + IRB.getInt8PtrTy(), IntptrTy); MemcpyFn = M.getOrInsertFunction( "__msan_memcpy", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), - IntptrTy, nullptr); + IntptrTy); MemsetFn = M.getOrInsertFunction( "__msan_memset", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IRB.getInt32Ty(), - IntptrTy, nullptr); + IntptrTy); // Create globals. RetvalTLS = new GlobalVariable( @@ -1037,15 +1037,19 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { OriginMap[V] = Origin; } + Constant *getCleanShadow(Type *OrigTy) { + Type *ShadowTy = getShadowTy(OrigTy); + if (!ShadowTy) + return nullptr; + return Constant::getNullValue(ShadowTy); + } + /// \brief Create a clean shadow value for a given value. /// /// Clean shadow (all zeroes) means all bits of the value are defined /// (initialized). Constant *getCleanShadow(Value *V) { - Type *ShadowTy = getShadowTy(V); - if (!ShadowTy) - return nullptr; - return Constant::getNullValue(ShadowTy); + return getCleanShadow(V->getType()); } /// \brief Create a dirty shadow of a given shadow type. @@ -1572,13 +1576,16 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { Value *CreateShadowCast(IRBuilder<> &IRB, Value *V, Type *dstTy, bool Signed = false) { Type *srcTy = V->getType(); + size_t srcSizeInBits = VectorOrPrimitiveTypeSizeInBits(srcTy); + size_t dstSizeInBits = VectorOrPrimitiveTypeSizeInBits(dstTy); + if (srcSizeInBits > 1 && dstSizeInBits == 1) + return IRB.CreateICmpNE(V, getCleanShadow(V)); + if (dstTy->isIntegerTy() && srcTy->isIntegerTy()) return IRB.CreateIntCast(V, dstTy, Signed); if (dstTy->isVectorTy() && srcTy->isVectorTy() && dstTy->getVectorNumElements() == srcTy->getVectorNumElements()) return IRB.CreateIntCast(V, dstTy, Signed); - size_t srcSizeInBits = VectorOrPrimitiveTypeSizeInBits(srcTy); - size_t dstSizeInBits = VectorOrPrimitiveTypeSizeInBits(dstTy); Value *V1 = IRB.CreateBitCast(V, Type::getIntNTy(*MS.C, srcSizeInBits)); Value *V2 = IRB.CreateIntCast(V1, Type::getIntNTy(*MS.C, dstSizeInBits), Signed); @@ -1942,7 +1949,6 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { if (ClCheckAccessAddress) insertShadowCheck(Addr, &I); - // FIXME: use ClStoreCleanOrigin // FIXME: factor out common code from materializeStores if (MS.TrackOrigins) IRB.CreateStore(getOrigin(&I, 1), getOriginPtr(Addr, IRB, 1)); @@ -2081,6 +2087,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { switch (I.getNumArgOperands()) { case 3: assert(isa<ConstantInt>(I.getArgOperand(2)) && "Invalid rounding mode"); + LLVM_FALLTHROUGH; case 2: CopyOp = I.getArgOperand(0); ConvertOp = I.getArgOperand(1); @@ -2325,11 +2332,49 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { setOriginForNaryOp(I); } + void handleStmxcsr(IntrinsicInst &I) { + IRBuilder<> IRB(&I); + Value* Addr = I.getArgOperand(0); + Type *Ty = IRB.getInt32Ty(); + Value *ShadowPtr = getShadowPtr(Addr, Ty, IRB); + + IRB.CreateStore(getCleanShadow(Ty), + IRB.CreatePointerCast(ShadowPtr, Ty->getPointerTo())); + + if (ClCheckAccessAddress) + insertShadowCheck(Addr, &I); + } + + void handleLdmxcsr(IntrinsicInst &I) { + if (!InsertChecks) return; + + IRBuilder<> IRB(&I); + Value *Addr = I.getArgOperand(0); + Type *Ty = IRB.getInt32Ty(); + unsigned Alignment = 1; + + if (ClCheckAccessAddress) + insertShadowCheck(Addr, &I); + + Value *Shadow = IRB.CreateAlignedLoad(getShadowPtr(Addr, Ty, IRB), + Alignment, "_ldmxcsr"); + Value *Origin = MS.TrackOrigins + ? IRB.CreateLoad(getOriginPtr(Addr, IRB, Alignment)) + : getCleanOrigin(); + insertShadowCheck(Shadow, Origin, &I); + } + void visitIntrinsicInst(IntrinsicInst &I) { switch (I.getIntrinsicID()) { case llvm::Intrinsic::bswap: handleBswap(I); break; + case llvm::Intrinsic::x86_sse_stmxcsr: + handleStmxcsr(I); + break; + case llvm::Intrinsic::x86_sse_ldmxcsr: + handleLdmxcsr(I); + break; case llvm::Intrinsic::x86_avx512_vcvtsd2usi64: case llvm::Intrinsic::x86_avx512_vcvtsd2usi32: case llvm::Intrinsic::x86_avx512_vcvtss2usi64: @@ -2566,10 +2611,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { AttrBuilder B; B.addAttribute(Attribute::ReadOnly) .addAttribute(Attribute::ReadNone); - Func->removeAttributes(AttributeSet::FunctionIndex, - AttributeSet::get(Func->getContext(), - AttributeSet::FunctionIndex, - B)); + Func->removeAttributes(AttributeList::FunctionIndex, B); } maybeMarkSanitizerLibraryCallNoBuiltin(Call, TLI); @@ -2597,12 +2639,12 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { " Shadow: " << *ArgShadow << "\n"); bool ArgIsInitialized = false; const DataLayout &DL = F.getParent()->getDataLayout(); - if (CS.paramHasAttr(i + 1, Attribute::ByVal)) { + if (CS.paramHasAttr(i, Attribute::ByVal)) { assert(A->getType()->isPointerTy() && "ByVal argument is not a pointer!"); Size = DL.getTypeAllocSize(A->getType()->getPointerElementType()); if (ArgOffset + Size > kParamTLSSize) break; - unsigned ParamAlignment = CS.getParamAlignment(i + 1); + unsigned ParamAlignment = CS.getParamAlignment(i); unsigned Alignment = std::min(ParamAlignment, kShadowTLSAlignment); Store = IRB.CreateMemCpy(ArgShadowBase, getShadowPtr(A, Type::getInt8Ty(*MS.C), IRB), @@ -2690,7 +2732,6 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } else { Value *Shadow = getShadow(RetVal); IRB.CreateAlignedStore(Shadow, ShadowPtr, kShadowTLSAlignment); - // FIXME: make it conditional if ClStoreCleanOrigin==0 if (MS.TrackOrigins) IRB.CreateStore(getOrigin(RetVal), getOriginPtrForRetval(IRB)); } @@ -2717,15 +2758,17 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { setOrigin(&I, getCleanOrigin()); IRBuilder<> IRB(I.getNextNode()); const DataLayout &DL = F.getParent()->getDataLayout(); - uint64_t Size = DL.getTypeAllocSize(I.getAllocatedType()); + uint64_t TypeSize = DL.getTypeAllocSize(I.getAllocatedType()); + Value *Len = ConstantInt::get(MS.IntptrTy, TypeSize); + if (I.isArrayAllocation()) + Len = IRB.CreateMul(Len, I.getArraySize()); if (PoisonStack && ClPoisonStackWithCall) { IRB.CreateCall(MS.MsanPoisonStackFn, - {IRB.CreatePointerCast(&I, IRB.getInt8PtrTy()), - ConstantInt::get(MS.IntptrTy, Size)}); + {IRB.CreatePointerCast(&I, IRB.getInt8PtrTy()), Len}); } else { Value *ShadowBase = getShadowPtr(&I, Type::getInt8PtrTy(*MS.C), IRB); Value *PoisonValue = IRB.getInt8(PoisonStack ? ClPoisonStackPattern : 0); - IRB.CreateMemSet(ShadowBase, PoisonValue, Size, I.getAlignment()); + IRB.CreateMemSet(ShadowBase, PoisonValue, Len, I.getAlignment()); } if (PoisonStack && MS.TrackOrigins) { @@ -2742,8 +2785,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { StackDescription.str()); IRB.CreateCall(MS.MsanSetAllocaOrigin4Fn, - {IRB.CreatePointerCast(&I, IRB.getInt8PtrTy()), - ConstantInt::get(MS.IntptrTy, Size), + {IRB.CreatePointerCast(&I, IRB.getInt8PtrTy()), Len, IRB.CreatePointerCast(Descr, IRB.getInt8PtrTy()), IRB.CreatePointerCast(&F, MS.IntptrTy)}); } @@ -2876,8 +2918,11 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { if (ClDumpStrictInstructions) dumpInst(I); DEBUG(dbgs() << "DEFAULT: " << I << "\n"); - for (size_t i = 0, n = I.getNumOperands(); i < n; i++) - insertShadowCheck(I.getOperand(i), &I); + for (size_t i = 0, n = I.getNumOperands(); i < n; i++) { + Value *Operand = I.getOperand(i); + if (Operand->getType()->isSized()) + insertShadowCheck(Operand, &I); + } setShadow(&I, getCleanShadow(&I)); setOrigin(&I, getCleanOrigin()); } @@ -2935,7 +2980,7 @@ struct VarArgAMD64Helper : public VarArgHelper { Value *A = *ArgIt; unsigned ArgNo = CS.getArgumentNo(ArgIt); bool IsFixed = ArgNo < CS.getFunctionType()->getNumParams(); - bool IsByVal = CS.paramHasAttr(ArgNo + 1, Attribute::ByVal); + bool IsByVal = CS.paramHasAttr(ArgNo, Attribute::ByVal); if (IsByVal) { // ByVal arguments always go to the overflow area. // Fixed arguments passed through the overflow area will be stepped @@ -2994,7 +3039,7 @@ struct VarArgAMD64Helper : public VarArgHelper { } void visitVAStartInst(VAStartInst &I) override { - if (F.getCallingConv() == CallingConv::X86_64_Win64) + if (F.getCallingConv() == CallingConv::Win64) return; IRBuilder<> IRB(&I); VAStartInstrumentationList.push_back(&I); @@ -3008,7 +3053,7 @@ struct VarArgAMD64Helper : public VarArgHelper { } void visitVACopyInst(VACopyInst &I) override { - if (F.getCallingConv() == CallingConv::X86_64_Win64) + if (F.getCallingConv() == CallingConv::Win64) return; IRBuilder<> IRB(&I); Value *VAListTag = I.getArgOperand(0); @@ -3456,12 +3501,12 @@ struct VarArgPowerPC64Helper : public VarArgHelper { Value *A = *ArgIt; unsigned ArgNo = CS.getArgumentNo(ArgIt); bool IsFixed = ArgNo < CS.getFunctionType()->getNumParams(); - bool IsByVal = CS.paramHasAttr(ArgNo + 1, Attribute::ByVal); + bool IsByVal = CS.paramHasAttr(ArgNo, Attribute::ByVal); if (IsByVal) { assert(A->getType()->isPointerTy()); Type *RealTy = A->getType()->getPointerElementType(); uint64_t ArgSize = DL.getTypeAllocSize(RealTy); - uint64_t ArgAlign = CS.getParamAlignment(ArgNo + 1); + uint64_t ArgAlign = CS.getParamAlignment(ArgNo); if (ArgAlign < 8) ArgAlign = 8; VAArgOffset = alignTo(VAArgOffset, ArgAlign); @@ -3618,9 +3663,7 @@ bool MemorySanitizer::runOnFunction(Function &F) { AttrBuilder B; B.addAttribute(Attribute::ReadOnly) .addAttribute(Attribute::ReadNone); - F.removeAttributes(AttributeSet::FunctionIndex, - AttributeSet::get(F.getContext(), - AttributeSet::FunctionIndex, B)); + F.removeAttributes(AttributeList::FunctionIndex, B); return Visitor.runOnFunction(); } diff --git a/contrib/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/contrib/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp index 04f9a64..8e4bfc0 100644 --- a/contrib/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp +++ b/contrib/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp @@ -58,8 +58,10 @@ #include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/IndirectCallSiteVisitor.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/GlobalValue.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" @@ -71,7 +73,9 @@ #include "llvm/ProfileData/InstrProfReader.h" #include "llvm/ProfileData/ProfileCommon.h" #include "llvm/Support/BranchProbability.h" +#include "llvm/Support/DOTGraphTraits.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/GraphWriter.h" #include "llvm/Support/JamCRC.h" #include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" @@ -87,6 +91,7 @@ using namespace llvm; STATISTIC(NumOfPGOInstrument, "Number of edges instrumented."); STATISTIC(NumOfPGOSelectInsts, "Number of select instruction instrumented."); +STATISTIC(NumOfPGOMemIntrinsics, "Number of mem intrinsics instrumented."); STATISTIC(NumOfPGOEdge, "Number of edges."); STATISTIC(NumOfPGOBB, "Number of basic-blocks."); STATISTIC(NumOfPGOSplit, "Number of critical edge splits."); @@ -116,6 +121,13 @@ static cl::opt<unsigned> MaxNumAnnotations( cl::desc("Max number of annotations for a single indirect " "call callsite")); +// Command line option to set the maximum number of value annotations +// to write to the metadata for a single memop intrinsic. +static cl::opt<unsigned> MaxNumMemOPAnnotations( + "memop-max-annotations", cl::init(4), cl::Hidden, cl::ZeroOrMore, + cl::desc("Max number of preicise value annotations for a single memop" + "intrinsic")); + // Command line option to control appending FunctionHash to the name of a COMDAT // function. This is to avoid the hash mismatch caused by the preinliner. static cl::opt<bool> DoComdatRenaming( @@ -125,26 +137,102 @@ static cl::opt<bool> DoComdatRenaming( // Command line option to enable/disable the warning about missing profile // information. -static cl::opt<bool> PGOWarnMissing("pgo-warn-missing-function", - cl::init(false), - cl::Hidden); +static cl::opt<bool> + PGOWarnMissing("pgo-warn-missing-function", cl::init(false), cl::Hidden, + cl::desc("Use this option to turn on/off " + "warnings about missing profile data for " + "functions.")); // Command line option to enable/disable the warning about a hash mismatch in // the profile data. -static cl::opt<bool> NoPGOWarnMismatch("no-pgo-warn-mismatch", cl::init(false), - cl::Hidden); +static cl::opt<bool> + NoPGOWarnMismatch("no-pgo-warn-mismatch", cl::init(false), cl::Hidden, + cl::desc("Use this option to turn off/on " + "warnings about profile cfg mismatch.")); // Command line option to enable/disable the warning about a hash mismatch in // the profile data for Comdat functions, which often turns out to be false // positive due to the pre-instrumentation inline. -static cl::opt<bool> NoPGOWarnMismatchComdat("no-pgo-warn-mismatch-comdat", - cl::init(true), cl::Hidden); +static cl::opt<bool> + NoPGOWarnMismatchComdat("no-pgo-warn-mismatch-comdat", cl::init(true), + cl::Hidden, + cl::desc("The option is used to turn on/off " + "warnings about hash mismatch for comdat " + "functions.")); // Command line option to enable/disable select instruction instrumentation. -static cl::opt<bool> PGOInstrSelect("pgo-instr-select", cl::init(true), - cl::Hidden); +static cl::opt<bool> + PGOInstrSelect("pgo-instr-select", cl::init(true), cl::Hidden, + cl::desc("Use this option to turn on/off SELECT " + "instruction instrumentation. ")); + +// Command line option to turn on CFG dot dump of raw profile counts +static cl::opt<bool> + PGOViewRawCounts("pgo-view-raw-counts", cl::init(false), cl::Hidden, + cl::desc("A boolean option to show CFG dag " + "with raw profile counts from " + "profile data. See also option " + "-pgo-view-counts. To limit graph " + "display to only one function, use " + "filtering option -view-bfi-func-name.")); + +// Command line option to enable/disable memop intrinsic call.size profiling. +static cl::opt<bool> + PGOInstrMemOP("pgo-instr-memop", cl::init(true), cl::Hidden, + cl::desc("Use this option to turn on/off " + "memory intrinsic size profiling.")); + +// Emit branch probability as optimization remarks. +static cl::opt<bool> + EmitBranchProbability("pgo-emit-branch-prob", cl::init(false), cl::Hidden, + cl::desc("When this option is on, the annotated " + "branch probability will be emitted as " + " optimization remarks: -Rpass-analysis=" + "pgo-instr-use")); + +// Command line option to turn on CFG dot dump after profile annotation. +// Defined in Analysis/BlockFrequencyInfo.cpp: -pgo-view-counts +extern cl::opt<bool> PGOViewCounts; + +// Command line option to specify the name of the function for CFG dump +// Defined in Analysis/BlockFrequencyInfo.cpp: -view-bfi-func-name= +extern cl::opt<std::string> ViewBlockFreqFuncName; + namespace { +// Return a string describing the branch condition that can be +// used in static branch probability heuristics: +std::string getBranchCondString(Instruction *TI) { + BranchInst *BI = dyn_cast<BranchInst>(TI); + if (!BI || !BI->isConditional()) + return std::string(); + + Value *Cond = BI->getCondition(); + ICmpInst *CI = dyn_cast<ICmpInst>(Cond); + if (!CI) + return std::string(); + + std::string result; + raw_string_ostream OS(result); + OS << CmpInst::getPredicateName(CI->getPredicate()) << "_"; + CI->getOperand(0)->getType()->print(OS, true); + + Value *RHS = CI->getOperand(1); + ConstantInt *CV = dyn_cast<ConstantInt>(RHS); + if (CV) { + if (CV->isZero()) + OS << "_Zero"; + else if (CV->isOne()) + OS << "_One"; + else if (CV->isMinusOne()) + OS << "_MinusOne"; + else + OS << "_Const"; + } + OS.flush(); + return result; +} + /// The select instruction visitor plays three roles specified /// by the mode. In \c VM_counting mode, it simply counts the number of /// select instructions. In \c VM_instrument mode, it inserts code to count @@ -167,6 +255,7 @@ struct SelectInstVisitor : public InstVisitor<SelectInstVisitor> { SelectInstVisitor(Function &Func) : F(Func) {} void countSelects(Function &Func) { + NSIs = 0; Mode = VM_counting; visit(Func); } @@ -196,9 +285,54 @@ struct SelectInstVisitor : public InstVisitor<SelectInstVisitor> { void annotateOneSelectInst(SelectInst &SI); // Visit \p SI instruction and perform tasks according to visit mode. void visitSelectInst(SelectInst &SI); + // Return the number of select instructions. This needs be called after + // countSelects(). unsigned getNumOfSelectInsts() const { return NSIs; } }; +/// Instruction Visitor class to visit memory intrinsic calls. +struct MemIntrinsicVisitor : public InstVisitor<MemIntrinsicVisitor> { + Function &F; + unsigned NMemIs = 0; // Number of memIntrinsics instrumented. + VisitMode Mode = VM_counting; // Visiting mode. + unsigned CurCtrId = 0; // Current counter index. + unsigned TotalNumCtrs = 0; // Total number of counters + GlobalVariable *FuncNameVar = nullptr; + uint64_t FuncHash = 0; + PGOUseFunc *UseFunc = nullptr; + std::vector<Instruction *> Candidates; + + MemIntrinsicVisitor(Function &Func) : F(Func) {} + + void countMemIntrinsics(Function &Func) { + NMemIs = 0; + Mode = VM_counting; + visit(Func); + } + + void instrumentMemIntrinsics(Function &Func, unsigned TotalNC, + GlobalVariable *FNV, uint64_t FHash) { + Mode = VM_instrument; + TotalNumCtrs = TotalNC; + FuncHash = FHash; + FuncNameVar = FNV; + visit(Func); + } + + std::vector<Instruction *> findMemIntrinsics(Function &Func) { + Candidates.clear(); + Mode = VM_annotate; + visit(Func); + return Candidates; + } + + // Visit the IR stream and annotate all mem intrinsic call instructions. + void instrumentOneMemIntrinsic(MemIntrinsic &MI); + // Visit \p MI instruction and perform tasks according to visit mode. + void visitMemIntrinsic(MemIntrinsic &SI); + unsigned getNumOfMemIntrinsics() const { return NMemIs; } +}; + class PGOInstrumentationGenLegacyPass : public ModulePass { public: static char ID; @@ -316,8 +450,9 @@ private: std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers; public: - std::vector<Instruction *> IndirectCallSites; + std::vector<std::vector<Instruction *>> ValueSites; SelectInstVisitor SIVisitor; + MemIntrinsicVisitor MIVisitor; std::string FuncName; GlobalVariable *FuncNameVar; // CFG hash value for this function. @@ -347,13 +482,16 @@ public: std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers, bool CreateGlobalVar = false, BranchProbabilityInfo *BPI = nullptr, BlockFrequencyInfo *BFI = nullptr) - : F(Func), ComdatMembers(ComdatMembers), SIVisitor(Func), FunctionHash(0), - MST(F, BPI, BFI) { + : F(Func), ComdatMembers(ComdatMembers), ValueSites(IPVK_Last + 1), + SIVisitor(Func), MIVisitor(Func), FunctionHash(0), MST(F, BPI, BFI) { // This should be done before CFG hash computation. SIVisitor.countSelects(Func); + MIVisitor.countMemIntrinsics(Func); NumOfPGOSelectInsts += SIVisitor.getNumOfSelectInsts(); - IndirectCallSites = findIndirectCallSites(Func); + NumOfPGOMemIntrinsics += MIVisitor.getNumOfMemIntrinsics(); + ValueSites[IPVK_IndirectCallTarget] = findIndirectCallSites(Func); + ValueSites[IPVK_MemOPSize] = MIVisitor.findMemIntrinsics(Func); FuncName = getPGOFuncName(F); computeCFGHash(); @@ -405,7 +543,7 @@ void FuncPGOInstrumentation<Edge, BBInfo>::computeCFGHash() { } JC.update(Indexes); FunctionHash = (uint64_t)SIVisitor.getNumOfSelectInsts() << 56 | - (uint64_t)IndirectCallSites.size() << 48 | + (uint64_t)ValueSites[IPVK_IndirectCallTarget].size() << 48 | (uint64_t)MST.AllEdges.size() << 32 | JC.getCRC(); } @@ -552,7 +690,7 @@ static void instrumentOneFunc( return; unsigned NumIndirectCallSites = 0; - for (auto &I : FuncInfo.IndirectCallSites) { + for (auto &I : FuncInfo.ValueSites[IPVK_IndirectCallTarget]) { CallSite CS(I); Value *Callee = CS.getCalledValue(); DEBUG(dbgs() << "Instrument one indirect call: CallSite Index = " @@ -565,10 +703,14 @@ static void instrumentOneFunc( {llvm::ConstantExpr::getBitCast(FuncInfo.FuncNameVar, I8PtrTy), Builder.getInt64(FuncInfo.FunctionHash), Builder.CreatePtrToInt(Callee, Builder.getInt64Ty()), - Builder.getInt32(llvm::InstrProfValueKind::IPVK_IndirectCallTarget), + Builder.getInt32(IPVK_IndirectCallTarget), Builder.getInt32(NumIndirectCallSites++)}); } NumOfPGOICall += NumIndirectCallSites; + + // Now instrument memop intrinsic calls. + FuncInfo.MIVisitor.instrumentMemIntrinsics( + F, NumCounters, FuncInfo.FuncNameVar, FuncInfo.FunctionHash); } // This class represents a CFG edge in profile use compilation. @@ -653,8 +795,11 @@ public: // Set the branch weights based on the count values. void setBranchWeights(); - // Annotate the indirect call sites. - void annotateIndirectCallSites(); + // Annotate the value profile call sites all all value kind. + void annotateValueSites(); + + // Annotate the value profile call sites for one value kind. + void annotateValueSites(uint32_t Kind); // The hotness of the function from the profile count. enum FuncFreqAttr { FFA_Normal, FFA_Cold, FFA_Hot }; @@ -677,6 +822,8 @@ public: return FuncInfo.findBBInfo(BB); } + Function &getFunc() const { return F; } + private: Function &F; Module *M; @@ -761,7 +908,7 @@ void PGOUseFunc::setInstrumentedCounts( NewEdge1.InMST = true; getBBInfo(InstrBB).setBBInfoCount(CountValue); } - ProfileCountSize = CountFromProfile.size(); + ProfileCountSize = CountFromProfile.size(); CountPosition = I; } @@ -932,21 +1079,6 @@ void PGOUseFunc::populateCounters() { DEBUG(FuncInfo.dumpInfo("after reading profile.")); } -static void setProfMetadata(Module *M, Instruction *TI, - ArrayRef<uint64_t> EdgeCounts, uint64_t MaxCount) { - MDBuilder MDB(M->getContext()); - assert(MaxCount > 0 && "Bad max count"); - uint64_t Scale = calculateCountScale(MaxCount); - SmallVector<unsigned, 4> Weights; - for (const auto &ECI : EdgeCounts) - Weights.push_back(scaleBranchCount(ECI, Scale)); - - DEBUG(dbgs() << "Weight is: "; - for (const auto &W : Weights) { dbgs() << W << " "; } - dbgs() << "\n";); - TI->setMetadata(llvm::LLVMContext::MD_prof, MDB.createBranchWeights(Weights)); -} - // Assign the scaled count values to the BB with multiple out edges. void PGOUseFunc::setBranchWeights() { // Generate MD_prof metadata for every branch instruction. @@ -990,8 +1122,8 @@ void SelectInstVisitor::instrumentOneSelectInst(SelectInst &SI) { Builder.CreateCall( Intrinsic::getDeclaration(M, Intrinsic::instrprof_increment_step), {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy), - Builder.getInt64(FuncHash), - Builder.getInt32(TotalNumCtrs), Builder.getInt32(*CurCtrIdx), Step}); + Builder.getInt64(FuncHash), Builder.getInt32(TotalNumCtrs), + Builder.getInt32(*CurCtrIdx), Step}); ++(*CurCtrIdx); } @@ -1020,9 +1152,9 @@ void SelectInstVisitor::visitSelectInst(SelectInst &SI) { if (SI.getCondition()->getType()->isVectorTy()) return; - NSIs++; switch (Mode) { case VM_counting: + NSIs++; return; case VM_instrument: instrumentOneSelectInst(SI); @@ -1035,35 +1167,79 @@ void SelectInstVisitor::visitSelectInst(SelectInst &SI) { llvm_unreachable("Unknown visiting mode"); } -// Traverse all the indirect callsites and annotate the instructions. -void PGOUseFunc::annotateIndirectCallSites() { +void MemIntrinsicVisitor::instrumentOneMemIntrinsic(MemIntrinsic &MI) { + Module *M = F.getParent(); + IRBuilder<> Builder(&MI); + Type *Int64Ty = Builder.getInt64Ty(); + Type *I8PtrTy = Builder.getInt8PtrTy(); + Value *Length = MI.getLength(); + assert(!dyn_cast<ConstantInt>(Length)); + Builder.CreateCall( + Intrinsic::getDeclaration(M, Intrinsic::instrprof_value_profile), + {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy), + Builder.getInt64(FuncHash), Builder.CreateZExtOrTrunc(Length, Int64Ty), + Builder.getInt32(IPVK_MemOPSize), Builder.getInt32(CurCtrId)}); + ++CurCtrId; +} + +void MemIntrinsicVisitor::visitMemIntrinsic(MemIntrinsic &MI) { + if (!PGOInstrMemOP) + return; + Value *Length = MI.getLength(); + // Not instrument constant length calls. + if (dyn_cast<ConstantInt>(Length)) + return; + + switch (Mode) { + case VM_counting: + NMemIs++; + return; + case VM_instrument: + instrumentOneMemIntrinsic(MI); + return; + case VM_annotate: + Candidates.push_back(&MI); + return; + } + llvm_unreachable("Unknown visiting mode"); +} + +// Traverse all valuesites and annotate the instructions for all value kind. +void PGOUseFunc::annotateValueSites() { if (DisableValueProfiling) return; // Create the PGOFuncName meta data. createPGOFuncNameMetadata(F, FuncInfo.FuncName); - unsigned IndirectCallSiteIndex = 0; - auto &IndirectCallSites = FuncInfo.IndirectCallSites; - unsigned NumValueSites = - ProfileRecord.getNumValueSites(IPVK_IndirectCallTarget); - if (NumValueSites != IndirectCallSites.size()) { - std::string Msg = - std::string("Inconsistent number of indirect call sites: ") + - F.getName().str(); + for (uint32_t Kind = IPVK_First; Kind <= IPVK_Last; ++Kind) + annotateValueSites(Kind); +} + +// Annotate the instructions for a specific value kind. +void PGOUseFunc::annotateValueSites(uint32_t Kind) { + unsigned ValueSiteIndex = 0; + auto &ValueSites = FuncInfo.ValueSites[Kind]; + unsigned NumValueSites = ProfileRecord.getNumValueSites(Kind); + if (NumValueSites != ValueSites.size()) { auto &Ctx = M->getContext(); - Ctx.diagnose( - DiagnosticInfoPGOProfile(M->getName().data(), Msg, DS_Warning)); + Ctx.diagnose(DiagnosticInfoPGOProfile( + M->getName().data(), + Twine("Inconsistent number of value sites for kind = ") + Twine(Kind) + + " in " + F.getName().str(), + DS_Warning)); return; } - for (auto &I : IndirectCallSites) { - DEBUG(dbgs() << "Read one indirect call instrumentation: Index=" - << IndirectCallSiteIndex << " out of " << NumValueSites - << "\n"); - annotateValueSite(*M, *I, ProfileRecord, IPVK_IndirectCallTarget, - IndirectCallSiteIndex, MaxNumAnnotations); - IndirectCallSiteIndex++; + for (auto &I : ValueSites) { + DEBUG(dbgs() << "Read one value site profile (kind = " << Kind + << "): Index = " << ValueSiteIndex << " out of " + << NumValueSites << "\n"); + annotateValueSite(*M, *I, ProfileRecord, + static_cast<InstrProfValueKind>(Kind), ValueSiteIndex, + Kind == IPVK_MemOPSize ? MaxNumMemOPAnnotations + : MaxNumAnnotations); + ValueSiteIndex++; } } } // end anonymous namespace @@ -1196,12 +1372,29 @@ static bool annotateAllFunctions( continue; Func.populateCounters(); Func.setBranchWeights(); - Func.annotateIndirectCallSites(); + Func.annotateValueSites(); PGOUseFunc::FuncFreqAttr FreqAttr = Func.getFuncFreqAttr(); if (FreqAttr == PGOUseFunc::FFA_Cold) ColdFunctions.push_back(&F); else if (FreqAttr == PGOUseFunc::FFA_Hot) HotFunctions.push_back(&F); + if (PGOViewCounts && (ViewBlockFreqFuncName.empty() || + F.getName().equals(ViewBlockFreqFuncName))) { + LoopInfo LI{DominatorTree(F)}; + std::unique_ptr<BranchProbabilityInfo> NewBPI = + llvm::make_unique<BranchProbabilityInfo>(F, LI); + std::unique_ptr<BlockFrequencyInfo> NewBFI = + llvm::make_unique<BlockFrequencyInfo>(F, *NewBPI, LI); + + NewBFI->view(); + } + if (PGOViewRawCounts && (ViewBlockFreqFuncName.empty() || + F.getName().equals(ViewBlockFreqFuncName))) { + if (ViewBlockFreqFuncName.empty()) + WriteGraph(&Func, Twine("PGORawCounts_") + Func.getFunc().getName()); + else + ViewGraph(&Func, Twine("PGORawCounts_") + Func.getFunc().getName()); + } } M.setProfileSummary(PGOReader->getSummary().getMD(M.getContext())); // Set function hotness attribute from the profile. @@ -1257,3 +1450,113 @@ bool PGOInstrumentationUseLegacyPass::runOnModule(Module &M) { return annotateAllFunctions(M, ProfileFileName, LookupBPI, LookupBFI); } + +namespace llvm { +void setProfMetadata(Module *M, Instruction *TI, ArrayRef<uint64_t> EdgeCounts, + uint64_t MaxCount) { + MDBuilder MDB(M->getContext()); + assert(MaxCount > 0 && "Bad max count"); + uint64_t Scale = calculateCountScale(MaxCount); + SmallVector<unsigned, 4> Weights; + for (const auto &ECI : EdgeCounts) + Weights.push_back(scaleBranchCount(ECI, Scale)); + + DEBUG(dbgs() << "Weight is: "; + for (const auto &W : Weights) { dbgs() << W << " "; } + dbgs() << "\n";); + TI->setMetadata(llvm::LLVMContext::MD_prof, MDB.createBranchWeights(Weights)); + if (EmitBranchProbability) { + std::string BrCondStr = getBranchCondString(TI); + if (BrCondStr.empty()) + return; + + unsigned WSum = + std::accumulate(Weights.begin(), Weights.end(), 0, + [](unsigned w1, unsigned w2) { return w1 + w2; }); + uint64_t TotalCount = + std::accumulate(EdgeCounts.begin(), EdgeCounts.end(), 0, + [](uint64_t c1, uint64_t c2) { return c1 + c2; }); + BranchProbability BP(Weights[0], WSum); + std::string BranchProbStr; + raw_string_ostream OS(BranchProbStr); + OS << BP; + OS << " (total count : " << TotalCount << ")"; + OS.flush(); + Function *F = TI->getParent()->getParent(); + emitOptimizationRemarkAnalysis( + F->getContext(), "pgo-use-annot", *F, TI->getDebugLoc(), + Twine(BrCondStr) + + " is true with probability : " + Twine(BranchProbStr)); + } +} + +template <> struct GraphTraits<PGOUseFunc *> { + typedef const BasicBlock *NodeRef; + typedef succ_const_iterator ChildIteratorType; + typedef pointer_iterator<Function::const_iterator> nodes_iterator; + + static NodeRef getEntryNode(const PGOUseFunc *G) { + return &G->getFunc().front(); + } + static ChildIteratorType child_begin(const NodeRef N) { + return succ_begin(N); + } + static ChildIteratorType child_end(const NodeRef N) { return succ_end(N); } + static nodes_iterator nodes_begin(const PGOUseFunc *G) { + return nodes_iterator(G->getFunc().begin()); + } + static nodes_iterator nodes_end(const PGOUseFunc *G) { + return nodes_iterator(G->getFunc().end()); + } +}; + +static std::string getSimpleNodeName(const BasicBlock *Node) { + if (!Node->getName().empty()) + return Node->getName(); + + std::string SimpleNodeName; + raw_string_ostream OS(SimpleNodeName); + Node->printAsOperand(OS, false); + return OS.str(); +} + +template <> struct DOTGraphTraits<PGOUseFunc *> : DefaultDOTGraphTraits { + explicit DOTGraphTraits(bool isSimple = false) + : DefaultDOTGraphTraits(isSimple) {} + + static std::string getGraphName(const PGOUseFunc *G) { + return G->getFunc().getName(); + } + + std::string getNodeLabel(const BasicBlock *Node, const PGOUseFunc *Graph) { + std::string Result; + raw_string_ostream OS(Result); + + OS << getSimpleNodeName(Node) << ":\\l"; + UseBBInfo *BI = Graph->findBBInfo(Node); + OS << "Count : "; + if (BI && BI->CountValid) + OS << BI->CountValue << "\\l"; + else + OS << "Unknown\\l"; + + if (!PGOInstrSelect) + return Result; + + for (auto BI = Node->begin(); BI != Node->end(); ++BI) { + auto *I = &*BI; + if (!isa<SelectInst>(I)) + continue; + // Display scaled counts for SELECT instruction: + OS << "SELECT : { T = "; + uint64_t TC, FC; + bool HasProf = I->extractProfMetadata(TC, FC); + if (!HasProf) + OS << "Unknown, F = Unknown }\\l"; + else + OS << TC << ", F = " << FC << " }\\l"; + } + return Result; + } +}; +} // namespace llvm diff --git a/contrib/llvm/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp b/contrib/llvm/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp new file mode 100644 index 0000000..0bc9ddf --- /dev/null +++ b/contrib/llvm/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp @@ -0,0 +1,419 @@ +//===-- PGOMemOPSizeOpt.cpp - Optimizations based on value profiling ===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the transformation that optimizes memory intrinsics +// such as memcpy using the size value profile. When memory intrinsic size +// value profile metadata is available, a single memory intrinsic is expanded +// to a sequence of guarded specialized versions that are called with the +// hottest size(s), for later expansion into more optimal inline sequences. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/InstVisitor.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Type.h" +#include "llvm/Pass.h" +#include "llvm/PassRegistry.h" +#include "llvm/PassSupport.h" +#include "llvm/ProfileData/InstrProf.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Transforms/Instrumentation.h" +#include "llvm/Transforms/PGOInstrumentation.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include <cassert> +#include <cstdint> +#include <vector> + +using namespace llvm; + +#define DEBUG_TYPE "pgo-memop-opt" + +STATISTIC(NumOfPGOMemOPOpt, "Number of memop intrinsics optimized."); +STATISTIC(NumOfPGOMemOPAnnotate, "Number of memop intrinsics annotated."); + +// The minimum call count to optimize memory intrinsic calls. +static cl::opt<unsigned> + MemOPCountThreshold("pgo-memop-count-threshold", cl::Hidden, cl::ZeroOrMore, + cl::init(1000), + cl::desc("The minimum count to optimize memory " + "intrinsic calls")); + +// Command line option to disable memory intrinsic optimization. The default is +// false. This is for debug purpose. +static cl::opt<bool> DisableMemOPOPT("disable-memop-opt", cl::init(false), + cl::Hidden, cl::desc("Disable optimize")); + +// The percent threshold to optimize memory intrinsic calls. +static cl::opt<unsigned> + MemOPPercentThreshold("pgo-memop-percent-threshold", cl::init(40), + cl::Hidden, cl::ZeroOrMore, + cl::desc("The percentage threshold for the " + "memory intrinsic calls optimization")); + +// Maximum number of versions for optimizing memory intrinsic call. +static cl::opt<unsigned> + MemOPMaxVersion("pgo-memop-max-version", cl::init(3), cl::Hidden, + cl::ZeroOrMore, + cl::desc("The max version for the optimized memory " + " intrinsic calls")); + +// Scale the counts from the annotation using the BB count value. +static cl::opt<bool> + MemOPScaleCount("pgo-memop-scale-count", cl::init(true), cl::Hidden, + cl::desc("Scale the memop size counts using the basic " + " block count value")); + +// This option sets the rangge of precise profile memop sizes. +extern cl::opt<std::string> MemOPSizeRange; + +// This option sets the value that groups large memop sizes +extern cl::opt<unsigned> MemOPSizeLarge; + +namespace { +class PGOMemOPSizeOptLegacyPass : public FunctionPass { +public: + static char ID; + + PGOMemOPSizeOptLegacyPass() : FunctionPass(ID) { + initializePGOMemOPSizeOptLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + StringRef getPassName() const override { return "PGOMemOPSize"; } + +private: + bool runOnFunction(Function &F) override; + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<BlockFrequencyInfoWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); + } +}; +} // end anonymous namespace + +char PGOMemOPSizeOptLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(PGOMemOPSizeOptLegacyPass, "pgo-memop-opt", + "Optimize memory intrinsic using its size value profile", + false, false) +INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass) +INITIALIZE_PASS_END(PGOMemOPSizeOptLegacyPass, "pgo-memop-opt", + "Optimize memory intrinsic using its size value profile", + false, false) + +FunctionPass *llvm::createPGOMemOPSizeOptLegacyPass() { + return new PGOMemOPSizeOptLegacyPass(); +} + +namespace { +class MemOPSizeOpt : public InstVisitor<MemOPSizeOpt> { +public: + MemOPSizeOpt(Function &Func, BlockFrequencyInfo &BFI) + : Func(Func), BFI(BFI), Changed(false) { + ValueDataArray = + llvm::make_unique<InstrProfValueData[]>(MemOPMaxVersion + 2); + // Get the MemOPSize range information from option MemOPSizeRange, + getMemOPSizeRangeFromOption(MemOPSizeRange, PreciseRangeStart, + PreciseRangeLast); + } + bool isChanged() const { return Changed; } + void perform() { + WorkList.clear(); + visit(Func); + + for (auto &MI : WorkList) { + ++NumOfPGOMemOPAnnotate; + if (perform(MI)) { + Changed = true; + ++NumOfPGOMemOPOpt; + DEBUG(dbgs() << "MemOP call: " << MI->getCalledFunction()->getName() + << "is Transformed.\n"); + } + } + } + + void visitMemIntrinsic(MemIntrinsic &MI) { + Value *Length = MI.getLength(); + // Not perform on constant length calls. + if (dyn_cast<ConstantInt>(Length)) + return; + WorkList.push_back(&MI); + } + +private: + Function &Func; + BlockFrequencyInfo &BFI; + bool Changed; + std::vector<MemIntrinsic *> WorkList; + // Start of the previse range. + int64_t PreciseRangeStart; + // Last value of the previse range. + int64_t PreciseRangeLast; + // The space to read the profile annotation. + std::unique_ptr<InstrProfValueData[]> ValueDataArray; + bool perform(MemIntrinsic *MI); + + // This kind shows which group the value falls in. For PreciseValue, we have + // the profile count for that value. LargeGroup groups the values that are in + // range [LargeValue, +inf). NonLargeGroup groups the rest of values. + enum MemOPSizeKind { PreciseValue, NonLargeGroup, LargeGroup }; + + MemOPSizeKind getMemOPSizeKind(int64_t Value) const { + if (Value == MemOPSizeLarge && MemOPSizeLarge != 0) + return LargeGroup; + if (Value == PreciseRangeLast + 1) + return NonLargeGroup; + return PreciseValue; + } +}; + +static const char *getMIName(const MemIntrinsic *MI) { + switch (MI->getIntrinsicID()) { + case Intrinsic::memcpy: + return "memcpy"; + case Intrinsic::memmove: + return "memmove"; + case Intrinsic::memset: + return "memset"; + default: + return "unknown"; + } +} + +static bool isProfitable(uint64_t Count, uint64_t TotalCount) { + assert(Count <= TotalCount); + if (Count < MemOPCountThreshold) + return false; + if (Count < TotalCount * MemOPPercentThreshold / 100) + return false; + return true; +} + +static inline uint64_t getScaledCount(uint64_t Count, uint64_t Num, + uint64_t Denom) { + if (!MemOPScaleCount) + return Count; + bool Overflowed; + uint64_t ScaleCount = SaturatingMultiply(Count, Num, &Overflowed); + return ScaleCount / Denom; +} + +bool MemOPSizeOpt::perform(MemIntrinsic *MI) { + assert(MI); + if (MI->getIntrinsicID() == Intrinsic::memmove) + return false; + + uint32_t NumVals, MaxNumPromotions = MemOPMaxVersion + 2; + uint64_t TotalCount; + if (!getValueProfDataFromInst(*MI, IPVK_MemOPSize, MaxNumPromotions, + ValueDataArray.get(), NumVals, TotalCount)) + return false; + + uint64_t ActualCount = TotalCount; + uint64_t SavedTotalCount = TotalCount; + if (MemOPScaleCount) { + auto BBEdgeCount = BFI.getBlockProfileCount(MI->getParent()); + if (!BBEdgeCount) + return false; + ActualCount = *BBEdgeCount; + } + + ArrayRef<InstrProfValueData> VDs(ValueDataArray.get(), NumVals); + DEBUG(dbgs() << "Read one memory intrinsic profile with count " << ActualCount + << "\n"); + DEBUG( + for (auto &VD + : VDs) { dbgs() << " (" << VD.Value << "," << VD.Count << ")\n"; }); + + if (ActualCount < MemOPCountThreshold) + return false; + // Skip if the total value profiled count is 0, in which case we can't + // scale up the counts properly (and there is no profitable transformation). + if (TotalCount == 0) + return false; + + TotalCount = ActualCount; + if (MemOPScaleCount) + DEBUG(dbgs() << "Scale counts: numerator = " << ActualCount + << " denominator = " << SavedTotalCount << "\n"); + + // Keeping track of the count of the default case: + uint64_t RemainCount = TotalCount; + uint64_t SavedRemainCount = SavedTotalCount; + SmallVector<uint64_t, 16> SizeIds; + SmallVector<uint64_t, 16> CaseCounts; + uint64_t MaxCount = 0; + unsigned Version = 0; + // Default case is in the front -- save the slot here. + CaseCounts.push_back(0); + for (auto &VD : VDs) { + int64_t V = VD.Value; + uint64_t C = VD.Count; + if (MemOPScaleCount) + C = getScaledCount(C, ActualCount, SavedTotalCount); + + // Only care precise value here. + if (getMemOPSizeKind(V) != PreciseValue) + continue; + + // ValueCounts are sorted on the count. Break at the first un-profitable + // value. + if (!isProfitable(C, RemainCount)) + break; + + SizeIds.push_back(V); + CaseCounts.push_back(C); + if (C > MaxCount) + MaxCount = C; + + assert(RemainCount >= C); + RemainCount -= C; + assert(SavedRemainCount >= VD.Count); + SavedRemainCount -= VD.Count; + + if (++Version > MemOPMaxVersion && MemOPMaxVersion != 0) + break; + } + + if (Version == 0) + return false; + + CaseCounts[0] = RemainCount; + if (RemainCount > MaxCount) + MaxCount = RemainCount; + + uint64_t SumForOpt = TotalCount - RemainCount; + + DEBUG(dbgs() << "Optimize one memory intrinsic call to " << Version + << " Versions (covering " << SumForOpt << " out of " + << TotalCount << ")\n"); + + // mem_op(..., size) + // ==> + // switch (size) { + // case s1: + // mem_op(..., s1); + // goto merge_bb; + // case s2: + // mem_op(..., s2); + // goto merge_bb; + // ... + // default: + // mem_op(..., size); + // goto merge_bb; + // } + // merge_bb: + + BasicBlock *BB = MI->getParent(); + DEBUG(dbgs() << "\n\n== Basic Block Before ==\n"); + DEBUG(dbgs() << *BB << "\n"); + auto OrigBBFreq = BFI.getBlockFreq(BB); + + BasicBlock *DefaultBB = SplitBlock(BB, MI); + BasicBlock::iterator It(*MI); + ++It; + assert(It != DefaultBB->end()); + BasicBlock *MergeBB = SplitBlock(DefaultBB, &(*It)); + MergeBB->setName("MemOP.Merge"); + BFI.setBlockFreq(MergeBB, OrigBBFreq.getFrequency()); + DefaultBB->setName("MemOP.Default"); + + auto &Ctx = Func.getContext(); + IRBuilder<> IRB(BB); + BB->getTerminator()->eraseFromParent(); + Value *SizeVar = MI->getLength(); + SwitchInst *SI = IRB.CreateSwitch(SizeVar, DefaultBB, SizeIds.size()); + + // Clear the value profile data. + MI->setMetadata(LLVMContext::MD_prof, nullptr); + // If all promoted, we don't need the MD.prof metadata. + if (SavedRemainCount > 0 || Version != NumVals) + // Otherwise we need update with the un-promoted records back. + annotateValueSite(*Func.getParent(), *MI, VDs.slice(Version), + SavedRemainCount, IPVK_MemOPSize, NumVals); + + DEBUG(dbgs() << "\n\n== Basic Block After==\n"); + + for (uint64_t SizeId : SizeIds) { + ConstantInt *CaseSizeId = ConstantInt::get(Type::getInt64Ty(Ctx), SizeId); + BasicBlock *CaseBB = BasicBlock::Create( + Ctx, Twine("MemOP.Case.") + Twine(SizeId), &Func, DefaultBB); + Instruction *NewInst = MI->clone(); + // Fix the argument. + dyn_cast<MemIntrinsic>(NewInst)->setLength(CaseSizeId); + CaseBB->getInstList().push_back(NewInst); + IRBuilder<> IRBCase(CaseBB); + IRBCase.CreateBr(MergeBB); + SI->addCase(CaseSizeId, CaseBB); + DEBUG(dbgs() << *CaseBB << "\n"); + } + setProfMetadata(Func.getParent(), SI, CaseCounts, MaxCount); + + DEBUG(dbgs() << *BB << "\n"); + DEBUG(dbgs() << *DefaultBB << "\n"); + DEBUG(dbgs() << *MergeBB << "\n"); + + emitOptimizationRemark(Func.getContext(), "memop-opt", Func, + MI->getDebugLoc(), + Twine("optimize ") + getMIName(MI) + " with count " + + Twine(SumForOpt) + " out of " + Twine(TotalCount) + + " for " + Twine(Version) + " versions"); + + return true; +} +} // namespace + +static bool PGOMemOPSizeOptImpl(Function &F, BlockFrequencyInfo &BFI) { + if (DisableMemOPOPT) + return false; + + if (F.hasFnAttribute(Attribute::OptimizeForSize)) + return false; + MemOPSizeOpt MemOPSizeOpt(F, BFI); + MemOPSizeOpt.perform(); + return MemOPSizeOpt.isChanged(); +} + +bool PGOMemOPSizeOptLegacyPass::runOnFunction(Function &F) { + BlockFrequencyInfo &BFI = + getAnalysis<BlockFrequencyInfoWrapperPass>().getBFI(); + return PGOMemOPSizeOptImpl(F, BFI); +} + +namespace llvm { +char &PGOMemOPSizeOptID = PGOMemOPSizeOptLegacyPass::ID; + +PreservedAnalyses PGOMemOPSizeOpt::run(Function &F, + FunctionAnalysisManager &FAM) { + auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(F); + bool Changed = PGOMemOPSizeOptImpl(F, BFI); + if (!Changed) + return PreservedAnalyses::all(); + auto PA = PreservedAnalyses(); + PA.preserve<GlobalsAA>(); + return PA; +} +} // namespace llvm diff --git a/contrib/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp b/contrib/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp index 5b4b1fb..06fe075 100644 --- a/contrib/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp +++ b/contrib/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp @@ -7,24 +7,7 @@ // //===----------------------------------------------------------------------===// // -// Coverage instrumentation that works with AddressSanitizer -// and potentially with other Sanitizers. -// -// We create a Guard variable with the same linkage -// as the function and inject this code into the entry block (SCK_Function) -// or all blocks (SCK_BB): -// if (Guard < 0) { -// __sanitizer_cov(&Guard); -// } -// The accesses to Guard are atomic. The rest of the logic is -// in __sanitizer_cov (it's fine to call it more than once). -// -// With SCK_Edge we also split critical edges this effectively -// instrumenting all edges. -// -// This coverage implementation provides very limited data: -// it only tells if a given function (block) was ever executed. No counters. -// But for many use cases this is what we need and the added slowdown small. +// Coverage instrumentation done on LLVM IR level, works with Sanitizers. // //===----------------------------------------------------------------------===// @@ -56,16 +39,8 @@ using namespace llvm; #define DEBUG_TYPE "sancov" -static const char *const SanCovModuleInitName = "__sanitizer_cov_module_init"; -static const char *const SanCovName = "__sanitizer_cov"; -static const char *const SanCovWithCheckName = "__sanitizer_cov_with_check"; -static const char *const SanCovIndirCallName = "__sanitizer_cov_indir_call16"; static const char *const SanCovTracePCIndirName = "__sanitizer_cov_trace_pc_indir"; -static const char *const SanCovTraceEnterName = - "__sanitizer_cov_trace_func_enter"; -static const char *const SanCovTraceBBName = - "__sanitizer_cov_trace_basic_block"; static const char *const SanCovTracePCName = "__sanitizer_cov_trace_pc"; static const char *const SanCovTraceCmp1 = "__sanitizer_cov_trace_cmp1"; static const char *const SanCovTraceCmp2 = "__sanitizer_cov_trace_cmp2"; @@ -78,39 +53,34 @@ static const char *const SanCovTraceSwitchName = "__sanitizer_cov_trace_switch"; static const char *const SanCovModuleCtorName = "sancov.module_ctor"; static const uint64_t SanCtorAndDtorPriority = 2; -static const char *const SanCovTracePCGuardSection = "__sancov_guards"; static const char *const SanCovTracePCGuardName = "__sanitizer_cov_trace_pc_guard"; static const char *const SanCovTracePCGuardInitName = "__sanitizer_cov_trace_pc_guard_init"; +static const char *const SanCov8bitCountersInitName = + "__sanitizer_cov_8bit_counters_init"; + +static const char *const SanCovGuardsSectionName = "sancov_guards"; +static const char *const SanCovCountersSectionName = "sancov_cntrs"; static cl::opt<int> ClCoverageLevel( "sanitizer-coverage-level", cl::desc("Sanitizer Coverage. 0: none, 1: entry block, 2: all blocks, " - "3: all blocks and critical edges, " - "4: above plus indirect calls"), + "3: all blocks and critical edges"), cl::Hidden, cl::init(0)); -static cl::opt<unsigned> ClCoverageBlockThreshold( - "sanitizer-coverage-block-threshold", - cl::desc("Use a callback with a guard check inside it if there are" - " more than this number of blocks."), - cl::Hidden, cl::init(500)); - -static cl::opt<bool> - ClExperimentalTracing("sanitizer-coverage-experimental-tracing", - cl::desc("Experimental basic-block tracing: insert " - "callbacks at every basic block"), - cl::Hidden, cl::init(false)); - -static cl::opt<bool> ClExperimentalTracePC("sanitizer-coverage-trace-pc", - cl::desc("Experimental pc tracing"), - cl::Hidden, cl::init(false)); +static cl::opt<bool> ClTracePC("sanitizer-coverage-trace-pc", + cl::desc("Experimental pc tracing"), cl::Hidden, + cl::init(false)); static cl::opt<bool> ClTracePCGuard("sanitizer-coverage-trace-pc-guard", cl::desc("pc tracing with a guard"), cl::Hidden, cl::init(false)); +static cl::opt<bool> ClInline8bitCounters("sanitizer-coverage-inline-8bit-counters", + cl::desc("increments 8-bit counter for every edge"), + cl::Hidden, cl::init(false)); + static cl::opt<bool> ClCMPTracing("sanitizer-coverage-trace-compares", cl::desc("Tracing of CMP and similar instructions"), @@ -129,16 +99,6 @@ static cl::opt<bool> cl::desc("Reduce the number of instrumented blocks"), cl::Hidden, cl::init(true)); -// Experimental 8-bit counters used as an additional search heuristic during -// coverage-guided fuzzing. -// The counters are not thread-friendly: -// - contention on these counters may cause significant slowdown; -// - the counter updates are racy and the results may be inaccurate. -// They are also inaccurate due to 8-bit integer overflow. -static cl::opt<bool> ClUse8bitCounters("sanitizer-coverage-8bit-counters", - cl::desc("Experimental 8-bit counters"), - cl::Hidden, cl::init(false)); - namespace { SanitizerCoverageOptions getOptions(int LegacyCoverageLevel) { @@ -169,13 +129,15 @@ SanitizerCoverageOptions OverrideFromCL(SanitizerCoverageOptions Options) { SanitizerCoverageOptions CLOpts = getOptions(ClCoverageLevel); Options.CoverageType = std::max(Options.CoverageType, CLOpts.CoverageType); Options.IndirectCalls |= CLOpts.IndirectCalls; - Options.TraceBB |= ClExperimentalTracing; Options.TraceCmp |= ClCMPTracing; Options.TraceDiv |= ClDIVTracing; Options.TraceGep |= ClGEPTracing; - Options.Use8bitCounters |= ClUse8bitCounters; - Options.TracePC |= ClExperimentalTracePC; + Options.TracePC |= ClTracePC; Options.TracePCGuard |= ClTracePCGuard; + Options.Inline8bitCounters |= ClInline8bitCounters; + if (!Options.TracePCGuard && !Options.TracePC && !Options.Inline8bitCounters) + Options.TracePCGuard = true; // TracePCGuard is default. + Options.NoPrune |= !ClPruneBlocks; return Options; } @@ -207,90 +169,128 @@ private: void InjectTraceForSwitch(Function &F, ArrayRef<Instruction *> SwitchTraceTargets); bool InjectCoverage(Function &F, ArrayRef<BasicBlock *> AllBlocks); - void CreateFunctionGuardArray(size_t NumGuards, Function &F); - void SetNoSanitizeMetadata(Instruction *I); - void InjectCoverageAtBlock(Function &F, BasicBlock &BB, size_t Idx, - bool UseCalls); - unsigned NumberOfInstrumentedBlocks() { - return SanCovFunction->getNumUses() + - SanCovWithCheckFunction->getNumUses() + SanCovTraceBB->getNumUses() + - SanCovTraceEnter->getNumUses(); + GlobalVariable *CreateFunctionLocalArrayInSection(size_t NumElements, + Function &F, Type *Ty, + const char *Section); + void CreateFunctionLocalArrays(size_t NumGuards, Function &F); + void InjectCoverageAtBlock(Function &F, BasicBlock &BB, size_t Idx); + void CreateInitCallForSection(Module &M, const char *InitFunctionName, + Type *Ty, const std::string &Section); + + void SetNoSanitizeMetadata(Instruction *I) { + I->setMetadata(I->getModule()->getMDKindID("nosanitize"), + MDNode::get(*C, None)); } - Function *SanCovFunction; - Function *SanCovWithCheckFunction; - Function *SanCovIndirCallFunction, *SanCovTracePCIndir; - Function *SanCovTraceEnter, *SanCovTraceBB, *SanCovTracePC, *SanCovTracePCGuard; + + std::string getSectionName(const std::string &Section) const; + std::string getSectionStart(const std::string &Section) const; + std::string getSectionEnd(const std::string &Section) const; + Function *SanCovTracePCIndir; + Function *SanCovTracePC, *SanCovTracePCGuard; Function *SanCovTraceCmpFunction[4]; Function *SanCovTraceDivFunction[2]; Function *SanCovTraceGepFunction; Function *SanCovTraceSwitchFunction; InlineAsm *EmptyAsm; - Type *IntptrTy, *IntptrPtrTy, *Int64Ty, *Int64PtrTy, *Int32Ty, *Int32PtrTy; + Type *IntptrTy, *IntptrPtrTy, *Int64Ty, *Int64PtrTy, *Int32Ty, *Int32PtrTy, + *Int8Ty, *Int8PtrTy; Module *CurModule; + Triple TargetTriple; LLVMContext *C; const DataLayout *DL; - GlobalVariable *GuardArray; GlobalVariable *FunctionGuardArray; // for trace-pc-guard. - GlobalVariable *EightBitCounterArray; - bool HasSancovGuardsSection; + GlobalVariable *Function8bitCounterArray; // for inline-8bit-counters. SanitizerCoverageOptions Options; }; } // namespace +void SanitizerCoverageModule::CreateInitCallForSection( + Module &M, const char *InitFunctionName, Type *Ty, + const std::string &Section) { + IRBuilder<> IRB(M.getContext()); + Function *CtorFunc; + GlobalVariable *SecStart = + new GlobalVariable(M, Ty, false, GlobalVariable::ExternalLinkage, nullptr, + getSectionStart(Section)); + SecStart->setVisibility(GlobalValue::HiddenVisibility); + GlobalVariable *SecEnd = + new GlobalVariable(M, Ty, false, GlobalVariable::ExternalLinkage, + nullptr, getSectionEnd(Section)); + SecEnd->setVisibility(GlobalValue::HiddenVisibility); + + std::tie(CtorFunc, std::ignore) = createSanitizerCtorAndInitFunctions( + M, SanCovModuleCtorName, InitFunctionName, {Ty, Ty}, + {IRB.CreatePointerCast(SecStart, Ty), IRB.CreatePointerCast(SecEnd, Ty)}); + + if (TargetTriple.supportsCOMDAT()) { + // Use comdat to dedup CtorFunc. + CtorFunc->setComdat(M.getOrInsertComdat(SanCovModuleCtorName)); + appendToGlobalCtors(M, CtorFunc, SanCtorAndDtorPriority, CtorFunc); + } else { + appendToGlobalCtors(M, CtorFunc, SanCtorAndDtorPriority); + } +} + bool SanitizerCoverageModule::runOnModule(Module &M) { if (Options.CoverageType == SanitizerCoverageOptions::SCK_None) return false; C = &(M.getContext()); DL = &M.getDataLayout(); CurModule = &M; - HasSancovGuardsSection = false; + TargetTriple = Triple(M.getTargetTriple()); + FunctionGuardArray = nullptr; + Function8bitCounterArray = nullptr; IntptrTy = Type::getIntNTy(*C, DL->getPointerSizeInBits()); IntptrPtrTy = PointerType::getUnqual(IntptrTy); Type *VoidTy = Type::getVoidTy(*C); IRBuilder<> IRB(*C); - Type *Int8PtrTy = PointerType::getUnqual(IRB.getInt8Ty()); Int64PtrTy = PointerType::getUnqual(IRB.getInt64Ty()); Int32PtrTy = PointerType::getUnqual(IRB.getInt32Ty()); + Int8PtrTy = PointerType::getUnqual(IRB.getInt8Ty()); Int64Ty = IRB.getInt64Ty(); Int32Ty = IRB.getInt32Ty(); + Int8Ty = IRB.getInt8Ty(); - SanCovFunction = checkSanitizerInterfaceFunction( - M.getOrInsertFunction(SanCovName, VoidTy, Int32PtrTy, nullptr)); - SanCovWithCheckFunction = checkSanitizerInterfaceFunction( - M.getOrInsertFunction(SanCovWithCheckName, VoidTy, Int32PtrTy, nullptr)); SanCovTracePCIndir = checkSanitizerInterfaceFunction( - M.getOrInsertFunction(SanCovTracePCIndirName, VoidTy, IntptrTy, nullptr)); - SanCovIndirCallFunction = - checkSanitizerInterfaceFunction(M.getOrInsertFunction( - SanCovIndirCallName, VoidTy, IntptrTy, IntptrTy, nullptr)); + M.getOrInsertFunction(SanCovTracePCIndirName, VoidTy, IntptrTy)); SanCovTraceCmpFunction[0] = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - SanCovTraceCmp1, VoidTy, IRB.getInt8Ty(), IRB.getInt8Ty(), nullptr)); + SanCovTraceCmp1, VoidTy, IRB.getInt8Ty(), IRB.getInt8Ty())); SanCovTraceCmpFunction[1] = checkSanitizerInterfaceFunction( M.getOrInsertFunction(SanCovTraceCmp2, VoidTy, IRB.getInt16Ty(), - IRB.getInt16Ty(), nullptr)); + IRB.getInt16Ty())); SanCovTraceCmpFunction[2] = checkSanitizerInterfaceFunction( M.getOrInsertFunction(SanCovTraceCmp4, VoidTy, IRB.getInt32Ty(), - IRB.getInt32Ty(), nullptr)); + IRB.getInt32Ty())); SanCovTraceCmpFunction[3] = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - SanCovTraceCmp8, VoidTy, Int64Ty, Int64Ty, nullptr)); + SanCovTraceCmp8, VoidTy, Int64Ty, Int64Ty)); SanCovTraceDivFunction[0] = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - SanCovTraceDiv4, VoidTy, IRB.getInt32Ty(), nullptr)); + SanCovTraceDiv4, VoidTy, IRB.getInt32Ty())); SanCovTraceDivFunction[1] = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - SanCovTraceDiv8, VoidTy, Int64Ty, nullptr)); + SanCovTraceDiv8, VoidTy, Int64Ty)); SanCovTraceGepFunction = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - SanCovTraceGep, VoidTy, IntptrTy, nullptr)); + SanCovTraceGep, VoidTy, IntptrTy)); SanCovTraceSwitchFunction = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - SanCovTraceSwitchName, VoidTy, Int64Ty, Int64PtrTy, nullptr)); + SanCovTraceSwitchName, VoidTy, Int64Ty, Int64PtrTy)); + // Make sure smaller parameters are zero-extended to i64 as required by the + // x86_64 ABI. + if (TargetTriple.getArch() == Triple::x86_64) { + for (int i = 0; i < 3; i++) { + SanCovTraceCmpFunction[i]->addParamAttr(0, Attribute::ZExt); + SanCovTraceCmpFunction[i]->addParamAttr(1, Attribute::ZExt); + } + SanCovTraceDivFunction[0]->addParamAttr(0, Attribute::ZExt); + } + // We insert an empty inline asm after cov callbacks to avoid callback merge. EmptyAsm = InlineAsm::get(FunctionType::get(IRB.getVoidTy(), false), @@ -298,102 +298,19 @@ bool SanitizerCoverageModule::runOnModule(Module &M) { /*hasSideEffects=*/true); SanCovTracePC = checkSanitizerInterfaceFunction( - M.getOrInsertFunction(SanCovTracePCName, VoidTy, nullptr)); + M.getOrInsertFunction(SanCovTracePCName, VoidTy)); SanCovTracePCGuard = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - SanCovTracePCGuardName, VoidTy, Int32PtrTy, nullptr)); - SanCovTraceEnter = checkSanitizerInterfaceFunction( - M.getOrInsertFunction(SanCovTraceEnterName, VoidTy, Int32PtrTy, nullptr)); - SanCovTraceBB = checkSanitizerInterfaceFunction( - M.getOrInsertFunction(SanCovTraceBBName, VoidTy, Int32PtrTy, nullptr)); - - // At this point we create a dummy array of guards because we don't - // know how many elements we will need. - Type *Int32Ty = IRB.getInt32Ty(); - Type *Int8Ty = IRB.getInt8Ty(); - - if (!Options.TracePCGuard) - GuardArray = - new GlobalVariable(M, Int32Ty, false, GlobalValue::ExternalLinkage, - nullptr, "__sancov_gen_cov_tmp"); - if (Options.Use8bitCounters) - EightBitCounterArray = - new GlobalVariable(M, Int8Ty, false, GlobalVariable::ExternalLinkage, - nullptr, "__sancov_gen_cov_tmp"); + SanCovTracePCGuardName, VoidTy, Int32PtrTy)); for (auto &F : M) runOnFunction(F); - auto N = NumberOfInstrumentedBlocks(); - - GlobalVariable *RealGuardArray = nullptr; - if (!Options.TracePCGuard) { - // Now we know how many elements we need. Create an array of guards - // with one extra element at the beginning for the size. - Type *Int32ArrayNTy = ArrayType::get(Int32Ty, N + 1); - RealGuardArray = new GlobalVariable( - M, Int32ArrayNTy, false, GlobalValue::PrivateLinkage, - Constant::getNullValue(Int32ArrayNTy), "__sancov_gen_cov"); - - // Replace the dummy array with the real one. - GuardArray->replaceAllUsesWith( - IRB.CreatePointerCast(RealGuardArray, Int32PtrTy)); - GuardArray->eraseFromParent(); - } - - GlobalVariable *RealEightBitCounterArray; - if (Options.Use8bitCounters) { - // Make sure the array is 16-aligned. - static const int CounterAlignment = 16; - Type *Int8ArrayNTy = ArrayType::get(Int8Ty, alignTo(N, CounterAlignment)); - RealEightBitCounterArray = new GlobalVariable( - M, Int8ArrayNTy, false, GlobalValue::PrivateLinkage, - Constant::getNullValue(Int8ArrayNTy), "__sancov_gen_cov_counter"); - RealEightBitCounterArray->setAlignment(CounterAlignment); - EightBitCounterArray->replaceAllUsesWith( - IRB.CreatePointerCast(RealEightBitCounterArray, Int8PtrTy)); - EightBitCounterArray->eraseFromParent(); - } - - // Create variable for module (compilation unit) name - Constant *ModNameStrConst = - ConstantDataArray::getString(M.getContext(), M.getName(), true); - GlobalVariable *ModuleName = new GlobalVariable( - M, ModNameStrConst->getType(), true, GlobalValue::PrivateLinkage, - ModNameStrConst, "__sancov_gen_modname"); - if (Options.TracePCGuard) { - if (HasSancovGuardsSection) { - Function *CtorFunc; - std::string SectionName(SanCovTracePCGuardSection); - GlobalVariable *Bounds[2]; - const char *Prefix[2] = {"__start_", "__stop_"}; - for (int i = 0; i < 2; i++) { - Bounds[i] = new GlobalVariable(M, Int32PtrTy, false, - GlobalVariable::ExternalLinkage, nullptr, - Prefix[i] + SectionName); - Bounds[i]->setVisibility(GlobalValue::HiddenVisibility); - } - std::tie(CtorFunc, std::ignore) = createSanitizerCtorAndInitFunctions( - M, SanCovModuleCtorName, SanCovTracePCGuardInitName, - {Int32PtrTy, Int32PtrTy}, - {IRB.CreatePointerCast(Bounds[0], Int32PtrTy), - IRB.CreatePointerCast(Bounds[1], Int32PtrTy)}); - - appendToGlobalCtors(M, CtorFunc, SanCtorAndDtorPriority); - } - } else if (!Options.TracePC) { - Function *CtorFunc; - std::tie(CtorFunc, std::ignore) = createSanitizerCtorAndInitFunctions( - M, SanCovModuleCtorName, SanCovModuleInitName, - {Int32PtrTy, IntptrTy, Int8PtrTy, Int8PtrTy}, - {IRB.CreatePointerCast(RealGuardArray, Int32PtrTy), - ConstantInt::get(IntptrTy, N), - Options.Use8bitCounters - ? IRB.CreatePointerCast(RealEightBitCounterArray, Int8PtrTy) - : Constant::getNullValue(Int8PtrTy), - IRB.CreatePointerCast(ModuleName, Int8PtrTy)}); - - appendToGlobalCtors(M, CtorFunc, SanCtorAndDtorPriority); - } + if (FunctionGuardArray) + CreateInitCallForSection(M, SanCovTracePCGuardInitName, Int32PtrTy, + SanCovGuardsSectionName); + if (Function8bitCounterArray) + CreateInitCallForSection(M, SanCov8bitCountersInitName, Int8PtrTy, + SanCovCountersSectionName); return true; } @@ -425,8 +342,10 @@ static bool isFullPostDominator(const BasicBlock *BB, return true; } -static bool shouldInstrumentBlock(const Function& F, const BasicBlock *BB, const DominatorTree *DT, - const PostDominatorTree *PDT) { +static bool shouldInstrumentBlock(const Function &F, const BasicBlock *BB, + const DominatorTree *DT, + const PostDominatorTree *PDT, + const SanitizerCoverageOptions &Options) { // Don't insert coverage for unreachable blocks: we will never call // __sanitizer_cov() for them, so counting them in // NumberOfInstrumentedBlocks() might complicate calculation of code coverage @@ -435,10 +354,18 @@ static bool shouldInstrumentBlock(const Function& F, const BasicBlock *BB, const if (isa<UnreachableInst>(BB->getTerminator())) return false; - if (!ClPruneBlocks || &F.getEntryBlock() == BB) + // Don't insert coverage into blocks without a valid insertion point + // (catchswitch blocks). + if (BB->getFirstInsertionPt() == BB->end()) + return false; + + if (Options.NoPrune || &F.getEntryBlock() == BB) return true; - return !(isFullDominator(BB, DT) || isFullPostDominator(BB, PDT)); + // Do not instrument full dominators, or full post-dominators with multiple + // predecessors. + return !isFullDominator(BB, DT) + && !(isFullPostDominator(BB, PDT) && !BB->getSinglePredecessor()); } bool SanitizerCoverageModule::runOnFunction(Function &F) { @@ -474,7 +401,7 @@ bool SanitizerCoverageModule::runOnFunction(Function &F) { &getAnalysis<PostDominatorTreeWrapperPass>(F).getPostDomTree(); for (auto &BB : F) { - if (shouldInstrumentBlock(F, &BB, DT, PDT)) + if (shouldInstrumentBlock(F, &BB, DT, PDT, Options)) BlocksToInstrument.push_back(&BB); for (auto &Inst : BB) { if (Options.IndirectCalls) { @@ -507,17 +434,26 @@ bool SanitizerCoverageModule::runOnFunction(Function &F) { InjectTraceForGep(F, GepTraceTargets); return true; } -void SanitizerCoverageModule::CreateFunctionGuardArray(size_t NumGuards, - Function &F) { - if (!Options.TracePCGuard) return; - HasSancovGuardsSection = true; - ArrayType *ArrayOfInt32Ty = ArrayType::get(Int32Ty, NumGuards); - FunctionGuardArray = new GlobalVariable( - *CurModule, ArrayOfInt32Ty, false, GlobalVariable::PrivateLinkage, - Constant::getNullValue(ArrayOfInt32Ty), "__sancov_gen_"); + +GlobalVariable *SanitizerCoverageModule::CreateFunctionLocalArrayInSection( + size_t NumElements, Function &F, Type *Ty, const char *Section) { + ArrayType *ArrayTy = ArrayType::get(Ty, NumElements); + auto Array = new GlobalVariable( + *CurModule, ArrayTy, false, GlobalVariable::PrivateLinkage, + Constant::getNullValue(ArrayTy), "__sancov_gen_"); if (auto Comdat = F.getComdat()) - FunctionGuardArray->setComdat(Comdat); - FunctionGuardArray->setSection(SanCovTracePCGuardSection); + Array->setComdat(Comdat); + Array->setSection(getSectionName(Section)); + return Array; +} +void SanitizerCoverageModule::CreateFunctionLocalArrays(size_t NumGuards, + Function &F) { + if (Options.TracePCGuard) + FunctionGuardArray = CreateFunctionLocalArrayInSection( + NumGuards, F, Int32Ty, SanCovGuardsSectionName); + if (Options.Inline8bitCounters) + Function8bitCounterArray = CreateFunctionLocalArrayInSection( + NumGuards, F, Int8Ty, SanCovCountersSectionName); } bool SanitizerCoverageModule::InjectCoverage(Function &F, @@ -527,14 +463,13 @@ bool SanitizerCoverageModule::InjectCoverage(Function &F, case SanitizerCoverageOptions::SCK_None: return false; case SanitizerCoverageOptions::SCK_Function: - CreateFunctionGuardArray(1, F); - InjectCoverageAtBlock(F, F.getEntryBlock(), 0, false); + CreateFunctionLocalArrays(1, F); + InjectCoverageAtBlock(F, F.getEntryBlock(), 0); return true; default: { - bool UseCalls = ClCoverageBlockThreshold < AllBlocks.size(); - CreateFunctionGuardArray(AllBlocks.size(), F); + CreateFunctionLocalArrays(AllBlocks.size(), F); for (size_t i = 0, N = AllBlocks.size(); i < N; i++) - InjectCoverageAtBlock(F, *AllBlocks[i], i, UseCalls); + InjectCoverageAtBlock(F, *AllBlocks[i], i); return true; } } @@ -551,26 +486,14 @@ void SanitizerCoverageModule::InjectCoverageForIndirectCalls( Function &F, ArrayRef<Instruction *> IndirCalls) { if (IndirCalls.empty()) return; - const int CacheSize = 16; - const int CacheAlignment = 64; // Align for better performance. - Type *Ty = ArrayType::get(IntptrTy, CacheSize); + assert(Options.TracePC || Options.TracePCGuard || Options.Inline8bitCounters); for (auto I : IndirCalls) { IRBuilder<> IRB(I); CallSite CS(I); Value *Callee = CS.getCalledValue(); if (isa<InlineAsm>(Callee)) continue; - GlobalVariable *CalleeCache = new GlobalVariable( - *F.getParent(), Ty, false, GlobalValue::PrivateLinkage, - Constant::getNullValue(Ty), "__sancov_gen_callee_cache"); - CalleeCache->setAlignment(CacheAlignment); - if (Options.TracePC || Options.TracePCGuard) - IRB.CreateCall(SanCovTracePCIndir, - IRB.CreatePointerCast(Callee, IntptrTy)); - else - IRB.CreateCall(SanCovIndirCallFunction, - {IRB.CreatePointerCast(Callee, IntptrTy), - IRB.CreatePointerCast(CalleeCache, IntptrTy)}); + IRB.CreateCall(SanCovTracePCIndir, IRB.CreatePointerCast(Callee, IntptrTy)); } } @@ -670,13 +593,8 @@ void SanitizerCoverageModule::InjectTraceForCmp( } } -void SanitizerCoverageModule::SetNoSanitizeMetadata(Instruction *I) { - I->setMetadata(I->getModule()->getMDKindID("nosanitize"), - MDNode::get(*C, None)); -} - void SanitizerCoverageModule::InjectCoverageAtBlock(Function &F, BasicBlock &BB, - size_t Idx, bool UseCalls) { + size_t Idx) { BasicBlock::iterator IP = BB.getFirstInsertionPt(); bool IsEntryBB = &BB == &F.getEntryBlock(); DebugLoc EntryLoc; @@ -696,65 +614,51 @@ void SanitizerCoverageModule::InjectCoverageAtBlock(Function &F, BasicBlock &BB, if (Options.TracePC) { IRB.CreateCall(SanCovTracePC); // gets the PC using GET_CALLER_PC. IRB.CreateCall(EmptyAsm, {}); // Avoids callback merge. - } else if (Options.TracePCGuard) { + } + if (Options.TracePCGuard) { auto GuardPtr = IRB.CreateIntToPtr( IRB.CreateAdd(IRB.CreatePointerCast(FunctionGuardArray, IntptrTy), ConstantInt::get(IntptrTy, Idx * 4)), Int32PtrTy); - if (!UseCalls) { - auto GuardLoad = IRB.CreateLoad(GuardPtr); - GuardLoad->setAtomic(AtomicOrdering::Monotonic); - GuardLoad->setAlignment(8); - SetNoSanitizeMetadata(GuardLoad); // Don't instrument with e.g. asan. - auto Cmp = IRB.CreateICmpNE( - GuardLoad, Constant::getNullValue(GuardLoad->getType())); - auto Ins = SplitBlockAndInsertIfThen( - Cmp, &*IP, false, MDBuilder(*C).createBranchWeights(1, 100000)); - IRB.SetInsertPoint(Ins); - IRB.SetCurrentDebugLocation(EntryLoc); - } IRB.CreateCall(SanCovTracePCGuard, GuardPtr); IRB.CreateCall(EmptyAsm, {}); // Avoids callback merge. - } else { - Value *GuardP = IRB.CreateAdd( - IRB.CreatePointerCast(GuardArray, IntptrTy), - ConstantInt::get(IntptrTy, (1 + NumberOfInstrumentedBlocks()) * 4)); - GuardP = IRB.CreateIntToPtr(GuardP, Int32PtrTy); - if (Options.TraceBB) { - IRB.CreateCall(IsEntryBB ? SanCovTraceEnter : SanCovTraceBB, GuardP); - } else if (UseCalls) { - IRB.CreateCall(SanCovWithCheckFunction, GuardP); - } else { - LoadInst *Load = IRB.CreateLoad(GuardP); - Load->setAtomic(AtomicOrdering::Monotonic); - Load->setAlignment(4); - SetNoSanitizeMetadata(Load); - Value *Cmp = - IRB.CreateICmpSGE(Constant::getNullValue(Load->getType()), Load); - Instruction *Ins = SplitBlockAndInsertIfThen( - Cmp, &*IP, false, MDBuilder(*C).createBranchWeights(1, 100000)); - IRB.SetInsertPoint(Ins); - IRB.SetCurrentDebugLocation(EntryLoc); - // __sanitizer_cov gets the PC of the instruction using GET_CALLER_PC. - IRB.CreateCall(SanCovFunction, GuardP); - IRB.CreateCall(EmptyAsm, {}); // Avoids callback merge. - } } - - if (Options.Use8bitCounters) { - IRB.SetInsertPoint(&*IP); - Value *P = IRB.CreateAdd( - IRB.CreatePointerCast(EightBitCounterArray, IntptrTy), - ConstantInt::get(IntptrTy, NumberOfInstrumentedBlocks() - 1)); - P = IRB.CreateIntToPtr(P, IRB.getInt8PtrTy()); - LoadInst *LI = IRB.CreateLoad(P); - Value *Inc = IRB.CreateAdd(LI, ConstantInt::get(IRB.getInt8Ty(), 1)); - StoreInst *SI = IRB.CreateStore(Inc, P); - SetNoSanitizeMetadata(LI); - SetNoSanitizeMetadata(SI); + if (Options.Inline8bitCounters) { + auto CounterPtr = IRB.CreateGEP( + Function8bitCounterArray, + {ConstantInt::get(IntptrTy, 0), ConstantInt::get(IntptrTy, Idx)}); + auto Load = IRB.CreateLoad(CounterPtr); + auto Inc = IRB.CreateAdd(Load, ConstantInt::get(Int8Ty, 1)); + auto Store = IRB.CreateStore(Inc, CounterPtr); + SetNoSanitizeMetadata(Load); + SetNoSanitizeMetadata(Store); } } +std::string +SanitizerCoverageModule::getSectionName(const std::string &Section) const { + if (TargetTriple.getObjectFormat() == Triple::COFF) + return ".SCOV$M"; + if (TargetTriple.isOSBinFormatMachO()) + return "__DATA,__" + Section; + return "__" + Section; +} + +std::string +SanitizerCoverageModule::getSectionStart(const std::string &Section) const { + if (TargetTriple.isOSBinFormatMachO()) + return "\1section$start$__DATA$__" + Section; + return "__start___" + Section; +} + +std::string +SanitizerCoverageModule::getSectionEnd(const std::string &Section) const { + if (TargetTriple.isOSBinFormatMachO()) + return "\1section$end$__DATA$__" + Section; + return "__stop___" + Section; +} + + char SanitizerCoverageModule::ID = 0; INITIALIZE_PASS_BEGIN(SanitizerCoverageModule, "sancov", "SanitizerCoverage: TODO." diff --git a/contrib/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp b/contrib/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp index 52035c7..ec69044 100644 --- a/contrib/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp +++ b/contrib/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp @@ -19,7 +19,6 @@ // The rest is handled by the run-time library. //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Instrumentation.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/SmallVector.h" @@ -42,6 +41,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/EscapeEnumerator.h" #include "llvm/Transforms/Utils/Local.h" @@ -155,17 +155,18 @@ FunctionPass *llvm::createThreadSanitizerPass() { void ThreadSanitizer::initializeCallbacks(Module &M) { IRBuilder<> IRB(M.getContext()); - AttributeSet Attr; - Attr = Attr.addAttribute(M.getContext(), AttributeSet::FunctionIndex, Attribute::NoUnwind); + AttributeList Attr; + Attr = Attr.addAttribute(M.getContext(), AttributeList::FunctionIndex, + Attribute::NoUnwind); // Initialize the callbacks. TsanFuncEntry = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - "__tsan_func_entry", Attr, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); + "__tsan_func_entry", Attr, IRB.getVoidTy(), IRB.getInt8PtrTy())); TsanFuncExit = checkSanitizerInterfaceFunction( - M.getOrInsertFunction("__tsan_func_exit", Attr, IRB.getVoidTy(), nullptr)); + M.getOrInsertFunction("__tsan_func_exit", Attr, IRB.getVoidTy())); TsanIgnoreBegin = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - "__tsan_ignore_thread_begin", Attr, IRB.getVoidTy(), nullptr)); + "__tsan_ignore_thread_begin", Attr, IRB.getVoidTy())); TsanIgnoreEnd = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - "__tsan_ignore_thread_end", Attr, IRB.getVoidTy(), nullptr)); + "__tsan_ignore_thread_end", Attr, IRB.getVoidTy())); OrdTy = IRB.getInt32Ty(); for (size_t i = 0; i < kNumberOfAccessSizes; ++i) { const unsigned ByteSize = 1U << i; @@ -174,31 +175,31 @@ void ThreadSanitizer::initializeCallbacks(Module &M) { std::string BitSizeStr = utostr(BitSize); SmallString<32> ReadName("__tsan_read" + ByteSizeStr); TsanRead[i] = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - ReadName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); + ReadName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy())); SmallString<32> WriteName("__tsan_write" + ByteSizeStr); TsanWrite[i] = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - WriteName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); + WriteName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy())); SmallString<64> UnalignedReadName("__tsan_unaligned_read" + ByteSizeStr); TsanUnalignedRead[i] = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - UnalignedReadName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); + UnalignedReadName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy())); SmallString<64> UnalignedWriteName("__tsan_unaligned_write" + ByteSizeStr); TsanUnalignedWrite[i] = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - UnalignedWriteName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); + UnalignedWriteName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy())); Type *Ty = Type::getIntNTy(M.getContext(), BitSize); Type *PtrTy = Ty->getPointerTo(); SmallString<32> AtomicLoadName("__tsan_atomic" + BitSizeStr + "_load"); TsanAtomicLoad[i] = checkSanitizerInterfaceFunction( - M.getOrInsertFunction(AtomicLoadName, Attr, Ty, PtrTy, OrdTy, nullptr)); + M.getOrInsertFunction(AtomicLoadName, Attr, Ty, PtrTy, OrdTy)); SmallString<32> AtomicStoreName("__tsan_atomic" + BitSizeStr + "_store"); TsanAtomicStore[i] = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - AtomicStoreName, Attr, IRB.getVoidTy(), PtrTy, Ty, OrdTy, nullptr)); + AtomicStoreName, Attr, IRB.getVoidTy(), PtrTy, Ty, OrdTy)); for (int op = AtomicRMWInst::FIRST_BINOP; op <= AtomicRMWInst::LAST_BINOP; ++op) { @@ -222,33 +223,33 @@ void ThreadSanitizer::initializeCallbacks(Module &M) { continue; SmallString<32> RMWName("__tsan_atomic" + itostr(BitSize) + NamePart); TsanAtomicRMW[op][i] = checkSanitizerInterfaceFunction( - M.getOrInsertFunction(RMWName, Attr, Ty, PtrTy, Ty, OrdTy, nullptr)); + M.getOrInsertFunction(RMWName, Attr, Ty, PtrTy, Ty, OrdTy)); } SmallString<32> AtomicCASName("__tsan_atomic" + BitSizeStr + "_compare_exchange_val"); TsanAtomicCAS[i] = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - AtomicCASName, Attr, Ty, PtrTy, Ty, Ty, OrdTy, OrdTy, nullptr)); + AtomicCASName, Attr, Ty, PtrTy, Ty, Ty, OrdTy, OrdTy)); } TsanVptrUpdate = checkSanitizerInterfaceFunction( M.getOrInsertFunction("__tsan_vptr_update", Attr, IRB.getVoidTy(), - IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), nullptr)); + IRB.getInt8PtrTy(), IRB.getInt8PtrTy())); TsanVptrLoad = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - "__tsan_vptr_read", Attr, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); + "__tsan_vptr_read", Attr, IRB.getVoidTy(), IRB.getInt8PtrTy())); TsanAtomicThreadFence = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - "__tsan_atomic_thread_fence", Attr, IRB.getVoidTy(), OrdTy, nullptr)); + "__tsan_atomic_thread_fence", Attr, IRB.getVoidTy(), OrdTy)); TsanAtomicSignalFence = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - "__tsan_atomic_signal_fence", Attr, IRB.getVoidTy(), OrdTy, nullptr)); + "__tsan_atomic_signal_fence", Attr, IRB.getVoidTy(), OrdTy)); MemmoveFn = checkSanitizerInterfaceFunction( M.getOrInsertFunction("memmove", Attr, IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), - IRB.getInt8PtrTy(), IntptrTy, nullptr)); + IRB.getInt8PtrTy(), IntptrTy)); MemcpyFn = checkSanitizerInterfaceFunction( M.getOrInsertFunction("memcpy", Attr, IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), - IRB.getInt8PtrTy(), IntptrTy, nullptr)); + IRB.getInt8PtrTy(), IntptrTy)); MemsetFn = checkSanitizerInterfaceFunction( M.getOrInsertFunction("memset", Attr, IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), - IRB.getInt32Ty(), IntptrTy, nullptr)); + IRB.getInt32Ty(), IntptrTy)); } bool ThreadSanitizer::doInitialization(Module &M) { @@ -271,7 +272,7 @@ static bool isVtableAccess(Instruction *I) { // Do not instrument known races/"benign races" that come from compiler // instrumentatin. The user has no way of suppressing them. -static bool shouldInstrumentReadWriteFromAddress(Value *Addr) { +static bool shouldInstrumentReadWriteFromAddress(const Module *M, Value *Addr) { // Peel off GEPs and BitCasts. Addr = Addr->stripInBoundsOffsets(); @@ -279,8 +280,9 @@ static bool shouldInstrumentReadWriteFromAddress(Value *Addr) { if (GV->hasSection()) { StringRef SectionName = GV->getSection(); // Check if the global is in the PGO counters section. - if (SectionName.endswith(getInstrProfCountersSectionName( - /*AddSegment=*/false))) + auto OF = Triple(M->getTargetTriple()).getObjectFormat(); + if (SectionName.endswith( + getInstrProfSectionName(IPSK_cnts, OF, /*AddSegmentInfo=*/false))) return false; } @@ -342,13 +344,13 @@ void ThreadSanitizer::chooseInstructionsToInstrument( for (Instruction *I : reverse(Local)) { if (StoreInst *Store = dyn_cast<StoreInst>(I)) { Value *Addr = Store->getPointerOperand(); - if (!shouldInstrumentReadWriteFromAddress(Addr)) + if (!shouldInstrumentReadWriteFromAddress(I->getModule(), Addr)) continue; WriteTargets.insert(Addr); } else { LoadInst *Load = cast<LoadInst>(I); Value *Addr = Load->getPointerOperand(); - if (!shouldInstrumentReadWriteFromAddress(Addr)) + if (!shouldInstrumentReadWriteFromAddress(I->getModule(), Addr)) continue; if (WriteTargets.count(Addr)) { // We will write to this temp, so no reason to analyze the read. @@ -377,10 +379,11 @@ void ThreadSanitizer::chooseInstructionsToInstrument( } static bool isAtomic(Instruction *I) { + // TODO: Ask TTI whether synchronization scope is between threads. if (LoadInst *LI = dyn_cast<LoadInst>(I)) - return LI->isAtomic() && LI->getSynchScope() == CrossThread; + return LI->isAtomic() && LI->getSyncScopeID() != SyncScope::SingleThread; if (StoreInst *SI = dyn_cast<StoreInst>(I)) - return SI->isAtomic() && SI->getSynchScope() == CrossThread; + return SI->isAtomic() && SI->getSyncScopeID() != SyncScope::SingleThread; if (isa<AtomicRMWInst>(I)) return true; if (isa<AtomicCmpXchgInst>(I)) @@ -674,7 +677,7 @@ bool ThreadSanitizer::instrumentAtomic(Instruction *I, const DataLayout &DL) { I->eraseFromParent(); } else if (FenceInst *FI = dyn_cast<FenceInst>(I)) { Value *Args[] = {createOrdering(&IRB, FI->getOrdering())}; - Function *F = FI->getSynchScope() == SingleThread ? + Function *F = FI->getSyncScopeID() == SyncScope::SingleThread ? TsanAtomicSignalFence : TsanAtomicThreadFence; CallInst *C = CallInst::Create(F, Args); ReplaceInstWithInst(I, C); diff --git a/contrib/llvm/lib/Transforms/ObjCARC/ARCRuntimeEntryPoints.h b/contrib/llvm/lib/Transforms/ObjCARC/ARCRuntimeEntryPoints.h index c748272..cb3b575 100644 --- a/contrib/llvm/lib/Transforms/ObjCARC/ARCRuntimeEntryPoints.h +++ b/contrib/llvm/lib/Transforms/ObjCARC/ARCRuntimeEntryPoints.h @@ -127,9 +127,8 @@ private: LLVMContext &C = TheModule->getContext(); Type *Params[] = { PointerType::getUnqual(Type::getInt8Ty(C)) }; - AttributeSet Attr = - AttributeSet().addAttribute(C, AttributeSet::FunctionIndex, - Attribute::NoUnwind); + AttributeList Attr = AttributeList().addAttribute( + C, AttributeList::FunctionIndex, Attribute::NoUnwind); FunctionType *Fty = FunctionType::get(Type::getVoidTy(C), Params, /*isVarArg=*/false); return Decl = TheModule->getOrInsertFunction(Name, Fty, Attr); @@ -144,10 +143,10 @@ private: Type *I8X = PointerType::getUnqual(Type::getInt8Ty(C)); Type *Params[] = { I8X }; FunctionType *Fty = FunctionType::get(I8X, Params, /*isVarArg=*/false); - AttributeSet Attr = AttributeSet(); + AttributeList Attr = AttributeList(); if (NoUnwind) - Attr = Attr.addAttribute(C, AttributeSet::FunctionIndex, + Attr = Attr.addAttribute(C, AttributeList::FunctionIndex, Attribute::NoUnwind); return Decl = TheModule->getOrInsertFunction(Name, Fty, Attr); @@ -162,10 +161,9 @@ private: Type *I8XX = PointerType::getUnqual(I8X); Type *Params[] = { I8XX, I8X }; - AttributeSet Attr = - AttributeSet().addAttribute(C, AttributeSet::FunctionIndex, - Attribute::NoUnwind); - Attr = Attr.addAttribute(C, 1, Attribute::NoCapture); + AttributeList Attr = AttributeList().addAttribute( + C, AttributeList::FunctionIndex, Attribute::NoUnwind); + Attr = Attr.addParamAttribute(C, 0, Attribute::NoCapture); FunctionType *Fty = FunctionType::get(Type::getVoidTy(C), Params, /*isVarArg=*/false); diff --git a/contrib/llvm/lib/Transforms/ObjCARC/BlotMapVector.h b/contrib/llvm/lib/Transforms/ObjCARC/BlotMapVector.h index ef075bd..9c5cf6f 100644 --- a/contrib/llvm/lib/Transforms/ObjCARC/BlotMapVector.h +++ b/contrib/llvm/lib/Transforms/ObjCARC/BlotMapVector.h @@ -8,8 +8,8 @@ //===----------------------------------------------------------------------===// #include "llvm/ADT/DenseMap.h" -#include <vector> #include <algorithm> +#include <vector> namespace llvm { /// \brief An associative container with fast insertion-order (deterministic) diff --git a/contrib/llvm/lib/Transforms/ObjCARC/DependencyAnalysis.cpp b/contrib/llvm/lib/Transforms/ObjCARC/DependencyAnalysis.cpp index 9d78e5a..4648050 100644 --- a/contrib/llvm/lib/Transforms/ObjCARC/DependencyAnalysis.cpp +++ b/contrib/llvm/lib/Transforms/ObjCARC/DependencyAnalysis.cpp @@ -20,8 +20,8 @@ /// //===----------------------------------------------------------------------===// -#include "ObjCARC.h" #include "DependencyAnalysis.h" +#include "ObjCARC.h" #include "ProvenanceAnalysis.h" #include "llvm/IR/CFG.h" diff --git a/contrib/llvm/lib/Transforms/ObjCARC/ObjCARC.h b/contrib/llvm/lib/Transforms/ObjCARC/ObjCARC.h index f02b75f..cd9b3d9 100644 --- a/contrib/llvm/lib/Transforms/ObjCARC/ObjCARC.h +++ b/contrib/llvm/lib/Transforms/ObjCARC/ObjCARC.h @@ -69,6 +69,19 @@ static inline void EraseInstruction(Instruction *CI) { RecursivelyDeleteTriviallyDeadInstructions(OldArg); } +/// If Inst is a ReturnRV and its operand is a call or invoke, return the +/// operand. Otherwise return null. +static inline const Instruction *getreturnRVOperand(const Instruction &Inst, + ARCInstKind Class) { + if (Class != ARCInstKind::RetainRV) + return nullptr; + + const auto *Opnd = Inst.getOperand(0)->stripPointerCasts(); + if (const auto *C = dyn_cast<CallInst>(Opnd)) + return C; + return dyn_cast<InvokeInst>(Opnd); +} + } // end namespace objcarc } // end namespace llvm diff --git a/contrib/llvm/lib/Transforms/ObjCARC/ObjCARCContract.cpp b/contrib/llvm/lib/Transforms/ObjCARC/ObjCARCContract.cpp index 23c1f59..e70e759 100644 --- a/contrib/llvm/lib/Transforms/ObjCARC/ObjCARCContract.cpp +++ b/contrib/llvm/lib/Transforms/ObjCARC/ObjCARCContract.cpp @@ -26,9 +26,9 @@ // TODO: ObjCARCContract could insert PHI nodes when uses aren't // dominated by single calls. -#include "ObjCARC.h" #include "ARCRuntimeEntryPoints.h" #include "DependencyAnalysis.h" +#include "ObjCARC.h" #include "ProvenanceAnalysis.h" #include "llvm/ADT/Statistic.h" #include "llvm/IR/Dominators.h" @@ -394,6 +394,7 @@ void ObjCARCContract::tryToContractReleaseIntoStoreStrong(Instruction *Release, DEBUG(llvm::dbgs() << " New Store Strong: " << *StoreStrong << "\n"); + if (&*Iter == Retain) ++Iter; if (&*Iter == Store) ++Iter; Store->eraseFromParent(); Release->eraseFromParent(); diff --git a/contrib/llvm/lib/Transforms/ObjCARC/ObjCARCOpts.cpp b/contrib/llvm/lib/Transforms/ObjCARC/ObjCARCOpts.cpp index 136d54a..8f3a33f 100644 --- a/contrib/llvm/lib/Transforms/ObjCARC/ObjCARCOpts.cpp +++ b/contrib/llvm/lib/Transforms/ObjCARC/ObjCARCOpts.cpp @@ -24,10 +24,10 @@ /// //===----------------------------------------------------------------------===// -#include "ObjCARC.h" #include "ARCRuntimeEntryPoints.h" #include "BlotMapVector.h" #include "DependencyAnalysis.h" +#include "ObjCARC.h" #include "ProvenanceAnalysis.h" #include "PtrState.h" #include "llvm/ADT/DenseMap.h" @@ -85,41 +85,6 @@ static const Value *FindSingleUseIdentifiedObject(const Value *Arg) { return nullptr; } -/// This is a wrapper around getUnderlyingObjCPtr along the lines of -/// GetUnderlyingObjects except that it returns early when it sees the first -/// alloca. -static inline bool AreAnyUnderlyingObjectsAnAlloca(const Value *V, - const DataLayout &DL) { - SmallPtrSet<const Value *, 4> Visited; - SmallVector<const Value *, 4> Worklist; - Worklist.push_back(V); - do { - const Value *P = Worklist.pop_back_val(); - P = GetUnderlyingObjCPtr(P, DL); - - if (isa<AllocaInst>(P)) - return true; - - if (!Visited.insert(P).second) - continue; - - if (const SelectInst *SI = dyn_cast<const SelectInst>(P)) { - Worklist.push_back(SI->getTrueValue()); - Worklist.push_back(SI->getFalseValue()); - continue; - } - - if (const PHINode *PN = dyn_cast<const PHINode>(P)) { - for (Value *IncValue : PN->incoming_values()) - Worklist.push_back(IncValue); - continue; - } - } while (!Worklist.empty()); - - return false; -} - - /// @} /// /// \defgroup ARCOpt ARC Optimization. @@ -481,9 +446,6 @@ namespace { /// MDKind identifiers. ARCMDKindCache MDKindCache; - // This is used to track if a pointer is stored into an alloca. - DenseSet<const Value *> MultiOwnersSet; - /// A flag indicating whether this optimization pass should run. bool Run; @@ -524,8 +486,7 @@ namespace { PairUpRetainsAndReleases(DenseMap<const BasicBlock *, BBState> &BBStates, BlotMapVector<Value *, RRInfo> &Retains, DenseMap<Value *, RRInfo> &Releases, Module *M, - SmallVectorImpl<Instruction *> &NewRetains, - SmallVectorImpl<Instruction *> &NewReleases, + Instruction * Retain, SmallVectorImpl<Instruction *> &DeadInsts, RRInfo &RetainsToMove, RRInfo &ReleasesToMove, Value *Arg, bool KnownSafe, @@ -1155,29 +1116,6 @@ bool ObjCARCOpt::VisitInstructionBottomUp( case ARCInstKind::None: // These are irrelevant. return NestingDetected; - case ARCInstKind::User: - // If we have a store into an alloca of a pointer we are tracking, the - // pointer has multiple owners implying that we must be more conservative. - // - // This comes up in the context of a pointer being ``KnownSafe''. In the - // presence of a block being initialized, the frontend will emit the - // objc_retain on the original pointer and the release on the pointer loaded - // from the alloca. The optimizer will through the provenance analysis - // realize that the two are related, but since we only require KnownSafe in - // one direction, will match the inner retain on the original pointer with - // the guard release on the original pointer. This is fixed by ensuring that - // in the presence of allocas we only unconditionally remove pointers if - // both our retain and our release are KnownSafe. - if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { - const DataLayout &DL = BB->getModule()->getDataLayout(); - if (AreAnyUnderlyingObjectsAnAlloca(SI->getPointerOperand(), DL)) { - auto I = MyStates.findPtrBottomUpState( - GetRCIdentityRoot(SI->getValueOperand())); - if (I != MyStates.bottom_up_ptr_end()) - MultiOwnersSet.insert(I->first); - } - } - break; default: break; } @@ -1540,8 +1478,7 @@ bool ObjCARCOpt::PairUpRetainsAndReleases( DenseMap<const BasicBlock *, BBState> &BBStates, BlotMapVector<Value *, RRInfo> &Retains, DenseMap<Value *, RRInfo> &Releases, Module *M, - SmallVectorImpl<Instruction *> &NewRetains, - SmallVectorImpl<Instruction *> &NewReleases, + Instruction *Retain, SmallVectorImpl<Instruction *> &DeadInsts, RRInfo &RetainsToMove, RRInfo &ReleasesToMove, Value *Arg, bool KnownSafe, bool &AnyPairsCompletelyEliminated) { @@ -1549,7 +1486,6 @@ bool ObjCARCOpt::PairUpRetainsAndReleases( // is already incremented, we can similarly ignore possible decrements unless // we are dealing with a retainable object with multiple provenance sources. bool KnownSafeTD = true, KnownSafeBU = true; - bool MultipleOwners = false; bool CFGHazardAfflicted = false; // Connect the dots between the top-down-collected RetainsToMove and @@ -1561,14 +1497,13 @@ bool ObjCARCOpt::PairUpRetainsAndReleases( unsigned OldCount = 0; unsigned NewCount = 0; bool FirstRelease = true; - for (;;) { + for (SmallVector<Instruction *, 4> NewRetains{Retain};;) { + SmallVector<Instruction *, 4> NewReleases; for (Instruction *NewRetain : NewRetains) { auto It = Retains.find(NewRetain); assert(It != Retains.end()); const RRInfo &NewRetainRRI = It->second; KnownSafeTD &= NewRetainRRI.KnownSafe; - MultipleOwners = - MultipleOwners || MultiOwnersSet.count(GetArgRCIdentityRoot(NewRetain)); for (Instruction *NewRetainRelease : NewRetainRRI.Calls) { auto Jt = Releases.find(NewRetainRelease); if (Jt == Releases.end()) @@ -1691,7 +1626,6 @@ bool ObjCARCOpt::PairUpRetainsAndReleases( } } } - NewReleases.clear(); if (NewRetains.empty()) break; } @@ -1745,10 +1679,6 @@ bool ObjCARCOpt::PerformCodePlacement( DEBUG(dbgs() << "\n== ObjCARCOpt::PerformCodePlacement ==\n"); bool AnyPairsCompletelyEliminated = false; - RRInfo RetainsToMove; - RRInfo ReleasesToMove; - SmallVector<Instruction *, 4> NewRetains; - SmallVector<Instruction *, 4> NewReleases; SmallVector<Instruction *, 8> DeadInsts; // Visit each retain. @@ -1780,9 +1710,10 @@ bool ObjCARCOpt::PerformCodePlacement( // Connect the dots between the top-down-collected RetainsToMove and // bottom-up-collected ReleasesToMove to form sets of related calls. - NewRetains.push_back(Retain); + RRInfo RetainsToMove, ReleasesToMove; + bool PerformMoveCalls = PairUpRetainsAndReleases( - BBStates, Retains, Releases, M, NewRetains, NewReleases, DeadInsts, + BBStates, Retains, Releases, M, Retain, DeadInsts, RetainsToMove, ReleasesToMove, Arg, KnownSafe, AnyPairsCompletelyEliminated); @@ -1792,12 +1723,6 @@ bool ObjCARCOpt::PerformCodePlacement( MoveCalls(Arg, RetainsToMove, ReleasesToMove, Retains, Releases, DeadInsts, M); } - - // Clean up state for next retain. - NewReleases.clear(); - NewRetains.clear(); - RetainsToMove.clear(); - ReleasesToMove.clear(); } // Now that we're done moving everything, we can delete the newly dead @@ -1987,9 +1912,6 @@ bool ObjCARCOpt::OptimizeSequences(Function &F) { Releases, F.getParent()); - // Cleanup. - MultiOwnersSet.clear(); - return AnyPairsCompletelyEliminated && NestingDetected; } diff --git a/contrib/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysis.cpp b/contrib/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysis.cpp index 9ffdfb4..62fc52f 100644 --- a/contrib/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysis.cpp +++ b/contrib/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysis.cpp @@ -22,8 +22,8 @@ /// //===----------------------------------------------------------------------===// -#include "ObjCARC.h" #include "ProvenanceAnalysis.h" +#include "ObjCARC.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" diff --git a/contrib/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysisEvaluator.cpp b/contrib/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysisEvaluator.cpp index c274e81..870a5f6 100644 --- a/contrib/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysisEvaluator.cpp +++ b/contrib/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysisEvaluator.cpp @@ -8,13 +8,13 @@ //===----------------------------------------------------------------------===// #include "ProvenanceAnalysis.h" -#include "llvm/Pass.h" #include "llvm/ADT/SetVector.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/Passes.h" -#include "llvm/IR/InstIterator.h" #include "llvm/IR/Function.h" +#include "llvm/IR/InstIterator.h" #include "llvm/IR/Module.h" +#include "llvm/Pass.h" #include "llvm/Support/raw_ostream.h" using namespace llvm; diff --git a/contrib/llvm/lib/Transforms/ObjCARC/PtrState.cpp b/contrib/llvm/lib/Transforms/ObjCARC/PtrState.cpp index a5afc8a..d13e941 100644 --- a/contrib/llvm/lib/Transforms/ObjCARC/PtrState.cpp +++ b/contrib/llvm/lib/Transforms/ObjCARC/PtrState.cpp @@ -244,6 +244,18 @@ void BottomUpPtrState::HandlePotentialUse(BasicBlock *BB, Instruction *Inst, const Value *Ptr, ProvenanceAnalysis &PA, ARCInstKind Class) { + auto SetSeqAndInsertReverseInsertPt = [&](Sequence NewSeq){ + assert(!HasReverseInsertPts()); + SetSeq(NewSeq); + // If this is an invoke instruction, we're scanning it as part of + // one of its successor blocks, since we can't insert code after it + // in its own block, and we don't want to split critical edges. + if (isa<InvokeInst>(Inst)) + InsertReverseInsertPt(&*BB->getFirstInsertionPt()); + else + InsertReverseInsertPt(&*++Inst->getIterator()); + }; + // Check for possible direct uses. switch (GetSeq()) { case S_Release: @@ -251,26 +263,18 @@ void BottomUpPtrState::HandlePotentialUse(BasicBlock *BB, Instruction *Inst, if (CanUse(Inst, Ptr, PA, Class)) { DEBUG(dbgs() << " CanUse: Seq: " << GetSeq() << "; " << *Ptr << "\n"); - assert(!HasReverseInsertPts()); - // If this is an invoke instruction, we're scanning it as part of - // one of its successor blocks, since we can't insert code after it - // in its own block, and we don't want to split critical edges. - if (isa<InvokeInst>(Inst)) - InsertReverseInsertPt(&*BB->getFirstInsertionPt()); - else - InsertReverseInsertPt(&*++Inst->getIterator()); - SetSeq(S_Use); + SetSeqAndInsertReverseInsertPt(S_Use); } else if (Seq == S_Release && IsUser(Class)) { DEBUG(dbgs() << " PreciseReleaseUse: Seq: " << GetSeq() << "; " << *Ptr << "\n"); // Non-movable releases depend on any possible objc pointer use. - SetSeq(S_Stop); - assert(!HasReverseInsertPts()); - // As above; handle invoke specially. - if (isa<InvokeInst>(Inst)) - InsertReverseInsertPt(&*BB->getFirstInsertionPt()); - else - InsertReverseInsertPt(&*++Inst->getIterator()); + SetSeqAndInsertReverseInsertPt(S_Stop); + } else if (const auto *Call = getreturnRVOperand(*Inst, Class)) { + if (CanUse(Call, Ptr, PA, GetBasicARCInstKind(Call))) { + DEBUG(dbgs() << " ReleaseUse: Seq: " << GetSeq() << "; " + << *Ptr << "\n"); + SetSeqAndInsertReverseInsertPt(S_Stop); + } } break; case S_Stop: @@ -351,8 +355,10 @@ bool TopDownPtrState::HandlePotentialAlterRefCount(Instruction *Inst, const Value *Ptr, ProvenanceAnalysis &PA, ARCInstKind Class) { - // Check for possible releases. - if (!CanAlterRefCount(Inst, Ptr, PA, Class)) + // Check for possible releases. Treat clang.arc.use as a releasing instruction + // to prevent sinking a retain past it. + if (!CanAlterRefCount(Inst, Ptr, PA, Class) && + Class != ARCInstKind::IntrinsicUser) return false; DEBUG(dbgs() << " CanAlterRefCount: Seq: " << GetSeq() << "; " << *Ptr diff --git a/contrib/llvm/lib/Transforms/ObjCARC/PtrState.h b/contrib/llvm/lib/Transforms/ObjCARC/PtrState.h index 9749e44..87298fa 100644 --- a/contrib/llvm/lib/Transforms/ObjCARC/PtrState.h +++ b/contrib/llvm/lib/Transforms/ObjCARC/PtrState.h @@ -21,8 +21,8 @@ #include "llvm/Analysis/ObjCARCInstKind.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Value.h" -#include "llvm/Support/raw_ostream.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" namespace llvm { namespace objcarc { diff --git a/contrib/llvm/lib/Transforms/Scalar/ADCE.cpp b/contrib/llvm/lib/Transforms/Scalar/ADCE.cpp index adc903c..5b467dc 100644 --- a/contrib/llvm/lib/Transforms/Scalar/ADCE.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/ADCE.cpp @@ -41,8 +41,8 @@ using namespace llvm; STATISTIC(NumRemoved, "Number of instructions removed"); STATISTIC(NumBranchesRemoved, "Number of branch instructions removed"); -// This is a tempoary option until we change the interface -// to this pass based on optimization level. +// This is a temporary option until we change the interface to this pass based +// on optimization level. static cl::opt<bool> RemoveControlFlowFlag("adce-remove-control-flow", cl::init(true), cl::Hidden); @@ -110,7 +110,7 @@ class AggressiveDeadCodeElimination { /// The set of blocks which we have determined whose control /// dependence sources must be live and which have not had - /// those dependences analyized. + /// those dependences analyzed. SmallPtrSet<BasicBlock *, 16> NewLiveBlocks; /// Set up auxiliary data structures for Instructions and BasicBlocks and @@ -145,7 +145,7 @@ class AggressiveDeadCodeElimination { /// was removed. bool removeDeadInstructions(); - /// Identify connected sections of the control flow grap which have + /// Identify connected sections of the control flow graph which have /// dead terminators and rewrite the control flow graph to remove them. void updateDeadRegions(); @@ -234,7 +234,7 @@ void AggressiveDeadCodeElimination::initialize() { return Iter != end() && Iter->second; } } State; - + State.reserve(F.size()); // Iterate over blocks in depth-first pre-order and // treat all edges to a block already seen as loop back edges @@ -262,25 +262,6 @@ void AggressiveDeadCodeElimination::initialize() { continue; auto *BB = BBInfo.BB; if (!PDT.getNode(BB)) { - markLive(BBInfo.Terminator); - continue; - } - for (auto *Succ : successors(BB)) - if (!PDT.getNode(Succ)) { - markLive(BBInfo.Terminator); - break; - } - } - - // Mark blocks live if there is no path from the block to the - // return of the function or a successor for which this is true. - // This protects IDFCalculator which cannot handle such blocks. - for (auto &BBInfoPair : BlockInfo) { - auto &BBInfo = BBInfoPair.second; - if (BBInfo.terminatorIsLive()) - continue; - auto *BB = BBInfo.BB; - if (!PDT.getNode(BB)) { DEBUG(dbgs() << "Not post-dominated by return: " << BB->getName() << '\n';); markLive(BBInfo.Terminator); @@ -579,7 +560,7 @@ void AggressiveDeadCodeElimination::updateDeadRegions() { PreferredSucc = Info; } assert((PreferredSucc && PreferredSucc->PostOrder > 0) && - "Failed to find safe successor for dead branc"); + "Failed to find safe successor for dead branch"); bool First = true; for (auto *Succ : successors(BB)) { if (!First || Succ != PreferredSucc->BB) @@ -594,13 +575,13 @@ void AggressiveDeadCodeElimination::updateDeadRegions() { // reverse top-sort order void AggressiveDeadCodeElimination::computeReversePostOrder() { - - // This provides a post-order numbering of the reverse conrtol flow graph + + // This provides a post-order numbering of the reverse control flow graph // Note that it is incomplete in the presence of infinite loops but we don't // need numbers blocks which don't reach the end of the functions since // all branches in those blocks are forced live. - - // For each block without successors, extend the DFS from the bloack + + // For each block without successors, extend the DFS from the block // backward through the graph SmallPtrSet<BasicBlock*, 16> Visited; unsigned PostOrder = 0; @@ -644,8 +625,8 @@ PreservedAnalyses ADCEPass::run(Function &F, FunctionAnalysisManager &FAM) { if (!AggressiveDeadCodeElimination(F, PDT).performDeadCodeElimination()) return PreservedAnalyses::all(); - // FIXME: This should also 'preserve the CFG'. - auto PA = PreservedAnalyses(); + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); PA.preserve<GlobalsAA>(); return PA; } diff --git a/contrib/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp b/contrib/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp index c1df317..99480f1 100644 --- a/contrib/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp @@ -19,12 +19,11 @@ #define AA_NAME "alignment-from-assumptions" #define DEBUG_TYPE AA_NAME #include "llvm/Transforms/Scalar/AlignmentFromAssumptions.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/ValueTracking.h" @@ -35,6 +34,7 @@ #include "llvm/IR/Module.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" using namespace llvm; STATISTIC(NumLoadAlignChanged, @@ -438,19 +438,13 @@ AlignmentFromAssumptionsPass::run(Function &F, FunctionAnalysisManager &AM) { AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F); ScalarEvolution &SE = AM.getResult<ScalarEvolutionAnalysis>(F); DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F); - bool Changed = runImpl(F, AC, &SE, &DT); - - // FIXME: We need to invalidate this to avoid PR28400. Is there a better - // solution? - AM.invalidate<ScalarEvolutionAnalysis>(F); - - if (!Changed) + if (!runImpl(F, AC, &SE, &DT)) return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); PA.preserve<AAManager>(); PA.preserve<ScalarEvolutionAnalysis>(); PA.preserve<GlobalsAA>(); - PA.preserve<LoopAnalysis>(); - PA.preserve<DominatorTreeAnalysis>(); return PA; } diff --git a/contrib/llvm/lib/Transforms/Scalar/BDCE.cpp b/contrib/llvm/lib/Transforms/Scalar/BDCE.cpp index 251b387..2e56186 100644 --- a/contrib/llvm/lib/Transforms/Scalar/BDCE.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/BDCE.cpp @@ -15,6 +15,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/BDCE.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/DemandedBits.h" @@ -35,6 +36,46 @@ using namespace llvm; STATISTIC(NumRemoved, "Number of instructions removed (unused)"); STATISTIC(NumSimplified, "Number of instructions trivialized (dead bits)"); +/// If an instruction is trivialized (dead), then the chain of users of that +/// instruction may need to be cleared of assumptions that can no longer be +/// guaranteed correct. +static void clearAssumptionsOfUsers(Instruction *I, DemandedBits &DB) { + assert(I->getType()->isIntegerTy() && "Trivializing a non-integer value?"); + + // Initialize the worklist with eligible direct users. + SmallVector<Instruction *, 16> WorkList; + for (User *JU : I->users()) { + // If all bits of a user are demanded, then we know that nothing below that + // in the def-use chain needs to be changed. + auto *J = dyn_cast<Instruction>(JU); + if (J && !DB.getDemandedBits(J).isAllOnesValue()) + WorkList.push_back(J); + } + + // DFS through subsequent users while tracking visits to avoid cycles. + SmallPtrSet<Instruction *, 16> Visited; + while (!WorkList.empty()) { + Instruction *J = WorkList.pop_back_val(); + + // NSW, NUW, and exact are based on operands that might have changed. + J->dropPoisonGeneratingFlags(); + + // We do not have to worry about llvm.assume or range metadata: + // 1. llvm.assume demands its operand, so trivializing can't change it. + // 2. range metadata only applies to memory accesses which demand all bits. + + Visited.insert(J); + + for (User *KU : J->users()) { + // If all bits of a user are demanded, then we know that nothing below + // that in the def-use chain needs to be changed. + auto *K = dyn_cast<Instruction>(KU); + if (K && !Visited.count(K) && !DB.getDemandedBits(K).isAllOnesValue()) + WorkList.push_back(K); + } + } +} + static bool bitTrackingDCE(Function &F, DemandedBits &DB) { SmallVector<Instruction*, 128> Worklist; bool Changed = false; @@ -51,6 +92,9 @@ static bool bitTrackingDCE(Function &F, DemandedBits &DB) { // replacing all uses with something else. Then, if they don't need to // remain live (because they have side effects, etc.) we can remove them. DEBUG(dbgs() << "BDCE: Trivializing: " << I << " (all bits dead)\n"); + + clearAssumptionsOfUsers(&I, DB); + // FIXME: In theory we could substitute undef here instead of zero. // This should be reconsidered once we settle on the semantics of // undef, poison, etc. @@ -80,8 +124,8 @@ PreservedAnalyses BDCEPass::run(Function &F, FunctionAnalysisManager &AM) { if (!bitTrackingDCE(F, DB)) return PreservedAnalyses::all(); - // FIXME: This should also 'preserve the CFG'. - auto PA = PreservedAnalyses(); + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); PA.preserve<GlobalsAA>(); return PA; } diff --git a/contrib/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp b/contrib/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp index 3826251..122c931 100644 --- a/contrib/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp @@ -38,11 +38,13 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/Pass.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" #include <tuple> using namespace llvm; @@ -53,6 +55,12 @@ using namespace consthoist; STATISTIC(NumConstantsHoisted, "Number of constants hoisted"); STATISTIC(NumConstantsRebased, "Number of constants rebased"); +static cl::opt<bool> ConstHoistWithBlockFrequency( + "consthoist-with-block-frequency", cl::init(true), cl::Hidden, + cl::desc("Enable the use of the block frequency analysis to reduce the " + "chance to execute const materialization more frequently than " + "without hoisting.")); + namespace { /// \brief The constant hoisting pass. class ConstantHoistingLegacyPass : public FunctionPass { @@ -68,6 +76,8 @@ public: void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); + if (ConstHoistWithBlockFrequency) + AU.addRequired<BlockFrequencyInfoWrapperPass>(); AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<TargetTransformInfoWrapperPass>(); } @@ -82,6 +92,7 @@ private: char ConstantHoistingLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(ConstantHoistingLegacyPass, "consthoist", "Constant Hoisting", false, false) +INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_END(ConstantHoistingLegacyPass, "consthoist", @@ -99,9 +110,13 @@ bool ConstantHoistingLegacyPass::runOnFunction(Function &Fn) { DEBUG(dbgs() << "********** Begin Constant Hoisting **********\n"); DEBUG(dbgs() << "********** Function: " << Fn.getName() << '\n'); - bool MadeChange = Impl.runImpl( - Fn, getAnalysis<TargetTransformInfoWrapperPass>().getTTI(Fn), - getAnalysis<DominatorTreeWrapperPass>().getDomTree(), Fn.getEntryBlock()); + bool MadeChange = + Impl.runImpl(Fn, getAnalysis<TargetTransformInfoWrapperPass>().getTTI(Fn), + getAnalysis<DominatorTreeWrapperPass>().getDomTree(), + ConstHoistWithBlockFrequency + ? &getAnalysis<BlockFrequencyInfoWrapperPass>().getBFI() + : nullptr, + Fn.getEntryBlock()); if (MadeChange) { DEBUG(dbgs() << "********** Function after Constant Hoisting: " @@ -136,37 +151,163 @@ Instruction *ConstantHoistingPass::findMatInsertPt(Instruction *Inst, if (Idx != ~0U && isa<PHINode>(Inst)) return cast<PHINode>(Inst)->getIncomingBlock(Idx)->getTerminator(); - BasicBlock *IDom = DT->getNode(Inst->getParent())->getIDom()->getBlock(); - return IDom->getTerminator(); + // This must be an EH pad. Iterate over immediate dominators until we find a + // non-EH pad. We need to skip over catchswitch blocks, which are both EH pads + // and terminators. + auto IDom = DT->getNode(Inst->getParent())->getIDom(); + while (IDom->getBlock()->isEHPad()) { + assert(Entry != IDom->getBlock() && "eh pad in entry block"); + IDom = IDom->getIDom(); + } + + return IDom->getBlock()->getTerminator(); +} + +/// \brief Given \p BBs as input, find another set of BBs which collectively +/// dominates \p BBs and have the minimal sum of frequencies. Return the BB +/// set found in \p BBs. +static void findBestInsertionSet(DominatorTree &DT, BlockFrequencyInfo &BFI, + BasicBlock *Entry, + SmallPtrSet<BasicBlock *, 8> &BBs) { + assert(!BBs.count(Entry) && "Assume Entry is not in BBs"); + // Nodes on the current path to the root. + SmallPtrSet<BasicBlock *, 8> Path; + // Candidates includes any block 'BB' in set 'BBs' that is not strictly + // dominated by any other blocks in set 'BBs', and all nodes in the path + // in the dominator tree from Entry to 'BB'. + SmallPtrSet<BasicBlock *, 16> Candidates; + for (auto BB : BBs) { + Path.clear(); + // Walk up the dominator tree until Entry or another BB in BBs + // is reached. Insert the nodes on the way to the Path. + BasicBlock *Node = BB; + // The "Path" is a candidate path to be added into Candidates set. + bool isCandidate = false; + do { + Path.insert(Node); + if (Node == Entry || Candidates.count(Node)) { + isCandidate = true; + break; + } + assert(DT.getNode(Node)->getIDom() && + "Entry doens't dominate current Node"); + Node = DT.getNode(Node)->getIDom()->getBlock(); + } while (!BBs.count(Node)); + + // If isCandidate is false, Node is another Block in BBs dominating + // current 'BB'. Drop the nodes on the Path. + if (!isCandidate) + continue; + + // Add nodes on the Path into Candidates. + Candidates.insert(Path.begin(), Path.end()); + } + + // Sort the nodes in Candidates in top-down order and save the nodes + // in Orders. + unsigned Idx = 0; + SmallVector<BasicBlock *, 16> Orders; + Orders.push_back(Entry); + while (Idx != Orders.size()) { + BasicBlock *Node = Orders[Idx++]; + for (auto ChildDomNode : DT.getNode(Node)->getChildren()) { + if (Candidates.count(ChildDomNode->getBlock())) + Orders.push_back(ChildDomNode->getBlock()); + } + } + + // Visit Orders in bottom-up order. + typedef std::pair<SmallPtrSet<BasicBlock *, 16>, BlockFrequency> + InsertPtsCostPair; + // InsertPtsMap is a map from a BB to the best insertion points for the + // subtree of BB (subtree not including the BB itself). + DenseMap<BasicBlock *, InsertPtsCostPair> InsertPtsMap; + InsertPtsMap.reserve(Orders.size() + 1); + for (auto RIt = Orders.rbegin(); RIt != Orders.rend(); RIt++) { + BasicBlock *Node = *RIt; + bool NodeInBBs = BBs.count(Node); + SmallPtrSet<BasicBlock *, 16> &InsertPts = InsertPtsMap[Node].first; + BlockFrequency &InsertPtsFreq = InsertPtsMap[Node].second; + + // Return the optimal insert points in BBs. + if (Node == Entry) { + BBs.clear(); + if (InsertPtsFreq > BFI.getBlockFreq(Node) || + (InsertPtsFreq == BFI.getBlockFreq(Node) && InsertPts.size() > 1)) + BBs.insert(Entry); + else + BBs.insert(InsertPts.begin(), InsertPts.end()); + break; + } + + BasicBlock *Parent = DT.getNode(Node)->getIDom()->getBlock(); + // Initially, ParentInsertPts is empty and ParentPtsFreq is 0. Every child + // will update its parent's ParentInsertPts and ParentPtsFreq. + SmallPtrSet<BasicBlock *, 16> &ParentInsertPts = InsertPtsMap[Parent].first; + BlockFrequency &ParentPtsFreq = InsertPtsMap[Parent].second; + // Choose to insert in Node or in subtree of Node. + // Don't hoist to EHPad because we may not find a proper place to insert + // in EHPad. + // If the total frequency of InsertPts is the same as the frequency of the + // target Node, and InsertPts contains more than one nodes, choose hoisting + // to reduce code size. + if (NodeInBBs || + (!Node->isEHPad() && + (InsertPtsFreq > BFI.getBlockFreq(Node) || + (InsertPtsFreq == BFI.getBlockFreq(Node) && InsertPts.size() > 1)))) { + ParentInsertPts.insert(Node); + ParentPtsFreq += BFI.getBlockFreq(Node); + } else { + ParentInsertPts.insert(InsertPts.begin(), InsertPts.end()); + ParentPtsFreq += InsertPtsFreq; + } + } } /// \brief Find an insertion point that dominates all uses. -Instruction *ConstantHoistingPass::findConstantInsertionPoint( +SmallPtrSet<Instruction *, 8> ConstantHoistingPass::findConstantInsertionPoint( const ConstantInfo &ConstInfo) const { assert(!ConstInfo.RebasedConstants.empty() && "Invalid constant info entry."); // Collect all basic blocks. SmallPtrSet<BasicBlock *, 8> BBs; + SmallPtrSet<Instruction *, 8> InsertPts; for (auto const &RCI : ConstInfo.RebasedConstants) for (auto const &U : RCI.Uses) BBs.insert(findMatInsertPt(U.Inst, U.OpndIdx)->getParent()); - if (BBs.count(Entry)) - return &Entry->front(); + if (BBs.count(Entry)) { + InsertPts.insert(&Entry->front()); + return InsertPts; + } + + if (BFI) { + findBestInsertionSet(*DT, *BFI, Entry, BBs); + for (auto BB : BBs) { + BasicBlock::iterator InsertPt = BB->begin(); + for (; isa<PHINode>(InsertPt) || InsertPt->isEHPad(); ++InsertPt) + ; + InsertPts.insert(&*InsertPt); + } + return InsertPts; + } while (BBs.size() >= 2) { BasicBlock *BB, *BB1, *BB2; BB1 = *BBs.begin(); BB2 = *std::next(BBs.begin()); BB = DT->findNearestCommonDominator(BB1, BB2); - if (BB == Entry) - return &Entry->front(); + if (BB == Entry) { + InsertPts.insert(&Entry->front()); + return InsertPts; + } BBs.erase(BB1); BBs.erase(BB2); BBs.insert(BB); } assert((BBs.size() == 1) && "Expected only one element."); Instruction &FirstInst = (*BBs.begin())->front(); - return findMatInsertPt(&FirstInst); + InsertPts.insert(findMatInsertPt(&FirstInst)); + return InsertPts; } @@ -210,68 +351,65 @@ void ConstantHoistingPass::collectConstantCandidates( } } -/// \brief Scan the instruction for expensive integer constants and record them -/// in the constant candidate vector. -void ConstantHoistingPass::collectConstantCandidates( - ConstCandMapType &ConstCandMap, Instruction *Inst) { - // Skip all cast instructions. They are visited indirectly later on. - if (Inst->isCast()) - return; - - // Can't handle inline asm. Skip it. - if (auto Call = dyn_cast<CallInst>(Inst)) - if (isa<InlineAsm>(Call->getCalledValue())) - return; - // Switch cases must remain constant, and if the value being tested is - // constant the entire thing should disappear. - if (isa<SwitchInst>(Inst)) - return; +/// \brief Check the operand for instruction Inst at index Idx. +void ConstantHoistingPass::collectConstantCandidates( + ConstCandMapType &ConstCandMap, Instruction *Inst, unsigned Idx) { + Value *Opnd = Inst->getOperand(Idx); - // Static allocas (constant size in the entry block) are handled by - // prologue/epilogue insertion so they're free anyway. We definitely don't - // want to make them non-constant. - auto AI = dyn_cast<AllocaInst>(Inst); - if (AI && AI->isStaticAlloca()) + // Visit constant integers. + if (auto ConstInt = dyn_cast<ConstantInt>(Opnd)) { + collectConstantCandidates(ConstCandMap, Inst, Idx, ConstInt); return; + } - // Scan all operands. - for (unsigned Idx = 0, E = Inst->getNumOperands(); Idx != E; ++Idx) { - Value *Opnd = Inst->getOperand(Idx); + // Visit cast instructions that have constant integers. + if (auto CastInst = dyn_cast<Instruction>(Opnd)) { + // Only visit cast instructions, which have been skipped. All other + // instructions should have already been visited. + if (!CastInst->isCast()) + return; - // Visit constant integers. - if (auto ConstInt = dyn_cast<ConstantInt>(Opnd)) { + if (auto *ConstInt = dyn_cast<ConstantInt>(CastInst->getOperand(0))) { + // Pretend the constant is directly used by the instruction and ignore + // the cast instruction. collectConstantCandidates(ConstCandMap, Inst, Idx, ConstInt); - continue; + return; } + } - // Visit cast instructions that have constant integers. - if (auto CastInst = dyn_cast<Instruction>(Opnd)) { - // Only visit cast instructions, which have been skipped. All other - // instructions should have already been visited. - if (!CastInst->isCast()) - continue; + // Visit constant expressions that have constant integers. + if (auto ConstExpr = dyn_cast<ConstantExpr>(Opnd)) { + // Only visit constant cast expressions. + if (!ConstExpr->isCast()) + return; - if (auto *ConstInt = dyn_cast<ConstantInt>(CastInst->getOperand(0))) { - // Pretend the constant is directly used by the instruction and ignore - // the cast instruction. - collectConstantCandidates(ConstCandMap, Inst, Idx, ConstInt); - continue; - } + if (auto ConstInt = dyn_cast<ConstantInt>(ConstExpr->getOperand(0))) { + // Pretend the constant is directly used by the instruction and ignore + // the constant expression. + collectConstantCandidates(ConstCandMap, Inst, Idx, ConstInt); + return; } + } +} - // Visit constant expressions that have constant integers. - if (auto ConstExpr = dyn_cast<ConstantExpr>(Opnd)) { - // Only visit constant cast expressions. - if (!ConstExpr->isCast()) - continue; - if (auto ConstInt = dyn_cast<ConstantInt>(ConstExpr->getOperand(0))) { - // Pretend the constant is directly used by the instruction and ignore - // the constant expression. - collectConstantCandidates(ConstCandMap, Inst, Idx, ConstInt); - continue; - } +/// \brief Scan the instruction for expensive integer constants and record them +/// in the constant candidate vector. +void ConstantHoistingPass::collectConstantCandidates( + ConstCandMapType &ConstCandMap, Instruction *Inst) { + // Skip all cast instructions. They are visited indirectly later on. + if (Inst->isCast()) + return; + + // Scan all operands. + for (unsigned Idx = 0, E = Inst->getNumOperands(); Idx != E; ++Idx) { + // The cost of materializing the constants (defined in + // `TargetTransformInfo::getIntImmCost`) for instructions which only take + // constant variables is lower than `TargetTransformInfo::TCC_Basic`. So + // it's safe for us to collect constant candidates from all IntrinsicInsts. + if (canReplaceOperandWithVariable(Inst, Idx) || isa<IntrinsicInst>(Inst)) { + collectConstantCandidates(ConstCandMap, Inst, Idx); } } // end of for all operands } @@ -289,8 +427,8 @@ void ConstantHoistingPass::collectConstantCandidates(Function &Fn) { // bit widths (APInt Operator- does not like that). If the value cannot be // represented in uint64 we return an "empty" APInt. This is then interpreted // as the value is not in range. -static llvm::Optional<APInt> calculateOffsetDiff(APInt V1, APInt V2) -{ +static llvm::Optional<APInt> calculateOffsetDiff(const APInt &V1, + const APInt &V2) { llvm::Optional<APInt> Res = None; unsigned BW = V1.getBitWidth() > V2.getBitWidth() ? V1.getBitWidth() : V2.getBitWidth(); @@ -549,29 +687,54 @@ bool ConstantHoistingPass::emitBaseConstants() { bool MadeChange = false; for (auto const &ConstInfo : ConstantVec) { // Hoist and hide the base constant behind a bitcast. - Instruction *IP = findConstantInsertionPoint(ConstInfo); - IntegerType *Ty = ConstInfo.BaseConstant->getType(); - Instruction *Base = - new BitCastInst(ConstInfo.BaseConstant, Ty, "const", IP); - DEBUG(dbgs() << "Hoist constant (" << *ConstInfo.BaseConstant << ") to BB " - << IP->getParent()->getName() << '\n' << *Base << '\n'); - NumConstantsHoisted++; + SmallPtrSet<Instruction *, 8> IPSet = findConstantInsertionPoint(ConstInfo); + assert(!IPSet.empty() && "IPSet is empty"); + + unsigned UsesNum = 0; + unsigned ReBasesNum = 0; + for (Instruction *IP : IPSet) { + IntegerType *Ty = ConstInfo.BaseConstant->getType(); + Instruction *Base = + new BitCastInst(ConstInfo.BaseConstant, Ty, "const", IP); + DEBUG(dbgs() << "Hoist constant (" << *ConstInfo.BaseConstant + << ") to BB " << IP->getParent()->getName() << '\n' + << *Base << '\n'); + + // Emit materialization code for all rebased constants. + unsigned Uses = 0; + for (auto const &RCI : ConstInfo.RebasedConstants) { + for (auto const &U : RCI.Uses) { + Uses++; + BasicBlock *OrigMatInsertBB = + findMatInsertPt(U.Inst, U.OpndIdx)->getParent(); + // If Base constant is to be inserted in multiple places, + // generate rebase for U using the Base dominating U. + if (IPSet.size() == 1 || + DT->dominates(Base->getParent(), OrigMatInsertBB)) { + emitBaseConstants(Base, RCI.Offset, U); + ReBasesNum++; + } + } + } + UsesNum = Uses; - // Emit materialization code for all rebased constants. - for (auto const &RCI : ConstInfo.RebasedConstants) { - NumConstantsRebased++; - for (auto const &U : RCI.Uses) - emitBaseConstants(Base, RCI.Offset, U); + // Use the same debug location as the last user of the constant. + assert(!Base->use_empty() && "The use list is empty!?"); + assert(isa<Instruction>(Base->user_back()) && + "All uses should be instructions."); + Base->setDebugLoc(cast<Instruction>(Base->user_back())->getDebugLoc()); } + (void)UsesNum; + (void)ReBasesNum; + // Expect all uses are rebased after rebase is done. + assert(UsesNum == ReBasesNum && "Not all uses are rebased"); + + NumConstantsHoisted++; - // Use the same debug location as the last user of the constant. - assert(!Base->use_empty() && "The use list is empty!?"); - assert(isa<Instruction>(Base->user_back()) && - "All uses should be instructions."); - Base->setDebugLoc(cast<Instruction>(Base->user_back())->getDebugLoc()); + // Base constant is also included in ConstInfo.RebasedConstants, so + // deduct 1 from ConstInfo.RebasedConstants.size(). + NumConstantsRebased = ConstInfo.RebasedConstants.size() - 1; - // Correct for base constant, which we counted above too. - NumConstantsRebased--; MadeChange = true; } return MadeChange; @@ -587,9 +750,11 @@ void ConstantHoistingPass::deleteDeadCastInst() const { /// \brief Optimize expensive integer constants in the given function. bool ConstantHoistingPass::runImpl(Function &Fn, TargetTransformInfo &TTI, - DominatorTree &DT, BasicBlock &Entry) { + DominatorTree &DT, BlockFrequencyInfo *BFI, + BasicBlock &Entry) { this->TTI = &TTI; this->DT = &DT; + this->BFI = BFI; this->Entry = &Entry; // Collect all constant candidates. collectConstantCandidates(Fn); @@ -620,9 +785,13 @@ PreservedAnalyses ConstantHoistingPass::run(Function &F, FunctionAnalysisManager &AM) { auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &TTI = AM.getResult<TargetIRAnalysis>(F); - if (!runImpl(F, TTI, DT, F.getEntryBlock())) + auto BFI = ConstHoistWithBlockFrequency + ? &AM.getResult<BlockFrequencyAnalysis>(F) + : nullptr; + if (!runImpl(F, TTI, DT, BFI, F.getEntryBlock())) return PreservedAnalyses::all(); - // FIXME: This should also 'preserve the CFG'. - return PreservedAnalyses::none(); + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + return PA; } diff --git a/contrib/llvm/lib/Transforms/Scalar/ConstantProp.cpp b/contrib/llvm/lib/Transforms/Scalar/ConstantProp.cpp index 9e98219..4fa2789 100644 --- a/contrib/llvm/lib/Transforms/Scalar/ConstantProp.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/ConstantProp.cpp @@ -18,15 +18,15 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/Constant.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instruction.h" #include "llvm/Pass.h" -#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" #include <set> using namespace llvm; diff --git a/contrib/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp b/contrib/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp index 84f9373..2815778 100644 --- a/contrib/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp @@ -12,8 +12,8 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/CorrelatedValuePropagation.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LazyValueInfo.h" @@ -26,6 +26,7 @@ #include "llvm/Pass.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; @@ -95,7 +96,8 @@ static bool processSelect(SelectInst *S, LazyValueInfo *LVI) { return true; } -static bool processPHI(PHINode *P, LazyValueInfo *LVI) { +static bool processPHI(PHINode *P, LazyValueInfo *LVI, + const SimplifyQuery &SQ) { bool Changed = false; BasicBlock *BB = P->getParent(); @@ -149,9 +151,7 @@ static bool processPHI(PHINode *P, LazyValueInfo *LVI) { Changed = true; } - // FIXME: Provide TLI, DT, AT to SimplifyInstruction. - const DataLayout &DL = BB->getModule()->getDataLayout(); - if (Value *V = SimplifyInstruction(P, DL)) { + if (Value *V = SimplifyInstruction(P, SQ)) { P->replaceAllUsesWith(V); P->eraseFromParent(); Changed = true; @@ -232,12 +232,10 @@ static bool processSwitch(SwitchInst *SI, LazyValueInfo *LVI) { pred_iterator PB = pred_begin(BB), PE = pred_end(BB); if (PB == PE) return false; - // Analyse each switch case in turn. This is done in reverse order so that - // removing a case doesn't cause trouble for the iteration. + // Analyse each switch case in turn. bool Changed = false; - for (SwitchInst::CaseIt CI = SI->case_end(), CE = SI->case_begin(); CI-- != CE; - ) { - ConstantInt *Case = CI.getCaseValue(); + for (auto CI = SI->case_begin(), CE = SI->case_end(); CI != CE;) { + ConstantInt *Case = CI->getCaseValue(); // Check to see if the switch condition is equal to/not equal to the case // value on every incoming edge, equal/not equal being the same each time. @@ -270,8 +268,9 @@ static bool processSwitch(SwitchInst *SI, LazyValueInfo *LVI) { if (State == LazyValueInfo::False) { // This case never fires - remove it. - CI.getCaseSuccessor()->removePredecessor(BB); - SI->removeCase(CI); // Does not invalidate the iterator. + CI->getCaseSuccessor()->removePredecessor(BB); + CI = SI->removeCase(CI); + CE = SI->case_end(); // The condition can be modified by removePredecessor's PHI simplification // logic. @@ -279,7 +278,9 @@ static bool processSwitch(SwitchInst *SI, LazyValueInfo *LVI) { ++NumDeadCases; Changed = true; - } else if (State == LazyValueInfo::True) { + continue; + } + if (State == LazyValueInfo::True) { // This case always fires. Arrange for the switch to be turned into an // unconditional branch by replacing the switch condition with the case // value. @@ -288,6 +289,9 @@ static bool processSwitch(SwitchInst *SI, LazyValueInfo *LVI) { Changed = true; break; } + + // Increment the case iterator since we didn't delete it. + ++CI; } if (Changed) @@ -300,7 +304,7 @@ static bool processSwitch(SwitchInst *SI, LazyValueInfo *LVI) { /// Infer nonnull attributes for the arguments at the specified callsite. static bool processCallSite(CallSite CS, LazyValueInfo *LVI) { - SmallVector<unsigned, 4> Indices; + SmallVector<unsigned, 4> ArgNos; unsigned ArgNo = 0; for (Value *V : CS.args()) { @@ -308,23 +312,24 @@ static bool processCallSite(CallSite CS, LazyValueInfo *LVI) { // Try to mark pointer typed parameters as non-null. We skip the // relatively expensive analysis for constants which are obviously either // null or non-null to start with. - if (Type && !CS.paramHasAttr(ArgNo + 1, Attribute::NonNull) && + if (Type && !CS.paramHasAttr(ArgNo, Attribute::NonNull) && !isa<Constant>(V) && LVI->getPredicateAt(ICmpInst::ICMP_EQ, V, ConstantPointerNull::get(Type), CS.getInstruction()) == LazyValueInfo::False) - Indices.push_back(ArgNo + 1); + ArgNos.push_back(ArgNo); ArgNo++; } assert(ArgNo == CS.arg_size() && "sanity check"); - if (Indices.empty()) + if (ArgNos.empty()) return false; - AttributeSet AS = CS.getAttributes(); + AttributeList AS = CS.getAttributes(); LLVMContext &Ctx = CS.getInstruction()->getContext(); - AS = AS.addAttribute(Ctx, Indices, Attribute::get(Ctx, Attribute::NonNull)); + AS = AS.addParamAttribute(Ctx, ArgNos, + Attribute::get(Ctx, Attribute::NonNull)); CS.setAttributes(AS); return true; @@ -437,9 +442,8 @@ static bool processAdd(BinaryOperator *AddOp, LazyValueInfo *LVI) { bool Changed = false; if (!NUW) { - ConstantRange NUWRange = - LRange.makeGuaranteedNoWrapRegion(BinaryOperator::Add, LRange, - OBO::NoUnsignedWrap); + ConstantRange NUWRange = ConstantRange::makeGuaranteedNoWrapRegion( + BinaryOperator::Add, LRange, OBO::NoUnsignedWrap); if (!NUWRange.isEmptySet()) { bool NewNUW = NUWRange.contains(LazyRRange()); AddOp->setHasNoUnsignedWrap(NewNUW); @@ -447,9 +451,8 @@ static bool processAdd(BinaryOperator *AddOp, LazyValueInfo *LVI) { } } if (!NSW) { - ConstantRange NSWRange = - LRange.makeGuaranteedNoWrapRegion(BinaryOperator::Add, LRange, - OBO::NoSignedWrap); + ConstantRange NSWRange = ConstantRange::makeGuaranteedNoWrapRegion( + BinaryOperator::Add, LRange, OBO::NoSignedWrap); if (!NSWRange.isEmptySet()) { bool NewNSW = NSWRange.contains(LazyRRange()); AddOp->setHasNoSignedWrap(NewNSW); @@ -483,9 +486,8 @@ static Constant *getConstantAt(Value *V, Instruction *At, LazyValueInfo *LVI) { ConstantInt::getFalse(C->getContext()); } -static bool runImpl(Function &F, LazyValueInfo *LVI) { +static bool runImpl(Function &F, LazyValueInfo *LVI, const SimplifyQuery &SQ) { bool FnChanged = false; - // Visiting in a pre-order depth-first traversal causes us to simplify early // blocks before querying later blocks (which require us to analyze early // blocks). Eagerly simplifying shallow blocks means there is strictly less @@ -500,7 +502,7 @@ static bool runImpl(Function &F, LazyValueInfo *LVI) { BBChanged |= processSelect(cast<SelectInst>(II), LVI); break; case Instruction::PHI: - BBChanged |= processPHI(cast<PHINode>(II), LVI); + BBChanged |= processPHI(cast<PHINode>(II), LVI, SQ); break; case Instruction::ICmp: case Instruction::FCmp: @@ -548,7 +550,7 @@ static bool runImpl(Function &F, LazyValueInfo *LVI) { BBChanged = true; } } - }; + } FnChanged |= BBChanged; } @@ -561,18 +563,14 @@ bool CorrelatedValuePropagation::runOnFunction(Function &F) { return false; LazyValueInfo *LVI = &getAnalysis<LazyValueInfoWrapperPass>().getLVI(); - return runImpl(F, LVI); + return runImpl(F, LVI, getBestSimplifyQuery(*this, F)); } PreservedAnalyses CorrelatedValuePropagationPass::run(Function &F, FunctionAnalysisManager &AM) { LazyValueInfo *LVI = &AM.getResult<LazyValueAnalysis>(F); - bool Changed = runImpl(F, LVI); - - // FIXME: We need to invalidate LVI to avoid PR28400. Is there a better - // solution? - AM.invalidate<LazyValueAnalysis>(F); + bool Changed = runImpl(F, LVI, getBestSimplifyQuery(AM, F)); if (!Changed) return PreservedAnalyses::all(); diff --git a/contrib/llvm/lib/Transforms/Scalar/DCE.cpp b/contrib/llvm/lib/Transforms/Scalar/DCE.cpp index cc2a3cf..fa4806e 100644 --- a/contrib/llvm/lib/Transforms/Scalar/DCE.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/DCE.cpp @@ -19,10 +19,10 @@ #include "llvm/Transforms/Scalar/DCE.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instruction.h" #include "llvm/Pass.h" -#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; @@ -124,9 +124,12 @@ static bool eliminateDeadCode(Function &F, TargetLibraryInfo *TLI) { } PreservedAnalyses DCEPass::run(Function &F, FunctionAnalysisManager &AM) { - if (eliminateDeadCode(F, AM.getCachedResult<TargetLibraryAnalysis>(F))) - return PreservedAnalyses::none(); - return PreservedAnalyses::all(); + if (!eliminateDeadCode(F, AM.getCachedResult<TargetLibraryAnalysis>(F))) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + return PA; } namespace { diff --git a/contrib/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/contrib/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp index 4d4c3ba..1ec38e5 100644 --- a/contrib/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp @@ -135,13 +135,13 @@ static bool hasMemoryWrite(Instruction *I, const TargetLibraryInfo &TLI) { if (auto CS = CallSite(I)) { if (Function *F = CS.getCalledFunction()) { StringRef FnName = F->getName(); - if (TLI.has(LibFunc::strcpy) && FnName == TLI.getName(LibFunc::strcpy)) + if (TLI.has(LibFunc_strcpy) && FnName == TLI.getName(LibFunc_strcpy)) return true; - if (TLI.has(LibFunc::strncpy) && FnName == TLI.getName(LibFunc::strncpy)) + if (TLI.has(LibFunc_strncpy) && FnName == TLI.getName(LibFunc_strncpy)) return true; - if (TLI.has(LibFunc::strcat) && FnName == TLI.getName(LibFunc::strcat)) + if (TLI.has(LibFunc_strcat) && FnName == TLI.getName(LibFunc_strcat)) return true; - if (TLI.has(LibFunc::strncat) && FnName == TLI.getName(LibFunc::strncat)) + if (TLI.has(LibFunc_strncat) && FnName == TLI.getName(LibFunc_strncat)) return true; } } @@ -287,19 +287,14 @@ static uint64_t getPointerSize(const Value *V, const DataLayout &DL, } namespace { -enum OverwriteResult { - OverwriteBegin, - OverwriteComplete, - OverwriteEnd, - OverwriteUnknown -}; +enum OverwriteResult { OW_Begin, OW_Complete, OW_End, OW_Unknown }; } -/// Return 'OverwriteComplete' if a store to the 'Later' location completely -/// overwrites a store to the 'Earlier' location, 'OverwriteEnd' if the end of -/// the 'Earlier' location is completely overwritten by 'Later', -/// 'OverwriteBegin' if the beginning of the 'Earlier' location is overwritten -/// by 'Later', or 'OverwriteUnknown' if nothing can be determined. +/// Return 'OW_Complete' if a store to the 'Later' location completely +/// overwrites a store to the 'Earlier' location, 'OW_End' if the end of the +/// 'Earlier' location is completely overwritten by 'Later', 'OW_Begin' if the +/// beginning of the 'Earlier' location is overwritten by 'Later', or +/// 'OW_Unknown' if nothing can be determined. static OverwriteResult isOverwrite(const MemoryLocation &Later, const MemoryLocation &Earlier, const DataLayout &DL, @@ -310,7 +305,7 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later, // If we don't know the sizes of either access, then we can't do a comparison. if (Later.Size == MemoryLocation::UnknownSize || Earlier.Size == MemoryLocation::UnknownSize) - return OverwriteUnknown; + return OW_Unknown; const Value *P1 = Earlier.Ptr->stripPointerCasts(); const Value *P2 = Later.Ptr->stripPointerCasts(); @@ -320,7 +315,7 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later, if (P1 == P2) { // Make sure that the Later size is >= the Earlier size. if (Later.Size >= Earlier.Size) - return OverwriteComplete; + return OW_Complete; } // Check to see if the later store is to the entire object (either a global, @@ -332,13 +327,13 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later, // If we can't resolve the same pointers to the same object, then we can't // analyze them at all. if (UO1 != UO2) - return OverwriteUnknown; + return OW_Unknown; // If the "Later" store is to a recognizable object, get its size. uint64_t ObjectSize = getPointerSize(UO2, DL, TLI); if (ObjectSize != MemoryLocation::UnknownSize) if (ObjectSize == Later.Size && ObjectSize >= Earlier.Size) - return OverwriteComplete; + return OW_Complete; // Okay, we have stores to two completely different pointers. Try to // decompose the pointer into a "base + constant_offset" form. If the base @@ -350,7 +345,7 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later, // If the base pointers still differ, we have two completely different stores. if (BP1 != BP2) - return OverwriteUnknown; + return OW_Unknown; // The later store completely overlaps the earlier store if: // @@ -370,7 +365,7 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later, if (EarlierOff >= LaterOff && Later.Size >= Earlier.Size && uint64_t(EarlierOff - LaterOff) + Earlier.Size <= Later.Size) - return OverwriteComplete; + return OW_Complete; // We may now overlap, although the overlap is not complete. There might also // be other incomplete overlaps, and together, they might cover the complete @@ -428,7 +423,7 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later, ") Composite Later [" << ILI->second << ", " << ILI->first << ")\n"); ++NumCompletePartials; - return OverwriteComplete; + return OW_Complete; } } @@ -443,7 +438,7 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later, if (!EnablePartialOverwriteTracking && (LaterOff > EarlierOff && LaterOff < int64_t(EarlierOff + Earlier.Size) && int64_t(LaterOff + Later.Size) >= int64_t(EarlierOff + Earlier.Size))) - return OverwriteEnd; + return OW_End; // Finally, we also need to check if the later store overwrites the beginning // of the earlier store. @@ -458,11 +453,11 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later, (LaterOff <= EarlierOff && int64_t(LaterOff + Later.Size) > EarlierOff)) { assert(int64_t(LaterOff + Later.Size) < int64_t(EarlierOff + Earlier.Size) && - "Expect to be handled as OverwriteComplete"); - return OverwriteBegin; + "Expect to be handled as OW_Complete"); + return OW_Begin; } // Otherwise, they don't completely overlap. - return OverwriteUnknown; + return OW_Unknown; } /// If 'Inst' might be a self read (i.e. a noop copy of a @@ -551,7 +546,7 @@ static bool memoryIsNotModifiedBetween(Instruction *FirstI, Instruction *I = &*BI; if (I->mayWriteToMemory() && I != SecondI) { auto Res = AA->getModRefInfo(I, MemLoc); - if (Res != MRI_NoModRef) + if (Res & MRI_Mod) return false; } } @@ -909,7 +904,7 @@ static bool tryToShortenBegin(Instruction *EarlierWrite, if (LaterStart <= EarlierStart && LaterStart + LaterSize > EarlierStart) { assert(LaterStart + LaterSize < EarlierStart + EarlierSize && - "Should have been handled as OverwriteComplete"); + "Should have been handled as OW_Complete"); if (tryToShorten(EarlierWrite, EarlierStart, EarlierSize, LaterStart, LaterSize, false)) { IntervalMap.erase(OII); @@ -1105,7 +1100,7 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, OverwriteResult OR = isOverwrite(Loc, DepLoc, DL, *TLI, DepWriteOffset, InstWriteOffset, DepWrite, IOL); - if (OR == OverwriteComplete) { + if (OR == OW_Complete) { DEBUG(dbgs() << "DSE: Remove Dead Store:\n DEAD: " << *DepWrite << "\n KILLER: " << *Inst << '\n'); @@ -1117,15 +1112,15 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, // We erased DepWrite; start over. InstDep = MD->getDependency(Inst); continue; - } else if ((OR == OverwriteEnd && isShortenableAtTheEnd(DepWrite)) || - ((OR == OverwriteBegin && + } else if ((OR == OW_End && isShortenableAtTheEnd(DepWrite)) || + ((OR == OW_Begin && isShortenableAtTheBeginning(DepWrite)))) { assert(!EnablePartialOverwriteTracking && "Do not expect to perform " "when partial-overwrite " "tracking is enabled"); int64_t EarlierSize = DepLoc.Size; int64_t LaterSize = Loc.Size; - bool IsOverwriteEnd = (OR == OverwriteEnd); + bool IsOverwriteEnd = (OR == OW_End); MadeChange |= tryToShorten(DepWrite, DepWriteOffset, EarlierSize, InstWriteOffset, LaterSize, IsOverwriteEnd); } @@ -1186,8 +1181,9 @@ PreservedAnalyses DSEPass::run(Function &F, FunctionAnalysisManager &AM) { if (!eliminateDeadStores(F, AA, MD, DT, TLI)) return PreservedAnalyses::all(); + PreservedAnalyses PA; - PA.preserve<DominatorTreeAnalysis>(); + PA.preserveSet<CFGAnalyses>(); PA.preserve<GlobalsAA>(); PA.preserve<MemoryDependenceAnalysis>(); return PA; diff --git a/contrib/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/contrib/llvm/lib/Transforms/Scalar/EarlyCSE.cpp index 16e08ee..c5c9b2c 100644 --- a/contrib/llvm/lib/Transforms/Scalar/EarlyCSE.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/EarlyCSE.cpp @@ -15,10 +15,13 @@ #include "llvm/Transforms/Scalar/EarlyCSE.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/ScopedHashTable.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/DataLayout.h" @@ -32,7 +35,6 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" -#include "llvm/Transforms/Utils/MemorySSA.h" #include <deque> using namespace llvm; using namespace llvm::PatternMatch; @@ -252,7 +254,9 @@ public: const TargetTransformInfo &TTI; DominatorTree &DT; AssumptionCache &AC; + const SimplifyQuery SQ; MemorySSA *MSSA; + std::unique_ptr<MemorySSAUpdater> MSSAUpdater; typedef RecyclingAllocator< BumpPtrAllocator, ScopedHashTableVal<SimpleValue, Value *>> AllocatorTy; typedef ScopedHashTable<SimpleValue, Value *, DenseMapInfo<SimpleValue>, @@ -313,9 +317,12 @@ public: unsigned CurrentGeneration; /// \brief Set up the EarlyCSE runner for a particular function. - EarlyCSE(const TargetLibraryInfo &TLI, const TargetTransformInfo &TTI, - DominatorTree &DT, AssumptionCache &AC, MemorySSA *MSSA) - : TLI(TLI), TTI(TTI), DT(DT), AC(AC), MSSA(MSSA), CurrentGeneration(0) {} + EarlyCSE(const DataLayout &DL, const TargetLibraryInfo &TLI, + const TargetTransformInfo &TTI, DominatorTree &DT, + AssumptionCache &AC, MemorySSA *MSSA) + : TLI(TLI), TTI(TTI), DT(DT), AC(AC), SQ(DL, &TLI, &DT, &AC), MSSA(MSSA), + MSSAUpdater(make_unique<MemorySSAUpdater>(MSSA)), CurrentGeneration(0) { + } bool run(); @@ -388,7 +395,7 @@ private: ParseMemoryInst(Instruction *Inst, const TargetTransformInfo &TTI) : IsTargetMemInst(false), Inst(Inst) { if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) - if (TTI.getTgtMemIntrinsic(II, Info) && Info.NumMemRefs == 1) + if (TTI.getTgtMemIntrinsic(II, Info)) IsTargetMemInst = true; } bool isLoad() const { @@ -400,17 +407,14 @@ private: return isa<StoreInst>(Inst); } bool isAtomic() const { - if (IsTargetMemInst) { - assert(Info.IsSimple && "need to refine IsSimple in TTI"); - return false; - } + if (IsTargetMemInst) + return Info.Ordering != AtomicOrdering::NotAtomic; return Inst->isAtomic(); } bool isUnordered() const { - if (IsTargetMemInst) { - assert(Info.IsSimple && "need to refine IsSimple in TTI"); - return true; - } + if (IsTargetMemInst) + return Info.isUnordered(); + if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) { return LI->isUnordered(); } else if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { @@ -421,10 +425,9 @@ private: } bool isVolatile() const { - if (IsTargetMemInst) { - assert(Info.IsSimple && "need to refine IsSimple in TTI"); - return false; - } + if (IsTargetMemInst) + return Info.IsVolatile; + if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) { return LI->isVolatile(); } else if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { @@ -504,7 +507,7 @@ private: if (MemoryAccess *MA = MSSA->getMemoryAccess(Inst)) { // Optimize MemoryPhi nodes that may become redundant by having all the // same input values once MA is removed. - SmallVector<MemoryPhi *, 4> PhisToCheck; + SmallSetVector<MemoryPhi *, 4> PhisToCheck; SmallVector<MemoryAccess *, 8> WorkQueue; WorkQueue.push_back(MA); // Process MemoryPhi nodes in FIFO order using a ever-growing vector since @@ -515,9 +518,9 @@ private: for (auto *U : WI->users()) if (MemoryPhi *MP = dyn_cast<MemoryPhi>(U)) - PhisToCheck.push_back(MP); + PhisToCheck.insert(MP); - MSSA->removeMemoryAccess(WI); + MSSAUpdater->removeMemoryAccess(WI); for (MemoryPhi *MP : PhisToCheck) { MemoryAccess *FirstIn = MP->getIncomingValue(0); @@ -559,13 +562,27 @@ bool EarlyCSE::isSameMemGeneration(unsigned EarlierGeneration, if (!MSSA) return false; + // If MemorySSA has determined that one of EarlierInst or LaterInst does not + // read/write memory, then we can safely return true here. + // FIXME: We could be more aggressive when checking doesNotAccessMemory(), + // onlyReadsMemory(), mayReadFromMemory(), and mayWriteToMemory() in this pass + // by also checking the MemorySSA MemoryAccess on the instruction. Initial + // experiments suggest this isn't worthwhile, at least for C/C++ code compiled + // with the default optimization pipeline. + auto *EarlierMA = MSSA->getMemoryAccess(EarlierInst); + if (!EarlierMA) + return true; + auto *LaterMA = MSSA->getMemoryAccess(LaterInst); + if (!LaterMA) + return true; + // Since we know LaterDef dominates LaterInst and EarlierInst dominates // LaterInst, if LaterDef dominates EarlierInst then it can't occur between // EarlierInst and LaterInst and neither can any other write that potentially // clobbers LaterInst. MemoryAccess *LaterDef = MSSA->getWalker()->getClobberingMemoryAccess(LaterInst); - return MSSA->dominates(LaterDef, MSSA->getMemoryAccess(EarlierInst)); + return MSSA->dominates(LaterDef, EarlierMA); } bool EarlyCSE::processNode(DomTreeNode *Node) { @@ -587,27 +604,28 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { // which reaches this block where the condition might hold a different // value. Since we're adding this to the scoped hash table (like any other // def), it will have been popped if we encounter a future merge block. - if (BasicBlock *Pred = BB->getSinglePredecessor()) - if (auto *BI = dyn_cast<BranchInst>(Pred->getTerminator())) - if (BI->isConditional()) - if (auto *CondInst = dyn_cast<Instruction>(BI->getCondition())) - if (SimpleValue::canHandle(CondInst)) { - assert(BI->getSuccessor(0) == BB || BI->getSuccessor(1) == BB); - auto *ConditionalConstant = (BI->getSuccessor(0) == BB) ? - ConstantInt::getTrue(BB->getContext()) : - ConstantInt::getFalse(BB->getContext()); - AvailableValues.insert(CondInst, ConditionalConstant); - DEBUG(dbgs() << "EarlyCSE CVP: Add conditional value for '" - << CondInst->getName() << "' as " << *ConditionalConstant - << " in " << BB->getName() << "\n"); - // Replace all dominated uses with the known value. - if (unsigned Count = - replaceDominatedUsesWith(CondInst, ConditionalConstant, DT, - BasicBlockEdge(Pred, BB))) { - Changed = true; - NumCSECVP = NumCSECVP + Count; - } - } + if (BasicBlock *Pred = BB->getSinglePredecessor()) { + auto *BI = dyn_cast<BranchInst>(Pred->getTerminator()); + if (BI && BI->isConditional()) { + auto *CondInst = dyn_cast<Instruction>(BI->getCondition()); + if (CondInst && SimpleValue::canHandle(CondInst)) { + assert(BI->getSuccessor(0) == BB || BI->getSuccessor(1) == BB); + auto *TorF = (BI->getSuccessor(0) == BB) + ? ConstantInt::getTrue(BB->getContext()) + : ConstantInt::getFalse(BB->getContext()); + AvailableValues.insert(CondInst, TorF); + DEBUG(dbgs() << "EarlyCSE CVP: Add conditional value for '" + << CondInst->getName() << "' as " << *TorF << " in " + << BB->getName() << "\n"); + // Replace all dominated uses with the known value. + if (unsigned Count = replaceDominatedUsesWith( + CondInst, TorF, DT, BasicBlockEdge(Pred, BB))) { + Changed = true; + NumCSECVP += Count; + } + } + } + } /// LastStore - Keep track of the last non-volatile store that we saw... for /// as long as there in no instruction that reads memory. If we see a store @@ -615,8 +633,6 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { /// stores which can occur in bitfield code among other things. Instruction *LastStore = nullptr; - const DataLayout &DL = BB->getModule()->getDataLayout(); - // See if any instructions in the block can be eliminated. If so, do it. If // not, add them to AvailableValues. for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E;) { @@ -634,10 +650,16 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { // Skip assume intrinsics, they don't really have side effects (although // they're marked as such to ensure preservation of control dependencies), - // and this pass will not disturb any of the assumption's control - // dependencies. + // and this pass will not bother with its removal. However, we should mark + // its condition as true for all dominated blocks. if (match(Inst, m_Intrinsic<Intrinsic::assume>())) { - DEBUG(dbgs() << "EarlyCSE skipping assumption: " << *Inst << '\n'); + auto *CondI = + dyn_cast<Instruction>(cast<CallInst>(Inst)->getArgOperand(0)); + if (CondI && SimpleValue::canHandle(CondI)) { + DEBUG(dbgs() << "EarlyCSE considering assumption: " << *Inst << '\n'); + AvailableValues.insert(CondI, ConstantInt::getTrue(BB->getContext())); + } else + DEBUG(dbgs() << "EarlyCSE skipping assumption: " << *Inst << '\n'); continue; } @@ -657,10 +679,25 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { if (match(Inst, m_Intrinsic<Intrinsic::experimental_guard>())) { if (auto *CondI = dyn_cast<Instruction>(cast<CallInst>(Inst)->getArgOperand(0))) { - // The condition we're on guarding here is true for all dominated - // locations. - if (SimpleValue::canHandle(CondI)) + if (SimpleValue::canHandle(CondI)) { + // Do we already know the actual value of this condition? + if (auto *KnownCond = AvailableValues.lookup(CondI)) { + // Is the condition known to be true? + if (isa<ConstantInt>(KnownCond) && + cast<ConstantInt>(KnownCond)->isOne()) { + DEBUG(dbgs() << "EarlyCSE removing guard: " << *Inst << '\n'); + removeMSSA(Inst); + Inst->eraseFromParent(); + Changed = true; + continue; + } else + // Use the known value if it wasn't true. + cast<CallInst>(Inst)->setArgOperand(0, KnownCond); + } + // The condition we're on guarding here is true for all dominated + // locations. AvailableValues.insert(CondI, ConstantInt::getTrue(BB->getContext())); + } } // Guard intrinsics read all memory, but don't write any memory. @@ -672,7 +709,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { // If the instruction can be simplified (e.g. X+0 = X) then replace it with // its simpler value. - if (Value *V = SimplifyInstruction(Inst, DL, &TLI, &DT, &AC)) { + if (Value *V = SimplifyInstruction(Inst, SQ)) { DEBUG(dbgs() << "EarlyCSE Simplify: " << *Inst << " to: " << *V << '\n'); bool Killed = false; if (!Inst->use_empty()) { @@ -761,12 +798,13 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { continue; } - // If this instruction may read from memory, forget LastStore. - // Load/store intrinsics will indicate both a read and a write to - // memory. The target may override this (e.g. so that a store intrinsic - // does not read from memory, and thus will be treated the same as a - // regular store for commoning purposes). - if (Inst->mayReadFromMemory() && + // If this instruction may read from memory or throw (and potentially read + // from memory in the exception handler), forget LastStore. Load/store + // intrinsics will indicate both a read and a write to memory. The target + // may override this (e.g. so that a store intrinsic does not read from + // memory, and thus will be treated the same as a regular store for + // commoning purposes). + if ((Inst->mayReadFromMemory() || Inst->mayThrow()) && !(MemInst.isValid() && !MemInst.mayReadFromMemory())) LastStore = nullptr; @@ -962,15 +1000,13 @@ PreservedAnalyses EarlyCSEPass::run(Function &F, auto *MSSA = UseMemorySSA ? &AM.getResult<MemorySSAAnalysis>(F).getMSSA() : nullptr; - EarlyCSE CSE(TLI, TTI, DT, AC, MSSA); + EarlyCSE CSE(F.getParent()->getDataLayout(), TLI, TTI, DT, AC, MSSA); if (!CSE.run()) return PreservedAnalyses::all(); - // CSE preserves the dominator tree because it doesn't mutate the CFG. - // FIXME: Bundle this with other CFG-preservation. PreservedAnalyses PA; - PA.preserve<DominatorTreeAnalysis>(); + PA.preserveSet<CFGAnalyses>(); PA.preserve<GlobalsAA>(); if (UseMemorySSA) PA.preserve<MemorySSAAnalysis>(); @@ -1008,7 +1044,7 @@ public: auto *MSSA = UseMemorySSA ? &getAnalysis<MemorySSAWrapperPass>().getMSSA() : nullptr; - EarlyCSE CSE(TLI, TTI, DT, AC, MSSA); + EarlyCSE CSE(F.getParent()->getDataLayout(), TLI, TTI, DT, AC, MSSA); return CSE.run(); } diff --git a/contrib/llvm/lib/Transforms/Scalar/FlattenCFGPass.cpp b/contrib/llvm/lib/Transforms/Scalar/FlattenCFGPass.cpp index 185cdbd..063df77 100644 --- a/contrib/llvm/lib/Transforms/Scalar/FlattenCFGPass.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/FlattenCFGPass.cpp @@ -11,10 +11,10 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/IR/CFG.h" #include "llvm/Pass.h" +#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; diff --git a/contrib/llvm/lib/Transforms/Scalar/Float2Int.cpp b/contrib/llvm/lib/Transforms/Scalar/Float2Int.cpp index 545036d..b105ece 100644 --- a/contrib/llvm/lib/Transforms/Scalar/Float2Int.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/Float2Int.cpp @@ -137,13 +137,13 @@ void Float2IntPass::findRoots(Function &F, SmallPtrSet<Instruction*,8> &Roots) { } // Helper - mark I as having been traversed, having range R. -ConstantRange Float2IntPass::seen(Instruction *I, ConstantRange R) { +void Float2IntPass::seen(Instruction *I, ConstantRange R) { DEBUG(dbgs() << "F2I: " << *I << ":" << R << "\n"); - if (SeenInsts.find(I) != SeenInsts.end()) - SeenInsts.find(I)->second = R; + auto IT = SeenInsts.find(I); + if (IT != SeenInsts.end()) + IT->second = std::move(R); else - SeenInsts.insert(std::make_pair(I, R)); - return R; + SeenInsts.insert(std::make_pair(I, std::move(R))); } // Helper - get a range representing a poison value. @@ -516,11 +516,10 @@ FunctionPass *createFloat2IntPass() { return new Float2IntLegacyPass(); } PreservedAnalyses Float2IntPass::run(Function &F, FunctionAnalysisManager &) { if (!runImpl(F)) return PreservedAnalyses::all(); - else { - // FIXME: This should also 'preserve the CFG'. - PreservedAnalyses PA; - PA.preserve<GlobalsAA>(); - return PA; - } + + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + PA.preserve<GlobalsAA>(); + return PA; } } // End namespace llvm diff --git a/contrib/llvm/lib/Transforms/Scalar/GVN.cpp b/contrib/llvm/lib/Transforms/Scalar/GVN.cpp index 0137378..ea28705 100644 --- a/contrib/llvm/lib/Transforms/Scalar/GVN.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/GVN.cpp @@ -36,7 +36,6 @@ #include "llvm/Analysis/OptimizationDiagnosticInfo.h" #include "llvm/Analysis/PHITransAddr.h" #include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/GlobalVariable.h" @@ -51,9 +50,12 @@ #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/SSAUpdater.h" +#include "llvm/Transforms/Utils/VNCoercion.h" + #include <vector> using namespace llvm; using namespace llvm::gvn; +using namespace llvm::VNCoercion; using namespace PatternMatch; #define DEBUG_TYPE "gvn" @@ -595,11 +597,12 @@ PreservedAnalyses GVN::run(Function &F, FunctionAnalysisManager &AM) { PreservedAnalyses PA; PA.preserve<DominatorTreeAnalysis>(); PA.preserve<GlobalsAA>(); + PA.preserve<TargetLibraryAnalysis>(); return PA; } -LLVM_DUMP_METHOD -void GVN::dump(DenseMap<uint32_t, Value*>& d) { +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void GVN::dump(DenseMap<uint32_t, Value*>& d) const { errs() << "{\n"; for (DenseMap<uint32_t, Value*>::iterator I = d.begin(), E = d.end(); I != E; ++I) { @@ -608,6 +611,7 @@ void GVN::dump(DenseMap<uint32_t, Value*>& d) { } errs() << "}\n"; } +#endif /// Return true if we can prove that the value /// we're analyzing is fully available in the specified block. As we go, keep @@ -690,442 +694,6 @@ SpeculationFailure: } -/// Return true if CoerceAvailableValueToLoadType will succeed. -static bool CanCoerceMustAliasedValueToLoad(Value *StoredVal, - Type *LoadTy, - const DataLayout &DL) { - // If the loaded or stored value is an first class array or struct, don't try - // to transform them. We need to be able to bitcast to integer. - if (LoadTy->isStructTy() || LoadTy->isArrayTy() || - StoredVal->getType()->isStructTy() || - StoredVal->getType()->isArrayTy()) - return false; - - // The store has to be at least as big as the load. - if (DL.getTypeSizeInBits(StoredVal->getType()) < - DL.getTypeSizeInBits(LoadTy)) - return false; - - return true; -} - -/// If we saw a store of a value to memory, and -/// then a load from a must-aliased pointer of a different type, try to coerce -/// the stored value. LoadedTy is the type of the load we want to replace. -/// IRB is IRBuilder used to insert new instructions. -/// -/// If we can't do it, return null. -static Value *CoerceAvailableValueToLoadType(Value *StoredVal, Type *LoadedTy, - IRBuilder<> &IRB, - const DataLayout &DL) { - assert(CanCoerceMustAliasedValueToLoad(StoredVal, LoadedTy, DL) && - "precondition violation - materialization can't fail"); - - if (auto *C = dyn_cast<Constant>(StoredVal)) - if (auto *FoldedStoredVal = ConstantFoldConstant(C, DL)) - StoredVal = FoldedStoredVal; - - // If this is already the right type, just return it. - Type *StoredValTy = StoredVal->getType(); - - uint64_t StoredValSize = DL.getTypeSizeInBits(StoredValTy); - uint64_t LoadedValSize = DL.getTypeSizeInBits(LoadedTy); - - // If the store and reload are the same size, we can always reuse it. - if (StoredValSize == LoadedValSize) { - // Pointer to Pointer -> use bitcast. - if (StoredValTy->getScalarType()->isPointerTy() && - LoadedTy->getScalarType()->isPointerTy()) { - StoredVal = IRB.CreateBitCast(StoredVal, LoadedTy); - } else { - // Convert source pointers to integers, which can be bitcast. - if (StoredValTy->getScalarType()->isPointerTy()) { - StoredValTy = DL.getIntPtrType(StoredValTy); - StoredVal = IRB.CreatePtrToInt(StoredVal, StoredValTy); - } - - Type *TypeToCastTo = LoadedTy; - if (TypeToCastTo->getScalarType()->isPointerTy()) - TypeToCastTo = DL.getIntPtrType(TypeToCastTo); - - if (StoredValTy != TypeToCastTo) - StoredVal = IRB.CreateBitCast(StoredVal, TypeToCastTo); - - // Cast to pointer if the load needs a pointer type. - if (LoadedTy->getScalarType()->isPointerTy()) - StoredVal = IRB.CreateIntToPtr(StoredVal, LoadedTy); - } - - if (auto *C = dyn_cast<ConstantExpr>(StoredVal)) - if (auto *FoldedStoredVal = ConstantFoldConstant(C, DL)) - StoredVal = FoldedStoredVal; - - return StoredVal; - } - - // If the loaded value is smaller than the available value, then we can - // extract out a piece from it. If the available value is too small, then we - // can't do anything. - assert(StoredValSize >= LoadedValSize && - "CanCoerceMustAliasedValueToLoad fail"); - - // Convert source pointers to integers, which can be manipulated. - if (StoredValTy->getScalarType()->isPointerTy()) { - StoredValTy = DL.getIntPtrType(StoredValTy); - StoredVal = IRB.CreatePtrToInt(StoredVal, StoredValTy); - } - - // Convert vectors and fp to integer, which can be manipulated. - if (!StoredValTy->isIntegerTy()) { - StoredValTy = IntegerType::get(StoredValTy->getContext(), StoredValSize); - StoredVal = IRB.CreateBitCast(StoredVal, StoredValTy); - } - - // If this is a big-endian system, we need to shift the value down to the low - // bits so that a truncate will work. - if (DL.isBigEndian()) { - uint64_t ShiftAmt = DL.getTypeStoreSizeInBits(StoredValTy) - - DL.getTypeStoreSizeInBits(LoadedTy); - StoredVal = IRB.CreateLShr(StoredVal, ShiftAmt, "tmp"); - } - - // Truncate the integer to the right size now. - Type *NewIntTy = IntegerType::get(StoredValTy->getContext(), LoadedValSize); - StoredVal = IRB.CreateTrunc(StoredVal, NewIntTy, "trunc"); - - if (LoadedTy != NewIntTy) { - // If the result is a pointer, inttoptr. - if (LoadedTy->getScalarType()->isPointerTy()) - StoredVal = IRB.CreateIntToPtr(StoredVal, LoadedTy, "inttoptr"); - else - // Otherwise, bitcast. - StoredVal = IRB.CreateBitCast(StoredVal, LoadedTy, "bitcast"); - } - - if (auto *C = dyn_cast<Constant>(StoredVal)) - if (auto *FoldedStoredVal = ConstantFoldConstant(C, DL)) - StoredVal = FoldedStoredVal; - - return StoredVal; -} - -/// This function is called when we have a -/// memdep query of a load that ends up being a clobbering memory write (store, -/// memset, memcpy, memmove). This means that the write *may* provide bits used -/// by the load but we can't be sure because the pointers don't mustalias. -/// -/// Check this case to see if there is anything more we can do before we give -/// up. This returns -1 if we have to give up, or a byte number in the stored -/// value of the piece that feeds the load. -static int AnalyzeLoadFromClobberingWrite(Type *LoadTy, Value *LoadPtr, - Value *WritePtr, - uint64_t WriteSizeInBits, - const DataLayout &DL) { - // If the loaded or stored value is a first class array or struct, don't try - // to transform them. We need to be able to bitcast to integer. - if (LoadTy->isStructTy() || LoadTy->isArrayTy()) - return -1; - - int64_t StoreOffset = 0, LoadOffset = 0; - Value *StoreBase = - GetPointerBaseWithConstantOffset(WritePtr, StoreOffset, DL); - Value *LoadBase = GetPointerBaseWithConstantOffset(LoadPtr, LoadOffset, DL); - if (StoreBase != LoadBase) - return -1; - - // If the load and store are to the exact same address, they should have been - // a must alias. AA must have gotten confused. - // FIXME: Study to see if/when this happens. One case is forwarding a memset - // to a load from the base of the memset. - - // If the load and store don't overlap at all, the store doesn't provide - // anything to the load. In this case, they really don't alias at all, AA - // must have gotten confused. - uint64_t LoadSize = DL.getTypeSizeInBits(LoadTy); - - if ((WriteSizeInBits & 7) | (LoadSize & 7)) - return -1; - uint64_t StoreSize = WriteSizeInBits / 8; // Convert to bytes. - LoadSize /= 8; - - - bool isAAFailure = false; - if (StoreOffset < LoadOffset) - isAAFailure = StoreOffset+int64_t(StoreSize) <= LoadOffset; - else - isAAFailure = LoadOffset+int64_t(LoadSize) <= StoreOffset; - - if (isAAFailure) - return -1; - - // If the Load isn't completely contained within the stored bits, we don't - // have all the bits to feed it. We could do something crazy in the future - // (issue a smaller load then merge the bits in) but this seems unlikely to be - // valuable. - if (StoreOffset > LoadOffset || - StoreOffset+StoreSize < LoadOffset+LoadSize) - return -1; - - // Okay, we can do this transformation. Return the number of bytes into the - // store that the load is. - return LoadOffset-StoreOffset; -} - -/// This function is called when we have a -/// memdep query of a load that ends up being a clobbering store. -static int AnalyzeLoadFromClobberingStore(Type *LoadTy, Value *LoadPtr, - StoreInst *DepSI) { - // Cannot handle reading from store of first-class aggregate yet. - if (DepSI->getValueOperand()->getType()->isStructTy() || - DepSI->getValueOperand()->getType()->isArrayTy()) - return -1; - - const DataLayout &DL = DepSI->getModule()->getDataLayout(); - Value *StorePtr = DepSI->getPointerOperand(); - uint64_t StoreSize =DL.getTypeSizeInBits(DepSI->getValueOperand()->getType()); - return AnalyzeLoadFromClobberingWrite(LoadTy, LoadPtr, - StorePtr, StoreSize, DL); -} - -/// This function is called when we have a -/// memdep query of a load that ends up being clobbered by another load. See if -/// the other load can feed into the second load. -static int AnalyzeLoadFromClobberingLoad(Type *LoadTy, Value *LoadPtr, - LoadInst *DepLI, const DataLayout &DL){ - // Cannot handle reading from store of first-class aggregate yet. - if (DepLI->getType()->isStructTy() || DepLI->getType()->isArrayTy()) - return -1; - - Value *DepPtr = DepLI->getPointerOperand(); - uint64_t DepSize = DL.getTypeSizeInBits(DepLI->getType()); - int R = AnalyzeLoadFromClobberingWrite(LoadTy, LoadPtr, DepPtr, DepSize, DL); - if (R != -1) return R; - - // If we have a load/load clobber an DepLI can be widened to cover this load, - // then we should widen it! - int64_t LoadOffs = 0; - const Value *LoadBase = - GetPointerBaseWithConstantOffset(LoadPtr, LoadOffs, DL); - unsigned LoadSize = DL.getTypeStoreSize(LoadTy); - - unsigned Size = MemoryDependenceResults::getLoadLoadClobberFullWidthSize( - LoadBase, LoadOffs, LoadSize, DepLI); - if (Size == 0) return -1; - - // Check non-obvious conditions enforced by MDA which we rely on for being - // able to materialize this potentially available value - assert(DepLI->isSimple() && "Cannot widen volatile/atomic load!"); - assert(DepLI->getType()->isIntegerTy() && "Can't widen non-integer load"); - - return AnalyzeLoadFromClobberingWrite(LoadTy, LoadPtr, DepPtr, Size*8, DL); -} - - - -static int AnalyzeLoadFromClobberingMemInst(Type *LoadTy, Value *LoadPtr, - MemIntrinsic *MI, - const DataLayout &DL) { - // If the mem operation is a non-constant size, we can't handle it. - ConstantInt *SizeCst = dyn_cast<ConstantInt>(MI->getLength()); - if (!SizeCst) return -1; - uint64_t MemSizeInBits = SizeCst->getZExtValue()*8; - - // If this is memset, we just need to see if the offset is valid in the size - // of the memset.. - if (MI->getIntrinsicID() == Intrinsic::memset) - return AnalyzeLoadFromClobberingWrite(LoadTy, LoadPtr, MI->getDest(), - MemSizeInBits, DL); - - // If we have a memcpy/memmove, the only case we can handle is if this is a - // copy from constant memory. In that case, we can read directly from the - // constant memory. - MemTransferInst *MTI = cast<MemTransferInst>(MI); - - Constant *Src = dyn_cast<Constant>(MTI->getSource()); - if (!Src) return -1; - - GlobalVariable *GV = dyn_cast<GlobalVariable>(GetUnderlyingObject(Src, DL)); - if (!GV || !GV->isConstant()) return -1; - - // See if the access is within the bounds of the transfer. - int Offset = AnalyzeLoadFromClobberingWrite(LoadTy, LoadPtr, - MI->getDest(), MemSizeInBits, DL); - if (Offset == -1) - return Offset; - - unsigned AS = Src->getType()->getPointerAddressSpace(); - // Otherwise, see if we can constant fold a load from the constant with the - // offset applied as appropriate. - Src = ConstantExpr::getBitCast(Src, - Type::getInt8PtrTy(Src->getContext(), AS)); - Constant *OffsetCst = - ConstantInt::get(Type::getInt64Ty(Src->getContext()), (unsigned)Offset); - Src = ConstantExpr::getGetElementPtr(Type::getInt8Ty(Src->getContext()), Src, - OffsetCst); - Src = ConstantExpr::getBitCast(Src, PointerType::get(LoadTy, AS)); - if (ConstantFoldLoadFromConstPtr(Src, LoadTy, DL)) - return Offset; - return -1; -} - - -/// This function is called when we have a -/// memdep query of a load that ends up being a clobbering store. This means -/// that the store provides bits used by the load but we the pointers don't -/// mustalias. Check this case to see if there is anything more we can do -/// before we give up. -static Value *GetStoreValueForLoad(Value *SrcVal, unsigned Offset, - Type *LoadTy, - Instruction *InsertPt, const DataLayout &DL){ - LLVMContext &Ctx = SrcVal->getType()->getContext(); - - uint64_t StoreSize = (DL.getTypeSizeInBits(SrcVal->getType()) + 7) / 8; - uint64_t LoadSize = (DL.getTypeSizeInBits(LoadTy) + 7) / 8; - - IRBuilder<> Builder(InsertPt); - - // Compute which bits of the stored value are being used by the load. Convert - // to an integer type to start with. - if (SrcVal->getType()->getScalarType()->isPointerTy()) - SrcVal = Builder.CreatePtrToInt(SrcVal, - DL.getIntPtrType(SrcVal->getType())); - if (!SrcVal->getType()->isIntegerTy()) - SrcVal = Builder.CreateBitCast(SrcVal, IntegerType::get(Ctx, StoreSize*8)); - - // Shift the bits to the least significant depending on endianness. - unsigned ShiftAmt; - if (DL.isLittleEndian()) - ShiftAmt = Offset*8; - else - ShiftAmt = (StoreSize-LoadSize-Offset)*8; - - if (ShiftAmt) - SrcVal = Builder.CreateLShr(SrcVal, ShiftAmt); - - if (LoadSize != StoreSize) - SrcVal = Builder.CreateTrunc(SrcVal, IntegerType::get(Ctx, LoadSize*8)); - - return CoerceAvailableValueToLoadType(SrcVal, LoadTy, Builder, DL); -} - -/// This function is called when we have a -/// memdep query of a load that ends up being a clobbering load. This means -/// that the load *may* provide bits used by the load but we can't be sure -/// because the pointers don't mustalias. Check this case to see if there is -/// anything more we can do before we give up. -static Value *GetLoadValueForLoad(LoadInst *SrcVal, unsigned Offset, - Type *LoadTy, Instruction *InsertPt, - GVN &gvn) { - const DataLayout &DL = SrcVal->getModule()->getDataLayout(); - // If Offset+LoadTy exceeds the size of SrcVal, then we must be wanting to - // widen SrcVal out to a larger load. - unsigned SrcValStoreSize = DL.getTypeStoreSize(SrcVal->getType()); - unsigned LoadSize = DL.getTypeStoreSize(LoadTy); - if (Offset+LoadSize > SrcValStoreSize) { - assert(SrcVal->isSimple() && "Cannot widen volatile/atomic load!"); - assert(SrcVal->getType()->isIntegerTy() && "Can't widen non-integer load"); - // If we have a load/load clobber an DepLI can be widened to cover this - // load, then we should widen it to the next power of 2 size big enough! - unsigned NewLoadSize = Offset+LoadSize; - if (!isPowerOf2_32(NewLoadSize)) - NewLoadSize = NextPowerOf2(NewLoadSize); - - Value *PtrVal = SrcVal->getPointerOperand(); - - // Insert the new load after the old load. This ensures that subsequent - // memdep queries will find the new load. We can't easily remove the old - // load completely because it is already in the value numbering table. - IRBuilder<> Builder(SrcVal->getParent(), ++BasicBlock::iterator(SrcVal)); - Type *DestPTy = - IntegerType::get(LoadTy->getContext(), NewLoadSize*8); - DestPTy = PointerType::get(DestPTy, - PtrVal->getType()->getPointerAddressSpace()); - Builder.SetCurrentDebugLocation(SrcVal->getDebugLoc()); - PtrVal = Builder.CreateBitCast(PtrVal, DestPTy); - LoadInst *NewLoad = Builder.CreateLoad(PtrVal); - NewLoad->takeName(SrcVal); - NewLoad->setAlignment(SrcVal->getAlignment()); - - DEBUG(dbgs() << "GVN WIDENED LOAD: " << *SrcVal << "\n"); - DEBUG(dbgs() << "TO: " << *NewLoad << "\n"); - - // Replace uses of the original load with the wider load. On a big endian - // system, we need to shift down to get the relevant bits. - Value *RV = NewLoad; - if (DL.isBigEndian()) - RV = Builder.CreateLShr(RV, (NewLoadSize - SrcValStoreSize) * 8); - RV = Builder.CreateTrunc(RV, SrcVal->getType()); - SrcVal->replaceAllUsesWith(RV); - - // We would like to use gvn.markInstructionForDeletion here, but we can't - // because the load is already memoized into the leader map table that GVN - // tracks. It is potentially possible to remove the load from the table, - // but then there all of the operations based on it would need to be - // rehashed. Just leave the dead load around. - gvn.getMemDep().removeInstruction(SrcVal); - SrcVal = NewLoad; - } - - return GetStoreValueForLoad(SrcVal, Offset, LoadTy, InsertPt, DL); -} - - -/// This function is called when we have a -/// memdep query of a load that ends up being a clobbering mem intrinsic. -static Value *GetMemInstValueForLoad(MemIntrinsic *SrcInst, unsigned Offset, - Type *LoadTy, Instruction *InsertPt, - const DataLayout &DL){ - LLVMContext &Ctx = LoadTy->getContext(); - uint64_t LoadSize = DL.getTypeSizeInBits(LoadTy)/8; - - IRBuilder<> Builder(InsertPt); - - // We know that this method is only called when the mem transfer fully - // provides the bits for the load. - if (MemSetInst *MSI = dyn_cast<MemSetInst>(SrcInst)) { - // memset(P, 'x', 1234) -> splat('x'), even if x is a variable, and - // independently of what the offset is. - Value *Val = MSI->getValue(); - if (LoadSize != 1) - Val = Builder.CreateZExt(Val, IntegerType::get(Ctx, LoadSize*8)); - - Value *OneElt = Val; - - // Splat the value out to the right number of bits. - for (unsigned NumBytesSet = 1; NumBytesSet != LoadSize; ) { - // If we can double the number of bytes set, do it. - if (NumBytesSet*2 <= LoadSize) { - Value *ShVal = Builder.CreateShl(Val, NumBytesSet*8); - Val = Builder.CreateOr(Val, ShVal); - NumBytesSet <<= 1; - continue; - } - - // Otherwise insert one byte at a time. - Value *ShVal = Builder.CreateShl(Val, 1*8); - Val = Builder.CreateOr(OneElt, ShVal); - ++NumBytesSet; - } - - return CoerceAvailableValueToLoadType(Val, LoadTy, Builder, DL); - } - - // Otherwise, this is a memcpy/memmove from a constant global. - MemTransferInst *MTI = cast<MemTransferInst>(SrcInst); - Constant *Src = cast<Constant>(MTI->getSource()); - unsigned AS = Src->getType()->getPointerAddressSpace(); - - // Otherwise, see if we can constant fold a load from the constant with the - // offset applied as appropriate. - Src = ConstantExpr::getBitCast(Src, - Type::getInt8PtrTy(Src->getContext(), AS)); - Constant *OffsetCst = - ConstantInt::get(Type::getInt64Ty(Src->getContext()), (unsigned)Offset); - Src = ConstantExpr::getGetElementPtr(Type::getInt8Ty(Src->getContext()), Src, - OffsetCst); - Src = ConstantExpr::getBitCast(Src, PointerType::get(LoadTy, AS)); - return ConstantFoldLoadFromConstPtr(Src, LoadTy, DL); -} /// Given a set of loads specified by ValuesPerBlock, @@ -1171,7 +739,7 @@ Value *AvailableValue::MaterializeAdjustedValue(LoadInst *LI, if (isSimpleValue()) { Res = getSimpleValue(); if (Res->getType() != LoadTy) { - Res = GetStoreValueForLoad(Res, Offset, LoadTy, InsertPt, DL); + Res = getStoreValueForLoad(Res, Offset, LoadTy, InsertPt, DL); DEBUG(dbgs() << "GVN COERCED NONLOCAL VAL:\nOffset: " << Offset << " " << *getSimpleValue() << '\n' @@ -1182,14 +750,20 @@ Value *AvailableValue::MaterializeAdjustedValue(LoadInst *LI, if (Load->getType() == LoadTy && Offset == 0) { Res = Load; } else { - Res = GetLoadValueForLoad(Load, Offset, LoadTy, InsertPt, gvn); - + Res = getLoadValueForLoad(Load, Offset, LoadTy, InsertPt, DL); + // We would like to use gvn.markInstructionForDeletion here, but we can't + // because the load is already memoized into the leader map table that GVN + // tracks. It is potentially possible to remove the load from the table, + // but then there all of the operations based on it would need to be + // rehashed. Just leave the dead load around. + gvn.getMemDep().removeInstruction(Load); DEBUG(dbgs() << "GVN COERCED NONLOCAL LOAD:\nOffset: " << Offset << " " << *getCoercedLoadValue() << '\n' - << *Res << '\n' << "\n\n\n"); + << *Res << '\n' + << "\n\n\n"); } } else if (isMemIntrinValue()) { - Res = GetMemInstValueForLoad(getMemIntrinValue(), Offset, LoadTy, + Res = getMemInstValueForLoad(getMemIntrinValue(), Offset, LoadTy, InsertPt, DL); DEBUG(dbgs() << "GVN COERCED NONLOCAL MEM INTRIN:\nOffset: " << Offset << " " << *getMemIntrinValue() << '\n' @@ -1258,7 +832,7 @@ bool GVN::AnalyzeLoadAvailability(LoadInst *LI, MemDepResult DepInfo, // Can't forward from non-atomic to atomic without violating memory model. if (Address && LI->isAtomic() <= DepSI->isAtomic()) { int Offset = - AnalyzeLoadFromClobberingStore(LI->getType(), Address, DepSI); + analyzeLoadFromClobberingStore(LI->getType(), Address, DepSI, DL); if (Offset != -1) { Res = AvailableValue::get(DepSI->getValueOperand(), Offset); return true; @@ -1276,7 +850,7 @@ bool GVN::AnalyzeLoadAvailability(LoadInst *LI, MemDepResult DepInfo, // Can't forward from non-atomic to atomic without violating memory model. if (DepLI != LI && Address && LI->isAtomic() <= DepLI->isAtomic()) { int Offset = - AnalyzeLoadFromClobberingLoad(LI->getType(), Address, DepLI, DL); + analyzeLoadFromClobberingLoad(LI->getType(), Address, DepLI, DL); if (Offset != -1) { Res = AvailableValue::getLoad(DepLI, Offset); @@ -1289,7 +863,7 @@ bool GVN::AnalyzeLoadAvailability(LoadInst *LI, MemDepResult DepInfo, // forward a value on from it. if (MemIntrinsic *DepMI = dyn_cast<MemIntrinsic>(DepInfo.getInst())) { if (Address && !LI->isAtomic()) { - int Offset = AnalyzeLoadFromClobberingMemInst(LI->getType(), Address, + int Offset = analyzeLoadFromClobberingMemInst(LI->getType(), Address, DepMI, DL); if (Offset != -1) { Res = AvailableValue::getMI(DepMI, Offset); @@ -1334,7 +908,7 @@ bool GVN::AnalyzeLoadAvailability(LoadInst *LI, MemDepResult DepInfo, // different types if we have to. If the stored value is larger or equal to // the loaded value, we can reuse it. if (S->getValueOperand()->getType() != LI->getType() && - !CanCoerceMustAliasedValueToLoad(S->getValueOperand(), + !canCoerceMustAliasedValueToLoad(S->getValueOperand(), LI->getType(), DL)) return false; @@ -1351,7 +925,7 @@ bool GVN::AnalyzeLoadAvailability(LoadInst *LI, MemDepResult DepInfo, // If the stored value is larger or equal to the loaded value, we can reuse // it. if (LD->getType() != LI->getType() && - !CanCoerceMustAliasedValueToLoad(LD, LI->getType(), DL)) + !canCoerceMustAliasedValueToLoad(LD, LI->getType(), DL)) return false; // Can't forward from non-atomic to atomic without violating memory model. @@ -1592,8 +1166,9 @@ bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock, auto *NewLoad = new LoadInst(LoadPtr, LI->getName()+".pre", LI->isVolatile(), LI->getAlignment(), - LI->getOrdering(), LI->getSynchScope(), + LI->getOrdering(), LI->getSyncScopeID(), UnavailablePred->getTerminator()); + NewLoad->setDebugLoc(LI->getDebugLoc()); // Transfer the old load's AA tags to the new load. AAMDNodes Tags; @@ -1628,7 +1203,7 @@ bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock, V->takeName(LI); if (Instruction *I = dyn_cast<Instruction>(V)) I->setDebugLoc(LI->getDebugLoc()); - if (V->getType()->getScalarType()->isPointerTy()) + if (V->getType()->isPtrOrPtrVectorTy()) MD->invalidateCachedPointerInfo(V); markInstructionForDeletion(LI); ORE->emit(OptimizationRemark(DEBUG_TYPE, "LoadPRE", LI) @@ -1713,9 +1288,9 @@ bool GVN::processNonLocalLoad(LoadInst *LI) { // If instruction I has debug info, then we should not update it. // Also, if I has a null DebugLoc, then it is still potentially incorrect // to propagate LI's DebugLoc because LI may not post-dominate I. - if (LI->getDebugLoc() && ValuesPerBlock.size() != 1) + if (LI->getDebugLoc() && LI->getParent() == I->getParent()) I->setDebugLoc(LI->getDebugLoc()); - if (V->getType()->getScalarType()->isPointerTy()) + if (V->getType()->isPtrOrPtrVectorTy()) MD->invalidateCachedPointerInfo(V); markInstructionForDeletion(LI); ++NumGVNLoad; @@ -1795,7 +1370,7 @@ static void patchReplacementInstruction(Instruction *I, Value *Repl) { // Patch the replacement so that it is not more restrictive than the value // being replaced. - // Note that if 'I' is a load being replaced by some operation, + // Note that if 'I' is a load being replaced by some operation, // for example, by an arithmetic operation, then andIRFlags() // would just erase all math flags from the original arithmetic // operation, which is clearly not wanted and not needed. @@ -1869,7 +1444,7 @@ bool GVN::processLoad(LoadInst *L) { reportLoadElim(L, AvailableValue, ORE); // Tell MDA to rexamine the reused pointer since we might have more // information after forwarding it. - if (MD && AvailableValue->getType()->getScalarType()->isPointerTy()) + if (MD && AvailableValue->getType()->isPtrOrPtrVectorTy()) MD->invalidateCachedPointerInfo(AvailableValue); return true; } @@ -2024,7 +1599,7 @@ bool GVN::propagateEquality(Value *LHS, Value *RHS, const BasicBlockEdge &Root, // RHS neither 'true' nor 'false' - bail out. continue; // Whether RHS equals 'true'. Otherwise it equals 'false'. - bool isKnownTrue = CI->isAllOnesValue(); + bool isKnownTrue = CI->isMinusOne(); bool isKnownFalse = !isKnownTrue; // If "A && B" is known true then both A and B are known true. If "A || B" @@ -2113,7 +1688,7 @@ bool GVN::processInstruction(Instruction *I) { // example if it determines that %y is equal to %x then the instruction // "%z = and i32 %x, %y" becomes "%z = and i32 %x, %x" which we now simplify. const DataLayout &DL = I->getModule()->getDataLayout(); - if (Value *V = SimplifyInstruction(I, DL, TLI, DT, AC)) { + if (Value *V = SimplifyInstruction(I, {DL, TLI, DT, AC})) { bool Changed = false; if (!I->use_empty()) { I->replaceAllUsesWith(V); @@ -2124,7 +1699,7 @@ bool GVN::processInstruction(Instruction *I) { Changed = true; } if (Changed) { - if (MD && V->getType()->getScalarType()->isPointerTy()) + if (MD && V->getType()->isPtrOrPtrVectorTy()) MD->invalidateCachedPointerInfo(V); ++NumGVNSimpl; return true; @@ -2187,11 +1762,11 @@ bool GVN::processInstruction(Instruction *I) { for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e; ++i) { - BasicBlock *Dst = i.getCaseSuccessor(); + BasicBlock *Dst = i->getCaseSuccessor(); // If there is only a single edge, propagate the case value into it. if (SwitchEdges.lookup(Dst) == 1) { BasicBlockEdge E(Parent, Dst); - Changed |= propagateEquality(SwitchCond, i.getCaseValue(), E, true); + Changed |= propagateEquality(SwitchCond, i->getCaseValue(), E, true); } } return Changed; @@ -2235,7 +1810,7 @@ bool GVN::processInstruction(Instruction *I) { // Remove it! patchAndReplaceAllUsesWith(I, Repl); - if (MD && Repl->getType()->getScalarType()->isPointerTy()) + if (MD && Repl->getType()->isPtrOrPtrVectorTy()) MD->invalidateCachedPointerInfo(Repl); markInstructionForDeletion(I); return true; @@ -2483,7 +2058,7 @@ bool GVN::performScalarPRE(Instruction *CurInst) { if (!performScalarPREInsertion(PREInstr, PREPred, ValNo)) { // If we failed insertion, make sure we remove the instruction. DEBUG(verifyRemoved(PREInstr)); - delete PREInstr; + PREInstr->deleteValue(); return false; } } @@ -2509,7 +2084,7 @@ bool GVN::performScalarPRE(Instruction *CurInst) { addToLeaderTable(ValNo, Phi, CurrentBlock); Phi->setDebugLoc(CurInst->getDebugLoc()); CurInst->replaceAllUsesWith(Phi); - if (MD && Phi->getType()->getScalarType()->isPointerTy()) + if (MD && Phi->getType()->isPtrOrPtrVectorTy()) MD->invalidateCachedPointerInfo(Phi); VN.erase(CurInst); removeFromLeaderTable(ValNo, CurInst, CurrentBlock); @@ -2581,21 +2156,12 @@ bool GVN::iterateOnFunction(Function &F) { // Top-down walk of the dominator tree bool Changed = false; - // Save the blocks this function have before transformation begins. GVN may - // split critical edge, and hence may invalidate the RPO/DT iterator. - // - std::vector<BasicBlock *> BBVect; - BBVect.reserve(256); // Needed for value numbering with phi construction to work. + // RPOT walks the graph in its constructor and will not be invalidated during + // processBlock. ReversePostOrderTraversal<Function *> RPOT(&F); - for (ReversePostOrderTraversal<Function *>::rpo_iterator RI = RPOT.begin(), - RE = RPOT.end(); - RI != RE; ++RI) - BBVect.push_back(*RI); - - for (std::vector<BasicBlock *>::iterator I = BBVect.begin(), E = BBVect.end(); - I != E; I++) - Changed |= processBlock(*I); + for (BasicBlock *BB : RPOT) + Changed |= processBlock(BB); return Changed; } @@ -2783,6 +2349,7 @@ public: AU.addPreserved<DominatorTreeWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); + AU.addPreserved<TargetLibraryInfoWrapperPass>(); AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); } diff --git a/contrib/llvm/lib/Transforms/Scalar/GVNHoist.cpp b/contrib/llvm/lib/Transforms/Scalar/GVNHoist.cpp index f8e1d2e..29de792 100644 --- a/contrib/llvm/lib/Transforms/Scalar/GVNHoist.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/GVNHoist.cpp @@ -17,16 +17,40 @@ // is disabled in the following cases. // 1. Scalars across calls. // 2. geps when corresponding load/store cannot be hoisted. +// +// TODO: Hoist from >2 successors. Currently GVNHoist will not hoist stores +// in this case because it works on two instructions at a time. +// entry: +// switch i32 %c1, label %exit1 [ +// i32 0, label %sw0 +// i32 1, label %sw1 +// ] +// +// sw0: +// store i32 1, i32* @G +// br label %exit +// +// sw1: +// store i32 1, i32* @G +// br label %exit +// +// exit1: +// store i32 1, i32* @G +// ret void +// exit: +// ret void //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar/GVN.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/GVN.h" #include "llvm/Transforms/Utils/Local.h" -#include "llvm/Transforms/Utils/MemorySSA.h" using namespace llvm; @@ -60,7 +84,7 @@ static cl::opt<int> cl::desc("Maximum length of dependent chains to hoist " "(default = 10, unlimited = -1)")); -namespace { +namespace llvm { // Provides a sorting function based on the execution order of two instructions. struct SortByDFSIn { @@ -72,13 +96,6 @@ public: // Returns true when A executes before B. bool operator()(const Instruction *A, const Instruction *B) const { - // FIXME: libc++ has a std::sort() algorithm that will call the compare - // function on the same element. Once PR20837 is fixed and some more years - // pass by and all the buildbots have moved to a corrected std::sort(), - // enable the following assert: - // - // assert(A != B); - const BasicBlock *BA = A->getParent(); const BasicBlock *BB = B->getParent(); unsigned ADFS, BDFS; @@ -202,6 +219,7 @@ public: GVNHoist(DominatorTree *DT, AliasAnalysis *AA, MemoryDependenceResults *MD, MemorySSA *MSSA) : DT(DT), AA(AA), MD(MD), MSSA(MSSA), + MSSAUpdater(make_unique<MemorySSAUpdater>(MSSA)), HoistingGeps(false), HoistedCtr(0) { } @@ -249,9 +267,11 @@ private: AliasAnalysis *AA; MemoryDependenceResults *MD; MemorySSA *MSSA; + std::unique_ptr<MemorySSAUpdater> MSSAUpdater; const bool HoistingGeps; DenseMap<const Value *, unsigned> DFSNumber; BBSideEffectsSet BBSideEffects; + DenseSet<const BasicBlock*> HoistBarrier; int HoistedCtr; enum InsKind { Unknown, Scalar, Load, Store }; @@ -307,8 +327,8 @@ private: continue; } - // Check for end of function, calls that do not return, etc. - if (!isGuaranteedToTransferExecutionToSuccessor(BB->getTerminator())) + // We reached the leaf Basic Block => not all paths have this instruction. + if (!BB->getTerminator()->getNumSuccessors()) return false; // When reaching the back-edge of a loop, there may be a path through the @@ -360,7 +380,7 @@ private: ReachedNewPt = true; } } - if (defClobbersUseOrDef(Def, MU, *AA)) + if (MemorySSAUtil::defClobbersUseOrDef(Def, MU, *AA)) return true; } @@ -387,7 +407,8 @@ private: // executed between the execution of NewBB and OldBB. Hoisting an expression // from OldBB into NewBB has to be safe on all execution paths. for (auto I = idf_begin(OldBB), E = idf_end(OldBB); I != E;) { - if (*I == NewBB) { + const BasicBlock *BB = *I; + if (BB == NewBB) { // Stop traversal when reaching HoistPt. I.skipChildren(); continue; @@ -398,11 +419,17 @@ private: return true; // Impossible to hoist with exceptions on the path. - if (hasEH(*I)) + if (hasEH(BB)) + return true; + + // No such instruction after HoistBarrier in a basic block was + // selected for hoisting so instructions selected within basic block with + // a hoist barrier can be hoisted. + if ((BB != OldBB) && HoistBarrier.count(BB)) return true; // Check that we do not move a store past loads. - if (hasMemoryUse(NewPt, Def, *I)) + if (hasMemoryUse(NewPt, Def, BB)) return true; // -1 is unlimited number of blocks on all paths. @@ -419,17 +446,18 @@ private: // Decrement by 1 NBBsOnAllPaths for each block between HoistPt and BB, and // return true when the counter NBBsOnAllPaths reaches 0, except when it is // initialized to -1 which is unlimited. - bool hasEHOnPath(const BasicBlock *HoistPt, const BasicBlock *BB, + bool hasEHOnPath(const BasicBlock *HoistPt, const BasicBlock *SrcBB, int &NBBsOnAllPaths) { - assert(DT->dominates(HoistPt, BB) && "Invalid path"); + assert(DT->dominates(HoistPt, SrcBB) && "Invalid path"); // Walk all basic blocks reachable in depth-first iteration on // the inverse CFG from BBInsn to NewHoistPt. These blocks are all the // blocks that may be executed between the execution of NewHoistPt and // BBInsn. Hoisting an expression from BBInsn into NewHoistPt has to be safe // on all execution paths. - for (auto I = idf_begin(BB), E = idf_end(BB); I != E;) { - if (*I == HoistPt) { + for (auto I = idf_begin(SrcBB), E = idf_end(SrcBB); I != E;) { + const BasicBlock *BB = *I; + if (BB == HoistPt) { // Stop traversal when reaching NewHoistPt. I.skipChildren(); continue; @@ -440,7 +468,13 @@ private: return true; // Impossible to hoist with exceptions on the path. - if (hasEH(*I)) + if (hasEH(BB)) + return true; + + // No such instruction after HoistBarrier in a basic block was + // selected for hoisting so instructions selected within basic block with + // a hoist barrier can be hoisted. + if ((BB != SrcBB) && HoistBarrier.count(BB)) return true; // -1 is unlimited number of blocks on all paths. @@ -626,6 +660,8 @@ private: // Compute the insertion point and the list of expressions to be hoisted. SmallVecInsn InstructionsToHoist; for (auto I : V) + // We don't need to check for hoist-barriers here because if + // I->getParent() is a barrier then I precedes the barrier. if (!hasEH(I->getParent())) InstructionsToHoist.push_back(I); @@ -809,9 +845,9 @@ private: // legal when the ld/st is not moved past its current definition. MemoryAccess *Def = OldMemAcc->getDefiningAccess(); NewMemAcc = - MSSA->createMemoryAccessInBB(Repl, Def, HoistPt, MemorySSA::End); + MSSAUpdater->createMemoryAccessInBB(Repl, Def, HoistPt, MemorySSA::End); OldMemAcc->replaceAllUsesWith(NewMemAcc); - MSSA->removeMemoryAccess(OldMemAcc); + MSSAUpdater->removeMemoryAccess(OldMemAcc); } } @@ -850,7 +886,7 @@ private: // Update the uses of the old MSSA access with NewMemAcc. MemoryAccess *OldMA = MSSA->getMemoryAccess(I); OldMA->replaceAllUsesWith(NewMemAcc); - MSSA->removeMemoryAccess(OldMA); + MSSAUpdater->removeMemoryAccess(OldMA); } Repl->andIRFlags(I); @@ -872,7 +908,7 @@ private: auto In = Phi->incoming_values(); if (all_of(In, [&](Use &U) { return U == NewMemAcc; })) { Phi->replaceAllUsesWith(NewMemAcc); - MSSA->removeMemoryAccess(Phi); + MSSAUpdater->removeMemoryAccess(Phi); } } } @@ -896,6 +932,12 @@ private: for (BasicBlock *BB : depth_first(&F.getEntryBlock())) { int InstructionNb = 0; for (Instruction &I1 : *BB) { + // If I1 cannot guarantee progress, subsequent instructions + // in BB cannot be hoisted anyways. + if (!isGuaranteedToTransferExecutionToSuccessor(&I1)) { + HoistBarrier.insert(BB); + break; + } // Only hoist the first instructions in BB up to MaxDepthInBB. Hoisting // deeper may increase the register pressure and compilation time. if (MaxDepthInBB != -1 && InstructionNb++ >= MaxDepthInBB) @@ -969,6 +1011,7 @@ public: AU.addRequired<MemorySSAWrapperPass>(); AU.addPreserved<DominatorTreeWrapperPass>(); AU.addPreserved<MemorySSAWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); } }; } // namespace @@ -985,6 +1028,7 @@ PreservedAnalyses GVNHoistPass::run(Function &F, FunctionAnalysisManager &AM) { PreservedAnalyses PA; PA.preserve<DominatorTreeAnalysis>(); PA.preserve<MemorySSAAnalysis>(); + PA.preserve<GlobalsAA>(); return PA; } diff --git a/contrib/llvm/lib/Transforms/Scalar/GVNSink.cpp b/contrib/llvm/lib/Transforms/Scalar/GVNSink.cpp new file mode 100644 index 0000000..5fd2dfc --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/GVNSink.cpp @@ -0,0 +1,883 @@ +//===- GVNSink.cpp - sink expressions into successors -------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +/// \file GVNSink.cpp +/// This pass attempts to sink instructions into successors, reducing static +/// instruction count and enabling if-conversion. +/// +/// We use a variant of global value numbering to decide what can be sunk. +/// Consider: +/// +/// [ %a1 = add i32 %b, 1 ] [ %c1 = add i32 %d, 1 ] +/// [ %a2 = xor i32 %a1, 1 ] [ %c2 = xor i32 %c1, 1 ] +/// \ / +/// [ %e = phi i32 %a2, %c2 ] +/// [ add i32 %e, 4 ] +/// +/// +/// GVN would number %a1 and %c1 differently because they compute different +/// results - the VN of an instruction is a function of its opcode and the +/// transitive closure of its operands. This is the key property for hoisting +/// and CSE. +/// +/// What we want when sinking however is for a numbering that is a function of +/// the *uses* of an instruction, which allows us to answer the question "if I +/// replace %a1 with %c1, will it contribute in an equivalent way to all +/// successive instructions?". The PostValueTable class in GVN provides this +/// mapping. +/// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseMapInfo.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SCCIterator.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/PostDominators.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Verifier.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/GVN.h" +#include "llvm/Transforms/Scalar/GVNExpression.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" +#include <unordered_set> +using namespace llvm; + +#define DEBUG_TYPE "gvn-sink" + +STATISTIC(NumRemoved, "Number of instructions removed"); + +namespace llvm { +namespace GVNExpression { + +LLVM_DUMP_METHOD void Expression::dump() const { + print(dbgs()); + dbgs() << "\n"; +} + +} +} + +namespace { + +static bool isMemoryInst(const Instruction *I) { + return isa<LoadInst>(I) || isa<StoreInst>(I) || + (isa<InvokeInst>(I) && !cast<InvokeInst>(I)->doesNotAccessMemory()) || + (isa<CallInst>(I) && !cast<CallInst>(I)->doesNotAccessMemory()); +} + +/// Iterates through instructions in a set of blocks in reverse order from the +/// first non-terminator. For example (assume all blocks have size n): +/// LockstepReverseIterator I([B1, B2, B3]); +/// *I-- = [B1[n], B2[n], B3[n]]; +/// *I-- = [B1[n-1], B2[n-1], B3[n-1]]; +/// *I-- = [B1[n-2], B2[n-2], B3[n-2]]; +/// ... +/// +/// It continues until all blocks have been exhausted. Use \c getActiveBlocks() +/// to +/// determine which blocks are still going and the order they appear in the +/// list returned by operator*. +class LockstepReverseIterator { + ArrayRef<BasicBlock *> Blocks; + SmallPtrSet<BasicBlock *, 4> ActiveBlocks; + SmallVector<Instruction *, 4> Insts; + bool Fail; + +public: + LockstepReverseIterator(ArrayRef<BasicBlock *> Blocks) : Blocks(Blocks) { + reset(); + } + + void reset() { + Fail = false; + ActiveBlocks.clear(); + for (BasicBlock *BB : Blocks) + ActiveBlocks.insert(BB); + Insts.clear(); + for (BasicBlock *BB : Blocks) { + if (BB->size() <= 1) { + // Block wasn't big enough - only contained a terminator. + ActiveBlocks.erase(BB); + continue; + } + Insts.push_back(BB->getTerminator()->getPrevNode()); + } + if (Insts.empty()) + Fail = true; + } + + bool isValid() const { return !Fail; } + ArrayRef<Instruction *> operator*() const { return Insts; } + SmallPtrSet<BasicBlock *, 4> &getActiveBlocks() { return ActiveBlocks; } + + void restrictToBlocks(SmallPtrSetImpl<BasicBlock *> &Blocks) { + for (auto II = Insts.begin(); II != Insts.end();) { + if (std::find(Blocks.begin(), Blocks.end(), (*II)->getParent()) == + Blocks.end()) { + ActiveBlocks.erase((*II)->getParent()); + II = Insts.erase(II); + } else { + ++II; + } + } + } + + void operator--() { + if (Fail) + return; + SmallVector<Instruction *, 4> NewInsts; + for (auto *Inst : Insts) { + if (Inst == &Inst->getParent()->front()) + ActiveBlocks.erase(Inst->getParent()); + else + NewInsts.push_back(Inst->getPrevNode()); + } + if (NewInsts.empty()) { + Fail = true; + return; + } + Insts = NewInsts; + } +}; + +//===----------------------------------------------------------------------===// + +/// Candidate solution for sinking. There may be different ways to +/// sink instructions, differing in the number of instructions sunk, +/// the number of predecessors sunk from and the number of PHIs +/// required. +struct SinkingInstructionCandidate { + unsigned NumBlocks; + unsigned NumInstructions; + unsigned NumPHIs; + unsigned NumMemoryInsts; + int Cost = -1; + SmallVector<BasicBlock *, 4> Blocks; + + void calculateCost(unsigned NumOrigPHIs, unsigned NumOrigBlocks) { + unsigned NumExtraPHIs = NumPHIs - NumOrigPHIs; + unsigned SplitEdgeCost = (NumOrigBlocks > NumBlocks) ? 2 : 0; + Cost = (NumInstructions * (NumBlocks - 1)) - + (NumExtraPHIs * + NumExtraPHIs) // PHIs are expensive, so make sure they're worth it. + - SplitEdgeCost; + } + bool operator>(const SinkingInstructionCandidate &Other) const { + return Cost > Other.Cost; + } +}; + +#ifndef NDEBUG +llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, + const SinkingInstructionCandidate &C) { + OS << "<Candidate Cost=" << C.Cost << " #Blocks=" << C.NumBlocks + << " #Insts=" << C.NumInstructions << " #PHIs=" << C.NumPHIs << ">"; + return OS; +} +#endif + +//===----------------------------------------------------------------------===// + +/// Describes a PHI node that may or may not exist. These track the PHIs +/// that must be created if we sunk a sequence of instructions. It provides +/// a hash function for efficient equality comparisons. +class ModelledPHI { + SmallVector<Value *, 4> Values; + SmallVector<BasicBlock *, 4> Blocks; + +public: + ModelledPHI() {} + ModelledPHI(const PHINode *PN) { + for (unsigned I = 0, E = PN->getNumIncomingValues(); I != E; ++I) + Blocks.push_back(PN->getIncomingBlock(I)); + std::sort(Blocks.begin(), Blocks.end()); + + // This assumes the PHI is already well-formed and there aren't conflicting + // incoming values for the same block. + for (auto *B : Blocks) + Values.push_back(PN->getIncomingValueForBlock(B)); + } + /// Create a dummy ModelledPHI that will compare unequal to any other ModelledPHI + /// without the same ID. + /// \note This is specifically for DenseMapInfo - do not use this! + static ModelledPHI createDummy(size_t ID) { + ModelledPHI M; + M.Values.push_back(reinterpret_cast<Value*>(ID)); + return M; + } + + /// Create a PHI from an array of incoming values and incoming blocks. + template <typename VArray, typename BArray> + ModelledPHI(const VArray &V, const BArray &B) { + std::copy(V.begin(), V.end(), std::back_inserter(Values)); + std::copy(B.begin(), B.end(), std::back_inserter(Blocks)); + } + + /// Create a PHI from [I[OpNum] for I in Insts]. + template <typename BArray> + ModelledPHI(ArrayRef<Instruction *> Insts, unsigned OpNum, const BArray &B) { + std::copy(B.begin(), B.end(), std::back_inserter(Blocks)); + for (auto *I : Insts) + Values.push_back(I->getOperand(OpNum)); + } + + /// Restrict the PHI's contents down to only \c NewBlocks. + /// \c NewBlocks must be a subset of \c this->Blocks. + void restrictToBlocks(const SmallPtrSetImpl<BasicBlock *> &NewBlocks) { + auto BI = Blocks.begin(); + auto VI = Values.begin(); + while (BI != Blocks.end()) { + assert(VI != Values.end()); + if (std::find(NewBlocks.begin(), NewBlocks.end(), *BI) == + NewBlocks.end()) { + BI = Blocks.erase(BI); + VI = Values.erase(VI); + } else { + ++BI; + ++VI; + } + } + assert(Blocks.size() == NewBlocks.size()); + } + + ArrayRef<Value *> getValues() const { return Values; } + + bool areAllIncomingValuesSame() const { + return all_of(Values, [&](Value *V) { return V == Values[0]; }); + } + bool areAllIncomingValuesSameType() const { + return all_of( + Values, [&](Value *V) { return V->getType() == Values[0]->getType(); }); + } + bool areAnyIncomingValuesConstant() const { + return any_of(Values, [&](Value *V) { return isa<Constant>(V); }); + } + // Hash functor + unsigned hash() const { + return (unsigned)hash_combine_range(Values.begin(), Values.end()); + } + bool operator==(const ModelledPHI &Other) const { + return Values == Other.Values && Blocks == Other.Blocks; + } +}; + +template <typename ModelledPHI> struct DenseMapInfo { + static inline ModelledPHI &getEmptyKey() { + static ModelledPHI Dummy = ModelledPHI::createDummy(0); + return Dummy; + } + static inline ModelledPHI &getTombstoneKey() { + static ModelledPHI Dummy = ModelledPHI::createDummy(1); + return Dummy; + } + static unsigned getHashValue(const ModelledPHI &V) { return V.hash(); } + static bool isEqual(const ModelledPHI &LHS, const ModelledPHI &RHS) { + return LHS == RHS; + } +}; + +typedef DenseSet<ModelledPHI, DenseMapInfo<ModelledPHI>> ModelledPHISet; + +//===----------------------------------------------------------------------===// +// ValueTable +//===----------------------------------------------------------------------===// +// This is a value number table where the value number is a function of the +// *uses* of a value, rather than its operands. Thus, if VN(A) == VN(B) we know +// that the program would be equivalent if we replaced A with PHI(A, B). +//===----------------------------------------------------------------------===// + +/// A GVN expression describing how an instruction is used. The operands +/// field of BasicExpression is used to store uses, not operands. +/// +/// This class also contains fields for discriminators used when determining +/// equivalence of instructions with sideeffects. +class InstructionUseExpr : public GVNExpression::BasicExpression { + unsigned MemoryUseOrder = -1; + bool Volatile = false; + +public: + InstructionUseExpr(Instruction *I, ArrayRecycler<Value *> &R, + BumpPtrAllocator &A) + : GVNExpression::BasicExpression(I->getNumUses()) { + allocateOperands(R, A); + setOpcode(I->getOpcode()); + setType(I->getType()); + + for (auto &U : I->uses()) + op_push_back(U.getUser()); + std::sort(op_begin(), op_end()); + } + void setMemoryUseOrder(unsigned MUO) { MemoryUseOrder = MUO; } + void setVolatile(bool V) { Volatile = V; } + + virtual hash_code getHashValue() const { + return hash_combine(GVNExpression::BasicExpression::getHashValue(), + MemoryUseOrder, Volatile); + } + + template <typename Function> hash_code getHashValue(Function MapFn) { + hash_code H = + hash_combine(getOpcode(), getType(), MemoryUseOrder, Volatile); + for (auto *V : operands()) + H = hash_combine(H, MapFn(V)); + return H; + } +}; + +class ValueTable { + DenseMap<Value *, uint32_t> ValueNumbering; + DenseMap<GVNExpression::Expression *, uint32_t> ExpressionNumbering; + DenseMap<size_t, uint32_t> HashNumbering; + BumpPtrAllocator Allocator; + ArrayRecycler<Value *> Recycler; + uint32_t nextValueNumber; + + /// Create an expression for I based on its opcode and its uses. If I + /// touches or reads memory, the expression is also based upon its memory + /// order - see \c getMemoryUseOrder(). + InstructionUseExpr *createExpr(Instruction *I) { + InstructionUseExpr *E = + new (Allocator) InstructionUseExpr(I, Recycler, Allocator); + if (isMemoryInst(I)) + E->setMemoryUseOrder(getMemoryUseOrder(I)); + + if (CmpInst *C = dyn_cast<CmpInst>(I)) { + CmpInst::Predicate Predicate = C->getPredicate(); + E->setOpcode((C->getOpcode() << 8) | Predicate); + } + return E; + } + + /// Helper to compute the value number for a memory instruction + /// (LoadInst/StoreInst), including checking the memory ordering and + /// volatility. + template <class Inst> InstructionUseExpr *createMemoryExpr(Inst *I) { + if (isStrongerThanUnordered(I->getOrdering()) || I->isAtomic()) + return nullptr; + InstructionUseExpr *E = createExpr(I); + E->setVolatile(I->isVolatile()); + return E; + } + +public: + /// Returns the value number for the specified value, assigning + /// it a new number if it did not have one before. + uint32_t lookupOrAdd(Value *V) { + auto VI = ValueNumbering.find(V); + if (VI != ValueNumbering.end()) + return VI->second; + + if (!isa<Instruction>(V)) { + ValueNumbering[V] = nextValueNumber; + return nextValueNumber++; + } + + Instruction *I = cast<Instruction>(V); + InstructionUseExpr *exp = nullptr; + switch (I->getOpcode()) { + case Instruction::Load: + exp = createMemoryExpr(cast<LoadInst>(I)); + break; + case Instruction::Store: + exp = createMemoryExpr(cast<StoreInst>(I)); + break; + case Instruction::Call: + case Instruction::Invoke: + case Instruction::Add: + case Instruction::FAdd: + case Instruction::Sub: + case Instruction::FSub: + case Instruction::Mul: + case Instruction::FMul: + case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::FDiv: + case Instruction::URem: + case Instruction::SRem: + case Instruction::FRem: + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + case Instruction::ICmp: + case Instruction::FCmp: + case Instruction::Trunc: + case Instruction::ZExt: + case Instruction::SExt: + case Instruction::FPToUI: + case Instruction::FPToSI: + case Instruction::UIToFP: + case Instruction::SIToFP: + case Instruction::FPTrunc: + case Instruction::FPExt: + case Instruction::PtrToInt: + case Instruction::IntToPtr: + case Instruction::BitCast: + case Instruction::Select: + case Instruction::ExtractElement: + case Instruction::InsertElement: + case Instruction::ShuffleVector: + case Instruction::InsertValue: + case Instruction::GetElementPtr: + exp = createExpr(I); + break; + default: + break; + } + + if (!exp) { + ValueNumbering[V] = nextValueNumber; + return nextValueNumber++; + } + + uint32_t e = ExpressionNumbering[exp]; + if (!e) { + hash_code H = exp->getHashValue([=](Value *V) { return lookupOrAdd(V); }); + auto I = HashNumbering.find(H); + if (I != HashNumbering.end()) { + e = I->second; + } else { + e = nextValueNumber++; + HashNumbering[H] = e; + ExpressionNumbering[exp] = e; + } + } + ValueNumbering[V] = e; + return e; + } + + /// Returns the value number of the specified value. Fails if the value has + /// not yet been numbered. + uint32_t lookup(Value *V) const { + auto VI = ValueNumbering.find(V); + assert(VI != ValueNumbering.end() && "Value not numbered?"); + return VI->second; + } + + /// Removes all value numberings and resets the value table. + void clear() { + ValueNumbering.clear(); + ExpressionNumbering.clear(); + HashNumbering.clear(); + Recycler.clear(Allocator); + nextValueNumber = 1; + } + + ValueTable() : nextValueNumber(1) {} + + /// \c Inst uses or touches memory. Return an ID describing the memory state + /// at \c Inst such that if getMemoryUseOrder(I1) == getMemoryUseOrder(I2), + /// the exact same memory operations happen after I1 and I2. + /// + /// This is a very hard problem in general, so we use domain-specific + /// knowledge that we only ever check for equivalence between blocks sharing a + /// single immediate successor that is common, and when determining if I1 == + /// I2 we will have already determined that next(I1) == next(I2). This + /// inductive property allows us to simply return the value number of the next + /// instruction that defines memory. + uint32_t getMemoryUseOrder(Instruction *Inst) { + auto *BB = Inst->getParent(); + for (auto I = std::next(Inst->getIterator()), E = BB->end(); + I != E && !I->isTerminator(); ++I) { + if (!isMemoryInst(&*I)) + continue; + if (isa<LoadInst>(&*I)) + continue; + CallInst *CI = dyn_cast<CallInst>(&*I); + if (CI && CI->onlyReadsMemory()) + continue; + InvokeInst *II = dyn_cast<InvokeInst>(&*I); + if (II && II->onlyReadsMemory()) + continue; + return lookupOrAdd(&*I); + } + return 0; + } +}; + +//===----------------------------------------------------------------------===// + +class GVNSink { +public: + GVNSink() : VN() {} + bool run(Function &F) { + DEBUG(dbgs() << "GVNSink: running on function @" << F.getName() << "\n"); + + unsigned NumSunk = 0; + ReversePostOrderTraversal<Function*> RPOT(&F); + for (auto *N : RPOT) + NumSunk += sinkBB(N); + + return NumSunk > 0; + } + +private: + ValueTable VN; + + bool isInstructionBlacklisted(Instruction *I) { + // These instructions may change or break semantics if moved. + if (isa<PHINode>(I) || I->isEHPad() || isa<AllocaInst>(I) || + I->getType()->isTokenTy()) + return true; + return false; + } + + /// The main heuristic function. Analyze the set of instructions pointed to by + /// LRI and return a candidate solution if these instructions can be sunk, or + /// None otherwise. + Optional<SinkingInstructionCandidate> analyzeInstructionForSinking( + LockstepReverseIterator &LRI, unsigned &InstNum, unsigned &MemoryInstNum, + ModelledPHISet &NeededPHIs, SmallPtrSetImpl<Value *> &PHIContents); + + /// Create a ModelledPHI for each PHI in BB, adding to PHIs. + void analyzeInitialPHIs(BasicBlock *BB, ModelledPHISet &PHIs, + SmallPtrSetImpl<Value *> &PHIContents) { + for (auto &I : *BB) { + auto *PN = dyn_cast<PHINode>(&I); + if (!PN) + return; + + auto MPHI = ModelledPHI(PN); + PHIs.insert(MPHI); + for (auto *V : MPHI.getValues()) + PHIContents.insert(V); + } + } + + /// The main instruction sinking driver. Set up state and try and sink + /// instructions into BBEnd from its predecessors. + unsigned sinkBB(BasicBlock *BBEnd); + + /// Perform the actual mechanics of sinking an instruction from Blocks into + /// BBEnd, which is their only successor. + void sinkLastInstruction(ArrayRef<BasicBlock *> Blocks, BasicBlock *BBEnd); + + /// Remove PHIs that all have the same incoming value. + void foldPointlessPHINodes(BasicBlock *BB) { + auto I = BB->begin(); + while (PHINode *PN = dyn_cast<PHINode>(I++)) { + if (!all_of(PN->incoming_values(), + [&](const Value *V) { return V == PN->getIncomingValue(0); })) + continue; + if (PN->getIncomingValue(0) != PN) + PN->replaceAllUsesWith(PN->getIncomingValue(0)); + else + PN->replaceAllUsesWith(UndefValue::get(PN->getType())); + PN->eraseFromParent(); + } + } +}; + +Optional<SinkingInstructionCandidate> GVNSink::analyzeInstructionForSinking( + LockstepReverseIterator &LRI, unsigned &InstNum, unsigned &MemoryInstNum, + ModelledPHISet &NeededPHIs, SmallPtrSetImpl<Value *> &PHIContents) { + auto Insts = *LRI; + DEBUG(dbgs() << " -- Analyzing instruction set: [\n"; for (auto *I + : Insts) { + I->dump(); + } dbgs() << " ]\n";); + + DenseMap<uint32_t, unsigned> VNums; + for (auto *I : Insts) { + uint32_t N = VN.lookupOrAdd(I); + DEBUG(dbgs() << " VN=" << utohexstr(N) << " for" << *I << "\n"); + if (N == ~0U) + return None; + VNums[N]++; + } + unsigned VNumToSink = + std::max_element(VNums.begin(), VNums.end(), + [](const std::pair<uint32_t, unsigned> &I, + const std::pair<uint32_t, unsigned> &J) { + return I.second < J.second; + }) + ->first; + + if (VNums[VNumToSink] == 1) + // Can't sink anything! + return None; + + // Now restrict the number of incoming blocks down to only those with + // VNumToSink. + auto &ActivePreds = LRI.getActiveBlocks(); + unsigned InitialActivePredSize = ActivePreds.size(); + SmallVector<Instruction *, 4> NewInsts; + for (auto *I : Insts) { + if (VN.lookup(I) != VNumToSink) + ActivePreds.erase(I->getParent()); + else + NewInsts.push_back(I); + } + for (auto *I : NewInsts) + if (isInstructionBlacklisted(I)) + return None; + + // If we've restricted the incoming blocks, restrict all needed PHIs also + // to that set. + bool RecomputePHIContents = false; + if (ActivePreds.size() != InitialActivePredSize) { + ModelledPHISet NewNeededPHIs; + for (auto P : NeededPHIs) { + P.restrictToBlocks(ActivePreds); + NewNeededPHIs.insert(P); + } + NeededPHIs = NewNeededPHIs; + LRI.restrictToBlocks(ActivePreds); + RecomputePHIContents = true; + } + + // The sunk instruction's results. + ModelledPHI NewPHI(NewInsts, ActivePreds); + + // Does sinking this instruction render previous PHIs redundant? + if (NeededPHIs.find(NewPHI) != NeededPHIs.end()) { + NeededPHIs.erase(NewPHI); + RecomputePHIContents = true; + } + + if (RecomputePHIContents) { + // The needed PHIs have changed, so recompute the set of all needed + // values. + PHIContents.clear(); + for (auto &PHI : NeededPHIs) + PHIContents.insert(PHI.getValues().begin(), PHI.getValues().end()); + } + + // Is this instruction required by a later PHI that doesn't match this PHI? + // if so, we can't sink this instruction. + for (auto *V : NewPHI.getValues()) + if (PHIContents.count(V)) + // V exists in this PHI, but the whole PHI is different to NewPHI + // (else it would have been removed earlier). We cannot continue + // because this isn't representable. + return None; + + // Which operands need PHIs? + // FIXME: If any of these fail, we should partition up the candidates to + // try and continue making progress. + Instruction *I0 = NewInsts[0]; + for (unsigned OpNum = 0, E = I0->getNumOperands(); OpNum != E; ++OpNum) { + ModelledPHI PHI(NewInsts, OpNum, ActivePreds); + if (PHI.areAllIncomingValuesSame()) + continue; + if (!canReplaceOperandWithVariable(I0, OpNum)) + // We can 't create a PHI from this instruction! + return None; + if (NeededPHIs.count(PHI)) + continue; + if (!PHI.areAllIncomingValuesSameType()) + return None; + // Don't create indirect calls! The called value is the final operand. + if ((isa<CallInst>(I0) || isa<InvokeInst>(I0)) && OpNum == E - 1 && + PHI.areAnyIncomingValuesConstant()) + return None; + + NeededPHIs.reserve(NeededPHIs.size()); + NeededPHIs.insert(PHI); + PHIContents.insert(PHI.getValues().begin(), PHI.getValues().end()); + } + + if (isMemoryInst(NewInsts[0])) + ++MemoryInstNum; + + SinkingInstructionCandidate Cand; + Cand.NumInstructions = ++InstNum; + Cand.NumMemoryInsts = MemoryInstNum; + Cand.NumBlocks = ActivePreds.size(); + Cand.NumPHIs = NeededPHIs.size(); + for (auto *C : ActivePreds) + Cand.Blocks.push_back(C); + + return Cand; +} + +unsigned GVNSink::sinkBB(BasicBlock *BBEnd) { + DEBUG(dbgs() << "GVNSink: running on basic block "; + BBEnd->printAsOperand(dbgs()); dbgs() << "\n"); + SmallVector<BasicBlock *, 4> Preds; + for (auto *B : predecessors(BBEnd)) { + auto *T = B->getTerminator(); + if (isa<BranchInst>(T) || isa<SwitchInst>(T)) + Preds.push_back(B); + else + return 0; + } + if (Preds.size() < 2) + return 0; + std::sort(Preds.begin(), Preds.end()); + + unsigned NumOrigPreds = Preds.size(); + // We can only sink instructions through unconditional branches. + for (auto I = Preds.begin(); I != Preds.end();) { + if ((*I)->getTerminator()->getNumSuccessors() != 1) + I = Preds.erase(I); + else + ++I; + } + + LockstepReverseIterator LRI(Preds); + SmallVector<SinkingInstructionCandidate, 4> Candidates; + unsigned InstNum = 0, MemoryInstNum = 0; + ModelledPHISet NeededPHIs; + SmallPtrSet<Value *, 4> PHIContents; + analyzeInitialPHIs(BBEnd, NeededPHIs, PHIContents); + unsigned NumOrigPHIs = NeededPHIs.size(); + + while (LRI.isValid()) { + auto Cand = analyzeInstructionForSinking(LRI, InstNum, MemoryInstNum, + NeededPHIs, PHIContents); + if (!Cand) + break; + Cand->calculateCost(NumOrigPHIs, Preds.size()); + Candidates.emplace_back(*Cand); + --LRI; + } + + std::stable_sort( + Candidates.begin(), Candidates.end(), + [](const SinkingInstructionCandidate &A, + const SinkingInstructionCandidate &B) { return A > B; }); + DEBUG(dbgs() << " -- Sinking candidates:\n"; for (auto &C + : Candidates) dbgs() + << " " << C << "\n";); + + // Pick the top candidate, as long it is positive! + if (Candidates.empty() || Candidates.front().Cost <= 0) + return 0; + auto C = Candidates.front(); + + DEBUG(dbgs() << " -- Sinking: " << C << "\n"); + BasicBlock *InsertBB = BBEnd; + if (C.Blocks.size() < NumOrigPreds) { + DEBUG(dbgs() << " -- Splitting edge to "; BBEnd->printAsOperand(dbgs()); + dbgs() << "\n"); + InsertBB = SplitBlockPredecessors(BBEnd, C.Blocks, ".gvnsink.split"); + if (!InsertBB) { + DEBUG(dbgs() << " -- FAILED to split edge!\n"); + // Edge couldn't be split. + return 0; + } + } + + for (unsigned I = 0; I < C.NumInstructions; ++I) + sinkLastInstruction(C.Blocks, InsertBB); + + return C.NumInstructions; +} + +void GVNSink::sinkLastInstruction(ArrayRef<BasicBlock *> Blocks, + BasicBlock *BBEnd) { + SmallVector<Instruction *, 4> Insts; + for (BasicBlock *BB : Blocks) + Insts.push_back(BB->getTerminator()->getPrevNode()); + Instruction *I0 = Insts.front(); + + SmallVector<Value *, 4> NewOperands; + for (unsigned O = 0, E = I0->getNumOperands(); O != E; ++O) { + bool NeedPHI = any_of(Insts, [&I0, O](const Instruction *I) { + return I->getOperand(O) != I0->getOperand(O); + }); + if (!NeedPHI) { + NewOperands.push_back(I0->getOperand(O)); + continue; + } + + // Create a new PHI in the successor block and populate it. + auto *Op = I0->getOperand(O); + assert(!Op->getType()->isTokenTy() && "Can't PHI tokens!"); + auto *PN = PHINode::Create(Op->getType(), Insts.size(), + Op->getName() + ".sink", &BBEnd->front()); + for (auto *I : Insts) + PN->addIncoming(I->getOperand(O), I->getParent()); + NewOperands.push_back(PN); + } + + // Arbitrarily use I0 as the new "common" instruction; remap its operands + // and move it to the start of the successor block. + for (unsigned O = 0, E = I0->getNumOperands(); O != E; ++O) + I0->getOperandUse(O).set(NewOperands[O]); + I0->moveBefore(&*BBEnd->getFirstInsertionPt()); + + // Update metadata and IR flags. + for (auto *I : Insts) + if (I != I0) { + combineMetadataForCSE(I0, I); + I0->andIRFlags(I); + } + + for (auto *I : Insts) + if (I != I0) + I->replaceAllUsesWith(I0); + foldPointlessPHINodes(BBEnd); + + // Finally nuke all instructions apart from the common instruction. + for (auto *I : Insts) + if (I != I0) + I->eraseFromParent(); + + NumRemoved += Insts.size() - 1; +} + +//////////////////////////////////////////////////////////////////////////////// +// Pass machinery / boilerplate + +class GVNSinkLegacyPass : public FunctionPass { +public: + static char ID; + + GVNSinkLegacyPass() : FunctionPass(ID) { + initializeGVNSinkLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + GVNSink G; + return G.run(F); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addPreserved<GlobalsAAWrapperPass>(); + } +}; +} // namespace + +PreservedAnalyses GVNSinkPass::run(Function &F, FunctionAnalysisManager &AM) { + GVNSink G; + if (!G.run(F)) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + PA.preserve<GlobalsAA>(); + return PA; +} + +char GVNSinkLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(GVNSinkLegacyPass, "gvn-sink", + "Early GVN sinking of Expressions", false, false) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) +INITIALIZE_PASS_END(GVNSinkLegacyPass, "gvn-sink", + "Early GVN sinking of Expressions", false, false) + +FunctionPass *llvm::createGVNSinkPass() { return new GVNSinkLegacyPass(); } diff --git a/contrib/llvm/lib/Transforms/Scalar/GuardWidening.cpp b/contrib/llvm/lib/Transforms/Scalar/GuardWidening.cpp index b05ef00..fb7c6e1 100644 --- a/contrib/llvm/lib/Transforms/Scalar/GuardWidening.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/GuardWidening.cpp @@ -40,7 +40,6 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/GuardWidening.h" -#include "llvm/Pass.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/Analysis/LoopInfo.h" @@ -50,7 +49,9 @@ #include "llvm/IR/Dominators.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/Pass.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/KnownBits.h" #include "llvm/Transforms/Scalar.h" using namespace llvm; @@ -536,10 +537,8 @@ bool GuardWideningImpl::parseRangeChecks( Changed = true; } else if (match(Check.getBase(), m_Or(m_Value(OpLHS), m_ConstantInt(OpRHS)))) { - unsigned BitWidth = OpLHS->getType()->getScalarSizeInBits(); - APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - computeKnownBits(OpLHS, KnownZero, KnownOne, DL); - if ((OpRHS->getValue() & KnownZero) == OpRHS->getValue()) { + KnownBits Known = computeKnownBits(OpLHS, DL); + if ((OpRHS->getValue() & Known.Zero) == OpRHS->getValue()) { Check.setBase(OpLHS); APInt NewOffset = Check.getOffsetValue() + OpRHS->getValue(); Check.setOffset(ConstantInt::get(Ctx, NewOffset)); @@ -568,8 +567,7 @@ bool GuardWideningImpl::combineRangeChecks( return RC.getBase() == CurrentBase && RC.getLength() == CurrentLength; }; - std::copy_if(Checks.begin(), Checks.end(), - std::back_inserter(CurrentChecks), IsCurrentCheck); + copy_if(Checks, std::back_inserter(CurrentChecks), IsCurrentCheck); Checks.erase(remove_if(Checks, IsCurrentCheck), Checks.end()); assert(CurrentChecks.size() != 0 && "We know we have at least one!"); @@ -613,16 +611,16 @@ bool GuardWideningImpl::combineRangeChecks( // We have a series of f+1 checks as: // // I+k_0 u< L ... Chk_0 - // I_k_1 u< L ... Chk_1 + // I+k_1 u< L ... Chk_1 // ... - // I_k_f u< L ... Chk_(f+1) + // I+k_f u< L ... Chk_f // - // with forall i in [0,f): k_f-k_i u< k_f-k_0 ... Precond_0 + // with forall i in [0,f]: k_f-k_i u< k_f-k_0 ... Precond_0 // k_f-k_0 u< INT_MIN+k_f ... Precond_1 // k_f != k_0 ... Precond_2 // // Claim: - // Chk_0 AND Chk_(f+1) implies all the other checks + // Chk_0 AND Chk_f implies all the other checks // // Informal proof sketch: // @@ -658,8 +656,12 @@ PreservedAnalyses GuardWideningPass::run(Function &F, auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &LI = AM.getResult<LoopAnalysis>(F); auto &PDT = AM.getResult<PostDominatorTreeAnalysis>(F); - bool Changed = GuardWideningImpl(DT, PDT, LI).run(); - return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); + if (!GuardWideningImpl(DT, PDT, LI).run()) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + return PA; } StringRef GuardWideningImpl::scoreTypeToString(WideningScore WS) { diff --git a/contrib/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp b/contrib/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp index 1752fb7..1078296 100644 --- a/contrib/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -86,6 +86,10 @@ static cl::opt<bool> UsePostIncrementRanges( cl::desc("Use post increment control-dependent ranges in IndVarSimplify"), cl::init(true)); +static cl::opt<bool> +DisableLFTR("disable-lftr", cl::Hidden, cl::init(false), + cl::desc("Disable Linear Function Test Replace optimization")); + namespace { struct RewritePhi; @@ -97,7 +101,7 @@ class IndVarSimplify { TargetLibraryInfo *TLI; const TargetTransformInfo *TTI; - SmallVector<WeakVH, 16> DeadInsts; + SmallVector<WeakTrackingVH, 16> DeadInsts; bool Changed = false; bool isValidRewrite(Value *FromVal, Value *ToVal); @@ -231,8 +235,9 @@ static bool ConvertToSInt(const APFloat &APF, int64_t &IntVal) { bool isExact = false; // See if we can convert this to an int64_t uint64_t UIntVal; - if (APF.convertToInteger(&UIntVal, 64, true, APFloat::rmTowardZero, - &isExact) != APFloat::opOK || !isExact) + if (APF.convertToInteger(makeMutableArrayRef(UIntVal), 64, true, + APFloat::rmTowardZero, &isExact) != APFloat::opOK || + !isExact) return false; IntVal = UIntVal; return true; @@ -414,8 +419,8 @@ void IndVarSimplify::handleFloatingPointIV(Loop *L, PHINode *PN) { Compare->getName()); // In the following deletions, PN may become dead and may be deleted. - // Use a WeakVH to observe whether this happens. - WeakVH WeakPH = PN; + // Use a WeakTrackingVH to observe whether this happens. + WeakTrackingVH WeakPH = PN; // Delete the old floating point exit comparison. The branch starts using the // new comparison. @@ -450,7 +455,7 @@ void IndVarSimplify::rewriteNonIntegerIVs(Loop *L) { // BasicBlock *Header = L->getHeader(); - SmallVector<WeakVH, 8> PHIs; + SmallVector<WeakTrackingVH, 8> PHIs; for (BasicBlock::iterator I = Header->begin(); PHINode *PN = dyn_cast<PHINode>(I); ++I) PHIs.push_back(PN); @@ -900,13 +905,13 @@ class WidenIV { PHINode *WidePhi; Instruction *WideInc; const SCEV *WideIncExpr; - SmallVectorImpl<WeakVH> &DeadInsts; + SmallVectorImpl<WeakTrackingVH> &DeadInsts; SmallPtrSet<Instruction *,16> Widened; SmallVector<NarrowIVDefUse, 8> NarrowIVUsers; enum ExtendKind { ZeroExtended, SignExtended, Unknown }; - // A map tracking the kind of extension used to widen each narrow IV + // A map tracking the kind of extension used to widen each narrow IV // and narrow IV user. // Key: pointer to a narrow IV or IV user. // Value: the kind of extension used to widen this Instruction. @@ -940,20 +945,13 @@ class WidenIV { } public: - WidenIV(const WideIVInfo &WI, LoopInfo *LInfo, - ScalarEvolution *SEv, DominatorTree *DTree, - SmallVectorImpl<WeakVH> &DI, bool HasGuards) : - OrigPhi(WI.NarrowIV), - WideType(WI.WidestNativeType), - LI(LInfo), - L(LI->getLoopFor(OrigPhi->getParent())), - SE(SEv), - DT(DTree), - HasGuards(HasGuards), - WidePhi(nullptr), - WideInc(nullptr), - WideIncExpr(nullptr), - DeadInsts(DI) { + WidenIV(const WideIVInfo &WI, LoopInfo *LInfo, ScalarEvolution *SEv, + DominatorTree *DTree, SmallVectorImpl<WeakTrackingVH> &DI, + bool HasGuards) + : OrigPhi(WI.NarrowIV), WideType(WI.WidestNativeType), LI(LInfo), + L(LI->getLoopFor(OrigPhi->getParent())), SE(SEv), DT(DTree), + HasGuards(HasGuards), WidePhi(nullptr), WideInc(nullptr), + WideIncExpr(nullptr), DeadInsts(DI) { assert(L->getHeader() == OrigPhi->getParent() && "Phi must be an IV"); ExtendKindMap[OrigPhi] = WI.IsSigned ? SignExtended : ZeroExtended; } @@ -1608,7 +1606,7 @@ void WidenIV::calculatePostIncRange(Instruction *NarrowDef, return; CmpInst::Predicate P = - TrueDest ? Pred : CmpInst::getInversePredicate(Pred); + TrueDest ? Pred : CmpInst::getInversePredicate(Pred); auto CmpRHSRange = SE->getSignedRange(SE->getSCEV(CmpRHS)); auto CmpConstrainedLHSRange = @@ -1634,7 +1632,7 @@ void WidenIV::calculatePostIncRange(Instruction *NarrowDef, UpdateRangeFromGuards(NarrowUser); BasicBlock *NarrowUserBB = NarrowUser->getParent(); - // If NarrowUserBB is statically unreachable asking dominator queries may + // If NarrowUserBB is statically unreachable asking dominator queries may // yield surprising results. (e.g. the block may not have a dom tree node) if (!DT->isReachableFromEntry(NarrowUserBB)) return; @@ -1829,6 +1827,7 @@ static PHINode *getLoopPhiForCounter(Value *IncV, Loop *L, DominatorTree *DT) { // An IV counter must preserve its type. if (IncI->getNumOperands() == 2) break; + LLVM_FALLTHROUGH; default: return nullptr; } @@ -2152,6 +2151,8 @@ linearFunctionTestReplace(Loop *L, Value *CmpIndVar = IndVar; const SCEV *IVCount = BackedgeTakenCount; + assert(L->getLoopLatch() && "Loop no longer in simplified form?"); + // If the exiting block is the same as the backedge block, we prefer to // compare against the post-incremented value, otherwise we must compare // against the preincremented value. @@ -2376,6 +2377,7 @@ bool IndVarSimplify::run(Loop *L) { // Loop::getCanonicalInductionVariable only supports loops with preheaders, // and we're in trouble if we can't find the induction variable even when // we've manually inserted one. + // - LFTR relies on having a single backedge. if (!L->isLoopSimplifyForm()) return false; @@ -2415,7 +2417,8 @@ bool IndVarSimplify::run(Loop *L) { // If we have a trip count expression, rewrite the loop's exit condition // using it. We can currently only handle loops with a single exit. - if (canExpandBackedgeTakenCount(L, SE, Rewriter) && needsLFTR(L, DT)) { + if (!DisableLFTR && canExpandBackedgeTakenCount(L, SE, Rewriter) && + needsLFTR(L, DT)) { PHINode *IndVar = FindLoopCounter(L, BackedgeTakenCount, SE, DT); if (IndVar) { // Check preconditions for proper SCEVExpander operation. SCEV does not @@ -2492,8 +2495,9 @@ PreservedAnalyses IndVarSimplifyPass::run(Loop &L, LoopAnalysisManager &AM, if (!IVS.run(&L)) return PreservedAnalyses::all(); - // FIXME: This should also 'preserve the CFG'. - return getLoopPassPreservedAnalyses(); + auto PA = getLoopPassPreservedAnalyses(); + PA.preserveSet<CFGAnalyses>(); + return PA; } namespace { diff --git a/contrib/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp b/contrib/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp index 8e81541..99b4458 100644 --- a/contrib/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp @@ -59,8 +59,8 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" -#include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/LoopSimplify.h" +#include "llvm/Transforms/Utils/LoopUtils.h" using namespace llvm; @@ -446,6 +446,15 @@ struct LoopStructure { BasicBlock *LatchExit; unsigned LatchBrExitIdx; + // The loop represented by this instance of LoopStructure is semantically + // equivalent to: + // + // intN_ty inc = IndVarIncreasing ? 1 : -1; + // pred_ty predicate = IndVarIncreasing ? ICMP_SLT : ICMP_SGT; + // + // for (intN_ty iv = IndVarStart; predicate(iv, LoopExitAt); iv = IndVarNext) + // ... body ... + Value *IndVarNext; Value *IndVarStart; Value *LoopExitAt; @@ -789,9 +798,32 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BP return None; } + const SCEV *StartNext = IndVarNext->getStart(); + const SCEV *Addend = SE.getNegativeSCEV(IndVarNext->getStepRecurrence(SE)); + const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend); + ConstantInt *One = ConstantInt::get(IndVarTy, 1); // TODO: generalize the predicates here to also match their unsigned variants. if (IsIncreasing) { + bool DecreasedRightValueByOne = false; + // Try to turn eq/ne predicates to those we can work with. + if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1) + // while (++i != len) { while (++i < len) { + // ... ---> ... + // } } + Pred = ICmpInst::ICMP_SLT; + else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0 && + !CanBeSMin(SE, RightSCEV)) { + // while (true) { while (true) { + // if (++i == len) ---> if (++i > len - 1) + // break; break; + // ... ... + // } } + Pred = ICmpInst::ICMP_SGT; + RightSCEV = SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType())); + DecreasedRightValueByOne = true; + } + bool FoundExpectedPred = (Pred == ICmpInst::ICMP_SLT && LatchBrExitIdx == 1) || (Pred == ICmpInst::ICMP_SGT && LatchBrExitIdx == 0); @@ -809,11 +841,48 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BP return None; } - IRBuilder<> B(Preheader->getTerminator()); - RightValue = B.CreateAdd(RightValue, One); - } + if (!SE.isLoopEntryGuardedByCond( + &L, CmpInst::ICMP_SLT, IndVarStart, + SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())))) { + FailureReason = "Induction variable start not bounded by upper limit"; + return None; + } + // We need to increase the right value unless we have already decreased + // it virtually when we replaced EQ with SGT. + if (!DecreasedRightValueByOne) { + IRBuilder<> B(Preheader->getTerminator()); + RightValue = B.CreateAdd(RightValue, One); + } + } else { + if (!SE.isLoopEntryGuardedByCond(&L, CmpInst::ICMP_SLT, IndVarStart, + RightSCEV)) { + FailureReason = "Induction variable start not bounded by upper limit"; + return None; + } + assert(!DecreasedRightValueByOne && + "Right value can be decreased only for LatchBrExitIdx == 0!"); + } } else { + bool IncreasedRightValueByOne = false; + // Try to turn eq/ne predicates to those we can work with. + if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1) + // while (--i != len) { while (--i > len) { + // ... ---> ... + // } } + Pred = ICmpInst::ICMP_SGT; + else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0 && + !CanBeSMax(SE, RightSCEV)) { + // while (true) { while (true) { + // if (--i == len) ---> if (--i < len + 1) + // break; break; + // ... ... + // } } + Pred = ICmpInst::ICMP_SLT; + RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())); + IncreasedRightValueByOne = true; + } + bool FoundExpectedPred = (Pred == ICmpInst::ICMP_SGT && LatchBrExitIdx == 1) || (Pred == ICmpInst::ICMP_SLT && LatchBrExitIdx == 0); @@ -831,15 +900,30 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BP return None; } - IRBuilder<> B(Preheader->getTerminator()); - RightValue = B.CreateSub(RightValue, One); + if (!SE.isLoopEntryGuardedByCond( + &L, CmpInst::ICMP_SGT, IndVarStart, + SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType())))) { + FailureReason = "Induction variable start not bounded by lower limit"; + return None; + } + + // We need to decrease the right value unless we have already increased + // it virtually when we replaced EQ with SLT. + if (!IncreasedRightValueByOne) { + IRBuilder<> B(Preheader->getTerminator()); + RightValue = B.CreateSub(RightValue, One); + } + } else { + if (!SE.isLoopEntryGuardedByCond(&L, CmpInst::ICMP_SGT, IndVarStart, + RightSCEV)) { + FailureReason = "Induction variable start not bounded by lower limit"; + return None; + } + assert(!IncreasedRightValueByOne && + "Right value can be increased only for LatchBrExitIdx == 0!"); } } - const SCEV *StartNext = IndVarNext->getStart(); - const SCEV *Addend = SE.getNegativeSCEV(IndVarNext->getStepRecurrence(SE)); - const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend); - BasicBlock *LatchExit = LatchBr->getSuccessor(LatchBrExitIdx); assert(SE.getLoopDisposition(LatchCount, &L) == @@ -883,20 +967,23 @@ LoopConstrainer::calculateSubRanges() const { // I think we can be more aggressive here and make this nuw / nsw if the // addition that feeds into the icmp for the latch's terminating branch is nuw // / nsw. In any case, a wrapping 2's complement addition is safe. - ConstantInt *One = ConstantInt::get(Ty, 1); const SCEV *Start = SE.getSCEV(MainLoopStructure.IndVarStart); const SCEV *End = SE.getSCEV(MainLoopStructure.LoopExitAt); bool Increasing = MainLoopStructure.IndVarIncreasing; - // We compute `Smallest` and `Greatest` such that [Smallest, Greatest) is the - // range of values the induction variable takes. + // We compute `Smallest` and `Greatest` such that [Smallest, Greatest), or + // [Smallest, GreatestSeen] is the range of values the induction variable + // takes. - const SCEV *Smallest = nullptr, *Greatest = nullptr; + const SCEV *Smallest = nullptr, *Greatest = nullptr, *GreatestSeen = nullptr; + const SCEV *One = SE.getOne(Ty); if (Increasing) { Smallest = Start; Greatest = End; + // No overflow, because the range [Smallest, GreatestSeen] is not empty. + GreatestSeen = SE.getMinusSCEV(End, One); } else { // These two computations may sign-overflow. Here is why that is okay: // @@ -914,8 +1001,9 @@ LoopConstrainer::calculateSubRanges() const { // will be an empty range. Returning an empty range is always safe. // - Smallest = SE.getAddExpr(End, SE.getSCEV(One)); - Greatest = SE.getAddExpr(Start, SE.getSCEV(One)); + Smallest = SE.getAddExpr(End, One); + Greatest = SE.getAddExpr(Start, One); + GreatestSeen = Start; } auto Clamp = [this, Smallest, Greatest](const SCEV *S) { @@ -930,7 +1018,7 @@ LoopConstrainer::calculateSubRanges() const { Result.LowLimit = Clamp(Range.getBegin()); bool ProvablyNoPostLoop = - SE.isKnownPredicate(ICmpInst::ICMP_SLE, Greatest, Range.getEnd()); + SE.isKnownPredicate(ICmpInst::ICMP_SLT, GreatestSeen, Range.getEnd()); if (!ProvablyNoPostLoop) Result.HighLimit = Clamp(Range.getEnd()); @@ -1194,7 +1282,12 @@ void LoopConstrainer::addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs) { Loop *LoopConstrainer::createClonedLoopStructure(Loop *Original, Loop *Parent, ValueToValueMapTy &VM) { - Loop &New = LPM.addLoop(Parent); + Loop &New = *new Loop(); + if (Parent) + Parent->addChildLoop(&New); + else + LI.addTopLevelLoop(&New); + LPM.addLoop(New); // Add all of the blocks in Original to the new loop. for (auto *BB : Original->blocks()) @@ -1332,28 +1425,35 @@ bool LoopConstrainer::run() { DT.recalculate(F); + // We need to first add all the pre and post loop blocks into the loop + // structures (as part of createClonedLoopStructure), and then update the + // LCSSA form and LoopSimplifyForm. This is necessary for correctly updating + // LI when LoopSimplifyForm is generated. + Loop *PreL = nullptr, *PostL = nullptr; if (!PreLoop.Blocks.empty()) { - auto *L = createClonedLoopStructure( + PreL = createClonedLoopStructure( &OriginalLoop, OriginalLoop.getParentLoop(), PreLoop.Map); - formLCSSARecursively(*L, DT, &LI, &SE); - simplifyLoop(L, &DT, &LI, &SE, nullptr, true); - // Pre loops are slow paths, we do not need to perform any loop - // optimizations on them. - DisableAllLoopOptsOnLoop(*L); } if (!PostLoop.Blocks.empty()) { - auto *L = createClonedLoopStructure( + PostL = createClonedLoopStructure( &OriginalLoop, OriginalLoop.getParentLoop(), PostLoop.Map); + } + + // This function canonicalizes the loop into Loop-Simplify and LCSSA forms. + auto CanonicalizeLoop = [&] (Loop *L, bool IsOriginalLoop) { formLCSSARecursively(*L, DT, &LI, &SE); simplifyLoop(L, &DT, &LI, &SE, nullptr, true); - // Post loops are slow paths, we do not need to perform any loop + // Pre/post loops are slow paths, we do not need to perform any loop // optimizations on them. - DisableAllLoopOptsOnLoop(*L); - } - - formLCSSARecursively(OriginalLoop, DT, &LI, &SE); - simplifyLoop(&OriginalLoop, &DT, &LI, &SE, nullptr, true); + if (!IsOriginalLoop) + DisableAllLoopOptsOnLoop(*L); + }; + if (PreL) + CanonicalizeLoop(PreL, false); + if (PostL) + CanonicalizeLoop(PostL, false); + CanonicalizeLoop(&OriginalLoop, true); return true; } diff --git a/contrib/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/contrib/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp new file mode 100644 index 0000000..89b28f0 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp @@ -0,0 +1,969 @@ +//===-- NVPTXInferAddressSpace.cpp - ---------------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// CUDA C/C++ includes memory space designation as variable type qualifers (such +// as __global__ and __shared__). Knowing the space of a memory access allows +// CUDA compilers to emit faster PTX loads and stores. For example, a load from +// shared memory can be translated to `ld.shared` which is roughly 10% faster +// than a generic `ld` on an NVIDIA Tesla K40c. +// +// Unfortunately, type qualifiers only apply to variable declarations, so CUDA +// compilers must infer the memory space of an address expression from +// type-qualified variables. +// +// LLVM IR uses non-zero (so-called) specific address spaces to represent memory +// spaces (e.g. addrspace(3) means shared memory). The Clang frontend +// places only type-qualified variables in specific address spaces, and then +// conservatively `addrspacecast`s each type-qualified variable to addrspace(0) +// (so-called the generic address space) for other instructions to use. +// +// For example, the Clang translates the following CUDA code +// __shared__ float a[10]; +// float v = a[i]; +// to +// %0 = addrspacecast [10 x float] addrspace(3)* @a to [10 x float]* +// %1 = gep [10 x float], [10 x float]* %0, i64 0, i64 %i +// %v = load float, float* %1 ; emits ld.f32 +// @a is in addrspace(3) since it's type-qualified, but its use from %1 is +// redirected to %0 (the generic version of @a). +// +// The optimization implemented in this file propagates specific address spaces +// from type-qualified variable declarations to its users. For example, it +// optimizes the above IR to +// %1 = gep [10 x float] addrspace(3)* @a, i64 0, i64 %i +// %v = load float addrspace(3)* %1 ; emits ld.shared.f32 +// propagating the addrspace(3) from @a to %1. As the result, the NVPTX +// codegen is able to emit ld.shared.f32 for %v. +// +// Address space inference works in two steps. First, it uses a data-flow +// analysis to infer as many generic pointers as possible to point to only one +// specific address space. In the above example, it can prove that %1 only +// points to addrspace(3). This algorithm was published in +// CUDA: Compiling and optimizing for a GPU platform +// Chakrabarti, Grover, Aarts, Kong, Kudlur, Lin, Marathe, Murphy, Wang +// ICCS 2012 +// +// Then, address space inference replaces all refinable generic pointers with +// equivalent specific pointers. +// +// The major challenge of implementing this optimization is handling PHINodes, +// which may create loops in the data flow graph. This brings two complications. +// +// First, the data flow analysis in Step 1 needs to be circular. For example, +// %generic.input = addrspacecast float addrspace(3)* %input to float* +// loop: +// %y = phi [ %generic.input, %y2 ] +// %y2 = getelementptr %y, 1 +// %v = load %y2 +// br ..., label %loop, ... +// proving %y specific requires proving both %generic.input and %y2 specific, +// but proving %y2 specific circles back to %y. To address this complication, +// the data flow analysis operates on a lattice: +// uninitialized > specific address spaces > generic. +// All address expressions (our implementation only considers phi, bitcast, +// addrspacecast, and getelementptr) start with the uninitialized address space. +// The monotone transfer function moves the address space of a pointer down a +// lattice path from uninitialized to specific and then to generic. A join +// operation of two different specific address spaces pushes the expression down +// to the generic address space. The analysis completes once it reaches a fixed +// point. +// +// Second, IR rewriting in Step 2 also needs to be circular. For example, +// converting %y to addrspace(3) requires the compiler to know the converted +// %y2, but converting %y2 needs the converted %y. To address this complication, +// we break these cycles using "undef" placeholders. When converting an +// instruction `I` to a new address space, if its operand `Op` is not converted +// yet, we let `I` temporarily use `undef` and fix all the uses of undef later. +// For instance, our algorithm first converts %y to +// %y' = phi float addrspace(3)* [ %input, undef ] +// Then, it converts %y2 to +// %y2' = getelementptr %y', 1 +// Finally, it fixes the undef in %y' so that +// %y' = phi float addrspace(3)* [ %input, %y2' ] +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Operator.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/ValueMapper.h" + +#define DEBUG_TYPE "infer-address-spaces" + +using namespace llvm; + +namespace { +static const unsigned UninitializedAddressSpace = ~0u; + +using ValueToAddrSpaceMapTy = DenseMap<const Value *, unsigned>; + +/// \brief InferAddressSpaces +class InferAddressSpaces : public FunctionPass { + /// Target specific address space which uses of should be replaced if + /// possible. + unsigned FlatAddrSpace; + +public: + static char ID; + + InferAddressSpaces() : FunctionPass(ID) {} + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addRequired<TargetTransformInfoWrapperPass>(); + } + + bool runOnFunction(Function &F) override; + +private: + // Returns the new address space of V if updated; otherwise, returns None. + Optional<unsigned> + updateAddressSpace(const Value &V, + const ValueToAddrSpaceMapTy &InferredAddrSpace) const; + + // Tries to infer the specific address space of each address expression in + // Postorder. + void inferAddressSpaces(ArrayRef<WeakTrackingVH> Postorder, + ValueToAddrSpaceMapTy *InferredAddrSpace) const; + + bool isSafeToCastConstAddrSpace(Constant *C, unsigned NewAS) const; + + // Changes the flat address expressions in function F to point to specific + // address spaces if InferredAddrSpace says so. Postorder is the postorder of + // all flat expressions in the use-def graph of function F. + bool + rewriteWithNewAddressSpaces(ArrayRef<WeakTrackingVH> Postorder, + const ValueToAddrSpaceMapTy &InferredAddrSpace, + Function *F) const; + + void appendsFlatAddressExpressionToPostorderStack( + Value *V, std::vector<std::pair<Value *, bool>> &PostorderStack, + DenseSet<Value *> &Visited) const; + + bool rewriteIntrinsicOperands(IntrinsicInst *II, + Value *OldV, Value *NewV) const; + void collectRewritableIntrinsicOperands( + IntrinsicInst *II, + std::vector<std::pair<Value *, bool>> &PostorderStack, + DenseSet<Value *> &Visited) const; + + std::vector<WeakTrackingVH> collectFlatAddressExpressions(Function &F) const; + + Value *cloneValueWithNewAddressSpace( + Value *V, unsigned NewAddrSpace, + const ValueToValueMapTy &ValueWithNewAddrSpace, + SmallVectorImpl<const Use *> *UndefUsesToFix) const; + unsigned joinAddressSpaces(unsigned AS1, unsigned AS2) const; +}; +} // end anonymous namespace + +char InferAddressSpaces::ID = 0; + +namespace llvm { +void initializeInferAddressSpacesPass(PassRegistry &); +} + +INITIALIZE_PASS(InferAddressSpaces, DEBUG_TYPE, "Infer address spaces", + false, false) + +// Returns true if V is an address expression. +// TODO: Currently, we consider only phi, bitcast, addrspacecast, and +// getelementptr operators. +static bool isAddressExpression(const Value &V) { + if (!isa<Operator>(V)) + return false; + + switch (cast<Operator>(V).getOpcode()) { + case Instruction::PHI: + case Instruction::BitCast: + case Instruction::AddrSpaceCast: + case Instruction::GetElementPtr: + case Instruction::Select: + return true; + default: + return false; + } +} + +// Returns the pointer operands of V. +// +// Precondition: V is an address expression. +static SmallVector<Value *, 2> getPointerOperands(const Value &V) { + const Operator &Op = cast<Operator>(V); + switch (Op.getOpcode()) { + case Instruction::PHI: { + auto IncomingValues = cast<PHINode>(Op).incoming_values(); + return SmallVector<Value *, 2>(IncomingValues.begin(), + IncomingValues.end()); + } + case Instruction::BitCast: + case Instruction::AddrSpaceCast: + case Instruction::GetElementPtr: + return {Op.getOperand(0)}; + case Instruction::Select: + return {Op.getOperand(1), Op.getOperand(2)}; + default: + llvm_unreachable("Unexpected instruction type."); + } +} + +// TODO: Move logic to TTI? +bool InferAddressSpaces::rewriteIntrinsicOperands(IntrinsicInst *II, + Value *OldV, + Value *NewV) const { + Module *M = II->getParent()->getParent()->getParent(); + + switch (II->getIntrinsicID()) { + case Intrinsic::amdgcn_atomic_inc: + case Intrinsic::amdgcn_atomic_dec:{ + const ConstantInt *IsVolatile = dyn_cast<ConstantInt>(II->getArgOperand(4)); + if (!IsVolatile || !IsVolatile->isZero()) + return false; + + LLVM_FALLTHROUGH; + } + case Intrinsic::objectsize: { + Type *DestTy = II->getType(); + Type *SrcTy = NewV->getType(); + Function *NewDecl = + Intrinsic::getDeclaration(M, II->getIntrinsicID(), {DestTy, SrcTy}); + II->setArgOperand(0, NewV); + II->setCalledFunction(NewDecl); + return true; + } + default: + return false; + } +} + +// TODO: Move logic to TTI? +void InferAddressSpaces::collectRewritableIntrinsicOperands( + IntrinsicInst *II, std::vector<std::pair<Value *, bool>> &PostorderStack, + DenseSet<Value *> &Visited) const { + switch (II->getIntrinsicID()) { + case Intrinsic::objectsize: + case Intrinsic::amdgcn_atomic_inc: + case Intrinsic::amdgcn_atomic_dec: + appendsFlatAddressExpressionToPostorderStack(II->getArgOperand(0), + PostorderStack, Visited); + break; + default: + break; + } +} + +// Returns all flat address expressions in function F. The elements are +// If V is an unvisited flat address expression, appends V to PostorderStack +// and marks it as visited. +void InferAddressSpaces::appendsFlatAddressExpressionToPostorderStack( + Value *V, std::vector<std::pair<Value *, bool>> &PostorderStack, + DenseSet<Value *> &Visited) const { + assert(V->getType()->isPointerTy()); + + // Generic addressing expressions may be hidden in nested constant + // expressions. + if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) { + // TODO: Look in non-address parts, like icmp operands. + if (isAddressExpression(*CE) && Visited.insert(CE).second) + PostorderStack.push_back(std::make_pair(CE, false)); + + return; + } + + if (isAddressExpression(*V) && + V->getType()->getPointerAddressSpace() == FlatAddrSpace) { + if (Visited.insert(V).second) { + PostorderStack.push_back(std::make_pair(V, false)); + + Operator *Op = cast<Operator>(V); + for (unsigned I = 0, E = Op->getNumOperands(); I != E; ++I) { + if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Op->getOperand(I))) { + if (isAddressExpression(*CE) && Visited.insert(CE).second) + PostorderStack.emplace_back(CE, false); + } + } + } + } +} + +// Returns all flat address expressions in function F. The elements are ordered +// ordered in postorder. +std::vector<WeakTrackingVH> +InferAddressSpaces::collectFlatAddressExpressions(Function &F) const { + // This function implements a non-recursive postorder traversal of a partial + // use-def graph of function F. + std::vector<std::pair<Value *, bool>> PostorderStack; + // The set of visited expressions. + DenseSet<Value *> Visited; + + auto PushPtrOperand = [&](Value *Ptr) { + appendsFlatAddressExpressionToPostorderStack(Ptr, PostorderStack, + Visited); + }; + + // Look at operations that may be interesting accelerate by moving to a known + // address space. We aim at generating after loads and stores, but pure + // addressing calculations may also be faster. + for (Instruction &I : instructions(F)) { + if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) { + if (!GEP->getType()->isVectorTy()) + PushPtrOperand(GEP->getPointerOperand()); + } else if (auto *LI = dyn_cast<LoadInst>(&I)) + PushPtrOperand(LI->getPointerOperand()); + else if (auto *SI = dyn_cast<StoreInst>(&I)) + PushPtrOperand(SI->getPointerOperand()); + else if (auto *RMW = dyn_cast<AtomicRMWInst>(&I)) + PushPtrOperand(RMW->getPointerOperand()); + else if (auto *CmpX = dyn_cast<AtomicCmpXchgInst>(&I)) + PushPtrOperand(CmpX->getPointerOperand()); + else if (auto *MI = dyn_cast<MemIntrinsic>(&I)) { + // For memset/memcpy/memmove, any pointer operand can be replaced. + PushPtrOperand(MI->getRawDest()); + + // Handle 2nd operand for memcpy/memmove. + if (auto *MTI = dyn_cast<MemTransferInst>(MI)) + PushPtrOperand(MTI->getRawSource()); + } else if (auto *II = dyn_cast<IntrinsicInst>(&I)) + collectRewritableIntrinsicOperands(II, PostorderStack, Visited); + else if (ICmpInst *Cmp = dyn_cast<ICmpInst>(&I)) { + // FIXME: Handle vectors of pointers + if (Cmp->getOperand(0)->getType()->isPointerTy()) { + PushPtrOperand(Cmp->getOperand(0)); + PushPtrOperand(Cmp->getOperand(1)); + } + } else if (auto *ASC = dyn_cast<AddrSpaceCastInst>(&I)) { + if (!ASC->getType()->isVectorTy()) + PushPtrOperand(ASC->getPointerOperand()); + } + } + + std::vector<WeakTrackingVH> Postorder; // The resultant postorder. + while (!PostorderStack.empty()) { + Value *TopVal = PostorderStack.back().first; + // If the operands of the expression on the top are already explored, + // adds that expression to the resultant postorder. + if (PostorderStack.back().second) { + if (TopVal->getType()->getPointerAddressSpace() == FlatAddrSpace) + Postorder.push_back(TopVal); + PostorderStack.pop_back(); + continue; + } + // Otherwise, adds its operands to the stack and explores them. + PostorderStack.back().second = true; + for (Value *PtrOperand : getPointerOperands(*TopVal)) { + appendsFlatAddressExpressionToPostorderStack(PtrOperand, PostorderStack, + Visited); + } + } + return Postorder; +} + +// A helper function for cloneInstructionWithNewAddressSpace. Returns the clone +// of OperandUse.get() in the new address space. If the clone is not ready yet, +// returns an undef in the new address space as a placeholder. +static Value *operandWithNewAddressSpaceOrCreateUndef( + const Use &OperandUse, unsigned NewAddrSpace, + const ValueToValueMapTy &ValueWithNewAddrSpace, + SmallVectorImpl<const Use *> *UndefUsesToFix) { + Value *Operand = OperandUse.get(); + + Type *NewPtrTy = + Operand->getType()->getPointerElementType()->getPointerTo(NewAddrSpace); + + if (Constant *C = dyn_cast<Constant>(Operand)) + return ConstantExpr::getAddrSpaceCast(C, NewPtrTy); + + if (Value *NewOperand = ValueWithNewAddrSpace.lookup(Operand)) + return NewOperand; + + UndefUsesToFix->push_back(&OperandUse); + return UndefValue::get(NewPtrTy); +} + +// Returns a clone of `I` with its operands converted to those specified in +// ValueWithNewAddrSpace. Due to potential cycles in the data flow graph, an +// operand whose address space needs to be modified might not exist in +// ValueWithNewAddrSpace. In that case, uses undef as a placeholder operand and +// adds that operand use to UndefUsesToFix so that caller can fix them later. +// +// Note that we do not necessarily clone `I`, e.g., if it is an addrspacecast +// from a pointer whose type already matches. Therefore, this function returns a +// Value* instead of an Instruction*. +static Value *cloneInstructionWithNewAddressSpace( + Instruction *I, unsigned NewAddrSpace, + const ValueToValueMapTy &ValueWithNewAddrSpace, + SmallVectorImpl<const Use *> *UndefUsesToFix) { + Type *NewPtrType = + I->getType()->getPointerElementType()->getPointerTo(NewAddrSpace); + + if (I->getOpcode() == Instruction::AddrSpaceCast) { + Value *Src = I->getOperand(0); + // Because `I` is flat, the source address space must be specific. + // Therefore, the inferred address space must be the source space, according + // to our algorithm. + assert(Src->getType()->getPointerAddressSpace() == NewAddrSpace); + if (Src->getType() != NewPtrType) + return new BitCastInst(Src, NewPtrType); + return Src; + } + + // Computes the converted pointer operands. + SmallVector<Value *, 4> NewPointerOperands; + for (const Use &OperandUse : I->operands()) { + if (!OperandUse.get()->getType()->isPointerTy()) + NewPointerOperands.push_back(nullptr); + else + NewPointerOperands.push_back(operandWithNewAddressSpaceOrCreateUndef( + OperandUse, NewAddrSpace, ValueWithNewAddrSpace, UndefUsesToFix)); + } + + switch (I->getOpcode()) { + case Instruction::BitCast: + return new BitCastInst(NewPointerOperands[0], NewPtrType); + case Instruction::PHI: { + assert(I->getType()->isPointerTy()); + PHINode *PHI = cast<PHINode>(I); + PHINode *NewPHI = PHINode::Create(NewPtrType, PHI->getNumIncomingValues()); + for (unsigned Index = 0; Index < PHI->getNumIncomingValues(); ++Index) { + unsigned OperandNo = PHINode::getOperandNumForIncomingValue(Index); + NewPHI->addIncoming(NewPointerOperands[OperandNo], + PHI->getIncomingBlock(Index)); + } + return NewPHI; + } + case Instruction::GetElementPtr: { + GetElementPtrInst *GEP = cast<GetElementPtrInst>(I); + GetElementPtrInst *NewGEP = GetElementPtrInst::Create( + GEP->getSourceElementType(), NewPointerOperands[0], + SmallVector<Value *, 4>(GEP->idx_begin(), GEP->idx_end())); + NewGEP->setIsInBounds(GEP->isInBounds()); + return NewGEP; + } + case Instruction::Select: { + assert(I->getType()->isPointerTy()); + return SelectInst::Create(I->getOperand(0), NewPointerOperands[1], + NewPointerOperands[2], "", nullptr, I); + } + default: + llvm_unreachable("Unexpected opcode"); + } +} + +// Similar to cloneInstructionWithNewAddressSpace, returns a clone of the +// constant expression `CE` with its operands replaced as specified in +// ValueWithNewAddrSpace. +static Value *cloneConstantExprWithNewAddressSpace( + ConstantExpr *CE, unsigned NewAddrSpace, + const ValueToValueMapTy &ValueWithNewAddrSpace) { + Type *TargetType = + CE->getType()->getPointerElementType()->getPointerTo(NewAddrSpace); + + if (CE->getOpcode() == Instruction::AddrSpaceCast) { + // Because CE is flat, the source address space must be specific. + // Therefore, the inferred address space must be the source space according + // to our algorithm. + assert(CE->getOperand(0)->getType()->getPointerAddressSpace() == + NewAddrSpace); + return ConstantExpr::getBitCast(CE->getOperand(0), TargetType); + } + + if (CE->getOpcode() == Instruction::BitCast) { + if (Value *NewOperand = ValueWithNewAddrSpace.lookup(CE->getOperand(0))) + return ConstantExpr::getBitCast(cast<Constant>(NewOperand), TargetType); + return ConstantExpr::getAddrSpaceCast(CE, TargetType); + } + + if (CE->getOpcode() == Instruction::Select) { + Constant *Src0 = CE->getOperand(1); + Constant *Src1 = CE->getOperand(2); + if (Src0->getType()->getPointerAddressSpace() == + Src1->getType()->getPointerAddressSpace()) { + + return ConstantExpr::getSelect( + CE->getOperand(0), ConstantExpr::getAddrSpaceCast(Src0, TargetType), + ConstantExpr::getAddrSpaceCast(Src1, TargetType)); + } + } + + // Computes the operands of the new constant expression. + bool IsNew = false; + SmallVector<Constant *, 4> NewOperands; + for (unsigned Index = 0; Index < CE->getNumOperands(); ++Index) { + Constant *Operand = CE->getOperand(Index); + // If the address space of `Operand` needs to be modified, the new operand + // with the new address space should already be in ValueWithNewAddrSpace + // because (1) the constant expressions we consider (i.e. addrspacecast, + // bitcast, and getelementptr) do not incur cycles in the data flow graph + // and (2) this function is called on constant expressions in postorder. + if (Value *NewOperand = ValueWithNewAddrSpace.lookup(Operand)) { + IsNew = true; + NewOperands.push_back(cast<Constant>(NewOperand)); + } else { + // Otherwise, reuses the old operand. + NewOperands.push_back(Operand); + } + } + + // If !IsNew, we will replace the Value with itself. However, replaced values + // are assumed to wrapped in a addrspace cast later so drop it now. + if (!IsNew) + return nullptr; + + if (CE->getOpcode() == Instruction::GetElementPtr) { + // Needs to specify the source type while constructing a getelementptr + // constant expression. + return CE->getWithOperands( + NewOperands, TargetType, /*OnlyIfReduced=*/false, + NewOperands[0]->getType()->getPointerElementType()); + } + + return CE->getWithOperands(NewOperands, TargetType); +} + +// Returns a clone of the value `V`, with its operands replaced as specified in +// ValueWithNewAddrSpace. This function is called on every flat address +// expression whose address space needs to be modified, in postorder. +// +// See cloneInstructionWithNewAddressSpace for the meaning of UndefUsesToFix. +Value *InferAddressSpaces::cloneValueWithNewAddressSpace( + Value *V, unsigned NewAddrSpace, + const ValueToValueMapTy &ValueWithNewAddrSpace, + SmallVectorImpl<const Use *> *UndefUsesToFix) const { + // All values in Postorder are flat address expressions. + assert(isAddressExpression(*V) && + V->getType()->getPointerAddressSpace() == FlatAddrSpace); + + if (Instruction *I = dyn_cast<Instruction>(V)) { + Value *NewV = cloneInstructionWithNewAddressSpace( + I, NewAddrSpace, ValueWithNewAddrSpace, UndefUsesToFix); + if (Instruction *NewI = dyn_cast<Instruction>(NewV)) { + if (NewI->getParent() == nullptr) { + NewI->insertBefore(I); + NewI->takeName(I); + } + } + return NewV; + } + + return cloneConstantExprWithNewAddressSpace( + cast<ConstantExpr>(V), NewAddrSpace, ValueWithNewAddrSpace); +} + +// Defines the join operation on the address space lattice (see the file header +// comments). +unsigned InferAddressSpaces::joinAddressSpaces(unsigned AS1, + unsigned AS2) const { + if (AS1 == FlatAddrSpace || AS2 == FlatAddrSpace) + return FlatAddrSpace; + + if (AS1 == UninitializedAddressSpace) + return AS2; + if (AS2 == UninitializedAddressSpace) + return AS1; + + // The join of two different specific address spaces is flat. + return (AS1 == AS2) ? AS1 : FlatAddrSpace; +} + +bool InferAddressSpaces::runOnFunction(Function &F) { + if (skipFunction(F)) + return false; + + const TargetTransformInfo &TTI = + getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + FlatAddrSpace = TTI.getFlatAddressSpace(); + if (FlatAddrSpace == UninitializedAddressSpace) + return false; + + // Collects all flat address expressions in postorder. + std::vector<WeakTrackingVH> Postorder = collectFlatAddressExpressions(F); + + // Runs a data-flow analysis to refine the address spaces of every expression + // in Postorder. + ValueToAddrSpaceMapTy InferredAddrSpace; + inferAddressSpaces(Postorder, &InferredAddrSpace); + + // Changes the address spaces of the flat address expressions who are inferred + // to point to a specific address space. + return rewriteWithNewAddressSpaces(Postorder, InferredAddrSpace, &F); +} + +// Constants need to be tracked through RAUW to handle cases with nested +// constant expressions, so wrap values in WeakTrackingVH. +void InferAddressSpaces::inferAddressSpaces( + ArrayRef<WeakTrackingVH> Postorder, + ValueToAddrSpaceMapTy *InferredAddrSpace) const { + SetVector<Value *> Worklist(Postorder.begin(), Postorder.end()); + // Initially, all expressions are in the uninitialized address space. + for (Value *V : Postorder) + (*InferredAddrSpace)[V] = UninitializedAddressSpace; + + while (!Worklist.empty()) { + Value *V = Worklist.pop_back_val(); + + // Tries to update the address space of the stack top according to the + // address spaces of its operands. + DEBUG(dbgs() << "Updating the address space of\n " << *V << '\n'); + Optional<unsigned> NewAS = updateAddressSpace(*V, *InferredAddrSpace); + if (!NewAS.hasValue()) + continue; + // If any updates are made, grabs its users to the worklist because + // their address spaces can also be possibly updated. + DEBUG(dbgs() << " to " << NewAS.getValue() << '\n'); + (*InferredAddrSpace)[V] = NewAS.getValue(); + + for (Value *User : V->users()) { + // Skip if User is already in the worklist. + if (Worklist.count(User)) + continue; + + auto Pos = InferredAddrSpace->find(User); + // Our algorithm only updates the address spaces of flat address + // expressions, which are those in InferredAddrSpace. + if (Pos == InferredAddrSpace->end()) + continue; + + // Function updateAddressSpace moves the address space down a lattice + // path. Therefore, nothing to do if User is already inferred as flat (the + // bottom element in the lattice). + if (Pos->second == FlatAddrSpace) + continue; + + Worklist.insert(User); + } + } +} + +Optional<unsigned> InferAddressSpaces::updateAddressSpace( + const Value &V, const ValueToAddrSpaceMapTy &InferredAddrSpace) const { + assert(InferredAddrSpace.count(&V)); + + // The new inferred address space equals the join of the address spaces + // of all its pointer operands. + unsigned NewAS = UninitializedAddressSpace; + + const Operator &Op = cast<Operator>(V); + if (Op.getOpcode() == Instruction::Select) { + Value *Src0 = Op.getOperand(1); + Value *Src1 = Op.getOperand(2); + + auto I = InferredAddrSpace.find(Src0); + unsigned Src0AS = (I != InferredAddrSpace.end()) ? + I->second : Src0->getType()->getPointerAddressSpace(); + + auto J = InferredAddrSpace.find(Src1); + unsigned Src1AS = (J != InferredAddrSpace.end()) ? + J->second : Src1->getType()->getPointerAddressSpace(); + + auto *C0 = dyn_cast<Constant>(Src0); + auto *C1 = dyn_cast<Constant>(Src1); + + // If one of the inputs is a constant, we may be able to do a constant + // addrspacecast of it. Defer inferring the address space until the input + // address space is known. + if ((C1 && Src0AS == UninitializedAddressSpace) || + (C0 && Src1AS == UninitializedAddressSpace)) + return None; + + if (C0 && isSafeToCastConstAddrSpace(C0, Src1AS)) + NewAS = Src1AS; + else if (C1 && isSafeToCastConstAddrSpace(C1, Src0AS)) + NewAS = Src0AS; + else + NewAS = joinAddressSpaces(Src0AS, Src1AS); + } else { + for (Value *PtrOperand : getPointerOperands(V)) { + auto I = InferredAddrSpace.find(PtrOperand); + unsigned OperandAS = I != InferredAddrSpace.end() ? + I->second : PtrOperand->getType()->getPointerAddressSpace(); + + // join(flat, *) = flat. So we can break if NewAS is already flat. + NewAS = joinAddressSpaces(NewAS, OperandAS); + if (NewAS == FlatAddrSpace) + break; + } + } + + unsigned OldAS = InferredAddrSpace.lookup(&V); + assert(OldAS != FlatAddrSpace); + if (OldAS == NewAS) + return None; + return NewAS; +} + +/// \p returns true if \p U is the pointer operand of a memory instruction with +/// a single pointer operand that can have its address space changed by simply +/// mutating the use to a new value. +static bool isSimplePointerUseValidToReplace(Use &U) { + User *Inst = U.getUser(); + unsigned OpNo = U.getOperandNo(); + + if (auto *LI = dyn_cast<LoadInst>(Inst)) + return OpNo == LoadInst::getPointerOperandIndex() && !LI->isVolatile(); + + if (auto *SI = dyn_cast<StoreInst>(Inst)) + return OpNo == StoreInst::getPointerOperandIndex() && !SI->isVolatile(); + + if (auto *RMW = dyn_cast<AtomicRMWInst>(Inst)) + return OpNo == AtomicRMWInst::getPointerOperandIndex() && !RMW->isVolatile(); + + if (auto *CmpX = dyn_cast<AtomicCmpXchgInst>(Inst)) { + return OpNo == AtomicCmpXchgInst::getPointerOperandIndex() && + !CmpX->isVolatile(); + } + + return false; +} + +/// Update memory intrinsic uses that require more complex processing than +/// simple memory instructions. Thse require re-mangling and may have multiple +/// pointer operands. +static bool handleMemIntrinsicPtrUse(MemIntrinsic *MI, Value *OldV, + Value *NewV) { + IRBuilder<> B(MI); + MDNode *TBAA = MI->getMetadata(LLVMContext::MD_tbaa); + MDNode *ScopeMD = MI->getMetadata(LLVMContext::MD_alias_scope); + MDNode *NoAliasMD = MI->getMetadata(LLVMContext::MD_noalias); + + if (auto *MSI = dyn_cast<MemSetInst>(MI)) { + B.CreateMemSet(NewV, MSI->getValue(), + MSI->getLength(), MSI->getAlignment(), + false, // isVolatile + TBAA, ScopeMD, NoAliasMD); + } else if (auto *MTI = dyn_cast<MemTransferInst>(MI)) { + Value *Src = MTI->getRawSource(); + Value *Dest = MTI->getRawDest(); + + // Be careful in case this is a self-to-self copy. + if (Src == OldV) + Src = NewV; + + if (Dest == OldV) + Dest = NewV; + + if (isa<MemCpyInst>(MTI)) { + MDNode *TBAAStruct = MTI->getMetadata(LLVMContext::MD_tbaa_struct); + B.CreateMemCpy(Dest, Src, MTI->getLength(), + MTI->getAlignment(), + false, // isVolatile + TBAA, TBAAStruct, ScopeMD, NoAliasMD); + } else { + assert(isa<MemMoveInst>(MTI)); + B.CreateMemMove(Dest, Src, MTI->getLength(), + MTI->getAlignment(), + false, // isVolatile + TBAA, ScopeMD, NoAliasMD); + } + } else + llvm_unreachable("unhandled MemIntrinsic"); + + MI->eraseFromParent(); + return true; +} + +// \p returns true if it is OK to change the address space of constant \p C with +// a ConstantExpr addrspacecast. +bool InferAddressSpaces::isSafeToCastConstAddrSpace(Constant *C, unsigned NewAS) const { + assert(NewAS != UninitializedAddressSpace); + + unsigned SrcAS = C->getType()->getPointerAddressSpace(); + if (SrcAS == NewAS || isa<UndefValue>(C)) + return true; + + // Prevent illegal casts between different non-flat address spaces. + if (SrcAS != FlatAddrSpace && NewAS != FlatAddrSpace) + return false; + + if (isa<ConstantPointerNull>(C)) + return true; + + if (auto *Op = dyn_cast<Operator>(C)) { + // If we already have a constant addrspacecast, it should be safe to cast it + // off. + if (Op->getOpcode() == Instruction::AddrSpaceCast) + return isSafeToCastConstAddrSpace(cast<Constant>(Op->getOperand(0)), NewAS); + + if (Op->getOpcode() == Instruction::IntToPtr && + Op->getType()->getPointerAddressSpace() == FlatAddrSpace) + return true; + } + + return false; +} + +static Value::use_iterator skipToNextUser(Value::use_iterator I, + Value::use_iterator End) { + User *CurUser = I->getUser(); + ++I; + + while (I != End && I->getUser() == CurUser) + ++I; + + return I; +} + +bool InferAddressSpaces::rewriteWithNewAddressSpaces( + ArrayRef<WeakTrackingVH> Postorder, + const ValueToAddrSpaceMapTy &InferredAddrSpace, Function *F) const { + // For each address expression to be modified, creates a clone of it with its + // pointer operands converted to the new address space. Since the pointer + // operands are converted, the clone is naturally in the new address space by + // construction. + ValueToValueMapTy ValueWithNewAddrSpace; + SmallVector<const Use *, 32> UndefUsesToFix; + for (Value* V : Postorder) { + unsigned NewAddrSpace = InferredAddrSpace.lookup(V); + if (V->getType()->getPointerAddressSpace() != NewAddrSpace) { + ValueWithNewAddrSpace[V] = cloneValueWithNewAddressSpace( + V, NewAddrSpace, ValueWithNewAddrSpace, &UndefUsesToFix); + } + } + + if (ValueWithNewAddrSpace.empty()) + return false; + + // Fixes all the undef uses generated by cloneInstructionWithNewAddressSpace. + for (const Use *UndefUse : UndefUsesToFix) { + User *V = UndefUse->getUser(); + User *NewV = cast<User>(ValueWithNewAddrSpace.lookup(V)); + unsigned OperandNo = UndefUse->getOperandNo(); + assert(isa<UndefValue>(NewV->getOperand(OperandNo))); + NewV->setOperand(OperandNo, ValueWithNewAddrSpace.lookup(UndefUse->get())); + } + + SmallVector<Instruction *, 16> DeadInstructions; + + // Replaces the uses of the old address expressions with the new ones. + for (const WeakTrackingVH &WVH : Postorder) { + assert(WVH && "value was unexpectedly deleted"); + Value *V = WVH; + Value *NewV = ValueWithNewAddrSpace.lookup(V); + if (NewV == nullptr) + continue; + + DEBUG(dbgs() << "Replacing the uses of " << *V + << "\n with\n " << *NewV << '\n'); + + if (Constant *C = dyn_cast<Constant>(V)) { + Constant *Replace = ConstantExpr::getAddrSpaceCast(cast<Constant>(NewV), + C->getType()); + if (C != Replace) { + DEBUG(dbgs() << "Inserting replacement const cast: " + << Replace << ": " << *Replace << '\n'); + C->replaceAllUsesWith(Replace); + V = Replace; + } + } + + Value::use_iterator I, E, Next; + for (I = V->use_begin(), E = V->use_end(); I != E; ) { + Use &U = *I; + + // Some users may see the same pointer operand in multiple operands. Skip + // to the next instruction. + I = skipToNextUser(I, E); + + if (isSimplePointerUseValidToReplace(U)) { + // If V is used as the pointer operand of a compatible memory operation, + // sets the pointer operand to NewV. This replacement does not change + // the element type, so the resultant load/store is still valid. + U.set(NewV); + continue; + } + + User *CurUser = U.getUser(); + // Handle more complex cases like intrinsic that need to be remangled. + if (auto *MI = dyn_cast<MemIntrinsic>(CurUser)) { + if (!MI->isVolatile() && handleMemIntrinsicPtrUse(MI, V, NewV)) + continue; + } + + if (auto *II = dyn_cast<IntrinsicInst>(CurUser)) { + if (rewriteIntrinsicOperands(II, V, NewV)) + continue; + } + + if (isa<Instruction>(CurUser)) { + if (ICmpInst *Cmp = dyn_cast<ICmpInst>(CurUser)) { + // If we can infer that both pointers are in the same addrspace, + // transform e.g. + // %cmp = icmp eq float* %p, %q + // into + // %cmp = icmp eq float addrspace(3)* %new_p, %new_q + + unsigned NewAS = NewV->getType()->getPointerAddressSpace(); + int SrcIdx = U.getOperandNo(); + int OtherIdx = (SrcIdx == 0) ? 1 : 0; + Value *OtherSrc = Cmp->getOperand(OtherIdx); + + if (Value *OtherNewV = ValueWithNewAddrSpace.lookup(OtherSrc)) { + if (OtherNewV->getType()->getPointerAddressSpace() == NewAS) { + Cmp->setOperand(OtherIdx, OtherNewV); + Cmp->setOperand(SrcIdx, NewV); + continue; + } + } + + // Even if the type mismatches, we can cast the constant. + if (auto *KOtherSrc = dyn_cast<Constant>(OtherSrc)) { + if (isSafeToCastConstAddrSpace(KOtherSrc, NewAS)) { + Cmp->setOperand(SrcIdx, NewV); + Cmp->setOperand(OtherIdx, + ConstantExpr::getAddrSpaceCast(KOtherSrc, NewV->getType())); + continue; + } + } + } + + if (AddrSpaceCastInst *ASC = dyn_cast<AddrSpaceCastInst>(CurUser)) { + unsigned NewAS = NewV->getType()->getPointerAddressSpace(); + if (ASC->getDestAddressSpace() == NewAS) { + ASC->replaceAllUsesWith(NewV); + DeadInstructions.push_back(ASC); + continue; + } + } + + // Otherwise, replaces the use with flat(NewV). + if (Instruction *I = dyn_cast<Instruction>(V)) { + BasicBlock::iterator InsertPos = std::next(I->getIterator()); + while (isa<PHINode>(InsertPos)) + ++InsertPos; + U.set(new AddrSpaceCastInst(NewV, V->getType(), "", &*InsertPos)); + } else { + U.set(ConstantExpr::getAddrSpaceCast(cast<Constant>(NewV), + V->getType())); + } + } + } + + if (V->use_empty()) { + if (Instruction *I = dyn_cast<Instruction>(V)) + DeadInstructions.push_back(I); + } + } + + for (Instruction *I : DeadInstructions) + RecursivelyDeleteTriviallyDeadInstructions(I); + + return true; +} + +FunctionPass *llvm::createInferAddressSpacesPass() { + return new InferAddressSpaces(); +} diff --git a/contrib/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/contrib/llvm/lib/Transforms/Scalar/JumpThreading.cpp index 1870c3d..dc9143b 100644 --- a/contrib/llvm/lib/Transforms/Scalar/JumpThreading.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/JumpThreading.cpp @@ -12,29 +12,33 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/JumpThreading.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Statistic.h" -#include "llvm/Analysis/GlobalsModRef.h" -#include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/BlockFrequencyInfoImpl.h" +#include "llvm/Analysis/CFG.h" #include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/ConstantRange.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Metadata.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/SSAUpdater.h" #include <algorithm> @@ -60,6 +64,11 @@ ImplicationSearchThreshold( "condition to use to thread over a weaker condition"), cl::init(3), cl::Hidden); +static cl::opt<bool> PrintLVIAfterJumpThreading( + "print-lvi-after-jump-threading", + cl::desc("Print the LazyValueInfo cache after JumpThreading"), cl::init(false), + cl::Hidden); + namespace { /// This pass performs 'jump threading', which looks at blocks that have /// multiple predecessors and multiple successors. If one or more of the @@ -89,8 +98,10 @@ namespace { bool runOnFunction(Function &F) override; void getAnalysisUsage(AnalysisUsage &AU) const override { + if (PrintLVIAfterJumpThreading) + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<AAResultsWrapperPass>(); AU.addRequired<LazyValueInfoWrapperPass>(); - AU.addPreserved<LazyValueInfoWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); } @@ -104,6 +115,7 @@ INITIALIZE_PASS_BEGIN(JumpThreading, "jump-threading", "Jump Threading", false, false) INITIALIZE_PASS_DEPENDENCY(LazyValueInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) INITIALIZE_PASS_END(JumpThreading, "jump-threading", "Jump Threading", false, false) @@ -121,16 +133,24 @@ bool JumpThreading::runOnFunction(Function &F) { return false; auto TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); auto LVI = &getAnalysis<LazyValueInfoWrapperPass>().getLVI(); + auto AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); std::unique_ptr<BlockFrequencyInfo> BFI; std::unique_ptr<BranchProbabilityInfo> BPI; bool HasProfileData = F.getEntryCount().hasValue(); if (HasProfileData) { LoopInfo LI{DominatorTree(F)}; - BPI.reset(new BranchProbabilityInfo(F, LI)); + BPI.reset(new BranchProbabilityInfo(F, LI, TLI)); BFI.reset(new BlockFrequencyInfo(F, *BPI, LI)); } - return Impl.runImpl(F, TLI, LVI, HasProfileData, std::move(BFI), - std::move(BPI)); + + bool Changed = Impl.runImpl(F, TLI, LVI, AA, HasProfileData, std::move(BFI), + std::move(BPI)); + if (PrintLVIAfterJumpThreading) { + dbgs() << "LVI for function '" << F.getName() << "':\n"; + LVI->printLVI(F, getAnalysis<DominatorTreeWrapperPass>().getDomTree(), + dbgs()); + } + return Changed; } PreservedAnalyses JumpThreadingPass::run(Function &F, @@ -138,20 +158,19 @@ PreservedAnalyses JumpThreadingPass::run(Function &F, auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); auto &LVI = AM.getResult<LazyValueAnalysis>(F); + auto &AA = AM.getResult<AAManager>(F); + std::unique_ptr<BlockFrequencyInfo> BFI; std::unique_ptr<BranchProbabilityInfo> BPI; bool HasProfileData = F.getEntryCount().hasValue(); if (HasProfileData) { LoopInfo LI{DominatorTree(F)}; - BPI.reset(new BranchProbabilityInfo(F, LI)); + BPI.reset(new BranchProbabilityInfo(F, LI, &TLI)); BFI.reset(new BlockFrequencyInfo(F, *BPI, LI)); } - bool Changed = - runImpl(F, &TLI, &LVI, HasProfileData, std::move(BFI), std::move(BPI)); - // FIXME: We need to invalidate LVI to avoid PR28400. Is there a better - // solution? - AM.invalidate<LazyValueAnalysis>(F); + bool Changed = runImpl(F, &TLI, &LVI, &AA, HasProfileData, std::move(BFI), + std::move(BPI)); if (!Changed) return PreservedAnalyses::all(); @@ -161,18 +180,23 @@ PreservedAnalyses JumpThreadingPass::run(Function &F, } bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_, - LazyValueInfo *LVI_, bool HasProfileData_, + LazyValueInfo *LVI_, AliasAnalysis *AA_, + bool HasProfileData_, std::unique_ptr<BlockFrequencyInfo> BFI_, std::unique_ptr<BranchProbabilityInfo> BPI_) { DEBUG(dbgs() << "Jump threading on function '" << F.getName() << "'\n"); TLI = TLI_; LVI = LVI_; + AA = AA_; BFI.reset(); BPI.reset(); // When profile data is available, we need to update edge weights after // successful jump threading, which requires both BPI and BFI being available. HasProfileData = HasProfileData_; + auto *GuardDecl = F.getParent()->getFunction( + Intrinsic::getName(Intrinsic::experimental_guard)); + HasGuards = GuardDecl && !GuardDecl->use_empty(); if (HasProfileData) { BPI = std::move(BPI_); BFI = std::move(BFI_); @@ -219,33 +243,22 @@ bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_, // Can't thread an unconditional jump, but if the block is "almost // empty", we can replace uses of it with uses of the successor and make // this dead. - // We should not eliminate the loop header either, because eliminating - // a loop header might later prevent LoopSimplify from transforming nested - // loops into simplified form. + // We should not eliminate the loop header or latch either, because + // eliminating a loop header or latch might later prevent LoopSimplify + // from transforming nested loops into simplified form. We will rely on + // later passes in backend to clean up empty blocks. if (BI && BI->isUnconditional() && BB != &BB->getParent()->getEntryBlock() && // If the terminator is the only non-phi instruction, try to nuke it. - BB->getFirstNonPHIOrDbg()->isTerminator() && !LoopHeaders.count(BB)) { - // Since TryToSimplifyUncondBranchFromEmptyBlock may delete the - // block, we have to make sure it isn't in the LoopHeaders set. We - // reinsert afterward if needed. - bool ErasedFromLoopHeaders = LoopHeaders.erase(BB); - BasicBlock *Succ = BI->getSuccessor(0); - + BB->getFirstNonPHIOrDbg()->isTerminator() && !LoopHeaders.count(BB) && + !LoopHeaders.count(BI->getSuccessor(0))) { // FIXME: It is always conservatively correct to drop the info // for a block even if it doesn't get erased. This isn't totally // awesome, but it allows us to use AssertingVH to prevent nasty // dangling pointer issues within LazyValueInfo. LVI->eraseBlock(BB); - if (TryToSimplifyUncondBranchFromEmptyBlock(BB)) { + if (TryToSimplifyUncondBranchFromEmptyBlock(BB)) Changed = true; - // If we deleted BB and BB was the header of a loop, then the - // successor is now the header of the loop. - BB = Succ; - } - - if (ErasedFromLoopHeaders) - LoopHeaders.insert(BB); } } EverChanged |= Changed; @@ -255,10 +268,42 @@ bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_, return EverChanged; } -/// getJumpThreadDuplicationCost - Return the cost of duplicating this block to -/// thread across it. Stop scanning the block when passing the threshold. -static unsigned getJumpThreadDuplicationCost(const BasicBlock *BB, +// Replace uses of Cond with ToVal when safe to do so. If all uses are +// replaced, we can remove Cond. We cannot blindly replace all uses of Cond +// because we may incorrectly replace uses when guards/assumes are uses of +// of `Cond` and we used the guards/assume to reason about the `Cond` value +// at the end of block. RAUW unconditionally replaces all uses +// including the guards/assumes themselves and the uses before the +// guard/assume. +static void ReplaceFoldableUses(Instruction *Cond, Value *ToVal) { + assert(Cond->getType() == ToVal->getType()); + auto *BB = Cond->getParent(); + // We can unconditionally replace all uses in non-local blocks (i.e. uses + // strictly dominated by BB), since LVI information is true from the + // terminator of BB. + replaceNonLocalUsesWith(Cond, ToVal); + for (Instruction &I : reverse(*BB)) { + // Reached the Cond whose uses we are trying to replace, so there are no + // more uses. + if (&I == Cond) + break; + // We only replace uses in instructions that are guaranteed to reach the end + // of BB, where we know Cond is ToVal. + if (!isGuaranteedToTransferExecutionToSuccessor(&I)) + break; + I.replaceUsesOfWith(Cond, ToVal); + } + if (Cond->use_empty() && !Cond->mayHaveSideEffects()) + Cond->eraseFromParent(); +} + +/// Return the cost of duplicating a piece of this block from first non-phi +/// and before StopAt instruction to thread across it. Stop scanning the block +/// when exceeding the threshold. If duplication is impossible, returns ~0U. +static unsigned getJumpThreadDuplicationCost(BasicBlock *BB, + Instruction *StopAt, unsigned Threshold) { + assert(StopAt->getParent() == BB && "Not an instruction from proper BB?"); /// Ignore PHI nodes, these will be flattened when duplication happens. BasicBlock::const_iterator I(BB->getFirstNonPHI()); @@ -266,15 +311,17 @@ static unsigned getJumpThreadDuplicationCost(const BasicBlock *BB, // branch, so they shouldn't count against the duplication cost. unsigned Bonus = 0; - const TerminatorInst *BBTerm = BB->getTerminator(); - // Threading through a switch statement is particularly profitable. If this - // block ends in a switch, decrease its cost to make it more likely to happen. - if (isa<SwitchInst>(BBTerm)) - Bonus = 6; - - // The same holds for indirect branches, but slightly more so. - if (isa<IndirectBrInst>(BBTerm)) - Bonus = 8; + if (BB->getTerminator() == StopAt) { + // Threading through a switch statement is particularly profitable. If this + // block ends in a switch, decrease its cost to make it more likely to + // happen. + if (isa<SwitchInst>(StopAt)) + Bonus = 6; + + // The same holds for indirect branches, but slightly more so. + if (isa<IndirectBrInst>(StopAt)) + Bonus = 8; + } // Bump the threshold up so the early exit from the loop doesn't skip the // terminator-based Size adjustment at the end. @@ -283,7 +330,7 @@ static unsigned getJumpThreadDuplicationCost(const BasicBlock *BB, // Sum up the cost of each instruction until we get to the terminator. Don't // include the terminator because the copy won't include it. unsigned Size = 0; - for (; !isa<TerminatorInst>(I); ++I) { + for (; &*I != StopAt; ++I) { // Stop scanning the block if we've reached the threshold. if (Size > Threshold) @@ -544,7 +591,12 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessors( // Handle compare with phi operand, where the PHI is defined in this block. if (CmpInst *Cmp = dyn_cast<CmpInst>(I)) { assert(Preference == WantInteger && "Compares only produce integers"); - PHINode *PN = dyn_cast<PHINode>(Cmp->getOperand(0)); + Type *CmpType = Cmp->getType(); + Value *CmpLHS = Cmp->getOperand(0); + Value *CmpRHS = Cmp->getOperand(1); + CmpInst::Predicate Pred = Cmp->getPredicate(); + + PHINode *PN = dyn_cast<PHINode>(CmpLHS); if (PN && PN->getParent() == BB) { const DataLayout &DL = PN->getModule()->getDataLayout(); // We can do this simplification if any comparisons fold to true or false. @@ -552,15 +604,15 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessors( for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { BasicBlock *PredBB = PN->getIncomingBlock(i); Value *LHS = PN->getIncomingValue(i); - Value *RHS = Cmp->getOperand(1)->DoPHITranslation(BB, PredBB); + Value *RHS = CmpRHS->DoPHITranslation(BB, PredBB); - Value *Res = SimplifyCmpInst(Cmp->getPredicate(), LHS, RHS, DL); + Value *Res = SimplifyCmpInst(Pred, LHS, RHS, {DL}); if (!Res) { if (!isa<Constant>(RHS)) continue; LazyValueInfo::Tristate - ResT = LVI->getPredicateOnEdge(Cmp->getPredicate(), LHS, + ResT = LVI->getPredicateOnEdge(Pred, LHS, cast<Constant>(RHS), PredBB, BB, CxtI ? CxtI : Cmp); if (ResT == LazyValueInfo::Unknown) @@ -577,44 +629,81 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessors( // If comparing a live-in value against a constant, see if we know the // live-in value on any predecessors. - if (isa<Constant>(Cmp->getOperand(1)) && Cmp->getType()->isIntegerTy()) { - if (!isa<Instruction>(Cmp->getOperand(0)) || - cast<Instruction>(Cmp->getOperand(0))->getParent() != BB) { - Constant *RHSCst = cast<Constant>(Cmp->getOperand(1)); + if (isa<Constant>(CmpRHS) && !CmpType->isVectorTy()) { + Constant *CmpConst = cast<Constant>(CmpRHS); + if (!isa<Instruction>(CmpLHS) || + cast<Instruction>(CmpLHS)->getParent() != BB) { for (BasicBlock *P : predecessors(BB)) { // If the value is known by LazyValueInfo to be a constant in a // predecessor, use that information to try to thread this block. LazyValueInfo::Tristate Res = - LVI->getPredicateOnEdge(Cmp->getPredicate(), Cmp->getOperand(0), - RHSCst, P, BB, CxtI ? CxtI : Cmp); + LVI->getPredicateOnEdge(Pred, CmpLHS, + CmpConst, P, BB, CxtI ? CxtI : Cmp); if (Res == LazyValueInfo::Unknown) continue; - Constant *ResC = ConstantInt::get(Cmp->getType(), Res); + Constant *ResC = ConstantInt::get(CmpType, Res); Result.push_back(std::make_pair(ResC, P)); } return !Result.empty(); } + // InstCombine can fold some forms of constant range checks into + // (icmp (add (x, C1)), C2). See if we have we have such a thing with + // x as a live-in. + { + using namespace PatternMatch; + Value *AddLHS; + ConstantInt *AddConst; + if (isa<ConstantInt>(CmpConst) && + match(CmpLHS, m_Add(m_Value(AddLHS), m_ConstantInt(AddConst)))) { + if (!isa<Instruction>(AddLHS) || + cast<Instruction>(AddLHS)->getParent() != BB) { + for (BasicBlock *P : predecessors(BB)) { + // If the value is known by LazyValueInfo to be a ConstantRange in + // a predecessor, use that information to try to thread this + // block. + ConstantRange CR = LVI->getConstantRangeOnEdge( + AddLHS, P, BB, CxtI ? CxtI : cast<Instruction>(CmpLHS)); + // Propagate the range through the addition. + CR = CR.add(AddConst->getValue()); + + // Get the range where the compare returns true. + ConstantRange CmpRange = ConstantRange::makeExactICmpRegion( + Pred, cast<ConstantInt>(CmpConst)->getValue()); + + Constant *ResC; + if (CmpRange.contains(CR)) + ResC = ConstantInt::getTrue(CmpType); + else if (CmpRange.inverse().contains(CR)) + ResC = ConstantInt::getFalse(CmpType); + else + continue; + + Result.push_back(std::make_pair(ResC, P)); + } + + return !Result.empty(); + } + } + } + // Try to find a constant value for the LHS of a comparison, // and evaluate it statically if we can. - if (Constant *CmpConst = dyn_cast<Constant>(Cmp->getOperand(1))) { - PredValueInfoTy LHSVals; - ComputeValueKnownInPredecessors(I->getOperand(0), BB, LHSVals, - WantInteger, CxtI); - - for (const auto &LHSVal : LHSVals) { - Constant *V = LHSVal.first; - Constant *Folded = ConstantExpr::getCompare(Cmp->getPredicate(), - V, CmpConst); - if (Constant *KC = getKnownConstant(Folded, WantInteger)) - Result.push_back(std::make_pair(KC, LHSVal.second)); - } + PredValueInfoTy LHSVals; + ComputeValueKnownInPredecessors(I->getOperand(0), BB, LHSVals, + WantInteger, CxtI); - return !Result.empty(); + for (const auto &LHSVal : LHSVals) { + Constant *V = LHSVal.first; + Constant *Folded = ConstantExpr::getCompare(Pred, V, CmpConst); + if (Constant *KC = getKnownConstant(Folded, WantInteger)) + Result.push_back(std::make_pair(KC, LHSVal.second)); } + + return !Result.empty(); } } @@ -722,6 +811,37 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) { LVI->eraseBlock(SinglePred); MergeBasicBlockIntoOnlyPred(BB); + // Now that BB is merged into SinglePred (i.e. SinglePred Code followed by + // BB code within one basic block `BB`), we need to invalidate the LVI + // information associated with BB, because the LVI information need not be + // true for all of BB after the merge. For example, + // Before the merge, LVI info and code is as follows: + // SinglePred: <LVI info1 for %p val> + // %y = use of %p + // call @exit() // need not transfer execution to successor. + // assume(%p) // from this point on %p is true + // br label %BB + // BB: <LVI info2 for %p val, i.e. %p is true> + // %x = use of %p + // br label exit + // + // Note that this LVI info for blocks BB and SinglPred is correct for %p + // (info2 and info1 respectively). After the merge and the deletion of the + // LVI info1 for SinglePred. We have the following code: + // BB: <LVI info2 for %p val> + // %y = use of %p + // call @exit() + // assume(%p) + // %x = use of %p <-- LVI info2 is correct from here onwards. + // br label exit + // LVI info2 for BB is incorrect at the beginning of BB. + + // Invalidate LVI information for BB if the LVI is not provably true for + // all of BB. + if (any_of(*BB, [](Instruction &I) { + return !isGuaranteedToTransferExecutionToSuccessor(&I); + })) + LVI->eraseBlock(BB); return true; } } @@ -729,6 +849,10 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) { if (TryToUnfoldSelectInCurrBB(BB)) return true; + // Look if we can propagate guards to predecessors. + if (HasGuards && ProcessGuards(BB)) + return true; + // What kind of constant we're looking for. ConstantPreference Preference = WantInteger; @@ -804,7 +928,6 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) { return false; } - if (CmpInst *CondCmp = dyn_cast<CmpInst>(CondInst)) { // If we're branching on a conditional, LVI might be able to determine // it's value at the branch instruction. We only handle comparisons @@ -812,7 +935,12 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) { // TODO: This should be extended to handle switches as well. BranchInst *CondBr = dyn_cast<BranchInst>(BB->getTerminator()); Constant *CondConst = dyn_cast<Constant>(CondCmp->getOperand(1)); - if (CondBr && CondConst && CondBr->isConditional()) { + if (CondBr && CondConst) { + // We should have returned as soon as we turn a conditional branch to + // unconditional. Because its no longer interesting as far as jump + // threading is concerned. + assert(CondBr->isConditional() && "Threading on unconditional terminator"); + LazyValueInfo::Tristate Ret = LVI->getPredicateAt(CondCmp->getPredicate(), CondCmp->getOperand(0), CondConst, CondBr); @@ -824,21 +952,27 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) { CondBr->eraseFromParent(); if (CondCmp->use_empty()) CondCmp->eraseFromParent(); + // We can safely replace *some* uses of the CondInst if it has + // exactly one value as returned by LVI. RAUW is incorrect in the + // presence of guards and assumes, that have the `Cond` as the use. This + // is because we use the guards/assume to reason about the `Cond` value + // at the end of block, but RAUW unconditionally replaces all uses + // including the guards/assumes themselves and the uses before the + // guard/assume. else if (CondCmp->getParent() == BB) { - // If the fact we just learned is true for all uses of the - // condition, replace it with a constant value auto *CI = Ret == LazyValueInfo::True ? ConstantInt::getTrue(CondCmp->getType()) : ConstantInt::getFalse(CondCmp->getType()); - CondCmp->replaceAllUsesWith(CI); - CondCmp->eraseFromParent(); + ReplaceFoldableUses(CondCmp, CI); } return true; } - } - if (CondBr && CondConst && TryToUnfoldSelect(CondCmp, BB)) - return true; + // We did not manage to simplify this branch, try to see whether + // CondCmp depends on a known phi-select pattern. + if (TryToUnfoldSelect(CondCmp, BB)) + return true; + } } // Check for some cases that are worth simplifying. Right now we want to look @@ -857,7 +991,6 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) { if (SimplifyPartiallyRedundantLoad(LI)) return true; - // Handle a variety of cases where we are branching on something derived from // a PHI node in the current block. If we can prove that any predecessors // compute a predictable value based on a PHI node, thread those predecessors. @@ -871,7 +1004,6 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) { if (PN->getParent() == BB && isa<BranchInst>(BB->getTerminator())) return ProcessBranchOnPHI(PN); - // If this is an otherwise-unfoldable branch on a XOR, see if we can simplify. if (CondInst->getOpcode() == Instruction::Xor && CondInst->getParent() == BB && isa<BranchInst>(BB->getTerminator())) @@ -920,6 +1052,14 @@ bool JumpThreadingPass::ProcessImpliedCondition(BasicBlock *BB) { return false; } +/// Return true if Op is an instruction defined in the given block. +static bool isOpDefinedInBlock(Value *Op, BasicBlock *BB) { + if (Instruction *OpInst = dyn_cast<Instruction>(Op)) + if (OpInst->getParent() == BB) + return true; + return false; +} + /// SimplifyPartiallyRedundantLoad - If LI is an obviously partially redundant /// load instruction, eliminate it by replacing it with a PHI node. This is an /// important optimization that encourages jump threading, and needs to be run @@ -942,18 +1082,17 @@ bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LI) { Value *LoadedPtr = LI->getOperand(0); - // If the loaded operand is defined in the LoadBB, it can't be available. - // TODO: Could do simple PHI translation, that would be fun :) - if (Instruction *PtrOp = dyn_cast<Instruction>(LoadedPtr)) - if (PtrOp->getParent() == LoadBB) - return false; + // If the loaded operand is defined in the LoadBB and its not a phi, + // it can't be available in predecessors. + if (isOpDefinedInBlock(LoadedPtr, LoadBB) && !isa<PHINode>(LoadedPtr)) + return false; // Scan a few instructions up from the load, to see if it is obviously live at // the entry to its block. BasicBlock::iterator BBIt(LI); bool IsLoadCSE; - if (Value *AvailableVal = - FindAvailableLoadedValue(LI, LoadBB, BBIt, DefMaxInstsToScan, nullptr, &IsLoadCSE)) { + if (Value *AvailableVal = FindAvailableLoadedValue( + LI, LoadBB, BBIt, DefMaxInstsToScan, AA, &IsLoadCSE)) { // If the value of the load is locally available within the block, just use // it. This frequently occurs for reg2mem'd allocas. @@ -997,12 +1136,34 @@ bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LI) { if (!PredsScanned.insert(PredBB).second) continue; - // Scan the predecessor to see if the value is available in the pred. BBIt = PredBB->end(); - Value *PredAvailable = FindAvailableLoadedValue(LI, PredBB, BBIt, - DefMaxInstsToScan, - nullptr, - &IsLoadCSE); + unsigned NumScanedInst = 0; + Value *PredAvailable = nullptr; + // NOTE: We don't CSE load that is volatile or anything stronger than + // unordered, that should have been checked when we entered the function. + assert(LI->isUnordered() && "Attempting to CSE volatile or atomic loads"); + // If this is a load on a phi pointer, phi-translate it and search + // for available load/store to the pointer in predecessors. + Value *Ptr = LoadedPtr->DoPHITranslation(LoadBB, PredBB); + PredAvailable = FindAvailablePtrLoadStore( + Ptr, LI->getType(), LI->isAtomic(), PredBB, BBIt, DefMaxInstsToScan, + AA, &IsLoadCSE, &NumScanedInst); + + // If PredBB has a single predecessor, continue scanning through the + // single precessor. + BasicBlock *SinglePredBB = PredBB; + while (!PredAvailable && SinglePredBB && BBIt == SinglePredBB->begin() && + NumScanedInst < DefMaxInstsToScan) { + SinglePredBB = SinglePredBB->getSinglePredecessor(); + if (SinglePredBB) { + BBIt = SinglePredBB->end(); + PredAvailable = FindAvailablePtrLoadStore( + Ptr, LI->getType(), LI->isAtomic(), SinglePredBB, BBIt, + (DefMaxInstsToScan - NumScanedInst), AA, &IsLoadCSE, + &NumScanedInst); + } + } + if (!PredAvailable) { OneUnavailablePred = PredBB; continue; @@ -1062,10 +1223,10 @@ bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LI) { if (UnavailablePred) { assert(UnavailablePred->getTerminator()->getNumSuccessors() == 1 && "Can't handle critical edge here!"); - LoadInst *NewVal = - new LoadInst(LoadedPtr, LI->getName() + ".pr", false, - LI->getAlignment(), LI->getOrdering(), LI->getSynchScope(), - UnavailablePred->getTerminator()); + LoadInst *NewVal = new LoadInst( + LoadedPtr->DoPHITranslation(LoadBB, UnavailablePred), + LI->getName() + ".pr", false, LI->getAlignment(), LI->getOrdering(), + LI->getSyncScopeID(), UnavailablePred->getTerminator()); NewVal->setDebugLoc(LI->getDebugLoc()); if (AATags) NewVal->setAAMetadata(AATags); @@ -1210,37 +1371,53 @@ bool JumpThreadingPass::ProcessThreadableEdges(Value *Cond, BasicBlock *BB, BasicBlock *OnlyDest = nullptr; BasicBlock *MultipleDestSentinel = (BasicBlock*)(intptr_t)~0ULL; + Constant *OnlyVal = nullptr; + Constant *MultipleVal = (Constant *)(intptr_t)~0ULL; + unsigned PredWithKnownDest = 0; for (const auto &PredValue : PredValues) { BasicBlock *Pred = PredValue.second; if (!SeenPreds.insert(Pred).second) continue; // Duplicate predecessor entry. - // If the predecessor ends with an indirect goto, we can't change its - // destination. - if (isa<IndirectBrInst>(Pred->getTerminator())) - continue; - Constant *Val = PredValue.first; BasicBlock *DestBB; if (isa<UndefValue>(Val)) DestBB = nullptr; - else if (BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator())) + else if (BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator())) { + assert(isa<ConstantInt>(Val) && "Expecting a constant integer"); DestBB = BI->getSuccessor(cast<ConstantInt>(Val)->isZero()); - else if (SwitchInst *SI = dyn_cast<SwitchInst>(BB->getTerminator())) { - DestBB = SI->findCaseValue(cast<ConstantInt>(Val)).getCaseSuccessor(); + } else if (SwitchInst *SI = dyn_cast<SwitchInst>(BB->getTerminator())) { + assert(isa<ConstantInt>(Val) && "Expecting a constant integer"); + DestBB = SI->findCaseValue(cast<ConstantInt>(Val))->getCaseSuccessor(); } else { assert(isa<IndirectBrInst>(BB->getTerminator()) && "Unexpected terminator"); + assert(isa<BlockAddress>(Val) && "Expecting a constant blockaddress"); DestBB = cast<BlockAddress>(Val)->getBasicBlock(); } // If we have exactly one destination, remember it for efficiency below. - if (PredToDestList.empty()) + if (PredToDestList.empty()) { OnlyDest = DestBB; - else if (OnlyDest != DestBB) - OnlyDest = MultipleDestSentinel; + OnlyVal = Val; + } else { + if (OnlyDest != DestBB) + OnlyDest = MultipleDestSentinel; + // It possible we have same destination, but different value, e.g. default + // case in switchinst. + if (Val != OnlyVal) + OnlyVal = MultipleVal; + } + + // We know where this predecessor is going. + ++PredWithKnownDest; + + // If the predecessor ends with an indirect goto, we can't change its + // destination. + if (isa<IndirectBrInst>(Pred->getTerminator())) + continue; PredToDestList.push_back(std::make_pair(Pred, DestBB)); } @@ -1249,6 +1426,45 @@ bool JumpThreadingPass::ProcessThreadableEdges(Value *Cond, BasicBlock *BB, if (PredToDestList.empty()) return false; + // If all the predecessors go to a single known successor, we want to fold, + // not thread. By doing so, we do not need to duplicate the current block and + // also miss potential opportunities in case we dont/cant duplicate. + if (OnlyDest && OnlyDest != MultipleDestSentinel) { + if (PredWithKnownDest == + (size_t)std::distance(pred_begin(BB), pred_end(BB))) { + bool SeenFirstBranchToOnlyDest = false; + for (BasicBlock *SuccBB : successors(BB)) { + if (SuccBB == OnlyDest && !SeenFirstBranchToOnlyDest) + SeenFirstBranchToOnlyDest = true; // Don't modify the first branch. + else + SuccBB->removePredecessor(BB, true); // This is unreachable successor. + } + + // Finally update the terminator. + TerminatorInst *Term = BB->getTerminator(); + BranchInst::Create(OnlyDest, Term); + Term->eraseFromParent(); + + // If the condition is now dead due to the removal of the old terminator, + // erase it. + if (auto *CondInst = dyn_cast<Instruction>(Cond)) { + if (CondInst->use_empty() && !CondInst->mayHaveSideEffects()) + CondInst->eraseFromParent(); + // We can safely replace *some* uses of the CondInst if it has + // exactly one value as returned by LVI. RAUW is incorrect in the + // presence of guards and assumes, that have the `Cond` as the use. This + // is because we use the guards/assume to reason about the `Cond` value + // at the end of block, but RAUW unconditionally replaces all uses + // including the guards/assumes themselves and the uses before the + // guard/assume. + else if (OnlyVal && OnlyVal != MultipleVal && + CondInst->getParent() == BB) + ReplaceFoldableUses(CondInst, OnlyVal); + } + return true; + } + } + // Determine which is the most common successor. If we have many inputs and // this block is a switch, we want to start by threading the batch that goes // to the most popular destination first. If we only know about one @@ -1468,7 +1684,8 @@ bool JumpThreadingPass::ThreadEdge(BasicBlock *BB, return false; } - unsigned JumpThreadCost = getJumpThreadDuplicationCost(BB, BBDupThreshold); + unsigned JumpThreadCost = + getJumpThreadDuplicationCost(BB, BB->getTerminator(), BBDupThreshold); if (JumpThreadCost > BBDupThreshold) { DEBUG(dbgs() << " Not threading BB '" << BB->getName() << "' - Cost is too high: " << JumpThreadCost << "\n"); @@ -1756,7 +1973,8 @@ bool JumpThreadingPass::DuplicateCondBranchOnPHIIntoPred( return false; } - unsigned DuplicationCost = getJumpThreadDuplicationCost(BB, BBDupThreshold); + unsigned DuplicationCost = + getJumpThreadDuplicationCost(BB, BB->getTerminator(), BBDupThreshold); if (DuplicationCost > BBDupThreshold) { DEBUG(dbgs() << " Not duplicating BB '" << BB->getName() << "' - Cost is too high: " << DuplicationCost << "\n"); @@ -1811,11 +2029,12 @@ bool JumpThreadingPass::DuplicateCondBranchOnPHIIntoPred( // If this instruction can be simplified after the operands are updated, // just use the simplified value instead. This frequently happens due to // phi translation. - if (Value *IV = - SimplifyInstruction(New, BB->getModule()->getDataLayout())) { + if (Value *IV = SimplifyInstruction( + New, + {BB->getModule()->getDataLayout(), TLI, nullptr, nullptr, New})) { ValueMapping[&*BI] = IV; if (!New->mayHaveSideEffects()) { - delete New; + New->deleteValue(); New = nullptr; } } else { @@ -1888,10 +2107,10 @@ bool JumpThreadingPass::DuplicateCondBranchOnPHIIntoPred( /// TryToUnfoldSelect - Look for blocks of the form /// bb1: /// %a = select -/// br bb +/// br bb2 /// /// bb2: -/// %p = phi [%a, %bb] ... +/// %p = phi [%a, %bb1] ... /// %c = icmp %p /// br i1 %c /// @@ -1963,11 +2182,19 @@ bool JumpThreadingPass::TryToUnfoldSelect(CmpInst *CondCmp, BasicBlock *BB) { return false; } -/// TryToUnfoldSelectInCurrBB - Look for PHI/Select in the same BB of the form +/// TryToUnfoldSelectInCurrBB - Look for PHI/Select or PHI/CMP/Select in the +/// same BB in the form /// bb: /// %p = phi [false, %bb1], [true, %bb2], [false, %bb3], [true, %bb4], ... -/// %s = select p, trueval, falseval +/// %s = select %p, trueval, falseval /// +/// or +/// +/// bb: +/// %p = phi [0, %bb1], [1, %bb2], [0, %bb3], [1, %bb4], ... +/// %c = cmp %p, 0 +/// %s = select %c, trueval, falseval +// /// And expand the select into a branch structure. This later enables /// jump-threading over bb in this pass. /// @@ -1981,43 +2208,180 @@ bool JumpThreadingPass::TryToUnfoldSelectInCurrBB(BasicBlock *BB) { if (LoopHeaders.count(BB)) return false; - // Look for a Phi/Select pair in the same basic block. The Phi feeds the - // condition of the Select and at least one of the incoming values is a - // constant. for (BasicBlock::iterator BI = BB->begin(); PHINode *PN = dyn_cast<PHINode>(BI); ++BI) { - unsigned NumPHIValues = PN->getNumIncomingValues(); - if (NumPHIValues == 0 || !PN->hasOneUse()) + // Look for a Phi having at least one constant incoming value. + if (llvm::all_of(PN->incoming_values(), + [](Value *V) { return !isa<ConstantInt>(V); })) continue; - SelectInst *SI = dyn_cast<SelectInst>(PN->user_back()); - if (!SI || SI->getParent() != BB) - continue; + auto isUnfoldCandidate = [BB](SelectInst *SI, Value *V) { + // Check if SI is in BB and use V as condition. + if (SI->getParent() != BB) + return false; + Value *Cond = SI->getCondition(); + return (Cond && Cond == V && Cond->getType()->isIntegerTy(1)); + }; - Value *Cond = SI->getCondition(); - if (!Cond || Cond != PN || !Cond->getType()->isIntegerTy(1)) + SelectInst *SI = nullptr; + for (Use &U : PN->uses()) { + if (ICmpInst *Cmp = dyn_cast<ICmpInst>(U.getUser())) { + // Look for a ICmp in BB that compares PN with a constant and is the + // condition of a Select. + if (Cmp->getParent() == BB && Cmp->hasOneUse() && + isa<ConstantInt>(Cmp->getOperand(1 - U.getOperandNo()))) + if (SelectInst *SelectI = dyn_cast<SelectInst>(Cmp->user_back())) + if (isUnfoldCandidate(SelectI, Cmp->use_begin()->get())) { + SI = SelectI; + break; + } + } else if (SelectInst *SelectI = dyn_cast<SelectInst>(U.getUser())) { + // Look for a Select in BB that uses PN as condtion. + if (isUnfoldCandidate(SelectI, U.get())) { + SI = SelectI; + break; + } + } + } + + if (!SI) continue; + // Expand the select. + TerminatorInst *Term = + SplitBlockAndInsertIfThen(SI->getCondition(), SI, false); + PHINode *NewPN = PHINode::Create(SI->getType(), 2, "", SI); + NewPN->addIncoming(SI->getTrueValue(), Term->getParent()); + NewPN->addIncoming(SI->getFalseValue(), BB); + SI->replaceAllUsesWith(NewPN); + SI->eraseFromParent(); + return true; + } + return false; +} - bool HasConst = false; - for (unsigned i = 0; i != NumPHIValues; ++i) { - if (PN->getIncomingBlock(i) == BB) - return false; - if (isa<ConstantInt>(PN->getIncomingValue(i))) - HasConst = true; - } +/// Try to propagate a guard from the current BB into one of its predecessors +/// in case if another branch of execution implies that the condition of this +/// guard is always true. Currently we only process the simplest case that +/// looks like: +/// +/// Start: +/// %cond = ... +/// br i1 %cond, label %T1, label %F1 +/// T1: +/// br label %Merge +/// F1: +/// br label %Merge +/// Merge: +/// %condGuard = ... +/// call void(i1, ...) @llvm.experimental.guard( i1 %condGuard )[ "deopt"() ] +/// +/// And cond either implies condGuard or !condGuard. In this case all the +/// instructions before the guard can be duplicated in both branches, and the +/// guard is then threaded to one of them. +bool JumpThreadingPass::ProcessGuards(BasicBlock *BB) { + using namespace PatternMatch; + // We only want to deal with two predecessors. + BasicBlock *Pred1, *Pred2; + auto PI = pred_begin(BB), PE = pred_end(BB); + if (PI == PE) + return false; + Pred1 = *PI++; + if (PI == PE) + return false; + Pred2 = *PI++; + if (PI != PE) + return false; + if (Pred1 == Pred2) + return false; - if (HasConst) { - // Expand the select. - TerminatorInst *Term = - SplitBlockAndInsertIfThen(SI->getCondition(), SI, false); - PHINode *NewPN = PHINode::Create(SI->getType(), 2, "", SI); - NewPN->addIncoming(SI->getTrueValue(), Term->getParent()); - NewPN->addIncoming(SI->getFalseValue(), BB); - SI->replaceAllUsesWith(NewPN); - SI->eraseFromParent(); - return true; + // Try to thread one of the guards of the block. + // TODO: Look up deeper than to immediate predecessor? + auto *Parent = Pred1->getSinglePredecessor(); + if (!Parent || Parent != Pred2->getSinglePredecessor()) + return false; + + if (auto *BI = dyn_cast<BranchInst>(Parent->getTerminator())) + for (auto &I : *BB) + if (match(&I, m_Intrinsic<Intrinsic::experimental_guard>())) + if (ThreadGuard(BB, cast<IntrinsicInst>(&I), BI)) + return true; + + return false; +} + +/// Try to propagate the guard from BB which is the lower block of a diamond +/// to one of its branches, in case if diamond's condition implies guard's +/// condition. +bool JumpThreadingPass::ThreadGuard(BasicBlock *BB, IntrinsicInst *Guard, + BranchInst *BI) { + assert(BI->getNumSuccessors() == 2 && "Wrong number of successors?"); + assert(BI->isConditional() && "Unconditional branch has 2 successors?"); + Value *GuardCond = Guard->getArgOperand(0); + Value *BranchCond = BI->getCondition(); + BasicBlock *TrueDest = BI->getSuccessor(0); + BasicBlock *FalseDest = BI->getSuccessor(1); + + auto &DL = BB->getModule()->getDataLayout(); + bool TrueDestIsSafe = false; + bool FalseDestIsSafe = false; + + // True dest is safe if BranchCond => GuardCond. + auto Impl = isImpliedCondition(BranchCond, GuardCond, DL); + if (Impl && *Impl) + TrueDestIsSafe = true; + else { + // False dest is safe if !BranchCond => GuardCond. + Impl = + isImpliedCondition(BranchCond, GuardCond, DL, /* InvertAPred */ true); + if (Impl && *Impl) + FalseDestIsSafe = true; + } + + if (!TrueDestIsSafe && !FalseDestIsSafe) + return false; + + BasicBlock *UnguardedBlock = TrueDestIsSafe ? TrueDest : FalseDest; + BasicBlock *GuardedBlock = FalseDestIsSafe ? TrueDest : FalseDest; + + ValueToValueMapTy UnguardedMapping, GuardedMapping; + Instruction *AfterGuard = Guard->getNextNode(); + unsigned Cost = getJumpThreadDuplicationCost(BB, AfterGuard, BBDupThreshold); + if (Cost > BBDupThreshold) + return false; + // Duplicate all instructions before the guard and the guard itself to the + // branch where implication is not proved. + GuardedBlock = DuplicateInstructionsInSplitBetween( + BB, GuardedBlock, AfterGuard, GuardedMapping); + assert(GuardedBlock && "Could not create the guarded block?"); + // Duplicate all instructions before the guard in the unguarded branch. + // Since we have successfully duplicated the guarded block and this block + // has fewer instructions, we expect it to succeed. + UnguardedBlock = DuplicateInstructionsInSplitBetween(BB, UnguardedBlock, + Guard, UnguardedMapping); + assert(UnguardedBlock && "Could not create the unguarded block?"); + DEBUG(dbgs() << "Moved guard " << *Guard << " to block " + << GuardedBlock->getName() << "\n"); + + // Some instructions before the guard may still have uses. For them, we need + // to create Phi nodes merging their copies in both guarded and unguarded + // branches. Those instructions that have no uses can be just removed. + SmallVector<Instruction *, 4> ToRemove; + for (auto BI = BB->begin(); &*BI != AfterGuard; ++BI) + if (!isa<PHINode>(&*BI)) + ToRemove.push_back(&*BI); + + Instruction *InsertionPoint = &*BB->getFirstInsertionPt(); + assert(InsertionPoint && "Empty block?"); + // Substitute with Phis & remove. + for (auto *Inst : reverse(ToRemove)) { + if (!Inst->use_empty()) { + PHINode *NewPN = PHINode::Create(Inst->getType(), 2); + NewPN->addIncoming(UnguardedMapping[Inst], UnguardedBlock); + NewPN->addIncoming(GuardedMapping[Inst], GuardedBlock); + NewPN->insertBefore(InsertionPoint); + Inst->replaceAllUsesWith(NewPN); } + Inst->eraseFromParent(); } - - return false; + return true; } diff --git a/contrib/llvm/lib/Transforms/Scalar/LICM.cpp b/contrib/llvm/lib/Transforms/Scalar/LICM.cpp index f51d11c..37b9c4b 100644 --- a/contrib/llvm/lib/Transforms/Scalar/LICM.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/LICM.cpp @@ -77,10 +77,16 @@ STATISTIC(NumMovedLoads, "Number of load insts hoisted or sunk"); STATISTIC(NumMovedCalls, "Number of call insts hoisted or sunk"); STATISTIC(NumPromoted, "Number of memory locations promoted to registers"); +/// Memory promotion is enabled by default. static cl::opt<bool> - DisablePromotion("disable-licm-promotion", cl::Hidden, + DisablePromotion("disable-licm-promotion", cl::Hidden, cl::init(false), cl::desc("Disable memory promotion in LICM pass")); +static cl::opt<uint32_t> MaxNumUsesTraversed( + "licm-max-num-uses-traversed", cl::Hidden, cl::init(8), + cl::desc("Max num uses visited for identifying load " + "invariance in loop using invariant start (default = 8)")); + static bool inSubLoop(BasicBlock *BB, Loop *CurLoop, LoopInfo *LI); static bool isNotUsedInLoop(const Instruction &I, const Loop *CurLoop, const LoopSafetyInfo *SafetyInfo); @@ -201,9 +207,9 @@ PreservedAnalyses LICMPass::run(Loop &L, LoopAnalysisManager &AM, if (!LICM.runOnLoop(&L, &AR.AA, &AR.LI, &AR.DT, &AR.TLI, &AR.SE, ORE, true)) return PreservedAnalyses::all(); - // FIXME: There is no setPreservesCFG in the new PM. When that becomes - // available, it should be used here. - return getLoopPassPreservedAnalyses(); + auto PA = getLoopPassPreservedAnalyses(); + PA.preserveSet<CFGAnalyses>(); + return PA; } char LegacyLICMPass::ID = 0; @@ -425,6 +431,29 @@ bool llvm::hoistRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, continue; } + // Attempt to remove floating point division out of the loop by converting + // it to a reciprocal multiplication. + if (I.getOpcode() == Instruction::FDiv && + CurLoop->isLoopInvariant(I.getOperand(1)) && + I.hasAllowReciprocal()) { + auto Divisor = I.getOperand(1); + auto One = llvm::ConstantFP::get(Divisor->getType(), 1.0); + auto ReciprocalDivisor = BinaryOperator::CreateFDiv(One, Divisor); + ReciprocalDivisor->setFastMathFlags(I.getFastMathFlags()); + ReciprocalDivisor->insertBefore(&I); + + auto Product = BinaryOperator::CreateFMul(I.getOperand(0), + ReciprocalDivisor); + Product->setFastMathFlags(I.getFastMathFlags()); + Product->insertAfter(&I); + I.replaceAllUsesWith(Product); + I.eraseFromParent(); + + hoist(*ReciprocalDivisor, DT, CurLoop, SafetyInfo, ORE); + Changed = true; + continue; + } + // Try hoisting the instruction out to the preheader. We can only do this // if all of the operands of the instruction are loop invariant and if it // is safe to hoist the instruction. @@ -461,7 +490,10 @@ void llvm::computeLoopSafetyInfo(LoopSafetyInfo *SafetyInfo, Loop *CurLoop) { SafetyInfo->MayThrow = SafetyInfo->HeaderMayThrow; // Iterate over loop instructions and compute safety info. - for (Loop::block_iterator BB = CurLoop->block_begin(), + // Skip header as it has been computed and stored in HeaderMayThrow. + // The first block in loopinfo.Blocks is guaranteed to be the header. + assert(Header == *CurLoop->getBlocks().begin() && "First block must be header"); + for (Loop::block_iterator BB = std::next(CurLoop->block_begin()), BBE = CurLoop->block_end(); (BB != BBE) && !SafetyInfo->MayThrow; ++BB) for (BasicBlock::iterator I = (*BB)->begin(), E = (*BB)->end(); @@ -477,6 +509,59 @@ void llvm::computeLoopSafetyInfo(LoopSafetyInfo *SafetyInfo, Loop *CurLoop) { SafetyInfo->BlockColors = colorEHFunclets(*Fn); } +// Return true if LI is invariant within scope of the loop. LI is invariant if +// CurLoop is dominated by an invariant.start representing the same memory location +// and size as the memory location LI loads from, and also the invariant.start +// has no uses. +static bool isLoadInvariantInLoop(LoadInst *LI, DominatorTree *DT, + Loop *CurLoop) { + Value *Addr = LI->getOperand(0); + const DataLayout &DL = LI->getModule()->getDataLayout(); + const uint32_t LocSizeInBits = DL.getTypeSizeInBits( + cast<PointerType>(Addr->getType())->getElementType()); + + // if the type is i8 addrspace(x)*, we know this is the type of + // llvm.invariant.start operand + auto *PtrInt8Ty = PointerType::get(Type::getInt8Ty(LI->getContext()), + LI->getPointerAddressSpace()); + unsigned BitcastsVisited = 0; + // Look through bitcasts until we reach the i8* type (this is invariant.start + // operand type). + while (Addr->getType() != PtrInt8Ty) { + auto *BC = dyn_cast<BitCastInst>(Addr); + // Avoid traversing high number of bitcast uses. + if (++BitcastsVisited > MaxNumUsesTraversed || !BC) + return false; + Addr = BC->getOperand(0); + } + + unsigned UsesVisited = 0; + // Traverse all uses of the load operand value, to see if invariant.start is + // one of the uses, and whether it dominates the load instruction. + for (auto *U : Addr->users()) { + // Avoid traversing for Load operand with high number of users. + if (++UsesVisited > MaxNumUsesTraversed) + return false; + IntrinsicInst *II = dyn_cast<IntrinsicInst>(U); + // If there are escaping uses of invariant.start instruction, the load maybe + // non-invariant. + if (!II || II->getIntrinsicID() != Intrinsic::invariant_start || + !II->use_empty()) + continue; + unsigned InvariantSizeInBits = + cast<ConstantInt>(II->getArgOperand(0))->getSExtValue() * 8; + // Confirm the invariant.start location size contains the load operand size + // in bits. Also, the invariant.start should dominate the load, and we + // should not hoist the load out of a loop that contains this dominating + // invariant.start. + if (LocSizeInBits <= InvariantSizeInBits && + DT->properlyDominates(II->getParent(), CurLoop->getHeader())) + return true; + } + + return false; +} + bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, Loop *CurLoop, AliasSetTracker *CurAST, LoopSafetyInfo *SafetyInfo, @@ -493,6 +578,10 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, if (LI->getMetadata(LLVMContext::MD_invariant_load)) return true; + // This checks for an invariant.start dominating the load. + if (isLoadInvariantInLoop(LI, DT, CurLoop)) + return true; + // Don't hoist loads which have may-aliased stores in loop. uint64_t Size = 0; if (LI->getType()->isSized()) @@ -782,7 +871,7 @@ static bool hoist(Instruction &I, const DominatorTree *DT, const Loop *CurLoop, DEBUG(dbgs() << "LICM hoisting to " << Preheader->getName() << ": " << I << "\n"); ORE->emit(OptimizationRemark(DEBUG_TYPE, "Hoisted", &I) - << "hosting " << ore::NV("Inst", &I)); + << "hoisting " << ore::NV("Inst", &I)); // Metadata can be dependent on conditions we are hoisting above. // Conservatively strip all metadata on the instruction unless we were @@ -852,6 +941,7 @@ class LoopPromoter : public LoadAndStorePromoter { LoopInfo &LI; DebugLoc DL; int Alignment; + bool UnorderedAtomic; AAMDNodes AATags; Value *maybeInsertLCSSAPHI(Value *V, BasicBlock *BB) const { @@ -875,10 +965,11 @@ public: SmallVectorImpl<BasicBlock *> &LEB, SmallVectorImpl<Instruction *> &LIP, PredIteratorCache &PIC, AliasSetTracker &ast, LoopInfo &li, DebugLoc dl, int alignment, - const AAMDNodes &AATags) + bool UnorderedAtomic, const AAMDNodes &AATags) : LoadAndStorePromoter(Insts, S), SomePtr(SP), PointerMustAliases(PMA), LoopExitBlocks(LEB), LoopInsertPts(LIP), PredCache(PIC), AST(ast), - LI(li), DL(std::move(dl)), Alignment(alignment), AATags(AATags) {} + LI(li), DL(std::move(dl)), Alignment(alignment), + UnorderedAtomic(UnorderedAtomic),AATags(AATags) {} bool isInstInList(Instruction *I, const SmallVectorImpl<Instruction *> &) const override { @@ -902,6 +993,8 @@ public: Value *Ptr = maybeInsertLCSSAPHI(SomePtr, ExitBlock); Instruction *InsertPos = LoopInsertPts[i]; StoreInst *NewSI = new StoreInst(LiveInValue, Ptr, InsertPos); + if (UnorderedAtomic) + NewSI->setOrdering(AtomicOrdering::Unordered); NewSI->setAlignment(Alignment); NewSI->setDebugLoc(DL); if (AATags) @@ -992,18 +1085,41 @@ bool llvm::promoteLoopAccessesToScalars( // We start with an alignment of one and try to find instructions that allow // us to prove better alignment. unsigned Alignment = 1; + // Keep track of which types of access we see + bool SawUnorderedAtomic = false; + bool SawNotAtomic = false; AAMDNodes AATags; const DataLayout &MDL = Preheader->getModule()->getDataLayout(); + // Do we know this object does not escape ? + bool IsKnownNonEscapingObject = false; if (SafetyInfo->MayThrow) { // If a loop can throw, we have to insert a store along each unwind edge. // That said, we can't actually make the unwind edge explicit. Therefore, // we have to prove that the store is dead along the unwind edge. // - // Currently, this code just special-cases alloca instructions. - if (!isa<AllocaInst>(GetUnderlyingObject(SomePtr, MDL))) - return false; + // If the underlying object is not an alloca, nor a pointer that does not + // escape, then we can not effectively prove that the store is dead along + // the unwind edge. i.e. the caller of this function could have ways to + // access the pointed object. + Value *Object = GetUnderlyingObject(SomePtr, MDL); + // If this is a base pointer we do not understand, simply bail. + // We only handle alloca and return value from alloc-like fn right now. + if (!isa<AllocaInst>(Object)) { + if (!isAllocLikeFn(Object, TLI)) + return false; + // If this is an alloc like fn. There are more constraints we need to verify. + // More specifically, we must make sure that the pointer can not escape. + // + // NOTE: PointerMayBeCaptured is not enough as the pointer may have escaped + // even though its not captured by the enclosing function. Standard allocation + // functions like malloc, calloc, and operator new return values which can + // be assumed not to have previously escaped. + if (PointerMayBeCaptured(Object, true, true)) + return false; + IsKnownNonEscapingObject = true; + } } // Check that all of the pointers in the alias set have the same type. We @@ -1029,8 +1145,11 @@ bool llvm::promoteLoopAccessesToScalars( // it. if (LoadInst *Load = dyn_cast<LoadInst>(UI)) { assert(!Load->isVolatile() && "AST broken"); - if (!Load->isSimple()) + if (!Load->isUnordered()) return false; + + SawUnorderedAtomic |= Load->isAtomic(); + SawNotAtomic |= !Load->isAtomic(); if (!DereferenceableInPH) DereferenceableInPH = isSafeToExecuteUnconditionally( @@ -1041,9 +1160,12 @@ bool llvm::promoteLoopAccessesToScalars( if (UI->getOperand(1) != ASIV) continue; assert(!Store->isVolatile() && "AST broken"); - if (!Store->isSimple()) + if (!Store->isUnordered()) return false; + SawUnorderedAtomic |= Store->isAtomic(); + SawNotAtomic |= !Store->isAtomic(); + // If the store is guaranteed to execute, both properties are satisfied. // We may want to check if a store is guaranteed to execute even if we // already know that promotion is safe, since it may have higher @@ -1096,6 +1218,12 @@ bool llvm::promoteLoopAccessesToScalars( } } + // If we found both an unordered atomic instruction and a non-atomic memory + // access, bail. We can't blindly promote non-atomic to atomic since we + // might not be able to lower the result. We can't downgrade since that + // would violate memory model. Also, align 0 is an error for atomics. + if (SawUnorderedAtomic && SawNotAtomic) + return false; // If we couldn't prove we can hoist the load, bail. if (!DereferenceableInPH) @@ -1106,10 +1234,15 @@ bool llvm::promoteLoopAccessesToScalars( // stores along paths which originally didn't have them without violating the // memory model. if (!SafeToInsertStore) { - Value *Object = GetUnderlyingObject(SomePtr, MDL); - SafeToInsertStore = - (isAllocLikeFn(Object, TLI) || isa<AllocaInst>(Object)) && + // If this is a known non-escaping object, it is safe to insert the stores. + if (IsKnownNonEscapingObject) + SafeToInsertStore = true; + else { + Value *Object = GetUnderlyingObject(SomePtr, MDL); + SafeToInsertStore = + (isAllocLikeFn(Object, TLI) || isa<AllocaInst>(Object)) && !PointerMayBeCaptured(Object, true, true); + } } // If we've still failed to prove we can sink the store, give up. @@ -1134,12 +1267,15 @@ bool llvm::promoteLoopAccessesToScalars( SmallVector<PHINode *, 16> NewPHIs; SSAUpdater SSA(&NewPHIs); LoopPromoter Promoter(SomePtr, LoopUses, SSA, PointerMustAliases, ExitBlocks, - InsertPts, PIC, *CurAST, *LI, DL, Alignment, AATags); + InsertPts, PIC, *CurAST, *LI, DL, Alignment, + SawUnorderedAtomic, AATags); // Set up the preheader to have a definition of the value. It is the live-out // value from the preheader that uses in the loop will use. LoadInst *PreheaderLoad = new LoadInst( SomePtr, SomePtr->getName() + ".promoted", Preheader->getTerminator()); + if (SawUnorderedAtomic) + PreheaderLoad->setOrdering(AtomicOrdering::Unordered); PreheaderLoad->setAlignment(Alignment); PreheaderLoad->setDebugLoc(DL); if (AATags) diff --git a/contrib/llvm/lib/Transforms/Scalar/LoadCombine.cpp b/contrib/llvm/lib/Transforms/Scalar/LoadCombine.cpp deleted file mode 100644 index 389f1c5..0000000 --- a/contrib/llvm/lib/Transforms/Scalar/LoadCombine.cpp +++ /dev/null @@ -1,280 +0,0 @@ -//===- LoadCombine.cpp - Combine Adjacent Loads ---------------------------===// -// -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. -// -//===----------------------------------------------------------------------===// -/// \file -/// This transformation combines adjacent loads. -/// -//===----------------------------------------------------------------------===// - -#include "llvm/Transforms/Scalar.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/Statistic.h" -#include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/Analysis/AliasSetTracker.h" -#include "llvm/Analysis/GlobalsModRef.h" -#include "llvm/Analysis/TargetFolder.h" -#include "llvm/IR/DataLayout.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Instructions.h" -#include "llvm/IR/Module.h" -#include "llvm/Pass.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/MathExtras.h" -#include "llvm/Support/raw_ostream.h" - -using namespace llvm; - -#define DEBUG_TYPE "load-combine" - -STATISTIC(NumLoadsAnalyzed, "Number of loads analyzed for combining"); -STATISTIC(NumLoadsCombined, "Number of loads combined"); - -#define LDCOMBINE_NAME "Combine Adjacent Loads" - -namespace { -struct PointerOffsetPair { - Value *Pointer; - APInt Offset; -}; - -struct LoadPOPPair { - LoadInst *Load; - PointerOffsetPair POP; - /// \brief The new load needs to be created before the first load in IR order. - unsigned InsertOrder; -}; - -class LoadCombine : public BasicBlockPass { - LLVMContext *C; - AliasAnalysis *AA; - -public: - LoadCombine() : BasicBlockPass(ID), C(nullptr), AA(nullptr) { - initializeLoadCombinePass(*PassRegistry::getPassRegistry()); - } - - using llvm::Pass::doInitialization; - bool doInitialization(Function &) override; - bool runOnBasicBlock(BasicBlock &BB) override; - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesCFG(); - AU.addRequired<AAResultsWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); - } - - StringRef getPassName() const override { return LDCOMBINE_NAME; } - static char ID; - - typedef IRBuilder<TargetFolder> BuilderTy; - -private: - BuilderTy *Builder; - - PointerOffsetPair getPointerOffsetPair(LoadInst &); - bool combineLoads(DenseMap<const Value *, SmallVector<LoadPOPPair, 8>> &); - bool aggregateLoads(SmallVectorImpl<LoadPOPPair> &); - bool combineLoads(SmallVectorImpl<LoadPOPPair> &); -}; -} - -bool LoadCombine::doInitialization(Function &F) { - DEBUG(dbgs() << "LoadCombine function: " << F.getName() << "\n"); - C = &F.getContext(); - return true; -} - -PointerOffsetPair LoadCombine::getPointerOffsetPair(LoadInst &LI) { - auto &DL = LI.getModule()->getDataLayout(); - - PointerOffsetPair POP; - POP.Pointer = LI.getPointerOperand(); - unsigned BitWidth = DL.getPointerSizeInBits(LI.getPointerAddressSpace()); - POP.Offset = APInt(BitWidth, 0); - - while (isa<BitCastInst>(POP.Pointer) || isa<GetElementPtrInst>(POP.Pointer)) { - if (auto *GEP = dyn_cast<GetElementPtrInst>(POP.Pointer)) { - APInt LastOffset = POP.Offset; - if (!GEP->accumulateConstantOffset(DL, POP.Offset)) { - // Can't handle GEPs with variable indices. - POP.Offset = LastOffset; - return POP; - } - POP.Pointer = GEP->getPointerOperand(); - } else if (auto *BC = dyn_cast<BitCastInst>(POP.Pointer)) { - POP.Pointer = BC->getOperand(0); - } - } - return POP; -} - -bool LoadCombine::combineLoads( - DenseMap<const Value *, SmallVector<LoadPOPPair, 8>> &LoadMap) { - bool Combined = false; - for (auto &Loads : LoadMap) { - if (Loads.second.size() < 2) - continue; - std::sort(Loads.second.begin(), Loads.second.end(), - [](const LoadPOPPair &A, const LoadPOPPair &B) { - return A.POP.Offset.slt(B.POP.Offset); - }); - if (aggregateLoads(Loads.second)) - Combined = true; - } - return Combined; -} - -/// \brief Try to aggregate loads from a sorted list of loads to be combined. -/// -/// It is guaranteed that no writes occur between any of the loads. All loads -/// have the same base pointer. There are at least two loads. -bool LoadCombine::aggregateLoads(SmallVectorImpl<LoadPOPPair> &Loads) { - assert(Loads.size() >= 2 && "Insufficient loads!"); - LoadInst *BaseLoad = nullptr; - SmallVector<LoadPOPPair, 8> AggregateLoads; - bool Combined = false; - bool ValidPrevOffset = false; - APInt PrevOffset; - uint64_t PrevSize = 0; - for (auto &L : Loads) { - if (ValidPrevOffset == false) { - BaseLoad = L.Load; - PrevOffset = L.POP.Offset; - PrevSize = L.Load->getModule()->getDataLayout().getTypeStoreSize( - L.Load->getType()); - AggregateLoads.push_back(L); - ValidPrevOffset = true; - continue; - } - if (L.Load->getAlignment() > BaseLoad->getAlignment()) - continue; - APInt PrevEnd = PrevOffset + PrevSize; - if (L.POP.Offset.sgt(PrevEnd)) { - // No other load will be combinable - if (combineLoads(AggregateLoads)) - Combined = true; - AggregateLoads.clear(); - ValidPrevOffset = false; - continue; - } - if (L.POP.Offset != PrevEnd) - // This load is offset less than the size of the last load. - // FIXME: We may want to handle this case. - continue; - PrevOffset = L.POP.Offset; - PrevSize = L.Load->getModule()->getDataLayout().getTypeStoreSize( - L.Load->getType()); - AggregateLoads.push_back(L); - } - if (combineLoads(AggregateLoads)) - Combined = true; - return Combined; -} - -/// \brief Given a list of combinable load. Combine the maximum number of them. -bool LoadCombine::combineLoads(SmallVectorImpl<LoadPOPPair> &Loads) { - // Remove loads from the end while the size is not a power of 2. - unsigned TotalSize = 0; - for (const auto &L : Loads) - TotalSize += L.Load->getType()->getPrimitiveSizeInBits(); - while (TotalSize != 0 && !isPowerOf2_32(TotalSize)) - TotalSize -= Loads.pop_back_val().Load->getType()->getPrimitiveSizeInBits(); - if (Loads.size() < 2) - return false; - - DEBUG({ - dbgs() << "***** Combining Loads ******\n"; - for (const auto &L : Loads) { - dbgs() << L.POP.Offset << ": " << *L.Load << "\n"; - } - }); - - // Find first load. This is where we put the new load. - LoadPOPPair FirstLP; - FirstLP.InsertOrder = -1u; - for (const auto &L : Loads) - if (L.InsertOrder < FirstLP.InsertOrder) - FirstLP = L; - - unsigned AddressSpace = - FirstLP.POP.Pointer->getType()->getPointerAddressSpace(); - - Builder->SetInsertPoint(FirstLP.Load); - Value *Ptr = Builder->CreateConstGEP1_64( - Builder->CreatePointerCast(Loads[0].POP.Pointer, - Builder->getInt8PtrTy(AddressSpace)), - Loads[0].POP.Offset.getSExtValue()); - LoadInst *NewLoad = new LoadInst( - Builder->CreatePointerCast( - Ptr, PointerType::get(IntegerType::get(Ptr->getContext(), TotalSize), - Ptr->getType()->getPointerAddressSpace())), - Twine(Loads[0].Load->getName()) + ".combined", false, - Loads[0].Load->getAlignment(), FirstLP.Load); - - for (const auto &L : Loads) { - Builder->SetInsertPoint(L.Load); - Value *V = Builder->CreateExtractInteger( - L.Load->getModule()->getDataLayout(), NewLoad, - cast<IntegerType>(L.Load->getType()), - (L.POP.Offset - Loads[0].POP.Offset).getZExtValue(), "combine.extract"); - L.Load->replaceAllUsesWith(V); - } - - NumLoadsCombined = NumLoadsCombined + Loads.size(); - return true; -} - -bool LoadCombine::runOnBasicBlock(BasicBlock &BB) { - if (skipBasicBlock(BB)) - return false; - - AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); - - IRBuilder<TargetFolder> TheBuilder( - BB.getContext(), TargetFolder(BB.getModule()->getDataLayout())); - Builder = &TheBuilder; - - DenseMap<const Value *, SmallVector<LoadPOPPair, 8>> LoadMap; - AliasSetTracker AST(*AA); - - bool Combined = false; - unsigned Index = 0; - for (auto &I : BB) { - if (I.mayThrow() || (I.mayWriteToMemory() && AST.containsUnknown(&I))) { - if (combineLoads(LoadMap)) - Combined = true; - LoadMap.clear(); - AST.clear(); - continue; - } - LoadInst *LI = dyn_cast<LoadInst>(&I); - if (!LI) - continue; - ++NumLoadsAnalyzed; - if (!LI->isSimple() || !LI->getType()->isIntegerTy()) - continue; - auto POP = getPointerOffsetPair(*LI); - if (!POP.Pointer) - continue; - LoadMap[POP.Pointer].push_back({LI, std::move(POP), Index++}); - AST.add(LI); - } - if (combineLoads(LoadMap)) - Combined = true; - return Combined; -} - -char LoadCombine::ID = 0; - -BasicBlockPass *llvm::createLoadCombinePass() { - return new LoadCombine(); -} - -INITIALIZE_PASS_BEGIN(LoadCombine, "load-combine", LDCOMBINE_NAME, false, false) -INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) -INITIALIZE_PASS_END(LoadCombine, "load-combine", LDCOMBINE_NAME, false, false) diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopDeletion.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopDeletion.cpp index cca75a3..ac4dd44 100644 --- a/contrib/llvm/lib/Transforms/Scalar/LoopDeletion.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/LoopDeletion.cpp @@ -20,6 +20,7 @@ #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Transforms/Utils/LoopUtils.h" @@ -29,32 +30,45 @@ using namespace llvm; STATISTIC(NumDeleted, "Number of loops deleted"); -/// isLoopDead - Determined if a loop is dead. This assumes that we've already -/// checked for unique exit and exiting blocks, and that the code is in LCSSA -/// form. -bool LoopDeletionPass::isLoopDead(Loop *L, ScalarEvolution &SE, - SmallVectorImpl<BasicBlock *> &exitingBlocks, - SmallVectorImpl<BasicBlock *> &exitBlocks, - bool &Changed, BasicBlock *Preheader) { - BasicBlock *exitBlock = exitBlocks[0]; - +/// This function deletes dead loops. The caller of this function needs to +/// guarantee that the loop is infact dead. Here we handle two kinds of dead +/// loop. The first kind (\p isLoopDead) is where only invariant values from +/// within the loop are used outside of it. The second kind (\p +/// isLoopNeverExecuted) is where the loop is provably never executed. We can +/// always remove never executed loops since they will not cause any difference +/// to program behaviour. +/// +/// This also updates the relevant analysis information in \p DT, \p SE, and \p +/// LI. It also updates the loop PM if an updater struct is provided. +// TODO: This function will be used by loop-simplifyCFG as well. So, move this +// to LoopUtils.cpp +static void deleteDeadLoop(Loop *L, DominatorTree &DT, ScalarEvolution &SE, + LoopInfo &LI, LPMUpdater *Updater = nullptr); +/// Determines if a loop is dead. +/// +/// This assumes that we've already checked for unique exit and exiting blocks, +/// and that the code is in LCSSA form. +static bool isLoopDead(Loop *L, ScalarEvolution &SE, + SmallVectorImpl<BasicBlock *> &ExitingBlocks, + BasicBlock *ExitBlock, bool &Changed, + BasicBlock *Preheader) { // Make sure that all PHI entries coming from the loop are loop invariant. // Because the code is in LCSSA form, any values used outside of the loop // must pass through a PHI in the exit block, meaning that this check is // sufficient to guarantee that no loop-variant values are used outside // of the loop. - BasicBlock::iterator BI = exitBlock->begin(); + BasicBlock::iterator BI = ExitBlock->begin(); bool AllEntriesInvariant = true; bool AllOutgoingValuesSame = true; while (PHINode *P = dyn_cast<PHINode>(BI)) { - Value *incoming = P->getIncomingValueForBlock(exitingBlocks[0]); + Value *incoming = P->getIncomingValueForBlock(ExitingBlocks[0]); // Make sure all exiting blocks produce the same incoming value for the exit // block. If there are different incoming values for different exiting // blocks, then it is impossible to statically determine which value should // be used. AllOutgoingValuesSame = - all_of(makeArrayRef(exitingBlocks).slice(1), [&](BasicBlock *BB) { + all_of(makeArrayRef(ExitingBlocks).slice(1), [&](BasicBlock *BB) { return incoming == P->getIncomingValueForBlock(BB); }); @@ -78,95 +92,187 @@ bool LoopDeletionPass::isLoopDead(Loop *L, ScalarEvolution &SE, // Make sure that no instructions in the block have potential side-effects. // This includes instructions that could write to memory, and loads that are - // marked volatile. This could be made more aggressive by using aliasing - // information to identify readonly and readnone calls. - for (Loop::block_iterator LI = L->block_begin(), LE = L->block_end(); - LI != LE; ++LI) { - for (Instruction &I : **LI) { - if (I.mayHaveSideEffects()) - return false; - } - } - + // marked volatile. + for (auto &I : L->blocks()) + if (any_of(*I, [](Instruction &I) { return I.mayHaveSideEffects(); })) + return false; return true; } -/// Remove dead loops, by which we mean loops that do not impact the observable -/// behavior of the program other than finite running time. Note we do ensure -/// that this never remove a loop that might be infinite, as doing so could -/// change the halting/non-halting nature of a program. NOTE: This entire -/// process relies pretty heavily on LoopSimplify and LCSSA in order to make -/// various safety checks work. -bool LoopDeletionPass::runImpl(Loop *L, DominatorTree &DT, ScalarEvolution &SE, - LoopInfo &loopInfo) { - assert(L->isLCSSAForm(DT) && "Expected LCSSA!"); +/// This function returns true if there is no viable path from the +/// entry block to the header of \p L. Right now, it only does +/// a local search to save compile time. +static bool isLoopNeverExecuted(Loop *L) { + using namespace PatternMatch; - // We can only remove the loop if there is a preheader that we can - // branch from after removing it. - BasicBlock *preheader = L->getLoopPreheader(); - if (!preheader) - return false; + auto *Preheader = L->getLoopPreheader(); + // TODO: We can relax this constraint, since we just need a loop + // predecessor. + assert(Preheader && "Needs preheader!"); - // If LoopSimplify form is not available, stay out of trouble. - if (!L->hasDedicatedExits()) + if (Preheader == &Preheader->getParent()->getEntryBlock()) return false; + // All predecessors of the preheader should have a constant conditional + // branch, with the loop's preheader as not-taken. + for (auto *Pred: predecessors(Preheader)) { + BasicBlock *Taken, *NotTaken; + ConstantInt *Cond; + if (!match(Pred->getTerminator(), + m_Br(m_ConstantInt(Cond), Taken, NotTaken))) + return false; + if (!Cond->getZExtValue()) + std::swap(Taken, NotTaken); + if (Taken == Preheader) + return false; + } + assert(!pred_empty(Preheader) && + "Preheader should have predecessors at this point!"); + // All the predecessors have the loop preheader as not-taken target. + return true; +} +/// Remove a loop if it is dead. +/// +/// A loop is considered dead if it does not impact the observable behavior of +/// the program other than finite running time. This never removes a loop that +/// might be infinite (unless it is never executed), as doing so could change +/// the halting/non-halting nature of a program. +/// +/// This entire process relies pretty heavily on LoopSimplify form and LCSSA in +/// order to make various safety checks work. +/// +/// \returns true if any changes were made. This may mutate the loop even if it +/// is unable to delete it due to hoisting trivially loop invariant +/// instructions out of the loop. +static bool deleteLoopIfDead(Loop *L, DominatorTree &DT, ScalarEvolution &SE, + LoopInfo &LI, LPMUpdater *Updater = nullptr) { + assert(L->isLCSSAForm(DT) && "Expected LCSSA!"); + + // We can only remove the loop if there is a preheader that we can branch from + // after removing it. Also, if LoopSimplify form is not available, stay out + // of trouble. + BasicBlock *Preheader = L->getLoopPreheader(); + if (!Preheader || !L->hasDedicatedExits()) { + DEBUG(dbgs() + << "Deletion requires Loop with preheader and dedicated exits.\n"); + return false; + } // We can't remove loops that contain subloops. If the subloops were dead, // they would already have been removed in earlier executions of this pass. - if (L->begin() != L->end()) + if (L->begin() != L->end()) { + DEBUG(dbgs() << "Loop contains subloops.\n"); return false; + } - SmallVector<BasicBlock *, 4> exitingBlocks; - L->getExitingBlocks(exitingBlocks); - SmallVector<BasicBlock *, 4> exitBlocks; - L->getUniqueExitBlocks(exitBlocks); + BasicBlock *ExitBlock = L->getUniqueExitBlock(); + + if (ExitBlock && isLoopNeverExecuted(L)) { + DEBUG(dbgs() << "Loop is proven to never execute, delete it!"); + // Set incoming value to undef for phi nodes in the exit block. + BasicBlock::iterator BI = ExitBlock->begin(); + while (PHINode *P = dyn_cast<PHINode>(BI)) { + for (unsigned i = 0; i < P->getNumIncomingValues(); i++) + P->setIncomingValue(i, UndefValue::get(P->getType())); + BI++; + } + deleteDeadLoop(L, DT, SE, LI, Updater); + ++NumDeleted; + return true; + } + + // The remaining checks below are for a loop being dead because all statements + // in the loop are invariant. + SmallVector<BasicBlock *, 4> ExitingBlocks; + L->getExitingBlocks(ExitingBlocks); // We require that the loop only have a single exit block. Otherwise, we'd // be in the situation of needing to be able to solve statically which exit // block will be branched to, or trying to preserve the branching logic in // a loop invariant manner. - if (exitBlocks.size() != 1) + if (!ExitBlock) { + DEBUG(dbgs() << "Deletion requires single exit block\n"); return false; - + } // Finally, we have to check that the loop really is dead. bool Changed = false; - if (!isLoopDead(L, SE, exitingBlocks, exitBlocks, Changed, preheader)) + if (!isLoopDead(L, SE, ExitingBlocks, ExitBlock, Changed, Preheader)) { + DEBUG(dbgs() << "Loop is not invariant, cannot delete.\n"); return Changed; + } // Don't remove loops for which we can't solve the trip count. // They could be infinite, in which case we'd be changing program behavior. const SCEV *S = SE.getMaxBackedgeTakenCount(L); - if (isa<SCEVCouldNotCompute>(S)) + if (isa<SCEVCouldNotCompute>(S)) { + DEBUG(dbgs() << "Could not compute SCEV MaxBackedgeTakenCount.\n"); return Changed; + } + + DEBUG(dbgs() << "Loop is invariant, delete it!"); + deleteDeadLoop(L, DT, SE, LI, Updater); + ++NumDeleted; + + return true; +} + +static void deleteDeadLoop(Loop *L, DominatorTree &DT, ScalarEvolution &SE, + LoopInfo &LI, LPMUpdater *Updater) { + assert(L->isLCSSAForm(DT) && "Expected LCSSA!"); + auto *Preheader = L->getLoopPreheader(); + assert(Preheader && "Preheader should exist!"); // Now that we know the removal is safe, remove the loop by changing the // branch from the preheader to go to the single exit block. - BasicBlock *exitBlock = exitBlocks[0]; - + // // Because we're deleting a large chunk of code at once, the sequence in which - // we remove things is very important to avoid invalidation issues. Don't - // mess with this unless you have good reason and know what you're doing. + // we remove things is very important to avoid invalidation issues. + + // If we have an LPM updater, tell it about the loop being removed. + if (Updater) + Updater->markLoopAsDeleted(*L); // Tell ScalarEvolution that the loop is deleted. Do this before // deleting the loop so that ScalarEvolution can look at the loop // to determine what it needs to clean up. SE.forgetLoop(L); - // Connect the preheader directly to the exit block. - TerminatorInst *TI = preheader->getTerminator(); - TI->replaceUsesOfWith(L->getHeader(), exitBlock); + auto *ExitBlock = L->getUniqueExitBlock(); + assert(ExitBlock && "Should have a unique exit block!"); - // Rewrite phis in the exit block to get their inputs from - // the preheader instead of the exiting block. - BasicBlock *exitingBlock = exitingBlocks[0]; - BasicBlock::iterator BI = exitBlock->begin(); + assert(L->hasDedicatedExits() && "Loop should have dedicated exits!"); + + // Connect the preheader directly to the exit block. + // Even when the loop is never executed, we cannot remove the edge from the + // source block to the exit block. Consider the case where the unexecuted loop + // branches back to an outer loop. If we deleted the loop and removed the edge + // coming to this inner loop, this will break the outer loop structure (by + // deleting the backedge of the outer loop). If the outer loop is indeed a + // non-loop, it will be deleted in a future iteration of loop deletion pass. + Preheader->getTerminator()->replaceUsesOfWith(L->getHeader(), ExitBlock); + + // Rewrite phis in the exit block to get their inputs from the Preheader + // instead of the exiting block. + BasicBlock::iterator BI = ExitBlock->begin(); while (PHINode *P = dyn_cast<PHINode>(BI)) { - int j = P->getBasicBlockIndex(exitingBlock); - assert(j >= 0 && "Can't find exiting block in exit block's phi node!"); - P->setIncomingBlock(j, preheader); - for (unsigned i = 1; i < exitingBlocks.size(); ++i) - P->removeIncomingValue(exitingBlocks[i]); + // Set the zero'th element of Phi to be from the preheader and remove all + // other incoming values. Given the loop has dedicated exits, all other + // incoming values must be from the exiting blocks. + int PredIndex = 0; + P->setIncomingBlock(PredIndex, Preheader); + // Removes all incoming values from all other exiting blocks (including + // duplicate values from an exiting block). + // Nuke all entries except the zero'th entry which is the preheader entry. + // NOTE! We need to remove Incoming Values in the reverse order as done + // below, to keep the indices valid for deletion (removeIncomingValues + // updates getNumIncomingValues and shifts all values down into the operand + // being deleted). + for (unsigned i = 0, e = P->getNumIncomingValues() - 1; i != e; ++i) + P->removeIncomingValue(e-i, false); + + assert((P->getNumIncomingValues() == 1 && + P->getIncomingBlock(PredIndex) == Preheader) && + "Should have exactly one value and that's from the preheader!"); ++BI; } @@ -175,11 +281,11 @@ bool LoopDeletionPass::runImpl(Loop *L, DominatorTree &DT, ScalarEvolution &SE, SmallVector<DomTreeNode*, 8> ChildNodes; for (Loop::block_iterator LI = L->block_begin(), LE = L->block_end(); LI != LE; ++LI) { - // Move all of the block's children to be children of the preheader, which + // Move all of the block's children to be children of the Preheader, which // allows us to remove the domtree entry for the block. ChildNodes.insert(ChildNodes.begin(), DT[*LI]->begin(), DT[*LI]->end()); for (DomTreeNode *ChildNode : ChildNodes) { - DT.changeImmediateDominator(ChildNode, DT[preheader]); + DT.changeImmediateDominator(ChildNode, DT[Preheader]); } ChildNodes.clear(); @@ -204,22 +310,19 @@ bool LoopDeletionPass::runImpl(Loop *L, DominatorTree &DT, ScalarEvolution &SE, SmallPtrSet<BasicBlock *, 8> blocks; blocks.insert(L->block_begin(), L->block_end()); for (BasicBlock *BB : blocks) - loopInfo.removeBlock(BB); + LI.removeBlock(BB); // The last step is to update LoopInfo now that we've eliminated this loop. - loopInfo.markAsRemoved(L); - Changed = true; - - ++NumDeleted; - - return Changed; + LI.markAsRemoved(L); } PreservedAnalyses LoopDeletionPass::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, - LPMUpdater &) { - bool Changed = runImpl(&L, AR.DT, AR.SE, AR.LI); - if (!Changed) + LPMUpdater &Updater) { + + DEBUG(dbgs() << "Analyzing Loop for deletion: "); + DEBUG(L.dump()); + if (!deleteLoopIfDead(&L, AR.DT, AR.SE, AR.LI, &Updater)) return PreservedAnalyses::all(); return getLoopPassPreservedAnalyses(); @@ -254,11 +357,11 @@ Pass *llvm::createLoopDeletionPass() { return new LoopDeletionLegacyPass(); } bool LoopDeletionLegacyPass::runOnLoop(Loop *L, LPPassManager &) { if (skipLoop(L)) return false; - DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - LoopInfo &loopInfo = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - LoopDeletionPass Impl; - return Impl.runImpl(L, DT, SE, loopInfo); + DEBUG(dbgs() << "Analyzing Loop for deletion: "); + DEBUG(L->dump()); + return deleteLoopIfDead(L, DT, SE, LI); } diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopDistribute.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopDistribute.cpp index 19716b2..3624bba 100644 --- a/contrib/llvm/lib/Transforms/Scalar/LoopDistribute.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/LoopDistribute.cpp @@ -812,29 +812,29 @@ private: const RuntimePointerChecking *RtPtrChecking) { SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks; - std::copy_if(AllChecks.begin(), AllChecks.end(), std::back_inserter(Checks), - [&](const RuntimePointerChecking::PointerCheck &Check) { - for (unsigned PtrIdx1 : Check.first->Members) - for (unsigned PtrIdx2 : Check.second->Members) - // Only include this check if there is a pair of pointers - // that require checking and the pointers fall into - // separate partitions. - // - // (Note that we already know at this point that the two - // pointer groups need checking but it doesn't follow - // that each pair of pointers within the two groups need - // checking as well. - // - // In other words we don't want to include a check just - // because there is a pair of pointers between the two - // pointer groups that require checks and a different - // pair whose pointers fall into different partitions.) - if (RtPtrChecking->needsChecking(PtrIdx1, PtrIdx2) && - !RuntimePointerChecking::arePointersInSamePartition( - PtrToPartition, PtrIdx1, PtrIdx2)) - return true; - return false; - }); + copy_if(AllChecks, std::back_inserter(Checks), + [&](const RuntimePointerChecking::PointerCheck &Check) { + for (unsigned PtrIdx1 : Check.first->Members) + for (unsigned PtrIdx2 : Check.second->Members) + // Only include this check if there is a pair of pointers + // that require checking and the pointers fall into + // separate partitions. + // + // (Note that we already know at this point that the two + // pointer groups need checking but it doesn't follow + // that each pair of pointers within the two groups need + // checking as well. + // + // In other words we don't want to include a check just + // because there is a pair of pointers between the two + // pointer groups that require checks and a different + // pair whose pointers fall into different partitions.) + if (RtPtrChecking->needsChecking(PtrIdx1, PtrIdx2) && + !RuntimePointerChecking::arePointersInSamePartition( + PtrToPartition, PtrIdx1, PtrIdx2)) + return true; + return false; + }); return Checks; } diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp index 5fec51c..4a6a35c 100644 --- a/contrib/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -110,6 +110,16 @@ private: bool HasMemset; bool HasMemsetPattern; bool HasMemcpy; + /// Return code for isLegalStore() + enum LegalStoreKind { + None = 0, + Memset, + MemsetPattern, + Memcpy, + UnorderedAtomicMemcpy, + DontUse // Dummy retval never to be used. Allows catching errors in retval + // handling. + }; /// \name Countable Loop Idiom Handling /// @{ @@ -119,8 +129,7 @@ private: SmallVectorImpl<BasicBlock *> &ExitBlocks); void collectStores(BasicBlock *BB); - bool isLegalStore(StoreInst *SI, bool &ForMemset, bool &ForMemsetPattern, - bool &ForMemcpy); + LegalStoreKind isLegalStore(StoreInst *SI); bool processLoopStores(SmallVectorImpl<StoreInst *> &SL, const SCEV *BECount, bool ForMemset); bool processLoopMemSet(MemSetInst *MSI, const SCEV *BECount); @@ -144,6 +153,10 @@ private: bool recognizePopcount(); void transformLoopToPopcount(BasicBlock *PreCondBB, Instruction *CntInst, PHINode *CntPhi, Value *Var); + bool recognizeAndInsertCTLZ(); + void transformLoopToCountable(BasicBlock *PreCondBB, Instruction *CntInst, + PHINode *CntPhi, Value *Var, const DebugLoc DL, + bool ZeroCheck, bool IsCntPhiUsedOutsideLoop); /// @} }; @@ -236,9 +249,9 @@ bool LoopIdiomRecognize::runOnLoop(Loop *L) { ApplyCodeSizeHeuristics = L->getHeader()->getParent()->optForSize() && UseLIRCodeSizeHeurs; - HasMemset = TLI->has(LibFunc::memset); - HasMemsetPattern = TLI->has(LibFunc::memset_pattern16); - HasMemcpy = TLI->has(LibFunc::memcpy); + HasMemset = TLI->has(LibFunc_memset); + HasMemsetPattern = TLI->has(LibFunc_memset_pattern16); + HasMemcpy = TLI->has(LibFunc_memcpy); if (HasMemset || HasMemsetPattern || HasMemcpy) if (SE->hasLoopInvariantBackedgeTakenCount(L)) @@ -339,15 +352,24 @@ static Constant *getMemSetPatternValue(Value *V, const DataLayout *DL) { return ConstantArray::get(AT, std::vector<Constant *>(ArraySize, C)); } -bool LoopIdiomRecognize::isLegalStore(StoreInst *SI, bool &ForMemset, - bool &ForMemsetPattern, bool &ForMemcpy) { +LoopIdiomRecognize::LegalStoreKind +LoopIdiomRecognize::isLegalStore(StoreInst *SI) { + // Don't touch volatile stores. - if (!SI->isSimple()) - return false; + if (SI->isVolatile()) + return LegalStoreKind::None; + // We only want simple or unordered-atomic stores. + if (!SI->isUnordered()) + return LegalStoreKind::None; + + // Don't convert stores of non-integral pointer types to memsets (which stores + // integers). + if (DL->isNonIntegralPointerType(SI->getValueOperand()->getType())) + return LegalStoreKind::None; // Avoid merging nontemporal stores. if (SI->getMetadata(LLVMContext::MD_nontemporal)) - return false; + return LegalStoreKind::None; Value *StoredVal = SI->getValueOperand(); Value *StorePtr = SI->getPointerOperand(); @@ -355,7 +377,7 @@ bool LoopIdiomRecognize::isLegalStore(StoreInst *SI, bool &ForMemset, // Reject stores that are so large that they overflow an unsigned. uint64_t SizeInBits = DL->getTypeSizeInBits(StoredVal->getType()); if ((SizeInBits & 7) || (SizeInBits >> 32) != 0) - return false; + return LegalStoreKind::None; // See if the pointer expression is an AddRec like {base,+,1} on the current // loop, which indicates a strided store. If we have something else, it's a @@ -363,11 +385,11 @@ bool LoopIdiomRecognize::isLegalStore(StoreInst *SI, bool &ForMemset, const SCEVAddRecExpr *StoreEv = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(StorePtr)); if (!StoreEv || StoreEv->getLoop() != CurLoop || !StoreEv->isAffine()) - return false; + return LegalStoreKind::None; // Check to see if we have a constant stride. if (!isa<SCEVConstant>(StoreEv->getOperand(1))) - return false; + return LegalStoreKind::None; // See if the store can be turned into a memset. @@ -378,22 +400,23 @@ bool LoopIdiomRecognize::isLegalStore(StoreInst *SI, bool &ForMemset, Value *SplatValue = isBytewiseValue(StoredVal); Constant *PatternValue = nullptr; + // Note: memset and memset_pattern on unordered-atomic is yet not supported + bool UnorderedAtomic = SI->isUnordered() && !SI->isSimple(); + // If we're allowed to form a memset, and the stored value would be // acceptable for memset, use it. - if (HasMemset && SplatValue && + if (!UnorderedAtomic && HasMemset && SplatValue && // Verify that the stored value is loop invariant. If not, we can't // promote the memset. CurLoop->isLoopInvariant(SplatValue)) { // It looks like we can use SplatValue. - ForMemset = true; - return true; - } else if (HasMemsetPattern && + return LegalStoreKind::Memset; + } else if (!UnorderedAtomic && HasMemsetPattern && // Don't create memset_pattern16s with address spaces. StorePtr->getType()->getPointerAddressSpace() == 0 && (PatternValue = getMemSetPatternValue(StoredVal, DL))) { // It looks like we can use PatternValue! - ForMemsetPattern = true; - return true; + return LegalStoreKind::MemsetPattern; } // Otherwise, see if the store can be turned into a memcpy. @@ -403,12 +426,17 @@ bool LoopIdiomRecognize::isLegalStore(StoreInst *SI, bool &ForMemset, APInt Stride = getStoreStride(StoreEv); unsigned StoreSize = getStoreSizeInBytes(SI, DL); if (StoreSize != Stride && StoreSize != -Stride) - return false; + return LegalStoreKind::None; // The store must be feeding a non-volatile load. LoadInst *LI = dyn_cast<LoadInst>(SI->getValueOperand()); - if (!LI || !LI->isSimple()) - return false; + + // Only allow non-volatile loads + if (!LI || LI->isVolatile()) + return LegalStoreKind::None; + // Only allow simple or unordered-atomic loads + if (!LI->isUnordered()) + return LegalStoreKind::None; // See if the pointer expression is an AddRec like {base,+,1} on the current // loop, which indicates a strided load. If we have something else, it's a @@ -416,18 +444,19 @@ bool LoopIdiomRecognize::isLegalStore(StoreInst *SI, bool &ForMemset, const SCEVAddRecExpr *LoadEv = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(LI->getPointerOperand())); if (!LoadEv || LoadEv->getLoop() != CurLoop || !LoadEv->isAffine()) - return false; + return LegalStoreKind::None; // The store and load must share the same stride. if (StoreEv->getOperand(1) != LoadEv->getOperand(1)) - return false; + return LegalStoreKind::None; // Success. This store can be converted into a memcpy. - ForMemcpy = true; - return true; + UnorderedAtomic = UnorderedAtomic || LI->isAtomic(); + return UnorderedAtomic ? LegalStoreKind::UnorderedAtomicMemcpy + : LegalStoreKind::Memcpy; } // This store can't be transformed into a memset/memcpy. - return false; + return LegalStoreKind::None; } void LoopIdiomRecognize::collectStores(BasicBlock *BB) { @@ -439,24 +468,29 @@ void LoopIdiomRecognize::collectStores(BasicBlock *BB) { if (!SI) continue; - bool ForMemset = false; - bool ForMemsetPattern = false; - bool ForMemcpy = false; // Make sure this is a strided store with a constant stride. - if (!isLegalStore(SI, ForMemset, ForMemsetPattern, ForMemcpy)) - continue; - - // Save the store locations. - if (ForMemset) { + switch (isLegalStore(SI)) { + case LegalStoreKind::None: + // Nothing to do + break; + case LegalStoreKind::Memset: { // Find the base pointer. Value *Ptr = GetUnderlyingObject(SI->getPointerOperand(), *DL); StoreRefsForMemset[Ptr].push_back(SI); - } else if (ForMemsetPattern) { + } break; + case LegalStoreKind::MemsetPattern: { // Find the base pointer. Value *Ptr = GetUnderlyingObject(SI->getPointerOperand(), *DL); StoreRefsForMemsetPattern[Ptr].push_back(SI); - } else if (ForMemcpy) + } break; + case LegalStoreKind::Memcpy: + case LegalStoreKind::UnorderedAtomicMemcpy: StoreRefsForMemcpy.push_back(SI); + break; + default: + assert(false && "unhandled return value"); + break; + } } } @@ -494,7 +528,7 @@ bool LoopIdiomRecognize::runOnLoopBlock( Instruction *Inst = &*I++; // Look for memset instructions, which may be optimized to a larger memset. if (MemSetInst *MSI = dyn_cast<MemSetInst>(Inst)) { - WeakVH InstPtr(&*I); + WeakTrackingVH InstPtr(&*I); if (!processLoopMemSet(MSI, BECount)) continue; MadeChange = true; @@ -778,6 +812,11 @@ bool LoopIdiomRecognize::processLoopStridedStore( if (NegStride) Start = getStartForNegStride(Start, BECount, IntPtr, StoreSize, SE); + // TODO: ideally we should still be able to generate memset if SCEV expander + // is taught to generate the dependencies at the latest point. + if (!isSafeToExpand(Start, *SE)) + return false; + // Okay, we have a strided store "p[i]" of a splattable value. We can turn // this into a memset in the loop preheader now if we want. However, this // would be unsafe to do if there is anything else in the loop that may read @@ -809,6 +848,11 @@ bool LoopIdiomRecognize::processLoopStridedStore( SCEV::FlagNUW); } + // TODO: ideally we should still be able to generate memset if SCEV expander + // is taught to generate the dependencies at the latest point. + if (!isSafeToExpand(NumBytesS, *SE)) + return false; + Value *NumBytes = Expander.expandCodeFor(NumBytesS, IntPtr, Preheader->getTerminator()); @@ -823,7 +867,7 @@ bool LoopIdiomRecognize::processLoopStridedStore( Module *M = TheStore->getModule(); Value *MSP = M->getOrInsertFunction("memset_pattern16", Builder.getVoidTy(), - Int8PtrTy, Int8PtrTy, IntPtr, (void *)nullptr); + Int8PtrTy, Int8PtrTy, IntPtr); inferLibFuncAttributes(*M->getFunction("memset_pattern16"), *TLI); // Otherwise we should form a memset_pattern16. PatternValue is known to be @@ -851,10 +895,10 @@ bool LoopIdiomRecognize::processLoopStridedStore( /// If the stored value is a strided load in the same loop with the same stride /// this may be transformable into a memcpy. This kicks in for stuff like -/// for (i) A[i] = B[i]; +/// for (i) A[i] = B[i]; bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI, const SCEV *BECount) { - assert(SI->isSimple() && "Expected only non-volatile stores."); + assert(SI->isUnordered() && "Expected only non-volatile non-ordered stores."); Value *StorePtr = SI->getPointerOperand(); const SCEVAddRecExpr *StoreEv = cast<SCEVAddRecExpr>(SE->getSCEV(StorePtr)); @@ -864,7 +908,7 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI, // The store must be feeding a non-volatile load. LoadInst *LI = cast<LoadInst>(SI->getValueOperand()); - assert(LI->isSimple() && "Expected only non-volatile stores."); + assert(LI->isUnordered() && "Expected only non-volatile non-ordered loads."); // See if the pointer expression is an AddRec like {base,+,1} on the current // loop, which indicates a strided load. If we have something else, it's a @@ -938,6 +982,7 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI, const SCEV *NumBytesS = SE->getAddExpr(BECount, SE->getOne(IntPtrTy), SCEV::FlagNUW); + if (StoreSize != 1) NumBytesS = SE->getMulExpr(NumBytesS, SE->getConstant(IntPtrTy, StoreSize), SCEV::FlagNUW); @@ -945,9 +990,37 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI, Value *NumBytes = Expander.expandCodeFor(NumBytesS, IntPtrTy, Preheader->getTerminator()); - CallInst *NewCall = - Builder.CreateMemCpy(StoreBasePtr, LoadBasePtr, NumBytes, - std::min(SI->getAlignment(), LI->getAlignment())); + unsigned Align = std::min(SI->getAlignment(), LI->getAlignment()); + CallInst *NewCall = nullptr; + // Check whether to generate an unordered atomic memcpy: + // If the load or store are atomic, then they must neccessarily be unordered + // by previous checks. + if (!SI->isAtomic() && !LI->isAtomic()) + NewCall = Builder.CreateMemCpy(StoreBasePtr, LoadBasePtr, NumBytes, Align); + else { + // We cannot allow unaligned ops for unordered load/store, so reject + // anything where the alignment isn't at least the element size. + if (Align < StoreSize) + return false; + + // If the element.atomic memcpy is not lowered into explicit + // loads/stores later, then it will be lowered into an element-size + // specific lib call. If the lib call doesn't exist for our store size, then + // we shouldn't generate the memcpy. + if (StoreSize > TTI->getAtomicMemIntrinsicMaxElementSize()) + return false; + + NewCall = Builder.CreateElementUnorderedAtomicMemCpy( + StoreBasePtr, LoadBasePtr, NumBytes, StoreSize); + + // Propagate alignment info onto the pointer args. Note that unordered + // atomic loads/stores are *required* by the spec to have an alignment + // but non-atomic loads/stores may not. + NewCall->addParamAttr(0, Attribute::getWithAlignment(NewCall->getContext(), + SI->getAlignment())); + NewCall->addParamAttr(1, Attribute::getWithAlignment(NewCall->getContext(), + LI->getAlignment())); + } NewCall->setDebugLoc(SI->getDebugLoc()); DEBUG(dbgs() << " Formed memcpy: " << *NewCall << "\n" @@ -979,7 +1052,7 @@ bool LoopIdiomRecognize::avoidLIRForMultiBlockLoop(bool IsMemset, } bool LoopIdiomRecognize::runOnNoncountableLoop() { - return recognizePopcount(); + return recognizePopcount() || recognizeAndInsertCTLZ(); } /// Check if the given conditional branch is based on the comparison between @@ -1007,6 +1080,17 @@ static Value *matchCondition(BranchInst *BI, BasicBlock *LoopEntry) { return nullptr; } +// Check if the recurrence variable `VarX` is in the right form to create +// the idiom. Returns the value coerced to a PHINode if so. +static PHINode *getRecurrenceVar(Value *VarX, Instruction *DefX, + BasicBlock *LoopEntry) { + auto *PhiX = dyn_cast<PHINode>(VarX); + if (PhiX && PhiX->getParent() == LoopEntry && + (PhiX->getOperand(0) == DefX || PhiX->getOperand(1) == DefX)) + return PhiX; + return nullptr; +} + /// Return true iff the idiom is detected in the loop. /// /// Additionally: @@ -1076,19 +1160,15 @@ static bool detectPopcountIdiom(Loop *CurLoop, BasicBlock *PreCondBB, if (!Dec || !((SubInst->getOpcode() == Instruction::Sub && Dec->isOne()) || (SubInst->getOpcode() == Instruction::Add && - Dec->isAllOnesValue()))) { + Dec->isMinusOne()))) { return false; } } // step 3: Check the recurrence of variable X - { - PhiX = dyn_cast<PHINode>(VarX1); - if (!PhiX || - (PhiX->getOperand(0) != DefX2 && PhiX->getOperand(1) != DefX2)) { - return false; - } - } + PhiX = getRecurrenceVar(VarX1, DefX2, LoopEntry); + if (!PhiX) + return false; // step 4: Find the instruction which count the population: cnt2 = cnt1 + 1 { @@ -1104,8 +1184,8 @@ static bool detectPopcountIdiom(Loop *CurLoop, BasicBlock *PreCondBB, if (!Inc || !Inc->isOne()) continue; - PHINode *Phi = dyn_cast<PHINode>(Inst->getOperand(0)); - if (!Phi || Phi->getParent() != LoopEntry) + PHINode *Phi = getRecurrenceVar(Inst->getOperand(0), Inst, LoopEntry); + if (!Phi) continue; // Check if the result of the instruction is live of the loop. @@ -1144,6 +1224,169 @@ static bool detectPopcountIdiom(Loop *CurLoop, BasicBlock *PreCondBB, return true; } +/// Return true if the idiom is detected in the loop. +/// +/// Additionally: +/// 1) \p CntInst is set to the instruction Counting Leading Zeros (CTLZ) +/// or nullptr if there is no such. +/// 2) \p CntPhi is set to the corresponding phi node +/// or nullptr if there is no such. +/// 3) \p Var is set to the value whose CTLZ could be used. +/// 4) \p DefX is set to the instruction calculating Loop exit condition. +/// +/// The core idiom we are trying to detect is: +/// \code +/// if (x0 == 0) +/// goto loop-exit // the precondition of the loop +/// cnt0 = init-val; +/// do { +/// x = phi (x0, x.next); //PhiX +/// cnt = phi(cnt0, cnt.next); +/// +/// cnt.next = cnt + 1; +/// ... +/// x.next = x >> 1; // DefX +/// ... +/// } while(x.next != 0); +/// +/// loop-exit: +/// \endcode +static bool detectCTLZIdiom(Loop *CurLoop, PHINode *&PhiX, + Instruction *&CntInst, PHINode *&CntPhi, + Instruction *&DefX) { + BasicBlock *LoopEntry; + Value *VarX = nullptr; + + DefX = nullptr; + PhiX = nullptr; + CntInst = nullptr; + CntPhi = nullptr; + LoopEntry = *(CurLoop->block_begin()); + + // step 1: Check if the loop-back branch is in desirable form. + if (Value *T = matchCondition( + dyn_cast<BranchInst>(LoopEntry->getTerminator()), LoopEntry)) + DefX = dyn_cast<Instruction>(T); + else + return false; + + // step 2: detect instructions corresponding to "x.next = x >> 1" + if (!DefX || DefX->getOpcode() != Instruction::AShr) + return false; + if (ConstantInt *Shft = dyn_cast<ConstantInt>(DefX->getOperand(1))) + if (!Shft || !Shft->isOne()) + return false; + VarX = DefX->getOperand(0); + + // step 3: Check the recurrence of variable X + PhiX = getRecurrenceVar(VarX, DefX, LoopEntry); + if (!PhiX) + return false; + + // step 4: Find the instruction which count the CTLZ: cnt.next = cnt + 1 + // TODO: We can skip the step. If loop trip count is known (CTLZ), + // then all uses of "cnt.next" could be optimized to the trip count + // plus "cnt0". Currently it is not optimized. + // This step could be used to detect POPCNT instruction: + // cnt.next = cnt + (x.next & 1) + for (BasicBlock::iterator Iter = LoopEntry->getFirstNonPHI()->getIterator(), + IterE = LoopEntry->end(); + Iter != IterE; Iter++) { + Instruction *Inst = &*Iter; + if (Inst->getOpcode() != Instruction::Add) + continue; + + ConstantInt *Inc = dyn_cast<ConstantInt>(Inst->getOperand(1)); + if (!Inc || !Inc->isOne()) + continue; + + PHINode *Phi = getRecurrenceVar(Inst->getOperand(0), Inst, LoopEntry); + if (!Phi) + continue; + + CntInst = Inst; + CntPhi = Phi; + break; + } + if (!CntInst) + return false; + + return true; +} + +/// Recognize CTLZ idiom in a non-countable loop and convert the loop +/// to countable (with CTLZ trip count). +/// If CTLZ inserted as a new trip count returns true; otherwise, returns false. +bool LoopIdiomRecognize::recognizeAndInsertCTLZ() { + // Give up if the loop has multiple blocks or multiple backedges. + if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 1) + return false; + + Instruction *CntInst, *DefX; + PHINode *CntPhi, *PhiX; + if (!detectCTLZIdiom(CurLoop, PhiX, CntInst, CntPhi, DefX)) + return false; + + bool IsCntPhiUsedOutsideLoop = false; + for (User *U : CntPhi->users()) + if (!CurLoop->contains(dyn_cast<Instruction>(U))) { + IsCntPhiUsedOutsideLoop = true; + break; + } + bool IsCntInstUsedOutsideLoop = false; + for (User *U : CntInst->users()) + if (!CurLoop->contains(dyn_cast<Instruction>(U))) { + IsCntInstUsedOutsideLoop = true; + break; + } + // If both CntInst and CntPhi are used outside the loop the profitability + // is questionable. + if (IsCntInstUsedOutsideLoop && IsCntPhiUsedOutsideLoop) + return false; + + // For some CPUs result of CTLZ(X) intrinsic is undefined + // when X is 0. If we can not guarantee X != 0, we need to check this + // when expand. + bool ZeroCheck = false; + // It is safe to assume Preheader exist as it was checked in + // parent function RunOnLoop. + BasicBlock *PH = CurLoop->getLoopPreheader(); + Value *InitX = PhiX->getIncomingValueForBlock(PH); + // If we check X != 0 before entering the loop we don't need a zero + // check in CTLZ intrinsic, but only if Cnt Phi is not used outside of the + // loop (if it is used we count CTLZ(X >> 1)). + if (!IsCntPhiUsedOutsideLoop) + if (BasicBlock *PreCondBB = PH->getSinglePredecessor()) + if (BranchInst *PreCondBr = + dyn_cast<BranchInst>(PreCondBB->getTerminator())) { + if (matchCondition(PreCondBr, PH) == InitX) + ZeroCheck = true; + } + + // Check if CTLZ intrinsic is profitable. Assume it is always profitable + // if we delete the loop (the loop has only 6 instructions): + // %n.addr.0 = phi [ %n, %entry ], [ %shr, %while.cond ] + // %i.0 = phi [ %i0, %entry ], [ %inc, %while.cond ] + // %shr = ashr %n.addr.0, 1 + // %tobool = icmp eq %shr, 0 + // %inc = add nsw %i.0, 1 + // br i1 %tobool + + IRBuilder<> Builder(PH->getTerminator()); + SmallVector<const Value *, 2> Ops = + {InitX, ZeroCheck ? Builder.getTrue() : Builder.getFalse()}; + ArrayRef<const Value *> Args(Ops); + if (CurLoop->getHeader()->size() != 6 && + TTI->getIntrinsicCost(Intrinsic::ctlz, InitX->getType(), Args) > + TargetTransformInfo::TCC_Basic) + return false; + + const DebugLoc DL = DefX->getDebugLoc(); + transformLoopToCountable(PH, CntInst, CntPhi, InitX, DL, ZeroCheck, + IsCntPhiUsedOutsideLoop); + return true; +} + /// Recognizes a population count idiom in a non-countable loop. /// /// If detected, transforms the relevant code to issue the popcount intrinsic @@ -1207,6 +1450,134 @@ static CallInst *createPopcntIntrinsic(IRBuilder<> &IRBuilder, Value *Val, return CI; } +static CallInst *createCTLZIntrinsic(IRBuilder<> &IRBuilder, Value *Val, + const DebugLoc &DL, bool ZeroCheck) { + Value *Ops[] = {Val, ZeroCheck ? IRBuilder.getTrue() : IRBuilder.getFalse()}; + Type *Tys[] = {Val->getType()}; + + Module *M = IRBuilder.GetInsertBlock()->getParent()->getParent(); + Value *Func = Intrinsic::getDeclaration(M, Intrinsic::ctlz, Tys); + CallInst *CI = IRBuilder.CreateCall(Func, Ops); + CI->setDebugLoc(DL); + + return CI; +} + +/// Transform the following loop: +/// loop: +/// CntPhi = PHI [Cnt0, CntInst] +/// PhiX = PHI [InitX, DefX] +/// CntInst = CntPhi + 1 +/// DefX = PhiX >> 1 +// LOOP_BODY +/// Br: loop if (DefX != 0) +/// Use(CntPhi) or Use(CntInst) +/// +/// Into: +/// If CntPhi used outside the loop: +/// CountPrev = BitWidth(InitX) - CTLZ(InitX >> 1) +/// Count = CountPrev + 1 +/// else +/// Count = BitWidth(InitX) - CTLZ(InitX) +/// loop: +/// CntPhi = PHI [Cnt0, CntInst] +/// PhiX = PHI [InitX, DefX] +/// PhiCount = PHI [Count, Dec] +/// CntInst = CntPhi + 1 +/// DefX = PhiX >> 1 +/// Dec = PhiCount - 1 +/// LOOP_BODY +/// Br: loop if (Dec != 0) +/// Use(CountPrev + Cnt0) // Use(CntPhi) +/// or +/// Use(Count + Cnt0) // Use(CntInst) +/// +/// If LOOP_BODY is empty the loop will be deleted. +/// If CntInst and DefX are not used in LOOP_BODY they will be removed. +void LoopIdiomRecognize::transformLoopToCountable( + BasicBlock *Preheader, Instruction *CntInst, PHINode *CntPhi, Value *InitX, + const DebugLoc DL, bool ZeroCheck, bool IsCntPhiUsedOutsideLoop) { + BranchInst *PreheaderBr = dyn_cast<BranchInst>(Preheader->getTerminator()); + + // Step 1: Insert the CTLZ instruction at the end of the preheader block + // Count = BitWidth - CTLZ(InitX); + // If there are uses of CntPhi create: + // CountPrev = BitWidth - CTLZ(InitX >> 1); + IRBuilder<> Builder(PreheaderBr); + Builder.SetCurrentDebugLocation(DL); + Value *CTLZ, *Count, *CountPrev, *NewCount, *InitXNext; + + if (IsCntPhiUsedOutsideLoop) + InitXNext = Builder.CreateAShr(InitX, + ConstantInt::get(InitX->getType(), 1)); + else + InitXNext = InitX; + CTLZ = createCTLZIntrinsic(Builder, InitXNext, DL, ZeroCheck); + Count = Builder.CreateSub( + ConstantInt::get(CTLZ->getType(), + CTLZ->getType()->getIntegerBitWidth()), + CTLZ); + if (IsCntPhiUsedOutsideLoop) { + CountPrev = Count; + Count = Builder.CreateAdd( + CountPrev, + ConstantInt::get(CountPrev->getType(), 1)); + } + if (IsCntPhiUsedOutsideLoop) + NewCount = Builder.CreateZExtOrTrunc(CountPrev, + cast<IntegerType>(CntInst->getType())); + else + NewCount = Builder.CreateZExtOrTrunc(Count, + cast<IntegerType>(CntInst->getType())); + + // If the CTLZ counter's initial value is not zero, insert Add Inst. + Value *CntInitVal = CntPhi->getIncomingValueForBlock(Preheader); + ConstantInt *InitConst = dyn_cast<ConstantInt>(CntInitVal); + if (!InitConst || !InitConst->isZero()) + NewCount = Builder.CreateAdd(NewCount, CntInitVal); + + // Step 2: Insert new IV and loop condition: + // loop: + // ... + // PhiCount = PHI [Count, Dec] + // ... + // Dec = PhiCount - 1 + // ... + // Br: loop if (Dec != 0) + BasicBlock *Body = *(CurLoop->block_begin()); + auto *LbBr = dyn_cast<BranchInst>(Body->getTerminator()); + ICmpInst *LbCond = cast<ICmpInst>(LbBr->getCondition()); + Type *Ty = Count->getType(); + + PHINode *TcPhi = PHINode::Create(Ty, 2, "tcphi", &Body->front()); + + Builder.SetInsertPoint(LbCond); + Instruction *TcDec = cast<Instruction>( + Builder.CreateSub(TcPhi, ConstantInt::get(Ty, 1), + "tcdec", false, true)); + + TcPhi->addIncoming(Count, Preheader); + TcPhi->addIncoming(TcDec, Body); + + CmpInst::Predicate Pred = + (LbBr->getSuccessor(0) == Body) ? CmpInst::ICMP_NE : CmpInst::ICMP_EQ; + LbCond->setPredicate(Pred); + LbCond->setOperand(0, TcDec); + LbCond->setOperand(1, ConstantInt::get(Ty, 0)); + + // Step 3: All the references to the original counter outside + // the loop are replaced with the NewCount -- the value returned from + // __builtin_ctlz(x). + if (IsCntPhiUsedOutsideLoop) + CntPhi->replaceUsesOutsideBlock(NewCount, Body); + else + CntInst->replaceUsesOutsideBlock(NewCount, Body); + + // step 4: Forget the "non-computable" trip-count SCEV associated with the + // loop. The loop would otherwise not be deleted even if it becomes empty. + SE->forgetLoop(CurLoop); +} + void LoopIdiomRecognize::transformLoopToPopcount(BasicBlock *PreCondBB, Instruction *CntInst, PHINode *CntPhi, Value *Var) { diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp index 69102d1..af09556 100644 --- a/contrib/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp @@ -77,7 +77,7 @@ static bool SimplifyLoopInst(Loop *L, DominatorTree *DT, LoopInfo *LI, // Don't bother simplifying unused instructions. if (!I->use_empty()) { - Value *V = SimplifyInstruction(I, DL, TLI, DT, AC); + Value *V = SimplifyInstruction(I, {DL, TLI, DT, AC}); if (V && LI->replacementPreservesLCSSAForm(I, V)) { // Mark all uses for resimplification next time round the loop. for (User *U : I->users()) @@ -189,7 +189,9 @@ PreservedAnalyses LoopInstSimplifyPass::run(Loop &L, LoopAnalysisManager &AM, if (!SimplifyLoopInst(&L, &AR.DT, &AR.LI, &AR.AC, &AR.TLI)) return PreservedAnalyses::all(); - return getLoopPassPreservedAnalyses(); + auto PA = getLoopPassPreservedAnalyses(); + PA.preserveSet<CFGAnalyses>(); + return PA; } char LoopInstSimplifyLegacyPass::ID = 0; diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopInterchange.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopInterchange.cpp index e9f84ed..2e0d8e0 100644 --- a/contrib/llvm/lib/Transforms/Scalar/LoopInterchange.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/LoopInterchange.cpp @@ -22,6 +22,7 @@ #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopIterator.h" #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/OptimizationDiagnosticInfo.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" @@ -39,7 +40,7 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/LoopUtils.h" -#include "llvm/Transforms/Utils/SSAUpdater.h" + using namespace llvm; #define DEBUG_TYPE "loop-interchange" @@ -323,9 +324,10 @@ static PHINode *getInductionVariable(Loop *L, ScalarEvolution *SE) { class LoopInterchangeLegality { public: LoopInterchangeLegality(Loop *Outer, Loop *Inner, ScalarEvolution *SE, - LoopInfo *LI, DominatorTree *DT, bool PreserveLCSSA) + LoopInfo *LI, DominatorTree *DT, bool PreserveLCSSA, + OptimizationRemarkEmitter *ORE) : OuterLoop(Outer), InnerLoop(Inner), SE(SE), LI(LI), DT(DT), - PreserveLCSSA(PreserveLCSSA), InnerLoopHasReduction(false) {} + PreserveLCSSA(PreserveLCSSA), ORE(ORE), InnerLoopHasReduction(false) {} /// Check if the loops can be interchanged. bool canInterchangeLoops(unsigned InnerLoopId, unsigned OuterLoopId, @@ -353,6 +355,8 @@ private: LoopInfo *LI; DominatorTree *DT; bool PreserveLCSSA; + /// Interface to emit optimization remarks. + OptimizationRemarkEmitter *ORE; bool InnerLoopHasReduction; }; @@ -361,8 +365,9 @@ private: /// loop. class LoopInterchangeProfitability { public: - LoopInterchangeProfitability(Loop *Outer, Loop *Inner, ScalarEvolution *SE) - : OuterLoop(Outer), InnerLoop(Inner), SE(SE) {} + LoopInterchangeProfitability(Loop *Outer, Loop *Inner, ScalarEvolution *SE, + OptimizationRemarkEmitter *ORE) + : OuterLoop(Outer), InnerLoop(Inner), SE(SE), ORE(ORE) {} /// Check if the loop interchange is profitable. bool isProfitable(unsigned InnerLoopId, unsigned OuterLoopId, @@ -376,6 +381,8 @@ private: /// Scev analysis. ScalarEvolution *SE; + /// Interface to emit optimization remarks. + OptimizationRemarkEmitter *ORE; }; /// LoopInterchangeTransform interchanges the loop. @@ -422,6 +429,9 @@ struct LoopInterchange : public FunctionPass { DependenceInfo *DI; DominatorTree *DT; bool PreserveLCSSA; + /// Interface to emit optimization remarks. + OptimizationRemarkEmitter *ORE; + LoopInterchange() : FunctionPass(ID), SE(nullptr), LI(nullptr), DI(nullptr), DT(nullptr) { initializeLoopInterchangePass(*PassRegistry::getPassRegistry()); @@ -435,6 +445,7 @@ struct LoopInterchange : public FunctionPass { AU.addRequired<DependenceAnalysisWrapperPass>(); AU.addRequiredID(LoopSimplifyID); AU.addRequiredID(LCSSAID); + AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); } bool runOnFunction(Function &F) override { @@ -446,6 +457,7 @@ struct LoopInterchange : public FunctionPass { DI = &getAnalysis<DependenceAnalysisWrapperPass>().getDI(); auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); DT = DTWP ? &DTWP->getDomTree() : nullptr; + ORE = &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); PreserveLCSSA = mustPreserveAnalysisID(LCSSAID); // Build up a worklist of loop pairs to analyze. @@ -575,18 +587,23 @@ struct LoopInterchange : public FunctionPass { Loop *OuterLoop = LoopList[OuterLoopId]; LoopInterchangeLegality LIL(OuterLoop, InnerLoop, SE, LI, DT, - PreserveLCSSA); + PreserveLCSSA, ORE); if (!LIL.canInterchangeLoops(InnerLoopId, OuterLoopId, DependencyMatrix)) { DEBUG(dbgs() << "Not interchanging Loops. Cannot prove legality\n"); return false; } DEBUG(dbgs() << "Loops are legal to interchange\n"); - LoopInterchangeProfitability LIP(OuterLoop, InnerLoop, SE); + LoopInterchangeProfitability LIP(OuterLoop, InnerLoop, SE, ORE); if (!LIP.isProfitable(InnerLoopId, OuterLoopId, DependencyMatrix)) { DEBUG(dbgs() << "Interchanging loops not profitable\n"); return false; } + ORE->emit(OptimizationRemark(DEBUG_TYPE, "Interchanged", + InnerLoop->getStartLoc(), + InnerLoop->getHeader()) + << "Loop interchanged with enclosing loop."); + LoopInterchangeTransform LIT(OuterLoop, InnerLoop, SE, LI, DT, LoopNestExit, LIL.hasInnerLoopReduction()); LIT.transform(); @@ -757,13 +774,28 @@ bool LoopInterchangeLegality::currentLimitations() { PHINode *InnerInductionVar; SmallVector<PHINode *, 8> Inductions; SmallVector<PHINode *, 8> Reductions; - if (!findInductionAndReductions(InnerLoop, Inductions, Reductions)) + if (!findInductionAndReductions(InnerLoop, Inductions, Reductions)) { + DEBUG(dbgs() << "Only inner loops with induction or reduction PHI nodes " + << "are supported currently.\n"); + ORE->emit(OptimizationRemarkMissed(DEBUG_TYPE, + "UnsupportedPHIInner", + InnerLoop->getStartLoc(), + InnerLoop->getHeader()) + << "Only inner loops with induction or reduction PHI nodes can be" + " interchange currently."); return true; + } // TODO: Currently we handle only loops with 1 induction variable. if (Inductions.size() != 1) { DEBUG(dbgs() << "We currently only support loops with 1 induction variable." << "Failed to interchange due to current limitation\n"); + ORE->emit(OptimizationRemarkMissed(DEBUG_TYPE, + "MultiInductionInner", + InnerLoop->getStartLoc(), + InnerLoop->getHeader()) + << "Only inner loops with 1 induction variable can be " + "interchanged currently."); return true; } if (Reductions.size() > 0) @@ -771,32 +803,80 @@ bool LoopInterchangeLegality::currentLimitations() { InnerInductionVar = Inductions.pop_back_val(); Reductions.clear(); - if (!findInductionAndReductions(OuterLoop, Inductions, Reductions)) + if (!findInductionAndReductions(OuterLoop, Inductions, Reductions)) { + DEBUG(dbgs() << "Only outer loops with induction or reduction PHI nodes " + << "are supported currently.\n"); + ORE->emit(OptimizationRemarkMissed(DEBUG_TYPE, + "UnsupportedPHIOuter", + OuterLoop->getStartLoc(), + OuterLoop->getHeader()) + << "Only outer loops with induction or reduction PHI nodes can be" + " interchanged currently."); return true; + } // Outer loop cannot have reduction because then loops will not be tightly // nested. - if (!Reductions.empty()) + if (!Reductions.empty()) { + DEBUG(dbgs() << "Outer loops with reductions are not supported " + << "currently.\n"); + ORE->emit(OptimizationRemarkMissed(DEBUG_TYPE, + "ReductionsOuter", + OuterLoop->getStartLoc(), + OuterLoop->getHeader()) + << "Outer loops with reductions cannot be interchangeed " + "currently."); return true; + } // TODO: Currently we handle only loops with 1 induction variable. - if (Inductions.size() != 1) + if (Inductions.size() != 1) { + DEBUG(dbgs() << "Loops with more than 1 induction variables are not " + << "supported currently.\n"); + ORE->emit(OptimizationRemarkMissed(DEBUG_TYPE, + "MultiIndutionOuter", + OuterLoop->getStartLoc(), + OuterLoop->getHeader()) + << "Only outer loops with 1 induction variable can be " + "interchanged currently."); return true; + } // TODO: Triangular loops are not handled for now. if (!isLoopStructureUnderstood(InnerInductionVar)) { DEBUG(dbgs() << "Loop structure not understood by pass\n"); + ORE->emit(OptimizationRemarkMissed(DEBUG_TYPE, + "UnsupportedStructureInner", + InnerLoop->getStartLoc(), + InnerLoop->getHeader()) + << "Inner loop structure not understood currently."); return true; } // TODO: We only handle LCSSA PHI's corresponding to reduction for now. BasicBlock *LoopExitBlock = getLoopLatchExitBlock(OuterLoopLatch, OuterLoopHeader); - if (!LoopExitBlock || !containsSafePHI(LoopExitBlock, true)) + if (!LoopExitBlock || !containsSafePHI(LoopExitBlock, true)) { + DEBUG(dbgs() << "Can only handle LCSSA PHIs in outer loops currently.\n"); + ORE->emit(OptimizationRemarkMissed(DEBUG_TYPE, + "NoLCSSAPHIOuter", + OuterLoop->getStartLoc(), + OuterLoop->getHeader()) + << "Only outer loops with LCSSA PHIs can be interchange " + "currently."); return true; + } LoopExitBlock = getLoopLatchExitBlock(InnerLoopLatch, InnerLoopHeader); - if (!LoopExitBlock || !containsSafePHI(LoopExitBlock, false)) + if (!LoopExitBlock || !containsSafePHI(LoopExitBlock, false)) { + DEBUG(dbgs() << "Can only handle LCSSA PHIs in inner loops currently.\n"); + ORE->emit(OptimizationRemarkMissed(DEBUG_TYPE, + "NoLCSSAPHIOuterInner", + InnerLoop->getStartLoc(), + InnerLoop->getHeader()) + << "Only inner loops with LCSSA PHIs can be interchange " + "currently."); return true; + } // TODO: Current limitation: Since we split the inner loop latch at the point // were induction variable is incremented (induction.next); We cannot have @@ -816,8 +896,16 @@ bool LoopInterchangeLegality::currentLimitations() { InnerIndexVarInc = dyn_cast<Instruction>(InnerInductionVar->getIncomingValue(0)); - if (!InnerIndexVarInc) + if (!InnerIndexVarInc) { + DEBUG(dbgs() << "Did not find an instruction to increment the induction " + << "variable.\n"); + ORE->emit(OptimizationRemarkMissed(DEBUG_TYPE, + "NoIncrementInInner", + InnerLoop->getStartLoc(), + InnerLoop->getHeader()) + << "The inner loop does not increment the induction variable."); return true; + } // Since we split the inner loop latch on this induction variable. Make sure // we do not have any instruction between the induction variable and branch @@ -827,19 +915,35 @@ bool LoopInterchangeLegality::currentLimitations() { for (const Instruction &I : reverse(*InnerLoopLatch)) { if (isa<BranchInst>(I) || isa<CmpInst>(I) || isa<TruncInst>(I)) continue; + // We found an instruction. If this is not induction variable then it is not // safe to split this loop latch. - if (!I.isIdenticalTo(InnerIndexVarInc)) + if (!I.isIdenticalTo(InnerIndexVarInc)) { + DEBUG(dbgs() << "Found unsupported instructions between induction " + << "variable increment and branch.\n"); + ORE->emit(OptimizationRemarkMissed(DEBUG_TYPE, + "UnsupportedInsBetweenInduction", + InnerLoop->getStartLoc(), + InnerLoop->getHeader()) + << "Found unsupported instruction between induction variable " + "increment and branch."); return true; + } FoundInduction = true; break; } // The loop latch ended and we didn't find the induction variable return as // current limitation. - if (!FoundInduction) + if (!FoundInduction) { + DEBUG(dbgs() << "Did not find the induction variable.\n"); + ORE->emit(OptimizationRemarkMissed(DEBUG_TYPE, + "NoIndutionVariable", + InnerLoop->getStartLoc(), + InnerLoop->getHeader()) + << "Did not find the induction variable."); return true; - + } return false; } @@ -851,6 +955,11 @@ bool LoopInterchangeLegality::canInterchangeLoops(unsigned InnerLoopId, DEBUG(dbgs() << "Failed interchange InnerLoopId = " << InnerLoopId << " and OuterLoopId = " << OuterLoopId << " due to dependence\n"); + ORE->emit(OptimizationRemarkMissed(DEBUG_TYPE, + "Dependence", + InnerLoop->getStartLoc(), + InnerLoop->getHeader()) + << "Cannot interchange loops due to dependences."); return false; } @@ -886,6 +995,12 @@ bool LoopInterchangeLegality::canInterchangeLoops(unsigned InnerLoopId, // Check if the loops are tightly nested. if (!tightlyNested(OuterLoop, InnerLoop)) { DEBUG(dbgs() << "Loops not tightly nested\n"); + ORE->emit(OptimizationRemarkMissed(DEBUG_TYPE, + "NotTightlyNested", + InnerLoop->getStartLoc(), + InnerLoop->getHeader()) + << "Cannot interchange loops because they are not tightly " + "nested."); return false; } @@ -981,9 +1096,18 @@ bool LoopInterchangeProfitability::isProfitable(unsigned InnerLoopId, // It is not profitable as per current cache profitability model. But check if // we can move this loop outside to improve parallelism. - bool ImprovesPar = - isProfitableForVectorization(InnerLoopId, OuterLoopId, DepMatrix); - return ImprovesPar; + if (isProfitableForVectorization(InnerLoopId, OuterLoopId, DepMatrix)) + return true; + + ORE->emit(OptimizationRemarkMissed(DEBUG_TYPE, + "InterchangeNotProfitable", + InnerLoop->getStartLoc(), + InnerLoop->getHeader()) + << "Interchanging loops is too costly (cost=" + << ore::NV("Cost", Cost) << ", threshold=" + << ore::NV("Threshold", LoopInterchangeCostThreshold) << + ") and it does not improve parallelism."); + return false; } void LoopInterchangeTransform::removeChildLoop(Loop *OuterLoop, @@ -1267,6 +1391,7 @@ INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopSimplify) INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) INITIALIZE_PASS_END(LoopInterchange, "loop-interchange", "Interchanges loops for cache reuse", false, false) diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp index 8fb5801..20b37c4 100644 --- a/contrib/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp @@ -20,13 +20,14 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Transforms/Scalar/LoopLoadElimination.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/Analysis/LoopInfo.h" @@ -45,9 +46,9 @@ #include "llvm/Support/Debug.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/LoopVersioning.h" -#include <forward_list> -#include <cassert> #include <algorithm> +#include <cassert> +#include <forward_list> #include <set> #include <tuple> #include <utility> @@ -196,8 +197,7 @@ public: continue; // Only progagate the value if they are of the same type. - if (Store->getPointerOperand()->getType() != - Load->getPointerOperand()->getType()) + if (Store->getPointerOperandType() != Load->getPointerOperandType()) continue; Candidates.emplace_front(Load, Store); @@ -373,15 +373,15 @@ public: const auto &AllChecks = LAI.getRuntimePointerChecking()->getChecks(); SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks; - std::copy_if(AllChecks.begin(), AllChecks.end(), std::back_inserter(Checks), - [&](const RuntimePointerChecking::PointerCheck &Check) { - for (auto PtrIdx1 : Check.first->Members) - for (auto PtrIdx2 : Check.second->Members) - if (needsChecking(PtrIdx1, PtrIdx2, - PtrsWrittenOnFwdingPath, CandLoadPtrs)) - return true; - return false; - }); + copy_if(AllChecks, std::back_inserter(Checks), + [&](const RuntimePointerChecking::PointerCheck &Check) { + for (auto PtrIdx1 : Check.first->Members) + for (auto PtrIdx2 : Check.second->Members) + if (needsChecking(PtrIdx1, PtrIdx2, PtrsWrittenOnFwdingPath, + CandLoadPtrs)) + return true; + return false; + }); DEBUG(dbgs() << "\nPointer Checks (count: " << Checks.size() << "):\n"); DEBUG(LAI.getRuntimePointerChecking()->printChecks(dbgs(), Checks)); @@ -558,6 +558,32 @@ private: PredicatedScalarEvolution PSE; }; +static bool +eliminateLoadsAcrossLoops(Function &F, LoopInfo &LI, DominatorTree &DT, + function_ref<const LoopAccessInfo &(Loop &)> GetLAI) { + // Build up a worklist of inner-loops to transform to avoid iterator + // invalidation. + // FIXME: This logic comes from other passes that actually change the loop + // nest structure. It isn't clear this is necessary (or useful) for a pass + // which merely optimizes the use of loads in a loop. + SmallVector<Loop *, 8> Worklist; + + for (Loop *TopLevelLoop : LI) + for (Loop *L : depth_first(TopLevelLoop)) + // We only handle inner-most loops. + if (L->empty()) + Worklist.push_back(L); + + // Now walk the identified inner loops. + bool Changed = false; + for (Loop *L : Worklist) { + // The actual work is performed by LoadEliminationForLoop. + LoadEliminationForLoop LEL(L, &LI, GetLAI(*L), &DT); + Changed |= LEL.processLoop(); + } + return Changed; +} + /// \brief The pass. Most of the work is delegated to the per-loop /// LoadEliminationForLoop class. class LoopLoadElimination : public FunctionPass { @@ -570,32 +596,14 @@ public: if (skipFunction(F)) return false; - auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - auto *LAA = &getAnalysis<LoopAccessLegacyAnalysis>(); - auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - - // Build up a worklist of inner-loops to vectorize. This is necessary as the - // act of distributing a loop creates new loops and can invalidate iterators - // across the loops. - SmallVector<Loop *, 8> Worklist; - - for (Loop *TopLevelLoop : *LI) - for (Loop *L : depth_first(TopLevelLoop)) - // We only handle inner-most loops. - if (L->empty()) - Worklist.push_back(L); - - // Now walk the identified inner loops. - bool Changed = false; - for (Loop *L : Worklist) { - const LoopAccessInfo &LAI = LAA->getInfo(L); - // The actual work is performed by LoadEliminationForLoop. - LoadEliminationForLoop LEL(L, LI, LAI, DT); - Changed |= LEL.processLoop(); - } + auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + auto &LAA = getAnalysis<LoopAccessLegacyAnalysis>(); + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); // Process each loop nest in the function. - return Changed; + return eliminateLoadsAcrossLoops( + F, LI, DT, + [&LAA](Loop &L) -> const LoopAccessInfo & { return LAA.getInfo(&L); }); } void getAnalysisUsage(AnalysisUsage &AU) const override { @@ -631,4 +639,28 @@ FunctionPass *createLoopLoadEliminationPass() { return new LoopLoadElimination(); } +PreservedAnalyses LoopLoadEliminationPass::run(Function &F, + FunctionAnalysisManager &AM) { + auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F); + auto &LI = AM.getResult<LoopAnalysis>(F); + auto &TTI = AM.getResult<TargetIRAnalysis>(F); + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); + auto &AA = AM.getResult<AAManager>(F); + auto &AC = AM.getResult<AssumptionAnalysis>(F); + + auto &LAM = AM.getResult<LoopAnalysisManagerFunctionProxy>(F).getManager(); + bool Changed = eliminateLoadsAcrossLoops( + F, LI, DT, [&](Loop &L) -> const LoopAccessInfo & { + LoopStandardAnalysisResults AR = {AA, AC, DT, LI, SE, TLI, TTI}; + return LAM.getResult<LoopAccessAnalysis>(L, AR); + }); + + if (!Changed) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + return PA; +} + } // end namespace llvm diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopPassManager.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopPassManager.cpp index 028f4bb..10f6fcd 100644 --- a/contrib/llvm/lib/Transforms/Scalar/LoopPassManager.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/LoopPassManager.cpp @@ -42,6 +42,13 @@ PassManager<Loop, LoopAnalysisManager, LoopStandardAnalysisResults &, break; } +#ifndef NDEBUG + // Verify the loop structure and LCSSA form before visiting the loop. + L.verifyLoop(); + assert(L.isRecursivelyLCSSAForm(AR.DT, AR.LI) && + "Loops must remain in LCSSA form!"); +#endif + // Update the analysis manager as each pass runs and potentially // invalidates analyses. AM.invalidate(L, PassPA); diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopPredication.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopPredication.cpp new file mode 100644 index 0000000..9b12ba1 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/LoopPredication.cpp @@ -0,0 +1,330 @@ +//===-- LoopPredication.cpp - Guard based loop predication pass -----------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// The LoopPredication pass tries to convert loop variant range checks to loop +// invariant by widening checks across loop iterations. For example, it will +// convert +// +// for (i = 0; i < n; i++) { +// guard(i < len); +// ... +// } +// +// to +// +// for (i = 0; i < n; i++) { +// guard(n - 1 < len); +// ... +// } +// +// After this transformation the condition of the guard is loop invariant, so +// loop-unswitch can later unswitch the loop by this condition which basically +// predicates the loop by the widened condition: +// +// if (n - 1 < len) +// for (i = 0; i < n; i++) { +// ... +// } +// else +// deoptimize +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/LoopPredication.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Pass.h" +#include "llvm/Support/Debug.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/LoopUtils.h" + +#define DEBUG_TYPE "loop-predication" + +using namespace llvm; + +namespace { +class LoopPredication { + /// Represents an induction variable check: + /// icmp Pred, <induction variable>, <loop invariant limit> + struct LoopICmp { + ICmpInst::Predicate Pred; + const SCEVAddRecExpr *IV; + const SCEV *Limit; + LoopICmp(ICmpInst::Predicate Pred, const SCEVAddRecExpr *IV, + const SCEV *Limit) + : Pred(Pred), IV(IV), Limit(Limit) {} + LoopICmp() {} + }; + + ScalarEvolution *SE; + + Loop *L; + const DataLayout *DL; + BasicBlock *Preheader; + + Optional<LoopICmp> parseLoopICmp(ICmpInst *ICI); + + Value *expandCheck(SCEVExpander &Expander, IRBuilder<> &Builder, + ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, + Instruction *InsertAt); + + Optional<Value *> widenICmpRangeCheck(ICmpInst *ICI, SCEVExpander &Expander, + IRBuilder<> &Builder); + bool widenGuardConditions(IntrinsicInst *II, SCEVExpander &Expander); + +public: + LoopPredication(ScalarEvolution *SE) : SE(SE){}; + bool runOnLoop(Loop *L); +}; + +class LoopPredicationLegacyPass : public LoopPass { +public: + static char ID; + LoopPredicationLegacyPass() : LoopPass(ID) { + initializeLoopPredicationLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + getLoopAnalysisUsage(AU); + } + + bool runOnLoop(Loop *L, LPPassManager &LPM) override { + if (skipLoop(L)) + return false; + auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + LoopPredication LP(SE); + return LP.runOnLoop(L); + } +}; + +char LoopPredicationLegacyPass::ID = 0; +} // end namespace llvm + +INITIALIZE_PASS_BEGIN(LoopPredicationLegacyPass, "loop-predication", + "Loop predication", false, false) +INITIALIZE_PASS_DEPENDENCY(LoopPass) +INITIALIZE_PASS_END(LoopPredicationLegacyPass, "loop-predication", + "Loop predication", false, false) + +Pass *llvm::createLoopPredicationPass() { + return new LoopPredicationLegacyPass(); +} + +PreservedAnalyses LoopPredicationPass::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, + LPMUpdater &U) { + LoopPredication LP(&AR.SE); + if (!LP.runOnLoop(&L)) + return PreservedAnalyses::all(); + + return getLoopPassPreservedAnalyses(); +} + +Optional<LoopPredication::LoopICmp> +LoopPredication::parseLoopICmp(ICmpInst *ICI) { + ICmpInst::Predicate Pred = ICI->getPredicate(); + + Value *LHS = ICI->getOperand(0); + Value *RHS = ICI->getOperand(1); + const SCEV *LHSS = SE->getSCEV(LHS); + if (isa<SCEVCouldNotCompute>(LHSS)) + return None; + const SCEV *RHSS = SE->getSCEV(RHS); + if (isa<SCEVCouldNotCompute>(RHSS)) + return None; + + // Canonicalize RHS to be loop invariant bound, LHS - a loop computable IV + if (SE->isLoopInvariant(LHSS, L)) { + std::swap(LHS, RHS); + std::swap(LHSS, RHSS); + Pred = ICmpInst::getSwappedPredicate(Pred); + } + + const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHSS); + if (!AR || AR->getLoop() != L) + return None; + + return LoopICmp(Pred, AR, RHSS); +} + +Value *LoopPredication::expandCheck(SCEVExpander &Expander, + IRBuilder<> &Builder, + ICmpInst::Predicate Pred, const SCEV *LHS, + const SCEV *RHS, Instruction *InsertAt) { + Type *Ty = LHS->getType(); + assert(Ty == RHS->getType() && "expandCheck operands have different types?"); + Value *LHSV = Expander.expandCodeFor(LHS, Ty, InsertAt); + Value *RHSV = Expander.expandCodeFor(RHS, Ty, InsertAt); + return Builder.CreateICmp(Pred, LHSV, RHSV); +} + +/// If ICI can be widened to a loop invariant condition emits the loop +/// invariant condition in the loop preheader and return it, otherwise +/// returns None. +Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, + SCEVExpander &Expander, + IRBuilder<> &Builder) { + DEBUG(dbgs() << "Analyzing ICmpInst condition:\n"); + DEBUG(ICI->dump()); + + auto RangeCheck = parseLoopICmp(ICI); + if (!RangeCheck) { + DEBUG(dbgs() << "Failed to parse the loop latch condition!\n"); + return None; + } + + ICmpInst::Predicate Pred = RangeCheck->Pred; + const SCEVAddRecExpr *IndexAR = RangeCheck->IV; + const SCEV *RHSS = RangeCheck->Limit; + + auto CanExpand = [this](const SCEV *S) { + return SE->isLoopInvariant(S, L) && isSafeToExpand(S, *SE); + }; + if (!CanExpand(RHSS)) + return None; + + DEBUG(dbgs() << "IndexAR: "); + DEBUG(IndexAR->dump()); + + bool IsIncreasing = false; + if (!SE->isMonotonicPredicate(IndexAR, Pred, IsIncreasing)) + return None; + + // If the predicate is increasing the condition can change from false to true + // as the loop progresses, in this case take the value on the first iteration + // for the widened check. Otherwise the condition can change from true to + // false as the loop progresses, so take the value on the last iteration. + const SCEV *NewLHSS = IsIncreasing + ? IndexAR->getStart() + : SE->getSCEVAtScope(IndexAR, L->getParentLoop()); + if (NewLHSS == IndexAR) { + DEBUG(dbgs() << "Can't compute NewLHSS!\n"); + return None; + } + + DEBUG(dbgs() << "NewLHSS: "); + DEBUG(NewLHSS->dump()); + + if (!CanExpand(NewLHSS)) + return None; + + DEBUG(dbgs() << "NewLHSS is loop invariant and safe to expand. Expand!\n"); + + Instruction *InsertAt = Preheader->getTerminator(); + return expandCheck(Expander, Builder, Pred, NewLHSS, RHSS, InsertAt); +} + +bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard, + SCEVExpander &Expander) { + DEBUG(dbgs() << "Processing guard:\n"); + DEBUG(Guard->dump()); + + IRBuilder<> Builder(cast<Instruction>(Preheader->getTerminator())); + + // The guard condition is expected to be in form of: + // cond1 && cond2 && cond3 ... + // Iterate over subconditions looking for for icmp conditions which can be + // widened across loop iterations. Widening these conditions remember the + // resulting list of subconditions in Checks vector. + SmallVector<Value *, 4> Worklist(1, Guard->getOperand(0)); + SmallPtrSet<Value *, 4> Visited; + + SmallVector<Value *, 4> Checks; + + unsigned NumWidened = 0; + do { + Value *Condition = Worklist.pop_back_val(); + if (!Visited.insert(Condition).second) + continue; + + Value *LHS, *RHS; + using namespace llvm::PatternMatch; + if (match(Condition, m_And(m_Value(LHS), m_Value(RHS)))) { + Worklist.push_back(LHS); + Worklist.push_back(RHS); + continue; + } + + if (ICmpInst *ICI = dyn_cast<ICmpInst>(Condition)) { + if (auto NewRangeCheck = widenICmpRangeCheck(ICI, Expander, Builder)) { + Checks.push_back(NewRangeCheck.getValue()); + NumWidened++; + continue; + } + } + + // Save the condition as is if we can't widen it + Checks.push_back(Condition); + } while (Worklist.size() != 0); + + if (NumWidened == 0) + return false; + + // Emit the new guard condition + Builder.SetInsertPoint(Guard); + Value *LastCheck = nullptr; + for (auto *Check : Checks) + if (!LastCheck) + LastCheck = Check; + else + LastCheck = Builder.CreateAnd(LastCheck, Check); + Guard->setOperand(0, LastCheck); + + DEBUG(dbgs() << "Widened checks = " << NumWidened << "\n"); + return true; +} + +bool LoopPredication::runOnLoop(Loop *Loop) { + L = Loop; + + DEBUG(dbgs() << "Analyzing "); + DEBUG(L->dump()); + + Module *M = L->getHeader()->getModule(); + + // There is nothing to do if the module doesn't use guards + auto *GuardDecl = + M->getFunction(Intrinsic::getName(Intrinsic::experimental_guard)); + if (!GuardDecl || GuardDecl->use_empty()) + return false; + + DL = &M->getDataLayout(); + + Preheader = L->getLoopPreheader(); + if (!Preheader) + return false; + + // Collect all the guards into a vector and process later, so as not + // to invalidate the instruction iterator. + SmallVector<IntrinsicInst *, 4> Guards; + for (const auto BB : L->blocks()) + for (auto &I : *BB) + if (auto *II = dyn_cast<IntrinsicInst>(&I)) + if (II->getIntrinsicID() == Intrinsic::experimental_guard) + Guards.push_back(II); + + if (Guards.empty()) + return false; + + SCEVExpander Expander(*SE, *DL, "loop-predication"); + + bool Changed = false; + for (auto *Guard : Guards) + Changed |= widenGuardConditions(Guard, Expander); + + return Changed; +} diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp index 86058fe..fc0216e 100644 --- a/contrib/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp @@ -11,10 +11,9 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar.h" +#include "llvm/ADT/BitVector.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/BitVector.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" @@ -31,6 +30,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" @@ -557,7 +557,7 @@ bool LoopReroll::isLoopControlIV(Loop *L, Instruction *IV) { Instruction *UUser = dyn_cast<Instruction>(UU); // Skip SExt if we are extending an nsw value // TODO: Allow ZExt too - if (BO->hasNoSignedWrap() && UUser && UUser->getNumUses() == 1 && + if (BO->hasNoSignedWrap() && UUser && UUser->hasOneUse() && isa<SExtInst>(UUser)) UUser = dyn_cast<Instruction>(*(UUser->user_begin())); if (!isCompareUsedByBranch(UUser)) @@ -852,7 +852,7 @@ collectPossibleRoots(Instruction *Base, std::map<int64_t,Instruction*> &Roots) { for (auto &KV : Roots) { if (KV.first == 0) continue; - if (KV.second->getNumUses() != NumBaseUses) { + if (!KV.second->hasNUses(NumBaseUses)) { DEBUG(dbgs() << "LRR: Aborting - Root and Base #users not the same: " << "#Base=" << NumBaseUses << ", #Root=" << KV.second->getNumUses() << "\n"); @@ -867,7 +867,7 @@ void LoopReroll::DAGRootTracker:: findRootsRecursive(Instruction *I, SmallInstructionSet SubsumedInsts) { // Does the user look like it could be part of a root set? // All its users must be simple arithmetic ops. - if (I->getNumUses() > IL_MaxRerollIterations) + if (I->hasNUsesOrMore(IL_MaxRerollIterations + 1)) return; if (I != IV && findRootsBase(I, SubsumedInsts)) diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopRotation.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopRotation.cpp index cc83069..3506ac3 100644 --- a/contrib/llvm/lib/Transforms/Scalar/LoopRotation.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/LoopRotation.cpp @@ -58,13 +58,14 @@ class LoopRotate { AssumptionCache *AC; DominatorTree *DT; ScalarEvolution *SE; + const SimplifyQuery &SQ; public: LoopRotate(unsigned MaxHeaderSize, LoopInfo *LI, const TargetTransformInfo *TTI, AssumptionCache *AC, - DominatorTree *DT, ScalarEvolution *SE) - : MaxHeaderSize(MaxHeaderSize), LI(LI), TTI(TTI), AC(AC), DT(DT), SE(SE) { - } + DominatorTree *DT, ScalarEvolution *SE, const SimplifyQuery &SQ) + : MaxHeaderSize(MaxHeaderSize), LI(LI), TTI(TTI), AC(AC), DT(DT), SE(SE), + SQ(SQ) {} bool processLoop(Loop *L); private: @@ -79,7 +80,8 @@ private: /// to merge the two values. Do this now. static void RewriteUsesOfClonedInstructions(BasicBlock *OrigHeader, BasicBlock *OrigPreheader, - ValueToValueMapTy &ValueMap) { + ValueToValueMapTy &ValueMap, + SmallVectorImpl<PHINode*> *InsertedPHIs) { // Remove PHI node entries that are no longer live. BasicBlock::iterator I, E = OrigHeader->end(); for (I = OrigHeader->begin(); PHINode *PN = dyn_cast<PHINode>(I); ++I) @@ -87,7 +89,7 @@ static void RewriteUsesOfClonedInstructions(BasicBlock *OrigHeader, // Now fix up users of the instructions in OrigHeader, inserting PHI nodes // as necessary. - SSAUpdater SSA; + SSAUpdater SSA(InsertedPHIs); for (I = OrigHeader->begin(); I != E; ++I) { Value *OrigHeaderVal = &*I; @@ -174,6 +176,38 @@ static void RewriteUsesOfClonedInstructions(BasicBlock *OrigHeader, } } +/// Propagate dbg.value intrinsics through the newly inserted Phis. +static void insertDebugValues(BasicBlock *OrigHeader, + SmallVectorImpl<PHINode*> &InsertedPHIs) { + ValueToValueMapTy DbgValueMap; + + // Map existing PHI nodes to their dbg.values. + for (auto &I : *OrigHeader) { + if (auto DbgII = dyn_cast<DbgInfoIntrinsic>(&I)) { + if (auto *Loc = dyn_cast_or_null<PHINode>(DbgII->getVariableLocation())) + DbgValueMap.insert({Loc, DbgII}); + } + } + + // Then iterate through the new PHIs and look to see if they use one of the + // previously mapped PHIs. If so, insert a new dbg.value intrinsic that will + // propagate the info through the new PHI. + LLVMContext &C = OrigHeader->getContext(); + for (auto PHI : InsertedPHIs) { + for (auto VI : PHI->operand_values()) { + auto V = DbgValueMap.find(VI); + if (V != DbgValueMap.end()) { + auto *DbgII = cast<DbgInfoIntrinsic>(V->second); + Instruction *NewDbgII = DbgII->clone(); + auto PhiMAV = MetadataAsValue::get(C, ValueAsMetadata::get(PHI)); + NewDbgII->setOperand(0, PhiMAV); + BasicBlock *Parent = PHI->getParent(); + NewDbgII->insertBefore(Parent->getFirstNonPHIOrDbgOrLifetime()); + } + } + } +} + /// Rotate loop LP. Return true if the loop is rotated. /// /// \param SimplifiedLatch is true if the latch was just folded into the final @@ -278,8 +312,6 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { for (; PHINode *PN = dyn_cast<PHINode>(I); ++I) ValueMap[PN] = PN->getIncomingValueForBlock(OrigPreheader); - const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); - // For the rest of the instructions, either hoist to the OrigPreheader if // possible or create a clone in the OldPreHeader if not. TerminatorInst *LoopEntryBranch = OrigPreheader->getTerminator(); @@ -309,14 +341,13 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { // With the operands remapped, see if the instruction constant folds or is // otherwise simplifyable. This commonly occurs because the entry from PHI // nodes allows icmps and other instructions to fold. - // FIXME: Provide TLI, DT, AC to SimplifyInstruction. - Value *V = SimplifyInstruction(C, DL); + Value *V = SimplifyInstruction(C, SQ); if (V && LI->replacementPreservesLCSSAForm(C, V)) { // If so, then delete the temporary instruction and stick the folded value // in the map. ValueMap[Inst] = V; if (!C->mayHaveSideEffects()) { - delete C; + C->deleteValue(); C = nullptr; } } else { @@ -347,9 +378,18 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { // remove the corresponding incoming values from the PHI nodes in OrigHeader. LoopEntryBranch->eraseFromParent(); + + SmallVector<PHINode*, 2> InsertedPHIs; // If there were any uses of instructions in the duplicated block outside the // loop, update them, inserting PHI nodes as required - RewriteUsesOfClonedInstructions(OrigHeader, OrigPreheader, ValueMap); + RewriteUsesOfClonedInstructions(OrigHeader, OrigPreheader, ValueMap, + &InsertedPHIs); + + // Attach dbg.value intrinsics to the new phis if that phi uses a value that + // previously had debug metadata attached. This keeps the debug info + // up-to-date in the loop body. + if (!InsertedPHIs.empty()) + insertDebugValues(OrigHeader, InsertedPHIs); // NewHeader is now the header of the loop. L->moveToHeader(NewHeader); @@ -445,10 +485,22 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { DomTreeNode *Node = HeaderChildren[I]; BasicBlock *BB = Node->getBlock(); - pred_iterator PI = pred_begin(BB); - BasicBlock *NearestDom = *PI; - for (pred_iterator PE = pred_end(BB); PI != PE; ++PI) - NearestDom = DT->findNearestCommonDominator(NearestDom, *PI); + BasicBlock *NearestDom = nullptr; + for (BasicBlock *Pred : predecessors(BB)) { + // Consider only reachable basic blocks. + if (!DT->getNode(Pred)) + continue; + + if (!NearestDom) { + NearestDom = Pred; + continue; + } + + NearestDom = DT->findNearestCommonDominator(NearestDom, Pred); + assert(NearestDom && "No NearestCommonDominator found"); + } + + assert(NearestDom && "Nearest dominator not found"); // Remember if this changes the DomTree. if (Node->getIDom()->getBlock() != NearestDom) { @@ -629,11 +681,15 @@ PreservedAnalyses LoopRotatePass::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &) { int Threshold = EnableHeaderDuplication ? DefaultRotationThreshold : 0; - LoopRotate LR(Threshold, &AR.LI, &AR.TTI, &AR.AC, &AR.DT, &AR.SE); + const DataLayout &DL = L.getHeader()->getModule()->getDataLayout(); + const SimplifyQuery SQ = getBestSimplifyQuery(AR, DL); + LoopRotate LR(Threshold, &AR.LI, &AR.TTI, &AR.AC, &AR.DT, &AR.SE, + SQ); bool Changed = LR.processLoop(&L); if (!Changed) return PreservedAnalyses::all(); + return getLoopPassPreservedAnalyses(); } @@ -671,7 +727,8 @@ public: auto *DT = DTWP ? &DTWP->getDomTree() : nullptr; auto *SEWP = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>(); auto *SE = SEWP ? &SEWP->getSE() : nullptr; - LoopRotate LR(MaxHeaderSize, LI, TTI, AC, DT, SE); + const SimplifyQuery SQ = getBestSimplifyQuery(*this, F); + LoopRotate LR(MaxHeaderSize, LI, TTI, AC, DT, SE, SQ); return LR.processLoop(L); } }; diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp index 1606121..35c05e8 100644 --- a/contrib/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp @@ -40,7 +40,7 @@ static bool simplifyLoopCFG(Loop &L, DominatorTree &DT, LoopInfo &LI) { bool Changed = false; // Copy blocks into a temporary array to avoid iterator invalidation issues // as we remove them. - SmallVector<WeakVH, 16> Blocks(L.blocks()); + SmallVector<WeakTrackingVH, 16> Blocks(L.blocks()); for (auto &Block : Blocks) { // Attempt to merge blocks in the trivial case. Don't modify blocks which @@ -69,6 +69,7 @@ PreservedAnalyses LoopSimplifyCFGPass::run(Loop &L, LoopAnalysisManager &AM, LPMUpdater &) { if (!simplifyLoopCFG(L, AR.DT, AR.LI)) return PreservedAnalyses::all(); + return getLoopPassPreservedAnalyses(); } diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopSink.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopSink.cpp index f3f4152..c9d55b4 100644 --- a/contrib/llvm/lib/Transforms/Scalar/LoopSink.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/LoopSink.cpp @@ -1,4 +1,4 @@ -//===-- LoopSink.cpp - Loop Sink Pass ------------------------===// +//===-- LoopSink.cpp - Loop Sink Pass -------------------------------------===// // // The LLVM Compiler Infrastructure // @@ -28,8 +28,10 @@ // InsertBBs = UseBBs - DomBBs + BB // For BB in InsertBBs: // Insert I at BB's beginning +// //===----------------------------------------------------------------------===// +#include "llvm/Transforms/Scalar/LoopSink.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AliasSetTracker.h" @@ -297,6 +299,42 @@ static bool sinkLoopInvariantInstructions(Loop &L, AAResults &AA, LoopInfo &LI, return Changed; } +PreservedAnalyses LoopSinkPass::run(Function &F, FunctionAnalysisManager &FAM) { + LoopInfo &LI = FAM.getResult<LoopAnalysis>(F); + // Nothing to do if there are no loops. + if (LI.empty()) + return PreservedAnalyses::all(); + + AAResults &AA = FAM.getResult<AAManager>(F); + DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); + BlockFrequencyInfo &BFI = FAM.getResult<BlockFrequencyAnalysis>(F); + + // We want to do a postorder walk over the loops. Since loops are a tree this + // is equivalent to a reversed preorder walk and preorder is easy to compute + // without recursion. Since we reverse the preorder, we will visit siblings + // in reverse program order. This isn't expected to matter at all but is more + // consistent with sinking algorithms which generally work bottom-up. + SmallVector<Loop *, 4> PreorderLoops = LI.getLoopsInPreorder(); + + bool Changed = false; + do { + Loop &L = *PreorderLoops.pop_back_val(); + + // Note that we don't pass SCEV here because it is only used to invalidate + // loops in SCEV and we don't preserve (or request) SCEV at all making that + // unnecessary. + Changed |= sinkLoopInvariantInstructions(L, AA, LI, DT, BFI, + /*ScalarEvolution*/ nullptr); + } while (!PreorderLoops.empty()); + + if (!Changed) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + return PA; +} + namespace { struct LegacyLoopSinkPass : public LoopPass { static char ID; diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp index 194587a..3638da1 100644 --- a/contrib/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -129,6 +129,24 @@ static cl::opt<bool> EnablePhiElim( "enable-lsr-phielim", cl::Hidden, cl::init(true), cl::desc("Enable LSR phi elimination")); +// The flag adds instruction count to solutions cost comparision. +static cl::opt<bool> InsnsCost( + "lsr-insns-cost", cl::Hidden, cl::init(false), + cl::desc("Add instruction count to a LSR cost model")); + +// Flag to choose how to narrow complex lsr solution +static cl::opt<bool> LSRExpNarrow( + "lsr-exp-narrow", cl::Hidden, cl::init(false), + cl::desc("Narrow LSR complex solution using" + " expectation of registers number")); + +// Flag to narrow search space by filtering non-optimal formulae with +// the same ScaledReg and Scale. +static cl::opt<bool> FilterSameScaledReg( + "lsr-filter-same-scaled-reg", cl::Hidden, cl::init(true), + cl::desc("Narrow LSR search space by filtering non-optimal formulae" + " with the same ScaledReg and Scale")); + #ifndef NDEBUG // Stress test IV chain generation. static cl::opt<bool> StressIVChain( @@ -181,10 +199,11 @@ void RegSortData::print(raw_ostream &OS) const { OS << "[NumUses=" << UsedByIndices.count() << ']'; } -LLVM_DUMP_METHOD -void RegSortData::dump() const { +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void RegSortData::dump() const { print(errs()); errs() << '\n'; } +#endif namespace { @@ -295,9 +314,13 @@ struct Formula { /// canonical representation of a formula is /// 1. BaseRegs.size > 1 implies ScaledReg != NULL and /// 2. ScaledReg != NULL implies Scale != 1 || !BaseRegs.empty(). + /// 3. The reg containing recurrent expr related with currect loop in the + /// formula should be put in the ScaledReg. /// #1 enforces that the scaled register is always used when at least two /// registers are needed by the formula: e.g., reg1 + reg2 is reg1 + 1 * reg2. /// #2 enforces that 1 * reg is reg. + /// #3 ensures invariant regs with respect to current loop can be combined + /// together in LSR codegen. /// This invariant can be temporarly broken while building a formula. /// However, every formula inserted into the LSRInstance must be in canonical /// form. @@ -318,12 +341,14 @@ struct Formula { void initialMatch(const SCEV *S, Loop *L, ScalarEvolution &SE); - bool isCanonical() const; + bool isCanonical(const Loop &L) const; - void canonicalize(); + void canonicalize(const Loop &L); bool unscale(); + bool hasZeroEnd() const; + size_t getNumRegs() const; Type *getType() const; @@ -410,16 +435,35 @@ void Formula::initialMatch(const SCEV *S, Loop *L, ScalarEvolution &SE) { BaseRegs.push_back(Sum); HasBaseReg = true; } - canonicalize(); + canonicalize(*L); } /// \brief Check whether or not this formula statisfies the canonical /// representation. /// \see Formula::BaseRegs. -bool Formula::isCanonical() const { - if (ScaledReg) - return Scale != 1 || !BaseRegs.empty(); - return BaseRegs.size() <= 1; +bool Formula::isCanonical(const Loop &L) const { + if (!ScaledReg) + return BaseRegs.size() <= 1; + + if (Scale != 1) + return true; + + if (Scale == 1 && BaseRegs.empty()) + return false; + + const SCEVAddRecExpr *SAR = dyn_cast<const SCEVAddRecExpr>(ScaledReg); + if (SAR && SAR->getLoop() == &L) + return true; + + // If ScaledReg is not a recurrent expr, or it is but its loop is not current + // loop, meanwhile BaseRegs contains a recurrent expr reg related with current + // loop, we want to swap the reg in BaseRegs with ScaledReg. + auto I = + find_if(make_range(BaseRegs.begin(), BaseRegs.end()), [&](const SCEV *S) { + return isa<const SCEVAddRecExpr>(S) && + (cast<SCEVAddRecExpr>(S)->getLoop() == &L); + }); + return I == BaseRegs.end(); } /// \brief Helper method to morph a formula into its canonical representation. @@ -428,21 +472,33 @@ bool Formula::isCanonical() const { /// field. Otherwise, we would have to do special cases everywhere in LSR /// to treat reg1 + reg2 + ... the same way as reg1 + 1*reg2 + ... /// On the other hand, 1*reg should be canonicalized into reg. -void Formula::canonicalize() { - if (isCanonical()) +void Formula::canonicalize(const Loop &L) { + if (isCanonical(L)) return; // So far we did not need this case. This is easy to implement but it is // useless to maintain dead code. Beside it could hurt compile time. assert(!BaseRegs.empty() && "1*reg => reg, should not be needed."); + // Keep the invariant sum in BaseRegs and one of the variant sum in ScaledReg. - ScaledReg = BaseRegs.back(); - BaseRegs.pop_back(); - Scale = 1; - size_t BaseRegsSize = BaseRegs.size(); - size_t Try = 0; - // If ScaledReg is an invariant, try to find a variant expression. - while (Try < BaseRegsSize && !isa<SCEVAddRecExpr>(ScaledReg)) - std::swap(ScaledReg, BaseRegs[Try++]); + if (!ScaledReg) { + ScaledReg = BaseRegs.back(); + BaseRegs.pop_back(); + Scale = 1; + } + + // If ScaledReg is an invariant with respect to L, find the reg from + // BaseRegs containing the recurrent expr related with Loop L. Swap the + // reg with ScaledReg. + const SCEVAddRecExpr *SAR = dyn_cast<const SCEVAddRecExpr>(ScaledReg); + if (!SAR || SAR->getLoop() != &L) { + auto I = find_if(make_range(BaseRegs.begin(), BaseRegs.end()), + [&](const SCEV *S) { + return isa<const SCEVAddRecExpr>(S) && + (cast<SCEVAddRecExpr>(S)->getLoop() == &L); + }); + if (I != BaseRegs.end()) + std::swap(ScaledReg, *I); + } } /// \brief Get rid of the scale in the formula. @@ -458,6 +514,14 @@ bool Formula::unscale() { return true; } +bool Formula::hasZeroEnd() const { + if (UnfoldedOffset || BaseOffset) + return false; + if (BaseRegs.size() != 1 || ScaledReg) + return false; + return true; +} + /// Return the total number of register operands used by this formula. This does /// not include register uses implied by non-constant addrec strides. size_t Formula::getNumRegs() const { @@ -534,10 +598,11 @@ void Formula::print(raw_ostream &OS) const { } } -LLVM_DUMP_METHOD -void Formula::dump() const { +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void Formula::dump() const { print(errs()); errs() << '\n'; } +#endif /// Return true if the given addrec can be sign-extended without changing its /// value. @@ -711,7 +776,7 @@ static GlobalValue *ExtractSymbol(const SCEV *&S, ScalarEvolution &SE) { static bool isAddressUse(Instruction *Inst, Value *OperandVal) { bool isAddress = isa<LoadInst>(Inst); if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { - if (SI->getOperand(1) == OperandVal) + if (SI->getPointerOperand() == OperandVal) isAddress = true; } else if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) { // Addressing modes can also be folded into prefetches and a variety @@ -723,6 +788,12 @@ static bool isAddressUse(Instruction *Inst, Value *OperandVal) { isAddress = true; break; } + } else if (AtomicRMWInst *RMW = dyn_cast<AtomicRMWInst>(Inst)) { + if (RMW->getPointerOperand() == OperandVal) + isAddress = true; + } else if (AtomicCmpXchgInst *CmpX = dyn_cast<AtomicCmpXchgInst>(Inst)) { + if (CmpX->getPointerOperand() == OperandVal) + isAddress = true; } return isAddress; } @@ -735,6 +806,10 @@ static MemAccessTy getAccessType(const Instruction *Inst) { AccessTy.AddrSpace = SI->getPointerAddressSpace(); } else if (const LoadInst *LI = dyn_cast<LoadInst>(Inst)) { AccessTy.AddrSpace = LI->getPointerAddressSpace(); + } else if (const AtomicRMWInst *RMW = dyn_cast<AtomicRMWInst>(Inst)) { + AccessTy.AddrSpace = RMW->getPointerAddressSpace(); + } else if (const AtomicCmpXchgInst *CmpX = dyn_cast<AtomicCmpXchgInst>(Inst)) { + AccessTy.AddrSpace = CmpX->getPointerAddressSpace(); } // All pointers have the same requirements, so canonicalize them to an @@ -832,7 +907,7 @@ static bool isHighCostExpansion(const SCEV *S, /// If any of the instructions is the specified set are trivially dead, delete /// them and see if this makes any of their operands subsequently dead. static bool -DeleteTriviallyDeadInstructions(SmallVectorImpl<WeakVH> &DeadInsts) { +DeleteTriviallyDeadInstructions(SmallVectorImpl<WeakTrackingVH> &DeadInsts) { bool Changed = false; while (!DeadInsts.empty()) { @@ -875,44 +950,44 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI, const LSRUse &LU, const Formula &F); // Get the cost of the scaling factor used in F for LU. static unsigned getScalingFactorCost(const TargetTransformInfo &TTI, - const LSRUse &LU, const Formula &F); + const LSRUse &LU, const Formula &F, + const Loop &L); namespace { /// This class is used to measure and compare candidate formulae. class Cost { - /// TODO: Some of these could be merged. Also, a lexical ordering - /// isn't always optimal. - unsigned NumRegs; - unsigned AddRecCost; - unsigned NumIVMuls; - unsigned NumBaseAdds; - unsigned ImmCost; - unsigned SetupCost; - unsigned ScaleCost; + TargetTransformInfo::LSRCost C; public: - Cost() - : NumRegs(0), AddRecCost(0), NumIVMuls(0), NumBaseAdds(0), ImmCost(0), - SetupCost(0), ScaleCost(0) {} + Cost() { + C.Insns = 0; + C.NumRegs = 0; + C.AddRecCost = 0; + C.NumIVMuls = 0; + C.NumBaseAdds = 0; + C.ImmCost = 0; + C.SetupCost = 0; + C.ScaleCost = 0; + } - bool operator<(const Cost &Other) const; + bool isLess(Cost &Other, const TargetTransformInfo &TTI); void Lose(); #ifndef NDEBUG // Once any of the metrics loses, they must all remain losers. bool isValid() { - return ((NumRegs | AddRecCost | NumIVMuls | NumBaseAdds - | ImmCost | SetupCost | ScaleCost) != ~0u) - || ((NumRegs & AddRecCost & NumIVMuls & NumBaseAdds - & ImmCost & SetupCost & ScaleCost) == ~0u); + return ((C.Insns | C.NumRegs | C.AddRecCost | C.NumIVMuls | C.NumBaseAdds + | C.ImmCost | C.SetupCost | C.ScaleCost) != ~0u) + || ((C.Insns & C.NumRegs & C.AddRecCost & C.NumIVMuls & C.NumBaseAdds + & C.ImmCost & C.SetupCost & C.ScaleCost) == ~0u); } #endif bool isLoser() { assert(isValid() && "invalid cost"); - return NumRegs == ~0u; + return C.NumRegs == ~0u; } void RateFormula(const TargetTransformInfo &TTI, @@ -1067,7 +1142,8 @@ public: } bool HasFormulaWithSameRegs(const Formula &F) const; - bool InsertFormula(const Formula &F); + float getNotSelectedProbability(const SCEV *Reg) const; + bool InsertFormula(const Formula &F, const Loop &L); void DeleteFormula(Formula &F); void RecomputeRegs(size_t LUIdx, RegUseTracker &Reguses); @@ -1083,20 +1159,26 @@ void Cost::RateRegister(const SCEV *Reg, const Loop *L, ScalarEvolution &SE, DominatorTree &DT) { if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Reg)) { - // If this is an addrec for another loop, don't second-guess its addrec phi - // nodes. LSR isn't currently smart enough to reason about more than one - // loop at a time. LSR has already run on inner loops, will not run on outer - // loops, and cannot be expected to change sibling loops. + // If this is an addrec for another loop, it should be an invariant + // with respect to L since L is the innermost loop (at least + // for now LSR only handles innermost loops). if (AR->getLoop() != L) { // If the AddRec exists, consider it's register free and leave it alone. if (isExistingPhi(AR, SE)) return; - // Otherwise, do not consider this formula at all. - Lose(); + // It is bad to allow LSR for current loop to add induction variables + // for its sibling loops. + if (!AR->getLoop()->contains(L)) { + Lose(); + return; + } + + // Otherwise, it will be an invariant with respect to Loop L. + ++C.NumRegs; return; } - AddRecCost += 1; /// TODO: This should be a function of the stride. + C.AddRecCost += 1; /// TODO: This should be a function of the stride. // Add the step value register, if it needs one. // TODO: The non-affine case isn't precisely modeled here. @@ -1108,7 +1190,7 @@ void Cost::RateRegister(const SCEV *Reg, } } } - ++NumRegs; + ++C.NumRegs; // Rough heuristic; favor registers which don't require extra setup // instructions in the preheader. @@ -1117,9 +1199,9 @@ void Cost::RateRegister(const SCEV *Reg, !(isa<SCEVAddRecExpr>(Reg) && (isa<SCEVUnknown>(cast<SCEVAddRecExpr>(Reg)->getStart()) || isa<SCEVConstant>(cast<SCEVAddRecExpr>(Reg)->getStart())))) - ++SetupCost; + ++C.SetupCost; - NumIVMuls += isa<SCEVMulExpr>(Reg) && + C.NumIVMuls += isa<SCEVMulExpr>(Reg) && SE.hasComputableLoopEvolution(Reg, L); } @@ -1150,8 +1232,11 @@ void Cost::RateFormula(const TargetTransformInfo &TTI, ScalarEvolution &SE, DominatorTree &DT, const LSRUse &LU, SmallPtrSetImpl<const SCEV *> *LoserRegs) { - assert(F.isCanonical() && "Cost is accurate only for canonical formula"); + assert(F.isCanonical(*L) && "Cost is accurate only for canonical formula"); // Tally up the registers. + unsigned PrevAddRecCost = C.AddRecCost; + unsigned PrevNumRegs = C.NumRegs; + unsigned PrevNumBaseAdds = C.NumBaseAdds; if (const SCEV *ScaledReg = F.ScaledReg) { if (VisitedRegs.count(ScaledReg)) { Lose(); @@ -1176,73 +1261,113 @@ void Cost::RateFormula(const TargetTransformInfo &TTI, if (NumBaseParts > 1) // Do not count the base and a possible second register if the target // allows to fold 2 registers. - NumBaseAdds += + C.NumBaseAdds += NumBaseParts - (1 + (F.Scale && isAMCompletelyFolded(TTI, LU, F))); - NumBaseAdds += (F.UnfoldedOffset != 0); + C.NumBaseAdds += (F.UnfoldedOffset != 0); // Accumulate non-free scaling amounts. - ScaleCost += getScalingFactorCost(TTI, LU, F); + C.ScaleCost += getScalingFactorCost(TTI, LU, F, *L); // Tally up the non-zero immediates. for (const LSRFixup &Fixup : LU.Fixups) { int64_t O = Fixup.Offset; int64_t Offset = (uint64_t)O + F.BaseOffset; if (F.BaseGV) - ImmCost += 64; // Handle symbolic values conservatively. + C.ImmCost += 64; // Handle symbolic values conservatively. // TODO: This should probably be the pointer size. else if (Offset != 0) - ImmCost += APInt(64, Offset, true).getMinSignedBits(); + C.ImmCost += APInt(64, Offset, true).getMinSignedBits(); // Check with target if this offset with this instruction is // specifically not supported. if ((isa<LoadInst>(Fixup.UserInst) || isa<StoreInst>(Fixup.UserInst)) && !TTI.isFoldableMemAccessOffset(Fixup.UserInst, Offset)) - NumBaseAdds++; + C.NumBaseAdds++; + } + + // If we don't count instruction cost exit here. + if (!InsnsCost) { + assert(isValid() && "invalid cost"); + return; + } + + // Treat every new register that exceeds TTI.getNumberOfRegisters() - 1 as + // additional instruction (at least fill). + unsigned TTIRegNum = TTI.getNumberOfRegisters(false) - 1; + if (C.NumRegs > TTIRegNum) { + // Cost already exceeded TTIRegNum, then only newly added register can add + // new instructions. + if (PrevNumRegs > TTIRegNum) + C.Insns += (C.NumRegs - PrevNumRegs); + else + C.Insns += (C.NumRegs - TTIRegNum); } + + // If ICmpZero formula ends with not 0, it could not be replaced by + // just add or sub. We'll need to compare final result of AddRec. + // That means we'll need an additional instruction. + // For -10 + {0, +, 1}: + // i = i + 1; + // cmp i, 10 + // + // For {-10, +, 1}: + // i = i + 1; + if (LU.Kind == LSRUse::ICmpZero && !F.hasZeroEnd()) + C.Insns++; + // Each new AddRec adds 1 instruction to calculation. + C.Insns += (C.AddRecCost - PrevAddRecCost); + + // BaseAdds adds instructions for unfolded registers. + if (LU.Kind != LSRUse::ICmpZero) + C.Insns += C.NumBaseAdds - PrevNumBaseAdds; assert(isValid() && "invalid cost"); } /// Set this cost to a losing value. void Cost::Lose() { - NumRegs = ~0u; - AddRecCost = ~0u; - NumIVMuls = ~0u; - NumBaseAdds = ~0u; - ImmCost = ~0u; - SetupCost = ~0u; - ScaleCost = ~0u; + C.Insns = ~0u; + C.NumRegs = ~0u; + C.AddRecCost = ~0u; + C.NumIVMuls = ~0u; + C.NumBaseAdds = ~0u; + C.ImmCost = ~0u; + C.SetupCost = ~0u; + C.ScaleCost = ~0u; } /// Choose the lower cost. -bool Cost::operator<(const Cost &Other) const { - return std::tie(NumRegs, AddRecCost, NumIVMuls, NumBaseAdds, ScaleCost, - ImmCost, SetupCost) < - std::tie(Other.NumRegs, Other.AddRecCost, Other.NumIVMuls, - Other.NumBaseAdds, Other.ScaleCost, Other.ImmCost, - Other.SetupCost); +bool Cost::isLess(Cost &Other, const TargetTransformInfo &TTI) { + if (InsnsCost.getNumOccurrences() > 0 && InsnsCost && + C.Insns != Other.C.Insns) + return C.Insns < Other.C.Insns; + return TTI.isLSRCostLess(C, Other.C); } void Cost::print(raw_ostream &OS) const { - OS << NumRegs << " reg" << (NumRegs == 1 ? "" : "s"); - if (AddRecCost != 0) - OS << ", with addrec cost " << AddRecCost; - if (NumIVMuls != 0) - OS << ", plus " << NumIVMuls << " IV mul" << (NumIVMuls == 1 ? "" : "s"); - if (NumBaseAdds != 0) - OS << ", plus " << NumBaseAdds << " base add" - << (NumBaseAdds == 1 ? "" : "s"); - if (ScaleCost != 0) - OS << ", plus " << ScaleCost << " scale cost"; - if (ImmCost != 0) - OS << ", plus " << ImmCost << " imm cost"; - if (SetupCost != 0) - OS << ", plus " << SetupCost << " setup cost"; + if (InsnsCost) + OS << C.Insns << " instruction" << (C.Insns == 1 ? " " : "s "); + OS << C.NumRegs << " reg" << (C.NumRegs == 1 ? "" : "s"); + if (C.AddRecCost != 0) + OS << ", with addrec cost " << C.AddRecCost; + if (C.NumIVMuls != 0) + OS << ", plus " << C.NumIVMuls << " IV mul" + << (C.NumIVMuls == 1 ? "" : "s"); + if (C.NumBaseAdds != 0) + OS << ", plus " << C.NumBaseAdds << " base add" + << (C.NumBaseAdds == 1 ? "" : "s"); + if (C.ScaleCost != 0) + OS << ", plus " << C.ScaleCost << " scale cost"; + if (C.ImmCost != 0) + OS << ", plus " << C.ImmCost << " imm cost"; + if (C.SetupCost != 0) + OS << ", plus " << C.SetupCost << " setup cost"; } -LLVM_DUMP_METHOD -void Cost::dump() const { +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void Cost::dump() const { print(errs()); errs() << '\n'; } +#endif LSRFixup::LSRFixup() : UserInst(nullptr), OperandValToReplace(nullptr), @@ -1285,10 +1410,11 @@ void LSRFixup::print(raw_ostream &OS) const { OS << ", Offset=" << Offset; } -LLVM_DUMP_METHOD -void LSRFixup::dump() const { +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void LSRFixup::dump() const { print(errs()); errs() << '\n'; } +#endif /// Test whether this use as a formula which has the same registers as the given /// formula. @@ -1300,10 +1426,19 @@ bool LSRUse::HasFormulaWithSameRegs(const Formula &F) const { return Uniquifier.count(Key); } +/// The function returns a probability of selecting formula without Reg. +float LSRUse::getNotSelectedProbability(const SCEV *Reg) const { + unsigned FNum = 0; + for (const Formula &F : Formulae) + if (F.referencesReg(Reg)) + FNum++; + return ((float)(Formulae.size() - FNum)) / Formulae.size(); +} + /// If the given formula has not yet been inserted, add it to the list, and /// return true. Return false otherwise. The formula must be in canonical form. -bool LSRUse::InsertFormula(const Formula &F) { - assert(F.isCanonical() && "Invalid canonical representation"); +bool LSRUse::InsertFormula(const Formula &F, const Loop &L) { + assert(F.isCanonical(L) && "Invalid canonical representation"); if (!Formulae.empty() && RigidFormula) return false; @@ -1391,10 +1526,11 @@ void LSRUse::print(raw_ostream &OS) const { OS << ", widest fixup type: " << *WidestFixupType; } -LLVM_DUMP_METHOD -void LSRUse::dump() const { +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void LSRUse::dump() const { print(errs()); errs() << '\n'; } +#endif static bool isAMCompletelyFolded(const TargetTransformInfo &TTI, LSRUse::KindType Kind, MemAccessTy AccessTy, @@ -1472,7 +1608,7 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI, static bool isAMCompletelyFolded(const TargetTransformInfo &TTI, int64_t MinOffset, int64_t MaxOffset, LSRUse::KindType Kind, MemAccessTy AccessTy, - const Formula &F) { + const Formula &F, const Loop &L) { // For the purpose of isAMCompletelyFolded either having a canonical formula // or a scale not equal to zero is correct. // Problems may arise from non canonical formulae having a scale == 0. @@ -1480,7 +1616,7 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI, // However, when we generate the scaled formulae, we first check that the // scaling factor is profitable before computing the actual ScaledReg for // compile time sake. - assert((F.isCanonical() || F.Scale != 0)); + assert((F.isCanonical(L) || F.Scale != 0)); return isAMCompletelyFolded(TTI, MinOffset, MaxOffset, Kind, AccessTy, F.BaseGV, F.BaseOffset, F.HasBaseReg, F.Scale); } @@ -1515,14 +1651,15 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI, } static unsigned getScalingFactorCost(const TargetTransformInfo &TTI, - const LSRUse &LU, const Formula &F) { + const LSRUse &LU, const Formula &F, + const Loop &L) { if (!F.Scale) return 0; // If the use is not completely folded in that instruction, we will have to // pay an extra cost only for scale != 1. if (!isAMCompletelyFolded(TTI, LU.MinOffset, LU.MaxOffset, LU.Kind, - LU.AccessTy, F)) + LU.AccessTy, F, L)) return F.Scale != 1; switch (LU.Kind) { @@ -1718,7 +1855,7 @@ class LSRInstance { void FinalizeChain(IVChain &Chain); void CollectChains(); void GenerateIVChain(const IVChain &Chain, SCEVExpander &Rewriter, - SmallVectorImpl<WeakVH> &DeadInsts); + SmallVectorImpl<WeakTrackingVH> &DeadInsts); void CollectInterestingTypesAndFactors(); void CollectFixupsAndInitialFormulae(); @@ -1772,6 +1909,8 @@ class LSRInstance { void NarrowSearchSpaceByDetectingSupersets(); void NarrowSearchSpaceByCollapsingUnrolledCode(); void NarrowSearchSpaceByRefilteringUndesirableDedicatedRegisters(); + void NarrowSearchSpaceByFilterFormulaWithSameScaledReg(); + void NarrowSearchSpaceByDeletingCostlyFormulas(); void NarrowSearchSpaceByPickingWinnerRegs(); void NarrowSearchSpaceUsingHeuristics(); @@ -1792,19 +1931,15 @@ class LSRInstance { const LSRUse &LU, SCEVExpander &Rewriter) const; - Value *Expand(const LSRUse &LU, const LSRFixup &LF, - const Formula &F, - BasicBlock::iterator IP, - SCEVExpander &Rewriter, - SmallVectorImpl<WeakVH> &DeadInsts) const; + Value *Expand(const LSRUse &LU, const LSRFixup &LF, const Formula &F, + BasicBlock::iterator IP, SCEVExpander &Rewriter, + SmallVectorImpl<WeakTrackingVH> &DeadInsts) const; void RewriteForPHI(PHINode *PN, const LSRUse &LU, const LSRFixup &LF, - const Formula &F, - SCEVExpander &Rewriter, - SmallVectorImpl<WeakVH> &DeadInsts) const; - void Rewrite(const LSRUse &LU, const LSRFixup &LF, - const Formula &F, + const Formula &F, SCEVExpander &Rewriter, + SmallVectorImpl<WeakTrackingVH> &DeadInsts) const; + void Rewrite(const LSRUse &LU, const LSRFixup &LF, const Formula &F, SCEVExpander &Rewriter, - SmallVectorImpl<WeakVH> &DeadInsts) const; + SmallVectorImpl<WeakTrackingVH> &DeadInsts) const; void ImplementSolution(const SmallVectorImpl<const Formula *> &Solution); public: @@ -2191,7 +2326,7 @@ LSRInstance::OptimizeLoopTermCond() { dyn_cast_or_null<SCEVConstant>(getExactSDiv(B, A, SE))) { const ConstantInt *C = D->getValue(); // Stride of one or negative one can have reuse with non-addresses. - if (C->isOne() || C->isAllOnesValue()) + if (C->isOne() || C->isMinusOne()) goto decline_post_inc; // Avoid weird situations. if (C->getValue().getMinSignedBits() >= 64 || @@ -2492,7 +2627,12 @@ static Value *getWideOperand(Value *Oper) { static bool isCompatibleIVType(Value *LVal, Value *RVal) { Type *LType = LVal->getType(); Type *RType = RVal->getType(); - return (LType == RType) || (LType->isPointerTy() && RType->isPointerTy()); + return (LType == RType) || (LType->isPointerTy() && RType->isPointerTy() && + // Different address spaces means (possibly) + // different types of the pointer implementation, + // e.g. i16 vs i32 so disallow that. + (LType->getPointerAddressSpace() == + RType->getPointerAddressSpace())); } /// Return an approximation of this SCEV expression's "base", or NULL for any @@ -2881,7 +3021,7 @@ static bool canFoldIVIncExpr(const SCEV *IncExpr, Instruction *UserInst, /// Generate an add or subtract for each IVInc in a chain to materialize the IV /// user's operand from the previous IV user's operand. void LSRInstance::GenerateIVChain(const IVChain &Chain, SCEVExpander &Rewriter, - SmallVectorImpl<WeakVH> &DeadInsts) { + SmallVectorImpl<WeakTrackingVH> &DeadInsts) { // Find the new IVOperand for the head of the chain. It may have been replaced // by LSR. const IVInc &Head = Chain.Incs[0]; @@ -2989,8 +3129,10 @@ void LSRInstance::CollectFixupsAndInitialFormulae() { User::op_iterator UseI = find(UserInst->operands(), U.getOperandValToReplace()); assert(UseI != UserInst->op_end() && "cannot find IV operand"); - if (IVIncSet.count(UseI)) + if (IVIncSet.count(UseI)) { + DEBUG(dbgs() << "Use is in profitable chain: " << **UseI << '\n'); continue; + } LSRUse::KindType Kind = LSRUse::Basic; MemAccessTy AccessTy; @@ -3025,8 +3167,7 @@ void LSRInstance::CollectFixupsAndInitialFormulae() { if (SE.isLoopInvariant(N, L) && isSafeToExpand(N, SE)) { // S is normalized, so normalize N before folding it into S // to keep the result normalized. - N = TransformForPostIncUse(Normalize, N, CI, nullptr, - TmpPostIncLoops, SE, DT); + N = normalizeForPostIncUse(N, TmpPostIncLoops, SE); Kind = LSRUse::ICmpZero; S = SE.getMinusSCEV(N, S); } @@ -3108,7 +3249,8 @@ bool LSRInstance::InsertFormula(LSRUse &LU, unsigned LUIdx, const Formula &F) { // Do not insert formula that we will not be able to expand. assert(isLegalUse(TTI, LU.MinOffset, LU.MaxOffset, LU.Kind, LU.AccessTy, F) && "Formula is illegal"); - if (!LU.InsertFormula(F)) + + if (!LU.InsertFormula(F, *L)) return false; CountRegisters(F, LUIdx); @@ -3347,7 +3489,7 @@ void LSRInstance::GenerateReassociationsImpl(LSRUse &LU, unsigned LUIdx, F.BaseRegs.push_back(*J); // We may have changed the number of register in base regs, adjust the // formula accordingly. - F.canonicalize(); + F.canonicalize(*L); if (InsertFormula(LU, LUIdx, F)) // If that formula hadn't been seen before, recurse to find more like @@ -3359,7 +3501,7 @@ void LSRInstance::GenerateReassociationsImpl(LSRUse &LU, unsigned LUIdx, /// Split out subexpressions from adds and the bases of addrecs. void LSRInstance::GenerateReassociations(LSRUse &LU, unsigned LUIdx, Formula Base, unsigned Depth) { - assert(Base.isCanonical() && "Input must be in the canonical form"); + assert(Base.isCanonical(*L) && "Input must be in the canonical form"); // Arbitrarily cap recursion to protect compile time. if (Depth >= 3) return; @@ -3400,7 +3542,7 @@ void LSRInstance::GenerateCombinations(LSRUse &LU, unsigned LUIdx, // rather than proceed with zero in a register. if (!Sum->isZero()) { F.BaseRegs.push_back(Sum); - F.canonicalize(); + F.canonicalize(*L); (void)InsertFormula(LU, LUIdx, F); } } @@ -3457,7 +3599,7 @@ void LSRInstance::GenerateConstantOffsetsImpl( F.ScaledReg = nullptr; } else F.deleteBaseReg(F.BaseRegs[Idx]); - F.canonicalize(); + F.canonicalize(*L); } else if (IsScaledReg) F.ScaledReg = NewG; else @@ -3620,10 +3762,10 @@ void LSRInstance::GenerateScales(LSRUse &LU, unsigned LUIdx, Formula Base) { if (LU.Kind == LSRUse::ICmpZero && !Base.HasBaseReg && Base.BaseOffset == 0 && !Base.BaseGV) continue; - // For each addrec base reg, apply the scale, if possible. - for (size_t i = 0, e = Base.BaseRegs.size(); i != e; ++i) - if (const SCEVAddRecExpr *AR = - dyn_cast<SCEVAddRecExpr>(Base.BaseRegs[i])) { + // For each addrec base reg, if its loop is current loop, apply the scale. + for (size_t i = 0, e = Base.BaseRegs.size(); i != e; ++i) { + const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Base.BaseRegs[i]); + if (AR && (AR->getLoop() == L || LU.AllFixupsOutsideLoop)) { const SCEV *FactorS = SE.getConstant(IntTy, Factor); if (FactorS->isZero()) continue; @@ -3637,11 +3779,17 @@ void LSRInstance::GenerateScales(LSRUse &LU, unsigned LUIdx, Formula Base) { // The canonical representation of 1*reg is reg, which is already in // Base. In that case, do not try to insert the formula, it will be // rejected anyway. - if (F.Scale == 1 && F.BaseRegs.empty()) + if (F.Scale == 1 && (F.BaseRegs.empty() || + (AR->getLoop() != L && LU.AllFixupsOutsideLoop))) continue; + // If AllFixupsOutsideLoop is true and F.Scale is 1, we may generate + // non canonical Formula with ScaledReg's loop not being L. + if (F.Scale == 1 && LU.AllFixupsOutsideLoop) + F.canonicalize(*L); (void)InsertFormula(LU, LUIdx, F); } } + } } } @@ -3668,6 +3816,7 @@ void LSRInstance::GenerateTruncates(LSRUse &LU, unsigned LUIdx, Formula Base) { if (!F.hasRegsUsedByUsesOtherThan(LUIdx, RegUses)) continue; + F.canonicalize(*L); (void)InsertFormula(LU, LUIdx, F); } } @@ -3697,10 +3846,11 @@ void WorkItem::print(raw_ostream &OS) const { << " , add offset " << Imm; } -LLVM_DUMP_METHOD -void WorkItem::dump() const { +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void WorkItem::dump() const { print(errs()); errs() << '\n'; } +#endif /// Look for registers which are a constant distance apart and try to form reuse /// opportunities between them. @@ -3764,8 +3914,7 @@ void LSRInstance::GenerateCrossUseConstantOffsets() { // Compute the difference between the two. int64_t Imm = (uint64_t)JImm - M->first; - for (int LUIdx = UsedByIndices.find_first(); LUIdx != -1; - LUIdx = UsedByIndices.find_next(LUIdx)) + for (unsigned LUIdx : UsedByIndices.set_bits()) // Make a memo of this use, offset, and register tuple. if (UniqueItems.insert(std::make_pair(LUIdx, Imm)).second) WorkItems.push_back(WorkItem(LUIdx, Imm, OrigReg)); @@ -3821,7 +3970,7 @@ void LSRInstance::GenerateCrossUseConstantOffsets() { continue; // OK, looks good. - NewF.canonicalize(); + NewF.canonicalize(*this->L); (void)InsertFormula(LU, LUIdx, NewF); } else { // Use the immediate in a base register. @@ -3853,7 +4002,7 @@ void LSRInstance::GenerateCrossUseConstantOffsets() { goto skip_formula; // Ok, looks good. - NewF.canonicalize(); + NewF.canonicalize(*this->L); (void)InsertFormula(LU, LUIdx, NewF); break; skip_formula:; @@ -3967,7 +4116,7 @@ void LSRInstance::FilterOutUndesirableDedicatedRegisters() { Cost CostBest; Regs.clear(); CostBest.RateFormula(TTI, Best, Regs, VisitedRegs, L, SE, DT, LU); - if (CostF < CostBest) + if (CostF.isLess(CostBest, TTI)) std::swap(F, Best); DEBUG(dbgs() << " Filtering out formula "; F.print(dbgs()); dbgs() << "\n" @@ -4165,6 +4314,242 @@ void LSRInstance::NarrowSearchSpaceByRefilteringUndesirableDedicatedRegisters(){ } } +/// If a LSRUse has multiple formulae with the same ScaledReg and Scale. +/// Pick the best one and delete the others. +/// This narrowing heuristic is to keep as many formulae with different +/// Scale and ScaledReg pair as possible while narrowing the search space. +/// The benefit is that it is more likely to find out a better solution +/// from a formulae set with more Scale and ScaledReg variations than +/// a formulae set with the same Scale and ScaledReg. The picking winner +/// reg heurstic will often keep the formulae with the same Scale and +/// ScaledReg and filter others, and we want to avoid that if possible. +void LSRInstance::NarrowSearchSpaceByFilterFormulaWithSameScaledReg() { + if (EstimateSearchSpaceComplexity() < ComplexityLimit) + return; + + DEBUG(dbgs() << "The search space is too complex.\n" + "Narrowing the search space by choosing the best Formula " + "from the Formulae with the same Scale and ScaledReg.\n"); + + // Map the "Scale * ScaledReg" pair to the best formula of current LSRUse. + typedef DenseMap<std::pair<const SCEV *, int64_t>, size_t> BestFormulaeTy; + BestFormulaeTy BestFormulae; +#ifndef NDEBUG + bool ChangedFormulae = false; +#endif + DenseSet<const SCEV *> VisitedRegs; + SmallPtrSet<const SCEV *, 16> Regs; + + for (size_t LUIdx = 0, NumUses = Uses.size(); LUIdx != NumUses; ++LUIdx) { + LSRUse &LU = Uses[LUIdx]; + DEBUG(dbgs() << "Filtering for use "; LU.print(dbgs()); dbgs() << '\n'); + + // Return true if Formula FA is better than Formula FB. + auto IsBetterThan = [&](Formula &FA, Formula &FB) { + // First we will try to choose the Formula with fewer new registers. + // For a register used by current Formula, the more the register is + // shared among LSRUses, the less we increase the register number + // counter of the formula. + size_t FARegNum = 0; + for (const SCEV *Reg : FA.BaseRegs) { + const SmallBitVector &UsedByIndices = RegUses.getUsedByIndices(Reg); + FARegNum += (NumUses - UsedByIndices.count() + 1); + } + size_t FBRegNum = 0; + for (const SCEV *Reg : FB.BaseRegs) { + const SmallBitVector &UsedByIndices = RegUses.getUsedByIndices(Reg); + FBRegNum += (NumUses - UsedByIndices.count() + 1); + } + if (FARegNum != FBRegNum) + return FARegNum < FBRegNum; + + // If the new register numbers are the same, choose the Formula with + // less Cost. + Cost CostFA, CostFB; + Regs.clear(); + CostFA.RateFormula(TTI, FA, Regs, VisitedRegs, L, SE, DT, LU); + Regs.clear(); + CostFB.RateFormula(TTI, FB, Regs, VisitedRegs, L, SE, DT, LU); + return CostFA.isLess(CostFB, TTI); + }; + + bool Any = false; + for (size_t FIdx = 0, NumForms = LU.Formulae.size(); FIdx != NumForms; + ++FIdx) { + Formula &F = LU.Formulae[FIdx]; + if (!F.ScaledReg) + continue; + auto P = BestFormulae.insert({{F.ScaledReg, F.Scale}, FIdx}); + if (P.second) + continue; + + Formula &Best = LU.Formulae[P.first->second]; + if (IsBetterThan(F, Best)) + std::swap(F, Best); + DEBUG(dbgs() << " Filtering out formula "; F.print(dbgs()); + dbgs() << "\n" + " in favor of formula "; + Best.print(dbgs()); dbgs() << '\n'); +#ifndef NDEBUG + ChangedFormulae = true; +#endif + LU.DeleteFormula(F); + --FIdx; + --NumForms; + Any = true; + } + if (Any) + LU.RecomputeRegs(LUIdx, RegUses); + + // Reset this to prepare for the next use. + BestFormulae.clear(); + } + + DEBUG(if (ChangedFormulae) { + dbgs() << "\n" + "After filtering out undesirable candidates:\n"; + print_uses(dbgs()); + }); +} + +/// The function delete formulas with high registers number expectation. +/// Assuming we don't know the value of each formula (already delete +/// all inefficient), generate probability of not selecting for each +/// register. +/// For example, +/// Use1: +/// reg(a) + reg({0,+,1}) +/// reg(a) + reg({-1,+,1}) + 1 +/// reg({a,+,1}) +/// Use2: +/// reg(b) + reg({0,+,1}) +/// reg(b) + reg({-1,+,1}) + 1 +/// reg({b,+,1}) +/// Use3: +/// reg(c) + reg(b) + reg({0,+,1}) +/// reg(c) + reg({b,+,1}) +/// +/// Probability of not selecting +/// Use1 Use2 Use3 +/// reg(a) (1/3) * 1 * 1 +/// reg(b) 1 * (1/3) * (1/2) +/// reg({0,+,1}) (2/3) * (2/3) * (1/2) +/// reg({-1,+,1}) (2/3) * (2/3) * 1 +/// reg({a,+,1}) (2/3) * 1 * 1 +/// reg({b,+,1}) 1 * (2/3) * (2/3) +/// reg(c) 1 * 1 * 0 +/// +/// Now count registers number mathematical expectation for each formula: +/// Note that for each use we exclude probability if not selecting for the use. +/// For example for Use1 probability for reg(a) would be just 1 * 1 (excluding +/// probabilty 1/3 of not selecting for Use1). +/// Use1: +/// reg(a) + reg({0,+,1}) 1 + 1/3 -- to be deleted +/// reg(a) + reg({-1,+,1}) + 1 1 + 4/9 -- to be deleted +/// reg({a,+,1}) 1 +/// Use2: +/// reg(b) + reg({0,+,1}) 1/2 + 1/3 -- to be deleted +/// reg(b) + reg({-1,+,1}) + 1 1/2 + 2/3 -- to be deleted +/// reg({b,+,1}) 2/3 +/// Use3: +/// reg(c) + reg(b) + reg({0,+,1}) 1 + 1/3 + 4/9 -- to be deleted +/// reg(c) + reg({b,+,1}) 1 + 2/3 + +void LSRInstance::NarrowSearchSpaceByDeletingCostlyFormulas() { + if (EstimateSearchSpaceComplexity() < ComplexityLimit) + return; + // Ok, we have too many of formulae on our hands to conveniently handle. + // Use a rough heuristic to thin out the list. + + // Set of Regs wich will be 100% used in final solution. + // Used in each formula of a solution (in example above this is reg(c)). + // We can skip them in calculations. + SmallPtrSet<const SCEV *, 4> UniqRegs; + DEBUG(dbgs() << "The search space is too complex.\n"); + + // Map each register to probability of not selecting + DenseMap <const SCEV *, float> RegNumMap; + for (const SCEV *Reg : RegUses) { + if (UniqRegs.count(Reg)) + continue; + float PNotSel = 1; + for (const LSRUse &LU : Uses) { + if (!LU.Regs.count(Reg)) + continue; + float P = LU.getNotSelectedProbability(Reg); + if (P != 0.0) + PNotSel *= P; + else + UniqRegs.insert(Reg); + } + RegNumMap.insert(std::make_pair(Reg, PNotSel)); + } + + DEBUG(dbgs() << "Narrowing the search space by deleting costly formulas\n"); + + // Delete formulas where registers number expectation is high. + for (size_t LUIdx = 0, NumUses = Uses.size(); LUIdx != NumUses; ++LUIdx) { + LSRUse &LU = Uses[LUIdx]; + // If nothing to delete - continue. + if (LU.Formulae.size() < 2) + continue; + // This is temporary solution to test performance. Float should be + // replaced with round independent type (based on integers) to avoid + // different results for different target builds. + float FMinRegNum = LU.Formulae[0].getNumRegs(); + float FMinARegNum = LU.Formulae[0].getNumRegs(); + size_t MinIdx = 0; + for (size_t i = 0, e = LU.Formulae.size(); i != e; ++i) { + Formula &F = LU.Formulae[i]; + float FRegNum = 0; + float FARegNum = 0; + for (const SCEV *BaseReg : F.BaseRegs) { + if (UniqRegs.count(BaseReg)) + continue; + FRegNum += RegNumMap[BaseReg] / LU.getNotSelectedProbability(BaseReg); + if (isa<SCEVAddRecExpr>(BaseReg)) + FARegNum += + RegNumMap[BaseReg] / LU.getNotSelectedProbability(BaseReg); + } + if (const SCEV *ScaledReg = F.ScaledReg) { + if (!UniqRegs.count(ScaledReg)) { + FRegNum += + RegNumMap[ScaledReg] / LU.getNotSelectedProbability(ScaledReg); + if (isa<SCEVAddRecExpr>(ScaledReg)) + FARegNum += + RegNumMap[ScaledReg] / LU.getNotSelectedProbability(ScaledReg); + } + } + if (FMinRegNum > FRegNum || + (FMinRegNum == FRegNum && FMinARegNum > FARegNum)) { + FMinRegNum = FRegNum; + FMinARegNum = FARegNum; + MinIdx = i; + } + } + DEBUG(dbgs() << " The formula "; LU.Formulae[MinIdx].print(dbgs()); + dbgs() << " with min reg num " << FMinRegNum << '\n'); + if (MinIdx != 0) + std::swap(LU.Formulae[MinIdx], LU.Formulae[0]); + while (LU.Formulae.size() != 1) { + DEBUG(dbgs() << " Deleting "; LU.Formulae.back().print(dbgs()); + dbgs() << '\n'); + LU.Formulae.pop_back(); + } + LU.RecomputeRegs(LUIdx, RegUses); + assert(LU.Formulae.size() == 1 && "Should be exactly 1 min regs formula"); + Formula &F = LU.Formulae[0]; + DEBUG(dbgs() << " Leaving only "; F.print(dbgs()); dbgs() << '\n'); + // When we choose the formula, the regs become unique. + UniqRegs.insert(F.BaseRegs.begin(), F.BaseRegs.end()); + if (F.ScaledReg) + UniqRegs.insert(F.ScaledReg); + } + DEBUG(dbgs() << "After pre-selection:\n"; + print_uses(dbgs())); +} + + /// Pick a register which seems likely to be profitable, and then in any use /// which has any reference to that register, delete all formulae which do not /// reference that register. @@ -4237,7 +4622,12 @@ void LSRInstance::NarrowSearchSpaceUsingHeuristics() { NarrowSearchSpaceByDetectingSupersets(); NarrowSearchSpaceByCollapsingUnrolledCode(); NarrowSearchSpaceByRefilteringUndesirableDedicatedRegisters(); - NarrowSearchSpaceByPickingWinnerRegs(); + if (FilterSameScaledReg) + NarrowSearchSpaceByFilterFormulaWithSameScaledReg(); + if (LSRExpNarrow) + NarrowSearchSpaceByDeletingCostlyFormulas(); + else + NarrowSearchSpaceByPickingWinnerRegs(); } /// This is the recursive solver. @@ -4294,7 +4684,7 @@ void LSRInstance::SolveRecurse(SmallVectorImpl<const Formula *> &Solution, NewCost = CurCost; NewRegs = CurRegs; NewCost.RateFormula(TTI, F, NewRegs, VisitedRegs, L, SE, DT, LU); - if (NewCost < SolutionCost) { + if (NewCost.isLess(SolutionCost, TTI)) { Workspace.push_back(&F); if (Workspace.size() != Uses.size()) { SolveRecurse(Solution, SolutionCost, Workspace, NewCost, @@ -4476,12 +4866,10 @@ LSRInstance::AdjustInsertPositionForExpand(BasicBlock::iterator LowestIP, /// Emit instructions for the leading candidate expression for this LSRUse (this /// is called "expanding"). -Value *LSRInstance::Expand(const LSRUse &LU, - const LSRFixup &LF, - const Formula &F, - BasicBlock::iterator IP, +Value *LSRInstance::Expand(const LSRUse &LU, const LSRFixup &LF, + const Formula &F, BasicBlock::iterator IP, SCEVExpander &Rewriter, - SmallVectorImpl<WeakVH> &DeadInsts) const { + SmallVectorImpl<WeakTrackingVH> &DeadInsts) const { if (LU.RigidFormula) return LF.OperandValToReplace; @@ -4515,11 +4903,7 @@ Value *LSRInstance::Expand(const LSRUse &LU, assert(!Reg->isZero() && "Zero allocated in a base register!"); // If we're expanding for a post-inc user, make the post-inc adjustment. - PostIncLoopSet &Loops = const_cast<PostIncLoopSet &>(LF.PostIncLoops); - Reg = TransformForPostIncUse(Denormalize, Reg, - LF.UserInst, LF.OperandValToReplace, - Loops, SE, DT); - + Reg = denormalizeForPostIncUse(Reg, LF.PostIncLoops, SE); Ops.push_back(SE.getUnknown(Rewriter.expandCodeFor(Reg, nullptr))); } @@ -4530,9 +4914,7 @@ Value *LSRInstance::Expand(const LSRUse &LU, // If we're expanding for a post-inc user, make the post-inc adjustment. PostIncLoopSet &Loops = const_cast<PostIncLoopSet &>(LF.PostIncLoops); - ScaledS = TransformForPostIncUse(Denormalize, ScaledS, - LF.UserInst, LF.OperandValToReplace, - Loops, SE, DT); + ScaledS = denormalizeForPostIncUse(ScaledS, Loops, SE); if (LU.Kind == LSRUse::ICmpZero) { // Expand ScaleReg as if it was part of the base regs. @@ -4662,12 +5044,9 @@ Value *LSRInstance::Expand(const LSRUse &LU, /// Helper for Rewrite. PHI nodes are special because the use of their operands /// effectively happens in their predecessor blocks, so the expression may need /// to be expanded in multiple places. -void LSRInstance::RewriteForPHI(PHINode *PN, - const LSRUse &LU, - const LSRFixup &LF, - const Formula &F, - SCEVExpander &Rewriter, - SmallVectorImpl<WeakVH> &DeadInsts) const { +void LSRInstance::RewriteForPHI( + PHINode *PN, const LSRUse &LU, const LSRFixup &LF, const Formula &F, + SCEVExpander &Rewriter, SmallVectorImpl<WeakTrackingVH> &DeadInsts) const { DenseMap<BasicBlock *, Value *> Inserted; for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) if (PN->getIncomingValue(i) == LF.OperandValToReplace) { @@ -4739,11 +5118,9 @@ void LSRInstance::RewriteForPHI(PHINode *PN, /// Emit instructions for the leading candidate expression for this LSRUse (this /// is called "expanding"), and update the UserInst to reference the newly /// expanded value. -void LSRInstance::Rewrite(const LSRUse &LU, - const LSRFixup &LF, - const Formula &F, - SCEVExpander &Rewriter, - SmallVectorImpl<WeakVH> &DeadInsts) const { +void LSRInstance::Rewrite(const LSRUse &LU, const LSRFixup &LF, + const Formula &F, SCEVExpander &Rewriter, + SmallVectorImpl<WeakTrackingVH> &DeadInsts) const { // First, find an insertion point that dominates UserInst. For PHI nodes, // find the nearest block which dominates all the relevant uses. if (PHINode *PN = dyn_cast<PHINode>(LF.UserInst)) { @@ -4781,7 +5158,7 @@ void LSRInstance::ImplementSolution( const SmallVectorImpl<const Formula *> &Solution) { // Keep track of instructions we may have made dead, so that // we can remove them after we are done working. - SmallVector<WeakVH, 16> DeadInsts; + SmallVector<WeakTrackingVH, 16> DeadInsts; SCEVExpander Rewriter(SE, L->getHeader()->getModule()->getDataLayout(), "lsr"); @@ -4975,10 +5352,11 @@ void LSRInstance::print(raw_ostream &OS) const { print_uses(OS); } -LLVM_DUMP_METHOD -void LSRInstance::dump() const { +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void LSRInstance::dump() const { print(errs()); errs() << '\n'; } +#endif namespace { @@ -5030,7 +5408,7 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE, // Remove any extra phis created by processing inner loops. Changed |= DeleteDeadPHIs(L->getHeader()); if (EnablePhiElim && L->isLoopSimplifyForm()) { - SmallVector<WeakVH, 16> DeadInsts; + SmallVector<WeakTrackingVH, 16> DeadInsts; const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); SCEVExpander Rewriter(SE, DL, "lsr"); #ifndef NDEBUG diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp index c7f9122..530a684 100644 --- a/contrib/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp @@ -44,7 +44,11 @@ using namespace llvm; static cl::opt<unsigned> UnrollThreshold("unroll-threshold", cl::Hidden, - cl::desc("The baseline cost threshold for loop unrolling")); + cl::desc("The cost threshold for loop unrolling")); + +static cl::opt<unsigned> UnrollPartialThreshold( + "unroll-partial-threshold", cl::Hidden, + cl::desc("The cost threshold for partial loop unrolling")); static cl::opt<unsigned> UnrollMaxPercentThresholdBoost( "unroll-max-percent-threshold-boost", cl::init(400), cl::Hidden, @@ -106,10 +110,19 @@ static cl::opt<unsigned> FlatLoopTripCountThreshold( "aggressively unrolled.")); static cl::opt<bool> - UnrollAllowPeeling("unroll-allow-peeling", cl::Hidden, + UnrollAllowPeeling("unroll-allow-peeling", cl::init(true), cl::Hidden, cl::desc("Allows loops to be peeled when the dynamic " "trip count is known to be low.")); +// This option isn't ever intended to be enabled, it serves to allow +// experiments to check the assumptions about when this kind of revisit is +// necessary. +static cl::opt<bool> UnrollRevisitChildLoops( + "unroll-revisit-child-loops", cl::Hidden, + cl::desc("Enqueue and re-visit child loops in the loop PM after unrolling. " + "This shouldn't typically be needed as child loops (or their " + "clones) were already visited.")); + /// A magic value for use with the Threshold parameter to indicate /// that the loop unroll should be performed regardless of how much /// code expansion would result. @@ -118,16 +131,17 @@ static const unsigned NoThreshold = UINT_MAX; /// Gather the various unrolling parameters based on the defaults, compiler /// flags, TTI overrides and user specified parameters. static TargetTransformInfo::UnrollingPreferences gatherUnrollingPreferences( - Loop *L, const TargetTransformInfo &TTI, Optional<unsigned> UserThreshold, - Optional<unsigned> UserCount, Optional<bool> UserAllowPartial, - Optional<bool> UserRuntime, Optional<bool> UserUpperBound) { + Loop *L, ScalarEvolution &SE, const TargetTransformInfo &TTI, int OptLevel, + Optional<unsigned> UserThreshold, Optional<unsigned> UserCount, + Optional<bool> UserAllowPartial, Optional<bool> UserRuntime, + Optional<bool> UserUpperBound) { TargetTransformInfo::UnrollingPreferences UP; // Set up the defaults - UP.Threshold = 150; + UP.Threshold = OptLevel > 2 ? 300 : 150; UP.MaxPercentThresholdBoost = 400; UP.OptSizeThreshold = 0; - UP.PartialThreshold = UP.Threshold; + UP.PartialThreshold = 150; UP.PartialOptSizeThreshold = 0; UP.Count = 0; UP.PeelCount = 0; @@ -141,10 +155,10 @@ static TargetTransformInfo::UnrollingPreferences gatherUnrollingPreferences( UP.AllowExpensiveTripCount = false; UP.Force = false; UP.UpperBound = false; - UP.AllowPeeling = false; + UP.AllowPeeling = true; // Override with any target specific settings - TTI.getUnrollingPreferences(L, UP); + TTI.getUnrollingPreferences(L, SE, UP); // Apply size attributes if (L->getHeader()->getParent()->optForSize()) { @@ -153,10 +167,10 @@ static TargetTransformInfo::UnrollingPreferences gatherUnrollingPreferences( } // Apply any user values specified by cl::opt - if (UnrollThreshold.getNumOccurrences() > 0) { + if (UnrollThreshold.getNumOccurrences() > 0) UP.Threshold = UnrollThreshold; - UP.PartialThreshold = UnrollThreshold; - } + if (UnrollPartialThreshold.getNumOccurrences() > 0) + UP.PartialThreshold = UnrollPartialThreshold; if (UnrollMaxPercentThresholdBoost.getNumOccurrences() > 0) UP.MaxPercentThresholdBoost = UnrollMaxPercentThresholdBoost; if (UnrollMaxCount.getNumOccurrences() > 0) @@ -495,7 +509,7 @@ analyzeLoopUnrollCost(const Loop *L, unsigned TripCount, DominatorTree &DT, KnownSucc = SI->getSuccessor(0); else if (ConstantInt *SimpleCondVal = dyn_cast<ConstantInt>(SimpleCond)) - KnownSucc = SI->findCaseValue(SimpleCondVal).getCaseSuccessor(); + KnownSucc = SI->findCaseValue(SimpleCondVal)->getCaseSuccessor(); } } if (KnownSucc) { @@ -685,7 +699,7 @@ static uint64_t getUnrolledLoopSize( // Calculates unroll count and writes it to UP.Count. static bool computeUnrollCount( Loop *L, const TargetTransformInfo &TTI, DominatorTree &DT, LoopInfo *LI, - ScalarEvolution *SE, OptimizationRemarkEmitter *ORE, unsigned &TripCount, + ScalarEvolution &SE, OptimizationRemarkEmitter *ORE, unsigned &TripCount, unsigned MaxTripCount, unsigned &TripMultiple, unsigned LoopSize, TargetTransformInfo::UnrollingPreferences &UP, bool &UseUpperBound) { // Check for explicit Count. @@ -756,7 +770,7 @@ static bool computeUnrollCount( // helps to remove a significant number of instructions. // To check that, run additional analysis on the loop. if (Optional<EstimatedUnrollCost> Cost = analyzeLoopUnrollCost( - L, FullUnrollTripCount, DT, *SE, TTI, + L, FullUnrollTripCount, DT, SE, TTI, UP.Threshold * UP.MaxPercentThresholdBoost / 100)) { unsigned Boost = getFullUnrollBoostingFactor(*Cost, UP.MaxPercentThresholdBoost); @@ -770,7 +784,15 @@ static bool computeUnrollCount( } } - // 4rd priority is partial unrolling. + // 4th priority is loop peeling + computePeelCount(L, LoopSize, UP, TripCount); + if (UP.PeelCount) { + UP.Runtime = false; + UP.Count = 1; + return ExplicitUnroll; + } + + // 5th priority is partial unrolling. // Try partial unroll only when TripCount could be staticaly calculated. if (TripCount) { UP.Partial |= ExplicitUnroll; @@ -814,6 +836,8 @@ static bool computeUnrollCount( } else { UP.Count = TripCount; } + if (UP.Count > UP.MaxCount) + UP.Count = UP.MaxCount; if ((PragmaFullUnroll || PragmaEnableUnroll) && TripCount && UP.Count != TripCount) ORE->emit( @@ -833,14 +857,6 @@ static bool computeUnrollCount( << "Unable to fully unroll loop as directed by unroll(full) pragma " "because loop has a runtime trip count."); - // 5th priority is loop peeling - computePeelCount(L, LoopSize, UP); - if (UP.PeelCount) { - UP.Runtime = false; - UP.Count = 1; - return ExplicitUnroll; - } - // 6th priority is runtime unrolling. // Don't unroll a runtime trip count loop when it is disabled. if (HasRuntimeUnrollDisablePragma(L)) { @@ -912,9 +928,9 @@ static bool computeUnrollCount( } static bool tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, - ScalarEvolution *SE, const TargetTransformInfo &TTI, + ScalarEvolution &SE, const TargetTransformInfo &TTI, AssumptionCache &AC, OptimizationRemarkEmitter &ORE, - bool PreserveLCSSA, + bool PreserveLCSSA, int OptLevel, Optional<unsigned> ProvidedCount, Optional<unsigned> ProvidedThreshold, Optional<bool> ProvidedAllowPartial, @@ -934,8 +950,8 @@ static bool tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, bool NotDuplicatable; bool Convergent; TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences( - L, TTI, ProvidedThreshold, ProvidedCount, ProvidedAllowPartial, - ProvidedRuntime, ProvidedUpperBound); + L, SE, TTI, OptLevel, ProvidedThreshold, ProvidedCount, + ProvidedAllowPartial, ProvidedRuntime, ProvidedUpperBound); // Exit early if unrolling is disabled. if (UP.Threshold == 0 && (!UP.Partial || UP.PartialThreshold == 0)) return false; @@ -963,8 +979,8 @@ static bool tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, if (!ExitingBlock || !L->isLoopExiting(ExitingBlock)) ExitingBlock = L->getExitingBlock(); if (ExitingBlock) { - TripCount = SE->getSmallConstantTripCount(L, ExitingBlock); - TripMultiple = SE->getSmallConstantTripMultiple(L, ExitingBlock); + TripCount = SE.getSmallConstantTripCount(L, ExitingBlock); + TripMultiple = SE.getSmallConstantTripMultiple(L, ExitingBlock); } // If the loop contains a convergent operation, the prelude we'd add @@ -986,8 +1002,8 @@ static bool tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, // count. bool MaxOrZero = false; if (!TripCount) { - MaxTripCount = SE->getSmallConstantMaxTripCount(L); - MaxOrZero = SE->isBackedgeTakenCountMaxOrZero(L); + MaxTripCount = SE.getSmallConstantMaxTripCount(L); + MaxOrZero = SE.isBackedgeTakenCountMaxOrZero(L); // We can unroll by the upper bound amount if it's generally allowed or if // we know that the loop is executed either the upper bound or zero times. // (MaxOrZero unrolling keeps only the first loop test, so the number of @@ -1016,7 +1032,7 @@ static bool tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, // Unroll the loop. if (!UnrollLoop(L, UP.Count, TripCount, UP.Force, UP.Runtime, UP.AllowExpensiveTripCount, UseUpperBound, MaxOrZero, - TripMultiple, UP.PeelCount, LI, SE, &DT, &AC, &ORE, + TripMultiple, UP.PeelCount, LI, &SE, &DT, &AC, &ORE, PreserveLCSSA)) return false; @@ -1034,16 +1050,17 @@ namespace { class LoopUnroll : public LoopPass { public: static char ID; // Pass ID, replacement for typeid - LoopUnroll(Optional<unsigned> Threshold = None, + LoopUnroll(int OptLevel = 2, Optional<unsigned> Threshold = None, Optional<unsigned> Count = None, Optional<bool> AllowPartial = None, Optional<bool> Runtime = None, Optional<bool> UpperBound = None) - : LoopPass(ID), ProvidedCount(std::move(Count)), + : LoopPass(ID), OptLevel(OptLevel), ProvidedCount(std::move(Count)), ProvidedThreshold(Threshold), ProvidedAllowPartial(AllowPartial), ProvidedRuntime(Runtime), ProvidedUpperBound(UpperBound) { initializeLoopUnrollPass(*PassRegistry::getPassRegistry()); } + int OptLevel; Optional<unsigned> ProvidedCount; Optional<unsigned> ProvidedThreshold; Optional<bool> ProvidedAllowPartial; @@ -1058,7 +1075,7 @@ public: auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); const TargetTransformInfo &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); @@ -1068,7 +1085,7 @@ public: OptimizationRemarkEmitter ORE(&F); bool PreserveLCSSA = mustPreserveAnalysisID(LCSSAID); - return tryToUnrollLoop(L, DT, LI, SE, TTI, AC, ORE, PreserveLCSSA, + return tryToUnrollLoop(L, DT, LI, SE, TTI, AC, ORE, PreserveLCSSA, OptLevel, ProvidedCount, ProvidedThreshold, ProvidedAllowPartial, ProvidedRuntime, ProvidedUpperBound); @@ -1094,26 +1111,27 @@ INITIALIZE_PASS_DEPENDENCY(LoopPass) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_END(LoopUnroll, "loop-unroll", "Unroll loops", false, false) -Pass *llvm::createLoopUnrollPass(int Threshold, int Count, int AllowPartial, - int Runtime, int UpperBound) { +Pass *llvm::createLoopUnrollPass(int OptLevel, int Threshold, int Count, + int AllowPartial, int Runtime, + int UpperBound) { // TODO: It would make more sense for this function to take the optionals // directly, but that's dangerous since it would silently break out of tree // callers. - return new LoopUnroll(Threshold == -1 ? None : Optional<unsigned>(Threshold), - Count == -1 ? None : Optional<unsigned>(Count), - AllowPartial == -1 ? None - : Optional<bool>(AllowPartial), - Runtime == -1 ? None : Optional<bool>(Runtime), - UpperBound == -1 ? None : Optional<bool>(UpperBound)); + return new LoopUnroll( + OptLevel, Threshold == -1 ? None : Optional<unsigned>(Threshold), + Count == -1 ? None : Optional<unsigned>(Count), + AllowPartial == -1 ? None : Optional<bool>(AllowPartial), + Runtime == -1 ? None : Optional<bool>(Runtime), + UpperBound == -1 ? None : Optional<bool>(UpperBound)); } -Pass *llvm::createSimpleLoopUnrollPass() { - return llvm::createLoopUnrollPass(-1, -1, 0, 0, 0); +Pass *llvm::createSimpleLoopUnrollPass(int OptLevel) { + return llvm::createLoopUnrollPass(OptLevel, -1, -1, 0, 0, 0); } PreservedAnalyses LoopUnrollPass::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, - LPMUpdater &) { + LPMUpdater &Updater) { const auto &FAM = AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager(); Function *F = L.getHeader()->getParent(); @@ -1124,12 +1142,84 @@ PreservedAnalyses LoopUnrollPass::run(Loop &L, LoopAnalysisManager &AM, report_fatal_error("LoopUnrollPass: OptimizationRemarkEmitterAnalysis not " "cached at a higher level"); - bool Changed = tryToUnrollLoop(&L, AR.DT, &AR.LI, &AR.SE, AR.TTI, AR.AC, *ORE, - /*PreserveLCSSA*/ true, ProvidedCount, - ProvidedThreshold, ProvidedAllowPartial, - ProvidedRuntime, ProvidedUpperBound); - + // Keep track of the previous loop structure so we can identify new loops + // created by unrolling. + Loop *ParentL = L.getParentLoop(); + SmallPtrSet<Loop *, 4> OldLoops; + if (ParentL) + OldLoops.insert(ParentL->begin(), ParentL->end()); + else + OldLoops.insert(AR.LI.begin(), AR.LI.end()); + + // The API here is quite complex to call, but there are only two interesting + // states we support: partial and full (or "simple") unrolling. However, to + // enable these things we actually pass "None" in for the optional to avoid + // providing an explicit choice. + Optional<bool> AllowPartialParam, RuntimeParam, UpperBoundParam; + if (!AllowPartialUnrolling) + AllowPartialParam = RuntimeParam = UpperBoundParam = false; + bool Changed = tryToUnrollLoop( + &L, AR.DT, &AR.LI, AR.SE, AR.TTI, AR.AC, *ORE, + /*PreserveLCSSA*/ true, OptLevel, /*Count*/ None, + /*Threshold*/ None, AllowPartialParam, RuntimeParam, UpperBoundParam); if (!Changed) return PreservedAnalyses::all(); + + // The parent must not be damaged by unrolling! +#ifndef NDEBUG + if (ParentL) + ParentL->verifyLoop(); +#endif + + // Unrolling can do several things to introduce new loops into a loop nest: + // - Partial unrolling clones child loops within the current loop. If it + // uses a remainder, then it can also create any number of sibling loops. + // - Full unrolling clones child loops within the current loop but then + // removes the current loop making all of the children appear to be new + // sibling loops. + // - Loop peeling can directly introduce new sibling loops by peeling one + // iteration. + // + // When a new loop appears as a sibling loop, either from peeling an + // iteration or fully unrolling, its nesting structure has fundamentally + // changed and we want to revisit it to reflect that. + // + // When unrolling has removed the current loop, we need to tell the + // infrastructure that it is gone. + // + // Finally, we support a debugging/testing mode where we revisit child loops + // as well. These are not expected to require further optimizations as either + // they or the loop they were cloned from have been directly visited already. + // But the debugging mode allows us to check this assumption. + bool IsCurrentLoopValid = false; + SmallVector<Loop *, 4> SibLoops; + if (ParentL) + SibLoops.append(ParentL->begin(), ParentL->end()); + else + SibLoops.append(AR.LI.begin(), AR.LI.end()); + erase_if(SibLoops, [&](Loop *SibLoop) { + if (SibLoop == &L) { + IsCurrentLoopValid = true; + return true; + } + + // Otherwise erase the loop from the list if it was in the old loops. + return OldLoops.count(SibLoop) != 0; + }); + Updater.addSiblingLoops(SibLoops); + + if (!IsCurrentLoopValid) { + Updater.markLoopAsDeleted(L); + } else { + // We can only walk child loops if the current loop remained valid. + if (UnrollRevisitChildLoops) { + // Walk *all* of the child loops. This is a highly speculative mode + // anyways so look for any simplifications that arose from partial + // unrolling or peeling off of iterations. + SmallVector<Loop *, 4> ChildLoops(L.begin(), L.end()); + Updater.addChildLoops(ChildLoops); + } + } + return getLoopPassPreservedAnalyses(); } diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp index 76fe918..d0c96fa 100644 --- a/contrib/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp @@ -8,7 +8,7 @@ //===----------------------------------------------------------------------===// // // This pass transforms loops that contain branches on loop-invariant conditions -// to have multiple loops. For example, it turns the left into the right code: +// to multiple loops. For example, it turns the left into the right code: // // for (...) if (lic) // A for (...) @@ -26,32 +26,34 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" -#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/BlockFrequencyInfoImpl.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/DivergenceAnalysis.h" +#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetTransformInfo.h" -#include "llvm/Analysis/BlockFrequencyInfoImpl.h" -#include "llvm/Analysis/BlockFrequencyInfo.h" -#include "llvm/Analysis/BranchProbabilityInfo.h" -#include "llvm/Support/BranchProbability.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/Module.h" #include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/BranchProbability.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/Local.h" @@ -77,19 +79,6 @@ static cl::opt<unsigned> Threshold("loop-unswitch-threshold", cl::desc("Max loop size to unswitch"), cl::init(100), cl::Hidden); -static cl::opt<bool> -LoopUnswitchWithBlockFrequency("loop-unswitch-with-block-frequency", - cl::init(false), cl::Hidden, - cl::desc("Enable the use of the block frequency analysis to access PGO " - "heuristics to minimize code growth in cold regions.")); - -static cl::opt<unsigned> -ColdnessThreshold("loop-unswitch-coldness-threshold", cl::init(1), cl::Hidden, - cl::desc("Coldness threshold in percentage. The loop header frequency " - "(relative to the entry frequency) is compared with this " - "threshold to determine if non-trivial unswitching should be " - "enabled.")); - namespace { class LUAnalysisCache { @@ -174,13 +163,6 @@ namespace { LUAnalysisCache BranchesInfo; - bool EnabledPGO; - - // BFI and ColdEntryFreq are only used when PGO and - // LoopUnswitchWithBlockFrequency are enabled. - BlockFrequencyInfo BFI; - BlockFrequency ColdEntryFreq; - bool OptimizeForSize; bool redoLoop; @@ -199,12 +181,14 @@ namespace { // NewBlocks contained cloned copy of basic blocks from LoopBlocks. std::vector<BasicBlock*> NewBlocks; + bool hasBranchDivergence; + public: static char ID; // Pass ID, replacement for typeid - explicit LoopUnswitch(bool Os = false) : + explicit LoopUnswitch(bool Os = false, bool hasBranchDivergence = false) : LoopPass(ID), OptimizeForSize(Os), redoLoop(false), currentLoop(nullptr), DT(nullptr), loopHeader(nullptr), - loopPreheader(nullptr) { + loopPreheader(nullptr), hasBranchDivergence(hasBranchDivergence) { initializeLoopUnswitchPass(*PassRegistry::getPassRegistry()); } @@ -217,6 +201,8 @@ namespace { void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<TargetTransformInfoWrapperPass>(); + if (hasBranchDivergence) + AU.addRequired<DivergenceAnalysis>(); getLoopAnalysisUsage(AU); } @@ -255,6 +241,11 @@ namespace { TerminatorInst *TI); void SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L); + + /// Given that the Invariant is not equal to Val. Simplify instructions + /// in the loop. + Value *SimplifyInstructionWithNotEqual(Instruction *Inst, Value *Invariant, + Constant *Val); }; } @@ -381,16 +372,35 @@ INITIALIZE_PASS_BEGIN(LoopUnswitch, "loop-unswitch", "Unswitch loops", INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(LoopPass) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DivergenceAnalysis) INITIALIZE_PASS_END(LoopUnswitch, "loop-unswitch", "Unswitch loops", false, false) -Pass *llvm::createLoopUnswitchPass(bool Os) { - return new LoopUnswitch(Os); +Pass *llvm::createLoopUnswitchPass(bool Os, bool hasBranchDivergence) { + return new LoopUnswitch(Os, hasBranchDivergence); } +/// Operator chain lattice. +enum OperatorChain { + OC_OpChainNone, ///< There is no operator. + OC_OpChainOr, ///< There are only ORs. + OC_OpChainAnd, ///< There are only ANDs. + OC_OpChainMixed ///< There are ANDs and ORs. +}; + /// Cond is a condition that occurs in L. If it is invariant in the loop, or has /// an invariant piece, return the invariant. Otherwise, return null. +// +/// NOTE: FindLIVLoopCondition will not return a partial LIV by walking up a +/// mixed operator chain, as we can not reliably find a value which will simplify +/// the operator chain. If the chain is AND-only or OR-only, we can use 0 or ~0 +/// to simplify the chain. +/// +/// NOTE: In case a partial LIV and a mixed operator chain, we may be able to +/// simplify the condition itself to a loop variant condition, but at the +/// cost of creating an entirely new loop. static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed, + OperatorChain &ParentChain, DenseMap<Value *, Value *> &Cache) { auto CacheIt = Cache.find(Cond); if (CacheIt != Cache.end()) @@ -414,21 +424,53 @@ static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed, return Cond; } + // Walk up the operator chain to find partial invariant conditions. if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Cond)) if (BO->getOpcode() == Instruction::And || BO->getOpcode() == Instruction::Or) { - // If either the left or right side is invariant, we can unswitch on this, - // which will cause the branch to go away in one loop and the condition to - // simplify in the other one. - if (Value *LHS = - FindLIVLoopCondition(BO->getOperand(0), L, Changed, Cache)) { - Cache[Cond] = LHS; - return LHS; + // Given the previous operator, compute the current operator chain status. + OperatorChain NewChain; + switch (ParentChain) { + case OC_OpChainNone: + NewChain = BO->getOpcode() == Instruction::And ? OC_OpChainAnd : + OC_OpChainOr; + break; + case OC_OpChainOr: + NewChain = BO->getOpcode() == Instruction::Or ? OC_OpChainOr : + OC_OpChainMixed; + break; + case OC_OpChainAnd: + NewChain = BO->getOpcode() == Instruction::And ? OC_OpChainAnd : + OC_OpChainMixed; + break; + case OC_OpChainMixed: + NewChain = OC_OpChainMixed; + break; } - if (Value *RHS = - FindLIVLoopCondition(BO->getOperand(1), L, Changed, Cache)) { - Cache[Cond] = RHS; - return RHS; + + // If we reach a Mixed state, we do not want to keep walking up as we can not + // reliably find a value that will simplify the chain. With this check, we + // will return null on the first sight of mixed chain and the caller will + // either backtrack to find partial LIV in other operand or return null. + if (NewChain != OC_OpChainMixed) { + // Update the current operator chain type before we search up the chain. + ParentChain = NewChain; + // If either the left or right side is invariant, we can unswitch on this, + // which will cause the branch to go away in one loop and the condition to + // simplify in the other one. + if (Value *LHS = FindLIVLoopCondition(BO->getOperand(0), L, Changed, + ParentChain, Cache)) { + Cache[Cond] = LHS; + return LHS; + } + // We did not manage to find a partial LIV in operand(0). Backtrack and try + // operand(1). + ParentChain = NewChain; + if (Value *RHS = FindLIVLoopCondition(BO->getOperand(1), L, Changed, + ParentChain, Cache)) { + Cache[Cond] = RHS; + return RHS; + } } } @@ -436,9 +478,21 @@ static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed, return nullptr; } -static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed) { +/// Cond is a condition that occurs in L. If it is invariant in the loop, or has +/// an invariant piece, return the invariant along with the operator chain type. +/// Otherwise, return null. +static std::pair<Value *, OperatorChain> FindLIVLoopCondition(Value *Cond, + Loop *L, + bool &Changed) { DenseMap<Value *, Value *> Cache; - return FindLIVLoopCondition(Cond, L, Changed, Cache); + OperatorChain OpChain = OC_OpChainNone; + Value *FCond = FindLIVLoopCondition(Cond, L, Changed, OpChain, Cache); + + // In case we do find a LIV, it can not be obtained by walking up a mixed + // operator chain. + assert((!FCond || OpChain != OC_OpChainMixed) && + "Do not expect a partial LIV with mixed operator chain"); + return {FCond, OpChain}; } bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPM_Ref) { @@ -457,19 +511,6 @@ bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPM_Ref) { if (SanitizeMemory) computeLoopSafetyInfo(&SafetyInfo, L); - EnabledPGO = F->getEntryCount().hasValue(); - - if (LoopUnswitchWithBlockFrequency && EnabledPGO) { - BranchProbabilityInfo BPI(*F, *LI); - BFI.calculate(*L->getHeader()->getParent(), BPI, *LI); - - // Use BranchProbability to compute a minimum frequency based on - // function entry baseline frequency. Loops with headers below this - // frequency are considered as cold. - const BranchProbability ColdProb(ColdnessThreshold, 100); - ColdEntryFreq = BlockFrequency(BFI.getEntryFreq()) * ColdProb; - } - bool Changed = false; do { assert(currentLoop->isLCSSAForm(*DT)); @@ -581,19 +622,9 @@ bool LoopUnswitch::processCurrentLoop() { loopHeader->getParent()->hasFnAttribute(Attribute::OptimizeForSize)) return false; - if (LoopUnswitchWithBlockFrequency && EnabledPGO) { - // Compute the weighted frequency of the hottest block in the - // loop (loopHeader in this case since inner loops should be - // processed before outer loop). If it is less than ColdFrequency, - // we should not unswitch. - BlockFrequency LoopEntryFreq = BFI.getBlockFreq(loopHeader); - if (LoopEntryFreq < ColdEntryFreq) - return false; - } - for (IntrinsicInst *Guard : Guards) { Value *LoopCond = - FindLIVLoopCondition(Guard->getOperand(0), currentLoop, Changed); + FindLIVLoopCondition(Guard->getOperand(0), currentLoop, Changed).first; if (LoopCond && UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context))) { // NB! Unswitching (if successful) could have erased some of the @@ -634,7 +665,7 @@ bool LoopUnswitch::processCurrentLoop() { // See if this, or some part of it, is loop invariant. If so, we can // unswitch on it if we desire. Value *LoopCond = FindLIVLoopCondition(BI->getCondition(), - currentLoop, Changed); + currentLoop, Changed).first; if (LoopCond && UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context), TI)) { ++NumBranches; @@ -642,24 +673,48 @@ bool LoopUnswitch::processCurrentLoop() { } } } else if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) { - Value *LoopCond = FindLIVLoopCondition(SI->getCondition(), - currentLoop, Changed); + Value *SC = SI->getCondition(); + Value *LoopCond; + OperatorChain OpChain; + std::tie(LoopCond, OpChain) = + FindLIVLoopCondition(SC, currentLoop, Changed); + unsigned NumCases = SI->getNumCases(); if (LoopCond && NumCases) { // Find a value to unswitch on: // FIXME: this should chose the most expensive case! // FIXME: scan for a case with a non-critical edge? Constant *UnswitchVal = nullptr; - - // Do not process same value again and again. - // At this point we have some cases already unswitched and - // some not yet unswitched. Let's find the first not yet unswitched one. - for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); - i != e; ++i) { - Constant *UnswitchValCandidate = i.getCaseValue(); - if (!BranchesInfo.isUnswitched(SI, UnswitchValCandidate)) { - UnswitchVal = UnswitchValCandidate; - break; + // Find a case value such that at least one case value is unswitched + // out. + if (OpChain == OC_OpChainAnd) { + // If the chain only has ANDs and the switch has a case value of 0. + // Dropping in a 0 to the chain will unswitch out the 0-casevalue. + auto *AllZero = cast<ConstantInt>(Constant::getNullValue(SC->getType())); + if (BranchesInfo.isUnswitched(SI, AllZero)) + continue; + // We are unswitching 0 out. + UnswitchVal = AllZero; + } else if (OpChain == OC_OpChainOr) { + // If the chain only has ORs and the switch has a case value of ~0. + // Dropping in a ~0 to the chain will unswitch out the ~0-casevalue. + auto *AllOne = cast<ConstantInt>(Constant::getAllOnesValue(SC->getType())); + if (BranchesInfo.isUnswitched(SI, AllOne)) + continue; + // We are unswitching ~0 out. + UnswitchVal = AllOne; + } else { + assert(OpChain == OC_OpChainNone && + "Expect to unswitch on trivial chain"); + // Do not process same value again and again. + // At this point we have some cases already unswitched and + // some not yet unswitched. Let's find the first not yet unswitched one. + for (auto Case : SI->cases()) { + Constant *UnswitchValCandidate = Case.getCaseValue(); + if (!BranchesInfo.isUnswitched(SI, UnswitchValCandidate)) { + UnswitchVal = UnswitchValCandidate; + break; + } } } @@ -668,6 +723,11 @@ bool LoopUnswitch::processCurrentLoop() { if (UnswitchIfProfitable(LoopCond, UnswitchVal)) { ++NumSwitches; + // In case of a full LIV, UnswitchVal is the value we unswitched out. + // In case of a partial LIV, we only unswitch when its an AND-chain + // or OR-chain. In both cases switch input value simplifies to + // UnswitchVal. + BranchesInfo.setUnswitched(SI, UnswitchVal); return true; } } @@ -678,7 +738,7 @@ bool LoopUnswitch::processCurrentLoop() { BBI != E; ++BBI) if (SelectInst *SI = dyn_cast<SelectInst>(BBI)) { Value *LoopCond = FindLIVLoopCondition(SI->getCondition(), - currentLoop, Changed); + currentLoop, Changed).first; if (LoopCond && UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context))) { ++NumSelects; @@ -753,6 +813,15 @@ bool LoopUnswitch::UnswitchIfProfitable(Value *LoopCond, Constant *Val, << ". Cost too high.\n"); return false; } + if (hasBranchDivergence && + getAnalysis<DivergenceAnalysis>().isDivergent(LoopCond)) { + DEBUG(dbgs() << "NOT unswitching loop %" + << currentLoop->getHeader()->getName() + << " at non-trivial condition '" << *Val + << "' == " << *LoopCond << "\n" + << ". Condition is divergent.\n"); + return false; + } UnswitchNontrivialCondition(LoopCond, Val, currentLoop, TI); return true; @@ -762,7 +831,12 @@ bool LoopUnswitch::UnswitchIfProfitable(Value *LoopCond, Constant *Val, /// mapping the blocks with the specified map. static Loop *CloneLoop(Loop *L, Loop *PL, ValueToValueMapTy &VM, LoopInfo *LI, LPPassManager *LPM) { - Loop &New = LPM->addLoop(PL); + Loop &New = *new Loop(); + if (PL) + PL->addChildLoop(&New); + else + LI->addTopLevelLoop(&New); + LPM->addLoop(New); // Add all of the blocks in L to the new loop. for (Loop::block_iterator I = L->block_begin(), E = L->block_end(); @@ -899,7 +973,6 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) { if (I.mayHaveSideEffects()) return false; - // FIXME: add check for constant foldable switch instructions. if (BranchInst *BI = dyn_cast<BranchInst>(CurrentTerm)) { if (BI->isUnconditional()) { CurrentBB = BI->getSuccessor(0); @@ -911,7 +984,16 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) { // Found a trivial condition candidate: non-foldable conditional branch. break; } + } else if (SwitchInst *SI = dyn_cast<SwitchInst>(CurrentTerm)) { + // At this point, any constant-foldable instructions should have probably + // been folded. + ConstantInt *Cond = dyn_cast<ConstantInt>(SI->getCondition()); + if (!Cond) + break; + // Find the target block we are definitely going to. + CurrentBB = SI->findCaseValue(Cond)->getCaseSuccessor(); } else { + // We do not understand these terminator instructions. break; } @@ -929,7 +1011,7 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) { return false; Value *LoopCond = FindLIVLoopCondition(BI->getCondition(), - currentLoop, Changed); + currentLoop, Changed).first; // Unswitch only if the trivial condition itself is an LIV (not // partial LIV which could occur in and/or) @@ -960,7 +1042,7 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) { } else if (SwitchInst *SI = dyn_cast<SwitchInst>(CurrentTerm)) { // If this isn't switching on an invariant condition, we can't unswitch it. Value *LoopCond = FindLIVLoopCondition(SI->getCondition(), - currentLoop, Changed); + currentLoop, Changed).first; // Unswitch only if the trivial condition itself is an LIV (not // partial LIV which could occur in and/or) @@ -973,13 +1055,12 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) { // this. // Note that we can't trivially unswitch on the default case or // on already unswitched cases. - for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); - i != e; ++i) { + for (auto Case : SI->cases()) { BasicBlock *LoopExitCandidate; - if ((LoopExitCandidate = isTrivialLoopExitBlock(currentLoop, - i.getCaseSuccessor()))) { + if ((LoopExitCandidate = + isTrivialLoopExitBlock(currentLoop, Case.getCaseSuccessor()))) { // Okay, we found a trivial case, remember the value that is trivial. - ConstantInt *CaseVal = i.getCaseValue(); + ConstantInt *CaseVal = Case.getCaseValue(); // Check that it was not unswitched before, since already unswitched // trivial vals are looks trivial too. @@ -998,6 +1079,9 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) { UnswitchTrivialCondition(currentLoop, LoopCond, CondVal, LoopExitBB, nullptr); + + // We are only unswitching full LIV. + BranchesInfo.setUnswitched(SI, CondVal); ++NumSwitches; return true; } @@ -1152,11 +1236,12 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val, LoopProcessWorklist.push_back(NewLoop); redoLoop = true; - // Keep a WeakVH holding onto LIC. If the first call to RewriteLoopBody + // Keep a WeakTrackingVH holding onto LIC. If the first call to + // RewriteLoopBody // deletes the instruction (for example by simplifying a PHI that feeds into // the condition that we're unswitching on), we don't rewrite the second // iteration. - WeakVH LICHandle(LIC); + WeakTrackingVH LICHandle(LIC); // Now we rewrite the original code to know that the condition is true and the // new code to know that the condition is false. @@ -1183,7 +1268,7 @@ static void RemoveFromWorklist(Instruction *I, static void ReplaceUsesOfWith(Instruction *I, Value *V, std::vector<Instruction*> &Worklist, Loop *L, LPPassManager *LPM) { - DEBUG(dbgs() << "Replace with '" << *V << "': " << *I); + DEBUG(dbgs() << "Replace with '" << *V << "': " << *I << "\n"); // Add uses to the worklist, which may be dead now. for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) @@ -1196,7 +1281,8 @@ static void ReplaceUsesOfWith(Instruction *I, Value *V, LPM->deleteSimpleAnalysisValue(I, L); RemoveFromWorklist(I, Worklist); I->replaceAllUsesWith(V); - I->eraseFromParent(); + if (!I->mayHaveSideEffects()) + I->eraseFromParent(); ++NumSimplify; } @@ -1253,18 +1339,38 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC, if (!UI || !L->contains(UI)) continue; - Worklist.push_back(UI); + // At this point, we know LIC is definitely not Val. Try to use some simple + // logic to simplify the user w.r.t. to the context. + if (Value *Replacement = SimplifyInstructionWithNotEqual(UI, LIC, Val)) { + if (LI->replacementPreservesLCSSAForm(UI, Replacement)) { + // This in-loop instruction has been simplified w.r.t. its context, + // i.e. LIC != Val, make sure we propagate its replacement value to + // all its users. + // + // We can not yet delete UI, the LIC user, yet, because that would invalidate + // the LIC->users() iterator !. However, we can make this instruction + // dead by replacing all its users and push it onto the worklist so that + // it can be properly deleted and its operands simplified. + UI->replaceAllUsesWith(Replacement); + } + } - // TODO: We could do other simplifications, for example, turning - // 'icmp eq LIC, Val' -> false. + // This is a LIC user, push it into the worklist so that SimplifyCode can + // attempt to simplify it. + Worklist.push_back(UI); // If we know that LIC is not Val, use this info to simplify code. SwitchInst *SI = dyn_cast<SwitchInst>(UI); if (!SI || !isa<ConstantInt>(Val)) continue; - SwitchInst::CaseIt DeadCase = SI->findCaseValue(cast<ConstantInt>(Val)); + // NOTE: if a case value for the switch is unswitched out, we record it + // after the unswitch finishes. We can not record it here as the switch + // is not a direct user of the partial LIV. + SwitchInst::CaseHandle DeadCase = + *SI->findCaseValue(cast<ConstantInt>(Val)); // Default case is live for multiple values. - if (DeadCase == SI->case_default()) continue; + if (DeadCase == *SI->case_default()) + continue; // Found a dead case value. Don't remove PHI nodes in the // successor if they become single-entry, those PHI nodes may @@ -1274,8 +1380,6 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC, BasicBlock *SISucc = DeadCase.getCaseSuccessor(); BasicBlock *Latch = L->getLoopLatch(); - BranchesInfo.setUnswitched(SI, Val); - if (!SI->findCaseDest(SISucc)) continue; // Edge is critical. // If the DeadCase successor dominates the loop latch, then the // transformation isn't safe since it will delete the sole predecessor edge @@ -1334,7 +1438,7 @@ void LoopUnswitch::SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L) { // Simple DCE. if (isInstructionTriviallyDead(I)) { - DEBUG(dbgs() << "Remove dead instruction '" << *I); + DEBUG(dbgs() << "Remove dead instruction '" << *I << "\n"); // Add uses to the worklist, which may be dead now. for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) @@ -1397,3 +1501,27 @@ void LoopUnswitch::SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L) { } } } + +/// Simple simplifications we can do given the information that Cond is +/// definitely not equal to Val. +Value *LoopUnswitch::SimplifyInstructionWithNotEqual(Instruction *Inst, + Value *Invariant, + Constant *Val) { + // icmp eq cond, val -> false + ICmpInst *CI = dyn_cast<ICmpInst>(Inst); + if (CI && CI->isEquality()) { + Value *Op0 = CI->getOperand(0); + Value *Op1 = CI->getOperand(1); + if ((Op0 == Invariant && Op1 == Val) || (Op0 == Val && Op1 == Invariant)) { + LLVMContext &Ctx = Inst->getContext(); + if (CI->getPredicate() == CmpInst::ICMP_EQ) + return ConstantInt::getFalse(Ctx); + else + return ConstantInt::getTrue(Ctx); + } + } + + // FIXME: there may be other opportunities, e.g. comparison with floating + // point, or Invariant - Val != 0, etc. + return nullptr; +} diff --git a/contrib/llvm/lib/Transforms/Scalar/LowerAtomic.cpp b/contrib/llvm/lib/Transforms/Scalar/LowerAtomic.cpp index 08e60b1..6f77c5b 100644 --- a/contrib/llvm/lib/Transforms/Scalar/LowerAtomic.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/LowerAtomic.cpp @@ -155,8 +155,7 @@ public: } bool runOnFunction(Function &F) override { - if (skipFunction(F)) - return false; + // Don't skip optnone functions; atomics still need to be lowered. FunctionAnalysisManager DummyFAM; auto PA = Impl.run(F, DummyFAM); return !PA.areAllPreserved(); diff --git a/contrib/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp b/contrib/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp index 52975ef..46f8a35 100644 --- a/contrib/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp @@ -14,6 +14,7 @@ #include "llvm/Transforms/Scalar/LowerExpectIntrinsic.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/iterator_range.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" @@ -67,11 +68,11 @@ static bool handleSwitchExpect(SwitchInst &SI) { if (!ExpectedValue) return false; - SwitchInst::CaseIt Case = SI.findCaseValue(ExpectedValue); + SwitchInst::CaseHandle Case = *SI.findCaseValue(ExpectedValue); unsigned n = SI.getNumCases(); // +1 for default case. SmallVector<uint32_t, 16> Weights(n + 1, UnlikelyBranchWeight); - if (Case == SI.case_default()) + if (Case == *SI.case_default()) Weights[0] = LikelyBranchWeight; else Weights[Case.getCaseIndex() + 1] = LikelyBranchWeight; @@ -83,6 +84,151 @@ static bool handleSwitchExpect(SwitchInst &SI) { return true; } +/// Handler for PHINodes that define the value argument to an +/// @llvm.expect call. +/// +/// If the operand of the phi has a constant value and it 'contradicts' +/// with the expected value of phi def, then the corresponding incoming +/// edge of the phi is unlikely to be taken. Using that information, +/// the branch probability info for the originating branch can be inferred. +static void handlePhiDef(CallInst *Expect) { + Value &Arg = *Expect->getArgOperand(0); + ConstantInt *ExpectedValue = dyn_cast<ConstantInt>(Expect->getArgOperand(1)); + if (!ExpectedValue) + return; + const APInt &ExpectedPhiValue = ExpectedValue->getValue(); + + // Walk up in backward a list of instructions that + // have 'copy' semantics by 'stripping' the copies + // until a PHI node or an instruction of unknown kind + // is reached. Negation via xor is also handled. + // + // C = PHI(...); + // B = C; + // A = B; + // D = __builtin_expect(A, 0); + // + Value *V = &Arg; + SmallVector<Instruction *, 4> Operations; + while (!isa<PHINode>(V)) { + if (ZExtInst *ZExt = dyn_cast<ZExtInst>(V)) { + V = ZExt->getOperand(0); + Operations.push_back(ZExt); + continue; + } + + if (SExtInst *SExt = dyn_cast<SExtInst>(V)) { + V = SExt->getOperand(0); + Operations.push_back(SExt); + continue; + } + + BinaryOperator *BinOp = dyn_cast<BinaryOperator>(V); + if (!BinOp || BinOp->getOpcode() != Instruction::Xor) + return; + + ConstantInt *CInt = dyn_cast<ConstantInt>(BinOp->getOperand(1)); + if (!CInt) + return; + + V = BinOp->getOperand(0); + Operations.push_back(BinOp); + } + + // Executes the recorded operations on input 'Value'. + auto ApplyOperations = [&](const APInt &Value) { + APInt Result = Value; + for (auto Op : llvm::reverse(Operations)) { + switch (Op->getOpcode()) { + case Instruction::Xor: + Result ^= cast<ConstantInt>(Op->getOperand(1))->getValue(); + break; + case Instruction::ZExt: + Result = Result.zext(Op->getType()->getIntegerBitWidth()); + break; + case Instruction::SExt: + Result = Result.sext(Op->getType()->getIntegerBitWidth()); + break; + default: + llvm_unreachable("Unexpected operation"); + } + } + return Result; + }; + + auto *PhiDef = dyn_cast<PHINode>(V); + + // Get the first dominating conditional branch of the operand + // i's incoming block. + auto GetDomConditional = [&](unsigned i) -> BranchInst * { + BasicBlock *BB = PhiDef->getIncomingBlock(i); + BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator()); + if (BI && BI->isConditional()) + return BI; + BB = BB->getSinglePredecessor(); + if (!BB) + return nullptr; + BI = dyn_cast<BranchInst>(BB->getTerminator()); + if (!BI || BI->isUnconditional()) + return nullptr; + return BI; + }; + + // Now walk through all Phi operands to find phi oprerands with values + // conflicting with the expected phi output value. Any such operand + // indicates the incoming edge to that operand is unlikely. + for (unsigned i = 0, e = PhiDef->getNumIncomingValues(); i != e; ++i) { + + Value *PhiOpnd = PhiDef->getIncomingValue(i); + ConstantInt *CI = dyn_cast<ConstantInt>(PhiOpnd); + if (!CI) + continue; + + // Not an interesting case when IsUnlikely is false -- we can not infer + // anything useful when the operand value matches the expected phi + // output. + if (ExpectedPhiValue == ApplyOperations(CI->getValue())) + continue; + + BranchInst *BI = GetDomConditional(i); + if (!BI) + continue; + + MDBuilder MDB(PhiDef->getContext()); + + // There are two situations in which an operand of the PhiDef comes + // from a given successor of a branch instruction BI. + // 1) When the incoming block of the operand is the successor block; + // 2) When the incoming block is BI's enclosing block and the + // successor is the PhiDef's enclosing block. + // + // Returns true if the operand which comes from OpndIncomingBB + // comes from outgoing edge of BI that leads to Succ block. + auto *OpndIncomingBB = PhiDef->getIncomingBlock(i); + auto IsOpndComingFromSuccessor = [&](BasicBlock *Succ) { + if (OpndIncomingBB == Succ) + // If this successor is the incoming block for this + // Phi operand, then this successor does lead to the Phi. + return true; + if (OpndIncomingBB == BI->getParent() && Succ == PhiDef->getParent()) + // Otherwise, if the edge is directly from the branch + // to the Phi, this successor is the one feeding this + // Phi operand. + return true; + return false; + }; + + if (IsOpndComingFromSuccessor(BI->getSuccessor(1))) + BI->setMetadata( + LLVMContext::MD_prof, + MDB.createBranchWeights(LikelyBranchWeight, UnlikelyBranchWeight)); + else if (IsOpndComingFromSuccessor(BI->getSuccessor(0))) + BI->setMetadata( + LLVMContext::MD_prof, + MDB.createBranchWeights(UnlikelyBranchWeight, LikelyBranchWeight)); + } +} + // Handle both BranchInst and SelectInst. template <class BrSelInst> static bool handleBrSelExpect(BrSelInst &BSI) { @@ -98,10 +244,18 @@ template <class BrSelInst> static bool handleBrSelExpect(BrSelInst &BSI) { CallInst *CI; ICmpInst *CmpI = dyn_cast<ICmpInst>(BSI.getCondition()); + CmpInst::Predicate Predicate; + ConstantInt *CmpConstOperand = nullptr; if (!CmpI) { CI = dyn_cast<CallInst>(BSI.getCondition()); + Predicate = CmpInst::ICMP_NE; } else { - if (CmpI->getPredicate() != CmpInst::ICMP_NE) + Predicate = CmpI->getPredicate(); + if (Predicate != CmpInst::ICMP_NE && Predicate != CmpInst::ICMP_EQ) + return false; + + CmpConstOperand = dyn_cast<ConstantInt>(CmpI->getOperand(1)); + if (!CmpConstOperand) return false; CI = dyn_cast<CallInst>(CmpI->getOperand(0)); } @@ -109,6 +263,13 @@ template <class BrSelInst> static bool handleBrSelExpect(BrSelInst &BSI) { if (!CI) return false; + uint64_t ValueComparedTo = 0; + if (CmpConstOperand) { + if (CmpConstOperand->getBitWidth() > 64) + return false; + ValueComparedTo = CmpConstOperand->getZExtValue(); + } + Function *Fn = CI->getCalledFunction(); if (!Fn || Fn->getIntrinsicID() != Intrinsic::expect) return false; @@ -121,9 +282,8 @@ template <class BrSelInst> static bool handleBrSelExpect(BrSelInst &BSI) { MDBuilder MDB(CI->getContext()); MDNode *Node; - // If expect value is equal to 1 it means that we are more likely to take - // branch 0, in other case more likely is branch 1. - if (ExpectedValue->isOne()) + if ((ExpectedValue->getZExtValue() == ValueComparedTo) == + (Predicate == CmpInst::ICMP_EQ)) Node = MDB.createBranchWeights(LikelyBranchWeight, UnlikelyBranchWeight); else Node = MDB.createBranchWeights(UnlikelyBranchWeight, LikelyBranchWeight); @@ -173,6 +333,10 @@ static bool lowerExpectIntrinsic(Function &F) { Function *Fn = CI->getCalledFunction(); if (Fn && Fn->getIntrinsicID() == Intrinsic::expect) { + // Before erasing the llvm.expect, walk backward to find + // phi that define llvm.expect's first arg, and + // infer branch probability: + handlePhiDef(CI); Value *Exp = CI->getArgOperand(0); CI->replaceAllUsesWith(Exp); CI->eraseFromParent(); diff --git a/contrib/llvm/lib/Transforms/Scalar/LowerGuardIntrinsic.cpp b/contrib/llvm/lib/Transforms/Scalar/LowerGuardIntrinsic.cpp index 4f41371..070114a 100644 --- a/contrib/llvm/lib/Transforms/Scalar/LowerGuardIntrinsic.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/LowerGuardIntrinsic.cpp @@ -17,10 +17,10 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" -#include "llvm/IR/IRBuilder.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" #include "llvm/Pass.h" diff --git a/contrib/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/contrib/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp index 1b59014..7896396 100644 --- a/contrib/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp @@ -13,19 +13,48 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/MemCpyOptimizer.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/MemoryDependenceAnalysis.h" +#include "llvm/Analysis/MemoryLocation.h" +#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" #include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" #include <algorithm> +#include <cassert> +#include <cstdint> + using namespace llvm; #define DEBUG_TYPE "memcpyopt" @@ -119,6 +148,7 @@ static bool IsPointerOffset(Value *Ptr1, Value *Ptr2, int64_t &Offset, return true; } +namespace { /// Represents a range of memset'd bytes with the ByteVal value. /// This allows us to analyze stores like: @@ -130,7 +160,6 @@ static bool IsPointerOffset(Value *Ptr1, Value *Ptr2, int64_t &Offset, /// the first store, we make a range [1, 2). The second store extends the range /// to [0, 2). The third makes a new range [2, 3). The fourth store joins the /// two ranges into [0, 3) which is memset'able. -namespace { struct MemsetRange { // Start/End - A semi range that describes the span that this range covers. // The range is closed at the start and open at the end: [Start, End). @@ -148,7 +177,8 @@ struct MemsetRange { bool isProfitableToUseMemset(const DataLayout &DL) const; }; -} // end anon namespace + +} // end anonymous namespace bool MemsetRange::isProfitableToUseMemset(const DataLayout &DL) const { // If we found more than 4 stores to merge or 16 bytes, use memset. @@ -192,13 +222,14 @@ bool MemsetRange::isProfitableToUseMemset(const DataLayout &DL) const { return TheStores.size() > NumPointerStores+NumByteStores; } - namespace { + class MemsetRanges { /// A sorted list of the memset ranges. SmallVector<MemsetRange, 8> Ranges; typedef SmallVectorImpl<MemsetRange>::iterator range_iterator; const DataLayout &DL; + public: MemsetRanges(const DataLayout &DL) : DL(DL) {} @@ -231,8 +262,7 @@ public: }; -} // end anon namespace - +} // end anonymous namespace /// Add a new store to the MemsetRanges data structure. This adds a /// new range for the specified store at the specified offset, merging into @@ -299,48 +329,36 @@ void MemsetRanges::addRange(int64_t Start, int64_t Size, Value *Ptr, //===----------------------------------------------------------------------===// namespace { - class MemCpyOptLegacyPass : public FunctionPass { - MemCpyOptPass Impl; - public: - static char ID; // Pass identification, replacement for typeid - MemCpyOptLegacyPass() : FunctionPass(ID) { - initializeMemCpyOptLegacyPassPass(*PassRegistry::getPassRegistry()); - } - bool runOnFunction(Function &F) override; - - private: - // This transformation requires dominator postdominator info - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesCFG(); - AU.addRequired<AssumptionCacheTracker>(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<MemoryDependenceWrapperPass>(); - AU.addRequired<AAResultsWrapperPass>(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); - AU.addPreserved<MemoryDependenceWrapperPass>(); - } +class MemCpyOptLegacyPass : public FunctionPass { + MemCpyOptPass Impl; - // Helper functions - bool processStore(StoreInst *SI, BasicBlock::iterator &BBI); - bool processMemSet(MemSetInst *SI, BasicBlock::iterator &BBI); - bool processMemCpy(MemCpyInst *M); - bool processMemMove(MemMoveInst *M); - bool performCallSlotOptzn(Instruction *cpy, Value *cpyDst, Value *cpySrc, - uint64_t cpyLen, unsigned cpyAlign, CallInst *C); - bool processMemCpyMemCpyDependence(MemCpyInst *M, MemCpyInst *MDep); - bool processMemSetMemCpyDependence(MemCpyInst *M, MemSetInst *MDep); - bool performMemCpyToMemSetOptzn(MemCpyInst *M, MemSetInst *MDep); - bool processByValArgument(CallSite CS, unsigned ArgNo); - Instruction *tryMergingIntoMemset(Instruction *I, Value *StartPtr, - Value *ByteVal); - - bool iterateOnFunction(Function &F); - }; +public: + static char ID; // Pass identification, replacement for typeid - char MemCpyOptLegacyPass::ID = 0; -} + MemCpyOptLegacyPass() : FunctionPass(ID) { + initializeMemCpyOptLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override; + +private: + // This transformation requires dominator postdominator info + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<MemoryDependenceWrapperPass>(); + AU.addRequired<AAResultsWrapperPass>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); + AU.addPreserved<MemoryDependenceWrapperPass>(); + } +}; + +char MemCpyOptLegacyPass::ID = 0; + +} // end anonymous namespace /// The public interface to this file... FunctionPass *llvm::createMemCpyOptPass() { return new MemCpyOptLegacyPass(); } @@ -523,14 +541,15 @@ static bool moveUp(AliasAnalysis &AA, StoreInst *SI, Instruction *P, if (Args.erase(C)) NeedLift = true; else if (MayAlias) { - NeedLift = any_of(MemLocs, [C, &AA](const MemoryLocation &ML) { + NeedLift = llvm::any_of(MemLocs, [C, &AA](const MemoryLocation &ML) { return AA.getModRefInfo(C, ML); }); if (!NeedLift) - NeedLift = any_of(CallSites, [C, &AA](const ImmutableCallSite &CS) { - return AA.getModRefInfo(C, CS); - }); + NeedLift = + llvm::any_of(CallSites, [C, &AA](const ImmutableCallSite &CS) { + return AA.getModRefInfo(C, CS); + }); } if (!NeedLift) @@ -567,7 +586,7 @@ static bool moveUp(AliasAnalysis &AA, StoreInst *SI, Instruction *P, } // We made it, we need to lift - for (auto *I : reverse(ToLift)) { + for (auto *I : llvm::reverse(ToLift)) { DEBUG(dbgs() << "Lifting " << *I << " before " << *P << "\n"); I->moveBefore(P); } @@ -761,7 +780,6 @@ bool MemCpyOptPass::processMemSet(MemSetInst *MSI, BasicBlock::iterator &BBI) { return false; } - /// Takes a memcpy and a call that it depends on, /// and checks for the possibility of a call slot optimization by having /// the call write its result directly into the destination of the memcpy. @@ -914,6 +932,17 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpy, Value *cpyDest, if (MR != MRI_NoModRef) return false; + // We can't create address space casts here because we don't know if they're + // safe for the target. + if (cpySrc->getType()->getPointerAddressSpace() != + cpyDest->getType()->getPointerAddressSpace()) + return false; + for (unsigned i = 0; i < CS.arg_size(); ++i) + if (CS.getArgument(i)->stripPointerCasts() == cpySrc && + cpySrc->getType()->getPointerAddressSpace() != + CS.getArgument(i)->getType()->getPointerAddressSpace()) + return false; + // All the checks have passed, so do the transformation. bool changedArgument = false; for (unsigned i = 0; i < CS.arg_size(); ++i) @@ -1240,7 +1269,7 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M) { bool MemCpyOptPass::processMemMove(MemMoveInst *M) { AliasAnalysis &AA = LookupAliasAnalysis(); - if (!TLI->has(LibFunc::memmove)) + if (!TLI->has(LibFunc_memmove)) return false; // See if the pointers alias. @@ -1294,7 +1323,7 @@ bool MemCpyOptPass::processByValArgument(CallSite CS, unsigned ArgNo) { // Get the alignment of the byval. If the call doesn't specify the alignment, // then it is some target specific value that we can't know. - unsigned ByValAlign = CS.getParamAlignment(ArgNo+1); + unsigned ByValAlign = CS.getParamAlignment(ArgNo); if (ByValAlign == 0) return false; // If it is greater than the memcpy, then we check to see if we can force the @@ -1306,6 +1335,11 @@ bool MemCpyOptPass::processByValArgument(CallSite CS, unsigned ArgNo) { CS.getInstruction(), &AC, &DT) < ByValAlign) return false; + // The address space of the memcpy source must match the byval argument + if (MDep->getSource()->getType()->getPointerAddressSpace() != + ByValArg->getType()->getPointerAddressSpace()) + return false; + // Verify that the copied-from memory doesn't change in between the memcpy and // the byval call. // memcpy(a <- b) @@ -1375,7 +1409,6 @@ bool MemCpyOptPass::iterateOnFunction(Function &F) { } PreservedAnalyses MemCpyOptPass::run(Function &F, FunctionAnalysisManager &AM) { - auto &MD = AM.getResult<MemoryDependenceAnalysis>(F); auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); @@ -1393,7 +1426,9 @@ PreservedAnalyses MemCpyOptPass::run(Function &F, FunctionAnalysisManager &AM) { LookupAssumptionCache, LookupDomTree); if (!MadeChange) return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); PA.preserve<GlobalsAA>(); PA.preserve<MemoryDependenceAnalysis>(); return PA; @@ -1414,10 +1449,10 @@ bool MemCpyOptPass::runImpl( // If we don't have at least memset and memcpy, there is little point of doing // anything here. These are required by a freestanding implementation, so if // even they are disabled, there is no point in trying hard. - if (!TLI->has(LibFunc::memset) || !TLI->has(LibFunc::memcpy)) + if (!TLI->has(LibFunc_memset) || !TLI->has(LibFunc_memcpy)) return false; - while (1) { + while (true) { if (!iterateOnFunction(F)) break; MadeChange = true; diff --git a/contrib/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp b/contrib/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp index 6a64c6b..6727cf0 100644 --- a/contrib/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp @@ -19,6 +19,8 @@ // thinks it safe to do so. This optimization helps with eg. hiding load // latencies, triggering if-conversion, and reducing static code size. // +// NOTE: This code no longer performs load hoisting, it is subsumed by GVNHoist. +// //===----------------------------------------------------------------------===// // // @@ -87,7 +89,6 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/SSAUpdater.h" using namespace llvm; @@ -118,16 +119,6 @@ private: void removeInstruction(Instruction *Inst); BasicBlock *getDiamondTail(BasicBlock *BB); bool isDiamondHead(BasicBlock *BB); - // Routines for hoisting loads - bool isLoadHoistBarrierInRange(const Instruction &Start, - const Instruction &End, LoadInst *LI, - bool SafeToLoadUnconditionally); - LoadInst *canHoistFromBlock(BasicBlock *BB, LoadInst *LI); - void hoistInstruction(BasicBlock *BB, Instruction *HoistCand, - Instruction *ElseInst); - bool isSafeToHoist(Instruction *I) const; - bool hoistLoad(BasicBlock *BB, LoadInst *HoistCand, LoadInst *ElseInst); - bool mergeLoads(BasicBlock *BB); // Routines for sinking stores StoreInst *canSinkFromBlock(BasicBlock *BB, StoreInst *SI); PHINode *getPHIOperand(BasicBlock *BB, StoreInst *S0, StoreInst *S1); @@ -188,169 +179,6 @@ bool MergedLoadStoreMotion::isDiamondHead(BasicBlock *BB) { return true; } -/// -/// \brief True when instruction is a hoist barrier for a load -/// -/// Whenever an instruction could possibly modify the value -/// being loaded or protect against the load from happening -/// it is considered a hoist barrier. -/// -bool MergedLoadStoreMotion::isLoadHoistBarrierInRange( - const Instruction &Start, const Instruction &End, LoadInst *LI, - bool SafeToLoadUnconditionally) { - if (!SafeToLoadUnconditionally) - for (const Instruction &Inst : - make_range(Start.getIterator(), End.getIterator())) - if (!isGuaranteedToTransferExecutionToSuccessor(&Inst)) - return true; - MemoryLocation Loc = MemoryLocation::get(LI); - return AA->canInstructionRangeModRef(Start, End, Loc, MRI_Mod); -} - -/// -/// \brief Decide if a load can be hoisted -/// -/// When there is a load in \p BB to the same address as \p LI -/// and it can be hoisted from \p BB, return that load. -/// Otherwise return Null. -/// -LoadInst *MergedLoadStoreMotion::canHoistFromBlock(BasicBlock *BB1, - LoadInst *Load0) { - BasicBlock *BB0 = Load0->getParent(); - BasicBlock *Head = BB0->getSinglePredecessor(); - bool SafeToLoadUnconditionally = isSafeToLoadUnconditionally( - Load0->getPointerOperand(), Load0->getAlignment(), - Load0->getModule()->getDataLayout(), - /*ScanFrom=*/Head->getTerminator()); - for (BasicBlock::iterator BBI = BB1->begin(), BBE = BB1->end(); BBI != BBE; - ++BBI) { - Instruction *Inst = &*BBI; - - // Only merge and hoist loads when their result in used only in BB - auto *Load1 = dyn_cast<LoadInst>(Inst); - if (!Load1 || Inst->isUsedOutsideOfBlock(BB1)) - continue; - - MemoryLocation Loc0 = MemoryLocation::get(Load0); - MemoryLocation Loc1 = MemoryLocation::get(Load1); - if (Load0->isSameOperationAs(Load1) && AA->isMustAlias(Loc0, Loc1) && - !isLoadHoistBarrierInRange(BB1->front(), *Load1, Load1, - SafeToLoadUnconditionally) && - !isLoadHoistBarrierInRange(BB0->front(), *Load0, Load0, - SafeToLoadUnconditionally)) { - return Load1; - } - } - return nullptr; -} - -/// -/// \brief Merge two equivalent instructions \p HoistCand and \p ElseInst into -/// \p BB -/// -/// BB is the head of a diamond -/// -void MergedLoadStoreMotion::hoistInstruction(BasicBlock *BB, - Instruction *HoistCand, - Instruction *ElseInst) { - DEBUG(dbgs() << " Hoist Instruction into BB \n"; BB->dump(); - dbgs() << "Instruction Left\n"; HoistCand->dump(); dbgs() << "\n"; - dbgs() << "Instruction Right\n"; ElseInst->dump(); dbgs() << "\n"); - // Hoist the instruction. - assert(HoistCand->getParent() != BB); - - // Intersect optional metadata. - HoistCand->andIRFlags(ElseInst); - HoistCand->dropUnknownNonDebugMetadata(); - - // Prepend point for instruction insert - Instruction *HoistPt = BB->getTerminator(); - - // Merged instruction - Instruction *HoistedInst = HoistCand->clone(); - - // Hoist instruction. - HoistedInst->insertBefore(HoistPt); - - HoistCand->replaceAllUsesWith(HoistedInst); - removeInstruction(HoistCand); - // Replace the else block instruction. - ElseInst->replaceAllUsesWith(HoistedInst); - removeInstruction(ElseInst); -} - -/// -/// \brief Return true if no operand of \p I is defined in I's parent block -/// -bool MergedLoadStoreMotion::isSafeToHoist(Instruction *I) const { - BasicBlock *Parent = I->getParent(); - for (Use &U : I->operands()) - if (auto *Instr = dyn_cast<Instruction>(&U)) - if (Instr->getParent() == Parent) - return false; - return true; -} - -/// -/// \brief Merge two equivalent loads and GEPs and hoist into diamond head -/// -bool MergedLoadStoreMotion::hoistLoad(BasicBlock *BB, LoadInst *L0, - LoadInst *L1) { - // Only one definition? - auto *A0 = dyn_cast<Instruction>(L0->getPointerOperand()); - auto *A1 = dyn_cast<Instruction>(L1->getPointerOperand()); - if (A0 && A1 && A0->isIdenticalTo(A1) && isSafeToHoist(A0) && - A0->hasOneUse() && (A0->getParent() == L0->getParent()) && - A1->hasOneUse() && (A1->getParent() == L1->getParent()) && - isa<GetElementPtrInst>(A0)) { - DEBUG(dbgs() << "Hoist Instruction into BB \n"; BB->dump(); - dbgs() << "Instruction Left\n"; L0->dump(); dbgs() << "\n"; - dbgs() << "Instruction Right\n"; L1->dump(); dbgs() << "\n"); - hoistInstruction(BB, A0, A1); - hoistInstruction(BB, L0, L1); - return true; - } - return false; -} - -/// -/// \brief Try to hoist two loads to same address into diamond header -/// -/// Starting from a diamond head block, iterate over the instructions in one -/// successor block and try to match a load in the second successor. -/// -bool MergedLoadStoreMotion::mergeLoads(BasicBlock *BB) { - bool MergedLoads = false; - assert(isDiamondHead(BB)); - BranchInst *BI = cast<BranchInst>(BB->getTerminator()); - BasicBlock *Succ0 = BI->getSuccessor(0); - BasicBlock *Succ1 = BI->getSuccessor(1); - // #Instructions in Succ1 for Compile Time Control - int Size1 = Succ1->size(); - int NLoads = 0; - for (BasicBlock::iterator BBI = Succ0->begin(), BBE = Succ0->end(); - BBI != BBE;) { - Instruction *I = &*BBI; - ++BBI; - - // Don't move non-simple (atomic, volatile) loads. - auto *L0 = dyn_cast<LoadInst>(I); - if (!L0 || !L0->isSimple() || L0->isUsedOutsideOfBlock(Succ0)) - continue; - - ++NLoads; - if (NLoads * Size1 >= MagicCompileTimeControl) - break; - if (LoadInst *L1 = canHoistFromBlock(Succ1, L0)) { - bool Res = hoistLoad(BB, L0, L1); - MergedLoads |= Res; - // Don't attempt to hoist above loads that had not been hoisted. - if (!Res) - break; - } - } - return MergedLoads; -} /// /// \brief True when instruction is a sink barrier for a store @@ -410,7 +238,7 @@ PHINode *MergedLoadStoreMotion::getPHIOperand(BasicBlock *BB, StoreInst *S0, &BB->front()); NewPN->addIncoming(Opd1, S0->getParent()); NewPN->addIncoming(Opd2, S1->getParent()); - if (MD && NewPN->getType()->getScalarType()->isPointerTy()) + if (MD && NewPN->getType()->isPtrOrPtrVectorTy()) MD->invalidateCachedPointerInfo(NewPN); return NewPN; } @@ -534,7 +362,6 @@ bool MergedLoadStoreMotion::run(Function &F, MemoryDependenceResults *MD, // Hoist equivalent loads and sink stores // outside diamonds when possible if (isDiamondHead(BB)) { - Changed |= mergeLoads(BB); Changed |= mergeStores(getDiamondTail(BB)); } } @@ -596,8 +423,8 @@ MergedLoadStoreMotionPass::run(Function &F, FunctionAnalysisManager &AM) { if (!Impl.run(F, MD, AA)) return PreservedAnalyses::all(); - // FIXME: This should also 'preserve the CFG'. PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); PA.preserve<GlobalsAA>(); PA.preserve<MemoryDependenceAnalysis>(); return PA; diff --git a/contrib/llvm/lib/Transforms/Scalar/NaryReassociate.cpp b/contrib/llvm/lib/Transforms/Scalar/NaryReassociate.cpp index 0a3bf7b..d0bfe36 100644 --- a/contrib/llvm/lib/Transforms/Scalar/NaryReassociate.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/NaryReassociate.cpp @@ -156,20 +156,12 @@ PreservedAnalyses NaryReassociatePass::run(Function &F, auto *TLI = &AM.getResult<TargetLibraryAnalysis>(F); auto *TTI = &AM.getResult<TargetIRAnalysis>(F); - bool Changed = runImpl(F, AC, DT, SE, TLI, TTI); - - // FIXME: We need to invalidate this to avoid PR28400. Is there a better - // solution? - AM.invalidate<ScalarEvolutionAnalysis>(F); - - if (!Changed) + if (!runImpl(F, AC, DT, SE, TLI, TTI)) return PreservedAnalyses::all(); - // FIXME: This should also 'preserve the CFG'. PreservedAnalyses PA; - PA.preserve<DominatorTreeAnalysis>(); + PA.preserveSet<CFGAnalyses>(); PA.preserve<ScalarEvolutionAnalysis>(); - PA.preserve<TargetLibraryAnalysis>(); return PA; } @@ -219,7 +211,8 @@ bool NaryReassociatePass::doOneIteration(Function &F) { Changed = true; SE->forgetValue(&*I); I->replaceAllUsesWith(NewI); - // If SeenExprs constains I's WeakVH, that entry will be replaced with + // If SeenExprs constains I's WeakTrackingVH, that entry will be + // replaced with // nullptr. RecursivelyDeleteTriviallyDeadInstructions(&*I, TLI); I = NewI->getIterator(); @@ -227,7 +220,7 @@ bool NaryReassociatePass::doOneIteration(Function &F) { // Add the rewritten instruction to SeenExprs; the original instruction // is deleted. const SCEV *NewSCEV = SE->getSCEV(&*I); - SeenExprs[NewSCEV].push_back(WeakVH(&*I)); + SeenExprs[NewSCEV].push_back(WeakTrackingVH(&*I)); // Ideally, NewSCEV should equal OldSCEV because tryReassociate(I) // is equivalent to I. However, ScalarEvolution::getSCEV may // weaken nsw causing NewSCEV not to equal OldSCEV. For example, suppose @@ -247,7 +240,7 @@ bool NaryReassociatePass::doOneIteration(Function &F) { // // This improvement is exercised in @reassociate_gep_nsw in nary-gep.ll. if (NewSCEV != OldSCEV) - SeenExprs[OldSCEV].push_back(WeakVH(&*I)); + SeenExprs[OldSCEV].push_back(WeakTrackingVH(&*I)); } } } @@ -502,7 +495,8 @@ NaryReassociatePass::findClosestMatchingDominator(const SCEV *CandidateExpr, // future instruction either. Therefore, we pop it out of the stack. This // optimization makes the algorithm O(n). while (!Candidates.empty()) { - // Candidates stores WeakVHs, so a candidate can be nullptr if it's removed + // Candidates stores WeakTrackingVHs, so a candidate can be nullptr if it's + // removed // during rewriting. if (Value *Candidate = Candidates.back()) { Instruction *CandidateInstruction = cast<Instruction>(Candidate); diff --git a/contrib/llvm/lib/Transforms/Scalar/NewGVN.cpp b/contrib/llvm/lib/Transforms/Scalar/NewGVN.cpp index 57e6e3d..9d01856 100644 --- a/contrib/llvm/lib/Transforms/Scalar/NewGVN.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/NewGVN.cpp @@ -17,6 +17,37 @@ /// "A Sparse Algorithm for Predicated Global Value Numbering" from /// Karthik Gargi. /// +/// A brief overview of the algorithm: The algorithm is essentially the same as +/// the standard RPO value numbering algorithm (a good reference is the paper +/// "SCC based value numbering" by L. Taylor Simpson) with one major difference: +/// The RPO algorithm proceeds, on every iteration, to process every reachable +/// block and every instruction in that block. This is because the standard RPO +/// algorithm does not track what things have the same value number, it only +/// tracks what the value number of a given operation is (the mapping is +/// operation -> value number). Thus, when a value number of an operation +/// changes, it must reprocess everything to ensure all uses of a value number +/// get updated properly. In constrast, the sparse algorithm we use *also* +/// tracks what operations have a given value number (IE it also tracks the +/// reverse mapping from value number -> operations with that value number), so +/// that it only needs to reprocess the instructions that are affected when +/// something's value number changes. The vast majority of complexity and code +/// in this file is devoted to tracking what value numbers could change for what +/// instructions when various things happen. The rest of the algorithm is +/// devoted to performing symbolic evaluation, forward propagation, and +/// simplification of operations based on the value numbers deduced so far +/// +/// In order to make the GVN mostly-complete, we use a technique derived from +/// "Detection of Redundant Expressions: A Complete and Polynomial-time +/// Algorithm in SSA" by R.R. Pai. The source of incompleteness in most SSA +/// based GVN algorithms is related to their inability to detect equivalence +/// between phi of ops (IE phi(a+b, c+d)) and op of phis (phi(a,c) + phi(b, d)). +/// We resolve this issue by generating the equivalent "phi of ops" form for +/// each op of phis we see, in a way that only takes polynomial time to resolve. +/// +/// We also do not perform elimination by using any published algorithm. All +/// published algorithms are O(Instructions). Instead, we use a technique that +/// is O(number of operations with the same value number), enabling us to skip +/// trying to eliminate things that have unique value numbers. //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/NewGVN.h" @@ -30,7 +61,6 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallSet.h" -#include "llvm/ADT/SparseBitVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/TinyPtrVector.h" #include "llvm/Analysis/AliasAnalysis.h" @@ -40,13 +70,10 @@ #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InstructionSimplify.h" -#include "llvm/Analysis/Loads.h" #include "llvm/Analysis/MemoryBuiltins.h" -#include "llvm/Analysis/MemoryDependenceAnalysis.h" #include "llvm/Analysis/MemoryLocation.h" -#include "llvm/Analysis/PHITransAddr.h" +#include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/GlobalVariable.h" @@ -55,24 +82,25 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/PatternMatch.h" -#include "llvm/IR/PredIteratorCache.h" #include "llvm/IR/Type.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugCounter.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/GVNExpression.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" -#include "llvm/Transforms/Utils/MemorySSA.h" -#include "llvm/Transforms/Utils/SSAUpdater.h" +#include "llvm/Transforms/Utils/PredicateInfo.h" +#include "llvm/Transforms/Utils/VNCoercion.h" +#include <numeric> #include <unordered_map> #include <utility> #include <vector> using namespace llvm; using namespace PatternMatch; using namespace llvm::GVNExpression; - +using namespace llvm::VNCoercion; #define DEBUG_TYPE "newgvn" STATISTIC(NumGVNInstrDeleted, "Number of instructions deleted"); @@ -85,8 +113,19 @@ STATISTIC(NumGVNLeaderChanges, "Number of leader changes"); STATISTIC(NumGVNSortedLeaderChanges, "Number of sorted leader changes"); STATISTIC(NumGVNAvoidedSortedLeaderChanges, "Number of avoided sorted leader changes"); -STATISTIC(NumGVNNotMostDominatingLeader, - "Number of times a member dominated it's new classes' leader"); +STATISTIC(NumGVNDeadStores, "Number of redundant/dead stores eliminated"); +STATISTIC(NumGVNPHIOfOpsCreated, "Number of PHI of ops created"); +STATISTIC(NumGVNPHIOfOpsEliminations, + "Number of things eliminated using PHI of ops"); +DEBUG_COUNTER(VNCounter, "newgvn-vn", + "Controls which instructions are value numbered") +DEBUG_COUNTER(PHIOfOpsCounter, "newgvn-phi", + "Controls which instructions we create phi of ops for") +// Currently store defining access refinement is too slow due to basicaa being +// egregiously slow. This flag lets us keep it working while we work on this +// issue. +static cl::opt<bool> EnableStoreRefinement("enable-store-refinement", + cl::init(false), cl::Hidden); //===----------------------------------------------------------------------===// // GVN Pass @@ -105,6 +144,79 @@ PHIExpression::~PHIExpression() = default; } } +// Tarjan's SCC finding algorithm with Nuutila's improvements +// SCCIterator is actually fairly complex for the simple thing we want. +// It also wants to hand us SCC's that are unrelated to the phi node we ask +// about, and have us process them there or risk redoing work. +// Graph traits over a filter iterator also doesn't work that well here. +// This SCC finder is specialized to walk use-def chains, and only follows +// instructions, +// not generic values (arguments, etc). +struct TarjanSCC { + + TarjanSCC() : Components(1) {} + + void Start(const Instruction *Start) { + if (Root.lookup(Start) == 0) + FindSCC(Start); + } + + const SmallPtrSetImpl<const Value *> &getComponentFor(const Value *V) const { + unsigned ComponentID = ValueToComponent.lookup(V); + + assert(ComponentID > 0 && + "Asking for a component for a value we never processed"); + return Components[ComponentID]; + } + +private: + void FindSCC(const Instruction *I) { + Root[I] = ++DFSNum; + // Store the DFS Number we had before it possibly gets incremented. + unsigned int OurDFS = DFSNum; + for (auto &Op : I->operands()) { + if (auto *InstOp = dyn_cast<Instruction>(Op)) { + if (Root.lookup(Op) == 0) + FindSCC(InstOp); + if (!InComponent.count(Op)) + Root[I] = std::min(Root.lookup(I), Root.lookup(Op)); + } + } + // See if we really were the root of a component, by seeing if we still have + // our DFSNumber. If we do, we are the root of the component, and we have + // completed a component. If we do not, we are not the root of a component, + // and belong on the component stack. + if (Root.lookup(I) == OurDFS) { + unsigned ComponentID = Components.size(); + Components.resize(Components.size() + 1); + auto &Component = Components.back(); + Component.insert(I); + DEBUG(dbgs() << "Component root is " << *I << "\n"); + InComponent.insert(I); + ValueToComponent[I] = ComponentID; + // Pop a component off the stack and label it. + while (!Stack.empty() && Root.lookup(Stack.back()) >= OurDFS) { + auto *Member = Stack.back(); + DEBUG(dbgs() << "Component member is " << *Member << "\n"); + Component.insert(Member); + InComponent.insert(Member); + ValueToComponent[Member] = ComponentID; + Stack.pop_back(); + } + } else { + // Part of a component, push to stack + Stack.push_back(I); + } + } + unsigned int DFSNum = 1; + SmallPtrSet<const Value *, 8> InComponent; + DenseMap<const Value *, unsigned int> Root; + SmallVector<const Value *, 8> Stack; + // Store the components as vector of ptr sets, because we need the topo order + // of SCC's, but not individual member order + SmallVector<SmallPtrSet<const Value *, 8>, 8> Components; + DenseMap<const Value *, unsigned> ValueToComponent; +}; // Congruence classes represent the set of expressions/instructions // that are all the same *during some scope in the function*. // That is, because of the way we perform equality propagation, and @@ -115,46 +227,166 @@ PHIExpression::~PHIExpression() = default; // For any Value in the Member set, it is valid to replace any dominated member // with that Value. // -// Every congruence class has a leader, and the leader is used to -// symbolize instructions in a canonical way (IE every operand of an -// instruction that is a member of the same congruence class will -// always be replaced with leader during symbolization). -// To simplify symbolization, we keep the leader as a constant if class can be -// proved to be a constant value. -// Otherwise, the leader is a randomly chosen member of the value set, it does -// not matter which one is chosen. -// Each congruence class also has a defining expression, -// though the expression may be null. If it exists, it can be used for forward -// propagation and reassociation of values. -// -struct CongruenceClass { - using MemberSet = SmallPtrSet<Value *, 4>; +// Every congruence class has a leader, and the leader is used to symbolize +// instructions in a canonical way (IE every operand of an instruction that is a +// member of the same congruence class will always be replaced with leader +// during symbolization). To simplify symbolization, we keep the leader as a +// constant if class can be proved to be a constant value. Otherwise, the +// leader is the member of the value set with the smallest DFS number. Each +// congruence class also has a defining expression, though the expression may be +// null. If it exists, it can be used for forward propagation and reassociation +// of values. + +// For memory, we also track a representative MemoryAccess, and a set of memory +// members for MemoryPhis (which have no real instructions). Note that for +// memory, it seems tempting to try to split the memory members into a +// MemoryCongruenceClass or something. Unfortunately, this does not work +// easily. The value numbering of a given memory expression depends on the +// leader of the memory congruence class, and the leader of memory congruence +// class depends on the value numbering of a given memory expression. This +// leads to wasted propagation, and in some cases, missed optimization. For +// example: If we had value numbered two stores together before, but now do not, +// we move them to a new value congruence class. This in turn will move at one +// of the memorydefs to a new memory congruence class. Which in turn, affects +// the value numbering of the stores we just value numbered (because the memory +// congruence class is part of the value number). So while theoretically +// possible to split them up, it turns out to be *incredibly* complicated to get +// it to work right, because of the interdependency. While structurally +// slightly messier, it is algorithmically much simpler and faster to do what we +// do here, and track them both at once in the same class. +// Note: The default iterators for this class iterate over values +class CongruenceClass { +public: + using MemberType = Value; + using MemberSet = SmallPtrSet<MemberType *, 4>; + using MemoryMemberType = MemoryPhi; + using MemoryMemberSet = SmallPtrSet<const MemoryMemberType *, 2>; + + explicit CongruenceClass(unsigned ID) : ID(ID) {} + CongruenceClass(unsigned ID, Value *Leader, const Expression *E) + : ID(ID), RepLeader(Leader), DefiningExpr(E) {} + unsigned getID() const { return ID; } + // True if this class has no members left. This is mainly used for assertion + // purposes, and for skipping empty classes. + bool isDead() const { + // If it's both dead from a value perspective, and dead from a memory + // perspective, it's really dead. + return empty() && memory_empty(); + } + // Leader functions + Value *getLeader() const { return RepLeader; } + void setLeader(Value *Leader) { RepLeader = Leader; } + const std::pair<Value *, unsigned int> &getNextLeader() const { + return NextLeader; + } + void resetNextLeader() { NextLeader = {nullptr, ~0}; } + + void addPossibleNextLeader(std::pair<Value *, unsigned int> LeaderPair) { + if (LeaderPair.second < NextLeader.second) + NextLeader = LeaderPair; + } + + Value *getStoredValue() const { return RepStoredValue; } + void setStoredValue(Value *Leader) { RepStoredValue = Leader; } + const MemoryAccess *getMemoryLeader() const { return RepMemoryAccess; } + void setMemoryLeader(const MemoryAccess *Leader) { RepMemoryAccess = Leader; } + + // Forward propagation info + const Expression *getDefiningExpr() const { return DefiningExpr; } + + // Value member set + bool empty() const { return Members.empty(); } + unsigned size() const { return Members.size(); } + MemberSet::const_iterator begin() const { return Members.begin(); } + MemberSet::const_iterator end() const { return Members.end(); } + void insert(MemberType *M) { Members.insert(M); } + void erase(MemberType *M) { Members.erase(M); } + void swap(MemberSet &Other) { Members.swap(Other); } + + // Memory member set + bool memory_empty() const { return MemoryMembers.empty(); } + unsigned memory_size() const { return MemoryMembers.size(); } + MemoryMemberSet::const_iterator memory_begin() const { + return MemoryMembers.begin(); + } + MemoryMemberSet::const_iterator memory_end() const { + return MemoryMembers.end(); + } + iterator_range<MemoryMemberSet::const_iterator> memory() const { + return make_range(memory_begin(), memory_end()); + } + void memory_insert(const MemoryMemberType *M) { MemoryMembers.insert(M); } + void memory_erase(const MemoryMemberType *M) { MemoryMembers.erase(M); } + + // Store count + unsigned getStoreCount() const { return StoreCount; } + void incStoreCount() { ++StoreCount; } + void decStoreCount() { + assert(StoreCount != 0 && "Store count went negative"); + --StoreCount; + } + + // True if this class has no memory members. + bool definesNoMemory() const { return StoreCount == 0 && memory_empty(); } + + // Return true if two congruence classes are equivalent to each other. This + // means + // that every field but the ID number and the dead field are equivalent. + bool isEquivalentTo(const CongruenceClass *Other) const { + if (!Other) + return false; + if (this == Other) + return true; + + if (std::tie(StoreCount, RepLeader, RepStoredValue, RepMemoryAccess) != + std::tie(Other->StoreCount, Other->RepLeader, Other->RepStoredValue, + Other->RepMemoryAccess)) + return false; + if (DefiningExpr != Other->DefiningExpr) + if (!DefiningExpr || !Other->DefiningExpr || + *DefiningExpr != *Other->DefiningExpr) + return false; + // We need some ordered set + std::set<Value *> AMembers(Members.begin(), Members.end()); + std::set<Value *> BMembers(Members.begin(), Members.end()); + return AMembers == BMembers; + } + +private: unsigned ID; // Representative leader. Value *RepLeader = nullptr; + // The most dominating leader after our current leader, because the member set + // is not sorted and is expensive to keep sorted all the time. + std::pair<Value *, unsigned int> NextLeader = {nullptr, ~0U}; + // If this is represented by a store, the value of the store. + Value *RepStoredValue = nullptr; + // If this class contains MemoryDefs or MemoryPhis, this is the leading memory + // access. + const MemoryAccess *RepMemoryAccess = nullptr; // Defining Expression. const Expression *DefiningExpr = nullptr; // Actual members of this class. MemberSet Members; - - // True if this class has no members left. This is mainly used for assertion - // purposes, and for skipping empty classes. - bool Dead = false; - + // This is the set of MemoryPhis that exist in the class. MemoryDefs and + // MemoryUses have real instructions representing them, so we only need to + // track MemoryPhis here. + MemoryMemberSet MemoryMembers; // Number of stores in this congruence class. // This is used so we can detect store equivalence changes properly. int StoreCount = 0; - - // The most dominating leader after our current leader, because the member set - // is not sorted and is expensive to keep sorted all the time. - std::pair<Value *, unsigned int> NextLeader = {nullptr, ~0U}; - - explicit CongruenceClass(unsigned ID) : ID(ID) {} - CongruenceClass(unsigned ID, Value *Leader, const Expression *E) - : ID(ID), RepLeader(Leader), DefiningExpr(E) {} }; namespace llvm { +struct ExactEqualsExpression { + const Expression &E; + explicit ExactEqualsExpression(const Expression &E) : E(E) {} + hash_code getComputedHash() const { return E.getComputedHash(); } + bool operator==(const Expression &Other) const { + return E.exactlyEquals(Other); + } +}; + template <> struct DenseMapInfo<const Expression *> { static const Expression *getEmptyKey() { auto Val = static_cast<uintptr_t>(-1); @@ -166,51 +398,144 @@ template <> struct DenseMapInfo<const Expression *> { Val <<= PointerLikeTypeTraits<const Expression *>::NumLowBitsAvailable; return reinterpret_cast<const Expression *>(Val); } - static unsigned getHashValue(const Expression *V) { - return static_cast<unsigned>(V->getHashValue()); + static unsigned getHashValue(const Expression *E) { + return E->getComputedHash(); } + static unsigned getHashValue(const ExactEqualsExpression &E) { + return E.getComputedHash(); + } + static bool isEqual(const ExactEqualsExpression &LHS, const Expression *RHS) { + if (RHS == getTombstoneKey() || RHS == getEmptyKey()) + return false; + return LHS == *RHS; + } + static bool isEqual(const Expression *LHS, const Expression *RHS) { if (LHS == RHS) return true; if (LHS == getTombstoneKey() || RHS == getTombstoneKey() || LHS == getEmptyKey() || RHS == getEmptyKey()) return false; + // Compare hashes before equality. This is *not* what the hashtable does, + // since it is computing it modulo the number of buckets, whereas we are + // using the full hash keyspace. Since the hashes are precomputed, this + // check is *much* faster than equality. + if (LHS->getComputedHash() != RHS->getComputedHash()) + return false; return *LHS == *RHS; } }; } // end namespace llvm -class NewGVN : public FunctionPass { +namespace { +class NewGVN { + Function &F; DominatorTree *DT; - const DataLayout *DL; const TargetLibraryInfo *TLI; - AssumptionCache *AC; AliasAnalysis *AA; MemorySSA *MSSA; MemorySSAWalker *MSSAWalker; - BumpPtrAllocator ExpressionAllocator; - ArrayRecycler<Value *> ArgRecycler; + const DataLayout &DL; + std::unique_ptr<PredicateInfo> PredInfo; + + // These are the only two things the create* functions should have + // side-effects on due to allocating memory. + mutable BumpPtrAllocator ExpressionAllocator; + mutable ArrayRecycler<Value *> ArgRecycler; + mutable TarjanSCC SCCFinder; + const SimplifyQuery SQ; + + // Number of function arguments, used by ranking + unsigned int NumFuncArgs; + + // RPOOrdering of basic blocks + DenseMap<const DomTreeNode *, unsigned> RPOOrdering; // Congruence class info. - CongruenceClass *InitialClass; + + // This class is called INITIAL in the paper. It is the class everything + // startsout in, and represents any value. Being an optimistic analysis, + // anything in the TOP class has the value TOP, which is indeterminate and + // equivalent to everything. + CongruenceClass *TOPClass; std::vector<CongruenceClass *> CongruenceClasses; unsigned NextCongruenceNum; // Value Mappings. DenseMap<Value *, CongruenceClass *> ValueToClass; DenseMap<Value *, const Expression *> ValueToExpression; + // Value PHI handling, used to make equivalence between phi(op, op) and + // op(phi, phi). + // These mappings just store various data that would normally be part of the + // IR. + DenseSet<const Instruction *> PHINodeUses; + // Map a temporary instruction we created to a parent block. + DenseMap<const Value *, BasicBlock *> TempToBlock; + // Map between the temporary phis we created and the real instructions they + // are known equivalent to. + DenseMap<const Value *, PHINode *> RealToTemp; + // In order to know when we should re-process instructions that have + // phi-of-ops, we track the set of expressions that they needed as + // leaders. When we discover new leaders for those expressions, we process the + // associated phi-of-op instructions again in case they have changed. The + // other way they may change is if they had leaders, and those leaders + // disappear. However, at the point they have leaders, there are uses of the + // relevant operands in the created phi node, and so they will get reprocessed + // through the normal user marking we perform. + mutable DenseMap<const Value *, SmallPtrSet<Value *, 2>> AdditionalUsers; + DenseMap<const Expression *, SmallPtrSet<Instruction *, 2>> + ExpressionToPhiOfOps; + // Map from basic block to the temporary operations we created + DenseMap<const BasicBlock *, SmallVector<PHINode *, 8>> PHIOfOpsPHIs; + // Map from temporary operation to MemoryAccess. + DenseMap<const Instruction *, MemoryUseOrDef *> TempToMemory; + // Set of all temporary instructions we created. + DenseSet<Instruction *> AllTempInstructions; + + // Mapping from predicate info we used to the instructions we used it with. + // In order to correctly ensure propagation, we must keep track of what + // comparisons we used, so that when the values of the comparisons change, we + // propagate the information to the places we used the comparison. + mutable DenseMap<const Value *, SmallPtrSet<Instruction *, 2>> + PredicateToUsers; + // the same reasoning as PredicateToUsers. When we skip MemoryAccesses for + // stores, we no longer can rely solely on the def-use chains of MemorySSA. + mutable DenseMap<const MemoryAccess *, SmallPtrSet<MemoryAccess *, 2>> + MemoryToUsers; // A table storing which memorydefs/phis represent a memory state provably // equivalent to another memory state. // We could use the congruence class machinery, but the MemoryAccess's are // abstract memory states, so they can only ever be equivalent to each other, // and not to constants, etc. - DenseMap<const MemoryAccess *, MemoryAccess *> MemoryAccessEquiv; - + DenseMap<const MemoryAccess *, CongruenceClass *> MemoryAccessToClass; + + // We could, if we wanted, build MemoryPhiExpressions and + // MemoryVariableExpressions, etc, and value number them the same way we value + // number phi expressions. For the moment, this seems like overkill. They + // can only exist in one of three states: they can be TOP (equal to + // everything), Equivalent to something else, or unique. Because we do not + // create expressions for them, we need to simulate leader change not just + // when they change class, but when they change state. Note: We can do the + // same thing for phis, and avoid having phi expressions if we wanted, We + // should eventually unify in one direction or the other, so this is a little + // bit of an experiment in which turns out easier to maintain. + enum MemoryPhiState { MPS_Invalid, MPS_TOP, MPS_Equivalent, MPS_Unique }; + DenseMap<const MemoryPhi *, MemoryPhiState> MemoryPhiState; + + enum InstCycleState { ICS_Unknown, ICS_CycleFree, ICS_Cycle }; + mutable DenseMap<const Instruction *, InstCycleState> InstCycleState; // Expression to class mapping. using ExpressionClassMap = DenseMap<const Expression *, CongruenceClass *>; ExpressionClassMap ExpressionToClass; + // We have a single expression that represents currently DeadExpressions. + // For dead expressions we can prove will stay dead, we mark them with + // DFS number zero. However, it's possible in the case of phi nodes + // for us to assume/prove all arguments are dead during fixpointing. + // We use DeadExpression for that case. + DeadExpression *SingletonDeadExpression = nullptr; + // Which values have changed as a result of leader changes. SmallPtrSet<Value *, 8> LeaderChanges; @@ -231,8 +556,6 @@ class NewGVN : public FunctionPass { BitVector TouchedInstructions; DenseMap<const BasicBlock *, std::pair<unsigned, unsigned>> BlockInstRange; - DenseMap<const DomTreeNode *, std::pair<unsigned, unsigned>> - DominatedInstRange; #ifndef NDEBUG // Debugging for how many times each block and instruction got processed. @@ -240,56 +563,47 @@ class NewGVN : public FunctionPass { #endif // DFS info. - DenseMap<const BasicBlock *, std::pair<int, int>> DFSDomMap; + // This contains a mapping from Instructions to DFS numbers. + // The numbering starts at 1. An instruction with DFS number zero + // means that the instruction is dead. DenseMap<const Value *, unsigned> InstrDFS; + + // This contains the mapping DFS numbers to instructions. SmallVector<Value *, 32> DFSToInstr; // Deletion info. SmallPtrSet<Instruction *, 8> InstructionsToErase; public: - static char ID; // Pass identification, replacement for typeid. - NewGVN() : FunctionPass(ID) { - initializeNewGVNPass(*PassRegistry::getPassRegistry()); + NewGVN(Function &F, DominatorTree *DT, AssumptionCache *AC, + TargetLibraryInfo *TLI, AliasAnalysis *AA, MemorySSA *MSSA, + const DataLayout &DL) + : F(F), DT(DT), TLI(TLI), AA(AA), MSSA(MSSA), DL(DL), + PredInfo(make_unique<PredicateInfo>(F, *DT, *AC)), SQ(DL, TLI, DT, AC) { } - - bool runOnFunction(Function &F) override; - bool runGVN(Function &F, DominatorTree *DT, AssumptionCache *AC, - TargetLibraryInfo *TLI, AliasAnalysis *AA, MemorySSA *MSSA); + bool runGVN(); private: - // This transformation requires dominator postdominator info. - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<AssumptionCacheTracker>(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); - AU.addRequired<MemorySSAWrapperPass>(); - AU.addRequired<AAResultsWrapperPass>(); - - AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); - } - // Expression handling. - const Expression *createExpression(Instruction *, const BasicBlock *); - const Expression *createBinaryExpression(unsigned, Type *, Value *, Value *, - const BasicBlock *); - PHIExpression *createPHIExpression(Instruction *); - const VariableExpression *createVariableExpression(Value *); - const ConstantExpression *createConstantExpression(Constant *); - const Expression *createVariableOrConstant(Value *V, const BasicBlock *B); - const UnknownExpression *createUnknownExpression(Instruction *); - const StoreExpression *createStoreExpression(StoreInst *, MemoryAccess *, - const BasicBlock *); + const Expression *createExpression(Instruction *) const; + const Expression *createBinaryExpression(unsigned, Type *, Value *, + Value *) const; + PHIExpression *createPHIExpression(Instruction *, bool &HasBackEdge, + bool &OriginalOpsConstant) const; + const DeadExpression *createDeadExpression() const; + const VariableExpression *createVariableExpression(Value *) const; + const ConstantExpression *createConstantExpression(Constant *) const; + const Expression *createVariableOrConstant(Value *V) const; + const UnknownExpression *createUnknownExpression(Instruction *) const; + const StoreExpression *createStoreExpression(StoreInst *, + const MemoryAccess *) const; LoadExpression *createLoadExpression(Type *, Value *, LoadInst *, - MemoryAccess *, const BasicBlock *); - - const CallExpression *createCallExpression(CallInst *, MemoryAccess *, - const BasicBlock *); + const MemoryAccess *) const; + const CallExpression *createCallExpression(CallInst *, + const MemoryAccess *) const; const AggregateValueExpression * - createAggregateValueExpression(Instruction *, const BasicBlock *); - bool setBasicExpressionInfo(Instruction *, BasicExpression *, - const BasicBlock *); + createAggregateValueExpression(Instruction *) const; + bool setBasicExpressionInfo(Instruction *, BasicExpression *) const; // Congruence class handling. CongruenceClass *createCongruenceClass(Value *Leader, const Expression *E) { @@ -298,13 +612,28 @@ private: return result; } + CongruenceClass *createMemoryClass(MemoryAccess *MA) { + auto *CC = createCongruenceClass(nullptr, nullptr); + CC->setMemoryLeader(MA); + return CC; + } + CongruenceClass *ensureLeaderOfMemoryClass(MemoryAccess *MA) { + auto *CC = getMemoryClass(MA); + if (CC->getMemoryLeader() != MA) + CC = createMemoryClass(MA); + return CC; + } + CongruenceClass *createSingletonCongruenceClass(Value *Member) { CongruenceClass *CClass = createCongruenceClass(Member, nullptr); - CClass->Members.insert(Member); + CClass->insert(Member); ValueToClass[Member] = CClass; return CClass; } void initializeCongruenceClasses(Function &F); + const Expression *makePossiblePhiOfOps(Instruction *, + SmallPtrSetImpl<Value *> &); + void addPhiOfOps(PHINode *Op, BasicBlock *BB, Instruction *ExistingValue); // Value number an Instruction or MemoryPhi. void valueNumberMemoryPhi(MemoryPhi *); @@ -312,78 +641,128 @@ private: // Symbolic evaluation. const Expression *checkSimplificationResults(Expression *, Instruction *, - Value *); - const Expression *performSymbolicEvaluation(Value *, const BasicBlock *); - const Expression *performSymbolicLoadEvaluation(Instruction *, - const BasicBlock *); - const Expression *performSymbolicStoreEvaluation(Instruction *, - const BasicBlock *); - const Expression *performSymbolicCallEvaluation(Instruction *, - const BasicBlock *); - const Expression *performSymbolicPHIEvaluation(Instruction *, - const BasicBlock *); - bool setMemoryAccessEquivTo(MemoryAccess *From, MemoryAccess *To); - const Expression *performSymbolicAggrValueEvaluation(Instruction *, - const BasicBlock *); + Value *) const; + const Expression *performSymbolicEvaluation(Value *, + SmallPtrSetImpl<Value *> &) const; + const Expression *performSymbolicLoadCoercion(Type *, Value *, LoadInst *, + Instruction *, + MemoryAccess *) const; + const Expression *performSymbolicLoadEvaluation(Instruction *) const; + const Expression *performSymbolicStoreEvaluation(Instruction *) const; + const Expression *performSymbolicCallEvaluation(Instruction *) const; + const Expression *performSymbolicPHIEvaluation(Instruction *) const; + const Expression *performSymbolicAggrValueEvaluation(Instruction *) const; + const Expression *performSymbolicCmpEvaluation(Instruction *) const; + const Expression *performSymbolicPredicateInfoEvaluation(Instruction *) const; // Congruence finding. - // Templated to allow them to work both on BB's and BB-edges. - template <class T> - Value *lookupOperandLeader(Value *, const User *, const T &) const; + bool someEquivalentDominates(const Instruction *, const Instruction *) const; + Value *lookupOperandLeader(Value *) const; void performCongruenceFinding(Instruction *, const Expression *); - void moveValueToNewCongruenceClass(Instruction *, CongruenceClass *, - CongruenceClass *); + void moveValueToNewCongruenceClass(Instruction *, const Expression *, + CongruenceClass *, CongruenceClass *); + void moveMemoryToNewCongruenceClass(Instruction *, MemoryAccess *, + CongruenceClass *, CongruenceClass *); + Value *getNextValueLeader(CongruenceClass *) const; + const MemoryAccess *getNextMemoryLeader(CongruenceClass *) const; + bool setMemoryClass(const MemoryAccess *From, CongruenceClass *To); + CongruenceClass *getMemoryClass(const MemoryAccess *MA) const; + const MemoryAccess *lookupMemoryLeader(const MemoryAccess *) const; + bool isMemoryAccessTOP(const MemoryAccess *) const; + + // Ranking + unsigned int getRank(const Value *) const; + bool shouldSwapOperands(const Value *, const Value *) const; + // Reachability handling. void updateReachableEdge(BasicBlock *, BasicBlock *); void processOutgoingEdges(TerminatorInst *, BasicBlock *); - bool isOnlyReachableViaThisEdge(const BasicBlockEdge &) const; - Value *findConditionEquivalence(Value *, BasicBlock *) const; - MemoryAccess *lookupMemoryAccessEquiv(MemoryAccess *) const; + Value *findConditionEquivalence(Value *) const; // Elimination. struct ValueDFS; - void convertDenseToDFSOrdered(CongruenceClass::MemberSet &, - SmallVectorImpl<ValueDFS> &); + void convertClassToDFSOrdered(const CongruenceClass &, + SmallVectorImpl<ValueDFS> &, + DenseMap<const Value *, unsigned int> &, + SmallPtrSetImpl<Instruction *> &) const; + void convertClassToLoadsAndStores(const CongruenceClass &, + SmallVectorImpl<ValueDFS> &) const; bool eliminateInstructions(Function &); void replaceInstruction(Instruction *, Value *); void markInstructionForDeletion(Instruction *); void deleteInstructionsInBlock(BasicBlock *); + Value *findPhiOfOpsLeader(const Expression *E, const BasicBlock *BB) const; // New instruction creation. void handleNewInstruction(Instruction *){}; // Various instruction touch utilities + template <typename Map, typename KeyType, typename Func> + void for_each_found(Map &, const KeyType &, Func); + template <typename Map, typename KeyType> + void touchAndErase(Map &, const KeyType &); void markUsersTouched(Value *); - void markMemoryUsersTouched(MemoryAccess *); - void markLeaderChangeTouched(CongruenceClass *CC); + void markMemoryUsersTouched(const MemoryAccess *); + void markMemoryDefTouched(const MemoryAccess *); + void markPredicateUsersTouched(Instruction *); + void markValueLeaderChangeTouched(CongruenceClass *CC); + void markMemoryLeaderChangeTouched(CongruenceClass *CC); + void markPhiOfOpsChanged(const Expression *E); + void addPredicateUsers(const PredicateBase *, Instruction *) const; + void addMemoryUsers(const MemoryAccess *To, MemoryAccess *U) const; + void addAdditionalUsers(Value *To, Value *User) const; + + // Main loop of value numbering + void iterateTouchedInstructions(); // Utilities. void cleanupTables(); std::pair<unsigned, unsigned> assignDFSNumbers(BasicBlock *, unsigned); - void updateProcessedCount(Value *V); + void updateProcessedCount(const Value *V); void verifyMemoryCongruency() const; - bool singleReachablePHIPath(const MemoryAccess *, const MemoryAccess *) const; -}; - -char NewGVN::ID = 0; + void verifyIterationSettled(Function &F); + void verifyStoreExpressions() const; + bool singleReachablePHIPath(SmallPtrSet<const MemoryAccess *, 8> &, + const MemoryAccess *, const MemoryAccess *) const; + BasicBlock *getBlockForValue(Value *V) const; + void deleteExpression(const Expression *E) const; + MemoryUseOrDef *getMemoryAccess(const Instruction *) const; + MemoryAccess *getDefiningAccess(const MemoryAccess *) const; + MemoryPhi *getMemoryAccess(const BasicBlock *) const; + template <class T, class Range> T *getMinDFSOfRange(const Range &) const; + unsigned InstrToDFSNum(const Value *V) const { + assert(isa<Instruction>(V) && "This should not be used for MemoryAccesses"); + return InstrDFS.lookup(V); + } -// createGVNPass - The public interface to this file. -FunctionPass *llvm::createNewGVNPass() { return new NewGVN(); } + unsigned InstrToDFSNum(const MemoryAccess *MA) const { + return MemoryToDFSNum(MA); + } + Value *InstrFromDFSNum(unsigned DFSNum) { return DFSToInstr[DFSNum]; } + // Given a MemoryAccess, return the relevant instruction DFS number. Note: + // This deliberately takes a value so it can be used with Use's, which will + // auto-convert to Value's but not to MemoryAccess's. + unsigned MemoryToDFSNum(const Value *MA) const { + assert(isa<MemoryAccess>(MA) && + "This should not be used with instructions"); + return isa<MemoryUseOrDef>(MA) + ? InstrToDFSNum(cast<MemoryUseOrDef>(MA)->getMemoryInst()) + : InstrDFS.lookup(MA); + } + bool isCycleFree(const Instruction *) const; + bool isBackedge(BasicBlock *From, BasicBlock *To) const; + // Debug counter info. When verifying, we have to reset the value numbering + // debug counter to the same state it started in to get the same results. + std::pair<int, int> StartingVNCounter; +}; +} // end anonymous namespace template <typename T> static bool equalsLoadStoreHelper(const T &LHS, const Expression &RHS) { - if ((!isa<LoadExpression>(RHS) && !isa<StoreExpression>(RHS)) || - !LHS.BasicExpression::equals(RHS)) { + if (!isa<LoadExpression>(RHS) && !isa<StoreExpression>(RHS)) return false; - } else if (const auto *L = dyn_cast<LoadExpression>(&RHS)) { - if (LHS.getDefiningAccess() != L->getDefiningAccess()) - return false; - } else if (const auto *S = dyn_cast<StoreExpression>(&RHS)) { - if (LHS.getDefiningAccess() != S->getDefiningAccess()) - return false; - } - return true; + return LHS.MemoryExpression::equals(RHS); } bool LoadExpression::equals(const Expression &Other) const { @@ -391,7 +770,22 @@ bool LoadExpression::equals(const Expression &Other) const { } bool StoreExpression::equals(const Expression &Other) const { - return equalsLoadStoreHelper(*this, Other); + if (!equalsLoadStoreHelper(*this, Other)) + return false; + // Make sure that store vs store includes the value operand. + if (const auto *S = dyn_cast<StoreExpression>(&Other)) + if (getStoredValue() != S->getStoredValue()) + return false; + return true; +} + +// Determine if the edge From->To is a backedge +bool NewGVN::isBackedge(BasicBlock *From, BasicBlock *To) const { + if (From == To) + return true; + auto *FromDTN = DT->getNode(From); + auto *ToDTN = DT->getNode(To); + return RPOOrdering.lookup(FromDTN) >= RPOOrdering.lookup(ToDTN); } #ifndef NDEBUG @@ -400,17 +794,45 @@ static std::string getBlockName(const BasicBlock *B) { } #endif -INITIALIZE_PASS_BEGIN(NewGVN, "newgvn", "Global Value Numbering", false, false) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) -INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) -INITIALIZE_PASS_END(NewGVN, "newgvn", "Global Value Numbering", false, false) +// Get a MemoryAccess for an instruction, fake or real. +MemoryUseOrDef *NewGVN::getMemoryAccess(const Instruction *I) const { + auto *Result = MSSA->getMemoryAccess(I); + return Result ? Result : TempToMemory.lookup(I); +} -PHIExpression *NewGVN::createPHIExpression(Instruction *I) { - BasicBlock *PHIBlock = I->getParent(); +// Get a MemoryPhi for a basic block. These are all real. +MemoryPhi *NewGVN::getMemoryAccess(const BasicBlock *BB) const { + return MSSA->getMemoryAccess(BB); +} + +// Get the basic block from an instruction/memory value. +BasicBlock *NewGVN::getBlockForValue(Value *V) const { + if (auto *I = dyn_cast<Instruction>(V)) { + auto *Parent = I->getParent(); + if (Parent) + return Parent; + Parent = TempToBlock.lookup(V); + assert(Parent && "Every fake instruction should have a block"); + return Parent; + } + + auto *MP = dyn_cast<MemoryPhi>(V); + assert(MP && "Should have been an instruction or a MemoryPhi"); + return MP->getBlock(); +} + +// Delete a definitely dead expression, so it can be reused by the expression +// allocator. Some of these are not in creation functions, so we have to accept +// const versions. +void NewGVN::deleteExpression(const Expression *E) const { + assert(isa<BasicExpression>(E)); + auto *BE = cast<BasicExpression>(E); + const_cast<BasicExpression *>(BE)->deallocateOperands(ArgRecycler); + ExpressionAllocator.Deallocate(E); +} +PHIExpression *NewGVN::createPHIExpression(Instruction *I, bool &HasBackedge, + bool &OriginalOpsConstant) const { + BasicBlock *PHIBlock = getBlockForValue(I); auto *PN = cast<PHINode>(I); auto *E = new (ExpressionAllocator) PHIExpression(PN->getNumOperands(), PHIBlock); @@ -419,28 +841,47 @@ PHIExpression *NewGVN::createPHIExpression(Instruction *I) { E->setType(I->getType()); E->setOpcode(I->getOpcode()); - auto ReachablePhiArg = [&](const Use &U) { - return ReachableBlocks.count(PN->getIncomingBlock(U)); - }; - - // Filter out unreachable operands - auto Filtered = make_filter_range(PN->operands(), ReachablePhiArg); - + // NewGVN assumes the operands of a PHI node are in a consistent order across + // PHIs. LLVM doesn't seem to always guarantee this. While we need to fix + // this in LLVM at some point we don't want GVN to find wrong congruences. + // Therefore, here we sort uses in predecessor order. + // We're sorting the values by pointer. In theory this might be cause of + // non-determinism, but here we don't rely on the ordering for anything + // significant, e.g. we don't create new instructions based on it so we're + // fine. + SmallVector<const Use *, 4> PHIOperands; + for (const Use &U : PN->operands()) + PHIOperands.push_back(&U); + std::sort(PHIOperands.begin(), PHIOperands.end(), + [&](const Use *U1, const Use *U2) { + return PN->getIncomingBlock(*U1) < PN->getIncomingBlock(*U2); + }); + + // Filter out unreachable phi operands. + auto Filtered = make_filter_range(PHIOperands, [&](const Use *U) { + if (*U == PN) + return false; + if (!ReachableEdges.count({PN->getIncomingBlock(*U), PHIBlock})) + return false; + // Things in TOPClass are equivalent to everything. + if (ValueToClass.lookup(*U) == TOPClass) + return false; + return lookupOperandLeader(*U) != PN; + }); std::transform(Filtered.begin(), Filtered.end(), op_inserter(E), - [&](const Use &U) -> Value * { - // Don't try to transform self-defined phis. - if (U == PN) - return PN; - const BasicBlockEdge BBE(PN->getIncomingBlock(U), PHIBlock); - return lookupOperandLeader(U, I, BBE); + [&](const Use *U) -> Value * { + auto *BB = PN->getIncomingBlock(*U); + HasBackedge = HasBackedge || isBackedge(BB, PHIBlock); + OriginalOpsConstant = + OriginalOpsConstant && isa<Constant>(*U); + return lookupOperandLeader(*U); }); return E; } // Set basic expression info (Arguments, type, opcode) for Expression // E from Instruction I in block B. -bool NewGVN::setBasicExpressionInfo(Instruction *I, BasicExpression *E, - const BasicBlock *B) { +bool NewGVN::setBasicExpressionInfo(Instruction *I, BasicExpression *E) const { bool AllConstant = true; if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) E->setType(GEP->getSourceElementType()); @@ -452,8 +893,8 @@ bool NewGVN::setBasicExpressionInfo(Instruction *I, BasicExpression *E, // Transform the operand array into an operand leader array, and keep track of // whether all members are constant. std::transform(I->op_begin(), I->op_end(), op_inserter(E), [&](Value *O) { - auto Operand = lookupOperandLeader(O, I, B); - AllConstant &= isa<Constant>(Operand); + auto Operand = lookupOperandLeader(O); + AllConstant = AllConstant && isa<Constant>(Operand); return Operand; }); @@ -461,8 +902,8 @@ bool NewGVN::setBasicExpressionInfo(Instruction *I, BasicExpression *E, } const Expression *NewGVN::createBinaryExpression(unsigned Opcode, Type *T, - Value *Arg1, Value *Arg2, - const BasicBlock *B) { + Value *Arg1, + Value *Arg2) const { auto *E = new (ExpressionAllocator) BasicExpression(2); E->setType(T); @@ -473,14 +914,13 @@ const Expression *NewGVN::createBinaryExpression(unsigned Opcode, Type *T, // of their operands get the same value number by sorting the operand value // numbers. Since all commutative instructions have two operands it is more // efficient to sort by hand rather than using, say, std::sort. - if (Arg1 > Arg2) + if (shouldSwapOperands(Arg1, Arg2)) std::swap(Arg1, Arg2); } - E->op_push_back(lookupOperandLeader(Arg1, nullptr, B)); - E->op_push_back(lookupOperandLeader(Arg2, nullptr, B)); + E->op_push_back(lookupOperandLeader(Arg1)); + E->op_push_back(lookupOperandLeader(Arg2)); - Value *V = SimplifyBinOp(Opcode, E->getOperand(0), E->getOperand(1), *DL, TLI, - DT, AC); + Value *V = SimplifyBinOp(Opcode, E->getOperand(0), E->getOperand(1), SQ); if (const Expression *SimplifiedE = checkSimplificationResults(E, nullptr, V)) return SimplifiedE; return E; @@ -492,7 +932,8 @@ const Expression *NewGVN::createBinaryExpression(unsigned Opcode, Type *T, // TODO: Once finished, this should not take an Instruction, we only // use it for printing. const Expression *NewGVN::checkSimplificationResults(Expression *E, - Instruction *I, Value *V) { + Instruction *I, + Value *V) const { if (!V) return nullptr; if (auto *C = dyn_cast<Constant>(V)) { @@ -502,40 +943,40 @@ const Expression *NewGVN::checkSimplificationResults(Expression *E, NumGVNOpsSimplified++; assert(isa<BasicExpression>(E) && "We should always have had a basic expression here"); - - cast<BasicExpression>(E)->deallocateOperands(ArgRecycler); - ExpressionAllocator.Deallocate(E); + deleteExpression(E); return createConstantExpression(C); } else if (isa<Argument>(V) || isa<GlobalVariable>(V)) { if (I) DEBUG(dbgs() << "Simplified " << *I << " to " << " variable " << *V << "\n"); - cast<BasicExpression>(E)->deallocateOperands(ArgRecycler); - ExpressionAllocator.Deallocate(E); + deleteExpression(E); return createVariableExpression(V); } CongruenceClass *CC = ValueToClass.lookup(V); - if (CC && CC->DefiningExpr) { + if (CC && CC->getDefiningExpr()) { + // If we simplified to something else, we need to communicate + // that we're users of the value we simplified to. + if (I != V) { + // Don't add temporary instructions to the user lists. + if (!AllTempInstructions.count(I)) + addAdditionalUsers(V, I); + } + if (I) DEBUG(dbgs() << "Simplified " << *I << " to " - << " expression " << *V << "\n"); + << " expression " << *CC->getDefiningExpr() << "\n"); NumGVNOpsSimplified++; - assert(isa<BasicExpression>(E) && - "We should always have had a basic expression here"); - cast<BasicExpression>(E)->deallocateOperands(ArgRecycler); - ExpressionAllocator.Deallocate(E); - return CC->DefiningExpr; + deleteExpression(E); + return CC->getDefiningExpr(); } return nullptr; } -const Expression *NewGVN::createExpression(Instruction *I, - const BasicBlock *B) { - +const Expression *NewGVN::createExpression(Instruction *I) const { auto *E = new (ExpressionAllocator) BasicExpression(I->getNumOperands()); - bool AllConstant = setBasicExpressionInfo(I, E, B); + bool AllConstant = setBasicExpressionInfo(I, E); if (I->isCommutative()) { // Ensure that commutative instructions that only differ by a permutation @@ -543,7 +984,7 @@ const Expression *NewGVN::createExpression(Instruction *I, // numbers. Since all commutative instructions have two operands it is more // efficient to sort by hand rather than using, say, std::sort. assert(I->getNumOperands() == 2 && "Unsupported commutative instruction!"); - if (E->getOperand(0) > E->getOperand(1)) + if (shouldSwapOperands(E->getOperand(0), E->getOperand(1))) E->swapOperands(0, 1); } @@ -559,48 +1000,43 @@ const Expression *NewGVN::createExpression(Instruction *I, // Sort the operand value numbers so x<y and y>x get the same value // number. CmpInst::Predicate Predicate = CI->getPredicate(); - if (E->getOperand(0) > E->getOperand(1)) { + if (shouldSwapOperands(E->getOperand(0), E->getOperand(1))) { E->swapOperands(0, 1); Predicate = CmpInst::getSwappedPredicate(Predicate); } E->setOpcode((CI->getOpcode() << 8) | Predicate); // TODO: 25% of our time is spent in SimplifyCmpInst with pointer operands - // TODO: Since we noop bitcasts, we may need to check types before - // simplifying, so that we don't end up simplifying based on a wrong - // type assumption. We should clean this up so we can use constants of the - // wrong type - assert(I->getOperand(0)->getType() == I->getOperand(1)->getType() && "Wrong types on cmp instruction"); - if ((E->getOperand(0)->getType() == I->getOperand(0)->getType() && - E->getOperand(1)->getType() == I->getOperand(1)->getType())) { - Value *V = SimplifyCmpInst(Predicate, E->getOperand(0), E->getOperand(1), - *DL, TLI, DT, AC); - if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V)) - return SimplifiedE; - } + assert((E->getOperand(0)->getType() == I->getOperand(0)->getType() && + E->getOperand(1)->getType() == I->getOperand(1)->getType())); + Value *V = + SimplifyCmpInst(Predicate, E->getOperand(0), E->getOperand(1), SQ); + if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V)) + return SimplifiedE; } else if (isa<SelectInst>(I)) { if (isa<Constant>(E->getOperand(0)) || - (E->getOperand(1)->getType() == I->getOperand(1)->getType() && - E->getOperand(2)->getType() == I->getOperand(2)->getType())) { + E->getOperand(0) == E->getOperand(1)) { + assert(E->getOperand(1)->getType() == I->getOperand(1)->getType() && + E->getOperand(2)->getType() == I->getOperand(2)->getType()); Value *V = SimplifySelectInst(E->getOperand(0), E->getOperand(1), - E->getOperand(2), *DL, TLI, DT, AC); + E->getOperand(2), SQ); if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V)) return SimplifiedE; } } else if (I->isBinaryOp()) { - Value *V = SimplifyBinOp(E->getOpcode(), E->getOperand(0), E->getOperand(1), - *DL, TLI, DT, AC); + Value *V = + SimplifyBinOp(E->getOpcode(), E->getOperand(0), E->getOperand(1), SQ); if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V)) return SimplifiedE; } else if (auto *BI = dyn_cast<BitCastInst>(I)) { - Value *V = SimplifyInstruction(BI, *DL, TLI, DT, AC); + Value *V = + SimplifyCastInst(BI->getOpcode(), BI->getOperand(0), BI->getType(), SQ); if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V)) return SimplifiedE; } else if (isa<GetElementPtrInst>(I)) { - Value *V = SimplifyGEPInst(E->getType(), - ArrayRef<Value *>(E->op_begin(), E->op_end()), - *DL, TLI, DT, AC); + Value *V = SimplifyGEPInst( + E->getType(), ArrayRef<Value *>(E->op_begin(), E->op_end()), SQ); if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V)) return SimplifiedE; } else if (AllConstant) { @@ -615,7 +1051,7 @@ const Expression *NewGVN::createExpression(Instruction *I, for (Value *Arg : E->operands()) C.emplace_back(cast<Constant>(Arg)); - if (Value *V = ConstantFoldInstOperands(I, C, *DL, TLI)) + if (Value *V = ConstantFoldInstOperands(I, C, DL, TLI)) if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V)) return SimplifiedE; } @@ -623,18 +1059,18 @@ const Expression *NewGVN::createExpression(Instruction *I, } const AggregateValueExpression * -NewGVN::createAggregateValueExpression(Instruction *I, const BasicBlock *B) { +NewGVN::createAggregateValueExpression(Instruction *I) const { if (auto *II = dyn_cast<InsertValueInst>(I)) { auto *E = new (ExpressionAllocator) AggregateValueExpression(I->getNumOperands(), II->getNumIndices()); - setBasicExpressionInfo(I, E, B); + setBasicExpressionInfo(I, E); E->allocateIntOperands(ExpressionAllocator); std::copy(II->idx_begin(), II->idx_end(), int_op_inserter(E)); return E; } else if (auto *EI = dyn_cast<ExtractValueInst>(I)) { auto *E = new (ExpressionAllocator) AggregateValueExpression(I->getNumOperands(), EI->getNumIndices()); - setBasicExpressionInfo(EI, E, B); + setBasicExpressionInfo(EI, E); E->allocateIntOperands(ExpressionAllocator); std::copy(EI->idx_begin(), EI->idx_end(), int_op_inserter(E)); return E; @@ -642,67 +1078,120 @@ NewGVN::createAggregateValueExpression(Instruction *I, const BasicBlock *B) { llvm_unreachable("Unhandled type of aggregate value operation"); } -const VariableExpression *NewGVN::createVariableExpression(Value *V) { +const DeadExpression *NewGVN::createDeadExpression() const { + // DeadExpression has no arguments and all DeadExpression's are the same, + // so we only need one of them. + return SingletonDeadExpression; +} + +const VariableExpression *NewGVN::createVariableExpression(Value *V) const { auto *E = new (ExpressionAllocator) VariableExpression(V); E->setOpcode(V->getValueID()); return E; } -const Expression *NewGVN::createVariableOrConstant(Value *V, - const BasicBlock *B) { - auto Leader = lookupOperandLeader(V, nullptr, B); - if (auto *C = dyn_cast<Constant>(Leader)) +const Expression *NewGVN::createVariableOrConstant(Value *V) const { + if (auto *C = dyn_cast<Constant>(V)) return createConstantExpression(C); - return createVariableExpression(Leader); + return createVariableExpression(V); } -const ConstantExpression *NewGVN::createConstantExpression(Constant *C) { +const ConstantExpression *NewGVN::createConstantExpression(Constant *C) const { auto *E = new (ExpressionAllocator) ConstantExpression(C); E->setOpcode(C->getValueID()); return E; } -const UnknownExpression *NewGVN::createUnknownExpression(Instruction *I) { +const UnknownExpression *NewGVN::createUnknownExpression(Instruction *I) const { auto *E = new (ExpressionAllocator) UnknownExpression(I); E->setOpcode(I->getOpcode()); return E; } -const CallExpression *NewGVN::createCallExpression(CallInst *CI, - MemoryAccess *HV, - const BasicBlock *B) { +const CallExpression * +NewGVN::createCallExpression(CallInst *CI, const MemoryAccess *MA) const { // FIXME: Add operand bundles for calls. auto *E = - new (ExpressionAllocator) CallExpression(CI->getNumOperands(), CI, HV); - setBasicExpressionInfo(CI, E, B); + new (ExpressionAllocator) CallExpression(CI->getNumOperands(), CI, MA); + setBasicExpressionInfo(CI, E); return E; } +// Return true if some equivalent of instruction Inst dominates instruction U. +bool NewGVN::someEquivalentDominates(const Instruction *Inst, + const Instruction *U) const { + auto *CC = ValueToClass.lookup(Inst); + // This must be an instruction because we are only called from phi nodes + // in the case that the value it needs to check against is an instruction. + + // The most likely candiates for dominance are the leader and the next leader. + // The leader or nextleader will dominate in all cases where there is an + // equivalent that is higher up in the dom tree. + // We can't *only* check them, however, because the + // dominator tree could have an infinite number of non-dominating siblings + // with instructions that are in the right congruence class. + // A + // B C D E F G + // | + // H + // Instruction U could be in H, with equivalents in every other sibling. + // Depending on the rpo order picked, the leader could be the equivalent in + // any of these siblings. + if (!CC) + return false; + if (DT->dominates(cast<Instruction>(CC->getLeader()), U)) + return true; + if (CC->getNextLeader().first && + DT->dominates(cast<Instruction>(CC->getNextLeader().first), U)) + return true; + return llvm::any_of(*CC, [&](const Value *Member) { + return Member != CC->getLeader() && + DT->dominates(cast<Instruction>(Member), U); + }); +} + // See if we have a congruence class and leader for this operand, and if so, // return it. Otherwise, return the operand itself. -template <class T> -Value *NewGVN::lookupOperandLeader(Value *V, const User *U, const T &B) const { +Value *NewGVN::lookupOperandLeader(Value *V) const { CongruenceClass *CC = ValueToClass.lookup(V); - if (CC && (CC != InitialClass)) - return CC->RepLeader; + if (CC) { + // Everything in TOP is represented by undef, as it can be any value. + // We do have to make sure we get the type right though, so we can't set the + // RepLeader to undef. + if (CC == TOPClass) + return UndefValue::get(V->getType()); + return CC->getStoredValue() ? CC->getStoredValue() : CC->getLeader(); + } + return V; } -MemoryAccess *NewGVN::lookupMemoryAccessEquiv(MemoryAccess *MA) const { - MemoryAccess *Result = MemoryAccessEquiv.lookup(MA); - return Result ? Result : MA; +const MemoryAccess *NewGVN::lookupMemoryLeader(const MemoryAccess *MA) const { + auto *CC = getMemoryClass(MA); + assert(CC->getMemoryLeader() && + "Every MemoryAccess should be mapped to a congruence class with a " + "representative memory access"); + return CC->getMemoryLeader(); +} + +// Return true if the MemoryAccess is really equivalent to everything. This is +// equivalent to the lattice value "TOP" in most lattices. This is the initial +// state of all MemoryAccesses. +bool NewGVN::isMemoryAccessTOP(const MemoryAccess *MA) const { + return getMemoryClass(MA) == TOPClass; } LoadExpression *NewGVN::createLoadExpression(Type *LoadType, Value *PointerOp, - LoadInst *LI, MemoryAccess *DA, - const BasicBlock *B) { - auto *E = new (ExpressionAllocator) LoadExpression(1, LI, DA); + LoadInst *LI, + const MemoryAccess *MA) const { + auto *E = + new (ExpressionAllocator) LoadExpression(1, LI, lookupMemoryLeader(MA)); E->allocateOperands(ArgRecycler, ExpressionAllocator); E->setType(LoadType); // Give store and loads same opcode so they value number together. E->setOpcode(0); - E->op_push_back(lookupOperandLeader(PointerOp, LI, B)); + E->op_push_back(PointerOp); if (LI) E->setAlignment(LI->getAlignment()); @@ -712,17 +1201,17 @@ LoadExpression *NewGVN::createLoadExpression(Type *LoadType, Value *PointerOp, return E; } -const StoreExpression *NewGVN::createStoreExpression(StoreInst *SI, - MemoryAccess *DA, - const BasicBlock *B) { - auto *E = - new (ExpressionAllocator) StoreExpression(SI->getNumOperands(), SI, DA); +const StoreExpression * +NewGVN::createStoreExpression(StoreInst *SI, const MemoryAccess *MA) const { + auto *StoredValueLeader = lookupOperandLeader(SI->getValueOperand()); + auto *E = new (ExpressionAllocator) + StoreExpression(SI->getNumOperands(), SI, StoredValueLeader, MA); E->allocateOperands(ArgRecycler, ExpressionAllocator); E->setType(SI->getValueOperand()->getType()); // Give store and loads same opcode so they value number together. E->setOpcode(0); - E->op_push_back(lookupOperandLeader(SI->getPointerOperand(), SI, B)); + E->op_push_back(lookupOperandLeader(SI->getPointerOperand())); // TODO: Value number heap versions. We may be able to discover // things alias analysis can't on it's own (IE that a store and a @@ -730,44 +1219,136 @@ const StoreExpression *NewGVN::createStoreExpression(StoreInst *SI, return E; } -// Utility function to check whether the congruence class has a member other -// than the given instruction. -bool hasMemberOtherThanUs(const CongruenceClass *CC, Instruction *I) { - // Either it has more than one store, in which case it must contain something - // other than us (because it's indexed by value), or if it only has one store - // right now, that member should not be us. - return CC->StoreCount > 1 || CC->Members.count(I) == 0; -} - -const Expression *NewGVN::performSymbolicStoreEvaluation(Instruction *I, - const BasicBlock *B) { +const Expression *NewGVN::performSymbolicStoreEvaluation(Instruction *I) const { // Unlike loads, we never try to eliminate stores, so we do not check if they // are simple and avoid value numbering them. auto *SI = cast<StoreInst>(I); - MemoryAccess *StoreAccess = MSSA->getMemoryAccess(SI); - // See if we are defined by a previous store expression, it already has a - // value, and it's the same value as our current store. FIXME: Right now, we - // only do this for simple stores, we should expand to cover memcpys, etc. + auto *StoreAccess = getMemoryAccess(SI); + // Get the expression, if any, for the RHS of the MemoryDef. + const MemoryAccess *StoreRHS = StoreAccess->getDefiningAccess(); + if (EnableStoreRefinement) + StoreRHS = MSSAWalker->getClobberingMemoryAccess(StoreAccess); + // If we bypassed the use-def chains, make sure we add a use. + if (StoreRHS != StoreAccess->getDefiningAccess()) + addMemoryUsers(StoreRHS, StoreAccess); + StoreRHS = lookupMemoryLeader(StoreRHS); + // If we are defined by ourselves, use the live on entry def. + if (StoreRHS == StoreAccess) + StoreRHS = MSSA->getLiveOnEntryDef(); + if (SI->isSimple()) { - // Get the expression, if any, for the RHS of the MemoryDef. - MemoryAccess *StoreRHS = lookupMemoryAccessEquiv( - cast<MemoryDef>(StoreAccess)->getDefiningAccess()); - const Expression *OldStore = createStoreExpression(SI, StoreRHS, B); - CongruenceClass *CC = ExpressionToClass.lookup(OldStore); - // Basically, check if the congruence class the store is in is defined by a - // store that isn't us, and has the same value. MemorySSA takes care of - // ensuring the store has the same memory state as us already. - if (CC && CC->DefiningExpr && isa<StoreExpression>(CC->DefiningExpr) && - CC->RepLeader == lookupOperandLeader(SI->getValueOperand(), SI, B) && - hasMemberOtherThanUs(CC, I)) - return createStoreExpression(SI, StoreRHS, B); + // See if we are defined by a previous store expression, it already has a + // value, and it's the same value as our current store. FIXME: Right now, we + // only do this for simple stores, we should expand to cover memcpys, etc. + const auto *LastStore = createStoreExpression(SI, StoreRHS); + const auto *LastCC = ExpressionToClass.lookup(LastStore); + // We really want to check whether the expression we matched was a store. No + // easy way to do that. However, we can check that the class we found has a + // store, which, assuming the value numbering state is not corrupt, is + // sufficient, because we must also be equivalent to that store's expression + // for it to be in the same class as the load. + if (LastCC && LastCC->getStoredValue() == LastStore->getStoredValue()) + return LastStore; + // Also check if our value operand is defined by a load of the same memory + // location, and the memory state is the same as it was then (otherwise, it + // could have been overwritten later. See test32 in + // transforms/DeadStoreElimination/simple.ll). + if (auto *LI = dyn_cast<LoadInst>(LastStore->getStoredValue())) + if ((lookupOperandLeader(LI->getPointerOperand()) == + LastStore->getOperand(0)) && + (lookupMemoryLeader(getMemoryAccess(LI)->getDefiningAccess()) == + StoreRHS)) + return LastStore; + deleteExpression(LastStore); + } + + // If the store is not equivalent to anything, value number it as a store that + // produces a unique memory state (instead of using it's MemoryUse, we use + // it's MemoryDef). + return createStoreExpression(SI, StoreAccess); +} + +// See if we can extract the value of a loaded pointer from a load, a store, or +// a memory instruction. +const Expression * +NewGVN::performSymbolicLoadCoercion(Type *LoadType, Value *LoadPtr, + LoadInst *LI, Instruction *DepInst, + MemoryAccess *DefiningAccess) const { + assert((!LI || LI->isSimple()) && "Not a simple load"); + if (auto *DepSI = dyn_cast<StoreInst>(DepInst)) { + // Can't forward from non-atomic to atomic without violating memory model. + // Also don't need to coerce if they are the same type, we will just + // propogate.. + if (LI->isAtomic() > DepSI->isAtomic() || + LoadType == DepSI->getValueOperand()->getType()) + return nullptr; + int Offset = analyzeLoadFromClobberingStore(LoadType, LoadPtr, DepSI, DL); + if (Offset >= 0) { + if (auto *C = dyn_cast<Constant>( + lookupOperandLeader(DepSI->getValueOperand()))) { + DEBUG(dbgs() << "Coercing load from store " << *DepSI << " to constant " + << *C << "\n"); + return createConstantExpression( + getConstantStoreValueForLoad(C, Offset, LoadType, DL)); + } + } + + } else if (LoadInst *DepLI = dyn_cast<LoadInst>(DepInst)) { + // Can't forward from non-atomic to atomic without violating memory model. + if (LI->isAtomic() > DepLI->isAtomic()) + return nullptr; + int Offset = analyzeLoadFromClobberingLoad(LoadType, LoadPtr, DepLI, DL); + if (Offset >= 0) { + // We can coerce a constant load into a load + if (auto *C = dyn_cast<Constant>(lookupOperandLeader(DepLI))) + if (auto *PossibleConstant = + getConstantLoadValueForLoad(C, Offset, LoadType, DL)) { + DEBUG(dbgs() << "Coercing load from load " << *LI << " to constant " + << *PossibleConstant << "\n"); + return createConstantExpression(PossibleConstant); + } + } + + } else if (MemIntrinsic *DepMI = dyn_cast<MemIntrinsic>(DepInst)) { + int Offset = analyzeLoadFromClobberingMemInst(LoadType, LoadPtr, DepMI, DL); + if (Offset >= 0) { + if (auto *PossibleConstant = + getConstantMemInstValueForLoad(DepMI, Offset, LoadType, DL)) { + DEBUG(dbgs() << "Coercing load from meminst " << *DepMI + << " to constant " << *PossibleConstant << "\n"); + return createConstantExpression(PossibleConstant); + } + } + } + + // All of the below are only true if the loaded pointer is produced + // by the dependent instruction. + if (LoadPtr != lookupOperandLeader(DepInst) && + !AA->isMustAlias(LoadPtr, DepInst)) + return nullptr; + // If this load really doesn't depend on anything, then we must be loading an + // undef value. This can happen when loading for a fresh allocation with no + // intervening stores, for example. Note that this is only true in the case + // that the result of the allocation is pointer equal to the load ptr. + if (isa<AllocaInst>(DepInst) || isMallocLikeFn(DepInst, TLI)) { + return createConstantExpression(UndefValue::get(LoadType)); + } + // If this load occurs either right after a lifetime begin, + // then the loaded value is undefined. + else if (auto *II = dyn_cast<IntrinsicInst>(DepInst)) { + if (II->getIntrinsicID() == Intrinsic::lifetime_start) + return createConstantExpression(UndefValue::get(LoadType)); + } + // If this load follows a calloc (which zero initializes memory), + // then the loaded value is zero + else if (isCallocLikeFn(DepInst, TLI)) { + return createConstantExpression(Constant::getNullValue(LoadType)); } - return createStoreExpression(SI, StoreAccess, B); + return nullptr; } -const Expression *NewGVN::performSymbolicLoadEvaluation(Instruction *I, - const BasicBlock *B) { +const Expression *NewGVN::performSymbolicLoadEvaluation(Instruction *I) const { auto *LI = cast<LoadInst>(I); // We can eliminate in favor of non-simple loads, but we won't be able to @@ -775,12 +1356,13 @@ const Expression *NewGVN::performSymbolicLoadEvaluation(Instruction *I, if (!LI->isSimple()) return nullptr; - Value *LoadAddressLeader = lookupOperandLeader(LI->getPointerOperand(), I, B); + Value *LoadAddressLeader = lookupOperandLeader(LI->getPointerOperand()); // Load of undef is undef. if (isa<UndefValue>(LoadAddressLeader)) return createConstantExpression(UndefValue::get(LI->getType())); - - MemoryAccess *DefiningAccess = MSSAWalker->getClobberingMemoryAccess(I); + MemoryAccess *OriginalAccess = getMemoryAccess(I); + MemoryAccess *DefiningAccess = + MSSAWalker->getClobberingMemoryAccess(OriginalAccess); if (!MSSA->isLiveOnEntryDef(DefiningAccess)) { if (auto *MD = dyn_cast<MemoryDef>(DefiningAccess)) { @@ -788,88 +1370,263 @@ const Expression *NewGVN::performSymbolicLoadEvaluation(Instruction *I, // If the defining instruction is not reachable, replace with undef. if (!ReachableBlocks.count(DefiningInst->getParent())) return createConstantExpression(UndefValue::get(LI->getType())); + // This will handle stores and memory insts. We only do if it the + // defining access has a different type, or it is a pointer produced by + // certain memory operations that cause the memory to have a fixed value + // (IE things like calloc). + if (const auto *CoercionResult = + performSymbolicLoadCoercion(LI->getType(), LoadAddressLeader, LI, + DefiningInst, DefiningAccess)) + return CoercionResult; } } - const Expression *E = - createLoadExpression(LI->getType(), LI->getPointerOperand(), LI, - lookupMemoryAccessEquiv(DefiningAccess), B); + const Expression *E = createLoadExpression(LI->getType(), LoadAddressLeader, + LI, DefiningAccess); return E; } +const Expression * +NewGVN::performSymbolicPredicateInfoEvaluation(Instruction *I) const { + auto *PI = PredInfo->getPredicateInfoFor(I); + if (!PI) + return nullptr; + + DEBUG(dbgs() << "Found predicate info from instruction !\n"); + + auto *PWC = dyn_cast<PredicateWithCondition>(PI); + if (!PWC) + return nullptr; + + auto *CopyOf = I->getOperand(0); + auto *Cond = PWC->Condition; + + // If this a copy of the condition, it must be either true or false depending + // on the predicate info type and edge + if (CopyOf == Cond) { + // We should not need to add predicate users because the predicate info is + // already a use of this operand. + if (isa<PredicateAssume>(PI)) + return createConstantExpression(ConstantInt::getTrue(Cond->getType())); + if (auto *PBranch = dyn_cast<PredicateBranch>(PI)) { + if (PBranch->TrueEdge) + return createConstantExpression(ConstantInt::getTrue(Cond->getType())); + return createConstantExpression(ConstantInt::getFalse(Cond->getType())); + } + if (auto *PSwitch = dyn_cast<PredicateSwitch>(PI)) + return createConstantExpression(cast<Constant>(PSwitch->CaseValue)); + } + + // Not a copy of the condition, so see what the predicates tell us about this + // value. First, though, we check to make sure the value is actually a copy + // of one of the condition operands. It's possible, in certain cases, for it + // to be a copy of a predicateinfo copy. In particular, if two branch + // operations use the same condition, and one branch dominates the other, we + // will end up with a copy of a copy. This is currently a small deficiency in + // predicateinfo. What will end up happening here is that we will value + // number both copies the same anyway. + + // Everything below relies on the condition being a comparison. + auto *Cmp = dyn_cast<CmpInst>(Cond); + if (!Cmp) + return nullptr; + + if (CopyOf != Cmp->getOperand(0) && CopyOf != Cmp->getOperand(1)) { + DEBUG(dbgs() << "Copy is not of any condition operands!\n"); + return nullptr; + } + Value *FirstOp = lookupOperandLeader(Cmp->getOperand(0)); + Value *SecondOp = lookupOperandLeader(Cmp->getOperand(1)); + bool SwappedOps = false; + // Sort the ops + if (shouldSwapOperands(FirstOp, SecondOp)) { + std::swap(FirstOp, SecondOp); + SwappedOps = true; + } + CmpInst::Predicate Predicate = + SwappedOps ? Cmp->getSwappedPredicate() : Cmp->getPredicate(); + + if (isa<PredicateAssume>(PI)) { + // If the comparison is true when the operands are equal, then we know the + // operands are equal, because assumes must always be true. + if (CmpInst::isTrueWhenEqual(Predicate)) { + addPredicateUsers(PI, I); + addAdditionalUsers(Cmp->getOperand(0), I); + return createVariableOrConstant(FirstOp); + } + } + if (const auto *PBranch = dyn_cast<PredicateBranch>(PI)) { + // If we are *not* a copy of the comparison, we may equal to the other + // operand when the predicate implies something about equality of + // operations. In particular, if the comparison is true/false when the + // operands are equal, and we are on the right edge, we know this operation + // is equal to something. + if ((PBranch->TrueEdge && Predicate == CmpInst::ICMP_EQ) || + (!PBranch->TrueEdge && Predicate == CmpInst::ICMP_NE)) { + addPredicateUsers(PI, I); + addAdditionalUsers(Cmp->getOperand(0), I); + return createVariableOrConstant(FirstOp); + } + // Handle the special case of floating point. + if (((PBranch->TrueEdge && Predicate == CmpInst::FCMP_OEQ) || + (!PBranch->TrueEdge && Predicate == CmpInst::FCMP_UNE)) && + isa<ConstantFP>(FirstOp) && !cast<ConstantFP>(FirstOp)->isZero()) { + addPredicateUsers(PI, I); + addAdditionalUsers(Cmp->getOperand(0), I); + return createConstantExpression(cast<Constant>(FirstOp)); + } + } + return nullptr; +} + // Evaluate read only and pure calls, and create an expression result. -const Expression *NewGVN::performSymbolicCallEvaluation(Instruction *I, - const BasicBlock *B) { +const Expression *NewGVN::performSymbolicCallEvaluation(Instruction *I) const { auto *CI = cast<CallInst>(I); - if (AA->doesNotAccessMemory(CI)) - return createCallExpression(CI, nullptr, B); - if (AA->onlyReadsMemory(CI)) { + if (auto *II = dyn_cast<IntrinsicInst>(I)) { + // Instrinsics with the returned attribute are copies of arguments. + if (auto *ReturnedValue = II->getReturnedArgOperand()) { + if (II->getIntrinsicID() == Intrinsic::ssa_copy) + if (const auto *Result = performSymbolicPredicateInfoEvaluation(I)) + return Result; + return createVariableOrConstant(ReturnedValue); + } + } + if (AA->doesNotAccessMemory(CI)) { + return createCallExpression(CI, TOPClass->getMemoryLeader()); + } else if (AA->onlyReadsMemory(CI)) { MemoryAccess *DefiningAccess = MSSAWalker->getClobberingMemoryAccess(CI); - return createCallExpression(CI, lookupMemoryAccessEquiv(DefiningAccess), B); + return createCallExpression(CI, DefiningAccess); } return nullptr; } -// Update the memory access equivalence table to say that From is equal to To, +// Retrieve the memory class for a given MemoryAccess. +CongruenceClass *NewGVN::getMemoryClass(const MemoryAccess *MA) const { + + auto *Result = MemoryAccessToClass.lookup(MA); + assert(Result && "Should have found memory class"); + return Result; +} + +// Update the MemoryAccess equivalence table to say that From is equal to To, // and return true if this is different from what already existed in the table. -bool NewGVN::setMemoryAccessEquivTo(MemoryAccess *From, MemoryAccess *To) { - DEBUG(dbgs() << "Setting " << *From << " equivalent to "); - if (!To) - DEBUG(dbgs() << "itself"); - else - DEBUG(dbgs() << *To); - DEBUG(dbgs() << "\n"); - auto LookupResult = MemoryAccessEquiv.find(From); +bool NewGVN::setMemoryClass(const MemoryAccess *From, + CongruenceClass *NewClass) { + assert(NewClass && + "Every MemoryAccess should be getting mapped to a non-null class"); + DEBUG(dbgs() << "Setting " << *From); + DEBUG(dbgs() << " equivalent to congruence class "); + DEBUG(dbgs() << NewClass->getID() << " with current MemoryAccess leader "); + DEBUG(dbgs() << *NewClass->getMemoryLeader() << "\n"); + + auto LookupResult = MemoryAccessToClass.find(From); bool Changed = false; // If it's already in the table, see if the value changed. - if (LookupResult != MemoryAccessEquiv.end()) { - if (To && LookupResult->second != To) { + if (LookupResult != MemoryAccessToClass.end()) { + auto *OldClass = LookupResult->second; + if (OldClass != NewClass) { + // If this is a phi, we have to handle memory member updates. + if (auto *MP = dyn_cast<MemoryPhi>(From)) { + OldClass->memory_erase(MP); + NewClass->memory_insert(MP); + // This may have killed the class if it had no non-memory members + if (OldClass->getMemoryLeader() == From) { + if (OldClass->definesNoMemory()) { + OldClass->setMemoryLeader(nullptr); + } else { + OldClass->setMemoryLeader(getNextMemoryLeader(OldClass)); + DEBUG(dbgs() << "Memory class leader change for class " + << OldClass->getID() << " to " + << *OldClass->getMemoryLeader() + << " due to removal of a memory member " << *From + << "\n"); + markMemoryLeaderChangeTouched(OldClass); + } + } + } // It wasn't equivalent before, and now it is. - LookupResult->second = To; - Changed = true; - } else if (!To) { - // It used to be equivalent to something, and now it's not. - MemoryAccessEquiv.erase(LookupResult); + LookupResult->second = NewClass; Changed = true; } - } else { - assert(!To && - "Memory equivalence should never change from nothing to something"); } return Changed; } + +// Determine if a instruction is cycle-free. That means the values in the +// instruction don't depend on any expressions that can change value as a result +// of the instruction. For example, a non-cycle free instruction would be v = +// phi(0, v+1). +bool NewGVN::isCycleFree(const Instruction *I) const { + // In order to compute cycle-freeness, we do SCC finding on the instruction, + // and see what kind of SCC it ends up in. If it is a singleton, it is + // cycle-free. If it is not in a singleton, it is only cycle free if the + // other members are all phi nodes (as they do not compute anything, they are + // copies). + auto ICS = InstCycleState.lookup(I); + if (ICS == ICS_Unknown) { + SCCFinder.Start(I); + auto &SCC = SCCFinder.getComponentFor(I); + // It's cycle free if it's size 1 or or the SCC is *only* phi nodes. + if (SCC.size() == 1) + InstCycleState.insert({I, ICS_CycleFree}); + else { + bool AllPhis = + llvm::all_of(SCC, [](const Value *V) { return isa<PHINode>(V); }); + ICS = AllPhis ? ICS_CycleFree : ICS_Cycle; + for (auto *Member : SCC) + if (auto *MemberPhi = dyn_cast<PHINode>(Member)) + InstCycleState.insert({MemberPhi, ICS}); + } + } + if (ICS == ICS_Cycle) + return false; + return true; +} + // Evaluate PHI nodes symbolically, and create an expression result. -const Expression *NewGVN::performSymbolicPHIEvaluation(Instruction *I, - const BasicBlock *B) { - auto *E = cast<PHIExpression>(createPHIExpression(I)); +const Expression *NewGVN::performSymbolicPHIEvaluation(Instruction *I) const { + // True if one of the incoming phi edges is a backedge. + bool HasBackedge = false; + // All constant tracks the state of whether all the *original* phi operands + // This is really shorthand for "this phi cannot cycle due to forward + // change in value of the phi is guaranteed not to later change the value of + // the phi. IE it can't be v = phi(undef, v+1) + bool AllConstant = true; + auto *E = + cast<PHIExpression>(createPHIExpression(I, HasBackedge, AllConstant)); // We match the semantics of SimplifyPhiNode from InstructionSimplify here. - - // See if all arguaments are the same. + // See if all arguments are the same. // We track if any were undef because they need special handling. bool HasUndef = false; - auto Filtered = make_filter_range(E->operands(), [&](const Value *Arg) { - if (Arg == I) - return false; + auto Filtered = make_filter_range(E->operands(), [&](Value *Arg) { if (isa<UndefValue>(Arg)) { HasUndef = true; return false; } return true; }); - // If we are left with no operands, it's undef + // If we are left with no operands, it's dead. if (Filtered.begin() == Filtered.end()) { - DEBUG(dbgs() << "Simplified PHI node " << *I << " to undef" - << "\n"); - E->deallocateOperands(ArgRecycler); - ExpressionAllocator.Deallocate(E); - return createConstantExpression(UndefValue::get(I->getType())); + // If it has undef at this point, it means there are no-non-undef arguments, + // and thus, the value of the phi node must be undef. + if (HasUndef) { + DEBUG(dbgs() << "PHI Node " << *I + << " has no non-undef arguments, valuing it as undef\n"); + return createConstantExpression(UndefValue::get(I->getType())); + } + + DEBUG(dbgs() << "No arguments of PHI node " << *I << " are live\n"); + deleteExpression(E); + return createDeadExpression(); } + unsigned NumOps = 0; Value *AllSameValue = *(Filtered.begin()); ++Filtered.begin(); // Can't use std::equal here, sadly, because filter.begin moves. - if (llvm::all_of(Filtered, [AllSameValue](const Value *V) { - return V == AllSameValue; + if (llvm::all_of(Filtered, [&](Value *Arg) { + ++NumOps; + return Arg == AllSameValue; })) { // In LLVM's non-standard representation of phi nodes, it's possible to have // phi nodes with cycles (IE dependent on other phis that are .... dependent @@ -881,27 +1638,38 @@ const Expression *NewGVN::performSymbolicPHIEvaluation(Instruction *I, // We also special case undef, so that if we have an undef, we can't use the // common value unless it dominates the phi block. if (HasUndef) { + // If we have undef and at least one other value, this is really a + // multivalued phi, and we need to know if it's cycle free in order to + // evaluate whether we can ignore the undef. The other parts of this are + // just shortcuts. If there is no backedge, or all operands are + // constants, or all operands are ignored but the undef, it also must be + // cycle free. + if (!AllConstant && HasBackedge && NumOps > 0 && + !isa<UndefValue>(AllSameValue) && !isCycleFree(I)) + return E; + // Only have to check for instructions if (auto *AllSameInst = dyn_cast<Instruction>(AllSameValue)) - if (!DT->dominates(AllSameInst, I)) + if (!someEquivalentDominates(AllSameInst, I)) return E; } - + // Can't simplify to something that comes later in the iteration. + // Otherwise, when and if it changes congruence class, we will never catch + // up. We will always be a class behind it. + if (isa<Instruction>(AllSameValue) && + InstrToDFSNum(AllSameValue) > InstrToDFSNum(I)) + return E; NumGVNPhisAllSame++; DEBUG(dbgs() << "Simplified PHI node " << *I << " to " << *AllSameValue << "\n"); - E->deallocateOperands(ArgRecycler); - ExpressionAllocator.Deallocate(E); - if (auto *C = dyn_cast<Constant>(AllSameValue)) - return createConstantExpression(C); - return createVariableExpression(AllSameValue); + deleteExpression(E); + return createVariableOrConstant(AllSameValue); } return E; } const Expression * -NewGVN::performSymbolicAggrValueEvaluation(Instruction *I, - const BasicBlock *B) { +NewGVN::performSymbolicAggrValueEvaluation(Instruction *I) const { if (auto *EI = dyn_cast<ExtractValueInst>(I)) { auto *II = dyn_cast<IntrinsicInst>(EI->getAggregateOperand()); if (II && EI->getNumIndices() == 1 && *EI->idx_begin() == 0) { @@ -931,19 +1699,140 @@ NewGVN::performSymbolicAggrValueEvaluation(Instruction *I, // expression. assert(II->getNumArgOperands() == 2 && "Expect two args for recognised intrinsics."); - return createBinaryExpression(Opcode, EI->getType(), - II->getArgOperand(0), - II->getArgOperand(1), B); + return createBinaryExpression( + Opcode, EI->getType(), II->getArgOperand(0), II->getArgOperand(1)); } } } - return createAggregateValueExpression(I, B); + return createAggregateValueExpression(I); +} +const Expression *NewGVN::performSymbolicCmpEvaluation(Instruction *I) const { + auto *CI = dyn_cast<CmpInst>(I); + // See if our operands are equal to those of a previous predicate, and if so, + // if it implies true or false. + auto Op0 = lookupOperandLeader(CI->getOperand(0)); + auto Op1 = lookupOperandLeader(CI->getOperand(1)); + auto OurPredicate = CI->getPredicate(); + if (shouldSwapOperands(Op0, Op1)) { + std::swap(Op0, Op1); + OurPredicate = CI->getSwappedPredicate(); + } + + // Avoid processing the same info twice + const PredicateBase *LastPredInfo = nullptr; + // See if we know something about the comparison itself, like it is the target + // of an assume. + auto *CmpPI = PredInfo->getPredicateInfoFor(I); + if (dyn_cast_or_null<PredicateAssume>(CmpPI)) + return createConstantExpression(ConstantInt::getTrue(CI->getType())); + + if (Op0 == Op1) { + // This condition does not depend on predicates, no need to add users + if (CI->isTrueWhenEqual()) + return createConstantExpression(ConstantInt::getTrue(CI->getType())); + else if (CI->isFalseWhenEqual()) + return createConstantExpression(ConstantInt::getFalse(CI->getType())); + } + + // NOTE: Because we are comparing both operands here and below, and using + // previous comparisons, we rely on fact that predicateinfo knows to mark + // comparisons that use renamed operands as users of the earlier comparisons. + // It is *not* enough to just mark predicateinfo renamed operands as users of + // the earlier comparisons, because the *other* operand may have changed in a + // previous iteration. + // Example: + // icmp slt %a, %b + // %b.0 = ssa.copy(%b) + // false branch: + // icmp slt %c, %b.0 + + // %c and %a may start out equal, and thus, the code below will say the second + // %icmp is false. c may become equal to something else, and in that case the + // %second icmp *must* be reexamined, but would not if only the renamed + // %operands are considered users of the icmp. + + // *Currently* we only check one level of comparisons back, and only mark one + // level back as touched when changes appen . If you modify this code to look + // back farther through comparisons, you *must* mark the appropriate + // comparisons as users in PredicateInfo.cpp, or you will cause bugs. See if + // we know something just from the operands themselves + + // See if our operands have predicate info, so that we may be able to derive + // something from a previous comparison. + for (const auto &Op : CI->operands()) { + auto *PI = PredInfo->getPredicateInfoFor(Op); + if (const auto *PBranch = dyn_cast_or_null<PredicateBranch>(PI)) { + if (PI == LastPredInfo) + continue; + LastPredInfo = PI; + + // TODO: Along the false edge, we may know more things too, like icmp of + // same operands is false. + // TODO: We only handle actual comparison conditions below, not and/or. + auto *BranchCond = dyn_cast<CmpInst>(PBranch->Condition); + if (!BranchCond) + continue; + auto *BranchOp0 = lookupOperandLeader(BranchCond->getOperand(0)); + auto *BranchOp1 = lookupOperandLeader(BranchCond->getOperand(1)); + auto BranchPredicate = BranchCond->getPredicate(); + if (shouldSwapOperands(BranchOp0, BranchOp1)) { + std::swap(BranchOp0, BranchOp1); + BranchPredicate = BranchCond->getSwappedPredicate(); + } + if (BranchOp0 == Op0 && BranchOp1 == Op1) { + if (PBranch->TrueEdge) { + // If we know the previous predicate is true and we are in the true + // edge then we may be implied true or false. + if (CmpInst::isImpliedTrueByMatchingCmp(BranchPredicate, + OurPredicate)) { + addPredicateUsers(PI, I); + return createConstantExpression( + ConstantInt::getTrue(CI->getType())); + } + + if (CmpInst::isImpliedFalseByMatchingCmp(BranchPredicate, + OurPredicate)) { + addPredicateUsers(PI, I); + return createConstantExpression( + ConstantInt::getFalse(CI->getType())); + } + + } else { + // Just handle the ne and eq cases, where if we have the same + // operands, we may know something. + if (BranchPredicate == OurPredicate) { + addPredicateUsers(PI, I); + // Same predicate, same ops,we know it was false, so this is false. + return createConstantExpression( + ConstantInt::getFalse(CI->getType())); + } else if (BranchPredicate == + CmpInst::getInversePredicate(OurPredicate)) { + addPredicateUsers(PI, I); + // Inverse predicate, we know the other was false, so this is true. + return createConstantExpression( + ConstantInt::getTrue(CI->getType())); + } + } + } + } + } + // Create expression will take care of simplifyCmpInst + return createExpression(I); +} + +// Return true if V is a value that will always be available (IE can +// be placed anywhere) in the function. We don't do globals here +// because they are often worse to put in place. +// TODO: Separate cost from availability +static bool alwaysAvailable(Value *V) { + return isa<Constant>(V) || isa<Argument>(V); } // Substitute and symbolize the value before value numbering. -const Expression *NewGVN::performSymbolicEvaluation(Value *V, - const BasicBlock *B) { +const Expression * +NewGVN::performSymbolicEvaluation(Value *V, + SmallPtrSetImpl<Value *> &Visited) const { const Expression *E = nullptr; if (auto *C = dyn_cast<Constant>(V)) E = createConstantExpression(C); @@ -957,24 +1846,27 @@ const Expression *NewGVN::performSymbolicEvaluation(Value *V, switch (I->getOpcode()) { case Instruction::ExtractValue: case Instruction::InsertValue: - E = performSymbolicAggrValueEvaluation(I, B); + E = performSymbolicAggrValueEvaluation(I); break; case Instruction::PHI: - E = performSymbolicPHIEvaluation(I, B); + E = performSymbolicPHIEvaluation(I); break; case Instruction::Call: - E = performSymbolicCallEvaluation(I, B); + E = performSymbolicCallEvaluation(I); break; case Instruction::Store: - E = performSymbolicStoreEvaluation(I, B); + E = performSymbolicStoreEvaluation(I); break; case Instruction::Load: - E = performSymbolicLoadEvaluation(I, B); + E = performSymbolicLoadEvaluation(I); break; case Instruction::BitCast: { - E = createExpression(I, B); + E = createExpression(I); + } break; + case Instruction::ICmp: + case Instruction::FCmp: { + E = performSymbolicCmpEvaluation(I); } break; - case Instruction::Add: case Instruction::FAdd: case Instruction::Sub: @@ -993,8 +1885,6 @@ const Expression *NewGVN::performSymbolicEvaluation(Value *V, case Instruction::And: case Instruction::Or: case Instruction::Xor: - case Instruction::ICmp: - case Instruction::FCmp: case Instruction::Trunc: case Instruction::ZExt: case Instruction::SExt: @@ -1011,7 +1901,7 @@ const Expression *NewGVN::performSymbolicEvaluation(Value *V, case Instruction::InsertElement: case Instruction::ShuffleVector: case Instruction::GetElementPtr: - E = createExpression(I, B); + E = createExpression(I); break; default: return nullptr; @@ -1020,147 +1910,308 @@ const Expression *NewGVN::performSymbolicEvaluation(Value *V, return E; } -// There is an edge from 'Src' to 'Dst'. Return true if every path from -// the entry block to 'Dst' passes via this edge. In particular 'Dst' -// must not be reachable via another edge from 'Src'. -bool NewGVN::isOnlyReachableViaThisEdge(const BasicBlockEdge &E) const { +// Look up a container in a map, and then call a function for each thing in the +// found container. +template <typename Map, typename KeyType, typename Func> +void NewGVN::for_each_found(Map &M, const KeyType &Key, Func F) { + const auto Result = M.find_as(Key); + if (Result != M.end()) + for (typename Map::mapped_type::value_type Mapped : Result->second) + F(Mapped); +} + +// Look up a container of values/instructions in a map, and touch all the +// instructions in the container. Then erase value from the map. +template <typename Map, typename KeyType> +void NewGVN::touchAndErase(Map &M, const KeyType &Key) { + const auto Result = M.find_as(Key); + if (Result != M.end()) { + for (const typename Map::mapped_type::value_type Mapped : Result->second) + TouchedInstructions.set(InstrToDFSNum(Mapped)); + M.erase(Result); + } +} - // While in theory it is interesting to consider the case in which Dst has - // more than one predecessor, because Dst might be part of a loop which is - // only reachable from Src, in practice it is pointless since at the time - // GVN runs all such loops have preheaders, which means that Dst will have - // been changed to have only one predecessor, namely Src. - const BasicBlock *Pred = E.getEnd()->getSinglePredecessor(); - const BasicBlock *Src = E.getStart(); - assert((!Pred || Pred == Src) && "No edge between these basic blocks!"); - (void)Src; - return Pred != nullptr; +void NewGVN::addAdditionalUsers(Value *To, Value *User) const { + if (isa<Instruction>(To)) + AdditionalUsers[To].insert(User); } void NewGVN::markUsersTouched(Value *V) { // Now mark the users as touched. for (auto *User : V->users()) { assert(isa<Instruction>(User) && "Use of value not within an instruction?"); - TouchedInstructions.set(InstrDFS[User]); + TouchedInstructions.set(InstrToDFSNum(User)); } + touchAndErase(AdditionalUsers, V); } -void NewGVN::markMemoryUsersTouched(MemoryAccess *MA) { - for (auto U : MA->users()) { - if (auto *MUD = dyn_cast<MemoryUseOrDef>(U)) - TouchedInstructions.set(InstrDFS[MUD->getMemoryInst()]); - else - TouchedInstructions.set(InstrDFS[U]); - } +void NewGVN::addMemoryUsers(const MemoryAccess *To, MemoryAccess *U) const { + DEBUG(dbgs() << "Adding memory user " << *U << " to " << *To << "\n"); + MemoryToUsers[To].insert(U); +} + +void NewGVN::markMemoryDefTouched(const MemoryAccess *MA) { + TouchedInstructions.set(MemoryToDFSNum(MA)); +} + +void NewGVN::markMemoryUsersTouched(const MemoryAccess *MA) { + if (isa<MemoryUse>(MA)) + return; + for (auto U : MA->users()) + TouchedInstructions.set(MemoryToDFSNum(U)); + touchAndErase(MemoryToUsers, MA); +} + +// Add I to the set of users of a given predicate. +void NewGVN::addPredicateUsers(const PredicateBase *PB, Instruction *I) const { + // Don't add temporary instructions to the user lists. + if (AllTempInstructions.count(I)) + return; + + if (auto *PBranch = dyn_cast<PredicateBranch>(PB)) + PredicateToUsers[PBranch->Condition].insert(I); + else if (auto *PAssume = dyn_cast<PredicateBranch>(PB)) + PredicateToUsers[PAssume->Condition].insert(I); +} + +// Touch all the predicates that depend on this instruction. +void NewGVN::markPredicateUsersTouched(Instruction *I) { + touchAndErase(PredicateToUsers, I); +} + +// Mark users affected by a memory leader change. +void NewGVN::markMemoryLeaderChangeTouched(CongruenceClass *CC) { + for (auto M : CC->memory()) + markMemoryDefTouched(M); } // Touch the instructions that need to be updated after a congruence class has a // leader change, and mark changed values. -void NewGVN::markLeaderChangeTouched(CongruenceClass *CC) { - for (auto M : CC->Members) { +void NewGVN::markValueLeaderChangeTouched(CongruenceClass *CC) { + for (auto M : *CC) { if (auto *I = dyn_cast<Instruction>(M)) - TouchedInstructions.set(InstrDFS[I]); + TouchedInstructions.set(InstrToDFSNum(I)); LeaderChanges.insert(M); } } +// Give a range of things that have instruction DFS numbers, this will return +// the member of the range with the smallest dfs number. +template <class T, class Range> +T *NewGVN::getMinDFSOfRange(const Range &R) const { + std::pair<T *, unsigned> MinDFS = {nullptr, ~0U}; + for (const auto X : R) { + auto DFSNum = InstrToDFSNum(X); + if (DFSNum < MinDFS.second) + MinDFS = {X, DFSNum}; + } + return MinDFS.first; +} + +// This function returns the MemoryAccess that should be the next leader of +// congruence class CC, under the assumption that the current leader is going to +// disappear. +const MemoryAccess *NewGVN::getNextMemoryLeader(CongruenceClass *CC) const { + // TODO: If this ends up to slow, we can maintain a next memory leader like we + // do for regular leaders. + // Make sure there will be a leader to find + assert(!CC->definesNoMemory() && "Can't get next leader if there is none"); + if (CC->getStoreCount() > 0) { + if (auto *NL = dyn_cast_or_null<StoreInst>(CC->getNextLeader().first)) + return getMemoryAccess(NL); + // Find the store with the minimum DFS number. + auto *V = getMinDFSOfRange<Value>(make_filter_range( + *CC, [&](const Value *V) { return isa<StoreInst>(V); })); + return getMemoryAccess(cast<StoreInst>(V)); + } + assert(CC->getStoreCount() == 0); + + // Given our assertion, hitting this part must mean + // !OldClass->memory_empty() + if (CC->memory_size() == 1) + return *CC->memory_begin(); + return getMinDFSOfRange<const MemoryPhi>(CC->memory()); +} + +// This function returns the next value leader of a congruence class, under the +// assumption that the current leader is going away. This should end up being +// the next most dominating member. +Value *NewGVN::getNextValueLeader(CongruenceClass *CC) const { + // We don't need to sort members if there is only 1, and we don't care about + // sorting the TOP class because everything either gets out of it or is + // unreachable. + + if (CC->size() == 1 || CC == TOPClass) { + return *(CC->begin()); + } else if (CC->getNextLeader().first) { + ++NumGVNAvoidedSortedLeaderChanges; + return CC->getNextLeader().first; + } else { + ++NumGVNSortedLeaderChanges; + // NOTE: If this ends up to slow, we can maintain a dual structure for + // member testing/insertion, or keep things mostly sorted, and sort only + // here, or use SparseBitVector or .... + return getMinDFSOfRange<Value>(*CC); + } +} + +// Move a MemoryAccess, currently in OldClass, to NewClass, including updates to +// the memory members, etc for the move. +// +// The invariants of this function are: +// +// - I must be moving to NewClass from OldClass +// - The StoreCount of OldClass and NewClass is expected to have been updated +// for I already if it is is a store. +// - The OldClass memory leader has not been updated yet if I was the leader. +void NewGVN::moveMemoryToNewCongruenceClass(Instruction *I, + MemoryAccess *InstMA, + CongruenceClass *OldClass, + CongruenceClass *NewClass) { + // If the leader is I, and we had a represenative MemoryAccess, it should + // be the MemoryAccess of OldClass. + assert((!InstMA || !OldClass->getMemoryLeader() || + OldClass->getLeader() != I || + MemoryAccessToClass.lookup(OldClass->getMemoryLeader()) == + MemoryAccessToClass.lookup(InstMA)) && + "Representative MemoryAccess mismatch"); + // First, see what happens to the new class + if (!NewClass->getMemoryLeader()) { + // Should be a new class, or a store becoming a leader of a new class. + assert(NewClass->size() == 1 || + (isa<StoreInst>(I) && NewClass->getStoreCount() == 1)); + NewClass->setMemoryLeader(InstMA); + // Mark it touched if we didn't just create a singleton + DEBUG(dbgs() << "Memory class leader change for class " << NewClass->getID() + << " due to new memory instruction becoming leader\n"); + markMemoryLeaderChangeTouched(NewClass); + } + setMemoryClass(InstMA, NewClass); + // Now, fixup the old class if necessary + if (OldClass->getMemoryLeader() == InstMA) { + if (!OldClass->definesNoMemory()) { + OldClass->setMemoryLeader(getNextMemoryLeader(OldClass)); + DEBUG(dbgs() << "Memory class leader change for class " + << OldClass->getID() << " to " + << *OldClass->getMemoryLeader() + << " due to removal of old leader " << *InstMA << "\n"); + markMemoryLeaderChangeTouched(OldClass); + } else + OldClass->setMemoryLeader(nullptr); + } +} + // Move a value, currently in OldClass, to be part of NewClass -// Update OldClass for the move (including changing leaders, etc) -void NewGVN::moveValueToNewCongruenceClass(Instruction *I, +// Update OldClass and NewClass for the move (including changing leaders, etc). +void NewGVN::moveValueToNewCongruenceClass(Instruction *I, const Expression *E, CongruenceClass *OldClass, CongruenceClass *NewClass) { - DEBUG(dbgs() << "New congruence class for " << I << " is " << NewClass->ID - << "\n"); - - if (I == OldClass->NextLeader.first) - OldClass->NextLeader = {nullptr, ~0U}; - - // It's possible, though unlikely, for us to discover equivalences such - // that the current leader does not dominate the old one. - // This statistic tracks how often this happens. - // We assert on phi nodes when this happens, currently, for debugging, because - // we want to make sure we name phi node cycles properly. - if (isa<Instruction>(NewClass->RepLeader) && NewClass->RepLeader && - I != NewClass->RepLeader && - DT->properlyDominates( - I->getParent(), - cast<Instruction>(NewClass->RepLeader)->getParent())) { - ++NumGVNNotMostDominatingLeader; - assert(!isa<PHINode>(I) && - "New class for instruction should not be dominated by instruction"); - } - - if (NewClass->RepLeader != I) { - auto DFSNum = InstrDFS.lookup(I); - if (DFSNum < NewClass->NextLeader.second) - NewClass->NextLeader = {I, DFSNum}; - } - - OldClass->Members.erase(I); - NewClass->Members.insert(I); - if (isa<StoreInst>(I)) { - --OldClass->StoreCount; - assert(OldClass->StoreCount >= 0); - ++NewClass->StoreCount; - assert(NewClass->StoreCount > 0); + if (I == OldClass->getNextLeader().first) + OldClass->resetNextLeader(); + + OldClass->erase(I); + NewClass->insert(I); + + if (NewClass->getLeader() != I) + NewClass->addPossibleNextLeader({I, InstrToDFSNum(I)}); + // Handle our special casing of stores. + if (auto *SI = dyn_cast<StoreInst>(I)) { + OldClass->decStoreCount(); + // Okay, so when do we want to make a store a leader of a class? + // If we have a store defined by an earlier load, we want the earlier load + // to lead the class. + // If we have a store defined by something else, we want the store to lead + // the class so everything else gets the "something else" as a value. + // If we have a store as the single member of the class, we want the store + // as the leader + if (NewClass->getStoreCount() == 0 && !NewClass->getStoredValue()) { + // If it's a store expression we are using, it means we are not equivalent + // to something earlier. + if (auto *SE = dyn_cast<StoreExpression>(E)) { + NewClass->setStoredValue(SE->getStoredValue()); + markValueLeaderChangeTouched(NewClass); + // Shift the new class leader to be the store + DEBUG(dbgs() << "Changing leader of congruence class " + << NewClass->getID() << " from " << *NewClass->getLeader() + << " to " << *SI << " because store joined class\n"); + // If we changed the leader, we have to mark it changed because we don't + // know what it will do to symbolic evaluation. + NewClass->setLeader(SI); + } + // We rely on the code below handling the MemoryAccess change. + } + NewClass->incStoreCount(); } + // True if there is no memory instructions left in a class that had memory + // instructions before. + // If it's not a memory use, set the MemoryAccess equivalence + auto *InstMA = dyn_cast_or_null<MemoryDef>(getMemoryAccess(I)); + if (InstMA) + moveMemoryToNewCongruenceClass(I, InstMA, OldClass, NewClass); ValueToClass[I] = NewClass; // See if we destroyed the class or need to swap leaders. - if (OldClass->Members.empty() && OldClass != InitialClass) { - if (OldClass->DefiningExpr) { - OldClass->Dead = true; - DEBUG(dbgs() << "Erasing expression " << OldClass->DefiningExpr + if (OldClass->empty() && OldClass != TOPClass) { + if (OldClass->getDefiningExpr()) { + DEBUG(dbgs() << "Erasing expression " << *OldClass->getDefiningExpr() << " from table\n"); - ExpressionToClass.erase(OldClass->DefiningExpr); + // We erase it as an exact expression to make sure we don't just erase an + // equivalent one. + auto Iter = ExpressionToClass.find_as( + ExactEqualsExpression(*OldClass->getDefiningExpr())); + if (Iter != ExpressionToClass.end()) + ExpressionToClass.erase(Iter); +#ifdef EXPENSIVE_CHECKS + assert( + (*OldClass->getDefiningExpr() != *E || ExpressionToClass.lookup(E)) && + "We erased the expression we just inserted, which should not happen"); +#endif } - } else if (OldClass->RepLeader == I) { + } else if (OldClass->getLeader() == I) { // When the leader changes, the value numbering of // everything may change due to symbolization changes, so we need to // reprocess. - DEBUG(dbgs() << "Leader change!\n"); + DEBUG(dbgs() << "Value class leader change for class " << OldClass->getID() + << "\n"); ++NumGVNLeaderChanges; - // We don't need to sort members if there is only 1, and we don't care about - // sorting the initial class because everything either gets out of it or is - // unreachable. - if (OldClass->Members.size() == 1 || OldClass == InitialClass) { - OldClass->RepLeader = *(OldClass->Members.begin()); - } else if (OldClass->NextLeader.first) { - ++NumGVNAvoidedSortedLeaderChanges; - OldClass->RepLeader = OldClass->NextLeader.first; - OldClass->NextLeader = {nullptr, ~0U}; - } else { - ++NumGVNSortedLeaderChanges; - // TODO: If this ends up to slow, we can maintain a dual structure for - // member testing/insertion, or keep things mostly sorted, and sort only - // here, or .... - std::pair<Value *, unsigned> MinDFS = {nullptr, ~0U}; - for (const auto X : OldClass->Members) { - auto DFSNum = InstrDFS.lookup(X); - if (DFSNum < MinDFS.second) - MinDFS = {X, DFSNum}; - } - OldClass->RepLeader = MinDFS.first; + // Destroy the stored value if there are no more stores to represent it. + // Note that this is basically clean up for the expression removal that + // happens below. If we remove stores from a class, we may leave it as a + // class of equivalent memory phis. + if (OldClass->getStoreCount() == 0) { + if (OldClass->getStoredValue()) + OldClass->setStoredValue(nullptr); } - markLeaderChangeTouched(OldClass); + OldClass->setLeader(getNextValueLeader(OldClass)); + OldClass->resetNextLeader(); + markValueLeaderChangeTouched(OldClass); } } +// For a given expression, mark the phi of ops instructions that could have +// changed as a result. +void NewGVN::markPhiOfOpsChanged(const Expression *E) { + touchAndErase(ExpressionToPhiOfOps, ExactEqualsExpression(*E)); +} + // Perform congruence finding on a given value numbering expression. void NewGVN::performCongruenceFinding(Instruction *I, const Expression *E) { - ValueToExpression[I] = E; // This is guaranteed to return something, since it will at least find - // INITIAL. + // TOP. - CongruenceClass *IClass = ValueToClass[I]; + CongruenceClass *IClass = ValueToClass.lookup(I); assert(IClass && "Should have found a IClass"); // Dead classes should have been eliminated from the mapping. - assert(!IClass->Dead && "Found a dead class"); + assert(!IClass->isDead() && "Found a dead class"); - CongruenceClass *EClass; + CongruenceClass *EClass = nullptr; if (const auto *VE = dyn_cast<VariableExpression>(E)) { - EClass = ValueToClass[VE->getVariableValue()]; - } else { + EClass = ValueToClass.lookup(VE->getVariableValue()); + } else if (isa<DeadExpression>(E)) { + EClass = TOPClass; + } + if (!EClass) { auto lookupResult = ExpressionToClass.insert({E, nullptr}); // If it's not in the value table, create a new congruence class. @@ -1171,80 +2222,73 @@ void NewGVN::performCongruenceFinding(Instruction *I, const Expression *E) { // Constants and variables should always be made the leader. if (const auto *CE = dyn_cast<ConstantExpression>(E)) { - NewClass->RepLeader = CE->getConstantValue(); + NewClass->setLeader(CE->getConstantValue()); } else if (const auto *SE = dyn_cast<StoreExpression>(E)) { StoreInst *SI = SE->getStoreInst(); - NewClass->RepLeader = - lookupOperandLeader(SI->getValueOperand(), SI, SI->getParent()); + NewClass->setLeader(SI); + NewClass->setStoredValue(SE->getStoredValue()); + // The RepMemoryAccess field will be filled in properly by the + // moveValueToNewCongruenceClass call. } else { - NewClass->RepLeader = I; + NewClass->setLeader(I); } assert(!isa<VariableExpression>(E) && "VariableExpression should have been handled already"); EClass = NewClass; DEBUG(dbgs() << "Created new congruence class for " << *I - << " using expression " << *E << " at " << NewClass->ID - << " and leader " << *(NewClass->RepLeader) << "\n"); - DEBUG(dbgs() << "Hash value was " << E->getHashValue() << "\n"); + << " using expression " << *E << " at " << NewClass->getID() + << " and leader " << *(NewClass->getLeader())); + if (NewClass->getStoredValue()) + DEBUG(dbgs() << " and stored value " << *(NewClass->getStoredValue())); + DEBUG(dbgs() << "\n"); } else { EClass = lookupResult.first->second; if (isa<ConstantExpression>(E)) - assert(isa<Constant>(EClass->RepLeader) && + assert((isa<Constant>(EClass->getLeader()) || + (EClass->getStoredValue() && + isa<Constant>(EClass->getStoredValue()))) && "Any class with a constant expression should have a " "constant leader"); assert(EClass && "Somehow don't have an eclass"); - assert(!EClass->Dead && "We accidentally looked up a dead class"); + assert(!EClass->isDead() && "We accidentally looked up a dead class"); } } bool ClassChanged = IClass != EClass; bool LeaderChanged = LeaderChanges.erase(I); if (ClassChanged || LeaderChanged) { - DEBUG(dbgs() << "Found class " << EClass->ID << " for expression " << E + DEBUG(dbgs() << "New class " << EClass->getID() << " for expression " << *E << "\n"); + if (ClassChanged) { + moveValueToNewCongruenceClass(I, E, IClass, EClass); + markPhiOfOpsChanged(E); + } - if (ClassChanged) - moveValueToNewCongruenceClass(I, IClass, EClass); markUsersTouched(I); - if (MemoryAccess *MA = MSSA->getMemoryAccess(I)) { - // If this is a MemoryDef, we need to update the equivalence table. If - // we determined the expression is congruent to a different memory - // state, use that different memory state. If we determined it didn't, - // we update that as well. Right now, we only support store - // expressions. - if (!isa<MemoryUse>(MA) && isa<StoreExpression>(E) && - EClass->Members.size() != 1) { - auto *DefAccess = cast<StoreExpression>(E)->getDefiningAccess(); - setMemoryAccessEquivTo(MA, DefAccess != MA ? DefAccess : nullptr); - } else { - setMemoryAccessEquivTo(MA, nullptr); - } + if (MemoryAccess *MA = getMemoryAccess(I)) markMemoryUsersTouched(MA); - } - } else if (auto *SI = dyn_cast<StoreInst>(I)) { - // There is, sadly, one complicating thing for stores. Stores do not - // produce values, only consume them. However, in order to make loads and - // stores value number the same, we ignore the value operand of the store. - // But the value operand will still be the leader of our class, and thus, it - // may change. Because the store is a use, the store will get reprocessed, - // but nothing will change about it, and so nothing above will catch it - // (since the class will not change). In order to make sure everything ends - // up okay, we need to recheck the leader of the class. Since stores of - // different values value number differently due to different memorydefs, we - // are guaranteed the leader is always the same between stores in the same - // class. - DEBUG(dbgs() << "Checking store leader\n"); - auto ProperLeader = - lookupOperandLeader(SI->getValueOperand(), SI, SI->getParent()); - if (EClass->RepLeader != ProperLeader) { - DEBUG(dbgs() << "Store leader changed, fixing\n"); - EClass->RepLeader = ProperLeader; - markLeaderChangeTouched(EClass); - markMemoryUsersTouched(MSSA->getMemoryAccess(SI)); + if (auto *CI = dyn_cast<CmpInst>(I)) + markPredicateUsersTouched(CI); + } + // If we changed the class of the store, we want to ensure nothing finds the + // old store expression. In particular, loads do not compare against stored + // value, so they will find old store expressions (and associated class + // mappings) if we leave them in the table. + if (ClassChanged && isa<StoreInst>(I)) { + auto *OldE = ValueToExpression.lookup(I); + // It could just be that the old class died. We don't want to erase it if we + // just moved classes. + if (OldE && isa<StoreExpression>(OldE) && *E != *OldE) { + // Erase this as an exact expression to ensure we don't erase expressions + // equivalent to it. + auto Iter = ExpressionToClass.find_as(ExactEqualsExpression(*OldE)); + if (Iter != ExpressionToClass.end()) + ExpressionToClass.erase(Iter); } } + ValueToExpression[I] = E; } // Process the fact that Edge (from, to) is reachable, including marking @@ -1266,25 +2310,26 @@ void NewGVN::updateReachableEdge(BasicBlock *From, BasicBlock *To) { // impact predicates. Otherwise, only mark the phi nodes as touched, as // they are the only thing that depend on new edges. Anything using their // values will get propagated to if necessary. - if (MemoryAccess *MemPhi = MSSA->getMemoryAccess(To)) - TouchedInstructions.set(InstrDFS[MemPhi]); + if (MemoryAccess *MemPhi = getMemoryAccess(To)) + TouchedInstructions.set(InstrToDFSNum(MemPhi)); auto BI = To->begin(); while (isa<PHINode>(BI)) { - TouchedInstructions.set(InstrDFS[&*BI]); + TouchedInstructions.set(InstrToDFSNum(&*BI)); ++BI; } + for_each_found(PHIOfOpsPHIs, To, [&](const PHINode *I) { + TouchedInstructions.set(InstrToDFSNum(I)); + }); } } } // Given a predicate condition (from a switch, cmp, or whatever) and a block, // see if we know some constant value for it already. -Value *NewGVN::findConditionEquivalence(Value *Cond, BasicBlock *B) const { - auto Result = lookupOperandLeader(Cond, nullptr, B); - if (isa<Constant>(Result)) - return Result; - return nullptr; +Value *NewGVN::findConditionEquivalence(Value *Cond) const { + auto Result = lookupOperandLeader(Cond); + return isa<Constant>(Result) ? Result : nullptr; } // Process the outgoing edges of a block for reachability. @@ -1293,10 +2338,10 @@ void NewGVN::processOutgoingEdges(TerminatorInst *TI, BasicBlock *B) { BranchInst *BR; if ((BR = dyn_cast<BranchInst>(TI)) && BR->isConditional()) { Value *Cond = BR->getCondition(); - Value *CondEvaluated = findConditionEquivalence(Cond, B); + Value *CondEvaluated = findConditionEquivalence(Cond); if (!CondEvaluated) { if (auto *I = dyn_cast<Instruction>(Cond)) { - const Expression *E = createExpression(I, B); + const Expression *E = createExpression(I); if (const auto *CE = dyn_cast<ConstantExpression>(E)) { CondEvaluated = CE->getConstantValue(); } @@ -1329,13 +2374,13 @@ void NewGVN::processOutgoingEdges(TerminatorInst *TI, BasicBlock *B) { SmallDenseMap<BasicBlock *, unsigned, 16> SwitchEdges; Value *SwitchCond = SI->getCondition(); - Value *CondEvaluated = findConditionEquivalence(SwitchCond, B); + Value *CondEvaluated = findConditionEquivalence(SwitchCond); // See if we were able to turn this switch statement into a constant. if (CondEvaluated && isa<ConstantInt>(CondEvaluated)) { auto *CondVal = cast<ConstantInt>(CondEvaluated); // We should be able to get case value for this. - auto CaseVal = SI->findCaseValue(CondVal); - if (CaseVal.getCaseSuccessor() == SI->getDefaultDest()) { + auto Case = *SI->findCaseValue(CondVal); + if (Case.getCaseSuccessor() == SI->getDefaultDest()) { // We proved the value is outside of the range of the case. // We can't do anything other than mark the default dest as reachable, // and go home. @@ -1343,7 +2388,7 @@ void NewGVN::processOutgoingEdges(TerminatorInst *TI, BasicBlock *B) { return; } // Now get where it goes and mark it reachable. - BasicBlock *TargetBlock = CaseVal.getCaseSuccessor(); + BasicBlock *TargetBlock = Case.getCaseSuccessor(); updateReachableEdge(B, TargetBlock); } else { for (unsigned i = 0, e = SI->getNumSuccessors(); i != e; ++i) { @@ -1361,45 +2406,215 @@ void NewGVN::processOutgoingEdges(TerminatorInst *TI, BasicBlock *B) { } // This also may be a memory defining terminator, in which case, set it - // equivalent to nothing. - if (MemoryAccess *MA = MSSA->getMemoryAccess(TI)) - setMemoryAccessEquivTo(MA, nullptr); + // equivalent only to itself. + // + auto *MA = getMemoryAccess(TI); + if (MA && !isa<MemoryUse>(MA)) { + auto *CC = ensureLeaderOfMemoryClass(MA); + if (setMemoryClass(MA, CC)) + markMemoryUsersTouched(MA); + } } } -// The algorithm initially places the values of the routine in the INITIAL -// congruence -// class. The leader of INITIAL is the undetermined value `TOP`. -// When the algorithm has finished, values still in INITIAL are unreachable. +void NewGVN::addPhiOfOps(PHINode *Op, BasicBlock *BB, + Instruction *ExistingValue) { + InstrDFS[Op] = InstrToDFSNum(ExistingValue); + AllTempInstructions.insert(Op); + PHIOfOpsPHIs[BB].push_back(Op); + TempToBlock[Op] = BB; + RealToTemp[ExistingValue] = Op; +} + +static bool okayForPHIOfOps(const Instruction *I) { + return isa<BinaryOperator>(I) || isa<SelectInst>(I) || isa<CmpInst>(I) || + isa<LoadInst>(I); +} + +// When we see an instruction that is an op of phis, generate the equivalent phi +// of ops form. +const Expression * +NewGVN::makePossiblePhiOfOps(Instruction *I, + SmallPtrSetImpl<Value *> &Visited) { + if (!okayForPHIOfOps(I)) + return nullptr; + + if (!Visited.insert(I).second) + return nullptr; + // For now, we require the instruction be cycle free because we don't + // *always* create a phi of ops for instructions that could be done as phi + // of ops, we only do it if we think it is useful. If we did do it all the + // time, we could remove the cycle free check. + if (!isCycleFree(I)) + return nullptr; + + unsigned IDFSNum = InstrToDFSNum(I); + SmallPtrSet<const Value *, 8> ProcessedPHIs; + // TODO: We don't do phi translation on memory accesses because it's + // complicated. For a load, we'd need to be able to simulate a new memoryuse, + // which we don't have a good way of doing ATM. + auto *MemAccess = getMemoryAccess(I); + // If the memory operation is defined by a memory operation this block that + // isn't a MemoryPhi, transforming the pointer backwards through a scalar phi + // can't help, as it would still be killed by that memory operation. + if (MemAccess && !isa<MemoryPhi>(MemAccess->getDefiningAccess()) && + MemAccess->getDefiningAccess()->getBlock() == I->getParent()) + return nullptr; + + // Convert op of phis to phi of ops + for (auto &Op : I->operands()) { + // TODO: We can't handle expressions that must be recursively translated + // IE + // a = phi (b, c) + // f = use a + // g = f + phi of something + // To properly make a phi of ops for g, we'd have to properly translate and + // use the instruction for f. We should add this by splitting out the + // instruction creation we do below. + if (isa<Instruction>(Op) && PHINodeUses.count(cast<Instruction>(Op))) + return nullptr; + if (!isa<PHINode>(Op)) + continue; + auto *OpPHI = cast<PHINode>(Op); + // No point in doing this for one-operand phis. + if (OpPHI->getNumOperands() == 1) + continue; + if (!DebugCounter::shouldExecute(PHIOfOpsCounter)) + return nullptr; + SmallVector<std::pair<Value *, BasicBlock *>, 4> Ops; + auto *PHIBlock = getBlockForValue(OpPHI); + for (auto PredBB : OpPHI->blocks()) { + Value *FoundVal = nullptr; + // We could just skip unreachable edges entirely but it's tricky to do + // with rewriting existing phi nodes. + if (ReachableEdges.count({PredBB, PHIBlock})) { + // Clone the instruction, create an expression from it, and see if we + // have a leader. + Instruction *ValueOp = I->clone(); + if (MemAccess) + TempToMemory.insert({ValueOp, MemAccess}); + + for (auto &Op : ValueOp->operands()) { + Op = Op->DoPHITranslation(PHIBlock, PredBB); + // When this operand changes, it could change whether there is a + // leader for us or not. + addAdditionalUsers(Op, I); + } + // Make sure it's marked as a temporary instruction. + AllTempInstructions.insert(ValueOp); + // and make sure anything that tries to add it's DFS number is + // redirected to the instruction we are making a phi of ops + // for. + InstrDFS.insert({ValueOp, IDFSNum}); + const Expression *E = performSymbolicEvaluation(ValueOp, Visited); + InstrDFS.erase(ValueOp); + AllTempInstructions.erase(ValueOp); + ValueOp->deleteValue(); + if (MemAccess) + TempToMemory.erase(ValueOp); + if (!E) + return nullptr; + FoundVal = findPhiOfOpsLeader(E, PredBB); + if (!FoundVal) { + ExpressionToPhiOfOps[E].insert(I); + return nullptr; + } + if (auto *SI = dyn_cast<StoreInst>(FoundVal)) + FoundVal = SI->getValueOperand(); + } else { + DEBUG(dbgs() << "Skipping phi of ops operand for incoming block " + << getBlockName(PredBB) + << " because the block is unreachable\n"); + FoundVal = UndefValue::get(I->getType()); + } + + Ops.push_back({FoundVal, PredBB}); + DEBUG(dbgs() << "Found phi of ops operand " << *FoundVal << " in " + << getBlockName(PredBB) << "\n"); + } + auto *ValuePHI = RealToTemp.lookup(I); + bool NewPHI = false; + if (!ValuePHI) { + ValuePHI = PHINode::Create(I->getType(), OpPHI->getNumOperands()); + addPhiOfOps(ValuePHI, PHIBlock, I); + NewPHI = true; + NumGVNPHIOfOpsCreated++; + } + if (NewPHI) { + for (auto PHIOp : Ops) + ValuePHI->addIncoming(PHIOp.first, PHIOp.second); + } else { + unsigned int i = 0; + for (auto PHIOp : Ops) { + ValuePHI->setIncomingValue(i, PHIOp.first); + ValuePHI->setIncomingBlock(i, PHIOp.second); + ++i; + } + } + + DEBUG(dbgs() << "Created phi of ops " << *ValuePHI << " for " << *I + << "\n"); + return performSymbolicEvaluation(ValuePHI, Visited); + } + return nullptr; +} + +// The algorithm initially places the values of the routine in the TOP +// congruence class. The leader of TOP is the undetermined value `undef`. +// When the algorithm has finished, values still in TOP are unreachable. void NewGVN::initializeCongruenceClasses(Function &F) { - // FIXME now i can't remember why this is 2 - NextCongruenceNum = 2; - // Initialize all other instructions to be in INITIAL class. - CongruenceClass::MemberSet InitialValues; - InitialClass = createCongruenceClass(nullptr, nullptr); - for (auto &B : F) { - if (auto *MP = MSSA->getMemoryAccess(&B)) - MemoryAccessEquiv.insert({MP, MSSA->getLiveOnEntryDef()}); - - for (auto &I : B) { - InitialValues.insert(&I); - ValueToClass[&I] = InitialClass; - // All memory accesses are equivalent to live on entry to start. They must - // be initialized to something so that initial changes are noticed. For - // the maximal answer, we initialize them all to be the same as - // liveOnEntry. Note that to save time, we only initialize the - // MemoryDef's for stores and all MemoryPhis to be equal. Right now, no - // other expression can generate a memory equivalence. If we start - // handling memcpy/etc, we can expand this. - if (isa<StoreInst>(&I)) { - MemoryAccessEquiv.insert( - {MSSA->getMemoryAccess(&I), MSSA->getLiveOnEntryDef()}); - ++InitialClass->StoreCount; - assert(InitialClass->StoreCount > 0); + NextCongruenceNum = 0; + + // Note that even though we use the live on entry def as a representative + // MemoryAccess, it is *not* the same as the actual live on entry def. We + // have no real equivalemnt to undef for MemoryAccesses, and so we really + // should be checking whether the MemoryAccess is top if we want to know if it + // is equivalent to everything. Otherwise, what this really signifies is that + // the access "it reaches all the way back to the beginning of the function" + + // Initialize all other instructions to be in TOP class. + TOPClass = createCongruenceClass(nullptr, nullptr); + TOPClass->setMemoryLeader(MSSA->getLiveOnEntryDef()); + // The live on entry def gets put into it's own class + MemoryAccessToClass[MSSA->getLiveOnEntryDef()] = + createMemoryClass(MSSA->getLiveOnEntryDef()); + + for (auto DTN : nodes(DT)) { + BasicBlock *BB = DTN->getBlock(); + // All MemoryAccesses are equivalent to live on entry to start. They must + // be initialized to something so that initial changes are noticed. For + // the maximal answer, we initialize them all to be the same as + // liveOnEntry. + auto *MemoryBlockDefs = MSSA->getBlockDefs(BB); + if (MemoryBlockDefs) + for (const auto &Def : *MemoryBlockDefs) { + MemoryAccessToClass[&Def] = TOPClass; + auto *MD = dyn_cast<MemoryDef>(&Def); + // Insert the memory phis into the member list. + if (!MD) { + const MemoryPhi *MP = cast<MemoryPhi>(&Def); + TOPClass->memory_insert(MP); + MemoryPhiState.insert({MP, MPS_TOP}); + } + + if (MD && isa<StoreInst>(MD->getMemoryInst())) + TOPClass->incStoreCount(); } + for (auto &I : *BB) { + // TODO: Move to helper + if (isa<PHINode>(&I)) + for (auto *U : I.users()) + if (auto *UInst = dyn_cast<Instruction>(U)) + if (InstrToDFSNum(UInst) != 0 && okayForPHIOfOps(UInst)) + PHINodeUses.insert(UInst); + // Don't insert void terminators into the class. We don't value number + // them, and they just end up sitting in TOP. + if (isa<TerminatorInst>(I) && I.getType()->isVoidTy()) + continue; + TOPClass->insert(&I); + ValueToClass[&I] = TOPClass; } } - InitialClass->Members.swap(InitialValues); // Initialize arguments to be in their own unique congruence classes for (auto &FA : F.args()) @@ -1408,45 +2623,79 @@ void NewGVN::initializeCongruenceClasses(Function &F) { void NewGVN::cleanupTables() { for (unsigned i = 0, e = CongruenceClasses.size(); i != e; ++i) { - DEBUG(dbgs() << "Congruence class " << CongruenceClasses[i]->ID << " has " - << CongruenceClasses[i]->Members.size() << " members\n"); + DEBUG(dbgs() << "Congruence class " << CongruenceClasses[i]->getID() + << " has " << CongruenceClasses[i]->size() << " members\n"); // Make sure we delete the congruence class (probably worth switching to // a unique_ptr at some point. delete CongruenceClasses[i]; CongruenceClasses[i] = nullptr; } + // Destroy the value expressions + SmallVector<Instruction *, 8> TempInst(AllTempInstructions.begin(), + AllTempInstructions.end()); + AllTempInstructions.clear(); + + // We have to drop all references for everything first, so there are no uses + // left as we delete them. + for (auto *I : TempInst) { + I->dropAllReferences(); + } + + while (!TempInst.empty()) { + auto *I = TempInst.back(); + TempInst.pop_back(); + I->deleteValue(); + } + ValueToClass.clear(); ArgRecycler.clear(ExpressionAllocator); ExpressionAllocator.Reset(); CongruenceClasses.clear(); ExpressionToClass.clear(); ValueToExpression.clear(); + RealToTemp.clear(); + AdditionalUsers.clear(); + ExpressionToPhiOfOps.clear(); + TempToBlock.clear(); + TempToMemory.clear(); + PHIOfOpsPHIs.clear(); ReachableBlocks.clear(); ReachableEdges.clear(); #ifndef NDEBUG ProcessedCount.clear(); #endif - DFSDomMap.clear(); InstrDFS.clear(); InstructionsToErase.clear(); - DFSToInstr.clear(); BlockInstRange.clear(); TouchedInstructions.clear(); - DominatedInstRange.clear(); - MemoryAccessEquiv.clear(); + MemoryAccessToClass.clear(); + PredicateToUsers.clear(); + MemoryToUsers.clear(); } +// Assign local DFS number mapping to instructions, and leave space for Value +// PHI's. std::pair<unsigned, unsigned> NewGVN::assignDFSNumbers(BasicBlock *B, unsigned Start) { unsigned End = Start; - if (MemoryAccess *MemPhi = MSSA->getMemoryAccess(B)) { + if (MemoryAccess *MemPhi = getMemoryAccess(B)) { InstrDFS[MemPhi] = End++; DFSToInstr.emplace_back(MemPhi); } + // Then the real block goes next. for (auto &I : *B) { + // There's no need to call isInstructionTriviallyDead more than once on + // an instruction. Therefore, once we know that an instruction is dead + // we change its DFS number so that it doesn't get value numbered. + if (isInstructionTriviallyDead(&I, TLI)) { + InstrDFS[&I] = 0; + DEBUG(dbgs() << "Skipping trivially dead instruction " << I << "\n"); + markInstructionForDeletion(&I); + continue; + } InstrDFS[&I] = End++; DFSToInstr.emplace_back(&I); } @@ -1457,12 +2706,12 @@ std::pair<unsigned, unsigned> NewGVN::assignDFSNumbers(BasicBlock *B, return std::make_pair(Start, End); } -void NewGVN::updateProcessedCount(Value *V) { +void NewGVN::updateProcessedCount(const Value *V) { #ifndef NDEBUG if (ProcessedCount.count(V) == 0) { ProcessedCount.insert({V, 1}); } else { - ProcessedCount[V] += 1; + ++ProcessedCount[V]; assert(ProcessedCount[V] < 100 && "Seem to have processed the same Value a lot"); } @@ -1471,27 +2720,35 @@ void NewGVN::updateProcessedCount(Value *V) { // Evaluate MemoryPhi nodes symbolically, just like PHI nodes void NewGVN::valueNumberMemoryPhi(MemoryPhi *MP) { // If all the arguments are the same, the MemoryPhi has the same value as the - // argument. - // Filter out unreachable blocks from our operands. + // argument. Filter out unreachable blocks and self phis from our operands. + // TODO: We could do cycle-checking on the memory phis to allow valueizing for + // self-phi checking. + const BasicBlock *PHIBlock = MP->getBlock(); auto Filtered = make_filter_range(MP->operands(), [&](const Use &U) { - return ReachableBlocks.count(MP->getIncomingBlock(U)); + return cast<MemoryAccess>(U) != MP && + !isMemoryAccessTOP(cast<MemoryAccess>(U)) && + ReachableEdges.count({MP->getIncomingBlock(U), PHIBlock}); }); - - assert(Filtered.begin() != Filtered.end() && - "We should not be processing a MemoryPhi in a completely " - "unreachable block"); + // If all that is left is nothing, our memoryphi is undef. We keep it as + // InitialClass. Note: The only case this should happen is if we have at + // least one self-argument. + if (Filtered.begin() == Filtered.end()) { + if (setMemoryClass(MP, TOPClass)) + markMemoryUsersTouched(MP); + return; + } // Transform the remaining operands into operand leaders. // FIXME: mapped_iterator should have a range version. auto LookupFunc = [&](const Use &U) { - return lookupMemoryAccessEquiv(cast<MemoryAccess>(U)); + return lookupMemoryLeader(cast<MemoryAccess>(U)); }; auto MappedBegin = map_iterator(Filtered.begin(), LookupFunc); auto MappedEnd = map_iterator(Filtered.end(), LookupFunc); // and now check if all the elements are equal. // Sadly, we can't use std::equals since these are random access iterators. - MemoryAccess *AllSameValue = *MappedBegin; + const auto *AllSameValue = *MappedBegin; ++MappedBegin; bool AllEqual = std::all_of( MappedBegin, MappedEnd, @@ -1501,8 +2758,18 @@ void NewGVN::valueNumberMemoryPhi(MemoryPhi *MP) { DEBUG(dbgs() << "Memory Phi value numbered to " << *AllSameValue << "\n"); else DEBUG(dbgs() << "Memory Phi value numbered to itself\n"); - - if (setMemoryAccessEquivTo(MP, AllEqual ? AllSameValue : nullptr)) + // If it's equal to something, it's in that class. Otherwise, it has to be in + // a class where it is the leader (other things may be equivalent to it, but + // it needs to start off in its own class, which means it must have been the + // leader, and it can't have stopped being the leader because it was never + // removed). + CongruenceClass *CC = + AllEqual ? getMemoryClass(AllSameValue) : ensureLeaderOfMemoryClass(MP); + auto OldState = MemoryPhiState.lookup(MP); + assert(OldState != MPS_Invalid && "Invalid memory phi state"); + auto NewState = AllEqual ? MPS_Equivalent : MPS_Unique; + MemoryPhiState[MP] = NewState; + if (setMemoryClass(MP, CC) || OldState != NewState) markMemoryUsersTouched(MP); } @@ -1510,13 +2777,23 @@ void NewGVN::valueNumberMemoryPhi(MemoryPhi *MP) { // congruence finding, and updating mappings. void NewGVN::valueNumberInstruction(Instruction *I) { DEBUG(dbgs() << "Processing instruction " << *I << "\n"); - if (isInstructionTriviallyDead(I, TLI)) { - DEBUG(dbgs() << "Skipping unused instruction\n"); - markInstructionForDeletion(I); - return; - } if (!I->isTerminator()) { - const auto *Symbolized = performSymbolicEvaluation(I, I->getParent()); + const Expression *Symbolized = nullptr; + SmallPtrSet<Value *, 2> Visited; + if (DebugCounter::shouldExecute(VNCounter)) { + Symbolized = performSymbolicEvaluation(I, Visited); + // Make a phi of ops if necessary + if (Symbolized && !isa<ConstantExpression>(Symbolized) && + !isa<VariableExpression>(Symbolized) && PHINodeUses.count(I)) { + auto *PHIE = makePossiblePhiOfOps(I, Visited); + if (PHIE) + Symbolized = PHIE; + } + + } else { + // Mark the instruction as unused so we don't value number it again. + InstrDFS[I] = 0; + } // If we couldn't come up with a symbolic expression, use the unknown // expression if (Symbolized == nullptr) @@ -1524,7 +2801,8 @@ void NewGVN::valueNumberInstruction(Instruction *I) { performCongruenceFinding(I, Symbolized); } else { // Handle terminators that return values. All of them produce values we - // don't currently understand. + // don't currently understand. We don't place non-value producing + // terminators in a class. if (!I->getType()->isVoidTy()) { auto *Symbolized = createUnknownExpression(I); performCongruenceFinding(I, Symbolized); @@ -1535,76 +2813,126 @@ void NewGVN::valueNumberInstruction(Instruction *I) { // Check if there is a path, using single or equal argument phi nodes, from // First to Second. -bool NewGVN::singleReachablePHIPath(const MemoryAccess *First, - const MemoryAccess *Second) const { +bool NewGVN::singleReachablePHIPath( + SmallPtrSet<const MemoryAccess *, 8> &Visited, const MemoryAccess *First, + const MemoryAccess *Second) const { if (First == Second) return true; - - if (auto *FirstDef = dyn_cast<MemoryUseOrDef>(First)) { - auto *DefAccess = FirstDef->getDefiningAccess(); - return singleReachablePHIPath(DefAccess, Second); - } else { - auto *MP = cast<MemoryPhi>(First); - auto ReachableOperandPred = [&](const Use &U) { - return ReachableBlocks.count(MP->getIncomingBlock(U)); - }; - auto FilteredPhiArgs = - make_filter_range(MP->operands(), ReachableOperandPred); - SmallVector<const Value *, 32> OperandList; - std::copy(FilteredPhiArgs.begin(), FilteredPhiArgs.end(), - std::back_inserter(OperandList)); - bool Okay = OperandList.size() == 1; - if (!Okay) - Okay = std::equal(OperandList.begin(), OperandList.end(), - OperandList.begin()); - if (Okay) - return singleReachablePHIPath(cast<MemoryAccess>(OperandList[0]), Second); + if (MSSA->isLiveOnEntryDef(First)) return false; + + // This is not perfect, but as we're just verifying here, we can live with + // the loss of precision. The real solution would be that of doing strongly + // connected component finding in this routine, and it's probably not worth + // the complexity for the time being. So, we just keep a set of visited + // MemoryAccess and return true when we hit a cycle. + if (Visited.count(First)) + return true; + Visited.insert(First); + + const auto *EndDef = First; + for (auto *ChainDef : optimized_def_chain(First)) { + if (ChainDef == Second) + return true; + if (MSSA->isLiveOnEntryDef(ChainDef)) + return false; + EndDef = ChainDef; } + auto *MP = cast<MemoryPhi>(EndDef); + auto ReachableOperandPred = [&](const Use &U) { + return ReachableEdges.count({MP->getIncomingBlock(U), MP->getBlock()}); + }; + auto FilteredPhiArgs = + make_filter_range(MP->operands(), ReachableOperandPred); + SmallVector<const Value *, 32> OperandList; + std::copy(FilteredPhiArgs.begin(), FilteredPhiArgs.end(), + std::back_inserter(OperandList)); + bool Okay = OperandList.size() == 1; + if (!Okay) + Okay = + std::equal(OperandList.begin(), OperandList.end(), OperandList.begin()); + if (Okay) + return singleReachablePHIPath(Visited, cast<MemoryAccess>(OperandList[0]), + Second); + return false; } // Verify the that the memory equivalence table makes sense relative to the // congruence classes. Note that this checking is not perfect, and is currently -// subject to very rare false negatives. It is only useful for testing/debugging. +// subject to very rare false negatives. It is only useful for +// testing/debugging. void NewGVN::verifyMemoryCongruency() const { - // Anything equivalent in the memory access table should be in the same +#ifndef NDEBUG + // Verify that the memory table equivalence and memory member set match + for (const auto *CC : CongruenceClasses) { + if (CC == TOPClass || CC->isDead()) + continue; + if (CC->getStoreCount() != 0) { + assert((CC->getStoredValue() || !isa<StoreInst>(CC->getLeader())) && + "Any class with a store as a leader should have a " + "representative stored value"); + assert(CC->getMemoryLeader() && + "Any congruence class with a store should have a " + "representative access"); + } + + if (CC->getMemoryLeader()) + assert(MemoryAccessToClass.lookup(CC->getMemoryLeader()) == CC && + "Representative MemoryAccess does not appear to be reverse " + "mapped properly"); + for (auto M : CC->memory()) + assert(MemoryAccessToClass.lookup(M) == CC && + "Memory member does not appear to be reverse mapped properly"); + } + + // Anything equivalent in the MemoryAccess table should be in the same // congruence class. // Filter out the unreachable and trivially dead entries, because they may // never have been updated if the instructions were not processed. auto ReachableAccessPred = - [&](const std::pair<const MemoryAccess *, MemoryAccess *> Pair) { + [&](const std::pair<const MemoryAccess *, CongruenceClass *> Pair) { bool Result = ReachableBlocks.count(Pair.first->getBlock()); - if (!Result) + if (!Result || MSSA->isLiveOnEntryDef(Pair.first) || + MemoryToDFSNum(Pair.first) == 0) return false; if (auto *MemDef = dyn_cast<MemoryDef>(Pair.first)) return !isInstructionTriviallyDead(MemDef->getMemoryInst()); + + // We could have phi nodes which operands are all trivially dead, + // so we don't process them. + if (auto *MemPHI = dyn_cast<MemoryPhi>(Pair.first)) { + for (auto &U : MemPHI->incoming_values()) { + if (Instruction *I = dyn_cast<Instruction>(U.get())) { + if (!isInstructionTriviallyDead(I)) + return true; + } + } + return false; + } + return true; }; - auto Filtered = make_filter_range(MemoryAccessEquiv, ReachableAccessPred); + auto Filtered = make_filter_range(MemoryAccessToClass, ReachableAccessPred); for (auto KV : Filtered) { - assert(KV.first != KV.second && - "We added a useless equivalence to the memory equivalence table"); - // Unreachable instructions may not have changed because we never process - // them. - if (!ReachableBlocks.count(KV.first->getBlock())) - continue; if (auto *FirstMUD = dyn_cast<MemoryUseOrDef>(KV.first)) { - auto *SecondMUD = dyn_cast<MemoryUseOrDef>(KV.second); - if (FirstMUD && SecondMUD) - assert((singleReachablePHIPath(FirstMUD, SecondMUD) || - ValueToClass.lookup(FirstMUD->getMemoryInst()) == - ValueToClass.lookup(SecondMUD->getMemoryInst())) && - "The instructions for these memory operations should have " - "been in the same congruence class or reachable through" - "a single argument phi"); + auto *SecondMUD = dyn_cast<MemoryUseOrDef>(KV.second->getMemoryLeader()); + if (FirstMUD && SecondMUD) { + SmallPtrSet<const MemoryAccess *, 8> VisitedMAS; + assert((singleReachablePHIPath(VisitedMAS, FirstMUD, SecondMUD) || + ValueToClass.lookup(FirstMUD->getMemoryInst()) == + ValueToClass.lookup(SecondMUD->getMemoryInst())) && + "The instructions for these memory operations should have " + "been in the same congruence class or reachable through" + "a single argument phi"); + } } else if (auto *FirstMP = dyn_cast<MemoryPhi>(KV.first)) { - // We can only sanely verify that MemoryDefs in the operand list all have // the same class. auto ReachableOperandPred = [&](const Use &U) { - return ReachableBlocks.count(FirstMP->getIncomingBlock(U)) && + return ReachableEdges.count( + {FirstMP->getIncomingBlock(U), FirstMP->getBlock()}) && isa<MemoryDef>(U); }; @@ -1622,35 +2950,179 @@ void NewGVN::verifyMemoryCongruency() const { "All MemoryPhi arguments should be in the same class"); } } +#endif +} + +// Verify that the sparse propagation we did actually found the maximal fixpoint +// We do this by storing the value to class mapping, touching all instructions, +// and redoing the iteration to see if anything changed. +void NewGVN::verifyIterationSettled(Function &F) { +#ifndef NDEBUG + DEBUG(dbgs() << "Beginning iteration verification\n"); + if (DebugCounter::isCounterSet(VNCounter)) + DebugCounter::setCounterValue(VNCounter, StartingVNCounter); + + // Note that we have to store the actual classes, as we may change existing + // classes during iteration. This is because our memory iteration propagation + // is not perfect, and so may waste a little work. But it should generate + // exactly the same congruence classes we have now, with different IDs. + std::map<const Value *, CongruenceClass> BeforeIteration; + + for (auto &KV : ValueToClass) { + if (auto *I = dyn_cast<Instruction>(KV.first)) + // Skip unused/dead instructions. + if (InstrToDFSNum(I) == 0) + continue; + BeforeIteration.insert({KV.first, *KV.second}); + } + + TouchedInstructions.set(); + TouchedInstructions.reset(0); + iterateTouchedInstructions(); + DenseSet<std::pair<const CongruenceClass *, const CongruenceClass *>> + EqualClasses; + for (const auto &KV : ValueToClass) { + if (auto *I = dyn_cast<Instruction>(KV.first)) + // Skip unused/dead instructions. + if (InstrToDFSNum(I) == 0) + continue; + // We could sink these uses, but i think this adds a bit of clarity here as + // to what we are comparing. + auto *BeforeCC = &BeforeIteration.find(KV.first)->second; + auto *AfterCC = KV.second; + // Note that the classes can't change at this point, so we memoize the set + // that are equal. + if (!EqualClasses.count({BeforeCC, AfterCC})) { + assert(BeforeCC->isEquivalentTo(AfterCC) && + "Value number changed after main loop completed!"); + EqualClasses.insert({BeforeCC, AfterCC}); + } + } +#endif +} + +// Verify that for each store expression in the expression to class mapping, +// only the latest appears, and multiple ones do not appear. +// Because loads do not use the stored value when doing equality with stores, +// if we don't erase the old store expressions from the table, a load can find +// a no-longer valid StoreExpression. +void NewGVN::verifyStoreExpressions() const { +#ifndef NDEBUG + // This is the only use of this, and it's not worth defining a complicated + // densemapinfo hash/equality function for it. + std::set< + std::pair<const Value *, + std::tuple<const Value *, const CongruenceClass *, Value *>>> + StoreExpressionSet; + for (const auto &KV : ExpressionToClass) { + if (auto *SE = dyn_cast<StoreExpression>(KV.first)) { + // Make sure a version that will conflict with loads is not already there + auto Res = StoreExpressionSet.insert( + {SE->getOperand(0), std::make_tuple(SE->getMemoryLeader(), KV.second, + SE->getStoredValue())}); + bool Okay = Res.second; + // It's okay to have the same expression already in there if it is + // identical in nature. + // This can happen when the leader of the stored value changes over time. + if (!Okay) + Okay = (std::get<1>(Res.first->second) == KV.second) && + (lookupOperandLeader(std::get<2>(Res.first->second)) == + lookupOperandLeader(SE->getStoredValue())); + assert(Okay && "Stored expression conflict exists in expression table"); + auto *ValueExpr = ValueToExpression.lookup(SE->getStoreInst()); + assert(ValueExpr && ValueExpr->equals(*SE) && + "StoreExpression in ExpressionToClass is not latest " + "StoreExpression for value"); + } + } +#endif +} + +// This is the main value numbering loop, it iterates over the initial touched +// instruction set, propagating value numbers, marking things touched, etc, +// until the set of touched instructions is completely empty. +void NewGVN::iterateTouchedInstructions() { + unsigned int Iterations = 0; + // Figure out where touchedinstructions starts + int FirstInstr = TouchedInstructions.find_first(); + // Nothing set, nothing to iterate, just return. + if (FirstInstr == -1) + return; + const BasicBlock *LastBlock = getBlockForValue(InstrFromDFSNum(FirstInstr)); + while (TouchedInstructions.any()) { + ++Iterations; + // Walk through all the instructions in all the blocks in RPO. + // TODO: As we hit a new block, we should push and pop equalities into a + // table lookupOperandLeader can use, to catch things PredicateInfo + // might miss, like edge-only equivalences. + for (unsigned InstrNum : TouchedInstructions.set_bits()) { + + // This instruction was found to be dead. We don't bother looking + // at it again. + if (InstrNum == 0) { + TouchedInstructions.reset(InstrNum); + continue; + } + + Value *V = InstrFromDFSNum(InstrNum); + const BasicBlock *CurrBlock = getBlockForValue(V); + + // If we hit a new block, do reachability processing. + if (CurrBlock != LastBlock) { + LastBlock = CurrBlock; + bool BlockReachable = ReachableBlocks.count(CurrBlock); + const auto &CurrInstRange = BlockInstRange.lookup(CurrBlock); + + // If it's not reachable, erase any touched instructions and move on. + if (!BlockReachable) { + TouchedInstructions.reset(CurrInstRange.first, CurrInstRange.second); + DEBUG(dbgs() << "Skipping instructions in block " + << getBlockName(CurrBlock) + << " because it is unreachable\n"); + continue; + } + updateProcessedCount(CurrBlock); + } + // Reset after processing (because we may mark ourselves as touched when + // we propagate equalities). + TouchedInstructions.reset(InstrNum); + + if (auto *MP = dyn_cast<MemoryPhi>(V)) { + DEBUG(dbgs() << "Processing MemoryPhi " << *MP << "\n"); + valueNumberMemoryPhi(MP); + } else if (auto *I = dyn_cast<Instruction>(V)) { + valueNumberInstruction(I); + } else { + llvm_unreachable("Should have been a MemoryPhi or Instruction"); + } + updateProcessedCount(V); + } + } + NumGVNMaxIterations = std::max(NumGVNMaxIterations.getValue(), Iterations); } // This is the main transformation entry point. -bool NewGVN::runGVN(Function &F, DominatorTree *_DT, AssumptionCache *_AC, - TargetLibraryInfo *_TLI, AliasAnalysis *_AA, - MemorySSA *_MSSA) { +bool NewGVN::runGVN() { + if (DebugCounter::isCounterSet(VNCounter)) + StartingVNCounter = DebugCounter::getCounterValue(VNCounter); bool Changed = false; - DT = _DT; - AC = _AC; - TLI = _TLI; - AA = _AA; - MSSA = _MSSA; - DL = &F.getParent()->getDataLayout(); + NumFuncArgs = F.arg_size(); MSSAWalker = MSSA->getWalker(); + SingletonDeadExpression = new (ExpressionAllocator) DeadExpression(); // Count number of instructions for sizing of hash tables, and come // up with a global dfs numbering for instructions. unsigned ICount = 1; // Add an empty instruction to account for the fact that we start at 1 DFSToInstr.emplace_back(nullptr); - // Note: We want RPO traversal of the blocks, which is not quite the same as - // dominator tree order, particularly with regard whether backedges get - // visited first or second, given a block with multiple successors. + // Note: We want ideal RPO traversal of the blocks, which is not quite the + // same as dominator tree order, particularly with regard whether backedges + // get visited first or second, given a block with multiple successors. // If we visit in the wrong order, we will end up performing N times as many // iterations. // The dominator tree does guarantee that, for a given dom tree node, it's // parent must occur before it in the RPO ordering. Thus, we only need to sort // the siblings. - DenseMap<const DomTreeNode *, unsigned> RPOOrdering; ReversePostOrderTraversal<Function *> RPOT(&F); unsigned Counter = 0; for (auto &B : RPOT) { @@ -1663,33 +3135,21 @@ bool NewGVN::runGVN(Function &F, DominatorTree *_DT, AssumptionCache *_AC, auto *Node = DT->getNode(B); if (Node->getChildren().size() > 1) std::sort(Node->begin(), Node->end(), - [&RPOOrdering](const DomTreeNode *A, const DomTreeNode *B) { + [&](const DomTreeNode *A, const DomTreeNode *B) { return RPOOrdering[A] < RPOOrdering[B]; }); } // Now a standard depth first ordering of the domtree is equivalent to RPO. - auto DFI = df_begin(DT->getRootNode()); - for (auto DFE = df_end(DT->getRootNode()); DFI != DFE; ++DFI) { - BasicBlock *B = DFI->getBlock(); + for (auto DTN : depth_first(DT->getRootNode())) { + BasicBlock *B = DTN->getBlock(); const auto &BlockRange = assignDFSNumbers(B, ICount); BlockInstRange.insert({B, BlockRange}); ICount += BlockRange.second - BlockRange.first; } - - // Handle forward unreachable blocks and figure out which blocks - // have single preds. - for (auto &B : F) { - // Assign numbers to unreachable blocks. - if (!DFI.nodeVisited(DT->getNode(&B))) { - const auto &BlockRange = assignDFSNumbers(&B, ICount); - BlockInstRange.insert({&B, BlockRange}); - ICount += BlockRange.second - BlockRange.first; - } - } + initializeCongruenceClasses(F); TouchedInstructions.resize(ICount); - DominatedInstRange.reserve(F.size()); // Ensure we don't end up resizing the expressionToClass map, as // that can be quite expensive. At most, we have one expression per // instruction. @@ -1698,65 +3158,15 @@ bool NewGVN::runGVN(Function &F, DominatorTree *_DT, AssumptionCache *_AC, // Initialize the touched instructions to include the entry block. const auto &InstRange = BlockInstRange.lookup(&F.getEntryBlock()); TouchedInstructions.set(InstRange.first, InstRange.second); + DEBUG(dbgs() << "Block " << getBlockName(&F.getEntryBlock()) + << " marked reachable\n"); ReachableBlocks.insert(&F.getEntryBlock()); - initializeCongruenceClasses(F); - - unsigned int Iterations = 0; - // We start out in the entry block. - BasicBlock *LastBlock = &F.getEntryBlock(); - while (TouchedInstructions.any()) { - ++Iterations; - // Walk through all the instructions in all the blocks in RPO. - for (int InstrNum = TouchedInstructions.find_first(); InstrNum != -1; - InstrNum = TouchedInstructions.find_next(InstrNum)) { - assert(InstrNum != 0 && "Bit 0 should never be set, something touched an " - "instruction not in the lookup table"); - Value *V = DFSToInstr[InstrNum]; - BasicBlock *CurrBlock = nullptr; - - if (auto *I = dyn_cast<Instruction>(V)) - CurrBlock = I->getParent(); - else if (auto *MP = dyn_cast<MemoryPhi>(V)) - CurrBlock = MP->getBlock(); - else - llvm_unreachable("DFSToInstr gave us an unknown type of instruction"); - - // If we hit a new block, do reachability processing. - if (CurrBlock != LastBlock) { - LastBlock = CurrBlock; - bool BlockReachable = ReachableBlocks.count(CurrBlock); - const auto &CurrInstRange = BlockInstRange.lookup(CurrBlock); - - // If it's not reachable, erase any touched instructions and move on. - if (!BlockReachable) { - TouchedInstructions.reset(CurrInstRange.first, CurrInstRange.second); - DEBUG(dbgs() << "Skipping instructions in block " - << getBlockName(CurrBlock) - << " because it is unreachable\n"); - continue; - } - updateProcessedCount(CurrBlock); - } - - if (auto *MP = dyn_cast<MemoryPhi>(V)) { - DEBUG(dbgs() << "Processing MemoryPhi " << *MP << "\n"); - valueNumberMemoryPhi(MP); - } else if (auto *I = dyn_cast<Instruction>(V)) { - valueNumberInstruction(I); - } else { - llvm_unreachable("Should have been a MemoryPhi or Instruction"); - } - updateProcessedCount(V); - // Reset after processing (because we may mark ourselves as touched when - // we propagate equalities). - TouchedInstructions.reset(InstrNum); - } - } - NumGVNMaxIterations = std::max(NumGVNMaxIterations.getValue(), Iterations); -#ifndef NDEBUG + iterateTouchedInstructions(); verifyMemoryCongruency(); -#endif + verifyIterationSettled(F); + verifyStoreExpressions(); + Changed |= eliminateInstructions(F); // Delete all instructions marked for deletion. @@ -1764,7 +3174,8 @@ bool NewGVN::runGVN(Function &F, DominatorTree *_DT, AssumptionCache *_AC, if (!ToErase->use_empty()) ToErase->replaceAllUsesWith(UndefValue::get(ToErase->getType())); - ToErase->eraseFromParent(); + if (ToErase->getParent()) + ToErase->eraseFromParent(); } // Delete all unreachable blocks. @@ -1783,59 +3194,15 @@ bool NewGVN::runGVN(Function &F, DominatorTree *_DT, AssumptionCache *_AC, return Changed; } -bool NewGVN::runOnFunction(Function &F) { - if (skipFunction(F)) - return false; - return runGVN(F, &getAnalysis<DominatorTreeWrapperPass>().getDomTree(), - &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F), - &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(), - &getAnalysis<AAResultsWrapperPass>().getAAResults(), - &getAnalysis<MemorySSAWrapperPass>().getMSSA()); -} - -PreservedAnalyses NewGVNPass::run(Function &F, AnalysisManager<Function> &AM) { - NewGVN Impl; - - // Apparently the order in which we get these results matter for - // the old GVN (see Chandler's comment in GVN.cpp). I'll keep - // the same order here, just in case. - auto &AC = AM.getResult<AssumptionAnalysis>(F); - auto &DT = AM.getResult<DominatorTreeAnalysis>(F); - auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); - auto &AA = AM.getResult<AAManager>(F); - auto &MSSA = AM.getResult<MemorySSAAnalysis>(F).getMSSA(); - bool Changed = Impl.runGVN(F, &DT, &AC, &TLI, &AA, &MSSA); - if (!Changed) - return PreservedAnalyses::all(); - PreservedAnalyses PA; - PA.preserve<DominatorTreeAnalysis>(); - PA.preserve<GlobalsAA>(); - return PA; -} - -// Return true if V is a value that will always be available (IE can -// be placed anywhere) in the function. We don't do globals here -// because they are often worse to put in place. -// TODO: Separate cost from availability -static bool alwaysAvailable(Value *V) { - return isa<Constant>(V) || isa<Argument>(V); -} - -// Get the basic block from an instruction/value. -static BasicBlock *getBlockForValue(Value *V) { - if (auto *I = dyn_cast<Instruction>(V)) - return I->getParent(); - return nullptr; -} - struct NewGVN::ValueDFS { int DFSIn = 0; int DFSOut = 0; int LocalNum = 0; - // Only one of these will be set. - Value *Val = nullptr; + // Only one of Def and U will be set. + // The bool in the Def tells us whether the Def is the stored value of a + // store. + PointerIntPair<Value *, 1, bool> Def; Use *U = nullptr; - bool operator<(const ValueDFS &Other) const { // It's not enough that any given field be less than - we have sets // of fields that need to be evaluated together to give a proper ordering. @@ -1875,89 +3242,163 @@ struct NewGVN::ValueDFS { // but .val and .u. // It does not matter what order we replace these operands in. // You will always end up with the same IR, and this is guaranteed. - return std::tie(DFSIn, DFSOut, LocalNum, Val, U) < - std::tie(Other.DFSIn, Other.DFSOut, Other.LocalNum, Other.Val, + return std::tie(DFSIn, DFSOut, LocalNum, Def, U) < + std::tie(Other.DFSIn, Other.DFSOut, Other.LocalNum, Other.Def, Other.U); } }; -void NewGVN::convertDenseToDFSOrdered( - CongruenceClass::MemberSet &Dense, - SmallVectorImpl<ValueDFS> &DFSOrderedSet) { +// This function converts the set of members for a congruence class from values, +// to sets of defs and uses with associated DFS info. The total number of +// reachable uses for each value is stored in UseCount, and instructions that +// seem +// dead (have no non-dead uses) are stored in ProbablyDead. +void NewGVN::convertClassToDFSOrdered( + const CongruenceClass &Dense, SmallVectorImpl<ValueDFS> &DFSOrderedSet, + DenseMap<const Value *, unsigned int> &UseCounts, + SmallPtrSetImpl<Instruction *> &ProbablyDead) const { for (auto D : Dense) { // First add the value. BasicBlock *BB = getBlockForValue(D); // Constants are handled prior to ever calling this function, so // we should only be left with instructions as members. assert(BB && "Should have figured out a basic block for value"); - ValueDFS VD; - - std::pair<int, int> DFSPair = DFSDomMap[BB]; - assert(DFSPair.first != -1 && DFSPair.second != -1 && "Invalid DFS Pair"); - VD.DFSIn = DFSPair.first; - VD.DFSOut = DFSPair.second; - VD.Val = D; - // If it's an instruction, use the real local dfs number. - if (auto *I = dyn_cast<Instruction>(D)) - VD.LocalNum = InstrDFS[I]; - else - llvm_unreachable("Should have been an instruction"); - - DFSOrderedSet.emplace_back(VD); + ValueDFS VDDef; + DomTreeNode *DomNode = DT->getNode(BB); + VDDef.DFSIn = DomNode->getDFSNumIn(); + VDDef.DFSOut = DomNode->getDFSNumOut(); + // If it's a store, use the leader of the value operand, if it's always + // available, or the value operand. TODO: We could do dominance checks to + // find a dominating leader, but not worth it ATM. + if (auto *SI = dyn_cast<StoreInst>(D)) { + auto Leader = lookupOperandLeader(SI->getValueOperand()); + if (alwaysAvailable(Leader)) { + VDDef.Def.setPointer(Leader); + } else { + VDDef.Def.setPointer(SI->getValueOperand()); + VDDef.Def.setInt(true); + } + } else { + VDDef.Def.setPointer(D); + } + assert(isa<Instruction>(D) && + "The dense set member should always be an instruction"); + Instruction *Def = cast<Instruction>(D); + VDDef.LocalNum = InstrToDFSNum(D); + DFSOrderedSet.push_back(VDDef); + // If there is a phi node equivalent, add it + if (auto *PN = RealToTemp.lookup(Def)) { + auto *PHIE = + dyn_cast_or_null<PHIExpression>(ValueToExpression.lookup(Def)); + if (PHIE) { + VDDef.Def.setInt(false); + VDDef.Def.setPointer(PN); + VDDef.LocalNum = 0; + DFSOrderedSet.push_back(VDDef); + } + } - // Now add the users. - for (auto &U : D->uses()) { + unsigned int UseCount = 0; + // Now add the uses. + for (auto &U : Def->uses()) { if (auto *I = dyn_cast<Instruction>(U.getUser())) { - ValueDFS VD; + // Don't try to replace into dead uses + if (InstructionsToErase.count(I)) + continue; + ValueDFS VDUse; // Put the phi node uses in the incoming block. BasicBlock *IBlock; if (auto *P = dyn_cast<PHINode>(I)) { IBlock = P->getIncomingBlock(U); // Make phi node users appear last in the incoming block // they are from. - VD.LocalNum = InstrDFS.size() + 1; + VDUse.LocalNum = InstrDFS.size() + 1; } else { - IBlock = I->getParent(); - VD.LocalNum = InstrDFS[I]; + IBlock = getBlockForValue(I); + VDUse.LocalNum = InstrToDFSNum(I); } - std::pair<int, int> DFSPair = DFSDomMap[IBlock]; - VD.DFSIn = DFSPair.first; - VD.DFSOut = DFSPair.second; - VD.U = &U; - DFSOrderedSet.emplace_back(VD); + + // Skip uses in unreachable blocks, as we're going + // to delete them. + if (ReachableBlocks.count(IBlock) == 0) + continue; + + DomTreeNode *DomNode = DT->getNode(IBlock); + VDUse.DFSIn = DomNode->getDFSNumIn(); + VDUse.DFSOut = DomNode->getDFSNumOut(); + VDUse.U = &U; + ++UseCount; + DFSOrderedSet.emplace_back(VDUse); } } + + // If there are no uses, it's probably dead (but it may have side-effects, + // so not definitely dead. Otherwise, store the number of uses so we can + // track if it becomes dead later). + if (UseCount == 0) + ProbablyDead.insert(Def); + else + UseCounts[Def] = UseCount; } } -static void patchReplacementInstruction(Instruction *I, Value *Repl) { - // Patch the replacement so that it is not more restrictive than the value - // being replaced. - auto *Op = dyn_cast<BinaryOperator>(I); - auto *ReplOp = dyn_cast<BinaryOperator>(Repl); +// This function converts the set of members for a congruence class from values, +// to the set of defs for loads and stores, with associated DFS info. +void NewGVN::convertClassToLoadsAndStores( + const CongruenceClass &Dense, + SmallVectorImpl<ValueDFS> &LoadsAndStores) const { + for (auto D : Dense) { + if (!isa<LoadInst>(D) && !isa<StoreInst>(D)) + continue; - if (Op && ReplOp) - ReplOp->andIRFlags(Op); + BasicBlock *BB = getBlockForValue(D); + ValueDFS VD; + DomTreeNode *DomNode = DT->getNode(BB); + VD.DFSIn = DomNode->getDFSNumIn(); + VD.DFSOut = DomNode->getDFSNumOut(); + VD.Def.setPointer(D); - if (auto *ReplInst = dyn_cast<Instruction>(Repl)) { - // FIXME: If both the original and replacement value are part of the - // same control-flow region (meaning that the execution of one - // guarentees the executation of the other), then we can combine the - // noalias scopes here and do better than the general conservative - // answer used in combineMetadata(). + // If it's an instruction, use the real local dfs number. + if (auto *I = dyn_cast<Instruction>(D)) + VD.LocalNum = InstrToDFSNum(I); + else + llvm_unreachable("Should have been an instruction"); - // In general, GVN unifies expressions over different control-flow - // regions, and so we need a conservative combination of the noalias - // scopes. - unsigned KnownIDs[] = { - LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, - LLVMContext::MD_noalias, LLVMContext::MD_range, - LLVMContext::MD_fpmath, LLVMContext::MD_invariant_load, - LLVMContext::MD_invariant_group}; - combineMetadata(ReplInst, I, KnownIDs); + LoadsAndStores.emplace_back(VD); } } +static void patchReplacementInstruction(Instruction *I, Value *Repl) { + auto *ReplInst = dyn_cast<Instruction>(Repl); + if (!ReplInst) + return; + + // Patch the replacement so that it is not more restrictive than the value + // being replaced. + // Note that if 'I' is a load being replaced by some operation, + // for example, by an arithmetic operation, then andIRFlags() + // would just erase all math flags from the original arithmetic + // operation, which is clearly not wanted and not needed. + if (!isa<LoadInst>(I)) + ReplInst->andIRFlags(I); + + // FIXME: If both the original and replacement value are part of the + // same control-flow region (meaning that the execution of one + // guarantees the execution of the other), then we can combine the + // noalias scopes here and do better than the general conservative + // answer used in combineMetadata(). + + // In general, GVN unifies expressions over different control-flow + // regions, and so we need a conservative combination of the noalias + // scopes. + static const unsigned KnownIDs[] = { + LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, + LLVMContext::MD_noalias, LLVMContext::MD_range, + LLVMContext::MD_fpmath, LLVMContext::MD_invariant_load, + LLVMContext::MD_invariant_group}; + combineMetadata(ReplInst, I, KnownIDs); +} + static void patchAndReplaceAllUsesWith(Instruction *I, Value *Repl) { patchReplacementInstruction(I, Repl); I->replaceAllUsesWith(Repl); @@ -1967,10 +3408,6 @@ void NewGVN::deleteInstructionsInBlock(BasicBlock *BB) { DEBUG(dbgs() << " BasicBlock Dead:" << *BB); ++NumGVNBlocksDeleted; - // Check to see if there are non-terminating instructions to delete. - if (isa<TerminatorInst>(BB->begin())) - return; - // Delete the instructions backwards, as it has a reduced likelihood of having // to update as many def-use and use-def chains. Start after the terminator. auto StartPoint = BB->rbegin(); @@ -1987,6 +3424,11 @@ void NewGVN::deleteInstructionsInBlock(BasicBlock *BB) { Inst.eraseFromParent(); ++NumGVNInstrDeleted; } + // Now insert something that simplifycfg will turn into an unreachable. + Type *Int8Ty = Type::getInt8Ty(BB->getContext()); + new StoreInst(UndefValue::get(Int8Ty), + Constant::getNullValue(Int8Ty->getPointerTo()), + BB->getTerminator()); } void NewGVN::markInstructionForDeletion(Instruction *I) { @@ -2042,6 +3484,37 @@ private: }; } +// Given a value and a basic block we are trying to see if it is available in, +// see if the value has a leader available in that block. +Value *NewGVN::findPhiOfOpsLeader(const Expression *E, + const BasicBlock *BB) const { + // It would already be constant if we could make it constant + if (auto *CE = dyn_cast<ConstantExpression>(E)) + return CE->getConstantValue(); + if (auto *VE = dyn_cast<VariableExpression>(E)) + return VE->getVariableValue(); + + auto *CC = ExpressionToClass.lookup(E); + if (!CC) + return nullptr; + if (alwaysAvailable(CC->getLeader())) + return CC->getLeader(); + + for (auto Member : *CC) { + auto *MemberInst = dyn_cast<Instruction>(Member); + // Anything that isn't an instruction is always available. + if (!MemberInst) + return Member; + // If we are looking for something in the same block as the member, it must + // be a leader because this function is looking for operands for a phi node. + if (MemberInst->getParent() == BB || + DT->dominates(MemberInst->getParent(), BB)) { + return Member; + } + } + return nullptr; +} + bool NewGVN::eliminateInstructions(Function &F) { // This is a non-standard eliminator. The normal way to eliminate is // to walk the dominator tree in order, keeping track of available @@ -2072,73 +3545,91 @@ bool NewGVN::eliminateInstructions(Function &F) { // DFS numbers are updated, we compute some ourselves. DT->updateDFSNumbers(); - for (auto &B : F) { - if (!ReachableBlocks.count(&B)) { - for (const auto S : successors(&B)) { - for (auto II = S->begin(); isa<PHINode>(II); ++II) { - auto &Phi = cast<PHINode>(*II); - DEBUG(dbgs() << "Replacing incoming value of " << *II << " for block " - << getBlockName(&B) - << " with undef due to it being unreachable\n"); - for (auto &Operand : Phi.incoming_values()) - if (Phi.getIncomingBlock(Operand) == &B) - Operand.set(UndefValue::get(Phi.getType())); - } + // Go through all of our phi nodes, and kill the arguments associated with + // unreachable edges. + auto ReplaceUnreachablePHIArgs = [&](PHINode &PHI, BasicBlock *BB) { + for (auto &Operand : PHI.incoming_values()) + if (!ReachableEdges.count({PHI.getIncomingBlock(Operand), BB})) { + DEBUG(dbgs() << "Replacing incoming value of " << PHI << " for block " + << getBlockName(PHI.getIncomingBlock(Operand)) + << " with undef due to it being unreachable\n"); + Operand.set(UndefValue::get(PHI.getType())); } + }; + SmallPtrSet<BasicBlock *, 8> BlocksWithPhis; + for (auto &B : F) + if ((!B.empty() && isa<PHINode>(*B.begin())) || + (PHIOfOpsPHIs.find(&B) != PHIOfOpsPHIs.end())) + BlocksWithPhis.insert(&B); + DenseMap<const BasicBlock *, unsigned> ReachablePredCount; + for (auto KV : ReachableEdges) + ReachablePredCount[KV.getEnd()]++; + for (auto *BB : BlocksWithPhis) + // TODO: It would be faster to use getNumIncomingBlocks() on a phi node in + // the block and subtract the pred count, but it's more complicated. + if (ReachablePredCount.lookup(BB) != + unsigned(std::distance(pred_begin(BB), pred_end(BB)))) { + for (auto II = BB->begin(); isa<PHINode>(II); ++II) { + auto &PHI = cast<PHINode>(*II); + ReplaceUnreachablePHIArgs(PHI, BB); + } + for_each_found(PHIOfOpsPHIs, BB, [&](PHINode *PHI) { + ReplaceUnreachablePHIArgs(*PHI, BB); + }); } - DomTreeNode *Node = DT->getNode(&B); - if (Node) - DFSDomMap[&B] = {Node->getDFSNumIn(), Node->getDFSNumOut()}; - } - for (CongruenceClass *CC : CongruenceClasses) { - // FIXME: We should eventually be able to replace everything still - // in the initial class with undef, as they should be unreachable. - // Right now, initial still contains some things we skip value - // numbering of (UNREACHABLE's, for example). - if (CC == InitialClass || CC->Dead) + // Map to store the use counts + DenseMap<const Value *, unsigned int> UseCounts; + for (auto *CC : reverse(CongruenceClasses)) { + DEBUG(dbgs() << "Eliminating in congruence class " << CC->getID() << "\n"); + // Track the equivalent store info so we can decide whether to try + // dead store elimination. + SmallVector<ValueDFS, 8> PossibleDeadStores; + SmallPtrSet<Instruction *, 8> ProbablyDead; + if (CC->isDead() || CC->empty()) continue; - assert(CC->RepLeader && "We should have had a leader"); + // Everything still in the TOP class is unreachable or dead. + if (CC == TOPClass) { + for (auto M : *CC) { + auto *VTE = ValueToExpression.lookup(M); + if (VTE && isa<DeadExpression>(VTE)) + markInstructionForDeletion(cast<Instruction>(M)); + assert((!ReachableBlocks.count(cast<Instruction>(M)->getParent()) || + InstructionsToErase.count(cast<Instruction>(M))) && + "Everything in TOP should be unreachable or dead at this " + "point"); + } + continue; + } + assert(CC->getLeader() && "We should have had a leader"); // If this is a leader that is always available, and it's a // constant or has no equivalences, just replace everything with // it. We then update the congruence class with whatever members // are left. - if (alwaysAvailable(CC->RepLeader)) { - SmallPtrSet<Value *, 4> MembersLeft; - for (auto M : CC->Members) { - + Value *Leader = + CC->getStoredValue() ? CC->getStoredValue() : CC->getLeader(); + if (alwaysAvailable(Leader)) { + CongruenceClass::MemberSet MembersLeft; + for (auto M : *CC) { Value *Member = M; - // Void things have no uses we can replace. - if (Member == CC->RepLeader || Member->getType()->isVoidTy()) { + if (Member == Leader || !isa<Instruction>(Member) || + Member->getType()->isVoidTy()) { MembersLeft.insert(Member); continue; } - - DEBUG(dbgs() << "Found replacement " << *(CC->RepLeader) << " for " - << *Member << "\n"); - // Due to equality propagation, these may not always be - // instructions, they may be real values. We don't really - // care about trying to replace the non-instructions. - if (auto *I = dyn_cast<Instruction>(Member)) { - assert(CC->RepLeader != I && - "About to accidentally remove our leader"); - replaceInstruction(I, CC->RepLeader); - AnythingReplaced = true; - - continue; - } else { - MembersLeft.insert(I); - } + DEBUG(dbgs() << "Found replacement " << *(Leader) << " for " << *Member + << "\n"); + auto *I = cast<Instruction>(Member); + assert(Leader != I && "About to accidentally remove our leader"); + replaceInstruction(I, Leader); + AnythingReplaced = true; } - CC->Members.swap(MembersLeft); - + CC->swap(MembersLeft); } else { - DEBUG(dbgs() << "Eliminating in congruence class " << CC->ID << "\n"); // If this is a singleton, we can skip it. - if (CC->Members.size() != 1) { - + if (CC->size() != 1 || RealToTemp.lookup(Leader)) { // This is a stack because equality replacement/etc may place // constants in the middle of the member list, and we want to use // those constant values in preference to the current leader, over @@ -2147,23 +3638,34 @@ bool NewGVN::eliminateInstructions(Function &F) { // Convert the members to DFS ordered sets and then merge them. SmallVector<ValueDFS, 8> DFSOrderedSet; - convertDenseToDFSOrdered(CC->Members, DFSOrderedSet); + convertClassToDFSOrdered(*CC, DFSOrderedSet, UseCounts, ProbablyDead); // Sort the whole thing. std::sort(DFSOrderedSet.begin(), DFSOrderedSet.end()); - for (auto &VD : DFSOrderedSet) { int MemberDFSIn = VD.DFSIn; int MemberDFSOut = VD.DFSOut; - Value *Member = VD.Val; - Use *MemberUse = VD.U; - - if (Member) { - // We ignore void things because we can't get a value from them. - // FIXME: We could actually use this to kill dead stores that are - // dominated by equivalent earlier stores. - if (Member->getType()->isVoidTy()) - continue; + Value *Def = VD.Def.getPointer(); + bool FromStore = VD.Def.getInt(); + Use *U = VD.U; + // We ignore void things because we can't get a value from them. + if (Def && Def->getType()->isVoidTy()) + continue; + auto *DefInst = dyn_cast_or_null<Instruction>(Def); + if (DefInst && AllTempInstructions.count(DefInst)) { + auto *PN = cast<PHINode>(DefInst); + + // If this is a value phi and that's the expression we used, insert + // it into the program + // remove from temp instruction list. + AllTempInstructions.erase(PN); + auto *DefBlock = getBlockForValue(Def); + DEBUG(dbgs() << "Inserting fully real phi of ops" << *Def + << " into block " + << getBlockName(getBlockForValue(Def)) << "\n"); + PN->insertBefore(&DefBlock->front()); + Def = PN; + NumGVNPHIOfOpsEliminations++; } if (EliminationStack.empty()) { @@ -2189,69 +3691,251 @@ bool NewGVN::eliminateInstructions(Function &F) { // start using, we also push. // Otherwise, we walk along, processing members who are // dominated by this scope, and eliminate them. - bool ShouldPush = - Member && (EliminationStack.empty() || isa<Constant>(Member)); + bool ShouldPush = Def && EliminationStack.empty(); bool OutOfScope = !EliminationStack.isInScope(MemberDFSIn, MemberDFSOut); if (OutOfScope || ShouldPush) { // Sync to our current scope. EliminationStack.popUntilDFSScope(MemberDFSIn, MemberDFSOut); - ShouldPush |= Member && EliminationStack.empty(); + bool ShouldPush = Def && EliminationStack.empty(); if (ShouldPush) { - EliminationStack.push_back(Member, MemberDFSIn, MemberDFSOut); + EliminationStack.push_back(Def, MemberDFSIn, MemberDFSOut); + } + } + + // Skip the Def's, we only want to eliminate on their uses. But mark + // dominated defs as dead. + if (Def) { + // For anything in this case, what and how we value number + // guarantees that any side-effets that would have occurred (ie + // throwing, etc) can be proven to either still occur (because it's + // dominated by something that has the same side-effects), or never + // occur. Otherwise, we would not have been able to prove it value + // equivalent to something else. For these things, we can just mark + // it all dead. Note that this is different from the "ProbablyDead" + // set, which may not be dominated by anything, and thus, are only + // easy to prove dead if they are also side-effect free. Note that + // because stores are put in terms of the stored value, we skip + // stored values here. If the stored value is really dead, it will + // still be marked for deletion when we process it in its own class. + if (!EliminationStack.empty() && Def != EliminationStack.back() && + isa<Instruction>(Def) && !FromStore) + markInstructionForDeletion(cast<Instruction>(Def)); + continue; + } + // At this point, we know it is a Use we are trying to possibly + // replace. + + assert(isa<Instruction>(U->get()) && + "Current def should have been an instruction"); + assert(isa<Instruction>(U->getUser()) && + "Current user should have been an instruction"); + + // If the thing we are replacing into is already marked to be dead, + // this use is dead. Note that this is true regardless of whether + // we have anything dominating the use or not. We do this here + // because we are already walking all the uses anyway. + Instruction *InstUse = cast<Instruction>(U->getUser()); + if (InstructionsToErase.count(InstUse)) { + auto &UseCount = UseCounts[U->get()]; + if (--UseCount == 0) { + ProbablyDead.insert(cast<Instruction>(U->get())); } } // If we get to this point, and the stack is empty we must have a use - // with nothing we can use to eliminate it, just skip it. + // with nothing we can use to eliminate this use, so just skip it. if (EliminationStack.empty()) continue; - // Skip the Value's, we only want to eliminate on their uses. - if (Member) - continue; - Value *Result = EliminationStack.back(); + Value *DominatingLeader = EliminationStack.back(); + + auto *II = dyn_cast<IntrinsicInst>(DominatingLeader); + if (II && II->getIntrinsicID() == Intrinsic::ssa_copy) + DominatingLeader = II->getOperand(0); // Don't replace our existing users with ourselves. - if (MemberUse->get() == Result) + if (U->get() == DominatingLeader) continue; - - DEBUG(dbgs() << "Found replacement " << *Result << " for " - << *MemberUse->get() << " in " << *(MemberUse->getUser()) - << "\n"); + DEBUG(dbgs() << "Found replacement " << *DominatingLeader << " for " + << *U->get() << " in " << *(U->getUser()) << "\n"); // If we replaced something in an instruction, handle the patching of - // metadata. - if (auto *ReplacedInst = dyn_cast<Instruction>(MemberUse->get())) - patchReplacementInstruction(ReplacedInst, Result); - - assert(isa<Instruction>(MemberUse->getUser())); - MemberUse->set(Result); + // metadata. Skip this if we are replacing predicateinfo with its + // original operand, as we already know we can just drop it. + auto *ReplacedInst = cast<Instruction>(U->get()); + auto *PI = PredInfo->getPredicateInfoFor(ReplacedInst); + if (!PI || DominatingLeader != PI->OriginalOp) + patchReplacementInstruction(ReplacedInst, DominatingLeader); + U->set(DominatingLeader); + // This is now a use of the dominating leader, which means if the + // dominating leader was dead, it's now live! + auto &LeaderUseCount = UseCounts[DominatingLeader]; + // It's about to be alive again. + if (LeaderUseCount == 0 && isa<Instruction>(DominatingLeader)) + ProbablyDead.erase(cast<Instruction>(DominatingLeader)); + if (LeaderUseCount == 0 && II) + ProbablyDead.insert(II); + ++LeaderUseCount; AnythingReplaced = true; } } } + // At this point, anything still in the ProbablyDead set is actually dead if + // would be trivially dead. + for (auto *I : ProbablyDead) + if (wouldInstructionBeTriviallyDead(I)) + markInstructionForDeletion(I); + // Cleanup the congruence class. - SmallPtrSet<Value *, 4> MembersLeft; - for (Value *Member : CC->Members) { - if (Member->getType()->isVoidTy()) { + CongruenceClass::MemberSet MembersLeft; + for (auto *Member : *CC) + if (!isa<Instruction>(Member) || + !InstructionsToErase.count(cast<Instruction>(Member))) MembersLeft.insert(Member); - continue; - } - - if (auto *MemberInst = dyn_cast<Instruction>(Member)) { - if (isInstructionTriviallyDead(MemberInst)) { - // TODO: Don't mark loads of undefs. - markInstructionForDeletion(MemberInst); - continue; + CC->swap(MembersLeft); + + // If we have possible dead stores to look at, try to eliminate them. + if (CC->getStoreCount() > 0) { + convertClassToLoadsAndStores(*CC, PossibleDeadStores); + std::sort(PossibleDeadStores.begin(), PossibleDeadStores.end()); + ValueDFSStack EliminationStack; + for (auto &VD : PossibleDeadStores) { + int MemberDFSIn = VD.DFSIn; + int MemberDFSOut = VD.DFSOut; + Instruction *Member = cast<Instruction>(VD.Def.getPointer()); + if (EliminationStack.empty() || + !EliminationStack.isInScope(MemberDFSIn, MemberDFSOut)) { + // Sync to our current scope. + EliminationStack.popUntilDFSScope(MemberDFSIn, MemberDFSOut); + if (EliminationStack.empty()) { + EliminationStack.push_back(Member, MemberDFSIn, MemberDFSOut); + continue; + } } + // We already did load elimination, so nothing to do here. + if (isa<LoadInst>(Member)) + continue; + assert(!EliminationStack.empty()); + Instruction *Leader = cast<Instruction>(EliminationStack.back()); + (void)Leader; + assert(DT->dominates(Leader->getParent(), Member->getParent())); + // Member is dominater by Leader, and thus dead + DEBUG(dbgs() << "Marking dead store " << *Member + << " that is dominated by " << *Leader << "\n"); + markInstructionForDeletion(Member); + CC->erase(Member); + ++NumGVNDeadStores; } - MembersLeft.insert(Member); } - CC->Members.swap(MembersLeft); } - return AnythingReplaced; } + +// This function provides global ranking of operations so that we can place them +// in a canonical order. Note that rank alone is not necessarily enough for a +// complete ordering, as constants all have the same rank. However, generally, +// we will simplify an operation with all constants so that it doesn't matter +// what order they appear in. +unsigned int NewGVN::getRank(const Value *V) const { + // Prefer constants to undef to anything else + // Undef is a constant, have to check it first. + // Prefer smaller constants to constantexprs + if (isa<ConstantExpr>(V)) + return 2; + if (isa<UndefValue>(V)) + return 1; + if (isa<Constant>(V)) + return 0; + else if (auto *A = dyn_cast<Argument>(V)) + return 3 + A->getArgNo(); + + // Need to shift the instruction DFS by number of arguments + 3 to account for + // the constant and argument ranking above. + unsigned Result = InstrToDFSNum(V); + if (Result > 0) + return 4 + NumFuncArgs + Result; + // Unreachable or something else, just return a really large number. + return ~0; +} + +// This is a function that says whether two commutative operations should +// have their order swapped when canonicalizing. +bool NewGVN::shouldSwapOperands(const Value *A, const Value *B) const { + // Because we only care about a total ordering, and don't rewrite expressions + // in this order, we order by rank, which will give a strict weak ordering to + // everything but constants, and then we order by pointer address. + return std::make_pair(getRank(A), A) > std::make_pair(getRank(B), B); +} + +namespace { +class NewGVNLegacyPass : public FunctionPass { +public: + static char ID; // Pass identification, replacement for typeid. + NewGVNLegacyPass() : FunctionPass(ID) { + initializeNewGVNLegacyPassPass(*PassRegistry::getPassRegistry()); + } + bool runOnFunction(Function &F) override; + +private: + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.addRequired<MemorySSAWrapperPass>(); + AU.addRequired<AAResultsWrapperPass>(); + AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); + } +}; +} // namespace + +bool NewGVNLegacyPass::runOnFunction(Function &F) { + if (skipFunction(F)) + return false; + return NewGVN(F, &getAnalysis<DominatorTreeWrapperPass>().getDomTree(), + &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F), + &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(), + &getAnalysis<AAResultsWrapperPass>().getAAResults(), + &getAnalysis<MemorySSAWrapperPass>().getMSSA(), + F.getParent()->getDataLayout()) + .runGVN(); +} + +INITIALIZE_PASS_BEGIN(NewGVNLegacyPass, "newgvn", "Global Value Numbering", + false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) +INITIALIZE_PASS_END(NewGVNLegacyPass, "newgvn", "Global Value Numbering", false, + false) + +char NewGVNLegacyPass::ID = 0; + +// createGVNPass - The public interface to this file. +FunctionPass *llvm::createNewGVNPass() { return new NewGVNLegacyPass(); } + +PreservedAnalyses NewGVNPass::run(Function &F, AnalysisManager<Function> &AM) { + // Apparently the order in which we get these results matter for + // the old GVN (see Chandler's comment in GVN.cpp). I'll keep + // the same order here, just in case. + auto &AC = AM.getResult<AssumptionAnalysis>(F); + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); + auto &AA = AM.getResult<AAManager>(F); + auto &MSSA = AM.getResult<MemorySSAAnalysis>(F).getMSSA(); + bool Changed = + NewGVN(F, &DT, &AC, &TLI, &AA, &MSSA, F.getParent()->getDataLayout()) + .runGVN(); + if (!Changed) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserve<DominatorTreeAnalysis>(); + PA.preserve<GlobalsAA>(); + return PA; +} diff --git a/contrib/llvm/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp b/contrib/llvm/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp index 1a7ddc9..1bfecea 100644 --- a/contrib/llvm/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp @@ -66,7 +66,7 @@ static bool optimizeSQRT(CallInst *Call, Function *CalledFunc, // Add attribute "readnone" so that backend can use a native sqrt instruction // for this call. Insert a FP compare instruction and a conditional branch // at the end of CurrBB. - Call->addAttribute(AttributeSet::FunctionIndex, Attribute::ReadNone); + Call->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone); CurrBB.getTerminator()->eraseFromParent(); Builder.SetInsertPoint(&CurrBB); Value *FCmp = Builder.CreateFCmpOEQ(Call, Call); @@ -98,14 +98,14 @@ static bool runPartiallyInlineLibCalls(Function &F, TargetLibraryInfo *TLI, // Skip if function either has local linkage or is not a known library // function. - LibFunc::Func LibFunc; + LibFunc LF; if (CalledFunc->hasLocalLinkage() || !CalledFunc->hasName() || - !TLI->getLibFunc(CalledFunc->getName(), LibFunc)) + !TLI->getLibFunc(CalledFunc->getName(), LF)) continue; - switch (LibFunc) { - case LibFunc::sqrtf: - case LibFunc::sqrt: + switch (LF) { + case LibFunc_sqrtf: + case LibFunc_sqrt: if (TTI->haveFastSqrt(Call->getType()) && optimizeSQRT(Call, CalledFunc, *CurrBB, BB)) break; diff --git a/contrib/llvm/lib/Transforms/Scalar/Reassociate.cpp b/contrib/llvm/lib/Transforms/Scalar/Reassociate.cpp index 65c814d..e235e5eb 100644 --- a/contrib/llvm/lib/Transforms/Scalar/Reassociate.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/Reassociate.cpp @@ -35,6 +35,7 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/IR/ValueHandle.h" #include "llvm/Pass.h" #include "llvm/Support/Debug.h" @@ -106,11 +107,12 @@ XorOpnd::XorOpnd(Value *V) { I->getOpcode() == Instruction::And)) { Value *V0 = I->getOperand(0); Value *V1 = I->getOperand(1); - if (isa<ConstantInt>(V0)) + const APInt *C; + if (match(V0, PatternMatch::m_APInt(C))) std::swap(V0, V1); - if (ConstantInt *C = dyn_cast<ConstantInt>(V1)) { - ConstPart = C->getValue(); + if (match(V1, PatternMatch::m_APInt(C))) { + ConstPart = *C; SymbolicPart = V0; isOr = (I->getOpcode() == Instruction::Or); return; @@ -119,7 +121,7 @@ XorOpnd::XorOpnd(Value *V) { // view the operand as "V | 0" SymbolicPart = V; - ConstPart = APInt::getNullValue(V->getType()->getIntegerBitWidth()); + ConstPart = APInt::getNullValue(V->getType()->getScalarSizeInBits()); isOr = true; } @@ -955,8 +957,8 @@ static BinaryOperator *ConvertShiftToMul(Instruction *Shl) { /// Scan backwards and forwards among values with the same rank as element i /// to see if X exists. If X does not exist, return i. This is useful when /// scanning for 'x' when we see '-x' because they both get the same rank. -static unsigned FindInOperandList(SmallVectorImpl<ValueEntry> &Ops, unsigned i, - Value *X) { +static unsigned FindInOperandList(const SmallVectorImpl<ValueEntry> &Ops, + unsigned i, Value *X) { unsigned XRank = Ops[i].Rank; unsigned e = Ops.size(); for (unsigned j = i+1; j != e && Ops[j].Rank == XRank; ++j) { @@ -982,7 +984,7 @@ static unsigned FindInOperandList(SmallVectorImpl<ValueEntry> &Ops, unsigned i, /// Emit a tree of add instructions, summing Ops together /// and returning the result. Insert the tree before I. static Value *EmitAddTreeOfValues(Instruction *I, - SmallVectorImpl<WeakVH> &Ops){ + SmallVectorImpl<WeakTrackingVH> &Ops) { if (Ops.size() == 1) return Ops.back(); Value *V1 = Ops.back(); @@ -1069,8 +1071,7 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) { /// /// Ops is the top-level list of add operands we're trying to factor. static void FindSingleUseMultiplyFactors(Value *V, - SmallVectorImpl<Value*> &Factors, - const SmallVectorImpl<ValueEntry> &Ops) { + SmallVectorImpl<Value*> &Factors) { BinaryOperator *BO = isReassociableOp(V, Instruction::Mul, Instruction::FMul); if (!BO) { Factors.push_back(V); @@ -1078,8 +1079,8 @@ static void FindSingleUseMultiplyFactors(Value *V, } // Otherwise, add the LHS and RHS to the list of factors. - FindSingleUseMultiplyFactors(BO->getOperand(1), Factors, Ops); - FindSingleUseMultiplyFactors(BO->getOperand(0), Factors, Ops); + FindSingleUseMultiplyFactors(BO->getOperand(1), Factors); + FindSingleUseMultiplyFactors(BO->getOperand(0), Factors); } /// Optimize a series of operands to an 'and', 'or', or 'xor' instruction. @@ -1135,20 +1136,19 @@ static Value *OptimizeAndOrXor(unsigned Opcode, /// instruction. There are two special cases: 1) if the constant operand is 0, /// it will return NULL. 2) if the constant is ~0, the symbolic operand will /// be returned. -static Value *createAndInstr(Instruction *InsertBefore, Value *Opnd, +static Value *createAndInstr(Instruction *InsertBefore, Value *Opnd, const APInt &ConstOpnd) { - if (ConstOpnd != 0) { - if (!ConstOpnd.isAllOnesValue()) { - LLVMContext &Ctx = Opnd->getType()->getContext(); - Instruction *I; - I = BinaryOperator::CreateAnd(Opnd, ConstantInt::get(Ctx, ConstOpnd), - "and.ra", InsertBefore); - I->setDebugLoc(InsertBefore->getDebugLoc()); - return I; - } + if (ConstOpnd.isNullValue()) + return nullptr; + + if (ConstOpnd.isAllOnesValue()) return Opnd; - } - return nullptr; + + Instruction *I = BinaryOperator::CreateAnd( + Opnd, ConstantInt::get(Opnd->getType(), ConstOpnd), "and.ra", + InsertBefore); + I->setDebugLoc(InsertBefore->getDebugLoc()); + return I; } // Helper function of OptimizeXor(). It tries to simplify "Opnd1 ^ ConstOpnd" @@ -1164,24 +1164,24 @@ bool ReassociatePass::CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, // = ((x | c1) ^ c1) ^ (c1 ^ c2) // = (x & ~c1) ^ (c1 ^ c2) // It is useful only when c1 == c2. - if (Opnd1->isOrExpr() && Opnd1->getConstPart() != 0) { - if (!Opnd1->getValue()->hasOneUse()) - return false; + if (!Opnd1->isOrExpr() || Opnd1->getConstPart().isNullValue()) + return false; - const APInt &C1 = Opnd1->getConstPart(); - if (C1 != ConstOpnd) - return false; + if (!Opnd1->getValue()->hasOneUse()) + return false; - Value *X = Opnd1->getSymbolicPart(); - Res = createAndInstr(I, X, ~C1); - // ConstOpnd was C2, now C1 ^ C2. - ConstOpnd ^= C1; + const APInt &C1 = Opnd1->getConstPart(); + if (C1 != ConstOpnd) + return false; - if (Instruction *T = dyn_cast<Instruction>(Opnd1->getValue())) - RedoInsts.insert(T); - return true; - } - return false; + Value *X = Opnd1->getSymbolicPart(); + Res = createAndInstr(I, X, ~C1); + // ConstOpnd was C2, now C1 ^ C2. + ConstOpnd ^= C1; + + if (Instruction *T = dyn_cast<Instruction>(Opnd1->getValue())) + RedoInsts.insert(T); + return true; } @@ -1222,8 +1222,8 @@ bool ReassociatePass::CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, APInt C3((~C1) ^ C2); // Do not increase code size! - if (C3 != 0 && !C3.isAllOnesValue()) { - int NewInstNum = ConstOpnd != 0 ? 1 : 2; + if (!C3.isNullValue() && !C3.isAllOnesValue()) { + int NewInstNum = ConstOpnd.getBoolValue() ? 1 : 2; if (NewInstNum > DeadInstNum) return false; } @@ -1239,8 +1239,8 @@ bool ReassociatePass::CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, APInt C3 = C1 ^ C2; // Do not increase code size - if (C3 != 0 && !C3.isAllOnesValue()) { - int NewInstNum = ConstOpnd != 0 ? 1 : 2; + if (!C3.isNullValue() && !C3.isAllOnesValue()) { + int NewInstNum = ConstOpnd.getBoolValue() ? 1 : 2; if (NewInstNum > DeadInstNum) return false; } @@ -1280,17 +1280,20 @@ Value *ReassociatePass::OptimizeXor(Instruction *I, SmallVector<XorOpnd, 8> Opnds; SmallVector<XorOpnd*, 8> OpndPtrs; Type *Ty = Ops[0].Op->getType(); - APInt ConstOpnd(Ty->getIntegerBitWidth(), 0); + APInt ConstOpnd(Ty->getScalarSizeInBits(), 0); // Step 1: Convert ValueEntry to XorOpnd for (unsigned i = 0, e = Ops.size(); i != e; ++i) { Value *V = Ops[i].Op; - if (!isa<ConstantInt>(V)) { + const APInt *C; + // TODO: Support non-splat vectors. + if (match(V, PatternMatch::m_APInt(C))) { + ConstOpnd ^= *C; + } else { XorOpnd O(V); O.setSymbolicRank(getRank(O.getSymbolicPart())); Opnds.push_back(O); - } else - ConstOpnd ^= cast<ConstantInt>(V)->getValue(); + } } // NOTE: From this point on, do *NOT* add/delete element to/from "Opnds". @@ -1328,7 +1331,8 @@ Value *ReassociatePass::OptimizeXor(Instruction *I, Value *CV; // Step 3.1: Try simplifying "CurrOpnd ^ ConstOpnd" - if (ConstOpnd != 0 && CombineXorOpnd(I, CurrOpnd, ConstOpnd, CV)) { + if (!ConstOpnd.isNullValue() && + CombineXorOpnd(I, CurrOpnd, ConstOpnd, CV)) { Changed = true; if (CV) *CurrOpnd = XorOpnd(CV); @@ -1370,17 +1374,17 @@ Value *ReassociatePass::OptimizeXor(Instruction *I, ValueEntry VE(getRank(O.getValue()), O.getValue()); Ops.push_back(VE); } - if (ConstOpnd != 0) { - Value *C = ConstantInt::get(Ty->getContext(), ConstOpnd); + if (!ConstOpnd.isNullValue()) { + Value *C = ConstantInt::get(Ty, ConstOpnd); ValueEntry VE(getRank(C), C); Ops.push_back(VE); } - int Sz = Ops.size(); + unsigned Sz = Ops.size(); if (Sz == 1) return Ops.back().Op; - else if (Sz == 0) { - assert(ConstOpnd == 0); - return ConstantInt::get(Ty->getContext(), ConstOpnd); + if (Sz == 0) { + assert(ConstOpnd.isNullValue()); + return ConstantInt::get(Ty, ConstOpnd); } } @@ -1499,7 +1503,7 @@ Value *ReassociatePass::OptimizeAdd(Instruction *I, // Compute all of the factors of this added value. SmallVector<Value*, 8> Factors; - FindSingleUseMultiplyFactors(BOp, Factors, Ops); + FindSingleUseMultiplyFactors(BOp, Factors); assert(Factors.size() > 1 && "Bad linearize!"); // Add one to FactorOccurrences for each unique factor in this op. @@ -1560,7 +1564,7 @@ Value *ReassociatePass::OptimizeAdd(Instruction *I, ? BinaryOperator::CreateAdd(MaxOccVal, MaxOccVal) : BinaryOperator::CreateFAdd(MaxOccVal, MaxOccVal); - SmallVector<WeakVH, 4> NewMulOps; + SmallVector<WeakTrackingVH, 4> NewMulOps; for (unsigned i = 0; i != Ops.size(); ++i) { // Only try to remove factors from expressions we're allowed to. BinaryOperator *BOp = @@ -1583,7 +1587,7 @@ Value *ReassociatePass::OptimizeAdd(Instruction *I, } // No need for extra uses anymore. - delete DummyInst; + DummyInst->deleteValue(); unsigned NumAddedValues = NewMulOps.size(); Value *V = EmitAddTreeOfValues(I, NewMulOps); @@ -1628,8 +1632,8 @@ Value *ReassociatePass::OptimizeAdd(Instruction *I, /// ((((x*y)*x)*y)*x) -> [(x, 3), (y, 2)] /// /// \returns Whether any factors have a power greater than one. -bool ReassociatePass::collectMultiplyFactors(SmallVectorImpl<ValueEntry> &Ops, - SmallVectorImpl<Factor> &Factors) { +static bool collectMultiplyFactors(SmallVectorImpl<ValueEntry> &Ops, + SmallVectorImpl<Factor> &Factors) { // FIXME: Have Ops be (ValueEntry, Multiplicity) pairs, simplifying this. // Compute the sum of powers of simplifiable factors. unsigned FactorPowerSum = 0; @@ -1890,6 +1894,8 @@ void ReassociatePass::EraseInst(Instruction *I) { Op = Op->user_back(); RedoInsts.insert(Op); } + + MadeChange = true; } // Canonicalize expressions of the following form: @@ -1923,7 +1929,7 @@ Instruction *ReassociatePass::canonicalizeNegConstExpr(Instruction *I) { // User must be a binary operator with one or more uses. Instruction *User = I->user_back(); - if (!isa<BinaryOperator>(User) || !User->hasNUsesOrMore(1)) + if (!isa<BinaryOperator>(User) || User->use_empty()) return nullptr; unsigned UserOpcode = User->getOpcode(); @@ -1935,6 +1941,12 @@ Instruction *ReassociatePass::canonicalizeNegConstExpr(Instruction *I) { if (!User->isCommutative() && User->getOperand(1) != I) return nullptr; + // Don't canonicalize x + (-Constant * y) -> x - (Constant * y), if the + // resulting subtract will be broken up later. This can get us into an + // infinite loop during reassociation. + if (UserOpcode == Instruction::FAdd && ShouldBreakUpSubtract(User)) + return nullptr; + // Change the sign of the constant. APFloat Val = CF->getValueAPF(); Val.changeSign(); @@ -2000,11 +2012,6 @@ void ReassociatePass::OptimizeInst(Instruction *I) { if (I->isCommutative()) canonicalizeOperands(I); - // TODO: We should optimize vector Xor instructions, but they are - // currently unsupported. - if (I->getType()->isVectorTy() && I->getOpcode() == Instruction::Xor) - return; - // Don't optimize floating point instructions that don't have unsafe algebra. if (I->getType()->isFPOrFPVectorTy() && !I->hasUnsafeAlgebra()) return; @@ -2147,7 +2154,7 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) { if (I->getOpcode() == Instruction::Mul && cast<Instruction>(I->user_back())->getOpcode() == Instruction::Add && isa<ConstantInt>(Ops.back().Op) && - cast<ConstantInt>(Ops.back().Op)->isAllOnesValue()) { + cast<ConstantInt>(Ops.back().Op)->isMinusOne()) { ValueEntry Tmp = Ops.pop_back_val(); Ops.insert(Ops.begin(), Tmp); } else if (I->getOpcode() == Instruction::FMul && @@ -2236,8 +2243,8 @@ PreservedAnalyses ReassociatePass::run(Function &F, FunctionAnalysisManager &) { ValueRankMap.clear(); if (MadeChange) { - // FIXME: This should also 'preserve the CFG'. - auto PA = PreservedAnalyses(); + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); PA.preserve<GlobalsAA>(); return PA; } diff --git a/contrib/llvm/lib/Transforms/Scalar/Reg2Mem.cpp b/contrib/llvm/lib/Transforms/Scalar/Reg2Mem.cpp index 615029d..9629568 100644 --- a/contrib/llvm/lib/Transforms/Scalar/Reg2Mem.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/Reg2Mem.cpp @@ -16,7 +16,6 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar.h" #include "llvm/ADT/Statistic.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" @@ -25,6 +24,7 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/Pass.h" +#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" #include <list> using namespace llvm; diff --git a/contrib/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp b/contrib/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp index 1de7420..f19d453 100644 --- a/contrib/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp @@ -7,20 +7,19 @@ // //===----------------------------------------------------------------------===// // -// Rewrite an existing set of gc.statepoints such that they make potential -// relocations performed by the garbage collector explicit in the IR. +// Rewrite call/invoke instructions so as to make potential relocations +// performed by the garbage collector explicit in the IR. // //===----------------------------------------------------------------------===// -#include "llvm/Pass.h" -#include "llvm/Analysis/CFG.h" -#include "llvm/Analysis/TargetTransformInfo.h" -#include "llvm/ADT/SetOperations.h" -#include "llvm/ADT/Statistic.h" #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SetOperations.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringRef.h" -#include "llvm/ADT/MapVector.h" +#include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/Dominators.h" @@ -28,15 +27,16 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/Intrinsics.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Module.h" +#include "llvm/IR/Intrinsics.h" #include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Module.h" #include "llvm/IR/Statepoint.h" #include "llvm/IR/Value.h" #include "llvm/IR/Verifier.h" -#include "llvm/Support/Debug.h" +#include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" @@ -89,10 +89,10 @@ struct RewriteStatepointsForGC : public ModulePass { Changed |= runOnFunction(F); if (Changed) { - // stripNonValidAttributes asserts that shouldRewriteStatepointsIn + // stripNonValidAttributesAndMetadata asserts that shouldRewriteStatepointsIn // returns true for at least one function in the module. Since at least // one function changed, we know that the precondition is satisfied. - stripNonValidAttributes(M); + stripNonValidAttributesAndMetadata(M); } return Changed; @@ -105,20 +105,24 @@ struct RewriteStatepointsForGC : public ModulePass { AU.addRequired<TargetTransformInfoWrapperPass>(); } - /// The IR fed into RewriteStatepointsForGC may have had attributes implying - /// dereferenceability that are no longer valid/correct after - /// RewriteStatepointsForGC has run. This is because semantically, after + /// The IR fed into RewriteStatepointsForGC may have had attributes and + /// metadata implying dereferenceability that are no longer valid/correct after + /// RewriteStatepointsForGC has run. This is because semantically, after /// RewriteStatepointsForGC runs, all calls to gc.statepoint "free" the entire - /// heap. stripNonValidAttributes (conservatively) restores correctness - /// by erasing all attributes in the module that externally imply - /// dereferenceability. - /// Similar reasoning also applies to the noalias attributes. gc.statepoint - /// can touch the entire heap including noalias objects. - void stripNonValidAttributes(Module &M); - - // Helpers for stripNonValidAttributes - void stripNonValidAttributesFromBody(Function &F); + /// heap. stripNonValidAttributesAndMetadata (conservatively) restores + /// correctness by erasing all attributes in the module that externally imply + /// dereferenceability. Similar reasoning also applies to the noalias + /// attributes and metadata. gc.statepoint can touch the entire heap including + /// noalias objects. + void stripNonValidAttributesAndMetadata(Module &M); + + // Helpers for stripNonValidAttributesAndMetadata + void stripNonValidAttributesAndMetadataFromBody(Function &F); void stripNonValidAttributesFromPrototype(Function &F); + // Certain metadata on instructions are invalid after running RS4GC. + // Optimizations that run after RS4GC can incorrectly use this metadata to + // optimize functions. We drop such metadata on the instruction. + void stripInvalidMetadataFromInstruction(Instruction &I); }; } // namespace @@ -365,6 +369,11 @@ findBaseDefiningValueOfVector(Value *I) { // for particular sufflevector patterns. return BaseDefiningValueResult(I, false); + // The behavior of getelementptr instructions is the same for vector and + // non-vector data types. + if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) + return findBaseDefiningValue(GEP->getPointerOperand()); + // A PHI or Select is a base defining value. The outer findBasePointer // algorithm is responsible for constructing a base value for this BDV. assert((isa<SelectInst>(I) || isa<PHINode>(I)) && @@ -634,7 +643,7 @@ static BDVState meetBDVStateImpl(const BDVState &LHS, const BDVState &RHS) { // Values of type BDVState form a lattice, and this function implements the meet // operation. -static BDVState meetBDVState(BDVState LHS, BDVState RHS) { +static BDVState meetBDVState(const BDVState &LHS, const BDVState &RHS) { BDVState Result = meetBDVStateImpl(LHS, RHS); assert(Result == meetBDVStateImpl(RHS, LHS) && "Math is wrong: meet does not commute!"); @@ -1123,39 +1132,23 @@ normalizeForInvokeSafepoint(BasicBlock *BB, BasicBlock *InvokeParent, // Create new attribute set containing only attributes which can be transferred // from original call to the safepoint. -static AttributeSet legalizeCallAttributes(AttributeSet AS) { - AttributeSet Ret; - - for (unsigned Slot = 0; Slot < AS.getNumSlots(); Slot++) { - unsigned Index = AS.getSlotIndex(Slot); - - if (Index == AttributeSet::ReturnIndex || - Index == AttributeSet::FunctionIndex) { - - for (Attribute Attr : make_range(AS.begin(Slot), AS.end(Slot))) { - - // Do not allow certain attributes - just skip them - // Safepoint can not be read only or read none. - if (Attr.hasAttribute(Attribute::ReadNone) || - Attr.hasAttribute(Attribute::ReadOnly)) - continue; - - // These attributes control the generation of the gc.statepoint call / - // invoke itself; and once the gc.statepoint is in place, they're of no - // use. - if (isStatepointDirectiveAttr(Attr)) - continue; - - Ret = Ret.addAttributes( - AS.getContext(), Index, - AttributeSet::get(AS.getContext(), Index, AttrBuilder(Attr))); - } - } - - // Just skip parameter attributes for now - } - - return Ret; +static AttributeList legalizeCallAttributes(AttributeList AL) { + if (AL.isEmpty()) + return AL; + + // Remove the readonly, readnone, and statepoint function attributes. + AttrBuilder FnAttrs = AL.getFnAttributes(); + FnAttrs.removeAttribute(Attribute::ReadNone); + FnAttrs.removeAttribute(Attribute::ReadOnly); + for (Attribute A : AL.getFnAttributes()) { + if (isStatepointDirectiveAttr(A)) + FnAttrs.remove(A); + } + + // Just skip parameter and return attributes for now + LLVMContext &Ctx = AL.getContext(); + return AttributeList::get(Ctx, AttributeList::FunctionIndex, + AttributeSet::get(Ctx, FnAttrs)); } /// Helper function to place all gc relocates necessary for the given @@ -1299,12 +1292,11 @@ static StringRef getDeoptLowering(CallSite CS) { const char *DeoptLowering = "deopt-lowering"; if (CS.hasFnAttr(DeoptLowering)) { // FIXME: CallSite has a *really* confusing interface around attributes - // with values. - const AttributeSet &CSAS = CS.getAttributes(); - if (CSAS.hasAttribute(AttributeSet::FunctionIndex, - DeoptLowering)) - return CSAS.getAttribute(AttributeSet::FunctionIndex, - DeoptLowering).getValueAsString(); + // with values. + const AttributeList &CSAS = CS.getAttributes(); + if (CSAS.hasAttribute(AttributeList::FunctionIndex, DeoptLowering)) + return CSAS.getAttribute(AttributeList::FunctionIndex, DeoptLowering) + .getValueAsString(); Function *F = CS.getCalledFunction(); assert(F && F->hasFnAttribute(DeoptLowering)); return F->getFnAttribute(DeoptLowering).getValueAsString(); @@ -1388,7 +1380,6 @@ makeStatepointExplicitImpl(const CallSite CS, /* to replace */ // Create the statepoint given all the arguments Instruction *Token = nullptr; - AttributeSet ReturnAttrs; if (CS.isCall()) { CallInst *ToReplace = cast<CallInst>(CS.getInstruction()); CallInst *Call = Builder.CreateGCStatepointCall( @@ -1399,12 +1390,10 @@ makeStatepointExplicitImpl(const CallSite CS, /* to replace */ Call->setCallingConv(ToReplace->getCallingConv()); // Currently we will fail on parameter attributes and on certain - // function attributes. - AttributeSet NewAttrs = legalizeCallAttributes(ToReplace->getAttributes()); - // In case if we can handle this set of attributes - set up function attrs - // directly on statepoint and return attrs later for gc_result intrinsic. - Call->setAttributes(NewAttrs.getFnAttributes()); - ReturnAttrs = NewAttrs.getRetAttributes(); + // function attributes. In case if we can handle this set of attributes - + // set up function attrs directly on statepoint and return attrs later for + // gc_result intrinsic. + Call->setAttributes(legalizeCallAttributes(ToReplace->getAttributes())); Token = Call; @@ -1427,12 +1416,10 @@ makeStatepointExplicitImpl(const CallSite CS, /* to replace */ Invoke->setCallingConv(ToReplace->getCallingConv()); // Currently we will fail on parameter attributes and on certain - // function attributes. - AttributeSet NewAttrs = legalizeCallAttributes(ToReplace->getAttributes()); - // In case if we can handle this set of attributes - set up function attrs - // directly on statepoint and return attrs later for gc_result intrinsic. - Invoke->setAttributes(NewAttrs.getFnAttributes()); - ReturnAttrs = NewAttrs.getRetAttributes(); + // function attributes. In case if we can handle this set of attributes - + // set up function attrs directly on statepoint and return attrs later for + // gc_result intrinsic. + Invoke->setAttributes(legalizeCallAttributes(ToReplace->getAttributes())); Token = Invoke; @@ -1478,7 +1465,9 @@ makeStatepointExplicitImpl(const CallSite CS, /* to replace */ StringRef Name = CS.getInstruction()->hasName() ? CS.getInstruction()->getName() : ""; CallInst *GCResult = Builder.CreateGCResult(Token, CS.getType(), Name); - GCResult->setAttributes(CS.getAttributes().getRetAttributes()); + GCResult->setAttributes( + AttributeList::get(GCResult->getContext(), AttributeList::ReturnIndex, + CS.getAttributes().getRetAttributes())); // We cannot RAUW or delete CS.getInstruction() because it could be in the // live set of some other safepoint, in which case that safepoint's @@ -1615,8 +1604,10 @@ static void relocationViaAlloca( // Emit alloca for "LiveValue" and record it in "allocaMap" and // "PromotableAllocas" + const DataLayout &DL = F.getParent()->getDataLayout(); auto emitAllocaFor = [&](Value *LiveValue) { - AllocaInst *Alloca = new AllocaInst(LiveValue->getType(), "", + AllocaInst *Alloca = new AllocaInst(LiveValue->getType(), + DL.getAllocaAddrSpace(), "", F.getEntryBlock().getFirstNonPHI()); AllocaMap[LiveValue] = Alloca; PromotableAllocas.push_back(Alloca); @@ -1873,7 +1864,7 @@ chainToBasePointerCost(SmallVectorImpl<Instruction*> &Chain, "non noop cast is found during rematerialization"); Type *SrcTy = CI->getOperand(0)->getType(); - Cost += TTI.getCastInstrCost(CI->getOpcode(), CI->getType(), SrcTy); + Cost += TTI.getCastInstrCost(CI->getOpcode(), CI->getType(), SrcTy, CI); } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Instr)) { // Cost of the address calculation @@ -1963,7 +1954,7 @@ static void rematerializeLiveValues(CallSite CS, // to identify the newly generated AlternateRootPhi (.base version of phi) // and RootOfChain (the original phi node itself) are the same, so that we // can rematerialize the gep and casts. This is a workaround for the - // deficieny in the findBasePointer algorithm. + // deficiency in the findBasePointer algorithm. if (!AreEquivalentPhiNodes(*OrigRootPhi, *AlternateRootPhi)) continue; // Now that the phi nodes are proved to be the same, assert that @@ -2003,7 +1994,7 @@ static void rematerializeLiveValues(CallSite CS, Instruction *LastClonedValue = nullptr; Instruction *LastValue = nullptr; for (Instruction *Instr: ChainToBase) { - // Only GEP's and casts are suported as we need to be careful to not + // Only GEP's and casts are supported as we need to be careful to not // introduce any new uses of pointers not in the liveset. // Note that it's fine to introduce new uses of pointers which were // otherwise not used after this statepoint. @@ -2107,9 +2098,9 @@ static bool insertParsePoints(Function &F, DominatorTree &DT, // live in the IR. We'll remove all of these when done. SmallVector<CallInst *, 64> Holders; - // Insert a dummy call with all of the arguments to the vm_state we'll need - // for the actual safepoint insertion. This ensures reference arguments in - // the deopt argument list are considered live through the safepoint (and + // Insert a dummy call with all of the deopt operands we'll need for the + // actual safepoint insertion as arguments. This ensures reference operands + // in the deopt argument list are considered live through the safepoint (and // thus makes sure they get relocated.) for (CallSite CS : ToUpdate) { SmallVector<Value *, 64> DeoptValues; @@ -2299,12 +2290,11 @@ static void RemoveNonValidAttrAtIndex(LLVMContext &Ctx, AttrHolder &AH, if (AH.getDereferenceableOrNullBytes(Index)) R.addAttribute(Attribute::get(Ctx, Attribute::DereferenceableOrNull, AH.getDereferenceableOrNullBytes(Index))); - if (AH.doesNotAlias(Index)) + if (AH.getAttributes().hasAttribute(Index, Attribute::NoAlias)) R.addAttribute(Attribute::NoAlias); if (!R.empty()) - AH.setAttributes(AH.getAttributes().removeAttributes( - Ctx, Index, AttributeSet::get(Ctx, Index, R))); + AH.setAttributes(AH.getAttributes().removeAttributes(Ctx, Index, R)); } void @@ -2313,19 +2303,51 @@ RewriteStatepointsForGC::stripNonValidAttributesFromPrototype(Function &F) { for (Argument &A : F.args()) if (isa<PointerType>(A.getType())) - RemoveNonValidAttrAtIndex(Ctx, F, A.getArgNo() + 1); + RemoveNonValidAttrAtIndex(Ctx, F, + A.getArgNo() + AttributeList::FirstArgIndex); if (isa<PointerType>(F.getReturnType())) - RemoveNonValidAttrAtIndex(Ctx, F, AttributeSet::ReturnIndex); + RemoveNonValidAttrAtIndex(Ctx, F, AttributeList::ReturnIndex); +} + +void RewriteStatepointsForGC::stripInvalidMetadataFromInstruction(Instruction &I) { + + if (!isa<LoadInst>(I) && !isa<StoreInst>(I)) + return; + // These are the attributes that are still valid on loads and stores after + // RS4GC. + // The metadata implying dereferenceability and noalias are (conservatively) + // dropped. This is because semantically, after RewriteStatepointsForGC runs, + // all calls to gc.statepoint "free" the entire heap. Also, gc.statepoint can + // touch the entire heap including noalias objects. Note: The reasoning is + // same as stripping the dereferenceability and noalias attributes that are + // analogous to the metadata counterparts. + // We also drop the invariant.load metadata on the load because that metadata + // implies the address operand to the load points to memory that is never + // changed once it became dereferenceable. This is no longer true after RS4GC. + // Similar reasoning applies to invariant.group metadata, which applies to + // loads within a group. + unsigned ValidMetadataAfterRS4GC[] = {LLVMContext::MD_tbaa, + LLVMContext::MD_range, + LLVMContext::MD_alias_scope, + LLVMContext::MD_nontemporal, + LLVMContext::MD_nonnull, + LLVMContext::MD_align, + LLVMContext::MD_type}; + + // Drops all metadata on the instruction other than ValidMetadataAfterRS4GC. + I.dropUnknownNonDebugMetadata(ValidMetadataAfterRS4GC); + } -void RewriteStatepointsForGC::stripNonValidAttributesFromBody(Function &F) { +void RewriteStatepointsForGC::stripNonValidAttributesAndMetadataFromBody(Function &F) { if (F.empty()) return; LLVMContext &Ctx = F.getContext(); MDBuilder Builder(Ctx); + for (Instruction &I : instructions(F)) { if (const MDNode *MD = I.getMetadata(LLVMContext::MD_tbaa)) { assert(MD->getNumOperands() < 5 && "unrecognized metadata shape!"); @@ -2346,12 +2368,14 @@ void RewriteStatepointsForGC::stripNonValidAttributesFromBody(Function &F) { I.setMetadata(LLVMContext::MD_tbaa, MutableTBAA); } + stripInvalidMetadataFromInstruction(I); + if (CallSite CS = CallSite(&I)) { for (int i = 0, e = CS.arg_size(); i != e; i++) if (isa<PointerType>(CS.getArgument(i)->getType())) - RemoveNonValidAttrAtIndex(Ctx, CS, i + 1); + RemoveNonValidAttrAtIndex(Ctx, CS, i + AttributeList::FirstArgIndex); if (isa<PointerType>(CS.getType())) - RemoveNonValidAttrAtIndex(Ctx, CS, AttributeSet::ReturnIndex); + RemoveNonValidAttrAtIndex(Ctx, CS, AttributeList::ReturnIndex); } } } @@ -2370,7 +2394,7 @@ static bool shouldRewriteStatepointsIn(Function &F) { return false; } -void RewriteStatepointsForGC::stripNonValidAttributes(Module &M) { +void RewriteStatepointsForGC::stripNonValidAttributesAndMetadata(Module &M) { #ifndef NDEBUG assert(any_of(M, shouldRewriteStatepointsIn) && "precondition!"); #endif @@ -2379,7 +2403,7 @@ void RewriteStatepointsForGC::stripNonValidAttributes(Module &M) { stripNonValidAttributesFromPrototype(F); for (Function &F : M) - stripNonValidAttributesFromBody(F); + stripNonValidAttributesAndMetadataFromBody(F); } bool RewriteStatepointsForGC::runOnFunction(Function &F) { diff --git a/contrib/llvm/lib/Transforms/Scalar/SCCP.cpp b/contrib/llvm/lib/Transforms/Scalar/SCCP.cpp index ede381c..4822cf7 100644 --- a/contrib/llvm/lib/Transforms/Scalar/SCCP.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/SCCP.cpp @@ -140,6 +140,14 @@ public: return nullptr; } + /// getBlockAddress - If this is a constant with a BlockAddress value, return + /// it, otherwise return null. + BlockAddress *getBlockAddress() const { + if (isConstant()) + return dyn_cast<BlockAddress>(getConstant()); + return nullptr; + } + void markForcedConstant(Constant *V) { assert(isUnknown() && "Can't force a defined value!"); Val.setInt(forcedconstant); @@ -306,20 +314,14 @@ public: return MRVFunctionsTracked; } - void markOverdefined(Value *V) { - assert(!V->getType()->isStructTy() && - "structs should use markAnythingOverdefined"); - markOverdefined(ValueState[V], V); - } - - /// markAnythingOverdefined - Mark the specified value overdefined. This + /// markOverdefined - Mark the specified value overdefined. This /// works with both scalars and structs. - void markAnythingOverdefined(Value *V) { + void markOverdefined(Value *V) { if (auto *STy = dyn_cast<StructType>(V->getType())) for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) markOverdefined(getStructValueState(V, i), V); else - markOverdefined(V); + markOverdefined(ValueState[V], V); } // isStructLatticeConstant - Return true if all the lattice values @@ -513,12 +515,8 @@ private: void visitCmpInst(CmpInst &I); void visitExtractValueInst(ExtractValueInst &EVI); void visitInsertValueInst(InsertValueInst &IVI); - void visitLandingPadInst(LandingPadInst &I) { markAnythingOverdefined(&I); } - void visitFuncletPadInst(FuncletPadInst &FPI) { - markAnythingOverdefined(&FPI); - } void visitCatchSwitchInst(CatchSwitchInst &CPI) { - markAnythingOverdefined(&CPI); + markOverdefined(&CPI); visitTerminatorInst(CPI); } @@ -537,17 +535,11 @@ private: void visitResumeInst (TerminatorInst &I) { /*returns void*/ } void visitUnreachableInst(TerminatorInst &I) { /*returns void*/ } void visitFenceInst (FenceInst &I) { /*returns void*/ } - void visitAtomicCmpXchgInst(AtomicCmpXchgInst &I) { - markAnythingOverdefined(&I); - } - void visitAtomicRMWInst (AtomicRMWInst &I) { markOverdefined(&I); } - void visitAllocaInst (Instruction &I) { markOverdefined(&I); } - void visitVAArgInst (Instruction &I) { markAnythingOverdefined(&I); } - void visitInstruction(Instruction &I) { - // If a new instruction is added to LLVM that we don't handle. + // All the instructions we don't do any special handling for just + // go to overdefined. DEBUG(dbgs() << "SCCP: Don't know how to handle: " << I << '\n'); - markAnythingOverdefined(&I); // Just in case + markOverdefined(&I); } }; @@ -602,14 +594,36 @@ void SCCPSolver::getFeasibleSuccessors(TerminatorInst &TI, return; } - Succs[SI->findCaseValue(CI).getSuccessorIndex()] = true; + Succs[SI->findCaseValue(CI)->getSuccessorIndex()] = true; return; } - // TODO: This could be improved if the operand is a [cast of a] BlockAddress. - if (isa<IndirectBrInst>(&TI)) { - // Just mark all destinations executable! - Succs.assign(TI.getNumSuccessors(), true); + // In case of indirect branch and its address is a blockaddress, we mark + // the target as executable. + if (auto *IBR = dyn_cast<IndirectBrInst>(&TI)) { + // Casts are folded by visitCastInst. + LatticeVal IBRValue = getValueState(IBR->getAddress()); + BlockAddress *Addr = IBRValue.getBlockAddress(); + if (!Addr) { // Overdefined or unknown condition? + // All destinations are executable! + if (!IBRValue.isUnknown()) + Succs.assign(TI.getNumSuccessors(), true); + return; + } + + BasicBlock* T = Addr->getBasicBlock(); + assert(Addr->getFunction() == T->getParent() && + "Block address of a different function ?"); + for (unsigned i = 0; i < IBR->getNumSuccessors(); ++i) { + // This is the target. + if (IBR->getDestination(i) == T) { + Succs[i] = true; + return; + } + } + + // If we didn't find our destination in the IBR successor list, then we + // have undefined behavior. Its ok to assume no successor is executable. return; } @@ -659,13 +673,21 @@ bool SCCPSolver::isEdgeFeasible(BasicBlock *From, BasicBlock *To) { if (!CI) return !SCValue.isUnknown(); - return SI->findCaseValue(CI).getCaseSuccessor() == To; + return SI->findCaseValue(CI)->getCaseSuccessor() == To; } - // Just mark all destinations executable! - // TODO: This could be improved if the operand is a [cast of a] BlockAddress. - if (isa<IndirectBrInst>(TI)) - return true; + // In case of indirect branch and its address is a blockaddress, we mark + // the target as executable. + if (auto *IBR = dyn_cast<IndirectBrInst>(TI)) { + LatticeVal IBRValue = getValueState(IBR->getAddress()); + BlockAddress *Addr = IBRValue.getBlockAddress(); + + if (!Addr) + return !IBRValue.isUnknown(); + + // At this point, the indirectbr is branching on a blockaddress. + return Addr->getBasicBlock() == To; + } DEBUG(dbgs() << "Unknown terminator instruction: " << *TI << '\n'); llvm_unreachable("SCCP: Don't know how to handle this terminator!"); @@ -693,7 +715,7 @@ void SCCPSolver::visitPHINode(PHINode &PN) { // If this PN returns a struct, just mark the result overdefined. // TODO: We could do a lot better than this if code actually uses this. if (PN.getType()->isStructTy()) - return markAnythingOverdefined(&PN); + return markOverdefined(&PN); if (getValueState(&PN).isOverdefined()) return; // Quick exit @@ -803,7 +825,7 @@ void SCCPSolver::visitExtractValueInst(ExtractValueInst &EVI) { // If this returns a struct, mark all elements over defined, we don't track // structs in structs. if (EVI.getType()->isStructTy()) - return markAnythingOverdefined(&EVI); + return markOverdefined(&EVI); // If this is extracting from more than one level of struct, we don't know. if (EVI.getNumIndices() != 1) @@ -828,7 +850,7 @@ void SCCPSolver::visitInsertValueInst(InsertValueInst &IVI) { // If this has more than one index, we can't handle it, drive all results to // undef. if (IVI.getNumIndices() != 1) - return markAnythingOverdefined(&IVI); + return markOverdefined(&IVI); Value *Aggr = IVI.getAggregateOperand(); unsigned Idx = *IVI.idx_begin(); @@ -857,7 +879,7 @@ void SCCPSolver::visitSelectInst(SelectInst &I) { // If this select returns a struct, just mark the result overdefined. // TODO: We could do a lot better than this if code actually uses this. if (I.getType()->isStructTy()) - return markAnythingOverdefined(&I); + return markOverdefined(&I); LatticeVal CondValue = getValueState(I.getCondition()); if (CondValue.isUnknown()) @@ -910,9 +932,16 @@ void SCCPSolver::visitBinaryOperator(Instruction &I) { // Otherwise, one of our operands is overdefined. Try to produce something // better than overdefined with some tricks. - - // If this is an AND or OR with 0 or -1, it doesn't matter that the other - // operand is overdefined. + // If this is 0 / Y, it doesn't matter that the second operand is + // overdefined, and we can replace it with zero. + if (I.getOpcode() == Instruction::UDiv || I.getOpcode() == Instruction::SDiv) + if (V1State.isConstant() && V1State.getConstant()->isNullValue()) + return markConstant(IV, &I, V1State.getConstant()); + + // If this is: + // -> AND/MUL with 0 + // -> OR with -1 + // it doesn't matter that the other operand is overdefined. if (I.getOpcode() == Instruction::And || I.getOpcode() == Instruction::Mul || I.getOpcode() == Instruction::Or) { LatticeVal *NonOverdefVal = nullptr; @@ -934,7 +963,7 @@ void SCCPSolver::visitBinaryOperator(Instruction &I) { } else { // X or -1 = -1 if (ConstantInt *CI = NonOverdefVal->getConstantInt()) - if (CI->isAllOnesValue()) + if (CI->isMinusOne()) return markConstant(IV, &I, NonOverdefVal->getConstant()); } } @@ -1021,7 +1050,7 @@ void SCCPSolver::visitStoreInst(StoreInst &SI) { void SCCPSolver::visitLoadInst(LoadInst &I) { // If this load is of a struct, just mark the result overdefined. if (I.getType()->isStructTy()) - return markAnythingOverdefined(&I); + return markOverdefined(&I); LatticeVal PtrVal = getValueState(I.getOperand(0)); if (PtrVal.isUnknown()) return; // The pointer is not resolved yet! @@ -1078,7 +1107,7 @@ CallOverdefined: // Otherwise, if we have a single return value case, and if the function is // a declaration, maybe we can constant fold it. if (F && F->isDeclaration() && !I->getType()->isStructTy() && - canConstantFoldCallTo(F)) { + canConstantFoldCallTo(CS, F)) { SmallVector<Constant*, 8> Operands; for (CallSite::arg_iterator AI = CS.arg_begin(), E = CS.arg_end(); @@ -1098,7 +1127,7 @@ CallOverdefined: // If we can constant fold this, mark the result of the call as a // constant. - if (Constant *C = ConstantFoldCall(F, Operands, TLI)) { + if (Constant *C = ConstantFoldCall(CS, F, Operands, TLI)) { // call -> undef. if (isa<UndefValue>(C)) return; @@ -1107,7 +1136,7 @@ CallOverdefined: } // Otherwise, we don't know anything about this call, mark it overdefined. - return markAnythingOverdefined(I); + return markOverdefined(I); } // If this is a local function that doesn't have its address taken, mark its @@ -1483,6 +1512,31 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { return true; } + if (auto *IBR = dyn_cast<IndirectBrInst>(TI)) { + // Indirect branch with no successor ?. Its ok to assume it branches + // to no target. + if (IBR->getNumSuccessors() < 1) + continue; + + if (!getValueState(IBR->getAddress()).isUnknown()) + continue; + + // If the input to SCCP is actually branch on undef, fix the undef to + // the first successor of the indirect branch. + if (isa<UndefValue>(IBR->getAddress())) { + IBR->setAddress(BlockAddress::get(IBR->getSuccessor(0))); + markEdgeExecutable(&BB, IBR->getSuccessor(0)); + return true; + } + + // Otherwise, it is a branch on a symbolic value which is currently + // considered to be undef. Handle this by forcing the input value to the + // branch to the first successor. + markForcedConstant(IBR->getAddress(), + BlockAddress::get(IBR->getSuccessor(0))); + return true; + } + if (auto *SI = dyn_cast<SwitchInst>(TI)) { if (!SI->getNumCases() || !getValueState(SI->getCondition()).isUnknown()) continue; @@ -1490,12 +1544,12 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { // If the input to SCCP is actually switch on undef, fix the undef to // the first constant. if (isa<UndefValue>(SI->getCondition())) { - SI->setCondition(SI->case_begin().getCaseValue()); - markEdgeExecutable(&BB, SI->case_begin().getCaseSuccessor()); + SI->setCondition(SI->case_begin()->getCaseValue()); + markEdgeExecutable(&BB, SI->case_begin()->getCaseSuccessor()); return true; } - markForcedConstant(SI->getCondition(), SI->case_begin().getCaseValue()); + markForcedConstant(SI->getCondition(), SI->case_begin()->getCaseValue()); return true; } } @@ -1545,7 +1599,7 @@ static bool runSCCP(Function &F, const DataLayout &DL, // Mark all arguments to the function as being overdefined. for (Argument &AI : F.args()) - Solver.markAnythingOverdefined(&AI); + Solver.markOverdefined(&AI); // Solve for constants. bool ResolvedUndefs = true; @@ -1715,8 +1769,9 @@ static bool runIPSCCP(Module &M, const DataLayout &DL, // arguments and return value aggressively, and can assume it is not called // unless we see evidence to the contrary. if (F.hasLocalLinkage()) { - if (AddressIsTaken(&F)) + if (F.hasAddressTaken()) { AddressTakenFunctions.insert(&F); + } else { Solver.AddArgumentTrackedFunction(&F); continue; @@ -1728,14 +1783,15 @@ static bool runIPSCCP(Module &M, const DataLayout &DL, // Assume nothing about the incoming arguments. for (Argument &AI : F.args()) - Solver.markAnythingOverdefined(&AI); + Solver.markOverdefined(&AI); } // Loop over global variables. We inform the solver about any internal global // variables that do not have their 'addresses taken'. If they don't have // their addresses taken, we can propagate constants through them. for (GlobalVariable &G : M.globals()) - if (!G.isConstant() && G.hasLocalLinkage() && !AddressIsTaken(&G)) + if (!G.isConstant() && G.hasLocalLinkage() && + G.hasDefinitiveInitializer() && !AddressIsTaken(&G)) Solver.TrackValueOfGlobalVariable(&G); // Solve for constants. @@ -1760,15 +1816,11 @@ static bool runIPSCCP(Module &M, const DataLayout &DL, if (F.isDeclaration()) continue; - if (Solver.isBlockExecutable(&F.front())) { + if (Solver.isBlockExecutable(&F.front())) for (Function::arg_iterator AI = F.arg_begin(), E = F.arg_end(); AI != E; - ++AI) { - if (AI->use_empty()) - continue; - if (tryToReplaceWithConstant(Solver, &*AI)) + ++AI) + if (!AI->use_empty() && tryToReplaceWithConstant(Solver, &*AI)) ++IPNumArgsElimed; - } - } for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) { if (!Solver.isBlockExecutable(&*BB)) { @@ -1817,32 +1869,9 @@ static bool runIPSCCP(Module &M, const DataLayout &DL, if (!I) continue; bool Folded = ConstantFoldTerminator(I->getParent()); - if (!Folded) { - // The constant folder may not have been able to fold the terminator - // if this is a branch or switch on undef. Fold it manually as a - // branch to the first successor. -#ifndef NDEBUG - if (auto *BI = dyn_cast<BranchInst>(I)) { - assert(BI->isConditional() && isa<UndefValue>(BI->getCondition()) && - "Branch should be foldable!"); - } else if (auto *SI = dyn_cast<SwitchInst>(I)) { - assert(isa<UndefValue>(SI->getCondition()) && "Switch should fold"); - } else { - llvm_unreachable("Didn't fold away reference to block!"); - } -#endif - - // Make this an uncond branch to the first successor. - TerminatorInst *TI = I->getParent()->getTerminator(); - BranchInst::Create(TI->getSuccessor(0), TI); - - // Remove entries in successor phi nodes to remove edges. - for (unsigned i = 1, e = TI->getNumSuccessors(); i != e; ++i) - TI->getSuccessor(i)->removePredecessor(TI->getParent()); - - // Remove the old terminator. - TI->eraseFromParent(); - } + assert(Folded && + "Expect TermInst on constantint or blockaddress to be folded"); + (void) Folded; } // Finally, delete the basic block. diff --git a/contrib/llvm/lib/Transforms/Scalar/SROA.cpp b/contrib/llvm/lib/Transforms/Scalar/SROA.cpp index bfcb155..b9cee5b 100644 --- a/contrib/llvm/lib/Transforms/Scalar/SROA.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/SROA.cpp @@ -25,6 +25,7 @@ #include "llvm/Transforms/Scalar/SROA.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AssumptionCache.h" @@ -325,7 +326,7 @@ private: /// partition. uint64_t BeginOffset, EndOffset; - /// \brief The start end end iterators of this partition. + /// \brief The start and end iterators of this partition. iterator SI, SJ; /// \brief A collection of split slice tails overlapping the partition. @@ -1251,7 +1252,7 @@ static bool isSafeSelectToSpeculate(SelectInst &SI) { if (!LI || !LI->isSimple()) return false; - // Both operands to the select need to be dereferencable, either + // Both operands to the select need to be dereferenceable, either // absolutely (e.g. allocas) or at this point because we can see other // accesses to it. if (!isSafeToLoadUnconditionally(TValue, LI->getAlignment(), DL, LI)) @@ -1636,8 +1637,17 @@ static bool canConvertValue(const DataLayout &DL, Type *OldTy, Type *NewTy) { return cast<PointerType>(NewTy)->getPointerAddressSpace() == cast<PointerType>(OldTy)->getPointerAddressSpace(); } - if (NewTy->isIntegerTy() || OldTy->isIntegerTy()) - return true; + + // We can convert integers to integral pointers, but not to non-integral + // pointers. + if (OldTy->isIntegerTy()) + return !DL.isNonIntegralPointerType(NewTy); + + // We can convert integral pointers to integers, but non-integral pointers + // need to remain pointers. + if (!DL.isNonIntegralPointerType(OldTy)) + return NewTy->isIntegerTy(); + return false; } @@ -1663,8 +1673,7 @@ static Value *convertValue(const DataLayout &DL, IRBuilderTy &IRB, Value *V, // See if we need inttoptr for this type pair. A cast involving both scalars // and vectors requires and additional bitcast. - if (OldTy->getScalarType()->isIntegerTy() && - NewTy->getScalarType()->isPointerTy()) { + if (OldTy->isIntOrIntVectorTy() && NewTy->isPtrOrPtrVectorTy()) { // Expand <2 x i32> to i8* --> <2 x i32> to i64 to i8* if (OldTy->isVectorTy() && !NewTy->isVectorTy()) return IRB.CreateIntToPtr(IRB.CreateBitCast(V, DL.getIntPtrType(NewTy)), @@ -1680,8 +1689,7 @@ static Value *convertValue(const DataLayout &DL, IRBuilderTy &IRB, Value *V, // See if we need ptrtoint for this type pair. A cast involving both scalars // and vectors requires and additional bitcast. - if (OldTy->getScalarType()->isPointerTy() && - NewTy->getScalarType()->isIntegerTy()) { + if (OldTy->isPtrOrPtrVectorTy() && NewTy->isIntOrIntVectorTy()) { // Expand <2 x i8*> to i128 --> <2 x i8*> to <2 x i64> to i128 if (OldTy->isVectorTy() && !NewTy->isVectorTy()) return IRB.CreateBitCast(IRB.CreatePtrToInt(V, DL.getIntPtrType(OldTy)), @@ -1825,6 +1833,7 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) { // Rank the remaining candidate vector types. This is easy because we know // they're all integer vectors. We sort by ascending number of elements. auto RankVectorTypes = [&DL](VectorType *RHSTy, VectorType *LHSTy) { + (void)DL; assert(DL.getTypeSizeInBits(RHSTy) == DL.getTypeSizeInBits(LHSTy) && "Cannot have vector types of different sizes!"); assert(RHSTy->getElementType()->isIntegerTy() && @@ -2185,8 +2194,8 @@ class llvm::sroa::AllocaSliceRewriter Instruction *OldPtr; // Track post-rewrite users which are PHI nodes and Selects. - SmallPtrSetImpl<PHINode *> &PHIUsers; - SmallPtrSetImpl<SelectInst *> &SelectUsers; + SmallSetVector<PHINode *, 8> &PHIUsers; + SmallSetVector<SelectInst *, 8> &SelectUsers; // Utility IR builder, whose name prefix is setup for each visited use, and // the insertion point is set to point to the user. @@ -2198,8 +2207,8 @@ public: uint64_t NewAllocaBeginOffset, uint64_t NewAllocaEndOffset, bool IsIntegerPromotable, VectorType *PromotableVecTy, - SmallPtrSetImpl<PHINode *> &PHIUsers, - SmallPtrSetImpl<SelectInst *> &SelectUsers) + SmallSetVector<PHINode *, 8> &PHIUsers, + SmallSetVector<SelectInst *, 8> &SelectUsers) : DL(DL), AS(AS), Pass(Pass), OldAI(OldAI), NewAI(NewAI), NewAllocaBeginOffset(NewAllocaBeginOffset), NewAllocaEndOffset(NewAllocaEndOffset), @@ -2294,7 +2303,8 @@ private: #endif return getAdjustedPtr(IRB, DL, &NewAI, - APInt(DL.getPointerSizeInBits(), Offset), PointerTy, + APInt(DL.getPointerTypeSizeInBits(PointerTy), Offset), + PointerTy, #ifndef NDEBUG Twine(OldName) + "." #else @@ -2369,6 +2379,8 @@ private: Value *OldOp = LI.getOperand(0); assert(OldOp == OldPtr); + unsigned AS = LI.getPointerAddressSpace(); + Type *TargetTy = IsSplit ? Type::getIntNTy(LI.getContext(), SliceSize * 8) : LI.getType(); const bool IsLoadPastEnd = DL.getTypeStoreSize(TargetTy) > SliceSize; @@ -2386,7 +2398,22 @@ private: LoadInst *NewLI = IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), LI.isVolatile(), LI.getName()); if (LI.isVolatile()) - NewLI->setAtomic(LI.getOrdering(), LI.getSynchScope()); + NewLI->setAtomic(LI.getOrdering(), LI.getSyncScopeID()); + + // Any !nonnull metadata or !range metadata on the old load is also valid + // on the new load. This is even true in some cases even when the loads + // are different types, for example by mapping !nonnull metadata to + // !range metadata by modeling the null pointer constant converted to the + // integer type. + // FIXME: Add support for range metadata here. Currently the utilities + // for this don't propagate range metadata in trivial cases from one + // integer load to another, don't handle non-addrspace-0 null pointers + // correctly, and don't have any support for mapping ranges as the + // integer type becomes winder or narrower. + if (MDNode *N = LI.getMetadata(LLVMContext::MD_nonnull)) + copyNonnullMetadata(LI, N, *NewLI); + + // Try to preserve nonnull metadata V = NewLI; // If this is an integer load past the end of the slice (which means the @@ -2401,12 +2428,12 @@ private: "endian_shift"); } } else { - Type *LTy = TargetTy->getPointerTo(); + Type *LTy = TargetTy->getPointerTo(AS); LoadInst *NewLI = IRB.CreateAlignedLoad(getNewAllocaSlicePtr(IRB, LTy), getSliceAlign(TargetTy), LI.isVolatile(), LI.getName()); if (LI.isVolatile()) - NewLI->setAtomic(LI.getOrdering(), LI.getSynchScope()); + NewLI->setAtomic(LI.getOrdering(), LI.getSyncScopeID()); V = NewLI; IsPtrAdjusted = true; @@ -2429,12 +2456,12 @@ private: // the computed value, and then replace the placeholder with LI, leaving // LI only used for this computation. Value *Placeholder = - new LoadInst(UndefValue::get(LI.getType()->getPointerTo())); + new LoadInst(UndefValue::get(LI.getType()->getPointerTo(AS))); V = insertInteger(DL, IRB, Placeholder, V, NewBeginOffset - BeginOffset, "insert"); LI.replaceAllUsesWith(V); Placeholder->replaceAllUsesWith(&LI); - delete Placeholder; + Placeholder->deleteValue(); } else { LI.replaceAllUsesWith(V); } @@ -2542,13 +2569,14 @@ private: NewSI = IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlignment(), SI.isVolatile()); } else { - Value *NewPtr = getNewAllocaSlicePtr(IRB, V->getType()->getPointerTo()); + unsigned AS = SI.getPointerAddressSpace(); + Value *NewPtr = getNewAllocaSlicePtr(IRB, V->getType()->getPointerTo(AS)); NewSI = IRB.CreateAlignedStore(V, NewPtr, getSliceAlign(V->getType()), SI.isVolatile()); } NewSI->copyMetadata(SI, LLVMContext::MD_mem_parallel_loop_access); if (SI.isVolatile()) - NewSI->setAtomic(SI.getOrdering(), SI.getSynchScope()); + NewSI->setAtomic(SI.getOrdering(), SI.getSyncScopeID()); Pass.DeadInsts.insert(&SI); deleteIfTriviallyDead(OldOp); @@ -3561,10 +3589,11 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { int Idx = 0, Size = Offsets.Splits.size(); for (;;) { auto *PartTy = Type::getIntNTy(Ty->getContext(), PartSize * 8); - auto *PartPtrTy = PartTy->getPointerTo(LI->getPointerAddressSpace()); + auto AS = LI->getPointerAddressSpace(); + auto *PartPtrTy = PartTy->getPointerTo(AS); LoadInst *PLoad = IRB.CreateAlignedLoad( getAdjustedPtr(IRB, DL, BasePtr, - APInt(DL.getPointerSizeInBits(), PartOffset), + APInt(DL.getPointerSizeInBits(AS), PartOffset), PartPtrTy, BasePtr->getName() + "."), getAdjustedAlignment(LI, PartOffset, DL), /*IsVolatile*/ false, LI->getName()); @@ -3616,10 +3645,12 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { auto *PartPtrTy = PLoad->getType()->getPointerTo(SI->getPointerAddressSpace()); + auto AS = SI->getPointerAddressSpace(); StoreInst *PStore = IRB.CreateAlignedStore( - PLoad, getAdjustedPtr(IRB, DL, StoreBasePtr, - APInt(DL.getPointerSizeInBits(), PartOffset), - PartPtrTy, StoreBasePtr->getName() + "."), + PLoad, + getAdjustedPtr(IRB, DL, StoreBasePtr, + APInt(DL.getPointerSizeInBits(AS), PartOffset), + PartPtrTy, StoreBasePtr->getName() + "."), getAdjustedAlignment(SI, PartOffset, DL), /*IsVolatile*/ false); PStore->copyMetadata(*LI, LLVMContext::MD_mem_parallel_loop_access); DEBUG(dbgs() << " +" << PartOffset << ":" << *PStore << "\n"); @@ -3688,7 +3719,8 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { int Idx = 0, Size = Offsets.Splits.size(); for (;;) { auto *PartTy = Type::getIntNTy(Ty->getContext(), PartSize * 8); - auto *PartPtrTy = PartTy->getPointerTo(SI->getPointerAddressSpace()); + auto *LoadPartPtrTy = PartTy->getPointerTo(LI->getPointerAddressSpace()); + auto *StorePartPtrTy = PartTy->getPointerTo(SI->getPointerAddressSpace()); // Either lookup a split load or create one. LoadInst *PLoad; @@ -3696,20 +3728,23 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { PLoad = (*SplitLoads)[Idx]; } else { IRB.SetInsertPoint(LI); + auto AS = LI->getPointerAddressSpace(); PLoad = IRB.CreateAlignedLoad( getAdjustedPtr(IRB, DL, LoadBasePtr, - APInt(DL.getPointerSizeInBits(), PartOffset), - PartPtrTy, LoadBasePtr->getName() + "."), + APInt(DL.getPointerSizeInBits(AS), PartOffset), + LoadPartPtrTy, LoadBasePtr->getName() + "."), getAdjustedAlignment(LI, PartOffset, DL), /*IsVolatile*/ false, LI->getName()); } // And store this partition. IRB.SetInsertPoint(SI); + auto AS = SI->getPointerAddressSpace(); StoreInst *PStore = IRB.CreateAlignedStore( - PLoad, getAdjustedPtr(IRB, DL, StoreBasePtr, - APInt(DL.getPointerSizeInBits(), PartOffset), - PartPtrTy, StoreBasePtr->getName() + "."), + PLoad, + getAdjustedPtr(IRB, DL, StoreBasePtr, + APInt(DL.getPointerSizeInBits(AS), PartOffset), + StorePartPtrTy, StoreBasePtr->getName() + "."), getAdjustedAlignment(SI, PartOffset, DL), /*IsVolatile*/ false); // Now build a new slice for the alloca. @@ -3857,7 +3892,7 @@ AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS, if (Alignment <= DL.getABITypeAlignment(SliceTy)) Alignment = 0; NewAI = new AllocaInst( - SliceTy, nullptr, Alignment, + SliceTy, AI.getType()->getAddressSpace(), nullptr, Alignment, AI.getName() + ".sroa." + Twine(P.begin() - AS.begin()), &AI); ++NumNewAllocas; } @@ -3871,8 +3906,8 @@ AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS, // fact scheduled for promotion. unsigned PPWOldSize = PostPromotionWorklist.size(); unsigned NumUses = 0; - SmallPtrSet<PHINode *, 8> PHIUsers; - SmallPtrSet<SelectInst *, 8> SelectUsers; + SmallSetVector<PHINode *, 8> PHIUsers; + SmallSetVector<SelectInst *, 8> SelectUsers; AllocaSliceRewriter Rewriter(DL, AS, *this, AI, *NewAI, P.beginOffset(), P.endOffset(), IsIntegerPromotable, VecTy, @@ -3888,24 +3923,20 @@ AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS, } NumAllocaPartitionUses += NumUses; - MaxUsesPerAllocaPartition = - std::max<unsigned>(NumUses, MaxUsesPerAllocaPartition); + MaxUsesPerAllocaPartition.updateMax(NumUses); // Now that we've processed all the slices in the new partition, check if any // PHIs or Selects would block promotion. - for (SmallPtrSetImpl<PHINode *>::iterator I = PHIUsers.begin(), - E = PHIUsers.end(); - I != E; ++I) - if (!isSafePHIToSpeculate(**I)) { + for (PHINode *PHI : PHIUsers) + if (!isSafePHIToSpeculate(*PHI)) { Promotable = false; PHIUsers.clear(); SelectUsers.clear(); break; } - for (SmallPtrSetImpl<SelectInst *>::iterator I = SelectUsers.begin(), - E = SelectUsers.end(); - I != E; ++I) - if (!isSafeSelectToSpeculate(**I)) { + + for (SelectInst *Sel : SelectUsers) + if (!isSafeSelectToSpeculate(*Sel)) { Promotable = false; PHIUsers.clear(); SelectUsers.clear(); @@ -4009,8 +4040,7 @@ bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { } NumAllocaPartitions += NumPartitions; - MaxPartitionsPerAlloca = - std::max<unsigned>(NumPartitions, MaxPartitionsPerAlloca); + MaxPartitionsPerAlloca.updateMax(NumPartitions); // Migrate debug information from the old alloca to the new alloca(s) // and the individual partitions. @@ -4184,7 +4214,7 @@ bool SROA::promoteAllocas(Function &F) { NumPromoted += PromotableAllocas.size(); DEBUG(dbgs() << "Promoting allocas with mem2reg...\n"); - PromoteMemToReg(PromotableAllocas, *DT, nullptr, AC); + PromoteMemToReg(PromotableAllocas, *DT, AC); PromotableAllocas.clear(); return true; } @@ -4234,9 +4264,8 @@ PreservedAnalyses SROA::runImpl(Function &F, DominatorTree &RunDT, if (!Changed) return PreservedAnalyses::all(); - // FIXME: Even when promoting allocas we should preserve some abstract set of - // CFG-specific analyses. PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); PA.preserve<GlobalsAA>(); return PA; } diff --git a/contrib/llvm/lib/Transforms/Scalar/Scalar.cpp b/contrib/llvm/lib/Transforms/Scalar/Scalar.cpp index afe7483..ce6f93e 100644 --- a/contrib/llvm/lib/Transforms/Scalar/Scalar.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/Scalar.cpp @@ -20,11 +20,12 @@ #include "llvm/Analysis/Passes.h" #include "llvm/Analysis/ScopedNoAliasAA.h" #include "llvm/Analysis/TypeBasedAliasAnalysis.h" -#include "llvm/Transforms/Scalar/GVN.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Verifier.h" #include "llvm/InitializePasses.h" -#include "llvm/IR/LegacyPassManager.h" +#include "llvm/Transforms/Scalar/GVN.h" +#include "llvm/Transforms/Scalar/SimpleLoopUnswitch.h" using namespace llvm; @@ -43,13 +44,15 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) { initializeDSELegacyPassPass(Registry); initializeGuardWideningLegacyPassPass(Registry); initializeGVNLegacyPassPass(Registry); - initializeNewGVNPass(Registry); + initializeNewGVNLegacyPassPass(Registry); initializeEarlyCSELegacyPassPass(Registry); initializeEarlyCSEMemSSALegacyPassPass(Registry); initializeGVNHoistLegacyPassPass(Registry); + initializeGVNSinkLegacyPassPass(Registry); initializeFlattenCFGPassPass(Registry); initializeInductiveRangeCheckEliminationPass(Registry); initializeIndVarSimplifyLegacyPassPass(Registry); + initializeInferAddressSpacesPass(Registry); initializeJumpThreadingPass(Registry); initializeLegacyLICMPassPass(Registry); initializeLegacyLoopSinkPassPass(Registry); @@ -58,6 +61,7 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) { initializeLoopAccessLegacyAnalysisPass(Registry); initializeLoopInstSimplifyLegacyPassPass(Registry); initializeLoopInterchangePass(Registry); + initializeLoopPredicationLegacyPassPass(Registry); initializeLoopRotateLegacyPassPass(Registry); initializeLoopStrengthReducePass(Registry); initializeLoopRerollPass(Registry); @@ -79,13 +83,14 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) { initializeIPSCCPLegacyPassPass(Registry); initializeSROALegacyPassPass(Registry); initializeCFGSimplifyPassPass(Registry); + initializeLateCFGSimplifyPassPass(Registry); initializeStructurizeCFGPass(Registry); + initializeSimpleLoopUnswitchLegacyPassPass(Registry); initializeSinkingLegacyPassPass(Registry); initializeTailCallElimPass(Registry); initializeSeparateConstOffsetFromGEPPass(Registry); initializeSpeculativeExecutionLegacyPassPass(Registry); initializeStraightLineStrengthReducePass(Registry); - initializeLoadCombinePass(Registry); initializePlaceBackedgeSafepointsImplPass(Registry); initializePlaceSafepointsPass(Registry); initializeFloat2IntLegacyPassPass(Registry); @@ -115,6 +120,10 @@ void LLVMAddCFGSimplificationPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createCFGSimplificationPass()); } +void LLVMAddLateCFGSimplificationPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createLateCFGSimplificationPass()); +} + void LLVMAddDeadStoreEliminationPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createDeadStoreEliminationPass()); } diff --git a/contrib/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/contrib/llvm/lib/Transforms/Scalar/Scalarizer.cpp index 39969e2..d11855f 100644 --- a/contrib/llvm/lib/Transforms/Scalar/Scalarizer.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/Scalarizer.cpp @@ -14,12 +14,12 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstVisitor.h" #include "llvm/Pass.h" +#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" using namespace llvm; @@ -520,12 +520,25 @@ bool Scalarizer::visitGetElementPtrInst(GetElementPtrInst &GEPI) { unsigned NumElems = VT->getNumElements(); unsigned NumIndices = GEPI.getNumIndices(); - Scatterer Base = scatter(&GEPI, GEPI.getOperand(0)); + // The base pointer might be scalar even if it's a vector GEP. In those cases, + // splat the pointer into a vector value, and scatter that vector. + Value *Op0 = GEPI.getOperand(0); + if (!Op0->getType()->isVectorTy()) + Op0 = Builder.CreateVectorSplat(NumElems, Op0); + Scatterer Base = scatter(&GEPI, Op0); SmallVector<Scatterer, 8> Ops; Ops.resize(NumIndices); - for (unsigned I = 0; I < NumIndices; ++I) - Ops[I] = scatter(&GEPI, GEPI.getOperand(I + 1)); + for (unsigned I = 0; I < NumIndices; ++I) { + Value *Op = GEPI.getOperand(I + 1); + + // The indices might be scalars even if it's a vector GEP. In those cases, + // splat the scalar into a vector value, and scatter that vector. + if (!Op->getType()->isVectorTy()) + Op = Builder.CreateVectorSplat(NumElems, Op); + + Ops[I] = scatter(&GEPI, Op); + } ValueVector Res; Res.resize(NumElems); diff --git a/contrib/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp b/contrib/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp index 4d59453..84675f4 100644 --- a/contrib/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp @@ -156,27 +156,27 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" -#include "llvm/IR/PatternMatch.h" #include "llvm/IR/Operator.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Target/TargetSubtargetInfo.h" -#include "llvm/IR/IRBuilder.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" using namespace llvm; using namespace llvm::PatternMatch; @@ -1138,7 +1138,7 @@ bool SeparateConstOffsetFromGEP::reuniteExts(Instruction *I) { // Add I to DominatingExprs if it's an add/sub that can't sign overflow. if (match(I, m_NSWAdd(m_Value(LHS), m_Value(RHS))) || match(I, m_NSWSub(m_Value(LHS), m_Value(RHS)))) { - if (isKnownNotFullPoison(I)) { + if (programUndefinedIfFullPoison(I)) { const SCEV *Key = SE->getAddExpr(SE->getUnknown(LHS), SE->getUnknown(RHS)); DominatingExprs[Key].push_back(I); diff --git a/contrib/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/contrib/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp new file mode 100644 index 0000000..aaab585 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -0,0 +1,808 @@ +//===- SimpleLoopUnswitch.cpp - Hoist loop-invariant control flow ---------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/SimpleLoopUnswitch.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/LoopAnalysisManager.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/GenericDomTree.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/LoopUtils.h" +#include <algorithm> +#include <cassert> +#include <iterator> +#include <utility> + +#define DEBUG_TYPE "simple-loop-unswitch" + +using namespace llvm; + +STATISTIC(NumBranches, "Number of branches unswitched"); +STATISTIC(NumSwitches, "Number of switches unswitched"); +STATISTIC(NumTrivial, "Number of unswitches that are trivial"); + +static void replaceLoopUsesWithConstant(Loop &L, Value &LIC, + Constant &Replacement) { + assert(!isa<Constant>(LIC) && "Why are we unswitching on a constant?"); + + // Replace uses of LIC in the loop with the given constant. + for (auto UI = LIC.use_begin(), UE = LIC.use_end(); UI != UE;) { + // Grab the use and walk past it so we can clobber it in the use list. + Use *U = &*UI++; + Instruction *UserI = dyn_cast<Instruction>(U->getUser()); + if (!UserI || !L.contains(UserI)) + continue; + + // Replace this use within the loop body. + *U = &Replacement; + } +} + +/// Update the dominator tree after removing one exiting predecessor of a loop +/// exit block. +static void updateLoopExitIDom(BasicBlock *LoopExitBB, Loop &L, + DominatorTree &DT) { + assert(pred_begin(LoopExitBB) != pred_end(LoopExitBB) && + "Cannot have empty predecessors of the loop exit block if we split " + "off a block to unswitch!"); + + BasicBlock *IDom = *pred_begin(LoopExitBB); + // Walk all of the other predecessors finding the nearest common dominator + // until all predecessors are covered or we reach the loop header. The loop + // header necessarily dominates all loop exit blocks in loop simplified form + // so we can early-exit the moment we hit that block. + for (auto PI = std::next(pred_begin(LoopExitBB)), PE = pred_end(LoopExitBB); + PI != PE && IDom != L.getHeader(); ++PI) + IDom = DT.findNearestCommonDominator(IDom, *PI); + + DT.changeImmediateDominator(LoopExitBB, IDom); +} + +/// Update the dominator tree after unswitching a particular former exit block. +/// +/// This handles the full update of the dominator tree after hoisting a block +/// that previously was an exit block (or split off of an exit block) up to be +/// reached from the new immediate dominator of the preheader. +/// +/// The common case is simple -- we just move the unswitched block to have an +/// immediate dominator of the old preheader. But in complex cases, there may +/// be other blocks reachable from the unswitched block that are immediately +/// dominated by some node between the unswitched one and the old preheader. +/// All of these also need to be hoisted in the dominator tree. We also want to +/// minimize queries to the dominator tree because each step of this +/// invalidates any DFS numbers that would make queries fast. +static void updateDTAfterUnswitch(BasicBlock *UnswitchedBB, BasicBlock *OldPH, + DominatorTree &DT) { + DomTreeNode *OldPHNode = DT[OldPH]; + DomTreeNode *UnswitchedNode = DT[UnswitchedBB]; + // If the dominator tree has already been updated for this unswitched node, + // we're done. This makes it easier to use this routine if there are multiple + // paths to the same unswitched destination. + if (UnswitchedNode->getIDom() == OldPHNode) + return; + + // First collect the domtree nodes that we are hoisting over. These are the + // set of nodes which may have children that need to be hoisted as well. + SmallPtrSet<DomTreeNode *, 4> DomChain; + for (auto *IDom = UnswitchedNode->getIDom(); IDom != OldPHNode; + IDom = IDom->getIDom()) + DomChain.insert(IDom); + + // The unswitched block ends up immediately dominated by the old preheader -- + // regardless of whether it is the loop exit block or split off of the loop + // exit block. + DT.changeImmediateDominator(UnswitchedNode, OldPHNode); + + // For everything that moves up the dominator tree, we need to examine the + // dominator frontier to see if it additionally should move up the dominator + // tree. This lambda appends the dominator frontier for a node on the + // worklist. + // + // Note that we don't currently use the IDFCalculator here for two reasons: + // 1) It computes dominator tree levels for the entire function on each run + // of 'compute'. While this isn't terrible, given that we expect to update + // relatively small subtrees of the domtree, it isn't necessarily the right + // tradeoff. + // 2) The interface doesn't fit this usage well. It doesn't operate in + // append-only, and builds several sets that we don't need. + // + // FIXME: Neither of these issues are a big deal and could be addressed with + // some amount of refactoring of IDFCalculator. That would allow us to share + // the core logic here (which is solving the same core problem). + SmallSetVector<BasicBlock *, 4> Worklist; + SmallVector<DomTreeNode *, 4> DomNodes; + SmallPtrSet<BasicBlock *, 4> DomSet; + auto AppendDomFrontier = [&](DomTreeNode *Node) { + assert(DomNodes.empty() && "Must start with no dominator nodes."); + assert(DomSet.empty() && "Must start with an empty dominator set."); + + // First flatten this subtree into sequence of nodes by doing a pre-order + // walk. + DomNodes.push_back(Node); + // We intentionally re-evaluate the size as each node can add new children. + // Because this is a tree walk, this cannot add any duplicates. + for (int i = 0; i < (int)DomNodes.size(); ++i) + DomNodes.insert(DomNodes.end(), DomNodes[i]->begin(), DomNodes[i]->end()); + + // Now create a set of the basic blocks so we can quickly test for + // dominated successors. We could in theory use the DFS numbers of the + // dominator tree for this, but we want this to remain predictably fast + // even while we mutate the dominator tree in ways that would invalidate + // the DFS numbering. + for (DomTreeNode *InnerN : DomNodes) + DomSet.insert(InnerN->getBlock()); + + // Now re-walk the nodes, appending every successor of every node that isn't + // in the set. Note that we don't append the node itself, even though if it + // is a successor it does not strictly dominate itself and thus it would be + // part of the dominance frontier. The reason we don't append it is that + // the node passed in came *from* the worklist and so it has already been + // processed. + for (DomTreeNode *InnerN : DomNodes) + for (BasicBlock *SuccBB : successors(InnerN->getBlock())) + if (!DomSet.count(SuccBB)) + Worklist.insert(SuccBB); + + DomNodes.clear(); + DomSet.clear(); + }; + + // Append the initial dom frontier nodes. + AppendDomFrontier(UnswitchedNode); + + // Walk the worklist. We grow the list in the loop and so must recompute size. + for (int i = 0; i < (int)Worklist.size(); ++i) { + auto *BB = Worklist[i]; + + DomTreeNode *Node = DT[BB]; + assert(!DomChain.count(Node) && + "Cannot be dominated by a block you can reach!"); + + // If this block had an immediate dominator somewhere in the chain + // we hoisted over, then its position in the domtree needs to move as it is + // reachable from a node hoisted over this chain. + if (!DomChain.count(Node->getIDom())) + continue; + + DT.changeImmediateDominator(Node, OldPHNode); + + // Now add this node's dominator frontier to the worklist as well. + AppendDomFrontier(Node); + } +} + +/// Check that all the LCSSA PHI nodes in the loop exit block have trivial +/// incoming values along this edge. +static bool areLoopExitPHIsLoopInvariant(Loop &L, BasicBlock &ExitingBB, + BasicBlock &ExitBB) { + for (Instruction &I : ExitBB) { + auto *PN = dyn_cast<PHINode>(&I); + if (!PN) + // No more PHIs to check. + return true; + + // If the incoming value for this edge isn't loop invariant the unswitch + // won't be trivial. + if (!L.isLoopInvariant(PN->getIncomingValueForBlock(&ExitingBB))) + return false; + } + llvm_unreachable("Basic blocks should never be empty!"); +} + +/// Rewrite the PHI nodes in an unswitched loop exit basic block. +/// +/// Requires that the loop exit and unswitched basic block are the same, and +/// that the exiting block was a unique predecessor of that block. Rewrites the +/// PHI nodes in that block such that what were LCSSA PHI nodes become trivial +/// PHI nodes from the old preheader that now contains the unswitched +/// terminator. +static void rewritePHINodesForUnswitchedExitBlock(BasicBlock &UnswitchedBB, + BasicBlock &OldExitingBB, + BasicBlock &OldPH) { + for (Instruction &I : UnswitchedBB) { + auto *PN = dyn_cast<PHINode>(&I); + if (!PN) + // No more PHIs to check. + break; + + // When the loop exit is directly unswitched we just need to update the + // incoming basic block. We loop to handle weird cases with repeated + // incoming blocks, but expect to typically only have one operand here. + for (auto i : seq<int>(0, PN->getNumOperands())) { + assert(PN->getIncomingBlock(i) == &OldExitingBB && + "Found incoming block different from unique predecessor!"); + PN->setIncomingBlock(i, &OldPH); + } + } +} + +/// Rewrite the PHI nodes in the loop exit basic block and the split off +/// unswitched block. +/// +/// Because the exit block remains an exit from the loop, this rewrites the +/// LCSSA PHI nodes in it to remove the unswitched edge and introduces PHI +/// nodes into the unswitched basic block to select between the value in the +/// old preheader and the loop exit. +static void rewritePHINodesForExitAndUnswitchedBlocks(BasicBlock &ExitBB, + BasicBlock &UnswitchedBB, + BasicBlock &OldExitingBB, + BasicBlock &OldPH) { + assert(&ExitBB != &UnswitchedBB && + "Must have different loop exit and unswitched blocks!"); + Instruction *InsertPt = &*UnswitchedBB.begin(); + for (Instruction &I : ExitBB) { + auto *PN = dyn_cast<PHINode>(&I); + if (!PN) + // No more PHIs to check. + break; + + auto *NewPN = PHINode::Create(PN->getType(), /*NumReservedValues*/ 2, + PN->getName() + ".split", InsertPt); + + // Walk backwards over the old PHI node's inputs to minimize the cost of + // removing each one. We have to do this weird loop manually so that we + // create the same number of new incoming edges in the new PHI as we expect + // each case-based edge to be included in the unswitched switch in some + // cases. + // FIXME: This is really, really gross. It would be much cleaner if LLVM + // allowed us to create a single entry for a predecessor block without + // having separate entries for each "edge" even though these edges are + // required to produce identical results. + for (int i = PN->getNumIncomingValues() - 1; i >= 0; --i) { + if (PN->getIncomingBlock(i) != &OldExitingBB) + continue; + + Value *Incoming = PN->removeIncomingValue(i); + NewPN->addIncoming(Incoming, &OldPH); + } + + // Now replace the old PHI with the new one and wire the old one in as an + // input to the new one. + PN->replaceAllUsesWith(NewPN); + NewPN->addIncoming(PN, &ExitBB); + } +} + +/// Unswitch a trivial branch if the condition is loop invariant. +/// +/// This routine should only be called when loop code leading to the branch has +/// been validated as trivial (no side effects). This routine checks if the +/// condition is invariant and one of the successors is a loop exit. This +/// allows us to unswitch without duplicating the loop, making it trivial. +/// +/// If this routine fails to unswitch the branch it returns false. +/// +/// If the branch can be unswitched, this routine splits the preheader and +/// hoists the branch above that split. Preserves loop simplified form +/// (splitting the exit block as necessary). It simplifies the branch within +/// the loop to an unconditional branch but doesn't remove it entirely. Further +/// cleanup can be done with some simplify-cfg like pass. +static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT, + LoopInfo &LI) { + assert(BI.isConditional() && "Can only unswitch a conditional branch!"); + DEBUG(dbgs() << " Trying to unswitch branch: " << BI << "\n"); + + Value *LoopCond = BI.getCondition(); + + // Need a trivial loop condition to unswitch. + if (!L.isLoopInvariant(LoopCond)) + return false; + + // FIXME: We should compute this once at the start and update it! + SmallVector<BasicBlock *, 16> ExitBlocks; + L.getExitBlocks(ExitBlocks); + SmallPtrSet<BasicBlock *, 16> ExitBlockSet(ExitBlocks.begin(), + ExitBlocks.end()); + + // Check to see if a successor of the branch is guaranteed to + // exit through a unique exit block without having any + // side-effects. If so, determine the value of Cond that causes + // it to do this. + ConstantInt *CondVal = ConstantInt::getTrue(BI.getContext()); + ConstantInt *Replacement = ConstantInt::getFalse(BI.getContext()); + int LoopExitSuccIdx = 0; + auto *LoopExitBB = BI.getSuccessor(0); + if (!ExitBlockSet.count(LoopExitBB)) { + std::swap(CondVal, Replacement); + LoopExitSuccIdx = 1; + LoopExitBB = BI.getSuccessor(1); + if (!ExitBlockSet.count(LoopExitBB)) + return false; + } + auto *ContinueBB = BI.getSuccessor(1 - LoopExitSuccIdx); + assert(L.contains(ContinueBB) && + "Cannot have both successors exit and still be in the loop!"); + + auto *ParentBB = BI.getParent(); + if (!areLoopExitPHIsLoopInvariant(L, *ParentBB, *LoopExitBB)) + return false; + + DEBUG(dbgs() << " unswitching trivial branch when: " << CondVal + << " == " << LoopCond << "\n"); + + // Split the preheader, so that we know that there is a safe place to insert + // the conditional branch. We will change the preheader to have a conditional + // branch on LoopCond. + BasicBlock *OldPH = L.getLoopPreheader(); + BasicBlock *NewPH = SplitEdge(OldPH, L.getHeader(), &DT, &LI); + + // Now that we have a place to insert the conditional branch, create a place + // to branch to: this is the exit block out of the loop that we are + // unswitching. We need to split this if there are other loop predecessors. + // Because the loop is in simplified form, *any* other predecessor is enough. + BasicBlock *UnswitchedBB; + if (BasicBlock *PredBB = LoopExitBB->getUniquePredecessor()) { + (void)PredBB; + assert(PredBB == BI.getParent() && + "A branch's parent isn't a predecessor!"); + UnswitchedBB = LoopExitBB; + } else { + UnswitchedBB = SplitBlock(LoopExitBB, &LoopExitBB->front(), &DT, &LI); + } + + // Now splice the branch to gate reaching the new preheader and re-point its + // successors. + OldPH->getInstList().splice(std::prev(OldPH->end()), + BI.getParent()->getInstList(), BI); + OldPH->getTerminator()->eraseFromParent(); + BI.setSuccessor(LoopExitSuccIdx, UnswitchedBB); + BI.setSuccessor(1 - LoopExitSuccIdx, NewPH); + + // Create a new unconditional branch that will continue the loop as a new + // terminator. + BranchInst::Create(ContinueBB, ParentBB); + + // Rewrite the relevant PHI nodes. + if (UnswitchedBB == LoopExitBB) + rewritePHINodesForUnswitchedExitBlock(*UnswitchedBB, *ParentBB, *OldPH); + else + rewritePHINodesForExitAndUnswitchedBlocks(*LoopExitBB, *UnswitchedBB, + *ParentBB, *OldPH); + + // Now we need to update the dominator tree. + updateDTAfterUnswitch(UnswitchedBB, OldPH, DT); + // But if we split something off of the loop exit block then we also removed + // one of the predecessors for the loop exit block and may need to update its + // idom. + if (UnswitchedBB != LoopExitBB) + updateLoopExitIDom(LoopExitBB, L, DT); + + // Since this is an i1 condition we can also trivially replace uses of it + // within the loop with a constant. + replaceLoopUsesWithConstant(L, *LoopCond, *Replacement); + + ++NumTrivial; + ++NumBranches; + return true; +} + +/// Unswitch a trivial switch if the condition is loop invariant. +/// +/// This routine should only be called when loop code leading to the switch has +/// been validated as trivial (no side effects). This routine checks if the +/// condition is invariant and that at least one of the successors is a loop +/// exit. This allows us to unswitch without duplicating the loop, making it +/// trivial. +/// +/// If this routine fails to unswitch the switch it returns false. +/// +/// If the switch can be unswitched, this routine splits the preheader and +/// copies the switch above that split. If the default case is one of the +/// exiting cases, it copies the non-exiting cases and points them at the new +/// preheader. If the default case is not exiting, it copies the exiting cases +/// and points the default at the preheader. It preserves loop simplified form +/// (splitting the exit blocks as necessary). It simplifies the switch within +/// the loop by removing now-dead cases. If the default case is one of those +/// unswitched, it replaces its destination with a new basic block containing +/// only unreachable. Such basic blocks, while technically loop exits, are not +/// considered for unswitching so this is a stable transform and the same +/// switch will not be revisited. If after unswitching there is only a single +/// in-loop successor, the switch is further simplified to an unconditional +/// branch. Still more cleanup can be done with some simplify-cfg like pass. +static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, + LoopInfo &LI) { + DEBUG(dbgs() << " Trying to unswitch switch: " << SI << "\n"); + Value *LoopCond = SI.getCondition(); + + // If this isn't switching on an invariant condition, we can't unswitch it. + if (!L.isLoopInvariant(LoopCond)) + return false; + + auto *ParentBB = SI.getParent(); + + // FIXME: We should compute this once at the start and update it! + SmallVector<BasicBlock *, 16> ExitBlocks; + L.getExitBlocks(ExitBlocks); + SmallPtrSet<BasicBlock *, 16> ExitBlockSet(ExitBlocks.begin(), + ExitBlocks.end()); + + SmallVector<int, 4> ExitCaseIndices; + for (auto Case : SI.cases()) { + auto *SuccBB = Case.getCaseSuccessor(); + if (ExitBlockSet.count(SuccBB) && + areLoopExitPHIsLoopInvariant(L, *ParentBB, *SuccBB)) + ExitCaseIndices.push_back(Case.getCaseIndex()); + } + BasicBlock *DefaultExitBB = nullptr; + if (ExitBlockSet.count(SI.getDefaultDest()) && + areLoopExitPHIsLoopInvariant(L, *ParentBB, *SI.getDefaultDest()) && + !isa<UnreachableInst>(SI.getDefaultDest()->getTerminator())) + DefaultExitBB = SI.getDefaultDest(); + else if (ExitCaseIndices.empty()) + return false; + + DEBUG(dbgs() << " unswitching trivial cases...\n"); + + SmallVector<std::pair<ConstantInt *, BasicBlock *>, 4> ExitCases; + ExitCases.reserve(ExitCaseIndices.size()); + // We walk the case indices backwards so that we remove the last case first + // and don't disrupt the earlier indices. + for (unsigned Index : reverse(ExitCaseIndices)) { + auto CaseI = SI.case_begin() + Index; + // Save the value of this case. + ExitCases.push_back({CaseI->getCaseValue(), CaseI->getCaseSuccessor()}); + // Delete the unswitched cases. + SI.removeCase(CaseI); + } + + // Check if after this all of the remaining cases point at the same + // successor. + BasicBlock *CommonSuccBB = nullptr; + if (SI.getNumCases() > 0 && + std::all_of(std::next(SI.case_begin()), SI.case_end(), + [&SI](const SwitchInst::CaseHandle &Case) { + return Case.getCaseSuccessor() == + SI.case_begin()->getCaseSuccessor(); + })) + CommonSuccBB = SI.case_begin()->getCaseSuccessor(); + + if (DefaultExitBB) { + // We can't remove the default edge so replace it with an edge to either + // the single common remaining successor (if we have one) or an unreachable + // block. + if (CommonSuccBB) { + SI.setDefaultDest(CommonSuccBB); + } else { + BasicBlock *UnreachableBB = BasicBlock::Create( + ParentBB->getContext(), + Twine(ParentBB->getName()) + ".unreachable_default", + ParentBB->getParent()); + new UnreachableInst(ParentBB->getContext(), UnreachableBB); + SI.setDefaultDest(UnreachableBB); + DT.addNewBlock(UnreachableBB, ParentBB); + } + } else { + // If we're not unswitching the default, we need it to match any cases to + // have a common successor or if we have no cases it is the common + // successor. + if (SI.getNumCases() == 0) + CommonSuccBB = SI.getDefaultDest(); + else if (SI.getDefaultDest() != CommonSuccBB) + CommonSuccBB = nullptr; + } + + // Split the preheader, so that we know that there is a safe place to insert + // the switch. + BasicBlock *OldPH = L.getLoopPreheader(); + BasicBlock *NewPH = SplitEdge(OldPH, L.getHeader(), &DT, &LI); + OldPH->getTerminator()->eraseFromParent(); + + // Now add the unswitched switch. + auto *NewSI = SwitchInst::Create(LoopCond, NewPH, ExitCases.size(), OldPH); + + // Rewrite the IR for the unswitched basic blocks. This requires two steps. + // First, we split any exit blocks with remaining in-loop predecessors. Then + // we update the PHIs in one of two ways depending on if there was a split. + // We walk in reverse so that we split in the same order as the cases + // appeared. This is purely for convenience of reading the resulting IR, but + // it doesn't cost anything really. + SmallPtrSet<BasicBlock *, 2> UnswitchedExitBBs; + SmallDenseMap<BasicBlock *, BasicBlock *, 2> SplitExitBBMap; + // Handle the default exit if necessary. + // FIXME: It'd be great if we could merge this with the loop below but LLVM's + // ranges aren't quite powerful enough yet. + if (DefaultExitBB) { + if (pred_empty(DefaultExitBB)) { + UnswitchedExitBBs.insert(DefaultExitBB); + rewritePHINodesForUnswitchedExitBlock(*DefaultExitBB, *ParentBB, *OldPH); + } else { + auto *SplitBB = + SplitBlock(DefaultExitBB, &DefaultExitBB->front(), &DT, &LI); + rewritePHINodesForExitAndUnswitchedBlocks(*DefaultExitBB, *SplitBB, + *ParentBB, *OldPH); + updateLoopExitIDom(DefaultExitBB, L, DT); + DefaultExitBB = SplitExitBBMap[DefaultExitBB] = SplitBB; + } + } + // Note that we must use a reference in the for loop so that we update the + // container. + for (auto &CasePair : reverse(ExitCases)) { + // Grab a reference to the exit block in the pair so that we can update it. + BasicBlock *ExitBB = CasePair.second; + + // If this case is the last edge into the exit block, we can simply reuse it + // as it will no longer be a loop exit. No mapping necessary. + if (pred_empty(ExitBB)) { + // Only rewrite once. + if (UnswitchedExitBBs.insert(ExitBB).second) + rewritePHINodesForUnswitchedExitBlock(*ExitBB, *ParentBB, *OldPH); + continue; + } + + // Otherwise we need to split the exit block so that we retain an exit + // block from the loop and a target for the unswitched condition. + BasicBlock *&SplitExitBB = SplitExitBBMap[ExitBB]; + if (!SplitExitBB) { + // If this is the first time we see this, do the split and remember it. + SplitExitBB = SplitBlock(ExitBB, &ExitBB->front(), &DT, &LI); + rewritePHINodesForExitAndUnswitchedBlocks(*ExitBB, *SplitExitBB, + *ParentBB, *OldPH); + updateLoopExitIDom(ExitBB, L, DT); + } + // Update the case pair to point to the split block. + CasePair.second = SplitExitBB; + } + + // Now add the unswitched cases. We do this in reverse order as we built them + // in reverse order. + for (auto CasePair : reverse(ExitCases)) { + ConstantInt *CaseVal = CasePair.first; + BasicBlock *UnswitchedBB = CasePair.second; + + NewSI->addCase(CaseVal, UnswitchedBB); + updateDTAfterUnswitch(UnswitchedBB, OldPH, DT); + } + + // If the default was unswitched, re-point it and add explicit cases for + // entering the loop. + if (DefaultExitBB) { + NewSI->setDefaultDest(DefaultExitBB); + updateDTAfterUnswitch(DefaultExitBB, OldPH, DT); + + // We removed all the exit cases, so we just copy the cases to the + // unswitched switch. + for (auto Case : SI.cases()) + NewSI->addCase(Case.getCaseValue(), NewPH); + } + + // If we ended up with a common successor for every path through the switch + // after unswitching, rewrite it to an unconditional branch to make it easy + // to recognize. Otherwise we potentially have to recognize the default case + // pointing at unreachable and other complexity. + if (CommonSuccBB) { + BasicBlock *BB = SI.getParent(); + SI.eraseFromParent(); + BranchInst::Create(CommonSuccBB, BB); + } + + DT.verifyDomTree(); + ++NumTrivial; + ++NumSwitches; + return true; +} + +/// This routine scans the loop to find a branch or switch which occurs before +/// any side effects occur. These can potentially be unswitched without +/// duplicating the loop. If a branch or switch is successfully unswitched the +/// scanning continues to see if subsequent branches or switches have become +/// trivial. Once all trivial candidates have been unswitched, this routine +/// returns. +/// +/// The return value indicates whether anything was unswitched (and therefore +/// changed). +static bool unswitchAllTrivialConditions(Loop &L, DominatorTree &DT, + LoopInfo &LI) { + bool Changed = false; + + // If loop header has only one reachable successor we should keep looking for + // trivial condition candidates in the successor as well. An alternative is + // to constant fold conditions and merge successors into loop header (then we + // only need to check header's terminator). The reason for not doing this in + // LoopUnswitch pass is that it could potentially break LoopPassManager's + // invariants. Folding dead branches could either eliminate the current loop + // or make other loops unreachable. LCSSA form might also not be preserved + // after deleting branches. The following code keeps traversing loop header's + // successors until it finds the trivial condition candidate (condition that + // is not a constant). Since unswitching generates branches with constant + // conditions, this scenario could be very common in practice. + BasicBlock *CurrentBB = L.getHeader(); + SmallPtrSet<BasicBlock *, 8> Visited; + Visited.insert(CurrentBB); + do { + // Check if there are any side-effecting instructions (e.g. stores, calls, + // volatile loads) in the part of the loop that the code *would* execute + // without unswitching. + if (llvm::any_of(*CurrentBB, + [](Instruction &I) { return I.mayHaveSideEffects(); })) + return Changed; + + TerminatorInst *CurrentTerm = CurrentBB->getTerminator(); + + if (auto *SI = dyn_cast<SwitchInst>(CurrentTerm)) { + // Don't bother trying to unswitch past a switch with a constant + // condition. This should be removed prior to running this pass by + // simplify-cfg. + if (isa<Constant>(SI->getCondition())) + return Changed; + + if (!unswitchTrivialSwitch(L, *SI, DT, LI)) + // Coludn't unswitch this one so we're done. + return Changed; + + // Mark that we managed to unswitch something. + Changed = true; + + // If unswitching turned the terminator into an unconditional branch then + // we can continue. The unswitching logic specifically works to fold any + // cases it can into an unconditional branch to make it easier to + // recognize here. + auto *BI = dyn_cast<BranchInst>(CurrentBB->getTerminator()); + if (!BI || BI->isConditional()) + return Changed; + + CurrentBB = BI->getSuccessor(0); + continue; + } + + auto *BI = dyn_cast<BranchInst>(CurrentTerm); + if (!BI) + // We do not understand other terminator instructions. + return Changed; + + // Don't bother trying to unswitch past an unconditional branch or a branch + // with a constant value. These should be removed by simplify-cfg prior to + // running this pass. + if (!BI->isConditional() || isa<Constant>(BI->getCondition())) + return Changed; + + // Found a trivial condition candidate: non-foldable conditional branch. If + // we fail to unswitch this, we can't do anything else that is trivial. + if (!unswitchTrivialBranch(L, *BI, DT, LI)) + return Changed; + + // Mark that we managed to unswitch something. + Changed = true; + + // We unswitched the branch. This should always leave us with an + // unconditional branch that we can follow now. + BI = cast<BranchInst>(CurrentBB->getTerminator()); + assert(!BI->isConditional() && + "Cannot form a conditional branch by unswitching1"); + CurrentBB = BI->getSuccessor(0); + + // When continuing, if we exit the loop or reach a previous visited block, + // then we can not reach any trivial condition candidates (unfoldable + // branch instructions or switch instructions) and no unswitch can happen. + } while (L.contains(CurrentBB) && Visited.insert(CurrentBB).second); + + return Changed; +} + +/// Unswitch control flow predicated on loop invariant conditions. +/// +/// This first hoists all branches or switches which are trivial (IE, do not +/// require duplicating any part of the loop) out of the loop body. It then +/// looks at other loop invariant control flows and tries to unswitch those as +/// well by cloning the loop if the result is small enough. +static bool unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, + AssumptionCache &AC) { + assert(L.isLCSSAForm(DT) && + "Loops must be in LCSSA form before unswitching."); + bool Changed = false; + + // Must be in loop simplified form: we need a preheader and dedicated exits. + if (!L.isLoopSimplifyForm()) + return false; + + // Try trivial unswitch first before loop over other basic blocks in the loop. + Changed |= unswitchAllTrivialConditions(L, DT, LI); + + // FIXME: Add support for non-trivial unswitching by cloning the loop. + + return Changed; +} + +PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, + LPMUpdater &U) { + Function &F = *L.getHeader()->getParent(); + (void)F; + + DEBUG(dbgs() << "Unswitching loop in " << F.getName() << ": " << L << "\n"); + + if (!unswitchLoop(L, AR.DT, AR.LI, AR.AC)) + return PreservedAnalyses::all(); + +#ifndef NDEBUG + // Historically this pass has had issues with the dominator tree so verify it + // in asserts builds. + AR.DT.verifyDomTree(); +#endif + return getLoopPassPreservedAnalyses(); +} + +namespace { + +class SimpleLoopUnswitchLegacyPass : public LoopPass { +public: + static char ID; // Pass ID, replacement for typeid + + explicit SimpleLoopUnswitchLegacyPass() : LoopPass(ID) { + initializeSimpleLoopUnswitchLegacyPassPass( + *PassRegistry::getPassRegistry()); + } + + bool runOnLoop(Loop *L, LPPassManager &LPM) override; + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); + getLoopAnalysisUsage(AU); + } +}; + +} // end anonymous namespace + +bool SimpleLoopUnswitchLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) { + if (skipLoop(L)) + return false; + + Function &F = *L->getHeader()->getParent(); + + DEBUG(dbgs() << "Unswitching loop in " << F.getName() << ": " << *L << "\n"); + + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + + bool Changed = unswitchLoop(*L, DT, LI, AC); + +#ifndef NDEBUG + // Historically this pass has had issues with the dominator tree so verify it + // in asserts builds. + DT.verifyDomTree(); +#endif + return Changed; +} + +char SimpleLoopUnswitchLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(SimpleLoopUnswitchLegacyPass, "simple-loop-unswitch", + "Simple unswitch loops", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(LoopPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_END(SimpleLoopUnswitchLegacyPass, "simple-loop-unswitch", + "Simple unswitch loops", false, false) + +Pass *llvm::createSimpleLoopUnswitchLegacyPass() { + return new SimpleLoopUnswitchLegacyPass(); +} diff --git a/contrib/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp b/contrib/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp index f2723bd..8754c71 100644 --- a/contrib/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp @@ -130,7 +130,8 @@ static bool mergeEmptyReturnBlocks(Function &F) { /// iterating until no more changes are made. static bool iterativelySimplifyCFG(Function &F, const TargetTransformInfo &TTI, AssumptionCache *AC, - unsigned BonusInstThreshold) { + unsigned BonusInstThreshold, + bool LateSimplifyCFG) { bool Changed = false; bool LocalChange = true; @@ -145,7 +146,7 @@ static bool iterativelySimplifyCFG(Function &F, const TargetTransformInfo &TTI, // Loop over all of the basic blocks and remove them if they are unneeded. for (Function::iterator BBIt = F.begin(); BBIt != F.end(); ) { - if (SimplifyCFG(&*BBIt++, TTI, BonusInstThreshold, AC, &LoopHeaders)) { + if (SimplifyCFG(&*BBIt++, TTI, BonusInstThreshold, AC, &LoopHeaders, LateSimplifyCFG)) { LocalChange = true; ++NumSimpl; } @@ -156,10 +157,12 @@ static bool iterativelySimplifyCFG(Function &F, const TargetTransformInfo &TTI, } static bool simplifyFunctionCFG(Function &F, const TargetTransformInfo &TTI, - AssumptionCache *AC, int BonusInstThreshold) { + AssumptionCache *AC, int BonusInstThreshold, + bool LateSimplifyCFG) { bool EverChanged = removeUnreachableBlocks(F); EverChanged |= mergeEmptyReturnBlocks(F); - EverChanged |= iterativelySimplifyCFG(F, TTI, AC, BonusInstThreshold); + EverChanged |= iterativelySimplifyCFG(F, TTI, AC, BonusInstThreshold, + LateSimplifyCFG); // If neither pass changed anything, we're done. if (!EverChanged) return false; @@ -173,7 +176,8 @@ static bool simplifyFunctionCFG(Function &F, const TargetTransformInfo &TTI, return true; do { - EverChanged = iterativelySimplifyCFG(F, TTI, AC, BonusInstThreshold); + EverChanged = iterativelySimplifyCFG(F, TTI, AC, BonusInstThreshold, + LateSimplifyCFG); EverChanged |= removeUnreachableBlocks(F); } while (EverChanged); @@ -181,17 +185,19 @@ static bool simplifyFunctionCFG(Function &F, const TargetTransformInfo &TTI, } SimplifyCFGPass::SimplifyCFGPass() - : BonusInstThreshold(UserBonusInstThreshold) {} + : BonusInstThreshold(UserBonusInstThreshold), + LateSimplifyCFG(true) {} -SimplifyCFGPass::SimplifyCFGPass(int BonusInstThreshold) - : BonusInstThreshold(BonusInstThreshold) {} +SimplifyCFGPass::SimplifyCFGPass(int BonusInstThreshold, bool LateSimplifyCFG) + : BonusInstThreshold(BonusInstThreshold), + LateSimplifyCFG(LateSimplifyCFG) {} PreservedAnalyses SimplifyCFGPass::run(Function &F, FunctionAnalysisManager &AM) { auto &TTI = AM.getResult<TargetIRAnalysis>(F); auto &AC = AM.getResult<AssumptionAnalysis>(F); - if (!simplifyFunctionCFG(F, TTI, &AC, BonusInstThreshold)) + if (!simplifyFunctionCFG(F, TTI, &AC, BonusInstThreshold, LateSimplifyCFG)) return PreservedAnalyses::all(); PreservedAnalyses PA; PA.preserve<GlobalsAA>(); @@ -199,16 +205,17 @@ PreservedAnalyses SimplifyCFGPass::run(Function &F, } namespace { -struct CFGSimplifyPass : public FunctionPass { - static char ID; // Pass identification, replacement for typeid +struct BaseCFGSimplifyPass : public FunctionPass { unsigned BonusInstThreshold; std::function<bool(const Function &)> PredicateFtor; + bool LateSimplifyCFG; - CFGSimplifyPass(int T = -1, - std::function<bool(const Function &)> Ftor = nullptr) - : FunctionPass(ID), PredicateFtor(std::move(Ftor)) { + BaseCFGSimplifyPass(int T, bool LateSimplifyCFG, + std::function<bool(const Function &)> Ftor, + char &ID) + : FunctionPass(ID), PredicateFtor(std::move(Ftor)), + LateSimplifyCFG(LateSimplifyCFG) { BonusInstThreshold = (T == -1) ? UserBonusInstThreshold : unsigned(T); - initializeCFGSimplifyPassPass(*PassRegistry::getPassRegistry()); } bool runOnFunction(Function &F) override { if (skipFunction(F) || (PredicateFtor && !PredicateFtor(F))) @@ -218,7 +225,7 @@ struct CFGSimplifyPass : public FunctionPass { &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); const TargetTransformInfo &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); - return simplifyFunctionCFG(F, TTI, AC, BonusInstThreshold); + return simplifyFunctionCFG(F, TTI, AC, BonusInstThreshold, LateSimplifyCFG); } void getAnalysisUsage(AnalysisUsage &AU) const override { @@ -227,6 +234,26 @@ struct CFGSimplifyPass : public FunctionPass { AU.addPreserved<GlobalsAAWrapperPass>(); } }; + +struct CFGSimplifyPass : public BaseCFGSimplifyPass { + static char ID; // Pass identification, replacement for typeid + + CFGSimplifyPass(int T = -1, + std::function<bool(const Function &)> Ftor = nullptr) + : BaseCFGSimplifyPass(T, false, Ftor, ID) { + initializeCFGSimplifyPassPass(*PassRegistry::getPassRegistry()); + } +}; + +struct LateCFGSimplifyPass : public BaseCFGSimplifyPass { + static char ID; // Pass identification, replacement for typeid + + LateCFGSimplifyPass(int T = -1, + std::function<bool(const Function &)> Ftor = nullptr) + : BaseCFGSimplifyPass(T, true, Ftor, ID) { + initializeLateCFGSimplifyPassPass(*PassRegistry::getPassRegistry()); + } +}; } char CFGSimplifyPass::ID = 0; @@ -237,9 +264,24 @@ INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_END(CFGSimplifyPass, "simplifycfg", "Simplify the CFG", false, false) +char LateCFGSimplifyPass::ID = 0; +INITIALIZE_PASS_BEGIN(LateCFGSimplifyPass, "latesimplifycfg", + "Simplify the CFG more aggressively", false, false) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_END(LateCFGSimplifyPass, "latesimplifycfg", + "Simplify the CFG more aggressively", false, false) + // Public interface to the CFGSimplification pass FunctionPass * llvm::createCFGSimplificationPass(int Threshold, - std::function<bool(const Function &)> Ftor) { + std::function<bool(const Function &)> Ftor) { return new CFGSimplifyPass(Threshold, std::move(Ftor)); } + +// Public interface to the LateCFGSimplification pass +FunctionPass * +llvm::createLateCFGSimplificationPass(int Threshold, + std::function<bool(const Function &)> Ftor) { + return new LateCFGSimplifyPass(Threshold, std::move(Ftor)); +} diff --git a/contrib/llvm/lib/Transforms/Scalar/Sink.cpp b/contrib/llvm/lib/Transforms/Scalar/Sink.cpp index c3f14a0..5210f16 100644 --- a/contrib/llvm/lib/Transforms/Scalar/Sink.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/Sink.cpp @@ -114,7 +114,7 @@ static bool IsAcceptableTarget(Instruction *Inst, BasicBlock *SuccToSinkTo, if (SuccToSinkTo->getUniquePredecessor() != Inst->getParent()) { // We cannot sink a load across a critical edge - there may be stores in // other code paths. - if (!isSafeToSpeculativelyExecute(Inst)) + if (isa<LoadInst>(Inst)) return false; // We don't want to sink across a critical edge if we don't dominate the @@ -164,13 +164,14 @@ static bool SinkInstruction(Instruction *Inst, // Instructions can only be sunk if all their uses are in blocks // dominated by one of the successors. - // Look at all the postdominators and see if we can sink it in one. + // Look at all the dominated blocks and see if we can sink it in one. DomTreeNode *DTN = DT.getNode(Inst->getParent()); for (DomTreeNode::iterator I = DTN->begin(), E = DTN->end(); I != E && SuccToSinkTo == nullptr; ++I) { BasicBlock *Candidate = (*I)->getBlock(); - if ((*I)->getIDom()->getBlock() == Inst->getParent() && - IsAcceptableTarget(Inst, Candidate, DT, LI)) + // A node always immediate-dominates its children on the dominator + // tree. + if (IsAcceptableTarget(Inst, Candidate, DT, LI)) SuccToSinkTo = Candidate; } @@ -262,9 +263,8 @@ PreservedAnalyses SinkingPass::run(Function &F, FunctionAnalysisManager &AM) { if (!iterativelySinkInstructions(F, DT, LI, AA)) return PreservedAnalyses::all(); - auto PA = PreservedAnalyses(); - PA.preserve<DominatorTreeAnalysis>(); - PA.preserve<LoopAnalysis>(); + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); return PA; } diff --git a/contrib/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp b/contrib/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp index 2be3f5c..8b8d659 100644 --- a/contrib/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp @@ -693,7 +693,7 @@ bool StraightLineStrengthReduce::runOnFunction(Function &F) { UnlinkedInst->setOperand(I, nullptr); RecursivelyDeleteTriviallyDeadInstructions(Op); } - delete UnlinkedInst; + UnlinkedInst->deleteValue(); } bool Ret = !UnlinkedInstructions.empty(); UnlinkedInstructions.clear(); diff --git a/contrib/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp b/contrib/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp index 49ce026..0cccb41 100644 --- a/contrib/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp @@ -7,7 +7,6 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/SCCIterator.h" @@ -20,6 +19,7 @@ #include "llvm/IR/PatternMatch.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/SSAUpdater.h" using namespace llvm; @@ -329,7 +329,7 @@ void StructurizeCFG::analyzeLoops(RegionNode *N) { Loops[Exit] = N->getEntry(); } else { - // Test for sucessors as back edge + // Test for successors as back edge BasicBlock *BB = N->getNodeAs<BasicBlock>(); BranchInst *Term = cast<BranchInst>(BB->getTerminator()); diff --git a/contrib/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp b/contrib/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp index a6b9fee..90c5c24 100644 --- a/contrib/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp @@ -51,13 +51,12 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/TailRecursionElimination.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" -#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/CaptureTracking.h" +#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InlineCost.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/Loads.h" @@ -69,6 +68,7 @@ #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Function.h" +#include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" @@ -76,6 +76,7 @@ #include "llvm/Pass.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; @@ -90,16 +91,10 @@ STATISTIC(NumAccumAdded, "Number of accumulators introduced"); /// If it contains any dynamic allocas, returns false. static bool canTRE(Function &F) { // Because of PR962, we don't TRE dynamic allocas. - for (auto &BB : F) { - for (auto &I : BB) { - if (AllocaInst *AI = dyn_cast<AllocaInst>(&I)) { - if (!AI->isStaticAlloca()) - return false; - } - } - } - - return true; + return llvm::all_of(instructions(F), [](Instruction &I) { + auto *AI = dyn_cast<AllocaInst>(&I); + return !AI || AI->isStaticAlloca(); + }); } namespace { @@ -321,7 +316,7 @@ static bool markTails(Function &F, bool &AllCallsAreTailCalls) { /// instruction from after the call to before the call, assuming that all /// instructions between the call and this instruction are movable. /// -static bool canMoveAboveCall(Instruction *I, CallInst *CI) { +static bool canMoveAboveCall(Instruction *I, CallInst *CI, AliasAnalysis *AA) { // FIXME: We can move load/store/call/free instructions above the call if the // call does not mod/ref the memory location being processed. if (I->mayHaveSideEffects()) // This also handles volatile loads. @@ -332,10 +327,10 @@ static bool canMoveAboveCall(Instruction *I, CallInst *CI) { if (CI->mayHaveSideEffects()) { // Non-volatile loads may be moved above a call with side effects if it // does not write to memory and the load provably won't trap. - // FIXME: Writes to memory only matter if they may alias the pointer + // Writes to memory only matter if they may alias the pointer // being loaded from. const DataLayout &DL = L->getModule()->getDataLayout(); - if (CI->mayWriteToMemory() || + if ((AA->getModRefInfo(CI, MemoryLocation::get(L)) & MRI_Mod) || !isSafeToLoadUnconditionally(L->getPointerOperand(), L->getAlignment(), DL, L)) return false; @@ -496,7 +491,7 @@ static bool eliminateRecursiveTailCall(CallInst *CI, ReturnInst *Ret, BasicBlock *&OldEntry, bool &TailCallsAreMarkedTail, SmallVectorImpl<PHINode *> &ArgumentPHIs, - bool CannotTailCallElimCallsMarkedTail) { + AliasAnalysis *AA) { // If we are introducing accumulator recursion to eliminate operations after // the call instruction that are both associative and commutative, the initial // value for the accumulator is placed in this variable. If this value is set @@ -516,7 +511,8 @@ static bool eliminateRecursiveTailCall(CallInst *CI, ReturnInst *Ret, // Check that this is the case now. BasicBlock::iterator BBI(CI); for (++BBI; &*BBI != Ret; ++BBI) { - if (canMoveAboveCall(&*BBI, CI)) continue; + if (canMoveAboveCall(&*BBI, CI, AA)) + continue; // If we can't move the instruction above the call, it might be because it // is an associative and commutative operation that could be transformed @@ -675,12 +671,17 @@ static bool foldReturnAndProcessPred(BasicBlock *BB, ReturnInst *Ret, bool &TailCallsAreMarkedTail, SmallVectorImpl<PHINode *> &ArgumentPHIs, bool CannotTailCallElimCallsMarkedTail, - const TargetTransformInfo *TTI) { + const TargetTransformInfo *TTI, + AliasAnalysis *AA) { bool Change = false; + // Make sure this block is a trivial return block. + assert(BB->getFirstNonPHIOrDbg() == Ret && + "Trying to fold non-trivial return block"); + // If the return block contains nothing but the return and PHI's, // there might be an opportunity to duplicate the return in its - // predecessors and perform TRC there. Look for predecessors that end + // predecessors and perform TRE there. Look for predecessors that end // in unconditional branch and recursive call(s). SmallVector<BranchInst*, 8> UncondBranchPreds; for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) { @@ -707,8 +708,7 @@ static bool foldReturnAndProcessPred(BasicBlock *BB, ReturnInst *Ret, BB->eraseFromParent(); eliminateRecursiveTailCall(CI, RI, OldEntry, TailCallsAreMarkedTail, - ArgumentPHIs, - CannotTailCallElimCallsMarkedTail); + ArgumentPHIs, AA); ++NumRetDuped; Change = true; } @@ -721,17 +721,18 @@ static bool processReturningBlock(ReturnInst *Ret, BasicBlock *&OldEntry, bool &TailCallsAreMarkedTail, SmallVectorImpl<PHINode *> &ArgumentPHIs, bool CannotTailCallElimCallsMarkedTail, - const TargetTransformInfo *TTI) { + const TargetTransformInfo *TTI, + AliasAnalysis *AA) { CallInst *CI = findTRECandidate(Ret, CannotTailCallElimCallsMarkedTail, TTI); if (!CI) return false; return eliminateRecursiveTailCall(CI, Ret, OldEntry, TailCallsAreMarkedTail, - ArgumentPHIs, - CannotTailCallElimCallsMarkedTail); + ArgumentPHIs, AA); } -static bool eliminateTailRecursion(Function &F, const TargetTransformInfo *TTI) { +static bool eliminateTailRecursion(Function &F, const TargetTransformInfo *TTI, + AliasAnalysis *AA) { if (F.getFnAttribute("disable-tail-calls").getValueAsString() == "true") return false; @@ -766,11 +767,11 @@ static bool eliminateTailRecursion(Function &F, const TargetTransformInfo *TTI) if (ReturnInst *Ret = dyn_cast<ReturnInst>(BB->getTerminator())) { bool Change = processReturningBlock(Ret, OldEntry, TailCallsAreMarkedTail, - ArgumentPHIs, !CanTRETailMarkedCall, TTI); + ArgumentPHIs, !CanTRETailMarkedCall, TTI, AA); if (!Change && BB->getFirstNonPHIOrDbg() == Ret) - Change = - foldReturnAndProcessPred(BB, Ret, OldEntry, TailCallsAreMarkedTail, - ArgumentPHIs, !CanTRETailMarkedCall, TTI); + Change = foldReturnAndProcessPred(BB, Ret, OldEntry, + TailCallsAreMarkedTail, ArgumentPHIs, + !CanTRETailMarkedCall, TTI, AA); MadeChange |= Change; } } @@ -800,6 +801,7 @@ struct TailCallElim : public FunctionPass { void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<TargetTransformInfoWrapperPass>(); + AU.addRequired<AAResultsWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); } @@ -808,7 +810,8 @@ struct TailCallElim : public FunctionPass { return false; return eliminateTailRecursion( - F, &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F)); + F, &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F), + &getAnalysis<AAResultsWrapperPass>().getAAResults()); } }; } @@ -829,8 +832,9 @@ PreservedAnalyses TailCallElimPass::run(Function &F, FunctionAnalysisManager &AM) { TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(F); + AliasAnalysis &AA = AM.getResult<AAManager>(F); - bool Changed = eliminateTailRecursion(F, &TTI); + bool Changed = eliminateTailRecursion(F, &TTI, &AA); if (!Changed) return PreservedAnalyses::all(); diff --git a/contrib/llvm/lib/Transforms/Utils/AddDiscriminators.cpp b/contrib/llvm/lib/Transforms/Utils/AddDiscriminators.cpp index 2e95926..4c9746b 100644 --- a/contrib/llvm/lib/Transforms/Utils/AddDiscriminators.cpp +++ b/contrib/llvm/lib/Transforms/Utils/AddDiscriminators.cpp @@ -102,6 +102,10 @@ FunctionPass *llvm::createAddDiscriminatorsPass() { return new AddDiscriminatorsLegacyPass(); } +static bool shouldHaveDiscriminator(const Instruction *I) { + return !isa<IntrinsicInst>(I) || isa<MemIntrinsic>(I); +} + /// \brief Assign DWARF discriminators. /// /// To assign discriminators, we examine the boundaries of every @@ -176,7 +180,13 @@ static bool addDiscriminators(Function &F) { // discriminator for this instruction. for (BasicBlock &B : F) { for (auto &I : B.getInstList()) { - if (isa<IntrinsicInst>(&I)) + // Not all intrinsic calls should have a discriminator. + // We want to avoid a non-deterministic assignment of discriminators at + // different debug levels. We still allow discriminators on memory + // intrinsic calls because those can be early expanded by SROA into + // pairs of loads and stores, and the expanded load/store instructions + // should have a valid discriminator. + if (!shouldHaveDiscriminator(&I)) continue; const DILocation *DIL = I.getDebugLoc(); if (!DIL) @@ -190,8 +200,8 @@ static bool addDiscriminators(Function &F) { // discriminator is needed to distinguish both instructions. // Only the lowest 7 bits are used to represent a discriminator to fit // it in 1 byte ULEB128 representation. - unsigned Discriminator = (R.second ? ++LDM[L] : LDM[L]) & 0x7f; - I.setDebugLoc(DIL->cloneWithDiscriminator(Discriminator)); + unsigned Discriminator = R.second ? ++LDM[L] : LDM[L]; + I.setDebugLoc(DIL->setBaseDiscriminator(Discriminator)); DEBUG(dbgs() << DIL->getFilename() << ":" << DIL->getLine() << ":" << DIL->getColumn() << ":" << Discriminator << " " << I << "\n"); @@ -207,6 +217,10 @@ static bool addDiscriminators(Function &F) { LocationSet CallLocations; for (auto &I : B.getInstList()) { CallInst *Current = dyn_cast<CallInst>(&I); + // We bypass intrinsic calls for the following two reasons: + // 1) We want to avoid a non-deterministic assigment of + // discriminators. + // 2) We want to minimize the number of base discriminators used. if (!Current || isa<IntrinsicInst>(&I)) continue; @@ -216,8 +230,8 @@ static bool addDiscriminators(Function &F) { Location L = std::make_pair(CurrentDIL->getFilename(), CurrentDIL->getLine()); if (!CallLocations.insert(L).second) { - Current->setDebugLoc( - CurrentDIL->cloneWithDiscriminator((++LDM[L]) & 0x7f)); + unsigned Discriminator = ++LDM[L]; + Current->setDebugLoc(CurrentDIL->setBaseDiscriminator(Discriminator)); Changed = true; } } diff --git a/contrib/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp b/contrib/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp index b90349d..3d5cbfc 100644 --- a/contrib/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp +++ b/contrib/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp @@ -78,8 +78,8 @@ void llvm::FoldSingleEntryPHINodes(BasicBlock *BB, bool llvm::DeleteDeadPHIs(BasicBlock *BB, const TargetLibraryInfo *TLI) { // Recursively deleting a PHI may cause multiple PHIs to be deleted - // or RAUW'd undef, so use an array of WeakVH for the PHIs to delete. - SmallVector<WeakVH, 8> PHIs; + // or RAUW'd undef, so use an array of WeakTrackingVH for the PHIs to delete. + SmallVector<WeakTrackingVH, 8> PHIs; for (BasicBlock::iterator I = BB->begin(); PHINode *PN = dyn_cast<PHINode>(I); ++I) PHIs.push_back(PN); @@ -438,7 +438,7 @@ BasicBlock *llvm::SplitBlockPredecessors(BasicBlock *BB, // The new block unconditionally branches to the old block. BranchInst *BI = BranchInst::Create(BB, NewBB); - BI->setDebugLoc(BB->getFirstNonPHI()->getDebugLoc()); + BI->setDebugLoc(BB->getFirstNonPHIOrDbg()->getDebugLoc()); // Move the edges from Preds to point to NewBB instead of BB. for (unsigned i = 0, e = Preds.size(); i != e; ++i) { @@ -646,9 +646,10 @@ llvm::SplitBlockAndInsertIfThen(Value *Cond, Instruction *SplitBefore, } if (LI) { - Loop *L = LI->getLoopFor(Head); - L->addBasicBlockToLoop(ThenBlock, *LI); - L->addBasicBlockToLoop(Tail, *LI); + if (Loop *L = LI->getLoopFor(Head)) { + L->addBasicBlockToLoop(ThenBlock, *LI); + L->addBasicBlockToLoop(Tail, *LI); + } } return CheckTerm; diff --git a/contrib/llvm/lib/Transforms/Utils/BuildLibCalls.cpp b/contrib/llvm/lib/Transforms/Utils/BuildLibCalls.cpp index e61b04f..b60dfb4 100644 --- a/contrib/llvm/lib/Transforms/Utils/BuildLibCalls.cpp +++ b/contrib/llvm/lib/Transforms/Utils/BuildLibCalls.cpp @@ -58,7 +58,7 @@ static bool setOnlyReadsMemory(Function &F) { static bool setOnlyAccessesArgMemory(Function &F) { if (F.onlyAccessesArgMemory()) return false; - F.setOnlyAccessesArgMemory (); + F.setOnlyAccessesArgMemory(); ++NumArgMemOnly; return true; } @@ -71,632 +71,633 @@ static bool setDoesNotThrow(Function &F) { return true; } -static bool setDoesNotCapture(Function &F, unsigned n) { - if (F.doesNotCapture(n)) +static bool setRetDoesNotAlias(Function &F) { + if (F.hasAttribute(AttributeList::ReturnIndex, Attribute::NoAlias)) return false; - F.setDoesNotCapture(n); - ++NumNoCapture; + F.addAttribute(AttributeList::ReturnIndex, Attribute::NoAlias); + ++NumNoAlias; return true; } -static bool setOnlyReadsMemory(Function &F, unsigned n) { - if (F.onlyReadsMemory(n)) +static bool setDoesNotCapture(Function &F, unsigned ArgNo) { + if (F.hasParamAttribute(ArgNo, Attribute::NoCapture)) return false; - F.setOnlyReadsMemory(n); - ++NumReadOnlyArg; + F.addParamAttr(ArgNo, Attribute::NoCapture); + ++NumNoCapture; return true; } -static bool setDoesNotAlias(Function &F, unsigned n) { - if (F.doesNotAlias(n)) +static bool setOnlyReadsMemory(Function &F, unsigned ArgNo) { + if (F.hasParamAttribute(ArgNo, Attribute::ReadOnly)) return false; - F.setDoesNotAlias(n); - ++NumNoAlias; + F.addParamAttr(ArgNo, Attribute::ReadOnly); + ++NumReadOnlyArg; return true; } -static bool setNonNull(Function &F, unsigned n) { - assert((n != AttributeSet::ReturnIndex || - F.getReturnType()->isPointerTy()) && +static bool setRetNonNull(Function &F) { + assert(F.getReturnType()->isPointerTy() && "nonnull applies only to pointers"); - if (F.getAttributes().hasAttribute(n, Attribute::NonNull)) + if (F.hasAttribute(AttributeList::ReturnIndex, Attribute::NonNull)) return false; - F.addAttribute(n, Attribute::NonNull); + F.addAttribute(AttributeList::ReturnIndex, Attribute::NonNull); ++NumNonNull; return true; } bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) { - LibFunc::Func TheLibFunc; + LibFunc TheLibFunc; if (!(TLI.getLibFunc(F, TheLibFunc) && TLI.has(TheLibFunc))) return false; bool Changed = false; switch (TheLibFunc) { - case LibFunc::strlen: + case LibFunc_strlen: + case LibFunc_wcslen: Changed |= setOnlyReadsMemory(F); Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); + Changed |= setOnlyAccessesArgMemory(F); + Changed |= setDoesNotCapture(F, 0); return Changed; - case LibFunc::strchr: - case LibFunc::strrchr: + case LibFunc_strchr: + case LibFunc_strrchr: Changed |= setOnlyReadsMemory(F); Changed |= setDoesNotThrow(F); return Changed; - case LibFunc::strtol: - case LibFunc::strtod: - case LibFunc::strtof: - case LibFunc::strtoul: - case LibFunc::strtoll: - case LibFunc::strtold: - case LibFunc::strtoull: + case LibFunc_strtol: + case LibFunc_strtod: + case LibFunc_strtof: + case LibFunc_strtoul: + case LibFunc_strtoll: + case LibFunc_strtold: + case LibFunc_strtoull: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 1); + Changed |= setDoesNotCapture(F, 1); + Changed |= setOnlyReadsMemory(F, 0); return Changed; - case LibFunc::strcpy: - case LibFunc::stpcpy: - case LibFunc::strcat: - case LibFunc::strncat: - case LibFunc::strncpy: - case LibFunc::stpncpy: + case LibFunc_strcpy: + case LibFunc_stpcpy: + case LibFunc_strcat: + case LibFunc_strncat: + case LibFunc_strncpy: + case LibFunc_stpncpy: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 2); + Changed |= setDoesNotCapture(F, 1); + Changed |= setOnlyReadsMemory(F, 1); return Changed; - case LibFunc::strxfrm: + case LibFunc_strxfrm: Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 0); Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 2); + Changed |= setOnlyReadsMemory(F, 1); return Changed; - case LibFunc::strcmp: // 0,1 - case LibFunc::strspn: // 0,1 - case LibFunc::strncmp: // 0,1 - case LibFunc::strcspn: // 0,1 - case LibFunc::strcoll: // 0,1 - case LibFunc::strcasecmp: // 0,1 - case LibFunc::strncasecmp: // + case LibFunc_strcmp: // 0,1 + case LibFunc_strspn: // 0,1 + case LibFunc_strncmp: // 0,1 + case LibFunc_strcspn: // 0,1 + case LibFunc_strcoll: // 0,1 + case LibFunc_strcasecmp: // 0,1 + case LibFunc_strncasecmp: // Changed |= setOnlyReadsMemory(F); Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 0); Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); return Changed; - case LibFunc::strstr: - case LibFunc::strpbrk: + case LibFunc_strstr: + case LibFunc_strpbrk: Changed |= setOnlyReadsMemory(F); Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 2); - return Changed; - case LibFunc::strtok: - case LibFunc::strtok_r: - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 2); + Changed |= setDoesNotCapture(F, 1); return Changed; - case LibFunc::scanf: + case LibFunc_strtok: + case LibFunc_strtok_r: Changed |= setDoesNotThrow(F); Changed |= setDoesNotCapture(F, 1); Changed |= setOnlyReadsMemory(F, 1); return Changed; - case LibFunc::setbuf: - case LibFunc::setvbuf: + case LibFunc_scanf: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 0); + Changed |= setOnlyReadsMemory(F, 0); + return Changed; + case LibFunc_setbuf: + case LibFunc_setvbuf: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 0); return Changed; - case LibFunc::strdup: - case LibFunc::strndup: + case LibFunc_strdup: + case LibFunc_strndup: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotAlias(F, 0); + Changed |= setRetDoesNotAlias(F); + Changed |= setDoesNotCapture(F, 0); + Changed |= setOnlyReadsMemory(F, 0); + return Changed; + case LibFunc_stat: + case LibFunc_statvfs: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 0); Changed |= setDoesNotCapture(F, 1); - Changed |= setOnlyReadsMemory(F, 1); + Changed |= setOnlyReadsMemory(F, 0); return Changed; - case LibFunc::stat: - case LibFunc::statvfs: + case LibFunc_sscanf: Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 0); Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); + Changed |= setOnlyReadsMemory(F, 0); Changed |= setOnlyReadsMemory(F, 1); return Changed; - case LibFunc::sscanf: + case LibFunc_sprintf: Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 0); Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); Changed |= setOnlyReadsMemory(F, 1); - Changed |= setOnlyReadsMemory(F, 2); return Changed; - case LibFunc::sprintf: + case LibFunc_snprintf: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 0); Changed |= setDoesNotCapture(F, 2); Changed |= setOnlyReadsMemory(F, 2); return Changed; - case LibFunc::snprintf: + case LibFunc_setitimer: Changed |= setDoesNotThrow(F); Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 3); - Changed |= setOnlyReadsMemory(F, 3); - return Changed; - case LibFunc::setitimer: - Changed |= setDoesNotThrow(F); Changed |= setDoesNotCapture(F, 2); - Changed |= setDoesNotCapture(F, 3); - Changed |= setOnlyReadsMemory(F, 2); + Changed |= setOnlyReadsMemory(F, 1); return Changed; - case LibFunc::system: + case LibFunc_system: // May throw; "system" is a valid pthread cancellation point. - Changed |= setDoesNotCapture(F, 1); - Changed |= setOnlyReadsMemory(F, 1); + Changed |= setDoesNotCapture(F, 0); + Changed |= setOnlyReadsMemory(F, 0); return Changed; - case LibFunc::malloc: + case LibFunc_malloc: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotAlias(F, 0); + Changed |= setRetDoesNotAlias(F); return Changed; - case LibFunc::memcmp: + case LibFunc_memcmp: Changed |= setOnlyReadsMemory(F); Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 0); Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); return Changed; - case LibFunc::memchr: - case LibFunc::memrchr: + case LibFunc_memchr: + case LibFunc_memrchr: Changed |= setOnlyReadsMemory(F); Changed |= setDoesNotThrow(F); return Changed; - case LibFunc::modf: - case LibFunc::modff: - case LibFunc::modfl: + case LibFunc_modf: + case LibFunc_modff: + case LibFunc_modfl: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 2); + Changed |= setDoesNotCapture(F, 1); return Changed; - case LibFunc::memcpy: - case LibFunc::mempcpy: - case LibFunc::memccpy: - case LibFunc::memmove: + case LibFunc_memcpy: + case LibFunc_mempcpy: + case LibFunc_memccpy: + case LibFunc_memmove: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 2); + Changed |= setDoesNotCapture(F, 1); + Changed |= setOnlyReadsMemory(F, 1); return Changed; - case LibFunc::memcpy_chk: + case LibFunc_memcpy_chk: Changed |= setDoesNotThrow(F); return Changed; - case LibFunc::memalign: - Changed |= setDoesNotAlias(F, 0); + case LibFunc_memalign: + Changed |= setRetDoesNotAlias(F); return Changed; - case LibFunc::mkdir: + case LibFunc_mkdir: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setOnlyReadsMemory(F, 1); + Changed |= setDoesNotCapture(F, 0); + Changed |= setOnlyReadsMemory(F, 0); return Changed; - case LibFunc::mktime: + case LibFunc_mktime: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 0); return Changed; - case LibFunc::realloc: + case LibFunc_realloc: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotAlias(F, 0); - Changed |= setDoesNotCapture(F, 1); + Changed |= setRetDoesNotAlias(F); + Changed |= setDoesNotCapture(F, 0); return Changed; - case LibFunc::read: + case LibFunc_read: // May throw; "read" is a valid pthread cancellation point. - Changed |= setDoesNotCapture(F, 2); + Changed |= setDoesNotCapture(F, 1); return Changed; - case LibFunc::rewind: + case LibFunc_rewind: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 0); return Changed; - case LibFunc::rmdir: - case LibFunc::remove: - case LibFunc::realpath: + case LibFunc_rmdir: + case LibFunc_remove: + case LibFunc_realpath: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setOnlyReadsMemory(F, 1); + Changed |= setDoesNotCapture(F, 0); + Changed |= setOnlyReadsMemory(F, 0); return Changed; - case LibFunc::rename: + case LibFunc_rename: Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 0); Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); + Changed |= setOnlyReadsMemory(F, 0); Changed |= setOnlyReadsMemory(F, 1); - Changed |= setOnlyReadsMemory(F, 2); return Changed; - case LibFunc::readlink: + case LibFunc_readlink: Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 0); Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 1); + Changed |= setOnlyReadsMemory(F, 0); return Changed; - case LibFunc::write: + case LibFunc_write: // May throw; "write" is a valid pthread cancellation point. - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 2); + Changed |= setDoesNotCapture(F, 1); + Changed |= setOnlyReadsMemory(F, 1); return Changed; - case LibFunc::bcopy: + case LibFunc_bcopy: Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 0); Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 1); + Changed |= setOnlyReadsMemory(F, 0); return Changed; - case LibFunc::bcmp: + case LibFunc_bcmp: Changed |= setDoesNotThrow(F); Changed |= setOnlyReadsMemory(F); + Changed |= setDoesNotCapture(F, 0); Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); return Changed; - case LibFunc::bzero: + case LibFunc_bzero: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 0); return Changed; - case LibFunc::calloc: + case LibFunc_calloc: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotAlias(F, 0); + Changed |= setRetDoesNotAlias(F); return Changed; - case LibFunc::chmod: - case LibFunc::chown: + case LibFunc_chmod: + case LibFunc_chown: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setOnlyReadsMemory(F, 1); + Changed |= setDoesNotCapture(F, 0); + Changed |= setOnlyReadsMemory(F, 0); return Changed; - case LibFunc::ctermid: - case LibFunc::clearerr: - case LibFunc::closedir: + case LibFunc_ctermid: + case LibFunc_clearerr: + case LibFunc_closedir: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 0); return Changed; - case LibFunc::atoi: - case LibFunc::atol: - case LibFunc::atof: - case LibFunc::atoll: + case LibFunc_atoi: + case LibFunc_atol: + case LibFunc_atof: + case LibFunc_atoll: Changed |= setDoesNotThrow(F); Changed |= setOnlyReadsMemory(F); - Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 0); return Changed; - case LibFunc::access: + case LibFunc_access: Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 0); + Changed |= setOnlyReadsMemory(F, 0); + return Changed; + case LibFunc_fopen: + Changed |= setDoesNotThrow(F); + Changed |= setRetDoesNotAlias(F); + Changed |= setDoesNotCapture(F, 0); Changed |= setDoesNotCapture(F, 1); + Changed |= setOnlyReadsMemory(F, 0); Changed |= setOnlyReadsMemory(F, 1); return Changed; - case LibFunc::fopen: + case LibFunc_fdopen: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotAlias(F, 0); + Changed |= setRetDoesNotAlias(F); Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); Changed |= setOnlyReadsMemory(F, 1); - Changed |= setOnlyReadsMemory(F, 2); return Changed; - case LibFunc::fdopen: + case LibFunc_feof: + case LibFunc_free: + case LibFunc_fseek: + case LibFunc_ftell: + case LibFunc_fgetc: + case LibFunc_fseeko: + case LibFunc_ftello: + case LibFunc_fileno: + case LibFunc_fflush: + case LibFunc_fclose: + case LibFunc_fsetpos: + case LibFunc_flockfile: + case LibFunc_funlockfile: + case LibFunc_ftrylockfile: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotAlias(F, 0); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 2); + Changed |= setDoesNotCapture(F, 0); return Changed; - case LibFunc::feof: - case LibFunc::free: - case LibFunc::fseek: - case LibFunc::ftell: - case LibFunc::fgetc: - case LibFunc::fseeko: - case LibFunc::ftello: - case LibFunc::fileno: - case LibFunc::fflush: - case LibFunc::fclose: - case LibFunc::fsetpos: - case LibFunc::flockfile: - case LibFunc::funlockfile: - case LibFunc::ftrylockfile: + case LibFunc_ferror: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 0); + Changed |= setOnlyReadsMemory(F); return Changed; - case LibFunc::ferror: + case LibFunc_fputc: + case LibFunc_fstat: + case LibFunc_frexp: + case LibFunc_frexpf: + case LibFunc_frexpl: + case LibFunc_fstatvfs: Changed |= setDoesNotThrow(F); Changed |= setDoesNotCapture(F, 1); - Changed |= setOnlyReadsMemory(F); return Changed; - case LibFunc::fputc: - case LibFunc::fstat: - case LibFunc::frexp: - case LibFunc::frexpf: - case LibFunc::frexpl: - case LibFunc::fstatvfs: + case LibFunc_fgets: Changed |= setDoesNotThrow(F); Changed |= setDoesNotCapture(F, 2); return Changed; - case LibFunc::fgets: + case LibFunc_fread: Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 0); Changed |= setDoesNotCapture(F, 3); return Changed; - case LibFunc::fread: + case LibFunc_fwrite: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 4); + Changed |= setDoesNotCapture(F, 0); + Changed |= setDoesNotCapture(F, 3); + // FIXME: readonly #1? return Changed; - case LibFunc::fwrite: + case LibFunc_fputs: Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 0); Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 4); - // FIXME: readonly #1? + Changed |= setOnlyReadsMemory(F, 0); return Changed; - case LibFunc::fputs: + case LibFunc_fscanf: + case LibFunc_fprintf: Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 0); Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); Changed |= setOnlyReadsMemory(F, 1); return Changed; - case LibFunc::fscanf: - case LibFunc::fprintf: + case LibFunc_fgetpos: Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 0); Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 2); return Changed; - case LibFunc::fgetpos: + case LibFunc_getc: + case LibFunc_getlogin_r: + case LibFunc_getc_unlocked: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); + Changed |= setDoesNotCapture(F, 0); return Changed; - case LibFunc::getc: - case LibFunc::getlogin_r: - case LibFunc::getc_unlocked: + case LibFunc_getenv: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); + Changed |= setOnlyReadsMemory(F); + Changed |= setDoesNotCapture(F, 0); return Changed; - case LibFunc::getenv: + case LibFunc_gets: + case LibFunc_getchar: Changed |= setDoesNotThrow(F); - Changed |= setOnlyReadsMemory(F); - Changed |= setDoesNotCapture(F, 1); return Changed; - case LibFunc::gets: - case LibFunc::getchar: + case LibFunc_getitimer: Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); return Changed; - case LibFunc::getitimer: + case LibFunc_getpwnam: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 2); + Changed |= setDoesNotCapture(F, 0); + Changed |= setOnlyReadsMemory(F, 0); return Changed; - case LibFunc::getpwnam: + case LibFunc_ungetc: Changed |= setDoesNotThrow(F); Changed |= setDoesNotCapture(F, 1); - Changed |= setOnlyReadsMemory(F, 1); return Changed; - case LibFunc::ungetc: + case LibFunc_uname: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 2); + Changed |= setDoesNotCapture(F, 0); return Changed; - case LibFunc::uname: + case LibFunc_unlink: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 0); + Changed |= setOnlyReadsMemory(F, 0); return Changed; - case LibFunc::unlink: + case LibFunc_unsetenv: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setOnlyReadsMemory(F, 1); + Changed |= setDoesNotCapture(F, 0); + Changed |= setOnlyReadsMemory(F, 0); return Changed; - case LibFunc::unsetenv: + case LibFunc_utime: + case LibFunc_utimes: Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 0); Changed |= setDoesNotCapture(F, 1); + Changed |= setOnlyReadsMemory(F, 0); Changed |= setOnlyReadsMemory(F, 1); return Changed; - case LibFunc::utime: - case LibFunc::utimes: + case LibFunc_putc: Changed |= setDoesNotThrow(F); Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 1); - Changed |= setOnlyReadsMemory(F, 2); return Changed; - case LibFunc::putc: + case LibFunc_puts: + case LibFunc_printf: + case LibFunc_perror: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 2); - return Changed; - case LibFunc::puts: - case LibFunc::printf: - case LibFunc::perror: - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setOnlyReadsMemory(F, 1); + Changed |= setDoesNotCapture(F, 0); + Changed |= setOnlyReadsMemory(F, 0); return Changed; - case LibFunc::pread: + case LibFunc_pread: // May throw; "pread" is a valid pthread cancellation point. - Changed |= setDoesNotCapture(F, 2); + Changed |= setDoesNotCapture(F, 1); return Changed; - case LibFunc::pwrite: + case LibFunc_pwrite: // May throw; "pwrite" is a valid pthread cancellation point. - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 2); + Changed |= setDoesNotCapture(F, 1); + Changed |= setOnlyReadsMemory(F, 1); return Changed; - case LibFunc::putchar: + case LibFunc_putchar: Changed |= setDoesNotThrow(F); return Changed; - case LibFunc::popen: + case LibFunc_popen: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotAlias(F, 0); + Changed |= setRetDoesNotAlias(F); + Changed |= setDoesNotCapture(F, 0); Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); + Changed |= setOnlyReadsMemory(F, 0); Changed |= setOnlyReadsMemory(F, 1); - Changed |= setOnlyReadsMemory(F, 2); return Changed; - case LibFunc::pclose: + case LibFunc_pclose: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 0); + return Changed; + case LibFunc_vscanf: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 0); + Changed |= setOnlyReadsMemory(F, 0); return Changed; - case LibFunc::vscanf: + case LibFunc_vsscanf: Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 0); Changed |= setDoesNotCapture(F, 1); + Changed |= setOnlyReadsMemory(F, 0); Changed |= setOnlyReadsMemory(F, 1); return Changed; - case LibFunc::vsscanf: + case LibFunc_vfscanf: Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 0); Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); Changed |= setOnlyReadsMemory(F, 1); - Changed |= setOnlyReadsMemory(F, 2); return Changed; - case LibFunc::vfscanf: + case LibFunc_valloc: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 2); + Changed |= setRetDoesNotAlias(F); return Changed; - case LibFunc::valloc: + case LibFunc_vprintf: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotAlias(F, 0); + Changed |= setDoesNotCapture(F, 0); + Changed |= setOnlyReadsMemory(F, 0); return Changed; - case LibFunc::vprintf: + case LibFunc_vfprintf: + case LibFunc_vsprintf: Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 0); Changed |= setDoesNotCapture(F, 1); Changed |= setOnlyReadsMemory(F, 1); return Changed; - case LibFunc::vfprintf: - case LibFunc::vsprintf: + case LibFunc_vsnprintf: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 0); Changed |= setDoesNotCapture(F, 2); Changed |= setOnlyReadsMemory(F, 2); return Changed; - case LibFunc::vsnprintf: - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 3); - Changed |= setOnlyReadsMemory(F, 3); - return Changed; - case LibFunc::open: + case LibFunc_open: // May throw; "open" is a valid pthread cancellation point. - Changed |= setDoesNotCapture(F, 1); - Changed |= setOnlyReadsMemory(F, 1); + Changed |= setDoesNotCapture(F, 0); + Changed |= setOnlyReadsMemory(F, 0); return Changed; - case LibFunc::opendir: + case LibFunc_opendir: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotAlias(F, 0); - Changed |= setDoesNotCapture(F, 1); - Changed |= setOnlyReadsMemory(F, 1); + Changed |= setRetDoesNotAlias(F); + Changed |= setDoesNotCapture(F, 0); + Changed |= setOnlyReadsMemory(F, 0); return Changed; - case LibFunc::tmpfile: + case LibFunc_tmpfile: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotAlias(F, 0); + Changed |= setRetDoesNotAlias(F); return Changed; - case LibFunc::times: + case LibFunc_times: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 0); return Changed; - case LibFunc::htonl: - case LibFunc::htons: - case LibFunc::ntohl: - case LibFunc::ntohs: + case LibFunc_htonl: + case LibFunc_htons: + case LibFunc_ntohl: + case LibFunc_ntohs: Changed |= setDoesNotThrow(F); Changed |= setDoesNotAccessMemory(F); return Changed; - case LibFunc::lstat: + case LibFunc_lstat: Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 0); Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 1); + Changed |= setOnlyReadsMemory(F, 0); return Changed; - case LibFunc::lchown: + case LibFunc_lchown: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setOnlyReadsMemory(F, 1); + Changed |= setDoesNotCapture(F, 0); + Changed |= setOnlyReadsMemory(F, 0); return Changed; - case LibFunc::qsort: + case LibFunc_qsort: // May throw; places call through function pointer. - Changed |= setDoesNotCapture(F, 4); + Changed |= setDoesNotCapture(F, 3); + return Changed; + case LibFunc_dunder_strdup: + case LibFunc_dunder_strndup: + Changed |= setDoesNotThrow(F); + Changed |= setRetDoesNotAlias(F); + Changed |= setDoesNotCapture(F, 0); + Changed |= setOnlyReadsMemory(F, 0); return Changed; - case LibFunc::dunder_strdup: - case LibFunc::dunder_strndup: + case LibFunc_dunder_strtok_r: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotAlias(F, 0); Changed |= setDoesNotCapture(F, 1); Changed |= setOnlyReadsMemory(F, 1); return Changed; - case LibFunc::dunder_strtok_r: + case LibFunc_under_IO_getc: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 2); + Changed |= setDoesNotCapture(F, 0); return Changed; - case LibFunc::under_IO_getc: + case LibFunc_under_IO_putc: Changed |= setDoesNotThrow(F); Changed |= setDoesNotCapture(F, 1); return Changed; - case LibFunc::under_IO_putc: - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 2); - return Changed; - case LibFunc::dunder_isoc99_scanf: + case LibFunc_dunder_isoc99_scanf: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setOnlyReadsMemory(F, 1); + Changed |= setDoesNotCapture(F, 0); + Changed |= setOnlyReadsMemory(F, 0); return Changed; - case LibFunc::stat64: - case LibFunc::lstat64: - case LibFunc::statvfs64: + case LibFunc_stat64: + case LibFunc_lstat64: + case LibFunc_statvfs64: Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 0); Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 1); + Changed |= setOnlyReadsMemory(F, 0); return Changed; - case LibFunc::dunder_isoc99_sscanf: + case LibFunc_dunder_isoc99_sscanf: Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 0); Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); + Changed |= setOnlyReadsMemory(F, 0); Changed |= setOnlyReadsMemory(F, 1); - Changed |= setOnlyReadsMemory(F, 2); return Changed; - case LibFunc::fopen64: + case LibFunc_fopen64: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotAlias(F, 0); + Changed |= setRetDoesNotAlias(F); + Changed |= setDoesNotCapture(F, 0); Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); + Changed |= setOnlyReadsMemory(F, 0); Changed |= setOnlyReadsMemory(F, 1); - Changed |= setOnlyReadsMemory(F, 2); return Changed; - case LibFunc::fseeko64: - case LibFunc::ftello64: + case LibFunc_fseeko64: + case LibFunc_ftello64: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 0); return Changed; - case LibFunc::tmpfile64: + case LibFunc_tmpfile64: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotAlias(F, 0); + Changed |= setRetDoesNotAlias(F); return Changed; - case LibFunc::fstat64: - case LibFunc::fstatvfs64: + case LibFunc_fstat64: + case LibFunc_fstatvfs64: Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 2); + Changed |= setDoesNotCapture(F, 1); return Changed; - case LibFunc::open64: + case LibFunc_open64: // May throw; "open" is a valid pthread cancellation point. - Changed |= setDoesNotCapture(F, 1); - Changed |= setOnlyReadsMemory(F, 1); + Changed |= setDoesNotCapture(F, 0); + Changed |= setOnlyReadsMemory(F, 0); return Changed; - case LibFunc::gettimeofday: + case LibFunc_gettimeofday: // Currently some platforms have the restrict keyword on the arguments to // gettimeofday. To be conservative, do not add noalias to gettimeofday's // arguments. Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 0); Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); return Changed; - case LibFunc::Znwj: // new(unsigned int) - case LibFunc::Znwm: // new(unsigned long) - case LibFunc::Znaj: // new[](unsigned int) - case LibFunc::Znam: // new[](unsigned long) - case LibFunc::msvc_new_int: // new(unsigned int) - case LibFunc::msvc_new_longlong: // new(unsigned long long) - case LibFunc::msvc_new_array_int: // new[](unsigned int) - case LibFunc::msvc_new_array_longlong: // new[](unsigned long long) + case LibFunc_Znwj: // new(unsigned int) + case LibFunc_Znwm: // new(unsigned long) + case LibFunc_Znaj: // new[](unsigned int) + case LibFunc_Znam: // new[](unsigned long) + case LibFunc_msvc_new_int: // new(unsigned int) + case LibFunc_msvc_new_longlong: // new(unsigned long long) + case LibFunc_msvc_new_array_int: // new[](unsigned int) + case LibFunc_msvc_new_array_longlong: // new[](unsigned long long) // Operator new always returns a nonnull noalias pointer - Changed |= setNonNull(F, AttributeSet::ReturnIndex); - Changed |= setDoesNotAlias(F, AttributeSet::ReturnIndex); + Changed |= setRetNonNull(F); + Changed |= setRetDoesNotAlias(F); return Changed; //TODO: add LibFunc entries for: - //case LibFunc::memset_pattern4: - //case LibFunc::memset_pattern8: - case LibFunc::memset_pattern16: + //case LibFunc_memset_pattern4: + //case LibFunc_memset_pattern8: + case LibFunc_memset_pattern16: Changed |= setOnlyAccessesArgMemory(F); + Changed |= setDoesNotCapture(F, 0); Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 2); + Changed |= setOnlyReadsMemory(F, 1); return Changed; // int __nvvm_reflect(const char *) - case LibFunc::nvvm_reflect: + case LibFunc_nvvm_reflect: Changed |= setDoesNotAccessMemory(F); Changed |= setDoesNotThrow(F); return Changed; @@ -717,13 +718,13 @@ Value *llvm::castToCStr(Value *V, IRBuilder<> &B) { Value *llvm::emitStrLen(Value *Ptr, IRBuilder<> &B, const DataLayout &DL, const TargetLibraryInfo *TLI) { - if (!TLI->has(LibFunc::strlen)) + if (!TLI->has(LibFunc_strlen)) return nullptr; Module *M = B.GetInsertBlock()->getModule(); LLVMContext &Context = B.GetInsertBlock()->getContext(); Constant *StrLen = M->getOrInsertFunction("strlen", DL.getIntPtrType(Context), - B.getInt8PtrTy(), nullptr); + B.getInt8PtrTy()); inferLibFuncAttributes(*M->getFunction("strlen"), *TLI); CallInst *CI = B.CreateCall(StrLen, castToCStr(Ptr, B), "strlen"); if (const Function *F = dyn_cast<Function>(StrLen->stripPointerCasts())) @@ -734,14 +735,14 @@ Value *llvm::emitStrLen(Value *Ptr, IRBuilder<> &B, const DataLayout &DL, Value *llvm::emitStrChr(Value *Ptr, char C, IRBuilder<> &B, const TargetLibraryInfo *TLI) { - if (!TLI->has(LibFunc::strchr)) + if (!TLI->has(LibFunc_strchr)) return nullptr; Module *M = B.GetInsertBlock()->getModule(); Type *I8Ptr = B.getInt8PtrTy(); Type *I32Ty = B.getInt32Ty(); Constant *StrChr = - M->getOrInsertFunction("strchr", I8Ptr, I8Ptr, I32Ty, nullptr); + M->getOrInsertFunction("strchr", I8Ptr, I8Ptr, I32Ty); inferLibFuncAttributes(*M->getFunction("strchr"), *TLI); CallInst *CI = B.CreateCall( StrChr, {castToCStr(Ptr, B), ConstantInt::get(I32Ty, C)}, "strchr"); @@ -752,14 +753,14 @@ Value *llvm::emitStrChr(Value *Ptr, char C, IRBuilder<> &B, Value *llvm::emitStrNCmp(Value *Ptr1, Value *Ptr2, Value *Len, IRBuilder<> &B, const DataLayout &DL, const TargetLibraryInfo *TLI) { - if (!TLI->has(LibFunc::strncmp)) + if (!TLI->has(LibFunc_strncmp)) return nullptr; Module *M = B.GetInsertBlock()->getModule(); LLVMContext &Context = B.GetInsertBlock()->getContext(); Value *StrNCmp = M->getOrInsertFunction("strncmp", B.getInt32Ty(), B.getInt8PtrTy(), B.getInt8PtrTy(), - DL.getIntPtrType(Context), nullptr); + DL.getIntPtrType(Context)); inferLibFuncAttributes(*M->getFunction("strncmp"), *TLI); CallInst *CI = B.CreateCall( StrNCmp, {castToCStr(Ptr1, B), castToCStr(Ptr2, B), Len}, "strncmp"); @@ -772,12 +773,12 @@ Value *llvm::emitStrNCmp(Value *Ptr1, Value *Ptr2, Value *Len, IRBuilder<> &B, Value *llvm::emitStrCpy(Value *Dst, Value *Src, IRBuilder<> &B, const TargetLibraryInfo *TLI, StringRef Name) { - if (!TLI->has(LibFunc::strcpy)) + if (!TLI->has(LibFunc_strcpy)) return nullptr; Module *M = B.GetInsertBlock()->getModule(); Type *I8Ptr = B.getInt8PtrTy(); - Value *StrCpy = M->getOrInsertFunction(Name, I8Ptr, I8Ptr, I8Ptr, nullptr); + Value *StrCpy = M->getOrInsertFunction(Name, I8Ptr, I8Ptr, I8Ptr); inferLibFuncAttributes(*M->getFunction(Name), *TLI); CallInst *CI = B.CreateCall(StrCpy, {castToCStr(Dst, B), castToCStr(Src, B)}, Name); @@ -788,13 +789,13 @@ Value *llvm::emitStrCpy(Value *Dst, Value *Src, IRBuilder<> &B, Value *llvm::emitStrNCpy(Value *Dst, Value *Src, Value *Len, IRBuilder<> &B, const TargetLibraryInfo *TLI, StringRef Name) { - if (!TLI->has(LibFunc::strncpy)) + if (!TLI->has(LibFunc_strncpy)) return nullptr; Module *M = B.GetInsertBlock()->getModule(); Type *I8Ptr = B.getInt8PtrTy(); Value *StrNCpy = M->getOrInsertFunction(Name, I8Ptr, I8Ptr, I8Ptr, - Len->getType(), nullptr); + Len->getType()); inferLibFuncAttributes(*M->getFunction(Name), *TLI); CallInst *CI = B.CreateCall( StrNCpy, {castToCStr(Dst, B), castToCStr(Src, B), Len}, "strncpy"); @@ -806,18 +807,18 @@ Value *llvm::emitStrNCpy(Value *Dst, Value *Src, Value *Len, IRBuilder<> &B, Value *llvm::emitMemCpyChk(Value *Dst, Value *Src, Value *Len, Value *ObjSize, IRBuilder<> &B, const DataLayout &DL, const TargetLibraryInfo *TLI) { - if (!TLI->has(LibFunc::memcpy_chk)) + if (!TLI->has(LibFunc_memcpy_chk)) return nullptr; Module *M = B.GetInsertBlock()->getModule(); - AttributeSet AS; - AS = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, - Attribute::NoUnwind); + AttributeList AS; + AS = AttributeList::get(M->getContext(), AttributeList::FunctionIndex, + Attribute::NoUnwind); LLVMContext &Context = B.GetInsertBlock()->getContext(); Value *MemCpy = M->getOrInsertFunction( - "__memcpy_chk", AttributeSet::get(M->getContext(), AS), B.getInt8PtrTy(), + "__memcpy_chk", AttributeList::get(M->getContext(), AS), B.getInt8PtrTy(), B.getInt8PtrTy(), B.getInt8PtrTy(), DL.getIntPtrType(Context), - DL.getIntPtrType(Context), nullptr); + DL.getIntPtrType(Context)); Dst = castToCStr(Dst, B); Src = castToCStr(Src, B); CallInst *CI = B.CreateCall(MemCpy, {Dst, Src, Len, ObjSize}); @@ -828,14 +829,14 @@ Value *llvm::emitMemCpyChk(Value *Dst, Value *Src, Value *Len, Value *ObjSize, Value *llvm::emitMemChr(Value *Ptr, Value *Val, Value *Len, IRBuilder<> &B, const DataLayout &DL, const TargetLibraryInfo *TLI) { - if (!TLI->has(LibFunc::memchr)) + if (!TLI->has(LibFunc_memchr)) return nullptr; Module *M = B.GetInsertBlock()->getModule(); LLVMContext &Context = B.GetInsertBlock()->getContext(); Value *MemChr = M->getOrInsertFunction("memchr", B.getInt8PtrTy(), B.getInt8PtrTy(), B.getInt32Ty(), - DL.getIntPtrType(Context), nullptr); + DL.getIntPtrType(Context)); inferLibFuncAttributes(*M->getFunction("memchr"), *TLI); CallInst *CI = B.CreateCall(MemChr, {castToCStr(Ptr, B), Val, Len}, "memchr"); @@ -847,14 +848,14 @@ Value *llvm::emitMemChr(Value *Ptr, Value *Val, Value *Len, IRBuilder<> &B, Value *llvm::emitMemCmp(Value *Ptr1, Value *Ptr2, Value *Len, IRBuilder<> &B, const DataLayout &DL, const TargetLibraryInfo *TLI) { - if (!TLI->has(LibFunc::memcmp)) + if (!TLI->has(LibFunc_memcmp)) return nullptr; Module *M = B.GetInsertBlock()->getModule(); LLVMContext &Context = B.GetInsertBlock()->getContext(); Value *MemCmp = M->getOrInsertFunction("memcmp", B.getInt32Ty(), B.getInt8PtrTy(), B.getInt8PtrTy(), - DL.getIntPtrType(Context), nullptr); + DL.getIntPtrType(Context)); inferLibFuncAttributes(*M->getFunction("memcmp"), *TLI); CallInst *CI = B.CreateCall( MemCmp, {castToCStr(Ptr1, B), castToCStr(Ptr2, B), Len}, "memcmp"); @@ -881,15 +882,21 @@ static void appendTypeSuffix(Value *Op, StringRef &Name, } Value *llvm::emitUnaryFloatFnCall(Value *Op, StringRef Name, IRBuilder<> &B, - const AttributeSet &Attrs) { + const AttributeList &Attrs) { SmallString<20> NameBuffer; appendTypeSuffix(Op, Name, NameBuffer); Module *M = B.GetInsertBlock()->getModule(); Value *Callee = M->getOrInsertFunction(Name, Op->getType(), - Op->getType(), nullptr); + Op->getType()); CallInst *CI = B.CreateCall(Callee, Op, Name); - CI->setAttributes(Attrs); + + // The incoming attribute set may have come from a speculatable intrinsic, but + // is being replaced with a library call which is not allowed to be + // speculatable. + CI->setAttributes(Attrs.removeAttribute(B.getContext(), + AttributeList::FunctionIndex, + Attribute::Speculatable)); if (const Function *F = dyn_cast<Function>(Callee->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); @@ -897,13 +904,13 @@ Value *llvm::emitUnaryFloatFnCall(Value *Op, StringRef Name, IRBuilder<> &B, } Value *llvm::emitBinaryFloatFnCall(Value *Op1, Value *Op2, StringRef Name, - IRBuilder<> &B, const AttributeSet &Attrs) { + IRBuilder<> &B, const AttributeList &Attrs) { SmallString<20> NameBuffer; appendTypeSuffix(Op1, Name, NameBuffer); Module *M = B.GetInsertBlock()->getModule(); Value *Callee = M->getOrInsertFunction(Name, Op1->getType(), Op1->getType(), - Op2->getType(), nullptr); + Op2->getType()); CallInst *CI = B.CreateCall(Callee, {Op1, Op2}, Name); CI->setAttributes(Attrs); if (const Function *F = dyn_cast<Function>(Callee->stripPointerCasts())) @@ -914,12 +921,12 @@ Value *llvm::emitBinaryFloatFnCall(Value *Op1, Value *Op2, StringRef Name, Value *llvm::emitPutChar(Value *Char, IRBuilder<> &B, const TargetLibraryInfo *TLI) { - if (!TLI->has(LibFunc::putchar)) + if (!TLI->has(LibFunc_putchar)) return nullptr; Module *M = B.GetInsertBlock()->getModule(); - Value *PutChar = M->getOrInsertFunction("putchar", B.getInt32Ty(), - B.getInt32Ty(), nullptr); + Value *PutChar = M->getOrInsertFunction("putchar", B.getInt32Ty(), B.getInt32Ty()); + inferLibFuncAttributes(*M->getFunction("putchar"), *TLI); CallInst *CI = B.CreateCall(PutChar, B.CreateIntCast(Char, B.getInt32Ty(), @@ -934,12 +941,12 @@ Value *llvm::emitPutChar(Value *Char, IRBuilder<> &B, Value *llvm::emitPutS(Value *Str, IRBuilder<> &B, const TargetLibraryInfo *TLI) { - if (!TLI->has(LibFunc::puts)) + if (!TLI->has(LibFunc_puts)) return nullptr; Module *M = B.GetInsertBlock()->getModule(); Value *PutS = - M->getOrInsertFunction("puts", B.getInt32Ty(), B.getInt8PtrTy(), nullptr); + M->getOrInsertFunction("puts", B.getInt32Ty(), B.getInt8PtrTy()); inferLibFuncAttributes(*M->getFunction("puts"), *TLI); CallInst *CI = B.CreateCall(PutS, castToCStr(Str, B), "puts"); if (const Function *F = dyn_cast<Function>(PutS->stripPointerCasts())) @@ -949,12 +956,12 @@ Value *llvm::emitPutS(Value *Str, IRBuilder<> &B, Value *llvm::emitFPutC(Value *Char, Value *File, IRBuilder<> &B, const TargetLibraryInfo *TLI) { - if (!TLI->has(LibFunc::fputc)) + if (!TLI->has(LibFunc_fputc)) return nullptr; Module *M = B.GetInsertBlock()->getModule(); Constant *F = M->getOrInsertFunction("fputc", B.getInt32Ty(), B.getInt32Ty(), - File->getType(), nullptr); + File->getType()); if (File->getType()->isPointerTy()) inferLibFuncAttributes(*M->getFunction("fputc"), *TLI); Char = B.CreateIntCast(Char, B.getInt32Ty(), /*isSigned*/true, @@ -968,13 +975,13 @@ Value *llvm::emitFPutC(Value *Char, Value *File, IRBuilder<> &B, Value *llvm::emitFPutS(Value *Str, Value *File, IRBuilder<> &B, const TargetLibraryInfo *TLI) { - if (!TLI->has(LibFunc::fputs)) + if (!TLI->has(LibFunc_fputs)) return nullptr; Module *M = B.GetInsertBlock()->getModule(); - StringRef FPutsName = TLI->getName(LibFunc::fputs); + StringRef FPutsName = TLI->getName(LibFunc_fputs); Constant *F = M->getOrInsertFunction( - FPutsName, B.getInt32Ty(), B.getInt8PtrTy(), File->getType(), nullptr); + FPutsName, B.getInt32Ty(), B.getInt8PtrTy(), File->getType()); if (File->getType()->isPointerTy()) inferLibFuncAttributes(*M->getFunction(FPutsName), *TLI); CallInst *CI = B.CreateCall(F, {castToCStr(Str, B), File}, "fputs"); @@ -986,16 +993,16 @@ Value *llvm::emitFPutS(Value *Str, Value *File, IRBuilder<> &B, Value *llvm::emitFWrite(Value *Ptr, Value *Size, Value *File, IRBuilder<> &B, const DataLayout &DL, const TargetLibraryInfo *TLI) { - if (!TLI->has(LibFunc::fwrite)) + if (!TLI->has(LibFunc_fwrite)) return nullptr; Module *M = B.GetInsertBlock()->getModule(); LLVMContext &Context = B.GetInsertBlock()->getContext(); - StringRef FWriteName = TLI->getName(LibFunc::fwrite); + StringRef FWriteName = TLI->getName(LibFunc_fwrite); Constant *F = M->getOrInsertFunction( FWriteName, DL.getIntPtrType(Context), B.getInt8PtrTy(), - DL.getIntPtrType(Context), DL.getIntPtrType(Context), File->getType(), - nullptr); + DL.getIntPtrType(Context), DL.getIntPtrType(Context), File->getType()); + if (File->getType()->isPointerTy()) inferLibFuncAttributes(*M->getFunction(FWriteName), *TLI); CallInst *CI = diff --git a/contrib/llvm/lib/Transforms/Utils/BypassSlowDivision.cpp b/contrib/llvm/lib/Transforms/Utils/BypassSlowDivision.cpp index bc2cef2..83ec7f5 100644 --- a/contrib/llvm/lib/Transforms/Utils/BypassSlowDivision.cpp +++ b/contrib/llvm/lib/Transforms/Utils/BypassSlowDivision.cpp @@ -17,9 +17,12 @@ #include "llvm/Transforms/Utils/BypassSlowDivision.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" +#include "llvm/Support/KnownBits.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; @@ -36,12 +39,21 @@ namespace { : SignedOp(InSignedOp), Dividend(InDividend), Divisor(InDivisor) {} }; - struct DivPhiNodes { - PHINode *Quotient; - PHINode *Remainder; + struct QuotRemPair { + Value *Quotient; + Value *Remainder; - DivPhiNodes(PHINode *InQuotient, PHINode *InRemainder) - : Quotient(InQuotient), Remainder(InRemainder) {} + QuotRemPair(Value *InQuotient, Value *InRemainder) + : Quotient(InQuotient), Remainder(InRemainder) {} + }; + + /// A quotient and remainder, plus a BB from which they logically "originate". + /// If you use Quotient or Remainder in a Phi node, you should use BB as its + /// corresponding predecessor. + struct QuotRemWithBB { + BasicBlock *BB = nullptr; + Value *Quotient = nullptr; + Value *Remainder = nullptr; }; } @@ -69,159 +81,376 @@ namespace llvm { } }; - typedef DenseMap<DivOpInfo, DivPhiNodes> DivCacheTy; + typedef DenseMap<DivOpInfo, QuotRemPair> DivCacheTy; + typedef DenseMap<unsigned, unsigned> BypassWidthsTy; + typedef SmallPtrSet<Instruction *, 4> VisitedSetTy; } -// insertFastDiv - Substitutes the div/rem instruction with code that checks the -// value of the operands and uses a shorter-faster div/rem instruction when -// possible and the longer-slower div/rem instruction otherwise. -static bool insertFastDiv(Instruction *I, IntegerType *BypassType, - bool UseDivOp, bool UseSignedOp, - DivCacheTy &PerBBDivCache) { - Function *F = I->getParent()->getParent(); - // Get instruction operands - Value *Dividend = I->getOperand(0); - Value *Divisor = I->getOperand(1); +namespace { +enum ValueRange { + /// Operand definitely fits into BypassType. No runtime checks are needed. + VALRNG_KNOWN_SHORT, + /// A runtime check is required, as value range is unknown. + VALRNG_UNKNOWN, + /// Operand is unlikely to fit into BypassType. The bypassing should be + /// disabled. + VALRNG_LIKELY_LONG +}; + +class FastDivInsertionTask { + bool IsValidTask = false; + Instruction *SlowDivOrRem = nullptr; + IntegerType *BypassType = nullptr; + BasicBlock *MainBB = nullptr; + + bool isHashLikeValue(Value *V, VisitedSetTy &Visited); + ValueRange getValueRange(Value *Op, VisitedSetTy &Visited); + QuotRemWithBB createSlowBB(BasicBlock *Successor); + QuotRemWithBB createFastBB(BasicBlock *Successor); + QuotRemPair createDivRemPhiNodes(QuotRemWithBB &LHS, QuotRemWithBB &RHS, + BasicBlock *PhiBB); + Value *insertOperandRuntimeCheck(Value *Op1, Value *Op2); + Optional<QuotRemPair> insertFastDivAndRem(); + + bool isSignedOp() { + return SlowDivOrRem->getOpcode() == Instruction::SDiv || + SlowDivOrRem->getOpcode() == Instruction::SRem; + } + bool isDivisionOp() { + return SlowDivOrRem->getOpcode() == Instruction::SDiv || + SlowDivOrRem->getOpcode() == Instruction::UDiv; + } + Type *getSlowType() { return SlowDivOrRem->getType(); } + +public: + FastDivInsertionTask(Instruction *I, const BypassWidthsTy &BypassWidths); + Value *getReplacement(DivCacheTy &Cache); +}; +} // anonymous namespace + +FastDivInsertionTask::FastDivInsertionTask(Instruction *I, + const BypassWidthsTy &BypassWidths) { + switch (I->getOpcode()) { + case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::URem: + case Instruction::SRem: + SlowDivOrRem = I; + break; + default: + // I is not a div/rem operation. + return; + } - if (isa<ConstantInt>(Divisor)) { - // Division by a constant should have been been solved and replaced earlier - // in the pipeline. - return false; + // Skip division on vector types. Only optimize integer instructions. + IntegerType *SlowType = dyn_cast<IntegerType>(SlowDivOrRem->getType()); + if (!SlowType) + return; + + // Skip if this bitwidth is not bypassed. + auto BI = BypassWidths.find(SlowType->getBitWidth()); + if (BI == BypassWidths.end()) + return; + + // Get type for div/rem instruction with bypass bitwidth. + IntegerType *BT = IntegerType::get(I->getContext(), BI->second); + BypassType = BT; + + // The original basic block. + MainBB = I->getParent(); + + // The instruction is indeed a slow div or rem operation. + IsValidTask = true; +} + +/// Reuses previously-computed dividend or remainder from the current BB if +/// operands and operation are identical. Otherwise calls insertFastDivAndRem to +/// perform the optimization and caches the resulting dividend and remainder. +/// If no replacement can be generated, nullptr is returned. +Value *FastDivInsertionTask::getReplacement(DivCacheTy &Cache) { + // First, make sure that the task is valid. + if (!IsValidTask) + return nullptr; + + // Then, look for a value in Cache. + Value *Dividend = SlowDivOrRem->getOperand(0); + Value *Divisor = SlowDivOrRem->getOperand(1); + DivOpInfo Key(isSignedOp(), Dividend, Divisor); + auto CacheI = Cache.find(Key); + + if (CacheI == Cache.end()) { + // If previous instance does not exist, try to insert fast div. + Optional<QuotRemPair> OptResult = insertFastDivAndRem(); + // Bail out if insertFastDivAndRem has failed. + if (!OptResult) + return nullptr; + CacheI = Cache.insert({Key, *OptResult}).first; } - // If the numerator is a constant, bail if it doesn't fit into BypassType. - if (ConstantInt *ConstDividend = dyn_cast<ConstantInt>(Dividend)) - if (ConstDividend->getValue().getActiveBits() > BypassType->getBitWidth()) + QuotRemPair &Value = CacheI->second; + return isDivisionOp() ? Value.Quotient : Value.Remainder; +} + +/// \brief Check if a value looks like a hash. +/// +/// The routine is expected to detect values computed using the most common hash +/// algorithms. Typically, hash computations end with one of the following +/// instructions: +/// +/// 1) MUL with a constant wider than BypassType +/// 2) XOR instruction +/// +/// And even if we are wrong and the value is not a hash, it is still quite +/// unlikely that such values will fit into BypassType. +/// +/// To detect string hash algorithms like FNV we have to look through PHI-nodes. +/// It is implemented as a depth-first search for values that look neither long +/// nor hash-like. +bool FastDivInsertionTask::isHashLikeValue(Value *V, VisitedSetTy &Visited) { + Instruction *I = dyn_cast<Instruction>(V); + if (!I) + return false; + + switch (I->getOpcode()) { + case Instruction::Xor: + return true; + case Instruction::Mul: { + // After Constant Hoisting pass, long constants may be represented as + // bitcast instructions. As a result, some constants may look like an + // instruction at first, and an additional check is necessary to find out if + // an operand is actually a constant. + Value *Op1 = I->getOperand(1); + ConstantInt *C = dyn_cast<ConstantInt>(Op1); + if (!C && isa<BitCastInst>(Op1)) + C = dyn_cast<ConstantInt>(cast<BitCastInst>(Op1)->getOperand(0)); + return C && C->getValue().getMinSignedBits() > BypassType->getBitWidth(); + } + case Instruction::PHI: { + // Stop IR traversal in case of a crazy input code. This limits recursion + // depth. + if (Visited.size() >= 16) return false; + // Do not visit nodes that have been visited already. We return true because + // it means that we couldn't find any value that doesn't look hash-like. + if (Visited.find(I) != Visited.end()) + return true; + Visited.insert(I); + return llvm::all_of(cast<PHINode>(I)->incoming_values(), [&](Value *V) { + // Ignore undef values as they probably don't affect the division + // operands. + return getValueRange(V, Visited) == VALRNG_LIKELY_LONG || + isa<UndefValue>(V); + }); + } + default: + return false; + } +} + +/// Check if an integer value fits into our bypass type. +ValueRange FastDivInsertionTask::getValueRange(Value *V, + VisitedSetTy &Visited) { + unsigned ShortLen = BypassType->getBitWidth(); + unsigned LongLen = V->getType()->getIntegerBitWidth(); + + assert(LongLen > ShortLen && "Value type must be wider than BypassType"); + unsigned HiBits = LongLen - ShortLen; + + const DataLayout &DL = SlowDivOrRem->getModule()->getDataLayout(); + KnownBits Known(LongLen); - // Basic Block is split before divide - BasicBlock *MainBB = &*I->getParent(); - BasicBlock *SuccessorBB = MainBB->splitBasicBlock(I); - - // Add new basic block for slow divide operation - BasicBlock *SlowBB = - BasicBlock::Create(F->getContext(), "", MainBB->getParent(), SuccessorBB); - SlowBB->moveBefore(SuccessorBB); - IRBuilder<> SlowBuilder(SlowBB, SlowBB->begin()); - Value *SlowQuotientV; - Value *SlowRemainderV; - if (UseSignedOp) { - SlowQuotientV = SlowBuilder.CreateSDiv(Dividend, Divisor); - SlowRemainderV = SlowBuilder.CreateSRem(Dividend, Divisor); + computeKnownBits(V, Known, DL); + + if (Known.countMinLeadingZeros() >= HiBits) + return VALRNG_KNOWN_SHORT; + + if (Known.countMaxLeadingZeros() < HiBits) + return VALRNG_LIKELY_LONG; + + // Long integer divisions are often used in hashtable implementations. It's + // not worth bypassing such divisions because hash values are extremely + // unlikely to have enough leading zeros. The call below tries to detect + // values that are unlikely to fit BypassType (including hashes). + if (isHashLikeValue(V, Visited)) + return VALRNG_LIKELY_LONG; + + return VALRNG_UNKNOWN; +} + +/// Add new basic block for slow div and rem operations and put it before +/// SuccessorBB. +QuotRemWithBB FastDivInsertionTask::createSlowBB(BasicBlock *SuccessorBB) { + QuotRemWithBB DivRemPair; + DivRemPair.BB = BasicBlock::Create(MainBB->getParent()->getContext(), "", + MainBB->getParent(), SuccessorBB); + IRBuilder<> Builder(DivRemPair.BB, DivRemPair.BB->begin()); + + Value *Dividend = SlowDivOrRem->getOperand(0); + Value *Divisor = SlowDivOrRem->getOperand(1); + + if (isSignedOp()) { + DivRemPair.Quotient = Builder.CreateSDiv(Dividend, Divisor); + DivRemPair.Remainder = Builder.CreateSRem(Dividend, Divisor); } else { - SlowQuotientV = SlowBuilder.CreateUDiv(Dividend, Divisor); - SlowRemainderV = SlowBuilder.CreateURem(Dividend, Divisor); + DivRemPair.Quotient = Builder.CreateUDiv(Dividend, Divisor); + DivRemPair.Remainder = Builder.CreateURem(Dividend, Divisor); } - SlowBuilder.CreateBr(SuccessorBB); - - // Add new basic block for fast divide operation - BasicBlock *FastBB = - BasicBlock::Create(F->getContext(), "", MainBB->getParent(), SuccessorBB); - FastBB->moveBefore(SlowBB); - IRBuilder<> FastBuilder(FastBB, FastBB->begin()); - Value *ShortDivisorV = FastBuilder.CreateCast(Instruction::Trunc, Divisor, - BypassType); - Value *ShortDividendV = FastBuilder.CreateCast(Instruction::Trunc, Dividend, - BypassType); - - // udiv/urem because optimization only handles positive numbers - Value *ShortQuotientV = FastBuilder.CreateUDiv(ShortDividendV, ShortDivisorV); - Value *ShortRemainderV = FastBuilder.CreateURem(ShortDividendV, - ShortDivisorV); - Value *FastQuotientV = FastBuilder.CreateCast(Instruction::ZExt, - ShortQuotientV, - Dividend->getType()); - Value *FastRemainderV = FastBuilder.CreateCast(Instruction::ZExt, - ShortRemainderV, - Dividend->getType()); - FastBuilder.CreateBr(SuccessorBB); - - // Phi nodes for result of div and rem - IRBuilder<> SuccessorBuilder(SuccessorBB, SuccessorBB->begin()); - PHINode *QuoPhi = SuccessorBuilder.CreatePHI(I->getType(), 2); - QuoPhi->addIncoming(SlowQuotientV, SlowBB); - QuoPhi->addIncoming(FastQuotientV, FastBB); - PHINode *RemPhi = SuccessorBuilder.CreatePHI(I->getType(), 2); - RemPhi->addIncoming(SlowRemainderV, SlowBB); - RemPhi->addIncoming(FastRemainderV, FastBB); - - // Replace I with appropriate phi node - if (UseDivOp) - I->replaceAllUsesWith(QuoPhi); - else - I->replaceAllUsesWith(RemPhi); - I->eraseFromParent(); - // Combine operands into a single value with OR for value testing below - MainBB->getInstList().back().eraseFromParent(); - IRBuilder<> MainBuilder(MainBB, MainBB->end()); + Builder.CreateBr(SuccessorBB); + return DivRemPair; +} + +/// Add new basic block for fast div and rem operations and put it before +/// SuccessorBB. +QuotRemWithBB FastDivInsertionTask::createFastBB(BasicBlock *SuccessorBB) { + QuotRemWithBB DivRemPair; + DivRemPair.BB = BasicBlock::Create(MainBB->getParent()->getContext(), "", + MainBB->getParent(), SuccessorBB); + IRBuilder<> Builder(DivRemPair.BB, DivRemPair.BB->begin()); + + Value *Dividend = SlowDivOrRem->getOperand(0); + Value *Divisor = SlowDivOrRem->getOperand(1); + Value *ShortDivisorV = + Builder.CreateCast(Instruction::Trunc, Divisor, BypassType); + Value *ShortDividendV = + Builder.CreateCast(Instruction::Trunc, Dividend, BypassType); + + // udiv/urem because this optimization only handles positive numbers. + Value *ShortQV = Builder.CreateUDiv(ShortDividendV, ShortDivisorV); + Value *ShortRV = Builder.CreateURem(ShortDividendV, ShortDivisorV); + DivRemPair.Quotient = + Builder.CreateCast(Instruction::ZExt, ShortQV, getSlowType()); + DivRemPair.Remainder = + Builder.CreateCast(Instruction::ZExt, ShortRV, getSlowType()); + Builder.CreateBr(SuccessorBB); + + return DivRemPair; +} - // We should have bailed out above if the divisor is a constant, but the - // dividend may still be a constant. Set OrV to our non-constant operands - // OR'ed together. - assert(!isa<ConstantInt>(Divisor)); +/// Creates Phi nodes for result of Div and Rem. +QuotRemPair FastDivInsertionTask::createDivRemPhiNodes(QuotRemWithBB &LHS, + QuotRemWithBB &RHS, + BasicBlock *PhiBB) { + IRBuilder<> Builder(PhiBB, PhiBB->begin()); + PHINode *QuoPhi = Builder.CreatePHI(getSlowType(), 2); + QuoPhi->addIncoming(LHS.Quotient, LHS.BB); + QuoPhi->addIncoming(RHS.Quotient, RHS.BB); + PHINode *RemPhi = Builder.CreatePHI(getSlowType(), 2); + RemPhi->addIncoming(LHS.Remainder, LHS.BB); + RemPhi->addIncoming(RHS.Remainder, RHS.BB); + return QuotRemPair(QuoPhi, RemPhi); +} + +/// Creates a runtime check to test whether both the divisor and dividend fit +/// into BypassType. The check is inserted at the end of MainBB. True return +/// value means that the operands fit. Either of the operands may be NULL if it +/// doesn't need a runtime check. +Value *FastDivInsertionTask::insertOperandRuntimeCheck(Value *Op1, Value *Op2) { + assert((Op1 || Op2) && "Nothing to check"); + IRBuilder<> Builder(MainBB, MainBB->end()); Value *OrV; - if (!isa<ConstantInt>(Dividend)) - OrV = MainBuilder.CreateOr(Dividend, Divisor); + if (Op1 && Op2) + OrV = Builder.CreateOr(Op1, Op2); else - OrV = Divisor; + OrV = Op1 ? Op1 : Op2; // BitMask is inverted to check if the operands are // larger than the bypass type uint64_t BitMask = ~BypassType->getBitMask(); - Value *AndV = MainBuilder.CreateAnd(OrV, BitMask); - - // Compare operand values and branch - Value *ZeroV = ConstantInt::getSigned(Dividend->getType(), 0); - Value *CmpV = MainBuilder.CreateICmpEQ(AndV, ZeroV); - MainBuilder.CreateCondBr(CmpV, FastBB, SlowBB); - - // Cache phi nodes to be used later in place of other instances - // of div or rem with the same sign, dividend, and divisor - DivOpInfo Key(UseSignedOp, Dividend, Divisor); - DivPhiNodes Value(QuoPhi, RemPhi); - PerBBDivCache.insert(std::pair<DivOpInfo, DivPhiNodes>(Key, Value)); - return true; + Value *AndV = Builder.CreateAnd(OrV, BitMask); + + // Compare operand values + Value *ZeroV = ConstantInt::getSigned(getSlowType(), 0); + return Builder.CreateICmpEQ(AndV, ZeroV); } -// reuseOrInsertFastDiv - Reuses previously computed dividend or remainder from -// the current BB if operands and operation are identical. Otherwise calls -// insertFastDiv to perform the optimization and caches the resulting dividend -// and remainder. -static bool reuseOrInsertFastDiv(Instruction *I, IntegerType *BypassType, - bool UseDivOp, bool UseSignedOp, - DivCacheTy &PerBBDivCache) { - // Get instruction operands - DivOpInfo Key(UseSignedOp, I->getOperand(0), I->getOperand(1)); - DivCacheTy::iterator CacheI = PerBBDivCache.find(Key); - - if (CacheI == PerBBDivCache.end()) { - // If previous instance does not exist, insert fast div - return insertFastDiv(I, BypassType, UseDivOp, UseSignedOp, PerBBDivCache); +/// Substitutes the div/rem instruction with code that checks the value of the +/// operands and uses a shorter-faster div/rem instruction when possible. +Optional<QuotRemPair> FastDivInsertionTask::insertFastDivAndRem() { + Value *Dividend = SlowDivOrRem->getOperand(0); + Value *Divisor = SlowDivOrRem->getOperand(1); + + if (isa<ConstantInt>(Divisor)) { + // Keep division by a constant for DAGCombiner. + return None; } - // Replace operation value with previously generated phi node - DivPhiNodes &Value = CacheI->second; - if (UseDivOp) { - // Replace all uses of div instruction with quotient phi node - I->replaceAllUsesWith(Value.Quotient); + VisitedSetTy SetL; + ValueRange DividendRange = getValueRange(Dividend, SetL); + if (DividendRange == VALRNG_LIKELY_LONG) + return None; + + VisitedSetTy SetR; + ValueRange DivisorRange = getValueRange(Divisor, SetR); + if (DivisorRange == VALRNG_LIKELY_LONG) + return None; + + bool DividendShort = (DividendRange == VALRNG_KNOWN_SHORT); + bool DivisorShort = (DivisorRange == VALRNG_KNOWN_SHORT); + + if (DividendShort && DivisorShort) { + // If both operands are known to be short then just replace the long + // division with a short one in-place. + + IRBuilder<> Builder(SlowDivOrRem); + Value *TruncDividend = Builder.CreateTrunc(Dividend, BypassType); + Value *TruncDivisor = Builder.CreateTrunc(Divisor, BypassType); + Value *TruncDiv = Builder.CreateUDiv(TruncDividend, TruncDivisor); + Value *TruncRem = Builder.CreateURem(TruncDividend, TruncDivisor); + Value *ExtDiv = Builder.CreateZExt(TruncDiv, getSlowType()); + Value *ExtRem = Builder.CreateZExt(TruncRem, getSlowType()); + return QuotRemPair(ExtDiv, ExtRem); + } else if (DividendShort && !isSignedOp()) { + // If the division is unsigned and Dividend is known to be short, then + // either + // 1) Divisor is less or equal to Dividend, and the result can be computed + // with a short division. + // 2) Divisor is greater than Dividend. In this case, no division is needed + // at all: The quotient is 0 and the remainder is equal to Dividend. + // + // So instead of checking at runtime whether Divisor fits into BypassType, + // we emit a runtime check to differentiate between these two cases. This + // lets us entirely avoid a long div. + + // Split the basic block before the div/rem. + BasicBlock *SuccessorBB = MainBB->splitBasicBlock(SlowDivOrRem); + // Remove the unconditional branch from MainBB to SuccessorBB. + MainBB->getInstList().back().eraseFromParent(); + QuotRemWithBB Long; + Long.BB = MainBB; + Long.Quotient = ConstantInt::get(getSlowType(), 0); + Long.Remainder = Dividend; + QuotRemWithBB Fast = createFastBB(SuccessorBB); + QuotRemPair Result = createDivRemPhiNodes(Fast, Long, SuccessorBB); + IRBuilder<> Builder(MainBB, MainBB->end()); + Value *CmpV = Builder.CreateICmpUGE(Dividend, Divisor); + Builder.CreateCondBr(CmpV, Fast.BB, SuccessorBB); + return Result; } else { - // Replace all uses of rem instruction with remainder phi node - I->replaceAllUsesWith(Value.Remainder); + // General case. Create both slow and fast div/rem pairs and choose one of + // them at runtime. + + // Split the basic block before the div/rem. + BasicBlock *SuccessorBB = MainBB->splitBasicBlock(SlowDivOrRem); + // Remove the unconditional branch from MainBB to SuccessorBB. + MainBB->getInstList().back().eraseFromParent(); + QuotRemWithBB Fast = createFastBB(SuccessorBB); + QuotRemWithBB Slow = createSlowBB(SuccessorBB); + QuotRemPair Result = createDivRemPhiNodes(Fast, Slow, SuccessorBB); + Value *CmpV = insertOperandRuntimeCheck(DividendShort ? nullptr : Dividend, + DivisorShort ? nullptr : Divisor); + IRBuilder<> Builder(MainBB, MainBB->end()); + Builder.CreateCondBr(CmpV, Fast.BB, Slow.BB); + return Result; } - - // Remove redundant operation - I->eraseFromParent(); - return true; } -// bypassSlowDivision - This optimization identifies DIV instructions in a BB -// that can be profitably bypassed and carried out with a shorter, faster -// divide. -bool llvm::bypassSlowDivision( - BasicBlock *BB, const DenseMap<unsigned int, unsigned int> &BypassWidths) { - DivCacheTy DivCache; +/// This optimization identifies DIV/REM instructions in a BB that can be +/// profitably bypassed and carried out with a shorter, faster divide. +bool llvm::bypassSlowDivision(BasicBlock *BB, + const BypassWidthsTy &BypassWidths) { + DivCacheTy PerBBDivCache; bool MadeChange = false; Instruction* Next = &*BB->begin(); @@ -231,42 +460,20 @@ bool llvm::bypassSlowDivision( Instruction* I = Next; Next = Next->getNextNode(); - // Get instruction details - unsigned Opcode = I->getOpcode(); - bool UseDivOp = Opcode == Instruction::SDiv || Opcode == Instruction::UDiv; - bool UseRemOp = Opcode == Instruction::SRem || Opcode == Instruction::URem; - bool UseSignedOp = Opcode == Instruction::SDiv || - Opcode == Instruction::SRem; - - // Only optimize div or rem ops - if (!UseDivOp && !UseRemOp) - continue; - - // Skip division on vector types, only optimize integer instructions - if (!I->getType()->isIntegerTy()) - continue; - - // Get bitwidth of div/rem instruction - IntegerType *T = cast<IntegerType>(I->getType()); - unsigned int bitwidth = T->getBitWidth(); - - // Continue if bitwidth is not bypassed - DenseMap<unsigned int, unsigned int>::const_iterator BI = BypassWidths.find(bitwidth); - if (BI == BypassWidths.end()) - continue; - - // Get type for div/rem instruction with bypass bitwidth - IntegerType *BT = IntegerType::get(I->getContext(), BI->second); - - MadeChange |= reuseOrInsertFastDiv(I, BT, UseDivOp, UseSignedOp, DivCache); + FastDivInsertionTask Task(I, BypassWidths); + if (Value *Replacement = Task.getReplacement(PerBBDivCache)) { + I->replaceAllUsesWith(Replacement); + I->eraseFromParent(); + MadeChange = true; + } } // Above we eagerly create divs and rems, as pairs, so that we can efficiently // create divrem machine instructions. Now erase any unused divs / rems so we // don't leave extra instructions sitting around. - for (auto &KV : DivCache) - for (Instruction *Phi : {KV.second.Quotient, KV.second.Remainder}) - RecursivelyDeleteTriviallyDeadInstructions(Phi); + for (auto &KV : PerBBDivCache) + for (Value *V : {KV.second.Quotient, KV.second.Remainder}) + RecursivelyDeleteTriviallyDeadInstructions(V); return MadeChange; } diff --git a/contrib/llvm/lib/Transforms/Utils/CloneFunction.cpp b/contrib/llvm/lib/Transforms/Utils/CloneFunction.cpp index 4d33e22..9c4e139 100644 --- a/contrib/llvm/lib/Transforms/Utils/CloneFunction.cpp +++ b/contrib/llvm/lib/Transforms/Utils/CloneFunction.cpp @@ -13,7 +13,6 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Utils/Cloning.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/ConstantFolding.h" @@ -31,24 +30,38 @@ #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ValueMapper.h" #include <map> using namespace llvm; /// See comments in Cloning.h. -BasicBlock *llvm::CloneBasicBlock(const BasicBlock *BB, - ValueToValueMapTy &VMap, +BasicBlock *llvm::CloneBasicBlock(const BasicBlock *BB, ValueToValueMapTy &VMap, const Twine &NameSuffix, Function *F, - ClonedCodeInfo *CodeInfo) { + ClonedCodeInfo *CodeInfo, + DebugInfoFinder *DIFinder) { + DenseMap<const MDNode *, MDNode *> Cache; BasicBlock *NewBB = BasicBlock::Create(BB->getContext(), "", F); if (BB->hasName()) NewBB->setName(BB->getName()+NameSuffix); bool hasCalls = false, hasDynamicAllocas = false, hasStaticAllocas = false; - + Module *TheModule = F ? F->getParent() : nullptr; + // Loop over all instructions, and copy them over. for (BasicBlock::const_iterator II = BB->begin(), IE = BB->end(); II != IE; ++II) { + + if (DIFinder && TheModule) { + if (auto *DDI = dyn_cast<DbgDeclareInst>(II)) + DIFinder->processDeclare(*TheModule, DDI); + else if (auto *DVI = dyn_cast<DbgValueInst>(II)) + DIFinder->processValue(*TheModule, DVI); + + if (auto DbgLoc = II->getDebugLoc()) + DIFinder->processLocation(*TheModule, DbgLoc.get()); + } + Instruction *NewInst = II->clone(); if (II->hasName()) NewInst->setName(II->getName()+NameSuffix); @@ -90,9 +103,9 @@ void llvm::CloneFunctionInto(Function *NewFunc, const Function *OldFunc, assert(VMap.count(&I) && "No mapping from source argument specified!"); #endif - // Copy all attributes other than those stored in the AttributeSet. We need - // to remap the parameter indices of the AttributeSet. - AttributeSet NewAttrs = NewFunc->getAttributes(); + // Copy all attributes other than those stored in the AttributeList. We need + // to remap the parameter indices of the AttributeList. + AttributeList NewAttrs = NewFunc->getAttributes(); NewFunc->copyAttributesFrom(OldFunc); NewFunc->setAttributes(NewAttrs); @@ -103,31 +116,54 @@ void llvm::CloneFunctionInto(Function *NewFunc, const Function *OldFunc, ModuleLevelChanges ? RF_None : RF_NoModuleLevelChanges, TypeMapper, Materializer)); - AttributeSet OldAttrs = OldFunc->getAttributes(); + SmallVector<AttributeSet, 4> NewArgAttrs(NewFunc->arg_size()); + AttributeList OldAttrs = OldFunc->getAttributes(); + // Clone any argument attributes that are present in the VMap. - for (const Argument &OldArg : OldFunc->args()) + for (const Argument &OldArg : OldFunc->args()) { if (Argument *NewArg = dyn_cast<Argument>(VMap[&OldArg])) { - AttributeSet attrs = - OldAttrs.getParamAttributes(OldArg.getArgNo() + 1); - if (attrs.getNumSlots() > 0) - NewArg->addAttr(attrs); + NewArgAttrs[NewArg->getArgNo()] = + OldAttrs.getParamAttributes(OldArg.getArgNo()); } + } NewFunc->setAttributes( - NewFunc->getAttributes() - .addAttributes(NewFunc->getContext(), AttributeSet::ReturnIndex, - OldAttrs.getRetAttributes()) - .addAttributes(NewFunc->getContext(), AttributeSet::FunctionIndex, - OldAttrs.getFnAttributes())); + AttributeList::get(NewFunc->getContext(), OldAttrs.getFnAttributes(), + OldAttrs.getRetAttributes(), NewArgAttrs)); + + bool MustCloneSP = + OldFunc->getParent() && OldFunc->getParent() == NewFunc->getParent(); + DISubprogram *SP = OldFunc->getSubprogram(); + if (SP) { + assert(!MustCloneSP || ModuleLevelChanges); + // Add mappings for some DebugInfo nodes that we don't want duplicated + // even if they're distinct. + auto &MD = VMap.MD(); + MD[SP->getUnit()].reset(SP->getUnit()); + MD[SP->getType()].reset(SP->getType()); + MD[SP->getFile()].reset(SP->getFile()); + // If we're not cloning into the same module, no need to clone the + // subprogram + if (!MustCloneSP) + MD[SP].reset(SP); + } SmallVector<std::pair<unsigned, MDNode *>, 1> MDs; OldFunc->getAllMetadata(MDs); - for (auto MD : MDs) + for (auto MD : MDs) { NewFunc->addMetadata( MD.first, *MapMetadata(MD.second, VMap, ModuleLevelChanges ? RF_None : RF_NoModuleLevelChanges, TypeMapper, Materializer)); + } + + // When we remap instructions, we want to avoid duplicating inlined + // DISubprograms, so record all subprograms we find as we duplicate + // instructions and then freeze them in the MD map. + // We also record information about dbg.value and dbg.declare to avoid + // duplicating the types. + DebugInfoFinder DIFinder; // Loop over all of the basic blocks in the function, cloning them as // appropriate. Note that we save BE this way in order to handle cloning of @@ -138,7 +174,8 @@ void llvm::CloneFunctionInto(Function *NewFunc, const Function *OldFunc, const BasicBlock &BB = *BI; // Create a new basic block and copy instructions into it! - BasicBlock *CBB = CloneBasicBlock(&BB, VMap, NameSuffix, NewFunc, CodeInfo); + BasicBlock *CBB = CloneBasicBlock(&BB, VMap, NameSuffix, NewFunc, CodeInfo, + SP ? &DIFinder : nullptr); // Add basic block mapping. VMap[&BB] = CBB; @@ -160,6 +197,16 @@ void llvm::CloneFunctionInto(Function *NewFunc, const Function *OldFunc, Returns.push_back(RI); } + for (DISubprogram *ISP : DIFinder.subprograms()) { + if (ISP != SP) { + VMap.MD()[ISP].reset(ISP); + } + } + + for (auto *Type : DIFinder.types()) { + VMap.MD()[Type].reset(Type); + } + // Loop over all of the instructions in the function, fixing up operand // references as we go. This uses VMap to do all the hard work. for (Function::iterator BB = @@ -208,7 +255,7 @@ Function *llvm::CloneFunction(Function *F, ValueToValueMapTy &VMap, } SmallVector<ReturnInst*, 8> Returns; // Ignore returns cloned. - CloneFunctionInto(NewF, F, VMap, /*ModuleLevelChanges=*/false, Returns, "", + CloneFunctionInto(NewF, F, VMap, F->getSubprogram() != nullptr, Returns, "", CodeInfo); return NewF; @@ -247,7 +294,7 @@ namespace { void PruningFunctionCloner::CloneBlock(const BasicBlock *BB, BasicBlock::const_iterator StartingInst, std::vector<const BasicBlock*> &ToClone){ - WeakVH &BBEntry = VMap[BB]; + WeakTrackingVH &BBEntry = VMap[BB]; // Have we already cloned this block? if (BBEntry) return; @@ -294,12 +341,13 @@ void PruningFunctionCloner::CloneBlock(const BasicBlock *BB, SimplifyInstruction(NewInst, BB->getModule()->getDataLayout())) { // On the off-chance that this simplifies to an instruction in the old // function, map it back into the new function. - if (Value *MappedV = VMap.lookup(V)) - V = MappedV; + if (NewFunc != OldFunc) + if (Value *MappedV = VMap.lookup(V)) + V = MappedV; if (!NewInst->mayHaveSideEffects()) { VMap[&*II] = V; - delete NewInst; + NewInst->deleteValue(); continue; } } @@ -353,7 +401,7 @@ void PruningFunctionCloner::CloneBlock(const BasicBlock *BB, Cond = dyn_cast_or_null<ConstantInt>(V); } if (Cond) { // Constant fold to uncond branch! - SwitchInst::ConstCaseIt Case = SI->findCaseValue(Cond); + SwitchInst::ConstCaseHandle Case = *SI->findCaseValue(Cond); BasicBlock *Dest = const_cast<BasicBlock*>(Case.getCaseSuccessor()); VMap[OldTI] = BranchInst::Create(Dest, NewBB); ToClone.push_back(Dest); @@ -549,7 +597,7 @@ void llvm::CloneAndPruneIntoFromInst(Function *NewFunc, const Function *OldFunc, // Make a second pass over the PHINodes now that all of them have been // remapped into the new function, simplifying the PHINode and performing any // recursive simplifications exposed. This will transparently update the - // WeakVH in the VMap. Notably, we rely on that so that if we coalesce + // WeakTrackingVH in the VMap. Notably, we rely on that so that if we coalesce // two PHINodes, the iteration over the old PHIs remains valid, and the // mapping will just map us to the new node (which may not even be a PHI // node). @@ -747,3 +795,40 @@ Loop *llvm::cloneLoopWithPreheader(BasicBlock *Before, BasicBlock *LoopDomBB, return NewLoop; } + +/// \brief Duplicate non-Phi instructions from the beginning of block up to +/// StopAt instruction into a split block between BB and its predecessor. +BasicBlock * +llvm::DuplicateInstructionsInSplitBetween(BasicBlock *BB, BasicBlock *PredBB, + Instruction *StopAt, + ValueToValueMapTy &ValueMapping) { + // We are going to have to map operands from the original BB block to the new + // copy of the block 'NewBB'. If there are PHI nodes in BB, evaluate them to + // account for entry from PredBB. + BasicBlock::iterator BI = BB->begin(); + for (; PHINode *PN = dyn_cast<PHINode>(BI); ++BI) + ValueMapping[PN] = PN->getIncomingValueForBlock(PredBB); + + BasicBlock *NewBB = SplitEdge(PredBB, BB); + NewBB->setName(PredBB->getName() + ".split"); + Instruction *NewTerm = NewBB->getTerminator(); + + // Clone the non-phi instructions of BB into NewBB, keeping track of the + // mapping and using it to remap operands in the cloned instructions. + for (; StopAt != &*BI; ++BI) { + Instruction *New = BI->clone(); + New->setName(BI->getName()); + New->insertBefore(NewTerm); + ValueMapping[&*BI] = New; + + // Remap operands to patch up intra-block references. + for (unsigned i = 0, e = New->getNumOperands(); i != e; ++i) + if (Instruction *Inst = dyn_cast<Instruction>(New->getOperand(i))) { + auto I = ValueMapping.find(Inst); + if (I != ValueMapping.end()) + New->setOperand(i, I->second); + } + } + + return NewBB; +} diff --git a/contrib/llvm/lib/Transforms/Utils/CloneModule.cpp b/contrib/llvm/lib/Transforms/Utils/CloneModule.cpp index 7ebeb61..e5392b5 100644 --- a/contrib/llvm/lib/Transforms/Utils/CloneModule.cpp +++ b/contrib/llvm/lib/Transforms/Utils/CloneModule.cpp @@ -12,14 +12,23 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm-c/Core.h" #include "llvm/IR/Constant.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Module.h" +#include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/ValueMapper.h" -#include "llvm-c/Core.h" using namespace llvm; +static void copyComdat(GlobalObject *Dst, const GlobalObject *Src) { + const Comdat *SC = Src->getComdat(); + if (!SC) + return; + Comdat *DC = Dst->getParent()->getOrInsertComdat(SC->getName()); + DC->setSelectionKind(SC->getSelectionKind()); + Dst->setComdat(DC); +} + /// This is not as easy as it might seem because we have to worry about making /// copies of global variables and functions, and making their (initializers and /// references, respectively) refer to the right globals. @@ -87,7 +96,7 @@ std::unique_ptr<Module> llvm::CloneModule( else GV = new GlobalVariable( *New, I->getValueType(), false, GlobalValue::ExternalLinkage, - (Constant *)nullptr, I->getName(), (GlobalVariable *)nullptr, + nullptr, I->getName(), nullptr, I->getThreadLocalMode(), I->getType()->getAddressSpace()); VMap[&*I] = GV; // We do not copy attributes (mainly because copying between different @@ -123,7 +132,10 @@ std::unique_ptr<Module> llvm::CloneModule( SmallVector<std::pair<unsigned, MDNode *>, 1> MDs; I->getAllMetadata(MDs); for (auto MD : MDs) - GV->addMetadata(MD.first, *MapMetadata(MD.second, VMap)); + GV->addMetadata(MD.first, + *MapMetadata(MD.second, VMap, RF_MoveDistinctMDs)); + + copyComdat(GV, &*I); } // Similarly, copy over function bodies now... @@ -153,6 +165,8 @@ std::unique_ptr<Module> llvm::CloneModule( if (I.hasPersonalityFn()) F->setPersonalityFn(MapValue(I.getPersonalityFn(), VMap)); + + copyComdat(F, &I); } // And aliases diff --git a/contrib/llvm/lib/Transforms/Utils/CmpInstAnalysis.cpp b/contrib/llvm/lib/Transforms/Utils/CmpInstAnalysis.cpp index 60ae374..d9294c4 100644 --- a/contrib/llvm/lib/Transforms/Utils/CmpInstAnalysis.cpp +++ b/contrib/llvm/lib/Transforms/Utils/CmpInstAnalysis.cpp @@ -73,17 +73,17 @@ bool llvm::decomposeBitTestICmp(const ICmpInst *I, CmpInst::Predicate &Pred, default: return false; case ICmpInst::ICMP_SLT: - // X < 0 is equivalent to (X & SignBit) != 0. + // X < 0 is equivalent to (X & SignMask) != 0. if (!C->isZero()) return false; - Y = ConstantInt::get(I->getContext(), APInt::getSignBit(C->getBitWidth())); + Y = ConstantInt::get(I->getContext(), APInt::getSignMask(C->getBitWidth())); Pred = ICmpInst::ICMP_NE; break; case ICmpInst::ICMP_SGT: - // X > -1 is equivalent to (X & SignBit) == 0. - if (!C->isAllOnesValue()) + // X > -1 is equivalent to (X & SignMask) == 0. + if (!C->isMinusOne()) return false; - Y = ConstantInt::get(I->getContext(), APInt::getSignBit(C->getBitWidth())); + Y = ConstantInt::get(I->getContext(), APInt::getSignMask(C->getBitWidth())); Pred = ICmpInst::ICMP_EQ; break; case ICmpInst::ICMP_ULT: diff --git a/contrib/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/contrib/llvm/lib/Transforms/Utils/CodeExtractor.cpp index c514c9c..1189714 100644 --- a/contrib/llvm/lib/Transforms/Utils/CodeExtractor.cpp +++ b/contrib/llvm/lib/Transforms/Utils/CodeExtractor.cpp @@ -27,6 +27,7 @@ #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" @@ -58,6 +59,33 @@ bool CodeExtractor::isBlockValidForExtraction(const BasicBlock &BB) { // Landing pads must be in the function where they were inserted for cleanup. if (BB.isEHPad()) return false; + // taking the address of a basic block moved to another function is illegal + if (BB.hasAddressTaken()) + return false; + + // don't hoist code that uses another basicblock address, as it's likely to + // lead to unexpected behavior, like cross-function jumps + SmallPtrSet<User const *, 16> Visited; + SmallVector<User const *, 16> ToVisit; + + for (Instruction const &Inst : BB) + ToVisit.push_back(&Inst); + + while (!ToVisit.empty()) { + User const *Curr = ToVisit.pop_back_val(); + if (!Visited.insert(Curr).second) + continue; + if (isa<BlockAddress const>(Curr)) + return false; // even a reference to self is likely to be not compatible + + if (isa<Instruction>(Curr) && cast<Instruction>(Curr)->getParent() != &BB) + continue; + + for (auto const &U : Curr->operands()) { + if (auto *UU = dyn_cast<User>(U)) + ToVisit.push_back(UU); + } + } // Don't hoist code containing allocas, invokes, or vastarts. for (BasicBlock::const_iterator I = BB.begin(), E = BB.end(); I != E; ++I) { @@ -73,24 +101,26 @@ bool CodeExtractor::isBlockValidForExtraction(const BasicBlock &BB) { } /// \brief Build a set of blocks to extract if the input blocks are viable. -template <typename IteratorT> -static SetVector<BasicBlock *> buildExtractionBlockSet(IteratorT BBBegin, - IteratorT BBEnd) { +static SetVector<BasicBlock *> +buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs, DominatorTree *DT) { + assert(!BBs.empty() && "The set of blocks to extract must be non-empty"); SetVector<BasicBlock *> Result; - assert(BBBegin != BBEnd); - // Loop over the blocks, adding them to our set-vector, and aborting with an // empty set if we encounter invalid blocks. - do { - if (!Result.insert(*BBBegin)) - llvm_unreachable("Repeated basic blocks in extraction input"); + for (BasicBlock *BB : BBs) { - if (!CodeExtractor::isBlockValidForExtraction(**BBBegin)) { + // If this block is dead, don't process it. + if (DT && !DT->isReachableFromEntry(BB)) + continue; + + if (!Result.insert(BB)) + llvm_unreachable("Repeated basic blocks in extraction input"); + if (!CodeExtractor::isBlockValidForExtraction(*BB)) { Result.clear(); return Result; } - } while (++BBBegin != BBEnd); + } #ifndef NDEBUG for (SetVector<BasicBlock *>::iterator I = std::next(Result.begin()), @@ -106,49 +136,19 @@ static SetVector<BasicBlock *> buildExtractionBlockSet(IteratorT BBBegin, return Result; } -/// \brief Helper to call buildExtractionBlockSet with an ArrayRef. -static SetVector<BasicBlock *> -buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs) { - return buildExtractionBlockSet(BBs.begin(), BBs.end()); -} - -/// \brief Helper to call buildExtractionBlockSet with a RegionNode. -static SetVector<BasicBlock *> -buildExtractionBlockSet(const RegionNode &RN) { - if (!RN.isSubRegion()) - // Just a single BasicBlock. - return buildExtractionBlockSet(RN.getNodeAs<BasicBlock>()); - - const Region &R = *RN.getNodeAs<Region>(); - - return buildExtractionBlockSet(R.block_begin(), R.block_end()); -} - -CodeExtractor::CodeExtractor(BasicBlock *BB, bool AggregateArgs, - BlockFrequencyInfo *BFI, - BranchProbabilityInfo *BPI) - : DT(nullptr), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), - BPI(BPI), Blocks(buildExtractionBlockSet(BB)), NumExitBlocks(~0U) {} - CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT, bool AggregateArgs, BlockFrequencyInfo *BFI, BranchProbabilityInfo *BPI) : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), - BPI(BPI), Blocks(buildExtractionBlockSet(BBs)), NumExitBlocks(~0U) {} + BPI(BPI), Blocks(buildExtractionBlockSet(BBs, DT)), NumExitBlocks(~0U) {} CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs, BlockFrequencyInfo *BFI, BranchProbabilityInfo *BPI) : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), - BPI(BPI), Blocks(buildExtractionBlockSet(L.getBlocks())), + BPI(BPI), Blocks(buildExtractionBlockSet(L.getBlocks(), &DT)), NumExitBlocks(~0U) {} -CodeExtractor::CodeExtractor(DominatorTree &DT, const RegionNode &RN, - bool AggregateArgs, BlockFrequencyInfo *BFI, - BranchProbabilityInfo *BPI) - : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), - BPI(BPI), Blocks(buildExtractionBlockSet(RN)), NumExitBlocks(~0U) {} - /// definedInRegion - Return true if the specified value is defined in the /// extracted region. static bool definedInRegion(const SetVector<BasicBlock *> &Blocks, Value *V) { @@ -169,16 +169,255 @@ static bool definedInCaller(const SetVector<BasicBlock *> &Blocks, Value *V) { return false; } -void CodeExtractor::findInputsOutputs(ValueSet &Inputs, - ValueSet &Outputs) const { +static BasicBlock *getCommonExitBlock(const SetVector<BasicBlock *> &Blocks) { + BasicBlock *CommonExitBlock = nullptr; + auto hasNonCommonExitSucc = [&](BasicBlock *Block) { + for (auto *Succ : successors(Block)) { + // Internal edges, ok. + if (Blocks.count(Succ)) + continue; + if (!CommonExitBlock) { + CommonExitBlock = Succ; + continue; + } + if (CommonExitBlock == Succ) + continue; + + return true; + } + return false; + }; + + if (any_of(Blocks, hasNonCommonExitSucc)) + return nullptr; + + return CommonExitBlock; +} + +bool CodeExtractor::isLegalToShrinkwrapLifetimeMarkers( + Instruction *Addr) const { + AllocaInst *AI = cast<AllocaInst>(Addr->stripInBoundsConstantOffsets()); + Function *Func = (*Blocks.begin())->getParent(); + for (BasicBlock &BB : *Func) { + if (Blocks.count(&BB)) + continue; + for (Instruction &II : BB) { + + if (isa<DbgInfoIntrinsic>(II)) + continue; + + unsigned Opcode = II.getOpcode(); + Value *MemAddr = nullptr; + switch (Opcode) { + case Instruction::Store: + case Instruction::Load: { + if (Opcode == Instruction::Store) { + StoreInst *SI = cast<StoreInst>(&II); + MemAddr = SI->getPointerOperand(); + } else { + LoadInst *LI = cast<LoadInst>(&II); + MemAddr = LI->getPointerOperand(); + } + // Global variable can not be aliased with locals. + if (dyn_cast<Constant>(MemAddr)) + break; + Value *Base = MemAddr->stripInBoundsConstantOffsets(); + if (!dyn_cast<AllocaInst>(Base) || Base == AI) + return false; + break; + } + default: { + IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(&II); + if (IntrInst) { + if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_start || + IntrInst->getIntrinsicID() == Intrinsic::lifetime_end) + break; + return false; + } + // Treat all the other cases conservatively if it has side effects. + if (II.mayHaveSideEffects()) + return false; + } + } + } + } + + return true; +} + +BasicBlock * +CodeExtractor::findOrCreateBlockForHoisting(BasicBlock *CommonExitBlock) { + BasicBlock *SinglePredFromOutlineRegion = nullptr; + assert(!Blocks.count(CommonExitBlock) && + "Expect a block outside the region!"); + for (auto *Pred : predecessors(CommonExitBlock)) { + if (!Blocks.count(Pred)) + continue; + if (!SinglePredFromOutlineRegion) { + SinglePredFromOutlineRegion = Pred; + } else if (SinglePredFromOutlineRegion != Pred) { + SinglePredFromOutlineRegion = nullptr; + break; + } + } + + if (SinglePredFromOutlineRegion) + return SinglePredFromOutlineRegion; + +#ifndef NDEBUG + auto getFirstPHI = [](BasicBlock *BB) { + BasicBlock::iterator I = BB->begin(); + PHINode *FirstPhi = nullptr; + while (I != BB->end()) { + PHINode *Phi = dyn_cast<PHINode>(I); + if (!Phi) + break; + if (!FirstPhi) { + FirstPhi = Phi; + break; + } + } + return FirstPhi; + }; + // If there are any phi nodes, the single pred either exists or has already + // be created before code extraction. + assert(!getFirstPHI(CommonExitBlock) && "Phi not expected"); +#endif + + BasicBlock *NewExitBlock = CommonExitBlock->splitBasicBlock( + CommonExitBlock->getFirstNonPHI()->getIterator()); + + for (auto *Pred : predecessors(CommonExitBlock)) { + if (Blocks.count(Pred)) + continue; + Pred->getTerminator()->replaceUsesOfWith(CommonExitBlock, NewExitBlock); + } + // Now add the old exit block to the outline region. + Blocks.insert(CommonExitBlock); + return CommonExitBlock; +} + +void CodeExtractor::findAllocas(ValueSet &SinkCands, ValueSet &HoistCands, + BasicBlock *&ExitBlock) const { + Function *Func = (*Blocks.begin())->getParent(); + ExitBlock = getCommonExitBlock(Blocks); + + for (BasicBlock &BB : *Func) { + if (Blocks.count(&BB)) + continue; + for (Instruction &II : BB) { + auto *AI = dyn_cast<AllocaInst>(&II); + if (!AI) + continue; + + // Find the pair of life time markers for address 'Addr' that are either + // defined inside the outline region or can legally be shrinkwrapped into + // the outline region. If there are not other untracked uses of the + // address, return the pair of markers if found; otherwise return a pair + // of nullptr. + auto GetLifeTimeMarkers = + [&](Instruction *Addr, bool &SinkLifeStart, + bool &HoistLifeEnd) -> std::pair<Instruction *, Instruction *> { + Instruction *LifeStart = nullptr, *LifeEnd = nullptr; + + for (User *U : Addr->users()) { + IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(U); + if (IntrInst) { + if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_start) { + // Do not handle the case where AI has multiple start markers. + if (LifeStart) + return std::make_pair<Instruction *>(nullptr, nullptr); + LifeStart = IntrInst; + } + if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_end) { + if (LifeEnd) + return std::make_pair<Instruction *>(nullptr, nullptr); + LifeEnd = IntrInst; + } + continue; + } + // Find untracked uses of the address, bail. + if (!definedInRegion(Blocks, U)) + return std::make_pair<Instruction *>(nullptr, nullptr); + } + + if (!LifeStart || !LifeEnd) + return std::make_pair<Instruction *>(nullptr, nullptr); + + SinkLifeStart = !definedInRegion(Blocks, LifeStart); + HoistLifeEnd = !definedInRegion(Blocks, LifeEnd); + // Do legality Check. + if ((SinkLifeStart || HoistLifeEnd) && + !isLegalToShrinkwrapLifetimeMarkers(Addr)) + return std::make_pair<Instruction *>(nullptr, nullptr); + + // Check to see if we have a place to do hoisting, if not, bail. + if (HoistLifeEnd && !ExitBlock) + return std::make_pair<Instruction *>(nullptr, nullptr); + + return std::make_pair(LifeStart, LifeEnd); + }; + + bool SinkLifeStart = false, HoistLifeEnd = false; + auto Markers = GetLifeTimeMarkers(AI, SinkLifeStart, HoistLifeEnd); + + if (Markers.first) { + if (SinkLifeStart) + SinkCands.insert(Markers.first); + SinkCands.insert(AI); + if (HoistLifeEnd) + HoistCands.insert(Markers.second); + continue; + } + + // Follow the bitcast. + Instruction *MarkerAddr = nullptr; + for (User *U : AI->users()) { + + if (U->stripInBoundsConstantOffsets() == AI) { + SinkLifeStart = false; + HoistLifeEnd = false; + Instruction *Bitcast = cast<Instruction>(U); + Markers = GetLifeTimeMarkers(Bitcast, SinkLifeStart, HoistLifeEnd); + if (Markers.first) { + MarkerAddr = Bitcast; + continue; + } + } + + // Found unknown use of AI. + if (!definedInRegion(Blocks, U)) { + MarkerAddr = nullptr; + break; + } + } + + if (MarkerAddr) { + if (SinkLifeStart) + SinkCands.insert(Markers.first); + if (!definedInRegion(Blocks, MarkerAddr)) + SinkCands.insert(MarkerAddr); + SinkCands.insert(AI); + if (HoistLifeEnd) + HoistCands.insert(Markers.second); + } + } + } +} + +void CodeExtractor::findInputsOutputs(ValueSet &Inputs, ValueSet &Outputs, + const ValueSet &SinkCands) const { + for (BasicBlock *BB : Blocks) { // If a used value is defined outside the region, it's an input. If an // instruction is used outside the region, it's an output. for (Instruction &II : *BB) { for (User::op_iterator OI = II.op_begin(), OE = II.op_end(); OI != OE; - ++OI) - if (definedInCaller(Blocks, *OI)) - Inputs.insert(*OI); + ++OI) { + Value *V = *OI; + if (!SinkCands.count(V) && definedInCaller(Blocks, V)) + Inputs.insert(V); + } for (User *U : II.users()) if (!definedInRegion(Blocks, U)) { @@ -218,9 +457,7 @@ void CodeExtractor::severSplitPHINodes(BasicBlock *&Header) { // containing PHI nodes merging values from outside of the region, and a // second that contains all of the code for the block and merges back any // incoming values from inside of the region. - BasicBlock::iterator AfterPHIs = Header->getFirstNonPHI()->getIterator(); - BasicBlock *NewBB = Header->splitBasicBlock(AfterPHIs, - Header->getName()+".ce"); + BasicBlock *NewBB = llvm::SplitBlock(Header, Header->getFirstNonPHI(), DT); // We only want to code extract the second block now, and it becomes the new // header of the region. @@ -229,11 +466,6 @@ void CodeExtractor::severSplitPHINodes(BasicBlock *&Header) { Blocks.insert(NewBB); Header = NewBB; - // Okay, update dominator sets. The blocks that dominate the new one are the - // blocks that dominate TIBB plus the new block itself. - if (DT) - DT->splitBlock(NewBB); - // Okay, now we need to adjust the PHI nodes and any branches from within the // region to go to the new header block instead of the old header block. if (NumPredsFromRegion) { @@ -248,12 +480,14 @@ void CodeExtractor::severSplitPHINodes(BasicBlock *&Header) { // Okay, everything within the region is now branching to the right block, we // just have to update the PHI nodes now, inserting PHI nodes into NewBB. + BasicBlock::iterator AfterPHIs; for (AfterPHIs = OldPred->begin(); isa<PHINode>(AfterPHIs); ++AfterPHIs) { PHINode *PN = cast<PHINode>(AfterPHIs); // Create a new PHI node in the new region, which has an incoming value // from OldPred of PN. PHINode *NewPN = PHINode::Create(PN->getType(), 1 + NumPredsFromRegion, PN->getName() + ".ce", &NewBB->front()); + PN->replaceAllUsesWith(NewPN); NewPN->addIncoming(PN, OldPred); // Loop over all of the incoming value in PN, moving them to NewPN if they @@ -362,9 +596,8 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, // "target-features" attribute allowing it to be lowered. // FIXME: This should be changed to check to see if a specific // attribute can not be inherited. - AttributeSet OldFnAttrs = oldFunction->getAttributes().getFnAttributes(); - AttrBuilder AB(OldFnAttrs, AttributeSet::FunctionIndex); - for (auto Attr : AB.td_attrs()) + AttrBuilder AB(oldFunction->getAttributes().getFnAttributes()); + for (const auto &Attr : AB.td_attrs()) newFunction->addFnAttr(Attr.first, Attr.second); newFunction->getBasicBlockList().push_back(newRootNode); @@ -440,8 +673,10 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer, // Emit a call to the new function, passing in: *pointer to struct (if // aggregating parameters), or plan inputs and allocated memory for outputs std::vector<Value*> params, StructValues, ReloadOutputs, Reloads; - - LLVMContext &Context = newFunction->getContext(); + + Module *M = newFunction->getParent(); + LLVMContext &Context = M->getContext(); + const DataLayout &DL = M->getDataLayout(); // Add inputs as params, or to be filled into the struct for (Value *input : inputs) @@ -456,8 +691,9 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer, StructValues.push_back(output); } else { AllocaInst *alloca = - new AllocaInst(output->getType(), nullptr, output->getName() + ".loc", - &codeReplacer->getParent()->front().front()); + new AllocaInst(output->getType(), DL.getAllocaAddrSpace(), + nullptr, output->getName() + ".loc", + &codeReplacer->getParent()->front().front()); ReloadOutputs.push_back(alloca); params.push_back(alloca); } @@ -473,7 +709,8 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer, // Allocate a struct at the beginning of this function StructArgTy = StructType::get(newFunction->getContext(), ArgTypes); - Struct = new AllocaInst(StructArgTy, nullptr, "structArg", + Struct = new AllocaInst(StructArgTy, DL.getAllocaAddrSpace(), nullptr, + "structArg", &codeReplacer->getParent()->front().front()); params.push_back(Struct); @@ -748,7 +985,8 @@ Function *CodeExtractor::extractCodeRegion() { if (!isEligible()) return nullptr; - ValueSet inputs, outputs; + ValueSet inputs, outputs, SinkingCands, HoistingCands; + BasicBlock *CommonExit = nullptr; // Assumption: this is a single-entry code region, and the header is the first // block in the region. @@ -787,8 +1025,23 @@ Function *CodeExtractor::extractCodeRegion() { "newFuncRoot"); newFuncRoot->getInstList().push_back(BranchInst::Create(header)); + findAllocas(SinkingCands, HoistingCands, CommonExit); + assert(HoistingCands.empty() || CommonExit); + // Find inputs to, outputs from the code region. - findInputsOutputs(inputs, outputs); + findInputsOutputs(inputs, outputs, SinkingCands); + + // Now sink all instructions which only have non-phi uses inside the region + for (auto *II : SinkingCands) + cast<Instruction>(II)->moveBefore(*newFuncRoot, + newFuncRoot->getFirstInsertionPt()); + + if (!HoistingCands.empty()) { + auto *HoistToBlock = findOrCreateBlockForHoisting(CommonExit); + Instruction *TI = HoistToBlock->getTerminator(); + for (auto *II : HoistingCands) + cast<Instruction>(II)->moveBefore(TI); + } // Calculate the exit blocks for the extracted region and the total exit // weights for each of those blocks. @@ -863,12 +1116,6 @@ Function *CodeExtractor::extractCodeRegion() { } } - //cerr << "NEW FUNCTION: " << *newFunction; - // verifyFunction(*newFunction); - - // cerr << "OLD FUNCTION: " << *oldFunction; - // verifyFunction(*oldFunction); - DEBUG(if (verifyFunction(*newFunction)) report_fatal_error("verifyFunction failed!")); return newFunction; diff --git a/contrib/llvm/lib/Transforms/Utils/DemoteRegToStack.cpp b/contrib/llvm/lib/Transforms/Utils/DemoteRegToStack.cpp index 75a1dde..6d3d287 100644 --- a/contrib/llvm/lib/Transforms/Utils/DemoteRegToStack.cpp +++ b/contrib/llvm/lib/Transforms/Utils/DemoteRegToStack.cpp @@ -7,12 +7,12 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/ADT/DenseMap.h" #include "llvm/Analysis/CFG.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Type.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; @@ -28,15 +28,17 @@ AllocaInst *llvm::DemoteRegToStack(Instruction &I, bool VolatileLoads, return nullptr; } + Function *F = I.getParent()->getParent(); + const DataLayout &DL = F->getParent()->getDataLayout(); + // Create a stack slot to hold the value. AllocaInst *Slot; if (AllocaPoint) { - Slot = new AllocaInst(I.getType(), nullptr, + Slot = new AllocaInst(I.getType(), DL.getAllocaAddrSpace(), nullptr, I.getName()+".reg2mem", AllocaPoint); } else { - Function *F = I.getParent()->getParent(); - Slot = new AllocaInst(I.getType(), nullptr, I.getName() + ".reg2mem", - &F->getEntryBlock().front()); + Slot = new AllocaInst(I.getType(), DL.getAllocaAddrSpace(), nullptr, + I.getName() + ".reg2mem", &F->getEntryBlock().front()); } // We cannot demote invoke instructions to the stack if their normal edge @@ -110,14 +112,17 @@ AllocaInst *llvm::DemotePHIToStack(PHINode *P, Instruction *AllocaPoint) { return nullptr; } + const DataLayout &DL = P->getModule()->getDataLayout(); + // Create a stack slot to hold the value. AllocaInst *Slot; if (AllocaPoint) { - Slot = new AllocaInst(P->getType(), nullptr, + Slot = new AllocaInst(P->getType(), DL.getAllocaAddrSpace(), nullptr, P->getName()+".reg2mem", AllocaPoint); } else { Function *F = P->getParent()->getParent(); - Slot = new AllocaInst(P->getType(), nullptr, P->getName() + ".reg2mem", + Slot = new AllocaInst(P->getType(), DL.getAllocaAddrSpace(), nullptr, + P->getName() + ".reg2mem", &F->getEntryBlock().front()); } diff --git a/contrib/llvm/lib/Transforms/Utils/EscapeEnumerator.cpp b/contrib/llvm/lib/Transforms/Utils/EscapeEnumerator.cpp index 8c23865..78d7474 100644 --- a/contrib/llvm/lib/Transforms/Utils/EscapeEnumerator.cpp +++ b/contrib/llvm/lib/Transforms/Utils/EscapeEnumerator.cpp @@ -67,8 +67,7 @@ IRBuilder<> *EscapeEnumerator::Next() { // Create a cleanup block. LLVMContext &C = F.getContext(); BasicBlock *CleanupBB = BasicBlock::Create(C, CleanupBBName, &F); - Type *ExnTy = - StructType::get(Type::getInt8PtrTy(C), Type::getInt32Ty(C), nullptr); + Type *ExnTy = StructType::get(Type::getInt8PtrTy(C), Type::getInt32Ty(C)); if (!F.hasPersonalityFn()) { Constant *PersFn = getDefaultPersonalityFn(F.getParent()); F.setPersonalityFn(PersFn); diff --git a/contrib/llvm/lib/Transforms/Utils/Evaluator.cpp b/contrib/llvm/lib/Transforms/Utils/Evaluator.cpp index 4adf175..1328f2f 100644 --- a/contrib/llvm/lib/Transforms/Utils/Evaluator.cpp +++ b/contrib/llvm/lib/Transforms/Utils/Evaluator.cpp @@ -16,11 +16,12 @@ #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/DiagnosticPrinter.h" #include "llvm/IR/GlobalVariable.h" -#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Operator.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -401,7 +402,7 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, Value *Ptr = PtrArg->stripPointerCasts(); if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Ptr)) { Type *ElemTy = GV->getValueType(); - if (!Size->isAllOnesValue() && + if (!Size->isMinusOne() && Size->getValue().getLimitedValue() >= DL.getTypeStoreSize(ElemTy)) { Invariants.insert(GV); @@ -438,7 +439,7 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, if (Callee->isDeclaration()) { // If this is a function we can constant fold, do it. - if (Constant *C = ConstantFoldCall(Callee, Formals, TLI)) { + if (Constant *C = ConstantFoldCall(CS, Callee, Formals, TLI)) { InstResult = C; DEBUG(dbgs() << "Constant folded function call. Result: " << *InstResult << "\n"); @@ -486,7 +487,7 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, ConstantInt *Val = dyn_cast<ConstantInt>(getVal(SI->getCondition())); if (!Val) return false; // Cannot determine. - NextBB = SI->findCaseValue(Val).getCaseSuccessor(); + NextBB = SI->findCaseValue(Val)->getCaseSuccessor(); } else if (IndirectBrInst *IBI = dyn_cast<IndirectBrInst>(CurInst)) { Value *Val = getVal(IBI->getAddress())->stripPointerCasts(); if (BlockAddress *BA = dyn_cast<BlockAddress>(Val)) diff --git a/contrib/llvm/lib/Transforms/Utils/FlattenCFG.cpp b/contrib/llvm/lib/Transforms/Utils/FlattenCFG.cpp index 7b96fbb..435eff3 100644 --- a/contrib/llvm/lib/Transforms/Utils/FlattenCFG.cpp +++ b/contrib/llvm/lib/Transforms/Utils/FlattenCFG.cpp @@ -11,7 +11,6 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Utils/Local.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/ValueTracking.h" @@ -19,6 +18,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" using namespace llvm; #define DEBUG_TYPE "flattencfg" diff --git a/contrib/llvm/lib/Transforms/Utils/FunctionComparator.cpp b/contrib/llvm/lib/Transforms/Utils/FunctionComparator.cpp index 81a7c4c..4a2be3a 100644 --- a/contrib/llvm/lib/Transforms/Utils/FunctionComparator.cpp +++ b/contrib/llvm/lib/Transforms/Utils/FunctionComparator.cpp @@ -15,8 +15,8 @@ #include "llvm/Transforms/Utils/FunctionComparator.h" #include "llvm/ADT/SmallSet.h" #include "llvm/IR/CallSite.h" -#include "llvm/IR/Instructions.h" #include "llvm/IR/InlineAsm.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -74,14 +74,16 @@ int FunctionComparator::cmpMem(StringRef L, StringRef R) const { return L.compare(R); } -int FunctionComparator::cmpAttrs(const AttributeSet L, - const AttributeSet R) const { - if (int Res = cmpNumbers(L.getNumSlots(), R.getNumSlots())) +int FunctionComparator::cmpAttrs(const AttributeList L, + const AttributeList R) const { + if (int Res = cmpNumbers(L.getNumAttrSets(), R.getNumAttrSets())) return Res; - for (unsigned i = 0, e = L.getNumSlots(); i != e; ++i) { - AttributeSet::iterator LI = L.begin(i), LE = L.end(i), RI = R.begin(i), - RE = R.end(i); + for (unsigned i = L.index_begin(), e = L.index_end(); i != e; ++i) { + AttributeSet LAS = L.getAttributes(i); + AttributeSet RAS = R.getAttributes(i); + AttributeSet::iterator LI = LAS.begin(), LE = LAS.end(); + AttributeSet::iterator RI = RAS.begin(), RE = RAS.end(); for (; LI != LE && RI != RE; ++LI, ++RI) { Attribute LA = *LI; Attribute RA = *RI; @@ -511,8 +513,8 @@ int FunctionComparator::cmpOperations(const Instruction *L, if (int Res = cmpOrderings(LI->getOrdering(), cast<LoadInst>(R)->getOrdering())) return Res; - if (int Res = - cmpNumbers(LI->getSynchScope(), cast<LoadInst>(R)->getSynchScope())) + if (int Res = cmpNumbers(LI->getSyncScopeID(), + cast<LoadInst>(R)->getSyncScopeID())) return Res; return cmpRangeMetadata(LI->getMetadata(LLVMContext::MD_range), cast<LoadInst>(R)->getMetadata(LLVMContext::MD_range)); @@ -527,7 +529,8 @@ int FunctionComparator::cmpOperations(const Instruction *L, if (int Res = cmpOrderings(SI->getOrdering(), cast<StoreInst>(R)->getOrdering())) return Res; - return cmpNumbers(SI->getSynchScope(), cast<StoreInst>(R)->getSynchScope()); + return cmpNumbers(SI->getSyncScopeID(), + cast<StoreInst>(R)->getSyncScopeID()); } if (const CmpInst *CI = dyn_cast<CmpInst>(L)) return cmpNumbers(CI->getPredicate(), cast<CmpInst>(R)->getPredicate()); @@ -582,7 +585,8 @@ int FunctionComparator::cmpOperations(const Instruction *L, if (int Res = cmpOrderings(FI->getOrdering(), cast<FenceInst>(R)->getOrdering())) return Res; - return cmpNumbers(FI->getSynchScope(), cast<FenceInst>(R)->getSynchScope()); + return cmpNumbers(FI->getSyncScopeID(), + cast<FenceInst>(R)->getSyncScopeID()); } if (const AtomicCmpXchgInst *CXI = dyn_cast<AtomicCmpXchgInst>(L)) { if (int Res = cmpNumbers(CXI->isVolatile(), @@ -599,8 +603,8 @@ int FunctionComparator::cmpOperations(const Instruction *L, cmpOrderings(CXI->getFailureOrdering(), cast<AtomicCmpXchgInst>(R)->getFailureOrdering())) return Res; - return cmpNumbers(CXI->getSynchScope(), - cast<AtomicCmpXchgInst>(R)->getSynchScope()); + return cmpNumbers(CXI->getSyncScopeID(), + cast<AtomicCmpXchgInst>(R)->getSyncScopeID()); } if (const AtomicRMWInst *RMWI = dyn_cast<AtomicRMWInst>(L)) { if (int Res = cmpNumbers(RMWI->getOperation(), @@ -612,8 +616,8 @@ int FunctionComparator::cmpOperations(const Instruction *L, if (int Res = cmpOrderings(RMWI->getOrdering(), cast<AtomicRMWInst>(R)->getOrdering())) return Res; - return cmpNumbers(RMWI->getSynchScope(), - cast<AtomicRMWInst>(R)->getSynchScope()); + return cmpNumbers(RMWI->getSyncScopeID(), + cast<AtomicRMWInst>(R)->getSyncScopeID()); } if (const PHINode *PNL = dyn_cast<PHINode>(L)) { const PHINode *PNR = cast<PHINode>(R); diff --git a/contrib/llvm/lib/Transforms/Utils/FunctionImportUtils.cpp b/contrib/llvm/lib/Transforms/Utils/FunctionImportUtils.cpp index 9844190..a98d072 100644 --- a/contrib/llvm/lib/Transforms/Utils/FunctionImportUtils.cpp +++ b/contrib/llvm/lib/Transforms/Utils/FunctionImportUtils.cpp @@ -12,8 +12,8 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Analysis/ModuleSummaryAnalysis.h" #include "llvm/Transforms/Utils/FunctionImportUtils.h" +#include "llvm/Analysis/ModuleSummaryAnalysis.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" using namespace llvm; @@ -21,11 +21,11 @@ using namespace llvm; /// Checks if we should import SGV as a definition, otherwise import as a /// declaration. bool FunctionImportGlobalProcessing::doImportAsDefinition( - const GlobalValue *SGV, DenseSet<const GlobalValue *> *GlobalsToImport) { + const GlobalValue *SGV, SetVector<GlobalValue *> *GlobalsToImport) { // For alias, we tie the definition to the base object. Extract it and recurse if (auto *GA = dyn_cast<GlobalAlias>(SGV)) { - if (GA->hasWeakAnyLinkage()) + if (GA->isInterposable()) return false; const GlobalObject *GO = GA->getBaseObject(); if (!GO->hasLinkOnceODRLinkage()) @@ -34,7 +34,7 @@ bool FunctionImportGlobalProcessing::doImportAsDefinition( GO, GlobalsToImport); } // Only import the globals requested for importing. - if (GlobalsToImport->count(SGV)) + if (GlobalsToImport->count(const_cast<GlobalValue *>(SGV))) return true; // Otherwise no. return false; @@ -57,7 +57,8 @@ bool FunctionImportGlobalProcessing::shouldPromoteLocalToGlobal( return false; if (isPerformingImport()) { - assert((!GlobalsToImport->count(SGV) || !isNonRenamableLocal(*SGV)) && + assert((!GlobalsToImport->count(const_cast<GlobalValue *>(SGV)) || + !isNonRenamableLocal(*SGV)) && "Attempting to promote non-renamable local"); // We don't know for sure yet if we are importing this value (as either // a reference or a def), since we are simply walking all values in the @@ -254,9 +255,8 @@ bool FunctionImportGlobalProcessing::run() { return false; } -bool llvm::renameModuleForThinLTO( - Module &M, const ModuleSummaryIndex &Index, - DenseSet<const GlobalValue *> *GlobalsToImport) { +bool llvm::renameModuleForThinLTO(Module &M, const ModuleSummaryIndex &Index, + SetVector<GlobalValue *> *GlobalsToImport) { FunctionImportGlobalProcessing ThinLTOProcessing(M, Index, GlobalsToImport); return ThinLTOProcessing.run(); } diff --git a/contrib/llvm/lib/Transforms/Utils/GlobalStatus.cpp b/contrib/llvm/lib/Transforms/Utils/GlobalStatus.cpp index 74ebcda..245fefb 100644 --- a/contrib/llvm/lib/Transforms/Utils/GlobalStatus.cpp +++ b/contrib/llvm/lib/Transforms/Utils/GlobalStatus.cpp @@ -7,12 +7,25 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Transforms/Utils/GlobalStatus.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CallSite.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/GlobalValue.h" #include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/Transforms/Utils/GlobalStatus.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/AtomicOrdering.h" +#include "llvm/Support/Casting.h" +#include <algorithm> +#include <cassert> using namespace llvm; @@ -175,13 +188,9 @@ static bool analyzeGlobalAux(const Value *V, GlobalStatus &GS, return false; } +GlobalStatus::GlobalStatus() = default; + bool GlobalStatus::analyzeGlobal(const Value *V, GlobalStatus &GS) { SmallPtrSet<const PHINode *, 16> PhiUsers; return analyzeGlobalAux(V, GS, PhiUsers); } - -GlobalStatus::GlobalStatus() - : IsCompared(false), IsLoaded(false), StoredType(NotStored), - StoredOnceValue(nullptr), AccessingFunction(nullptr), - HasMultipleAccessingFunctions(false), HasNonInstructionUser(false), - Ordering(AtomicOrdering::NotAtomic) {} diff --git a/contrib/llvm/lib/Transforms/Utils/ImportedFunctionsInliningStatistics.cpp b/contrib/llvm/lib/Transforms/Utils/ImportedFunctionsInliningStatistics.cpp index ed018bb..b8c12ad 100644 --- a/contrib/llvm/lib/Transforms/Utils/ImportedFunctionsInliningStatistics.cpp +++ b/contrib/llvm/lib/Transforms/Utils/ImportedFunctionsInliningStatistics.cpp @@ -62,6 +62,8 @@ void ImportedFunctionsInliningStatistics::recordInline(const Function &Caller, void ImportedFunctionsInliningStatistics::setModuleInfo(const Module &M) { ModuleName = M.getName(); for (const auto &F : M.functions()) { + if (F.isDeclaration()) + continue; AllFunctions++; ImportedFunctions += int(F.getMetadata("thinlto_src_module") != nullptr); } diff --git a/contrib/llvm/lib/Transforms/Utils/InlineFunction.cpp b/contrib/llvm/lib/Transforms/Utils/InlineFunction.cpp index a40079c..2a18c14 100644 --- a/contrib/llvm/lib/Transforms/Utils/InlineFunction.cpp +++ b/contrib/llvm/lib/Transforms/Utils/InlineFunction.cpp @@ -12,7 +12,6 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Utils/Cloning.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallSet.h" @@ -20,19 +19,21 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/EHPersonalities.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Attributes.h" -#include "llvm/IR/CallSite.h" #include "llvm/IR/CFG.h" +#include "llvm/IR/CallSite.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/DIBuilder.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DebugInfo.h" #include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/DIBuilder.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" @@ -40,8 +41,9 @@ #include "llvm/IR/Intrinsics.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/Local.h" #include <algorithm> using namespace llvm; @@ -1107,26 +1109,23 @@ static void AddAlignmentAssumptions(CallSite CS, InlineFunctionInfo &IFI) { bool DTCalculated = false; Function *CalledFunc = CS.getCalledFunction(); - for (Function::arg_iterator I = CalledFunc->arg_begin(), - E = CalledFunc->arg_end(); - I != E; ++I) { - unsigned Align = I->getType()->isPointerTy() ? I->getParamAlignment() : 0; - if (Align && !I->hasByValOrInAllocaAttr() && !I->hasNUses(0)) { + for (Argument &Arg : CalledFunc->args()) { + unsigned Align = Arg.getType()->isPointerTy() ? Arg.getParamAlignment() : 0; + if (Align && !Arg.hasByValOrInAllocaAttr() && !Arg.hasNUses(0)) { if (!DTCalculated) { - DT.recalculate(const_cast<Function&>(*CS.getInstruction()->getParent() - ->getParent())); + DT.recalculate(*CS.getCaller()); DTCalculated = true; } // If we can already prove the asserted alignment in the context of the // caller, then don't bother inserting the assumption. - Value *Arg = CS.getArgument(I->getArgNo()); - if (getKnownAlignment(Arg, DL, CS.getInstruction(), AC, &DT) >= Align) + Value *ArgVal = CS.getArgument(Arg.getArgNo()); + if (getKnownAlignment(ArgVal, DL, CS.getInstruction(), AC, &DT) >= Align) continue; - CallInst *NewAssumption = IRBuilder<>(CS.getInstruction()) - .CreateAlignmentAssumption(DL, Arg, Align); - AC->registerAssumption(NewAssumption); + CallInst *NewAsmp = IRBuilder<>(CS.getInstruction()) + .CreateAlignmentAssumption(DL, ArgVal, Align); + AC->registerAssumption(NewAsmp); } } } @@ -1140,7 +1139,7 @@ static void UpdateCallGraphAfterInlining(CallSite CS, ValueToValueMapTy &VMap, InlineFunctionInfo &IFI) { CallGraph &CG = *IFI.CG; - const Function *Caller = CS.getInstruction()->getParent()->getParent(); + const Function *Caller = CS.getCaller(); const Function *Callee = CS.getCalledFunction(); CallGraphNode *CalleeNode = CG[Callee]; CallGraphNode *CallerNode = CG[Caller]; @@ -1225,7 +1224,8 @@ static Value *HandleByValArgument(Value *Arg, Instruction *TheCall, PointerType *ArgTy = cast<PointerType>(Arg->getType()); Type *AggTy = ArgTy->getElementType(); - Function *Caller = TheCall->getParent()->getParent(); + Function *Caller = TheCall->getFunction(); + const DataLayout &DL = Caller->getParent()->getDataLayout(); // If the called function is readonly, then it could not mutate the caller's // copy of the byval'd memory. In this case, it is safe to elide the copy and @@ -1239,31 +1239,30 @@ static Value *HandleByValArgument(Value *Arg, Instruction *TheCall, AssumptionCache *AC = IFI.GetAssumptionCache ? &(*IFI.GetAssumptionCache)(*Caller) : nullptr; - const DataLayout &DL = Caller->getParent()->getDataLayout(); // If the pointer is already known to be sufficiently aligned, or if we can // round it up to a larger alignment, then we don't need a temporary. if (getOrEnforceKnownAlignment(Arg, ByValAlignment, DL, TheCall, AC) >= ByValAlignment) return Arg; - + // Otherwise, we have to make a memcpy to get a safe alignment. This is bad // for code quality, but rarely happens and is required for correctness. } // Create the alloca. If we have DataLayout, use nice alignment. - unsigned Align = - Caller->getParent()->getDataLayout().getPrefTypeAlignment(AggTy); + unsigned Align = DL.getPrefTypeAlignment(AggTy); // If the byval had an alignment specified, we *must* use at least that // alignment, as it is required by the byval argument (and uses of the // pointer inside the callee). Align = std::max(Align, ByValAlignment); - - Value *NewAlloca = new AllocaInst(AggTy, nullptr, Align, Arg->getName(), + + Value *NewAlloca = new AllocaInst(AggTy, DL.getAllocaAddrSpace(), + nullptr, Align, Arg->getName(), &*Caller->begin()->begin()); IFI.StaticAllocas.push_back(cast<AllocaInst>(NewAlloca)); - + // Uses of the argument in the function should use our new alloca // instead. return NewAlloca; @@ -1303,41 +1302,6 @@ static bool hasLifetimeMarkers(AllocaInst *AI) { return false; } -/// Rebuild the entire inlined-at chain for this instruction so that the top of -/// the chain now is inlined-at the new call site. -static DebugLoc -updateInlinedAtInfo(const DebugLoc &DL, DILocation *InlinedAtNode, - LLVMContext &Ctx, - DenseMap<const DILocation *, DILocation *> &IANodes) { - SmallVector<DILocation *, 3> InlinedAtLocations; - DILocation *Last = InlinedAtNode; - DILocation *CurInlinedAt = DL; - - // Gather all the inlined-at nodes - while (DILocation *IA = CurInlinedAt->getInlinedAt()) { - // Skip any we've already built nodes for - if (DILocation *Found = IANodes[IA]) { - Last = Found; - break; - } - - InlinedAtLocations.push_back(IA); - CurInlinedAt = IA; - } - - // Starting from the top, rebuild the nodes to point to the new inlined-at - // location (then rebuilding the rest of the chain behind it) and update the - // map of already-constructed inlined-at nodes. - for (const DILocation *MD : reverse(InlinedAtLocations)) { - Last = IANodes[MD] = DILocation::getDistinct( - Ctx, MD->getLine(), MD->getColumn(), MD->getScope(), Last); - } - - // And finally create the normal location for this instruction, referring to - // the new inlined-at chain. - return DebugLoc::get(DL.getLine(), DL.getCol(), DL.getScope(), Last); -} - /// Return the result of AI->isStaticAlloca() if AI were moved to the entry /// block. Allocas used in inalloca calls and allocas of dynamic array size /// cannot be static. @@ -1365,14 +1329,16 @@ static void fixupLineNumbers(Function *Fn, Function::iterator FI, // Cache the inlined-at nodes as they're built so they are reused, without // this every instruction's inlined-at chain would become distinct from each // other. - DenseMap<const DILocation *, DILocation *> IANodes; + DenseMap<const MDNode *, MDNode *> IANodes; for (; FI != Fn->end(); ++FI) { for (BasicBlock::iterator BI = FI->begin(), BE = FI->end(); BI != BE; ++BI) { if (DebugLoc DL = BI->getDebugLoc()) { - BI->setDebugLoc( - updateInlinedAtInfo(DL, InlinedAtNode, BI->getContext(), IANodes)); + auto IA = DebugLoc::appendInlinedAt(DL, InlinedAtNode, BI->getContext(), + IANodes); + auto IDL = DebugLoc::get(DL.getLine(), DL.getCol(), DL.getScope(), IA); + BI->setDebugLoc(IDL); continue; } @@ -1393,6 +1359,91 @@ static void fixupLineNumbers(Function *Fn, Function::iterator FI, } } } +/// Update the block frequencies of the caller after a callee has been inlined. +/// +/// Each block cloned into the caller has its block frequency scaled by the +/// ratio of CallSiteFreq/CalleeEntryFreq. This ensures that the cloned copy of +/// callee's entry block gets the same frequency as the callsite block and the +/// relative frequencies of all cloned blocks remain the same after cloning. +static void updateCallerBFI(BasicBlock *CallSiteBlock, + const ValueToValueMapTy &VMap, + BlockFrequencyInfo *CallerBFI, + BlockFrequencyInfo *CalleeBFI, + const BasicBlock &CalleeEntryBlock) { + SmallPtrSet<BasicBlock *, 16> ClonedBBs; + for (auto const &Entry : VMap) { + if (!isa<BasicBlock>(Entry.first) || !Entry.second) + continue; + auto *OrigBB = cast<BasicBlock>(Entry.first); + auto *ClonedBB = cast<BasicBlock>(Entry.second); + uint64_t Freq = CalleeBFI->getBlockFreq(OrigBB).getFrequency(); + if (!ClonedBBs.insert(ClonedBB).second) { + // Multiple blocks in the callee might get mapped to one cloned block in + // the caller since we prune the callee as we clone it. When that happens, + // we want to use the maximum among the original blocks' frequencies. + uint64_t NewFreq = CallerBFI->getBlockFreq(ClonedBB).getFrequency(); + if (NewFreq > Freq) + Freq = NewFreq; + } + CallerBFI->setBlockFreq(ClonedBB, Freq); + } + BasicBlock *EntryClone = cast<BasicBlock>(VMap.lookup(&CalleeEntryBlock)); + CallerBFI->setBlockFreqAndScale( + EntryClone, CallerBFI->getBlockFreq(CallSiteBlock).getFrequency(), + ClonedBBs); +} + +/// Update the branch metadata for cloned call instructions. +static void updateCallProfile(Function *Callee, const ValueToValueMapTy &VMap, + const Optional<uint64_t> &CalleeEntryCount, + const Instruction *TheCall, + ProfileSummaryInfo *PSI, + BlockFrequencyInfo *CallerBFI) { + if (!CalleeEntryCount.hasValue() || CalleeEntryCount.getValue() < 1) + return; + Optional<uint64_t> CallSiteCount = + PSI ? PSI->getProfileCount(TheCall, CallerBFI) : None; + uint64_t CallCount = + std::min(CallSiteCount.hasValue() ? CallSiteCount.getValue() : 0, + CalleeEntryCount.getValue()); + + for (auto const &Entry : VMap) + if (isa<CallInst>(Entry.first)) + if (auto *CI = dyn_cast_or_null<CallInst>(Entry.second)) + CI->updateProfWeight(CallCount, CalleeEntryCount.getValue()); + for (BasicBlock &BB : *Callee) + // No need to update the callsite if it is pruned during inlining. + if (VMap.count(&BB)) + for (Instruction &I : BB) + if (CallInst *CI = dyn_cast<CallInst>(&I)) + CI->updateProfWeight(CalleeEntryCount.getValue() - CallCount, + CalleeEntryCount.getValue()); +} + +/// Update the entry count of callee after inlining. +/// +/// The callsite's block count is subtracted from the callee's function entry +/// count. +static void updateCalleeCount(BlockFrequencyInfo *CallerBFI, BasicBlock *CallBB, + Instruction *CallInst, Function *Callee, + ProfileSummaryInfo *PSI) { + // If the callee has a original count of N, and the estimated count of + // callsite is M, the new callee count is set to N - M. M is estimated from + // the caller's entry count, its entry block frequency and the block frequency + // of the callsite. + Optional<uint64_t> CalleeCount = Callee->getEntryCount(); + if (!CalleeCount.hasValue() || !PSI) + return; + Optional<uint64_t> CallCount = PSI->getProfileCount(CallInst, CallerBFI); + if (!CallCount.hasValue()) + return; + // Since CallSiteCount is an estimate, it could exceed the original callee + // count and has to be set to 0. + if (CallCount.getValue() > CalleeCount.getValue()) + Callee->setEntryCount(0); + else + Callee->setEntryCount(CalleeCount.getValue() - CallCount.getValue()); +} /// This function inlines the called function into the basic block of the /// caller. This returns false if it is not possible to inline this call. @@ -1405,13 +1456,13 @@ static void fixupLineNumbers(Function *Fn, Function::iterator FI, bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, AAResults *CalleeAAR, bool InsertLifetime) { Instruction *TheCall = CS.getInstruction(); - assert(TheCall->getParent() && TheCall->getParent()->getParent() && - "Instruction not in function!"); + assert(TheCall->getParent() && TheCall->getFunction() + && "Instruction not in function!"); // If IFI has any state in it, zap it before we fill it in. IFI.reset(); - - const Function *CalledFunc = CS.getCalledFunction(); + + Function *CalledFunc = CS.getCalledFunction(); if (!CalledFunc || // Can't inline external function or indirect CalledFunc->isDeclaration() || // call, or call to a vararg function! CalledFunc->getFunctionType()->isVarArg()) return false; @@ -1548,7 +1599,7 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, // matches up the formal to the actual argument values. CallSite::arg_iterator AI = CS.arg_begin(); unsigned ArgNo = 0; - for (Function::const_arg_iterator I = CalledFunc->arg_begin(), + for (Function::arg_iterator I = CalledFunc->arg_begin(), E = CalledFunc->arg_end(); I != E; ++I, ++AI, ++ArgNo) { Value *ActualArg = *AI; @@ -1558,7 +1609,7 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, // modify the struct. if (CS.isByValArgument(ArgNo)) { ActualArg = HandleByValArgument(ActualArg, TheCall, CalledFunc, IFI, - CalledFunc->getParamAlignment(ArgNo+1)); + CalledFunc->getParamAlignment(ArgNo)); if (ActualArg != *AI) ByValInit.push_back(std::make_pair(ActualArg, (Value*) *AI)); } @@ -1578,10 +1629,19 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, CloneAndPruneFunctionInto(Caller, CalledFunc, VMap, /*ModuleLevelChanges=*/false, Returns, ".i", &InlinedFunctionInfo, TheCall); - // Remember the first block that is newly cloned over. FirstNewBlock = LastBlock; ++FirstNewBlock; + if (IFI.CallerBFI != nullptr && IFI.CalleeBFI != nullptr) + // Update the BFI of blocks cloned into the caller. + updateCallerBFI(OrigBB, VMap, IFI.CallerBFI, IFI.CalleeBFI, + CalledFunc->front()); + + updateCallProfile(CalledFunc, VMap, CalledFunc->getEntryCount(), TheCall, + IFI.PSI, IFI.CallerBFI); + // Update the profile count of callee. + updateCalleeCount(IFI.CallerBFI, OrigBB, TheCall, CalledFunc, IFI.PSI); + // Inject byval arguments initialization. for (std::pair<Value*, Value*> &Init : ByValInit) HandleByValArgumentInit(Init.first, Init.second, Caller->getParent(), @@ -2087,6 +2147,12 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, CalledFunc->getName() + ".exit"); } + if (IFI.CallerBFI) { + // Copy original BB's block frequency to AfterCallBB + IFI.CallerBFI->setBlockFreq( + AfterCallBB, IFI.CallerBFI->getBlockFreq(OrigBB).getFrequency()); + } + // Change the branch that used to go to AfterCallBB to branch to the first // basic block of the inlined function. // @@ -2206,7 +2272,7 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, AssumptionCache *AC = IFI.GetAssumptionCache ? &(*IFI.GetAssumptionCache)(*Caller) : nullptr; auto &DL = Caller->getParent()->getDataLayout(); - if (Value *V = SimplifyInstruction(PHI, DL, nullptr, nullptr, AC)) { + if (Value *V = SimplifyInstruction(PHI, {DL, nullptr, nullptr, AC})) { PHI->replaceAllUsesWith(V); PHI->eraseFromParent(); } diff --git a/contrib/llvm/lib/Transforms/Utils/InstructionNamer.cpp b/contrib/llvm/lib/Transforms/Utils/InstructionNamer.cpp index 8a1973d..23ec45e 100644 --- a/contrib/llvm/lib/Transforms/Utils/InstructionNamer.cpp +++ b/contrib/llvm/lib/Transforms/Utils/InstructionNamer.cpp @@ -14,10 +14,10 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar.h" #include "llvm/IR/Function.h" #include "llvm/IR/Type.h" #include "llvm/Pass.h" +#include "llvm/Transforms/Scalar.h" using namespace llvm; namespace { @@ -26,16 +26,15 @@ namespace { InstNamer() : FunctionPass(ID) { initializeInstNamerPass(*PassRegistry::getPassRegistry()); } - + void getAnalysisUsage(AnalysisUsage &Info) const override { Info.setPreservesAll(); } bool runOnFunction(Function &F) override { - for (Function::arg_iterator AI = F.arg_begin(), AE = F.arg_end(); - AI != AE; ++AI) - if (!AI->hasName() && !AI->getType()->isVoidTy()) - AI->setName("arg"); + for (auto &Arg : F.args()) + if (!Arg.hasName()) + Arg.setName("arg"); for (BasicBlock &BB : F) { if (!BB.hasName()) @@ -48,11 +47,11 @@ namespace { return true; } }; - + char InstNamer::ID = 0; } -INITIALIZE_PASS(InstNamer, "instnamer", +INITIALIZE_PASS(InstNamer, "instnamer", "Assign names to anonymous instructions", false, false) char &llvm::InstructionNamerID = InstNamer::ID; //===----------------------------------------------------------------------===// diff --git a/contrib/llvm/lib/Transforms/Utils/LCSSA.cpp b/contrib/llvm/lib/Transforms/Utils/LCSSA.cpp index 68c6b74..089f2b5 100644 --- a/contrib/llvm/lib/Transforms/Utils/LCSSA.cpp +++ b/contrib/llvm/lib/Transforms/Utils/LCSSA.cpp @@ -85,9 +85,11 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist, UsesToRewrite.clear(); Instruction *I = Worklist.pop_back_val(); + assert(!I->getType()->isTokenTy() && "Tokens shouldn't be in the worklist"); BasicBlock *InstBB = I->getParent(); Loop *L = LI.getLoopFor(InstBB); - if (!LoopExitBlocks.count(L)) + assert(L && "Instruction belongs to a BB that's not part of a loop"); + if (!LoopExitBlocks.count(L)) L->getExitBlocks(LoopExitBlocks[L]); assert(LoopExitBlocks.count(L)); const SmallVectorImpl<BasicBlock *> &ExitBlocks = LoopExitBlocks[L]; @@ -95,17 +97,10 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist, if (ExitBlocks.empty()) continue; - // Tokens cannot be used in PHI nodes, so we skip over them. - // We can run into tokens which are live out of a loop with catchswitch - // instructions in Windows EH if the catchswitch has one catchpad which - // is inside the loop and another which is not. - if (I->getType()->isTokenTy()) - continue; - for (Use &U : I->uses()) { Instruction *User = cast<Instruction>(U.getUser()); BasicBlock *UserBB = User->getParent(); - if (PHINode *PN = dyn_cast<PHINode>(User)) + if (auto *PN = dyn_cast<PHINode>(User)) UserBB = PN->getIncomingBlock(U); if (InstBB != UserBB && !L->contains(UserBB)) @@ -123,7 +118,7 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist, // DomBB dominates the value, so adjust DomBB to the normal destination // block, which is effectively where the value is first usable. BasicBlock *DomBB = InstBB; - if (InvokeInst *Inv = dyn_cast<InvokeInst>(I)) + if (auto *Inv = dyn_cast<InvokeInst>(I)) DomBB = Inv->getNormalDest(); DomTreeNode *DomNode = DT.getNode(DomBB); @@ -188,7 +183,7 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist, // block. Instruction *User = cast<Instruction>(UseToRewrite->getUser()); BasicBlock *UserBB = User->getParent(); - if (PHINode *PN = dyn_cast<PHINode>(User)) + if (auto *PN = dyn_cast<PHINode>(User)) UserBB = PN->getIncomingBlock(*UseToRewrite); if (isa<PHINode>(UserBB->begin()) && isExitBlock(UserBB, ExitBlocks)) { @@ -213,13 +208,9 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist, // Post process PHI instructions that were inserted into another disjoint // loop and update their exits properly. - for (auto *PostProcessPN : PostProcessPHIs) { - if (PostProcessPN->use_empty()) - continue; - - // Reprocess each PHI instruction. - Worklist.push_back(PostProcessPN); - } + for (auto *PostProcessPN : PostProcessPHIs) + if (!PostProcessPN->use_empty()) + Worklist.push_back(PostProcessPN); // Keep track of PHI nodes that we want to remove because they did not have // any uses rewritten. @@ -237,40 +228,75 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist, return Changed; } -/// Return true if the specified block dominates at least -/// one of the blocks in the specified list. -static bool -blockDominatesAnExit(BasicBlock *BB, - DominatorTree &DT, - const SmallVectorImpl<BasicBlock *> &ExitBlocks) { - DomTreeNode *DomNode = DT.getNode(BB); - return any_of(ExitBlocks, [&](BasicBlock *EB) { - return DT.dominates(DomNode, DT.getNode(EB)); - }); +// Compute the set of BasicBlocks in the loop `L` dominating at least one exit. +static void computeBlocksDominatingExits( + Loop &L, DominatorTree &DT, SmallVector<BasicBlock *, 8> &ExitBlocks, + SmallSetVector<BasicBlock *, 8> &BlocksDominatingExits) { + SmallVector<BasicBlock *, 8> BBWorklist; + + // We start from the exit blocks, as every block trivially dominates itself + // (not strictly). + for (BasicBlock *BB : ExitBlocks) + BBWorklist.push_back(BB); + + while (!BBWorklist.empty()) { + BasicBlock *BB = BBWorklist.pop_back_val(); + + // Check if this is a loop header. If this is the case, we're done. + if (L.getHeader() == BB) + continue; + + // Otherwise, add its immediate predecessor in the dominator tree to the + // worklist, unless we visited it already. + BasicBlock *IDomBB = DT.getNode(BB)->getIDom()->getBlock(); + + // Exit blocks can have an immediate dominator not beloinging to the + // loop. For an exit block to be immediately dominated by another block + // outside the loop, it implies not all paths from that dominator, to the + // exit block, go through the loop. + // Example: + // + // |---- A + // | | + // | B<-- + // | | | + // |---> C -- + // | + // D + // + // C is the exit block of the loop and it's immediately dominated by A, + // which doesn't belong to the loop. + if (!L.contains(IDomBB)) + continue; + + if (BlocksDominatingExits.insert(IDomBB)) + BBWorklist.push_back(IDomBB); + } } bool llvm::formLCSSA(Loop &L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution *SE) { bool Changed = false; - // Get the set of exiting blocks. SmallVector<BasicBlock *, 8> ExitBlocks; L.getExitBlocks(ExitBlocks); - if (ExitBlocks.empty()) return false; + SmallSetVector<BasicBlock *, 8> BlocksDominatingExits; + + // We want to avoid use-scanning leveraging dominance informations. + // If a block doesn't dominate any of the loop exits, the none of the values + // defined in the loop can be used outside. + // We compute the set of blocks fullfilling the conditions in advance + // walking the dominator tree upwards until we hit a loop header. + computeBlocksDominatingExits(L, DT, ExitBlocks, BlocksDominatingExits); + SmallVector<Instruction *, 8> Worklist; // Look at all the instructions in the loop, checking to see if they have uses // outside the loop. If so, put them into the worklist to rewrite those uses. - for (BasicBlock *BB : L.blocks()) { - // For large loops, avoid use-scanning by using dominance information: In - // particular, if a block does not dominate any of the loop exits, then none - // of the values defined in the block could be used outside the loop. - if (!blockDominatesAnExit(BB, DT, ExitBlocks)) - continue; - + for (BasicBlock *BB : BlocksDominatingExits) { for (Instruction &I : *BB) { // Reject two common cases fast: instructions with no uses (like stores) // and instructions with one use that is in the same block as this. @@ -279,6 +305,13 @@ bool llvm::formLCSSA(Loop &L, DominatorTree &DT, LoopInfo *LI, !isa<PHINode>(I.user_back()))) continue; + // Tokens cannot be used in PHI nodes, so we skip over them. + // We can run into tokens which are live out of a loop with catchswitch + // instructions in Windows EH if the catchswitch has one catchpad which + // is inside the loop and another which is not. + if (I.getType()->isTokenTy()) + continue; + Worklist.push_back(&I); } } @@ -395,8 +428,8 @@ PreservedAnalyses LCSSAPass::run(Function &F, FunctionAnalysisManager &AM) { if (!formLCSSAOnAllLoops(&LI, DT, SE)) return PreservedAnalyses::all(); - // FIXME: This should also 'preserve the CFG'. PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); PA.preserve<BasicAA>(); PA.preserve<GlobalsAA>(); PA.preserve<SCEVAA>(); diff --git a/contrib/llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp b/contrib/llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp index d97cd75..42aca75 100644 --- a/contrib/llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp +++ b/contrib/llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp @@ -33,6 +33,7 @@ #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstVisitor.h" @@ -48,16 +49,6 @@ using namespace llvm; STATISTIC(NumWrappedOneCond, "Number of One-Condition Wrappers Inserted"); STATISTIC(NumWrappedTwoCond, "Number of Two-Condition Wrappers Inserted"); -static cl::opt<bool> LibCallsShrinkWrapDoDomainError( - "libcalls-shrinkwrap-domain-error", cl::init(true), cl::Hidden, - cl::desc("Perform shrink-wrap on lib calls with domain errors")); -static cl::opt<bool> LibCallsShrinkWrapDoRangeError( - "libcalls-shrinkwrap-range-error", cl::init(true), cl::Hidden, - cl::desc("Perform shrink-wrap on lib calls with range errors")); -static cl::opt<bool> LibCallsShrinkWrapDoPoleError( - "libcalls-shrinkwrap-pole-error", cl::init(true), cl::Hidden, - cl::desc("Perform shrink-wrap on lib calls with pole errors")); - namespace { class LibCallsShrinkWrapLegacyPass : public FunctionPass { public: @@ -82,10 +73,11 @@ INITIALIZE_PASS_END(LibCallsShrinkWrapLegacyPass, "libcalls-shrinkwrap", namespace { class LibCallsShrinkWrap : public InstVisitor<LibCallsShrinkWrap> { public: - LibCallsShrinkWrap(const TargetLibraryInfo &TLI) : TLI(TLI), Changed(false){}; - bool isChanged() const { return Changed; } + LibCallsShrinkWrap(const TargetLibraryInfo &TLI, DominatorTree *DT) + : TLI(TLI), DT(DT){}; void visitCallInst(CallInst &CI) { checkCandidate(CI); } - void perform() { + bool perform() { + bool Changed = false; for (auto &CI : WorkList) { DEBUG(dbgs() << "CDCE calls: " << CI->getCalledFunction()->getName() << "\n"); @@ -94,18 +86,19 @@ public: DEBUG(dbgs() << "Transformed\n"); } } + return Changed; } private: bool perform(CallInst *CI); void checkCandidate(CallInst &CI); void shrinkWrapCI(CallInst *CI, Value *Cond); - bool performCallDomainErrorOnly(CallInst *CI, const LibFunc::Func &Func); - bool performCallErrors(CallInst *CI, const LibFunc::Func &Func); - bool performCallRangeErrorOnly(CallInst *CI, const LibFunc::Func &Func); - Value *generateOneRangeCond(CallInst *CI, const LibFunc::Func &Func); - Value *generateTwoRangeCond(CallInst *CI, const LibFunc::Func &Func); - Value *generateCondForPow(CallInst *CI, const LibFunc::Func &Func); + bool performCallDomainErrorOnly(CallInst *CI, const LibFunc &Func); + bool performCallErrors(CallInst *CI, const LibFunc &Func); + bool performCallRangeErrorOnly(CallInst *CI, const LibFunc &Func); + Value *generateOneRangeCond(CallInst *CI, const LibFunc &Func); + Value *generateTwoRangeCond(CallInst *CI, const LibFunc &Func); + Value *generateCondForPow(CallInst *CI, const LibFunc &Func); // Create an OR of two conditions. Value *createOrCond(CallInst *CI, CmpInst::Predicate Cmp, float Val, @@ -134,51 +127,51 @@ private: } const TargetLibraryInfo &TLI; + DominatorTree *DT; SmallVector<CallInst *, 16> WorkList; - bool Changed; }; } // end anonymous namespace // Perform the transformation to calls with errno set by domain error. bool LibCallsShrinkWrap::performCallDomainErrorOnly(CallInst *CI, - const LibFunc::Func &Func) { + const LibFunc &Func) { Value *Cond = nullptr; switch (Func) { - case LibFunc::acos: // DomainError: (x < -1 || x > 1) - case LibFunc::acosf: // Same as acos - case LibFunc::acosl: // Same as acos - case LibFunc::asin: // DomainError: (x < -1 || x > 1) - case LibFunc::asinf: // Same as asin - case LibFunc::asinl: // Same as asin + case LibFunc_acos: // DomainError: (x < -1 || x > 1) + case LibFunc_acosf: // Same as acos + case LibFunc_acosl: // Same as acos + case LibFunc_asin: // DomainError: (x < -1 || x > 1) + case LibFunc_asinf: // Same as asin + case LibFunc_asinl: // Same as asin { ++NumWrappedTwoCond; Cond = createOrCond(CI, CmpInst::FCMP_OLT, -1.0f, CmpInst::FCMP_OGT, 1.0f); break; } - case LibFunc::cos: // DomainError: (x == +inf || x == -inf) - case LibFunc::cosf: // Same as cos - case LibFunc::cosl: // Same as cos - case LibFunc::sin: // DomainError: (x == +inf || x == -inf) - case LibFunc::sinf: // Same as sin - case LibFunc::sinl: // Same as sin + case LibFunc_cos: // DomainError: (x == +inf || x == -inf) + case LibFunc_cosf: // Same as cos + case LibFunc_cosl: // Same as cos + case LibFunc_sin: // DomainError: (x == +inf || x == -inf) + case LibFunc_sinf: // Same as sin + case LibFunc_sinl: // Same as sin { ++NumWrappedTwoCond; Cond = createOrCond(CI, CmpInst::FCMP_OEQ, INFINITY, CmpInst::FCMP_OEQ, -INFINITY); break; } - case LibFunc::acosh: // DomainError: (x < 1) - case LibFunc::acoshf: // Same as acosh - case LibFunc::acoshl: // Same as acosh + case LibFunc_acosh: // DomainError: (x < 1) + case LibFunc_acoshf: // Same as acosh + case LibFunc_acoshl: // Same as acosh { ++NumWrappedOneCond; Cond = createCond(CI, CmpInst::FCMP_OLT, 1.0f); break; } - case LibFunc::sqrt: // DomainError: (x < 0) - case LibFunc::sqrtf: // Same as sqrt - case LibFunc::sqrtl: // Same as sqrt + case LibFunc_sqrt: // DomainError: (x < 0) + case LibFunc_sqrtf: // Same as sqrt + case LibFunc_sqrtl: // Same as sqrt { ++NumWrappedOneCond; Cond = createCond(CI, CmpInst::FCMP_OLT, 0.0f); @@ -193,31 +186,31 @@ bool LibCallsShrinkWrap::performCallDomainErrorOnly(CallInst *CI, // Perform the transformation to calls with errno set by range error. bool LibCallsShrinkWrap::performCallRangeErrorOnly(CallInst *CI, - const LibFunc::Func &Func) { + const LibFunc &Func) { Value *Cond = nullptr; switch (Func) { - case LibFunc::cosh: - case LibFunc::coshf: - case LibFunc::coshl: - case LibFunc::exp: - case LibFunc::expf: - case LibFunc::expl: - case LibFunc::exp10: - case LibFunc::exp10f: - case LibFunc::exp10l: - case LibFunc::exp2: - case LibFunc::exp2f: - case LibFunc::exp2l: - case LibFunc::sinh: - case LibFunc::sinhf: - case LibFunc::sinhl: { + case LibFunc_cosh: + case LibFunc_coshf: + case LibFunc_coshl: + case LibFunc_exp: + case LibFunc_expf: + case LibFunc_expl: + case LibFunc_exp10: + case LibFunc_exp10f: + case LibFunc_exp10l: + case LibFunc_exp2: + case LibFunc_exp2f: + case LibFunc_exp2l: + case LibFunc_sinh: + case LibFunc_sinhf: + case LibFunc_sinhl: { Cond = generateTwoRangeCond(CI, Func); break; } - case LibFunc::expm1: // RangeError: (709, inf) - case LibFunc::expm1f: // RangeError: (88, inf) - case LibFunc::expm1l: // RangeError: (11356, inf) + case LibFunc_expm1: // RangeError: (709, inf) + case LibFunc_expm1f: // RangeError: (88, inf) + case LibFunc_expm1l: // RangeError: (11356, inf) { Cond = generateOneRangeCond(CI, Func); break; @@ -231,63 +224,54 @@ bool LibCallsShrinkWrap::performCallRangeErrorOnly(CallInst *CI, // Perform the transformation to calls with errno set by combination of errors. bool LibCallsShrinkWrap::performCallErrors(CallInst *CI, - const LibFunc::Func &Func) { + const LibFunc &Func) { Value *Cond = nullptr; switch (Func) { - case LibFunc::atanh: // DomainError: (x < -1 || x > 1) + case LibFunc_atanh: // DomainError: (x < -1 || x > 1) // PoleError: (x == -1 || x == 1) // Overall Cond: (x <= -1 || x >= 1) - case LibFunc::atanhf: // Same as atanh - case LibFunc::atanhl: // Same as atanh + case LibFunc_atanhf: // Same as atanh + case LibFunc_atanhl: // Same as atanh { - if (!LibCallsShrinkWrapDoDomainError || !LibCallsShrinkWrapDoPoleError) - return false; ++NumWrappedTwoCond; Cond = createOrCond(CI, CmpInst::FCMP_OLE, -1.0f, CmpInst::FCMP_OGE, 1.0f); break; } - case LibFunc::log: // DomainError: (x < 0) + case LibFunc_log: // DomainError: (x < 0) // PoleError: (x == 0) // Overall Cond: (x <= 0) - case LibFunc::logf: // Same as log - case LibFunc::logl: // Same as log - case LibFunc::log10: // Same as log - case LibFunc::log10f: // Same as log - case LibFunc::log10l: // Same as log - case LibFunc::log2: // Same as log - case LibFunc::log2f: // Same as log - case LibFunc::log2l: // Same as log - case LibFunc::logb: // Same as log - case LibFunc::logbf: // Same as log - case LibFunc::logbl: // Same as log + case LibFunc_logf: // Same as log + case LibFunc_logl: // Same as log + case LibFunc_log10: // Same as log + case LibFunc_log10f: // Same as log + case LibFunc_log10l: // Same as log + case LibFunc_log2: // Same as log + case LibFunc_log2f: // Same as log + case LibFunc_log2l: // Same as log + case LibFunc_logb: // Same as log + case LibFunc_logbf: // Same as log + case LibFunc_logbl: // Same as log { - if (!LibCallsShrinkWrapDoDomainError || !LibCallsShrinkWrapDoPoleError) - return false; ++NumWrappedOneCond; Cond = createCond(CI, CmpInst::FCMP_OLE, 0.0f); break; } - case LibFunc::log1p: // DomainError: (x < -1) + case LibFunc_log1p: // DomainError: (x < -1) // PoleError: (x == -1) // Overall Cond: (x <= -1) - case LibFunc::log1pf: // Same as log1p - case LibFunc::log1pl: // Same as log1p + case LibFunc_log1pf: // Same as log1p + case LibFunc_log1pl: // Same as log1p { - if (!LibCallsShrinkWrapDoDomainError || !LibCallsShrinkWrapDoPoleError) - return false; ++NumWrappedOneCond; Cond = createCond(CI, CmpInst::FCMP_OLE, -1.0f); break; } - case LibFunc::pow: // DomainError: x < 0 and y is noninteger + case LibFunc_pow: // DomainError: x < 0 and y is noninteger // PoleError: x == 0 and y < 0 // RangeError: overflow or underflow - case LibFunc::powf: - case LibFunc::powl: { - if (!LibCallsShrinkWrapDoDomainError || !LibCallsShrinkWrapDoPoleError || - !LibCallsShrinkWrapDoRangeError) - return false; + case LibFunc_powf: + case LibFunc_powl: { Cond = generateCondForPow(CI, Func); if (Cond == nullptr) return false; @@ -313,7 +297,7 @@ void LibCallsShrinkWrap::checkCandidate(CallInst &CI) { if (!CI.use_empty()) return; - LibFunc::Func Func; + LibFunc Func; Function *Callee = CI.getCalledFunction(); if (!Callee) return; @@ -333,20 +317,20 @@ void LibCallsShrinkWrap::checkCandidate(CallInst &CI) { // Generate the upper bound condition for RangeError. Value *LibCallsShrinkWrap::generateOneRangeCond(CallInst *CI, - const LibFunc::Func &Func) { + const LibFunc &Func) { float UpperBound; switch (Func) { - case LibFunc::expm1: // RangeError: (709, inf) + case LibFunc_expm1: // RangeError: (709, inf) UpperBound = 709.0f; break; - case LibFunc::expm1f: // RangeError: (88, inf) + case LibFunc_expm1f: // RangeError: (88, inf) UpperBound = 88.0f; break; - case LibFunc::expm1l: // RangeError: (11356, inf) + case LibFunc_expm1l: // RangeError: (11356, inf) UpperBound = 11356.0f; break; default: - llvm_unreachable("Should be reach here"); + llvm_unreachable("Unhandled library call!"); } ++NumWrappedOneCond; @@ -355,62 +339,62 @@ Value *LibCallsShrinkWrap::generateOneRangeCond(CallInst *CI, // Generate the lower and upper bound condition for RangeError. Value *LibCallsShrinkWrap::generateTwoRangeCond(CallInst *CI, - const LibFunc::Func &Func) { + const LibFunc &Func) { float UpperBound, LowerBound; switch (Func) { - case LibFunc::cosh: // RangeError: (x < -710 || x > 710) - case LibFunc::sinh: // Same as cosh + case LibFunc_cosh: // RangeError: (x < -710 || x > 710) + case LibFunc_sinh: // Same as cosh LowerBound = -710.0f; UpperBound = 710.0f; break; - case LibFunc::coshf: // RangeError: (x < -89 || x > 89) - case LibFunc::sinhf: // Same as coshf + case LibFunc_coshf: // RangeError: (x < -89 || x > 89) + case LibFunc_sinhf: // Same as coshf LowerBound = -89.0f; UpperBound = 89.0f; break; - case LibFunc::coshl: // RangeError: (x < -11357 || x > 11357) - case LibFunc::sinhl: // Same as coshl + case LibFunc_coshl: // RangeError: (x < -11357 || x > 11357) + case LibFunc_sinhl: // Same as coshl LowerBound = -11357.0f; UpperBound = 11357.0f; break; - case LibFunc::exp: // RangeError: (x < -745 || x > 709) + case LibFunc_exp: // RangeError: (x < -745 || x > 709) LowerBound = -745.0f; UpperBound = 709.0f; break; - case LibFunc::expf: // RangeError: (x < -103 || x > 88) + case LibFunc_expf: // RangeError: (x < -103 || x > 88) LowerBound = -103.0f; UpperBound = 88.0f; break; - case LibFunc::expl: // RangeError: (x < -11399 || x > 11356) + case LibFunc_expl: // RangeError: (x < -11399 || x > 11356) LowerBound = -11399.0f; UpperBound = 11356.0f; break; - case LibFunc::exp10: // RangeError: (x < -323 || x > 308) + case LibFunc_exp10: // RangeError: (x < -323 || x > 308) LowerBound = -323.0f; UpperBound = 308.0f; break; - case LibFunc::exp10f: // RangeError: (x < -45 || x > 38) + case LibFunc_exp10f: // RangeError: (x < -45 || x > 38) LowerBound = -45.0f; UpperBound = 38.0f; break; - case LibFunc::exp10l: // RangeError: (x < -4950 || x > 4932) + case LibFunc_exp10l: // RangeError: (x < -4950 || x > 4932) LowerBound = -4950.0f; UpperBound = 4932.0f; break; - case LibFunc::exp2: // RangeError: (x < -1074 || x > 1023) + case LibFunc_exp2: // RangeError: (x < -1074 || x > 1023) LowerBound = -1074.0f; UpperBound = 1023.0f; break; - case LibFunc::exp2f: // RangeError: (x < -149 || x > 127) + case LibFunc_exp2f: // RangeError: (x < -149 || x > 127) LowerBound = -149.0f; UpperBound = 127.0f; break; - case LibFunc::exp2l: // RangeError: (x < -16445 || x > 11383) + case LibFunc_exp2l: // RangeError: (x < -16445 || x > 11383) LowerBound = -16445.0f; UpperBound = 11383.0f; break; default: - llvm_unreachable("Should be reach here"); + llvm_unreachable("Unhandled library call!"); } ++NumWrappedTwoCond; @@ -434,9 +418,9 @@ Value *LibCallsShrinkWrap::generateTwoRangeCond(CallInst *CI, // (i.e. we might invoke the calls that will not set the errno.). // Value *LibCallsShrinkWrap::generateCondForPow(CallInst *CI, - const LibFunc::Func &Func) { - // FIXME: LibFunc::powf and powl TBD. - if (Func != LibFunc::pow) { + const LibFunc &Func) { + // FIXME: LibFunc_powf and powl TBD. + if (Func != LibFunc_pow) { DEBUG(dbgs() << "Not handled powf() and powl()\n"); return nullptr; } @@ -499,14 +483,17 @@ Value *LibCallsShrinkWrap::generateCondForPow(CallInst *CI, // Wrap conditions that can potentially generate errno to the library call. void LibCallsShrinkWrap::shrinkWrapCI(CallInst *CI, Value *Cond) { - assert(Cond != nullptr && "hrinkWrapCI is not expecting an empty call inst"); + assert(Cond != nullptr && "ShrinkWrapCI is not expecting an empty call inst"); MDNode *BranchWeights = MDBuilder(CI->getContext()).createBranchWeights(1, 2000); + TerminatorInst *NewInst = - SplitBlockAndInsertIfThen(Cond, CI, false, BranchWeights); + SplitBlockAndInsertIfThen(Cond, CI, false, BranchWeights, DT); BasicBlock *CallBB = NewInst->getParent(); CallBB->setName("cdce.call"); - CallBB->getSingleSuccessor()->setName("cdce.end"); + BasicBlock *SuccBB = CallBB->getSingleSuccessor(); + assert(SuccBB && "The split block should have a single successor"); + SuccBB->setName("cdce.end"); CI->removeFromParent(); CallBB->getInstList().insert(CallBB->getFirstInsertionPt(), CI); DEBUG(dbgs() << "== Basic Block After =="); @@ -516,38 +503,44 @@ void LibCallsShrinkWrap::shrinkWrapCI(CallInst *CI, Value *Cond) { // Perform the transformation to a single candidate. bool LibCallsShrinkWrap::perform(CallInst *CI) { - LibFunc::Func Func; + LibFunc Func; Function *Callee = CI->getCalledFunction(); assert(Callee && "perform() should apply to a non-empty callee"); TLI.getLibFunc(*Callee, Func); assert(Func && "perform() is not expecting an empty function"); - if (LibCallsShrinkWrapDoDomainError && performCallDomainErrorOnly(CI, Func)) - return true; - - if (LibCallsShrinkWrapDoRangeError && performCallRangeErrorOnly(CI, Func)) + if (performCallDomainErrorOnly(CI, Func) || performCallRangeErrorOnly(CI, Func)) return true; - return performCallErrors(CI, Func); } void LibCallsShrinkWrapLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addPreserved<DominatorTreeWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); } -static bool runImpl(Function &F, const TargetLibraryInfo &TLI) { +static bool runImpl(Function &F, const TargetLibraryInfo &TLI, + DominatorTree *DT) { if (F.hasFnAttribute(Attribute::OptimizeForSize)) return false; - LibCallsShrinkWrap CCDCE(TLI); + LibCallsShrinkWrap CCDCE(TLI, DT); CCDCE.visit(F); - CCDCE.perform(); - return CCDCE.isChanged(); + bool Changed = CCDCE.perform(); + +// Verify the dominator after we've updated it locally. +#ifndef NDEBUG + if (DT) + DT->verifyDomTree(); +#endif + return Changed; } bool LibCallsShrinkWrapLegacyPass::runOnFunction(Function &F) { auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); - return runImpl(F, TLI); + auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); + auto *DT = DTWP ? &DTWP->getDomTree() : nullptr; + return runImpl(F, TLI, DT); } namespace llvm { @@ -561,11 +554,12 @@ FunctionPass *createLibCallsShrinkWrapPass() { PreservedAnalyses LibCallsShrinkWrapPass::run(Function &F, FunctionAnalysisManager &FAM) { auto &TLI = FAM.getResult<TargetLibraryAnalysis>(F); - bool Changed = runImpl(F, TLI); - if (!Changed) + auto *DT = FAM.getCachedResult<DominatorTreeAnalysis>(F); + if (!runImpl(F, TLI, DT)) return PreservedAnalyses::all(); auto PA = PreservedAnalyses(); PA.preserve<GlobalsAA>(); + PA.preserve<DominatorTreeAnalysis>(); return PA; } } diff --git a/contrib/llvm/lib/Transforms/Utils/Local.cpp b/contrib/llvm/lib/Transforms/Utils/Local.cpp index 6e4174a..7461061 100644 --- a/contrib/llvm/lib/Transforms/Utils/Local.cpp +++ b/contrib/llvm/lib/Transforms/Utils/Local.cpp @@ -22,10 +22,11 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/EHPersonalities.h" #include "llvm/Analysis/InstructionSimplify.h" -#include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/LazyValueInfo.h" +#include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/CFG.h" +#include "llvm/IR/ConstantRange.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DIBuilder.h" #include "llvm/IR/DataLayout.h" @@ -45,6 +46,7 @@ #include "llvm/IR/PatternMatch.h" #include "llvm/IR/ValueHandle.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/KnownBits.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" using namespace llvm; @@ -126,21 +128,20 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, // If the default is unreachable, ignore it when searching for TheOnlyDest. if (isa<UnreachableInst>(DefaultDest->getFirstNonPHIOrDbg()) && SI->getNumCases() > 0) { - TheOnlyDest = SI->case_begin().getCaseSuccessor(); + TheOnlyDest = SI->case_begin()->getCaseSuccessor(); } // Figure out which case it goes to. - for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); - i != e; ++i) { + for (auto i = SI->case_begin(), e = SI->case_end(); i != e;) { // Found case matching a constant operand? - if (i.getCaseValue() == CI) { - TheOnlyDest = i.getCaseSuccessor(); + if (i->getCaseValue() == CI) { + TheOnlyDest = i->getCaseSuccessor(); break; } // Check to see if this branch is going to the same place as the default // dest. If so, eliminate it as an explicit compare. - if (i.getCaseSuccessor() == DefaultDest) { + if (i->getCaseSuccessor() == DefaultDest) { MDNode *MD = SI->getMetadata(LLVMContext::MD_prof); unsigned NCases = SI->getNumCases(); // Fold the case metadata into the default if there will be any branches @@ -154,7 +155,7 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, Weights.push_back(CI->getValue().getZExtValue()); } // Merge weight of this case to the default weight. - unsigned idx = i.getCaseIndex(); + unsigned idx = i->getCaseIndex(); Weights[0] += Weights[idx+1]; // Remove weight for this case. std::swap(Weights[idx+1], Weights.back()); @@ -165,15 +166,19 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, } // Remove this entry. DefaultDest->removePredecessor(SI->getParent()); - SI->removeCase(i); - --i; --e; + i = SI->removeCase(i); + e = SI->case_end(); continue; } // Otherwise, check to see if the switch only branches to one destination. // We do this by reseting "TheOnlyDest" to null when we find two non-equal // destinations. - if (i.getCaseSuccessor() != TheOnlyDest) TheOnlyDest = nullptr; + if (i->getCaseSuccessor() != TheOnlyDest) + TheOnlyDest = nullptr; + + // Increment this iterator as we haven't removed the case. + ++i; } if (CI && !TheOnlyDest) { @@ -209,7 +214,7 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, if (SI->getNumCases() == 1) { // Otherwise, we can fold this switch into a conditional branch // instruction if it has only one non-default destination. - SwitchInst::CaseIt FirstCase = SI->case_begin(); + auto FirstCase = *SI->case_begin(); Value *Cond = Builder.CreateICmpEQ(SI->getCondition(), FirstCase.getCaseValue(), "cond"); @@ -287,7 +292,15 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, /// bool llvm::isInstructionTriviallyDead(Instruction *I, const TargetLibraryInfo *TLI) { - if (!I->use_empty() || isa<TerminatorInst>(I)) return false; + if (!I->use_empty()) + return false; + return wouldInstructionBeTriviallyDead(I, TLI); +} + +bool llvm::wouldInstructionBeTriviallyDead(Instruction *I, + const TargetLibraryInfo *TLI) { + if (isa<TerminatorInst>(I)) + return false; // We don't want the landingpad-like instructions removed by anything this // general. @@ -307,7 +320,8 @@ bool llvm::isInstructionTriviallyDead(Instruction *I, return true; } - if (!I->mayHaveSideEffects()) return true; + if (!I->mayHaveSideEffects()) + return true; // Special case intrinsics that "may have side effects" but can be deleted // when dead. @@ -334,7 +348,8 @@ bool llvm::isInstructionTriviallyDead(Instruction *I, } } - if (isAllocLikeFn(I, TLI)) return true; + if (isAllocLikeFn(I, TLI)) + return true; if (CallInst *CI = isFreeCall(I, TLI)) if (Constant *C = dyn_cast<Constant>(CI->getArgOperand(0))) @@ -548,7 +563,7 @@ void llvm::RemovePredecessorAndSimplify(BasicBlock *BB, BasicBlock *Pred) { // that can be removed. BB->removePredecessor(Pred, true); - WeakVH PhiIt = &BB->front(); + WeakTrackingVH PhiIt = &BB->front(); while (PHINode *PN = dyn_cast<PHINode>(PhiIt)) { PhiIt = &*++BasicBlock::iterator(cast<Instruction>(PhiIt)); Value *OldPhiIt = PhiIt; @@ -1023,17 +1038,15 @@ unsigned llvm::getOrEnforceKnownAlignment(Value *V, unsigned PrefAlign, const DominatorTree *DT) { assert(V->getType()->isPointerTy() && "getOrEnforceKnownAlignment expects a pointer!"); - unsigned BitWidth = DL.getPointerTypeSizeInBits(V->getType()); - APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - computeKnownBits(V, KnownZero, KnownOne, DL, 0, AC, CxtI, DT); - unsigned TrailZ = KnownZero.countTrailingOnes(); + KnownBits Known = computeKnownBits(V, DL, 0, AC, CxtI, DT); + unsigned TrailZ = Known.countMinTrailingZeros(); // Avoid trouble with ridiculously large TrailZ values, such as // those computed from a null pointer. TrailZ = std::min(TrailZ, unsigned(sizeof(unsigned) * CHAR_BIT - 1)); - unsigned Align = 1u << std::min(BitWidth - 1, TrailZ); + unsigned Align = 1u << std::min(Known.getBitWidth() - 1, TrailZ); // LLVM doesn't support alignments larger than this currently. Align = std::min(Align, +Value::MaximumAlignment); @@ -1069,17 +1082,17 @@ static bool LdStHasDebugValue(DILocalVariable *DIVar, DIExpression *DIExpr, } /// See if there is a dbg.value intrinsic for DIVar for the PHI node. -static bool PhiHasDebugValue(DILocalVariable *DIVar, +static bool PhiHasDebugValue(DILocalVariable *DIVar, DIExpression *DIExpr, PHINode *APN) { // Since we can't guarantee that the original dbg.declare instrinsic // is removed by LowerDbgDeclare(), we need to make sure that we are // not inserting the same dbg.value intrinsic over and over. - DbgValueList DbgValues; - FindAllocaDbgValues(DbgValues, APN); - for (auto DVI : DbgValues) { - assert (DVI->getValue() == APN); - assert (DVI->getOffset() == 0); + SmallVector<DbgValueInst *, 1> DbgValues; + findDbgValues(DbgValues, APN); + for (auto *DVI : DbgValues) { + assert(DVI->getValue() == APN); + assert(DVI->getOffset() == 0); if ((DVI->getVariable() == DIVar) && (DVI->getExpression() == DIExpr)) return true; } @@ -1091,8 +1104,9 @@ static bool PhiHasDebugValue(DILocalVariable *DIVar, void llvm::ConvertDebugDeclareToDebugValue(DbgDeclareInst *DDI, StoreInst *SI, DIBuilder &Builder) { auto *DIVar = DDI->getVariable(); - auto *DIExpr = DDI->getExpression(); assert(DIVar && "Missing variable"); + auto *DIExpr = DDI->getExpression(); + Value *DV = SI->getOperand(0); // If an argument is zero extended then use argument directly. The ZExt // may be zapped by an optimization pass in future. @@ -1102,34 +1116,28 @@ void llvm::ConvertDebugDeclareToDebugValue(DbgDeclareInst *DDI, if (SExtInst *SExt = dyn_cast<SExtInst>(SI->getOperand(0))) ExtendedArg = dyn_cast<Argument>(SExt->getOperand(0)); if (ExtendedArg) { - // We're now only describing a subset of the variable. The fragment we're - // describing will always be smaller than the variable size, because - // VariableSize == Size of Alloca described by DDI. Since SI stores - // to the alloca described by DDI, if it's first operand is an extend, - // we're guaranteed that before extension, the value was narrower than - // the size of the alloca, hence the size of the described variable. - SmallVector<uint64_t, 3> Ops; - unsigned FragmentOffset = 0; - // If this already is a bit fragment, we drop the bit fragment from the - // expression and record the offset. - auto Fragment = DIExpr->getFragmentInfo(); - if (Fragment) { - Ops.append(DIExpr->elements_begin(), DIExpr->elements_end()-3); - FragmentOffset = Fragment->OffsetInBits; - } else { - Ops.append(DIExpr->elements_begin(), DIExpr->elements_end()); + // If this DDI was already describing only a fragment of a variable, ensure + // that fragment is appropriately narrowed here. + // But if a fragment wasn't used, describe the value as the original + // argument (rather than the zext or sext) so that it remains described even + // if the sext/zext is optimized away. This widens the variable description, + // leaving it up to the consumer to know how the smaller value may be + // represented in a larger register. + if (auto Fragment = DIExpr->getFragmentInfo()) { + unsigned FragmentOffset = Fragment->OffsetInBits; + SmallVector<uint64_t, 3> Ops(DIExpr->elements_begin(), + DIExpr->elements_end() - 3); + Ops.push_back(dwarf::DW_OP_LLVM_fragment); + Ops.push_back(FragmentOffset); + const DataLayout &DL = DDI->getModule()->getDataLayout(); + Ops.push_back(DL.getTypeSizeInBits(ExtendedArg->getType())); + DIExpr = Builder.createExpression(Ops); } - Ops.push_back(dwarf::DW_OP_LLVM_fragment); - Ops.push_back(FragmentOffset); - const DataLayout &DL = DDI->getModule()->getDataLayout(); - Ops.push_back(DL.getTypeSizeInBits(ExtendedArg->getType())); - auto NewDIExpr = Builder.createExpression(Ops); - if (!LdStHasDebugValue(DIVar, NewDIExpr, SI)) - Builder.insertDbgValueIntrinsic(ExtendedArg, 0, DIVar, NewDIExpr, - DDI->getDebugLoc(), SI); - } else if (!LdStHasDebugValue(DIVar, DIExpr, SI)) - Builder.insertDbgValueIntrinsic(SI->getOperand(0), 0, DIVar, DIExpr, - DDI->getDebugLoc(), SI); + DV = ExtendedArg; + } + if (!LdStHasDebugValue(DIVar, DIExpr, SI)) + Builder.insertDbgValueIntrinsic(DV, 0, DIVar, DIExpr, DDI->getDebugLoc(), + SI); } /// Inserts a llvm.dbg.value intrinsic before a load of an alloca'd value @@ -1152,7 +1160,7 @@ void llvm::ConvertDebugDeclareToDebugValue(DbgDeclareInst *DDI, DbgValue->insertAfter(LI); } -/// Inserts a llvm.dbg.value intrinsic after a phi +/// Inserts a llvm.dbg.value intrinsic after a phi /// that has an associated llvm.dbg.decl intrinsic. void llvm::ConvertDebugDeclareToDebugValue(DbgDeclareInst *DDI, PHINode *APN, DIBuilder &Builder) { @@ -1214,13 +1222,9 @@ bool llvm::LowerDbgDeclare(Function &F) { // This is a call by-value or some other instruction that // takes a pointer to the variable. Insert a *value* // intrinsic that describes the alloca. - SmallVector<uint64_t, 1> NewDIExpr; - auto *DIExpr = DDI->getExpression(); - NewDIExpr.push_back(dwarf::DW_OP_deref); - NewDIExpr.append(DIExpr->elements_begin(), DIExpr->elements_end()); DIB.insertDbgValueIntrinsic(AI, 0, DDI->getVariable(), - DIB.createExpression(NewDIExpr), - DDI->getDebugLoc(), CI); + DDI->getExpression(), DDI->getDebugLoc(), + CI); } } DDI->eraseFromParent(); @@ -1241,9 +1245,7 @@ DbgDeclareInst *llvm::FindAllocaDbgDeclare(Value *V) { return nullptr; } -/// FindAllocaDbgValues - Finds the llvm.dbg.value intrinsics describing the -/// alloca 'V', if any. -void llvm::FindAllocaDbgValues(DbgValueList &DbgValues, Value *V) { +void llvm::findDbgValues(SmallVectorImpl<DbgValueInst *> &DbgValues, Value *V) { if (auto *L = LocalAsMetadata::getIfExists(V)) if (auto *MDV = MetadataAsValue::getIfExists(V->getContext(), L)) for (User *U : MDV->users()) @@ -1251,37 +1253,6 @@ void llvm::FindAllocaDbgValues(DbgValueList &DbgValues, Value *V) { DbgValues.push_back(DVI); } -static void DIExprAddDeref(SmallVectorImpl<uint64_t> &Expr) { - Expr.push_back(dwarf::DW_OP_deref); -} - -static void DIExprAddOffset(SmallVectorImpl<uint64_t> &Expr, int Offset) { - if (Offset > 0) { - Expr.push_back(dwarf::DW_OP_plus); - Expr.push_back(Offset); - } else if (Offset < 0) { - Expr.push_back(dwarf::DW_OP_minus); - Expr.push_back(-Offset); - } -} - -static DIExpression *BuildReplacementDIExpr(DIBuilder &Builder, - DIExpression *DIExpr, bool Deref, - int Offset) { - if (!Deref && !Offset) - return DIExpr; - // Create a copy of the original DIDescriptor for user variable, prepending - // "deref" operation to a list of address elements, as new llvm.dbg.declare - // will take a value storing address of the memory for variable, not - // alloca itself. - SmallVector<uint64_t, 4> NewDIExpr; - if (Deref) - DIExprAddDeref(NewDIExpr); - DIExprAddOffset(NewDIExpr, Offset); - if (DIExpr) - NewDIExpr.append(DIExpr->elements_begin(), DIExpr->elements_end()); - return Builder.createExpression(NewDIExpr); -} bool llvm::replaceDbgDeclare(Value *Address, Value *NewAddress, Instruction *InsertBefore, DIBuilder &Builder, @@ -1293,9 +1264,7 @@ bool llvm::replaceDbgDeclare(Value *Address, Value *NewAddress, auto *DIVar = DDI->getVariable(); auto *DIExpr = DDI->getExpression(); assert(DIVar && "Missing variable"); - - DIExpr = BuildReplacementDIExpr(Builder, DIExpr, Deref, Offset); - + DIExpr = DIExpression::prepend(DIExpr, Deref, Offset); // Insert llvm.dbg.declare immediately after the original alloca, and remove // old llvm.dbg.declare. Builder.insertDeclare(NewAddress, DIVar, DIExpr, Loc, InsertBefore); @@ -1326,11 +1295,11 @@ static void replaceOneDbgValueForAlloca(DbgValueInst *DVI, Value *NewAddress, // Insert the offset immediately after the first deref. // We could just change the offset argument of dbg.value, but it's unsigned... if (Offset) { - SmallVector<uint64_t, 4> NewDIExpr; - DIExprAddDeref(NewDIExpr); - DIExprAddOffset(NewDIExpr, Offset); - NewDIExpr.append(DIExpr->elements_begin() + 1, DIExpr->elements_end()); - DIExpr = Builder.createExpression(NewDIExpr); + SmallVector<uint64_t, 4> Ops; + Ops.push_back(dwarf::DW_OP_deref); + DIExpression::appendOffset(Ops, Offset); + Ops.append(DIExpr->elements_begin() + 1, DIExpr->elements_end()); + DIExpr = Builder.createExpression(Ops); } Builder.insertDbgValueIntrinsic(NewAddress, DVI->getOffset(), DIVar, DIExpr, @@ -1349,6 +1318,57 @@ void llvm::replaceDbgValueForAlloca(AllocaInst *AI, Value *NewAllocaAddress, } } +void llvm::salvageDebugInfo(Instruction &I) { + SmallVector<DbgValueInst *, 1> DbgValues; + auto &M = *I.getModule(); + + auto MDWrap = [&](Value *V) { + return MetadataAsValue::get(I.getContext(), ValueAsMetadata::get(V)); + }; + + if (isa<BitCastInst>(&I)) { + findDbgValues(DbgValues, &I); + for (auto *DVI : DbgValues) { + // Bitcasts are entirely irrelevant for debug info. Rewrite the dbg.value + // to use the cast's source. + DVI->setOperand(0, MDWrap(I.getOperand(0))); + DEBUG(dbgs() << "SALVAGE: " << *DVI << '\n'); + } + } else if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) { + findDbgValues(DbgValues, &I); + for (auto *DVI : DbgValues) { + unsigned BitWidth = + M.getDataLayout().getPointerSizeInBits(GEP->getPointerAddressSpace()); + APInt Offset(BitWidth, 0); + // Rewrite a constant GEP into a DIExpression. Since we are performing + // arithmetic to compute the variable's *value* in the DIExpression, we + // need to mark the expression with a DW_OP_stack_value. + if (GEP->accumulateConstantOffset(M.getDataLayout(), Offset)) { + auto *DIExpr = DVI->getExpression(); + DIBuilder DIB(M, /*AllowUnresolved*/ false); + // GEP offsets are i32 and thus always fit into an int64_t. + DIExpr = DIExpression::prepend(DIExpr, DIExpression::NoDeref, + Offset.getSExtValue(), + DIExpression::WithStackValue); + DVI->setOperand(0, MDWrap(I.getOperand(0))); + DVI->setOperand(3, MetadataAsValue::get(I.getContext(), DIExpr)); + DEBUG(dbgs() << "SALVAGE: " << *DVI << '\n'); + } + } + } else if (isa<LoadInst>(&I)) { + findDbgValues(DbgValues, &I); + for (auto *DVI : DbgValues) { + // Rewrite the load into DW_OP_deref. + auto *DIExpr = DVI->getExpression(); + DIBuilder DIB(M, /*AllowUnresolved*/ false); + DIExpr = DIExpression::prepend(DIExpr, DIExpression::WithDeref); + DVI->setOperand(0, MDWrap(I.getOperand(0))); + DVI->setOperand(3, MetadataAsValue::get(I.getContext(), DIExpr)); + DEBUG(dbgs() << "SALVAGE: " << *DVI << '\n'); + } + } +} + unsigned llvm::removeAllNonTerminatorAndEHPadInstructions(BasicBlock *BB) { unsigned NumDeadInst = 0; // Delete the instructions backwards, as it has a reduced likelihood of @@ -1450,7 +1470,7 @@ BasicBlock *llvm::changeToInvokeAndSplitBasicBlock(CallInst *CI, II->setAttributes(CI->getAttributes()); // Make sure that anything using the call now uses the invoke! This also - // updates the CallGraph if present, because it uses a WeakVH. + // updates the CallGraph if present, because it uses a WeakTrackingVH. CI->replaceAllUsesWith(II); // Delete the original call @@ -1642,9 +1662,10 @@ void llvm::removeUnwindEdge(BasicBlock *BB) { TI->eraseFromParent(); } -/// removeUnreachableBlocksFromFn - Remove blocks that are not reachable, even +/// removeUnreachableBlocks - Remove blocks that are not reachable, even /// if they are in a dead cycle. Return true if a change was made, false -/// otherwise. +/// otherwise. If `LVI` is passed, this function preserves LazyValueInfo +/// after modifying the CFG. bool llvm::removeUnreachableBlocks(Function &F, LazyValueInfo *LVI) { SmallPtrSet<BasicBlock*, 16> Reachable; bool Changed = markAliveBlocks(F, Reachable); @@ -1723,12 +1744,12 @@ void llvm::combineMetadata(Instruction *K, const Instruction *J, // Preserve !invariant.group in K. break; case LLVMContext::MD_align: - K->setMetadata(Kind, + K->setMetadata(Kind, MDNode::getMostGenericAlignmentOrDereferenceable(JMD, KMD)); break; case LLVMContext::MD_dereferenceable: case LLVMContext::MD_dereferenceable_or_null: - K->setMetadata(Kind, + K->setMetadata(Kind, MDNode::getMostGenericAlignmentOrDereferenceable(JMD, KMD)); break; } @@ -1755,46 +1776,62 @@ void llvm::combineMetadataForCSE(Instruction *K, const Instruction *J) { combineMetadata(K, J, KnownIDs); } -unsigned llvm::replaceDominatedUsesWith(Value *From, Value *To, - DominatorTree &DT, - const BasicBlockEdge &Root) { +template <typename RootType, typename DominatesFn> +static unsigned replaceDominatedUsesWith(Value *From, Value *To, + const RootType &Root, + const DominatesFn &Dominates) { assert(From->getType() == To->getType()); - + unsigned Count = 0; for (Value::use_iterator UI = From->use_begin(), UE = From->use_end(); - UI != UE; ) { + UI != UE;) { Use &U = *UI++; - if (DT.dominates(Root, U)) { - U.set(To); - DEBUG(dbgs() << "Replace dominated use of '" - << From->getName() << "' as " - << *To << " in " << *U << "\n"); - ++Count; - } + if (!Dominates(Root, U)) + continue; + U.set(To); + DEBUG(dbgs() << "Replace dominated use of '" << From->getName() << "' as " + << *To << " in " << *U << "\n"); + ++Count; } return Count; } -unsigned llvm::replaceDominatedUsesWith(Value *From, Value *To, - DominatorTree &DT, - const BasicBlock *BB) { - assert(From->getType() == To->getType()); +unsigned llvm::replaceNonLocalUsesWith(Instruction *From, Value *To) { + assert(From->getType() == To->getType()); + auto *BB = From->getParent(); + unsigned Count = 0; - unsigned Count = 0; for (Value::use_iterator UI = From->use_begin(), UE = From->use_end(); UI != UE;) { Use &U = *UI++; auto *I = cast<Instruction>(U.getUser()); - if (DT.properlyDominates(BB, I->getParent())) { - U.set(To); - DEBUG(dbgs() << "Replace dominated use of '" << From->getName() << "' as " - << *To << " in " << *U << "\n"); - ++Count; - } + if (I->getParent() == BB) + continue; + U.set(To); + ++Count; } return Count; } +unsigned llvm::replaceDominatedUsesWith(Value *From, Value *To, + DominatorTree &DT, + const BasicBlockEdge &Root) { + auto Dominates = [&DT](const BasicBlockEdge &Root, const Use &U) { + return DT.dominates(Root, U); + }; + return ::replaceDominatedUsesWith(From, To, Root, Dominates); +} + +unsigned llvm::replaceDominatedUsesWith(Value *From, Value *To, + DominatorTree &DT, + const BasicBlock *BB) { + auto ProperlyDominates = [&DT](const BasicBlock *BB, const Use &U) { + auto *I = cast<Instruction>(U.getUser())->getParent(); + return DT.properlyDominates(BB, I); + }; + return ::replaceDominatedUsesWith(From, To, BB, ProperlyDominates); +} + bool llvm::callsGCLeafFunction(ImmutableCallSite CS) { // Check if the function is specifically marked as a gc leaf function. if (CS.hasFnAttr("gc-leaf-function")) @@ -1812,6 +1849,49 @@ bool llvm::callsGCLeafFunction(ImmutableCallSite CS) { return false; } +void llvm::copyNonnullMetadata(const LoadInst &OldLI, MDNode *N, + LoadInst &NewLI) { + auto *NewTy = NewLI.getType(); + + // This only directly applies if the new type is also a pointer. + if (NewTy->isPointerTy()) { + NewLI.setMetadata(LLVMContext::MD_nonnull, N); + return; + } + + // The only other translation we can do is to integral loads with !range + // metadata. + if (!NewTy->isIntegerTy()) + return; + + MDBuilder MDB(NewLI.getContext()); + const Value *Ptr = OldLI.getPointerOperand(); + auto *ITy = cast<IntegerType>(NewTy); + auto *NullInt = ConstantExpr::getPtrToInt( + ConstantPointerNull::get(cast<PointerType>(Ptr->getType())), ITy); + auto *NonNullInt = ConstantExpr::getAdd(NullInt, ConstantInt::get(ITy, 1)); + NewLI.setMetadata(LLVMContext::MD_range, + MDB.createRange(NonNullInt, NullInt)); +} + +void llvm::copyRangeMetadata(const DataLayout &DL, const LoadInst &OldLI, + MDNode *N, LoadInst &NewLI) { + auto *NewTy = NewLI.getType(); + + // Give up unless it is converted to a pointer where there is a single very + // valuable mapping we can do reliably. + // FIXME: It would be nice to propagate this in more ways, but the type + // conversions make it hard. + if (!NewTy->isPointerTy()) + return; + + unsigned BitWidth = DL.getTypeSizeInBits(NewTy); + if (!getConstantRangeFromMetadata(*N).contains(APInt(BitWidth, 0))) { + MDNode *NN = MDNode::get(OldLI.getContext(), None); + NewLI.setMetadata(LLVMContext::MD_nonnull, NN); + } +} + namespace { /// A potential constituent of a bitreverse or bswap expression. See /// collectBitParts for a fuller explanation. @@ -1933,7 +2013,7 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals, unsigned NumMaskedBits = AndMask.countPopulation(); if (!MatchBitReversals && NumMaskedBits % 8 != 0) return Result; - + auto &Res = collectBitParts(I->getOperand(0), MatchBSwaps, MatchBitReversals, BPS); if (!Res) @@ -2068,9 +2148,63 @@ bool llvm::recognizeBSwapOrBitReverseIdiom( void llvm::maybeMarkSanitizerLibraryCallNoBuiltin( CallInst *CI, const TargetLibraryInfo *TLI) { Function *F = CI->getCalledFunction(); - LibFunc::Func Func; + LibFunc Func; if (F && !F->hasLocalLinkage() && F->hasName() && TLI->getLibFunc(F->getName(), Func) && TLI->hasOptimizedCodeGen(Func) && !F->doesNotAccessMemory()) - CI->addAttribute(AttributeSet::FunctionIndex, Attribute::NoBuiltin); + CI->addAttribute(AttributeList::FunctionIndex, Attribute::NoBuiltin); +} + +bool llvm::canReplaceOperandWithVariable(const Instruction *I, unsigned OpIdx) { + // We can't have a PHI with a metadata type. + if (I->getOperand(OpIdx)->getType()->isMetadataTy()) + return false; + + // Early exit. + if (!isa<Constant>(I->getOperand(OpIdx))) + return true; + + switch (I->getOpcode()) { + default: + return true; + case Instruction::Call: + case Instruction::Invoke: + // Can't handle inline asm. Skip it. + if (isa<InlineAsm>(ImmutableCallSite(I).getCalledValue())) + return false; + // Many arithmetic intrinsics have no issue taking a + // variable, however it's hard to distingish these from + // specials such as @llvm.frameaddress that require a constant. + if (isa<IntrinsicInst>(I)) + return false; + + // Constant bundle operands may need to retain their constant-ness for + // correctness. + if (ImmutableCallSite(I).isBundleOperand(OpIdx)) + return false; + return true; + case Instruction::ShuffleVector: + // Shufflevector masks are constant. + return OpIdx != 2; + case Instruction::Switch: + case Instruction::ExtractValue: + // All operands apart from the first are constant. + return OpIdx == 0; + case Instruction::InsertValue: + // All operands apart from the first and the second are constant. + return OpIdx < 2; + case Instruction::Alloca: + // Static allocas (constant size in the entry block) are handled by + // prologue/epilogue insertion so they're free anyway. We definitely don't + // want to make them non-constant. + return !dyn_cast<AllocaInst>(I)->isStaticAlloca(); + case Instruction::GetElementPtr: + if (OpIdx == 0) + return true; + gep_type_iterator It = gep_type_begin(I); + for (auto E = std::next(It, OpIdx); It != E; ++It) + if (It.isStruct()) + return false; + return true; + } } diff --git a/contrib/llvm/lib/Transforms/Utils/LoopSimplify.cpp b/contrib/llvm/lib/Transforms/Utils/LoopSimplify.cpp index 00cda2a..e21e34d 100644 --- a/contrib/llvm/lib/Transforms/Utils/LoopSimplify.cpp +++ b/contrib/llvm/lib/Transforms/Utils/LoopSimplify.cpp @@ -38,15 +38,14 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/LoopSimplify.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/SetOperations.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/DependenceAnalysis.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InstructionSimplify.h" @@ -65,6 +64,7 @@ #include "llvm/IR/Type.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" @@ -72,7 +72,6 @@ using namespace llvm; #define DEBUG_TYPE "loop-simplify" -STATISTIC(NumInserted, "Number of pre-header or exit blocks inserted"); STATISTIC(NumNested , "Number of nested loops split out"); // If the block isn't already, move the new block to right after some 'outside @@ -152,37 +151,6 @@ BasicBlock *llvm::InsertPreheaderForLoop(Loop *L, DominatorTree *DT, return PreheaderBB; } -/// \brief Ensure that the loop preheader dominates all exit blocks. -/// -/// This method is used to split exit blocks that have predecessors outside of -/// the loop. -static BasicBlock *rewriteLoopExitBlock(Loop *L, BasicBlock *Exit, - DominatorTree *DT, LoopInfo *LI, - bool PreserveLCSSA) { - SmallVector<BasicBlock*, 8> LoopBlocks; - for (pred_iterator I = pred_begin(Exit), E = pred_end(Exit); I != E; ++I) { - BasicBlock *P = *I; - if (L->contains(P)) { - // Don't do this if the loop is exited via an indirect branch. - if (isa<IndirectBrInst>(P->getTerminator())) return nullptr; - - LoopBlocks.push_back(P); - } - } - - assert(!LoopBlocks.empty() && "No edges coming in from outside the loop?"); - BasicBlock *NewExitBB = nullptr; - - NewExitBB = SplitBlockPredecessors(Exit, LoopBlocks, ".loopexit", DT, LI, - PreserveLCSSA); - if (!NewExitBB) - return nullptr; - - DEBUG(dbgs() << "LoopSimplify: Creating dedicated exit block " - << NewExitBB->getName() << "\n"); - return NewExitBB; -} - /// Add the specified block, and all of its predecessors, to the specified set, /// if it's not already in there. Stop predecessor traversal when we reach /// StopBlock. @@ -210,7 +178,7 @@ static PHINode *findPHIToPartitionLoops(Loop *L, DominatorTree *DT, for (BasicBlock::iterator I = L->getHeader()->begin(); isa<PHINode>(I); ) { PHINode *PN = cast<PHINode>(I); ++I; - if (Value *V = SimplifyInstruction(PN, DL, nullptr, DT, AC)) { + if (Value *V = SimplifyInstruction(PN, {DL, nullptr, DT, AC})) { // This is a degenerate PHI already, don't modify it! PN->replaceAllUsesWith(V); PN->eraseFromParent(); @@ -346,16 +314,7 @@ static Loop *separateNestedLoop(Loop *L, BasicBlock *Preheader, // Split edges to exit blocks from the inner loop, if they emerged in the // process of separating the outer one. - SmallVector<BasicBlock *, 8> ExitBlocks; - L->getExitBlocks(ExitBlocks); - SmallSetVector<BasicBlock *, 8> ExitBlockSet(ExitBlocks.begin(), - ExitBlocks.end()); - for (BasicBlock *ExitBlock : ExitBlockSet) { - if (any_of(predecessors(ExitBlock), - [L](BasicBlock *BB) { return !L->contains(BB); })) { - rewriteLoopExitBlock(L, ExitBlock, DT, LI, PreserveLCSSA); - } - } + formDedicatedExitBlocks(L, DT, LI, PreserveLCSSA); if (PreserveLCSSA) { // Fix LCSSA form for L. Some values, which previously were only used inside @@ -563,29 +522,16 @@ ReprocessLoop: BasicBlock *Preheader = L->getLoopPreheader(); if (!Preheader) { Preheader = InsertPreheaderForLoop(L, DT, LI, PreserveLCSSA); - if (Preheader) { - ++NumInserted; + if (Preheader) Changed = true; - } } // Next, check to make sure that all exit nodes of the loop only have // predecessors that are inside of the loop. This check guarantees that the // loop preheader/header will dominate the exit blocks. If the exit block has // predecessors from outside of the loop, split the edge now. - SmallVector<BasicBlock*, 8> ExitBlocks; - L->getExitBlocks(ExitBlocks); - - SmallSetVector<BasicBlock *, 8> ExitBlockSet(ExitBlocks.begin(), - ExitBlocks.end()); - for (BasicBlock *ExitBlock : ExitBlockSet) { - if (any_of(predecessors(ExitBlock), - [L](BasicBlock *BB) { return !L->contains(BB); })) { - rewriteLoopExitBlock(L, ExitBlock, DT, LI, PreserveLCSSA); - ++NumInserted; - Changed = true; - } - } + if (formDedicatedExitBlocks(L, DT, LI, PreserveLCSSA)) + Changed = true; // If the header has more than two predecessors at this point (from the // preheader and from multiple backedges), we must adjust the loop. @@ -614,10 +560,8 @@ ReprocessLoop: // insert a new block that all backedges target, then make it jump to the // loop header. LoopLatch = insertUniqueBackedgeBlock(L, Preheader, DT, LI); - if (LoopLatch) { - ++NumInserted; + if (LoopLatch) Changed = true; - } } const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); @@ -628,7 +572,7 @@ ReprocessLoop: PHINode *PN; for (BasicBlock::iterator I = L->getHeader()->begin(); (PN = dyn_cast<PHINode>(I++)); ) - if (Value *V = SimplifyInstruction(PN, DL, nullptr, DT, AC)) { + if (Value *V = SimplifyInstruction(PN, {DL, nullptr, DT, AC})) { if (SE) SE->forgetValue(PN); if (!PreserveLCSSA || LI->replacementPreservesLCSSAForm(PN, V)) { PN->replaceAllUsesWith(V); @@ -645,14 +589,22 @@ ReprocessLoop: // loop-invariant instructions out of the way to open up more // opportunities, and the disadvantage of having the responsibility // to preserve dominator information. - bool UniqueExit = true; - if (!ExitBlocks.empty()) - for (unsigned i = 1, e = ExitBlocks.size(); i != e; ++i) - if (ExitBlocks[i] != ExitBlocks[0]) { - UniqueExit = false; - break; + auto HasUniqueExitBlock = [&]() { + BasicBlock *UniqueExit = nullptr; + for (auto *ExitingBB : ExitingBlocks) + for (auto *SuccBB : successors(ExitingBB)) { + if (L->contains(SuccBB)) + continue; + + if (!UniqueExit) + UniqueExit = SuccBB; + else if (UniqueExit != SuccBB) + return false; } - if (UniqueExit) { + + return true; + }; + if (HasUniqueExitBlock()) { for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) { BasicBlock *ExitingBlock = ExitingBlocks[i]; if (!ExitingBlock->getSinglePredecessor()) continue; @@ -735,6 +687,17 @@ bool llvm::simplifyLoop(Loop *L, DominatorTree *DT, LoopInfo *LI, bool PreserveLCSSA) { bool Changed = false; +#ifndef NDEBUG + // If we're asked to preserve LCSSA, the loop nest needs to start in LCSSA + // form. + if (PreserveLCSSA) { + assert(DT && "DT not available."); + assert(LI && "LI not available."); + assert(L->isRecursivelyLCSSAForm(*DT, *LI) && + "Requested to preserve LCSSA, but it's already broken."); + } +#endif + // Worklist maintains our depth-first queue of loops in this nest to process. SmallVector<Loop *, 4> Worklist; Worklist.push_back(L); @@ -814,15 +777,6 @@ bool LoopSimplify::runOnFunction(Function &F) { &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); bool PreserveLCSSA = mustPreserveAnalysisID(LCSSAID); -#ifndef NDEBUG - if (PreserveLCSSA) { - assert(DT && "DT not available."); - assert(LI && "LI not available."); - bool InLCSSA = all_of( - *LI, [&](Loop *L) { return L->isRecursivelyLCSSAForm(*DT, *LI); }); - assert(InLCSSA && "Requested to preserve LCSSA, but it's already broken."); - } -#endif // Simplify each loop nest in the function. for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I) @@ -846,17 +800,14 @@ PreservedAnalyses LoopSimplifyPass::run(Function &F, ScalarEvolution *SE = AM.getCachedResult<ScalarEvolutionAnalysis>(F); AssumptionCache *AC = &AM.getResult<AssumptionAnalysis>(F); - // FIXME: This pass should verify that the loops on which it's operating - // are in canonical SSA form, and that the pass itself preserves this form. + // Note that we don't preserve LCSSA in the new PM, if you need it run LCSSA + // after simplifying the loops. for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I) - Changed |= simplifyLoop(*I, DT, LI, SE, AC, true /* PreserveLCSSA */); - - // FIXME: We need to invalidate this to avoid PR28400. Is there a better - // solution? - AM.invalidate<ScalarEvolutionAnalysis>(F); + Changed |= simplifyLoop(*I, DT, LI, SE, AC, /*PreserveLCSSA*/ false); if (!Changed) return PreservedAnalyses::all(); + PreservedAnalyses PA; PA.preserve<DominatorTreeAnalysis>(); PA.preserve<LoopAnalysis>(); diff --git a/contrib/llvm/lib/Transforms/Utils/LoopUnroll.cpp b/contrib/llvm/lib/Transforms/Utils/LoopUnroll.cpp index e346ebd..f2527f8 100644 --- a/contrib/llvm/lib/Transforms/Utils/LoopUnroll.cpp +++ b/contrib/llvm/lib/Transforms/Utils/LoopUnroll.cpp @@ -16,7 +16,6 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Utils/UnrollLoop.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AssumptionCache.h" @@ -27,6 +26,7 @@ #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LLVMContext.h" @@ -38,6 +38,7 @@ #include "llvm/Transforms/Utils/LoopSimplify.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/SimplifyIndVar.h" +#include "llvm/Transforms/Utils/UnrollLoop.h" using namespace llvm; #define DEBUG_TYPE "loop-unroll" @@ -51,6 +52,16 @@ UnrollRuntimeEpilog("unroll-runtime-epilog", cl::init(false), cl::Hidden, cl::desc("Allow runtime unrolled loops to be unrolled " "with epilog instead of prolog.")); +static cl::opt<bool> +UnrollVerifyDomtree("unroll-verify-domtree", cl::Hidden, + cl::desc("Verify domtree after unrolling"), +#ifdef NDEBUG + cl::init(false) +#else + cl::init(true) +#endif + ); + /// Convert the instruction operands from referencing the current values into /// those specified by VMap. static inline void remapInstruction(Instruction *I, @@ -205,6 +216,45 @@ const Loop* llvm::addClonedBlockToLoopInfo(BasicBlock *OriginalBB, } } +/// The function chooses which type of unroll (epilog or prolog) is more +/// profitabale. +/// Epilog unroll is more profitable when there is PHI that starts from +/// constant. In this case epilog will leave PHI start from constant, +/// but prolog will convert it to non-constant. +/// +/// loop: +/// PN = PHI [I, Latch], [CI, PreHeader] +/// I = foo(PN) +/// ... +/// +/// Epilog unroll case. +/// loop: +/// PN = PHI [I2, Latch], [CI, PreHeader] +/// I1 = foo(PN) +/// I2 = foo(I1) +/// ... +/// Prolog unroll case. +/// NewPN = PHI [PrologI, Prolog], [CI, PreHeader] +/// loop: +/// PN = PHI [I2, Latch], [NewPN, PreHeader] +/// I1 = foo(PN) +/// I2 = foo(I1) +/// ... +/// +static bool isEpilogProfitable(Loop *L) { + BasicBlock *PreHeader = L->getLoopPreheader(); + BasicBlock *Header = L->getHeader(); + assert(PreHeader && Header); + for (Instruction &BBI : *Header) { + PHINode *PN = dyn_cast<PHINode>(&BBI); + if (!PN) + break; + if (isa<ConstantInt>(PN->getIncomingValueForBlock(PreHeader))) + return true; + } + return false; +} + /// Unroll the given loop by Count. The loop must be in LCSSA form. Returns true /// if unrolling was successful, or false if the loop was unmodified. Unrolling /// can only fail when the loop's latch block is not terminated by a conditional @@ -268,6 +318,10 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force, return false; } + // The current loop unroll pass can only unroll loops with a single latch + // that's a conditional branch exiting the loop. + // FIXME: The implementation can be extended to work with more complicated + // cases, e.g. loops with multiple latches. BasicBlock *Header = L->getHeader(); BranchInst *BI = dyn_cast<BranchInst>(LatchBlock->getTerminator()); @@ -278,6 +332,16 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force, return false; } + auto CheckSuccessors = [&](unsigned S1, unsigned S2) { + return BI->getSuccessor(S1) == Header && !L->contains(BI->getSuccessor(S2)); + }; + + if (!CheckSuccessors(0, 1) && !CheckSuccessors(1, 0)) { + DEBUG(dbgs() << "Can't unroll; only loops with one conditional latch" + " exiting the loop can be unrolled\n"); + return false; + } + if (Header->hasAddressTaken()) { // The loop-rotate pass can be helpful to avoid this in many cases. DEBUG(dbgs() << @@ -296,8 +360,10 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force, Count = TripCount; // Don't enter the unroll code if there is nothing to do. - if (TripCount == 0 && Count < 2 && PeelCount == 0) + if (TripCount == 0 && Count < 2 && PeelCount == 0) { + DEBUG(dbgs() << "Won't unroll; almost nothing to do\n"); return false; + } assert(Count > 0); assert(TripMultiple > 0); @@ -330,7 +396,7 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force, "and peeling for the same loop"); if (PeelCount) - peelLoop(L, PeelCount, LI, SE, DT, PreserveLCSSA); + peelLoop(L, PeelCount, LI, SE, DT, AC, PreserveLCSSA); // Loops containing convergent instructions must have a count that divides // their TripMultiple. @@ -346,14 +412,22 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force, "convergent operation."); }); + bool EpilogProfitability = + UnrollRuntimeEpilog.getNumOccurrences() ? UnrollRuntimeEpilog + : isEpilogProfitable(L); + if (RuntimeTripCount && TripMultiple % Count != 0 && !UnrollRuntimeLoopRemainder(L, Count, AllowExpensiveTripCount, - UnrollRuntimeEpilog, LI, SE, DT, + EpilogProfitability, LI, SE, DT, PreserveLCSSA)) { if (Force) RuntimeTripCount = false; - else + else { + DEBUG( + dbgs() << "Wont unroll; remainder loop could not be generated" + "when assuming runtime trip count\n"); return false; + } } // Notify ScalarEvolution that the loop will be substantially changed, @@ -446,6 +520,12 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force, for (Loop *SubLoop : *L) LoopsToSimplify.insert(SubLoop); + if (Header->getParent()->isDebugInfoForProfiling()) + for (BasicBlock *BB : L->getBlocks()) + for (Instruction &I : *BB) + if (const DILocation *DIL = I.getDebugLoc()) + I.setDebugLoc(DIL->cloneWithDuplicationFactor(Count)); + for (unsigned It = 1; It != Count; ++It) { std::vector<BasicBlock*> NewBlocks; SmallDenseMap<const Loop *, Loop *, 4> NewLoops; @@ -456,19 +536,16 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force, BasicBlock *New = CloneBasicBlock(*BB, VMap, "." + Twine(It)); Header->getParent()->getBasicBlockList().push_back(New); + assert((*BB != Header || LI->getLoopFor(*BB) == L) && + "Header should not be in a sub-loop"); // Tell LI about New. - if (*BB == Header) { - assert(LI->getLoopFor(*BB) == L && "Header should not be in a sub-loop"); - L->addBasicBlockToLoop(New, *LI); - } else { - const Loop *OldLoop = addClonedBlockToLoopInfo(*BB, New, LI, NewLoops); - if (OldLoop) { - LoopsToSimplify.insert(NewLoops[OldLoop]); + const Loop *OldLoop = addClonedBlockToLoopInfo(*BB, New, LI, NewLoops); + if (OldLoop) { + LoopsToSimplify.insert(NewLoops[OldLoop]); - // Forget the old loop, since its inputs may have changed. - if (SE) - SE->forgetLoop(OldLoop); - } + // Forget the old loop, since its inputs may have changed. + if (SE) + SE->forgetLoop(OldLoop); } if (*BB == Header) @@ -615,14 +692,11 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force, Term->eraseFromParent(); } } + // Update dominators of blocks we might reach through exits. // Immediate dominator of such block might change, because we add more // routes which can lead to the exit: we can now reach it from the copied - // iterations too. Thus, the new idom of the block will be the nearest - // common dominator of the previous idom and common dominator of all copies of - // the previous idom. This is equivalent to the nearest common dominator of - // the previous idom and the first latch, which dominates all copies of the - // previous idom. + // iterations too. if (DT && Count > 1) { for (auto *BB : OriginalLoopBlocks) { auto *BBDomNode = DT->getNode(BB); @@ -632,12 +706,38 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force, if (!L->contains(ChildBB)) ChildrenToUpdate.push_back(ChildBB); } - BasicBlock *NewIDom = DT->findNearestCommonDominator(BB, Latches[0]); + BasicBlock *NewIDom; + if (BB == LatchBlock) { + // The latch is special because we emit unconditional branches in + // some cases where the original loop contained a conditional branch. + // Since the latch is always at the bottom of the loop, if the latch + // dominated an exit before unrolling, the new dominator of that exit + // must also be a latch. Specifically, the dominator is the first + // latch which ends in a conditional branch, or the last latch if + // there is no such latch. + NewIDom = Latches.back(); + for (BasicBlock *IterLatch : Latches) { + TerminatorInst *Term = IterLatch->getTerminator(); + if (isa<BranchInst>(Term) && cast<BranchInst>(Term)->isConditional()) { + NewIDom = IterLatch; + break; + } + } + } else { + // The new idom of the block will be the nearest common dominator + // of all copies of the previous idom. This is equivalent to the + // nearest common dominator of the previous idom and the first latch, + // which dominates all copies of the previous idom. + NewIDom = DT->findNearestCommonDominator(BB, LatchBlock); + } for (auto *ChildBB : ChildrenToUpdate) DT->changeImmediateDominator(ChildBB, NewIDom); } } + if (DT && UnrollVerifyDomtree) + DT->verifyDomTree(); + // Merge adjacent basic blocks, if possible. SmallPtrSet<Loop *, 4> ForgottenLoops; for (BasicBlock *Latch : Latches) { @@ -655,16 +755,9 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force, } } - // FIXME: We only preserve DT info for complete unrolling now. Incrementally - // updating domtree after partial loop unrolling should also be easy. - if (DT && !CompletelyUnroll) - DT->recalculate(*L->getHeader()->getParent()); - else if (DT) - DEBUG(DT->verifyDomTree()); - // Simplify any new induction variables in the partially unrolled loop. if (SE && !CompletelyUnroll && Count > 1) { - SmallVector<WeakVH, 16> DeadInsts; + SmallVector<WeakTrackingVH, 16> DeadInsts; simplifyLoopIVs(L, SE, DT, LI, DeadInsts); // Aggressively clean up dead instructions that simplifyLoopIVs already @@ -684,7 +777,7 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force, for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ) { Instruction *Inst = &*I++; - if (Value *V = SimplifyInstruction(Inst, DL)) + if (Value *V = SimplifyInstruction(Inst, {DL, nullptr, DT, AC})) if (LI->replacementPreservesLCSSAForm(Inst, V)) Inst->replaceAllUsesWith(V); if (isInstructionTriviallyDead(Inst)) @@ -721,29 +814,29 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force, // at least one layer outside of the loop that was unrolled so that any // changes to the parent loop exposed by the unrolling are considered. if (DT) { - if (!OuterL && !CompletelyUnroll) - OuterL = L; if (OuterL) { // OuterL includes all loops for which we can break loop-simplify, so // it's sufficient to simplify only it (it'll recursively simplify inner // loops too). + if (NeedToFixLCSSA) { + // LCSSA must be performed on the outermost affected loop. The unrolled + // loop's last loop latch is guaranteed to be in the outermost loop + // after LoopInfo's been updated by markAsRemoved. + Loop *LatchLoop = LI->getLoopFor(Latches.back()); + Loop *FixLCSSALoop = OuterL; + if (!FixLCSSALoop->contains(LatchLoop)) + while (FixLCSSALoop->getParentLoop() != LatchLoop) + FixLCSSALoop = FixLCSSALoop->getParentLoop(); + + formLCSSARecursively(*FixLCSSALoop, *DT, LI, SE); + } else if (PreserveLCSSA) { + assert(OuterL->isLCSSAForm(*DT) && + "Loops should be in LCSSA form after loop-unroll."); + } + // TODO: That potentially might be compile-time expensive. We should try // to fix the loop-simplified form incrementally. simplifyLoop(OuterL, DT, LI, SE, AC, PreserveLCSSA); - - // LCSSA must be performed on the outermost affected loop. The unrolled - // loop's last loop latch is guaranteed to be in the outermost loop after - // LoopInfo's been updated by markAsRemoved. - Loop *LatchLoop = LI->getLoopFor(Latches.back()); - if (!OuterL->contains(LatchLoop)) - while (OuterL->getParentLoop() != LatchLoop) - OuterL = OuterL->getParentLoop(); - - if (NeedToFixLCSSA) - formLCSSARecursively(*OuterL, *DT, LI, SE); - else - assert(OuterL->isLCSSAForm(*DT) && - "Loops should be in LCSSA form after loop-unroll."); } else { // Simplify loops for which we might've broken loop-simplify form. for (Loop *SubLoop : LoopsToSimplify) diff --git a/contrib/llvm/lib/Transforms/Utils/LoopUnrollPeel.cpp b/contrib/llvm/lib/Transforms/Utils/LoopUnrollPeel.cpp index 842cf31..5c21490 100644 --- a/contrib/llvm/lib/Transforms/Utils/LoopUnrollPeel.cpp +++ b/contrib/llvm/lib/Transforms/Utils/LoopUnrollPeel.cpp @@ -28,6 +28,7 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/LoopSimplify.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/UnrollLoop.h" #include <algorithm> @@ -45,6 +46,11 @@ static cl::opt<unsigned> UnrollForcePeelCount( "unroll-force-peel-count", cl::init(0), cl::Hidden, cl::desc("Force a peel count regardless of profiling information.")); +// Designates that a Phi is estimated to become invariant after an "infinite" +// number of loop iterations (i.e. only may become an invariant if the loop is +// fully unrolled). +static const unsigned InfiniteIterationsToInvariance = UINT_MAX; + // Check whether we are capable of peeling this loop. static bool canPeel(Loop *L) { // Make sure the loop is in simplified form @@ -55,12 +61,72 @@ static bool canPeel(Loop *L) { if (!L->getExitingBlock() || !L->getUniqueExitBlock()) return false; + // Don't try to peel loops where the latch is not the exiting block. + // This can be an indication of two different things: + // 1) The loop is not rotated. + // 2) The loop contains irreducible control flow that involves the latch. + if (L->getLoopLatch() != L->getExitingBlock()) + return false; + return true; } +// This function calculates the number of iterations after which the given Phi +// becomes an invariant. The pre-calculated values are memorized in the map. The +// function (shortcut is I) is calculated according to the following definition: +// Given %x = phi <Inputs from above the loop>, ..., [%y, %back.edge]. +// If %y is a loop invariant, then I(%x) = 1. +// If %y is a Phi from the loop header, I(%x) = I(%y) + 1. +// Otherwise, I(%x) is infinite. +// TODO: Actually if %y is an expression that depends only on Phi %z and some +// loop invariants, we can estimate I(%x) = I(%z) + 1. The example +// looks like: +// %x = phi(0, %a), <-- becomes invariant starting from 3rd iteration. +// %y = phi(0, 5), +// %a = %y + 1. +static unsigned calculateIterationsToInvariance( + PHINode *Phi, Loop *L, BasicBlock *BackEdge, + SmallDenseMap<PHINode *, unsigned> &IterationsToInvariance) { + assert(Phi->getParent() == L->getHeader() && + "Non-loop Phi should not be checked for turning into invariant."); + assert(BackEdge == L->getLoopLatch() && "Wrong latch?"); + // If we already know the answer, take it from the map. + auto I = IterationsToInvariance.find(Phi); + if (I != IterationsToInvariance.end()) + return I->second; + + // Otherwise we need to analyze the input from the back edge. + Value *Input = Phi->getIncomingValueForBlock(BackEdge); + // Place infinity to map to avoid infinite recursion for cycled Phis. Such + // cycles can never stop on an invariant. + IterationsToInvariance[Phi] = InfiniteIterationsToInvariance; + unsigned ToInvariance = InfiniteIterationsToInvariance; + + if (L->isLoopInvariant(Input)) + ToInvariance = 1u; + else if (PHINode *IncPhi = dyn_cast<PHINode>(Input)) { + // Only consider Phis in header block. + if (IncPhi->getParent() != L->getHeader()) + return InfiniteIterationsToInvariance; + // If the input becomes an invariant after X iterations, then our Phi + // becomes an invariant after X + 1 iterations. + unsigned InputToInvariance = calculateIterationsToInvariance( + IncPhi, L, BackEdge, IterationsToInvariance); + if (InputToInvariance != InfiniteIterationsToInvariance) + ToInvariance = InputToInvariance + 1u; + } + + // If we found that this Phi lies in an invariant chain, update the map. + if (ToInvariance != InfiniteIterationsToInvariance) + IterationsToInvariance[Phi] = ToInvariance; + return ToInvariance; +} + // Return the number of iterations we want to peel off. void llvm::computePeelCount(Loop *L, unsigned LoopSize, - TargetTransformInfo::UnrollingPreferences &UP) { + TargetTransformInfo::UnrollingPreferences &UP, + unsigned &TripCount) { + assert(LoopSize > 0 && "Zero loop size is not allowed!"); UP.PeelCount = 0; if (!canPeel(L)) return; @@ -69,6 +135,46 @@ void llvm::computePeelCount(Loop *L, unsigned LoopSize, if (!L->empty()) return; + // Here we try to get rid of Phis which become invariants after 1, 2, ..., N + // iterations of the loop. For this we compute the number for iterations after + // which every Phi is guaranteed to become an invariant, and try to peel the + // maximum number of iterations among these values, thus turning all those + // Phis into invariants. + // First, check that we can peel at least one iteration. + if (2 * LoopSize <= UP.Threshold && UnrollPeelMaxCount > 0) { + // Store the pre-calculated values here. + SmallDenseMap<PHINode *, unsigned> IterationsToInvariance; + // Now go through all Phis to calculate their the number of iterations they + // need to become invariants. + unsigned DesiredPeelCount = 0; + BasicBlock *BackEdge = L->getLoopLatch(); + assert(BackEdge && "Loop is not in simplified form?"); + for (auto BI = L->getHeader()->begin(); isa<PHINode>(&*BI); ++BI) { + PHINode *Phi = cast<PHINode>(&*BI); + unsigned ToInvariance = calculateIterationsToInvariance( + Phi, L, BackEdge, IterationsToInvariance); + if (ToInvariance != InfiniteIterationsToInvariance) + DesiredPeelCount = std::max(DesiredPeelCount, ToInvariance); + } + if (DesiredPeelCount > 0) { + // Pay respect to limitations implied by loop size and the max peel count. + unsigned MaxPeelCount = UnrollPeelMaxCount; + MaxPeelCount = std::min(MaxPeelCount, UP.Threshold / LoopSize - 1); + DesiredPeelCount = std::min(DesiredPeelCount, MaxPeelCount); + // Consider max peel count limitation. + assert(DesiredPeelCount > 0 && "Wrong loop size estimation?"); + DEBUG(dbgs() << "Peel " << DesiredPeelCount << " iteration(s) to turn" + << " some Phis into invariants.\n"); + UP.PeelCount = DesiredPeelCount; + return; + } + } + + // Bail if we know the statically calculated trip count. + // In this case we rather prefer partial unrolling. + if (TripCount) + return; + // If the user provided a peel count, use that. bool UserPeelCount = UnrollForcePeelCount.getNumOccurrences() > 0; if (UserPeelCount) { @@ -164,7 +270,8 @@ static void cloneLoopBlocks(Loop *L, unsigned IterNumber, BasicBlock *InsertTop, BasicBlock *InsertBot, BasicBlock *Exit, SmallVectorImpl<BasicBlock *> &NewBlocks, LoopBlocksDFS &LoopBlocks, ValueToValueMapTy &VMap, - ValueToValueMapTy &LVMap, LoopInfo *LI) { + ValueToValueMapTy &LVMap, DominatorTree *DT, + LoopInfo *LI) { BasicBlock *Header = L->getHeader(); BasicBlock *Latch = L->getLoopLatch(); @@ -185,6 +292,17 @@ static void cloneLoopBlocks(Loop *L, unsigned IterNumber, BasicBlock *InsertTop, ParentLoop->addBasicBlockToLoop(NewBB, *LI); VMap[*BB] = NewBB; + + // If dominator tree is available, insert nodes to represent cloned blocks. + if (DT) { + if (Header == *BB) + DT->addNewBlock(NewBB, InsertTop); + else { + DomTreeNode *IDom = DT->getNode(*BB)->getIDom(); + // VMap must contain entry for IDom, as the iteration order is RPO. + DT->addNewBlock(NewBB, cast<BasicBlock>(VMap[IDom->getBlock()])); + } + } } // Hook-up the control flow for the newly inserted blocks. @@ -198,11 +316,13 @@ static void cloneLoopBlocks(Loop *L, unsigned IterNumber, BasicBlock *InsertTop, // The backedge now goes to the "bottom", which is either the loop's real // header (for the last peeled iteration) or the copied header of the next // iteration (for every other iteration) - BranchInst *LatchBR = - cast<BranchInst>(cast<BasicBlock>(VMap[Latch])->getTerminator()); + BasicBlock *NewLatch = cast<BasicBlock>(VMap[Latch]); + BranchInst *LatchBR = cast<BranchInst>(NewLatch->getTerminator()); unsigned HeaderIdx = (LatchBR->getSuccessor(0) == Header ? 0 : 1); LatchBR->setSuccessor(HeaderIdx, InsertBot); LatchBR->setSuccessor(1 - HeaderIdx, Exit); + if (DT) + DT->changeImmediateDominator(InsertBot, NewLatch); // The new copy of the loop body starts with a bunch of PHI nodes // that pick an incoming value from either the preheader, or the previous @@ -257,7 +377,7 @@ static void cloneLoopBlocks(Loop *L, unsigned IterNumber, BasicBlock *InsertTop, /// optimizations. bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT, - bool PreserveLCSSA) { + AssumptionCache *AC, bool PreserveLCSSA) { if (!canPeel(L)) return false; @@ -358,7 +478,24 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI, CurHeaderWeight = 1; cloneLoopBlocks(L, Iter, InsertTop, InsertBot, Exit, - NewBlocks, LoopBlocks, VMap, LVMap, LI); + NewBlocks, LoopBlocks, VMap, LVMap, DT, LI); + + // Remap to use values from the current iteration instead of the + // previous one. + remapInstructionsInBlocks(NewBlocks, VMap); + + if (DT) { + // Latches of the cloned loops dominate over the loop exit, so idom of the + // latter is the first cloned loop body, as original PreHeader dominates + // the original loop body. + if (Iter == 0) + DT->changeImmediateDominator(Exit, cast<BasicBlock>(LVMap[Latch])); +#ifndef NDEBUG + if (VerifyDomInfo) + DT->verifyDomTree(); +#endif + } + updateBranchWeights(InsertBot, cast<BranchInst>(VMap[LatchBR]), Iter, PeelCount, ExitWeight); @@ -369,10 +506,6 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI, F->getBasicBlockList().splice(InsertTop->getIterator(), F->getBasicBlockList(), NewBlocks[0]->getIterator(), F->end()); - - // Remap to use values from the current iteration instead of the - // previous one. - remapInstructionsInBlocks(NewBlocks, VMap); } // Now adjust the phi nodes in the loop header to get their initial values @@ -405,9 +538,16 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI, } // If the loop is nested, we changed the parent loop, update SE. - if (Loop *ParentLoop = L->getParentLoop()) + if (Loop *ParentLoop = L->getParentLoop()) { SE->forgetLoop(ParentLoop); + // FIXME: Incrementally update loop-simplify + simplifyLoop(ParentLoop, DT, LI, SE, AC, PreserveLCSSA); + } else { + // FIXME: Incrementally update loop-simplify + simplifyLoop(L, DT, LI, SE, AC, PreserveLCSSA); + } + NumPeeled++; return true; diff --git a/contrib/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/contrib/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp index d3ea156..d43ce7a 100644 --- a/contrib/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp +++ b/contrib/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp @@ -21,8 +21,8 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Utils/UnrollLoop.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/LoopIterator.h" #include "llvm/Analysis/LoopPass.h" @@ -37,6 +37,8 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/UnrollLoop.h" #include <algorithm> using namespace llvm; @@ -45,6 +47,10 @@ using namespace llvm; STATISTIC(NumRuntimeUnrolled, "Number of loops unrolled with run-time trip counts"); +static cl::opt<bool> UnrollRuntimeMultiExit( + "unroll-runtime-multi-exit", cl::init(false), cl::Hidden, + cl::desc("Allow runtime unrolling for loops with multiple exits, when " + "epilog is generated")); /// Connect the unrolling prolog code to the original loop. /// The unrolling prolog code contains code to execute the @@ -60,9 +66,11 @@ STATISTIC(NumRuntimeUnrolled, /// than the unroll factor. /// static void ConnectProlog(Loop *L, Value *BECount, unsigned Count, - BasicBlock *PrologExit, BasicBlock *PreHeader, - BasicBlock *NewPreHeader, ValueToValueMapTy &VMap, - DominatorTree *DT, LoopInfo *LI, bool PreserveLCSSA) { + BasicBlock *PrologExit, + BasicBlock *OriginalLoopLatchExit, + BasicBlock *PreHeader, BasicBlock *NewPreHeader, + ValueToValueMapTy &VMap, DominatorTree *DT, + LoopInfo *LI, bool PreserveLCSSA) { BasicBlock *Latch = L->getLoopLatch(); assert(Latch && "Loop must have a latch"); BasicBlock *PrologLatch = cast<BasicBlock>(VMap[Latch]); @@ -137,15 +145,15 @@ static void ConnectProlog(Loop *L, Value *BECount, unsigned Count, // then (BECount + 1) cannot unsigned-overflow. Value *BrLoopExit = B.CreateICmpULT(BECount, ConstantInt::get(BECount->getType(), Count - 1)); - BasicBlock *Exit = L->getUniqueExitBlock(); - assert(Exit && "Loop must have a single exit block only"); // Split the exit to maintain loop canonicalization guarantees - SmallVector<BasicBlock*, 4> Preds(predecessors(Exit)); - SplitBlockPredecessors(Exit, Preds, ".unr-lcssa", DT, LI, + SmallVector<BasicBlock *, 4> Preds(predecessors(OriginalLoopLatchExit)); + SplitBlockPredecessors(OriginalLoopLatchExit, Preds, ".unr-lcssa", DT, LI, PreserveLCSSA); // Add the branch to the exit block (around the unrolled loop) - B.CreateCondBr(BrLoopExit, Exit, NewPreHeader); + B.CreateCondBr(BrLoopExit, OriginalLoopLatchExit, NewPreHeader); InsertPt->eraseFromParent(); + if (DT) + DT->changeImmediateDominator(OriginalLoopLatchExit, PrologExit); } /// Connect the unrolling epilog code to the original loop. @@ -260,13 +268,20 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit, IRBuilder<> B(InsertPt); Value *BrLoopExit = B.CreateIsNotNull(ModVal, "lcmp.mod"); assert(Exit && "Loop must have a single exit block only"); - // Split the exit to maintain loop canonicalization guarantees + // Split the epilogue exit to maintain loop canonicalization guarantees SmallVector<BasicBlock*, 4> Preds(predecessors(Exit)); SplitBlockPredecessors(Exit, Preds, ".epilog-lcssa", DT, LI, PreserveLCSSA); // Add the branch to the exit block (around the unrolling loop) B.CreateCondBr(BrLoopExit, EpilogPreHeader, Exit); InsertPt->eraseFromParent(); + if (DT) + DT->changeImmediateDominator(Exit, NewExit); + + // Split the main loop exit to maintain canonicalization guarantees. + SmallVector<BasicBlock*, 4> NewExitPreds{Latch}; + SplitBlockPredecessors(NewExit, NewExitPreds, ".loopexit", DT, LI, + PreserveLCSSA); } /// Create a clone of the blocks in a loop and connect them together. @@ -276,35 +291,23 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit, /// The cloned blocks should be inserted between InsertTop and InsertBot. /// If loop structure is cloned InsertTop should be new preheader, InsertBot /// new loop exit. -/// -static void CloneLoopBlocks(Loop *L, Value *NewIter, - const bool CreateRemainderLoop, - const bool UseEpilogRemainder, - BasicBlock *InsertTop, BasicBlock *InsertBot, - BasicBlock *Preheader, - std::vector<BasicBlock *> &NewBlocks, - LoopBlocksDFS &LoopBlocks, ValueToValueMapTy &VMap, - LoopInfo *LI) { +/// Return the new cloned loop that is created when CreateRemainderLoop is true. +static Loop * +CloneLoopBlocks(Loop *L, Value *NewIter, const bool CreateRemainderLoop, + const bool UseEpilogRemainder, BasicBlock *InsertTop, + BasicBlock *InsertBot, BasicBlock *Preheader, + std::vector<BasicBlock *> &NewBlocks, LoopBlocksDFS &LoopBlocks, + ValueToValueMapTy &VMap, DominatorTree *DT, LoopInfo *LI) { StringRef suffix = UseEpilogRemainder ? "epil" : "prol"; BasicBlock *Header = L->getHeader(); BasicBlock *Latch = L->getLoopLatch(); Function *F = Header->getParent(); LoopBlocksDFS::RPOIterator BlockBegin = LoopBlocks.beginRPO(); LoopBlocksDFS::RPOIterator BlockEnd = LoopBlocks.endRPO(); - Loop *NewLoop = nullptr; Loop *ParentLoop = L->getParentLoop(); - if (CreateRemainderLoop) { - NewLoop = new Loop(); - if (ParentLoop) - ParentLoop->addChildLoop(NewLoop); - else - LI->addTopLevelLoop(NewLoop); - } - NewLoopsMap NewLoops; - if (NewLoop) - NewLoops[L] = NewLoop; - else if (ParentLoop) + NewLoops[ParentLoop] = ParentLoop; + if (!CreateRemainderLoop) NewLoops[L] = ParentLoop; // For each block in the original loop, create a new copy, @@ -312,7 +315,7 @@ static void CloneLoopBlocks(Loop *L, Value *NewIter, for (LoopBlocksDFS::RPOIterator BB = BlockBegin; BB != BlockEnd; ++BB) { BasicBlock *NewBB = CloneBasicBlock(*BB, VMap, "." + suffix, F); NewBlocks.push_back(NewBB); - + // If we're unrolling the outermost loop, there's no remainder loop, // and this block isn't in a nested loop, then the new block is not // in any loop. Otherwise, add it to loopinfo. @@ -326,6 +329,17 @@ static void CloneLoopBlocks(Loop *L, Value *NewIter, InsertTop->getTerminator()->setSuccessor(0, NewBB); } + if (DT) { + if (Header == *BB) { + // The header is dominated by the preheader. + DT->addNewBlock(NewBB, InsertTop); + } else { + // Copy information from original loop to unrolled loop. + BasicBlock *IDomBB = DT->getNode(*BB)->getIDom()->getBlock(); + DT->addNewBlock(NewBB, cast<BasicBlock>(VMap[IDomBB])); + } + } + if (Latch == *BB) { // For the last block, if CreateRemainderLoop is false, create a direct // jump to InsertBot. If not, create a loop back to cloned head. @@ -376,7 +390,9 @@ static void CloneLoopBlocks(Loop *L, Value *NewIter, NewPHI->setIncomingValue(idx, V); } } - if (NewLoop) { + if (CreateRemainderLoop) { + Loop *NewLoop = NewLoops[L]; + assert(NewLoop && "L should have been cloned"); // Add unroll disable metadata to disable future unrolling for this loop. SmallVector<Metadata *, 4> MDs; // Reserve first location for self reference to the LoopID metadata node. @@ -406,9 +422,56 @@ static void CloneLoopBlocks(Loop *L, Value *NewIter, // Set operand 0 to refer to the loop id itself. NewLoopID->replaceOperandWith(0, NewLoopID); NewLoop->setLoopID(NewLoopID); + return NewLoop; } + else + return nullptr; +} + +/// Returns true if we can safely unroll a multi-exit/exiting loop. OtherExits +/// is populated with all the loop exit blocks other than the LatchExit block. +static bool +canSafelyUnrollMultiExitLoop(Loop *L, SmallVectorImpl<BasicBlock *> &OtherExits, + BasicBlock *LatchExit, bool PreserveLCSSA, + bool UseEpilogRemainder) { + + // Support runtime unrolling for multiple exit blocks and multiple exiting + // blocks. + if (!UnrollRuntimeMultiExit) + return false; + // Even if runtime multi exit is enabled, we currently have some correctness + // constrains in unrolling a multi-exit loop. + // We rely on LCSSA form being preserved when the exit blocks are transformed. + if (!PreserveLCSSA) + return false; + SmallVector<BasicBlock *, 4> Exits; + L->getUniqueExitBlocks(Exits); + for (auto *BB : Exits) + if (BB != LatchExit) + OtherExits.push_back(BB); + + // TODO: Support multiple exiting blocks jumping to the `LatchExit` when + // UnrollRuntimeMultiExit is true. This will need updating the logic in + // connectEpilog/connectProlog. + if (!LatchExit->getSinglePredecessor()) { + DEBUG(dbgs() << "Bailout for multi-exit handling when latch exit has >1 " + "predecessor.\n"); + return false; + } + // FIXME: We bail out of multi-exit unrolling when epilog loop is generated + // and L is an inner loop. This is because in presence of multiple exits, the + // outer loop is incorrect: we do not add the EpilogPreheader and exit to the + // outer loop. This is automatically handled in the prolog case, so we do not + // have that bug in prolog generation. + if (UseEpilogRemainder && L->getParentLoop()) + return false; + + // All constraints have been satisfied. + return true; } + + /// Insert code in the prolog/epilog code when unrolling a loop with a /// run-time trip-count. /// @@ -452,18 +515,40 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count, bool UseEpilogRemainder, LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT, bool PreserveLCSSA) { - // for now, only unroll loops that contain a single exit - if (!L->getExitingBlock()) - return false; + DEBUG(dbgs() << "Trying runtime unrolling on Loop: \n"); + DEBUG(L->dump()); - // Make sure the loop is in canonical form, and there is a single - // exit block only. - if (!L->isLoopSimplifyForm()) - return false; - BasicBlock *Exit = L->getUniqueExitBlock(); // successor out of loop - if (!Exit) + // Make sure the loop is in canonical form. + if (!L->isLoopSimplifyForm()) { + DEBUG(dbgs() << "Not in simplify form!\n"); return false; + } + // Guaranteed by LoopSimplifyForm. + BasicBlock *Latch = L->getLoopLatch(); + BasicBlock *Header = L->getHeader(); + + BranchInst *LatchBR = cast<BranchInst>(Latch->getTerminator()); + unsigned ExitIndex = LatchBR->getSuccessor(0) == Header ? 1 : 0; + BasicBlock *LatchExit = LatchBR->getSuccessor(ExitIndex); + // Cloning the loop basic blocks (`CloneLoopBlocks`) requires that one of the + // targets of the Latch be an exit block out of the loop. This needs + // to be guaranteed by the callers of UnrollRuntimeLoopRemainder. + assert(!L->contains(LatchExit) && + "one of the loop latch successors should be the exit block!"); + // These are exit blocks other than the target of the latch exiting block. + SmallVector<BasicBlock *, 4> OtherExits; + bool isMultiExitUnrollingEnabled = canSafelyUnrollMultiExitLoop( + L, OtherExits, LatchExit, PreserveLCSSA, UseEpilogRemainder); + // Support only single exit and exiting block unless multi-exit loop unrolling is enabled. + if (!isMultiExitUnrollingEnabled && + (!L->getExitingBlock() || OtherExits.size())) { + DEBUG( + dbgs() + << "Multiple exit/exiting blocks in loop and multi-exit unrolling not " + "enabled!\n"); + return false; + } // Use Scalar Evolution to compute the trip count. This allows more loops to // be unrolled than relying on induction var simplification. if (!SE) @@ -471,34 +556,44 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count, // Only unroll loops with a computable trip count, and the trip count needs // to be an int value (allowing a pointer type is a TODO item). - const SCEV *BECountSC = SE->getBackedgeTakenCount(L); + // We calculate the backedge count by using getExitCount on the Latch block, + // which is proven to be the only exiting block in this loop. This is same as + // calculating getBackedgeTakenCount on the loop (which computes SCEV for all + // exiting blocks). + const SCEV *BECountSC = SE->getExitCount(L, Latch); if (isa<SCEVCouldNotCompute>(BECountSC) || - !BECountSC->getType()->isIntegerTy()) + !BECountSC->getType()->isIntegerTy()) { + DEBUG(dbgs() << "Could not compute exit block SCEV\n"); return false; + } unsigned BEWidth = cast<IntegerType>(BECountSC->getType())->getBitWidth(); // Add 1 since the backedge count doesn't include the first loop iteration. const SCEV *TripCountSC = SE->getAddExpr(BECountSC, SE->getConstant(BECountSC->getType(), 1)); - if (isa<SCEVCouldNotCompute>(TripCountSC)) + if (isa<SCEVCouldNotCompute>(TripCountSC)) { + DEBUG(dbgs() << "Could not compute trip count SCEV.\n"); return false; + } - BasicBlock *Header = L->getHeader(); BasicBlock *PreHeader = L->getLoopPreheader(); BranchInst *PreHeaderBR = cast<BranchInst>(PreHeader->getTerminator()); const DataLayout &DL = Header->getModule()->getDataLayout(); SCEVExpander Expander(*SE, DL, "loop-unroll"); if (!AllowExpensiveTripCount && - Expander.isHighCostExpansion(TripCountSC, L, PreHeaderBR)) + Expander.isHighCostExpansion(TripCountSC, L, PreHeaderBR)) { + DEBUG(dbgs() << "High cost for expanding trip count scev!\n"); return false; + } // This constraint lets us deal with an overflowing trip count easily; see the // comment on ModVal below. - if (Log2_32(Count) > BEWidth) + if (Log2_32(Count) > BEWidth) { + DEBUG(dbgs() + << "Count failed constraint on overflow trip count calculation.\n"); return false; - - BasicBlock *Latch = L->getLoopLatch(); + } // Loop structure is the following: // @@ -506,7 +601,7 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count, // Header // ... // Latch - // Exit + // LatchExit BasicBlock *NewPreHeader; BasicBlock *NewExit = nullptr; @@ -519,9 +614,9 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count, // Split PreHeader to insert a branch around loop for unrolling. NewPreHeader = SplitBlock(PreHeader, PreHeader->getTerminator(), DT, LI); NewPreHeader->setName(PreHeader->getName() + ".new"); - // Split Exit to create phi nodes from branch above. - SmallVector<BasicBlock*, 4> Preds(predecessors(Exit)); - NewExit = SplitBlockPredecessors(Exit, Preds, ".unr-lcssa", + // Split LatchExit to create phi nodes from branch above. + SmallVector<BasicBlock*, 4> Preds(predecessors(LatchExit)); + NewExit = SplitBlockPredecessors(LatchExit, Preds, ".unr-lcssa", DT, LI, PreserveLCSSA); // Split NewExit to insert epilog remainder loop. EpilogPreHeader = SplitBlock(NewExit, NewExit->getTerminator(), DT, LI); @@ -548,7 +643,7 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count, // Latch Header // *NewExit ... // *EpilogPreHeader Latch - // Exit Exit + // LatchExit LatchExit // Calculate conditions for branch around loop for unrolling // in epilog case and around prolog remainder loop in prolog case. @@ -599,6 +694,12 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count, // Branch to either remainder (extra iterations) loop or unrolling loop. B.CreateCondBr(BranchVal, RemainderLoop, UnrollingLoop); PreHeaderBR->eraseFromParent(); + if (DT) { + if (UseEpilogRemainder) + DT->changeImmediateDominator(NewExit, PreHeader); + else + DT->changeImmediateDominator(PrologExit, PreHeader); + } Function *F = Header->getParent(); // Get an ordered list of blocks in the loop to help with the ordering of the // cloned blocks in the prolog/epilog code @@ -620,10 +721,11 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count, // Clone all the basic blocks in the loop. If Count is 2, we don't clone // the loop, otherwise we create a cloned loop to execute the extra // iterations. This function adds the appropriate CFG connections. - BasicBlock *InsertBot = UseEpilogRemainder ? Exit : PrologExit; + BasicBlock *InsertBot = UseEpilogRemainder ? LatchExit : PrologExit; BasicBlock *InsertTop = UseEpilogRemainder ? EpilogPreHeader : PrologPreHeader; - CloneLoopBlocks(L, ModVal, CreateRemainderLoop, UseEpilogRemainder, InsertTop, - InsertBot, NewPreHeader, NewBlocks, LoopBlocks, VMap, LI); + Loop *remainderLoop = CloneLoopBlocks( + L, ModVal, CreateRemainderLoop, UseEpilogRemainder, InsertTop, InsertBot, + NewPreHeader, NewBlocks, LoopBlocks, VMap, DT, LI); // Insert the cloned blocks into the function. F->getBasicBlockList().splice(InsertBot->getIterator(), @@ -631,6 +733,66 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count, NewBlocks[0]->getIterator(), F->end()); + // Now the loop blocks are cloned and the other exiting blocks from the + // remainder are connected to the original Loop's exit blocks. The remaining + // work is to update the phi nodes in the original loop, and take in the + // values from the cloned region. Also update the dominator info for + // OtherExits and their immediate successors, since we have new edges into + // OtherExits. + SmallSet<BasicBlock*, 8> ImmediateSuccessorsOfExitBlocks; + for (auto *BB : OtherExits) { + for (auto &II : *BB) { + + // Given we preserve LCSSA form, we know that the values used outside the + // loop will be used through these phi nodes at the exit blocks that are + // transformed below. + if (!isa<PHINode>(II)) + break; + PHINode *Phi = cast<PHINode>(&II); + unsigned oldNumOperands = Phi->getNumIncomingValues(); + // Add the incoming values from the remainder code to the end of the phi + // node. + for (unsigned i =0; i < oldNumOperands; i++){ + Value *newVal = VMap[Phi->getIncomingValue(i)]; + // newVal can be a constant or derived from values outside the loop, and + // hence need not have a VMap value. + if (!newVal) + newVal = Phi->getIncomingValue(i); + Phi->addIncoming(newVal, + cast<BasicBlock>(VMap[Phi->getIncomingBlock(i)])); + } + } +#if defined(EXPENSIVE_CHECKS) && !defined(NDEBUG) + for (BasicBlock *SuccBB : successors(BB)) { + assert(!(any_of(OtherExits, + [SuccBB](BasicBlock *EB) { return EB == SuccBB; }) || + SuccBB == LatchExit) && + "Breaks the definition of dedicated exits!"); + } +#endif + // Update the dominator info because the immediate dominator is no longer the + // header of the original Loop. BB has edges both from L and remainder code. + // Since the preheader determines which loop is run (L or directly jump to + // the remainder code), we set the immediate dominator as the preheader. + if (DT) { + DT->changeImmediateDominator(BB, PreHeader); + // Also update the IDom for immediate successors of BB. If the current + // IDom is the header, update the IDom to be the preheader because that is + // the nearest common dominator of all predecessors of SuccBB. We need to + // check for IDom being the header because successors of exit blocks can + // have edges from outside the loop, and we should not incorrectly update + // the IDom in that case. + for (BasicBlock *SuccBB: successors(BB)) + if (ImmediateSuccessorsOfExitBlocks.insert(SuccBB).second) { + if (DT->getNode(SuccBB)->getIDom()->getBlock() == Header) { + assert(!SuccBB->getSinglePredecessor() && + "BB should be the IDom then!"); + DT->changeImmediateDominator(SuccBB, PreHeader); + } + } + } + } + // Loop structure should be the following: // Epilog Prolog // @@ -644,7 +806,7 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count, // EpilogHeader Header // ... ... // EpilogLatch Latch - // Exit Exit + // LatchExit LatchExit // Rewrite the cloned instruction operands to use the values created when the // clone is created. @@ -658,7 +820,7 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count, if (UseEpilogRemainder) { // Connect the epilog code to the original loop and update the // PHI functions. - ConnectEpilog(L, ModVal, NewExit, Exit, PreHeader, + ConnectEpilog(L, ModVal, NewExit, LatchExit, PreHeader, EpilogPreHeader, NewPreHeader, VMap, DT, LI, PreserveLCSSA); @@ -684,8 +846,8 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count, } else { // Connect the prolog code to the original loop and update the // PHI functions. - ConnectProlog(L, BECount, Count, PrologExit, PreHeader, NewPreHeader, - VMap, DT, LI, PreserveLCSSA); + ConnectProlog(L, BECount, Count, PrologExit, LatchExit, PreHeader, + NewPreHeader, VMap, DT, LI, PreserveLCSSA); } // If this loop is nested, then the loop unroller changes the code in the @@ -693,6 +855,19 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count, if (Loop *ParentLoop = L->getParentLoop()) SE->forgetLoop(ParentLoop); + // Canonicalize to LoopSimplifyForm both original and remainder loops. We + // cannot rely on the LoopUnrollPass to do this because it only does + // canonicalization for parent/subloops and not the sibling loops. + if (OtherExits.size() > 0) { + // Generate dedicated exit blocks for the original loop, to preserve + // LoopSimplifyForm. + formDedicatedExitBlocks(L, DT, LI, PreserveLCSSA); + // Generate dedicated exit blocks for the remainder loop if one exists, to + // preserve LoopSimplifyForm. + if (remainderLoop) + formDedicatedExitBlocks(remainderLoop, DT, LI, PreserveLCSSA); + } + NumRuntimeUnrolled++; return true; } diff --git a/contrib/llvm/lib/Transforms/Utils/LoopUtils.cpp b/contrib/llvm/lib/Transforms/Utils/LoopUtils.cpp index c8efa9e..3c52278 100644 --- a/contrib/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/contrib/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -12,16 +12,17 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/GlobalsModRef.h" -#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" #include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" @@ -29,6 +30,7 @@ #include "llvm/IR/ValueHandle.h" #include "llvm/Pass.h" #include "llvm/Support/Debug.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" using namespace llvm; using namespace llvm::PatternMatch; @@ -87,8 +89,7 @@ RecurrenceDescriptor::lookThroughAnd(PHINode *Phi, Type *&RT, // Matches either I & 2^x-1 or 2^x-1 & I. If we find a match, we update RT // with a new integer type of the corresponding bit width. - if (match(J, m_CombineOr(m_And(m_Instruction(I), m_APInt(M)), - m_And(m_APInt(M), m_Instruction(I))))) { + if (match(J, m_c_And(m_Instruction(I), m_APInt(M)))) { int32_t Bits = (*M + 1).exactLogBase2(); if (Bits > 0) { RT = IntegerType::get(Phi->getContext(), Bits); @@ -230,8 +231,9 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurrenceKind Kind, // - PHI: // - All uses of the PHI must be the reduction (safe). // - Otherwise, not safe. - // - By one instruction outside of the loop (safe). - // - By further instructions outside of the loop (not safe). + // - By instructions outside of the loop (safe). + // * One value may have several outside users, but all outside + // uses must be of the same value. // - By an instruction that is not part of the reduction (not safe). // This is either: // * An instruction type other than PHI or the reduction operation. @@ -297,10 +299,15 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurrenceKind Kind, // Check if we found the exit user. BasicBlock *Parent = UI->getParent(); if (!TheLoop->contains(Parent)) { - // Exit if you find multiple outside users or if the header phi node is - // being used. In this case the user uses the value of the previous - // iteration, in which case we would loose "VF-1" iterations of the - // reduction operation if we vectorize. + // If we already know this instruction is used externally, move on to + // the next user. + if (ExitInstruction == Cur) + continue; + + // Exit if you find multiple values used outside or if the header phi + // node is being used. In this case the user uses the value of the + // previous iteration, in which case we would loose "VF-1" iterations of + // the reduction operation if we vectorize. if (ExitInstruction != nullptr || Cur == Phi) return false; @@ -521,8 +528,9 @@ bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop, return false; } -bool RecurrenceDescriptor::isFirstOrderRecurrence(PHINode *Phi, Loop *TheLoop, - DominatorTree *DT) { +bool RecurrenceDescriptor::isFirstOrderRecurrence( + PHINode *Phi, Loop *TheLoop, + DenseMap<Instruction *, Instruction *> &SinkAfter, DominatorTree *DT) { // Ensure the phi node is in the loop header and has two incoming values. if (Phi->getParent() != TheLoop->getHeader() || @@ -544,16 +552,29 @@ bool RecurrenceDescriptor::isFirstOrderRecurrence(PHINode *Phi, Loop *TheLoop, // Get the previous value. The previous value comes from the latch edge while // the initial value comes form the preheader edge. auto *Previous = dyn_cast<Instruction>(Phi->getIncomingValueForBlock(Latch)); - if (!Previous || !TheLoop->contains(Previous) || isa<PHINode>(Previous)) + if (!Previous || !TheLoop->contains(Previous) || isa<PHINode>(Previous) || + SinkAfter.count(Previous)) // Cannot rely on dominance due to motion. return false; - // Ensure every user of the phi node is dominated by the previous value. The - // dominance requirement ensures the loop vectorizer will not need to + // Ensure every user of the phi node is dominated by the previous value. + // The dominance requirement ensures the loop vectorizer will not need to // vectorize the initial value prior to the first iteration of the loop. + // TODO: Consider extending this sinking to handle other kinds of instructions + // and expressions, beyond sinking a single cast past Previous. + if (Phi->hasOneUse()) { + auto *I = Phi->user_back(); + if (I->isCast() && (I->getParent() == Phi->getParent()) && I->hasOneUse() && + DT->dominates(Previous, I->user_back())) { + SinkAfter[I] = Previous; + return true; + } + } + for (User *U : Phi->users()) - if (auto *I = dyn_cast<Instruction>(U)) + if (auto *I = dyn_cast<Instruction>(U)) { if (!DT->dominates(Previous, I)) return false; + } return true; } @@ -916,6 +937,69 @@ bool InductionDescriptor::isInductionPHI(PHINode *Phi, const Loop *TheLoop, return true; } +bool llvm::formDedicatedExitBlocks(Loop *L, DominatorTree *DT, LoopInfo *LI, + bool PreserveLCSSA) { + bool Changed = false; + + // We re-use a vector for the in-loop predecesosrs. + SmallVector<BasicBlock *, 4> InLoopPredecessors; + + auto RewriteExit = [&](BasicBlock *BB) { + assert(InLoopPredecessors.empty() && + "Must start with an empty predecessors list!"); + auto Cleanup = make_scope_exit([&] { InLoopPredecessors.clear(); }); + + // See if there are any non-loop predecessors of this exit block and + // keep track of the in-loop predecessors. + bool IsDedicatedExit = true; + for (auto *PredBB : predecessors(BB)) + if (L->contains(PredBB)) { + if (isa<IndirectBrInst>(PredBB->getTerminator())) + // We cannot rewrite exiting edges from an indirectbr. + return false; + + InLoopPredecessors.push_back(PredBB); + } else { + IsDedicatedExit = false; + } + + assert(!InLoopPredecessors.empty() && "Must have *some* loop predecessor!"); + + // Nothing to do if this is already a dedicated exit. + if (IsDedicatedExit) + return false; + + auto *NewExitBB = SplitBlockPredecessors( + BB, InLoopPredecessors, ".loopexit", DT, LI, PreserveLCSSA); + + if (!NewExitBB) + DEBUG(dbgs() << "WARNING: Can't create a dedicated exit block for loop: " + << *L << "\n"); + else + DEBUG(dbgs() << "LoopSimplify: Creating dedicated exit block " + << NewExitBB->getName() << "\n"); + return true; + }; + + // Walk the exit blocks directly rather than building up a data structure for + // them, but only visit each one once. + SmallPtrSet<BasicBlock *, 4> Visited; + for (auto *BB : L->blocks()) + for (auto *SuccBB : successors(BB)) { + // We're looking for exit blocks so skip in-loop successors. + if (L->contains(SuccBB)) + continue; + + // Visit each exit block exactly once. + if (!Visited.insert(SuccBB).second) + continue; + + Changed |= RewriteExit(SuccBB); + } + + return Changed; +} + /// \brief Returns the instructions that use values defined in the loop. SmallVector<Instruction *, 8> llvm::findDefsUsedOutsideOfLoop(Loop *L) { SmallVector<Instruction *, 8> UsedOutside; @@ -1105,3 +1189,208 @@ Optional<unsigned> llvm::getLoopEstimatedTripCount(Loop *L) { else return (FalseVal + (TrueVal / 2)) / TrueVal; } + +/// \brief Adds a 'fast' flag to floating point operations. +static Value *addFastMathFlag(Value *V) { + if (isa<FPMathOperator>(V)) { + FastMathFlags Flags; + Flags.setUnsafeAlgebra(); + cast<Instruction>(V)->setFastMathFlags(Flags); + } + return V; +} + +// Helper to generate a log2 shuffle reduction. +Value * +llvm::getShuffleReduction(IRBuilder<> &Builder, Value *Src, unsigned Op, + RecurrenceDescriptor::MinMaxRecurrenceKind MinMaxKind, + ArrayRef<Value *> RedOps) { + unsigned VF = Src->getType()->getVectorNumElements(); + // VF is a power of 2 so we can emit the reduction using log2(VF) shuffles + // and vector ops, reducing the set of values being computed by half each + // round. + assert(isPowerOf2_32(VF) && + "Reduction emission only supported for pow2 vectors!"); + Value *TmpVec = Src; + SmallVector<Constant *, 32> ShuffleMask(VF, nullptr); + for (unsigned i = VF; i != 1; i >>= 1) { + // Move the upper half of the vector to the lower half. + for (unsigned j = 0; j != i / 2; ++j) + ShuffleMask[j] = Builder.getInt32(i / 2 + j); + + // Fill the rest of the mask with undef. + std::fill(&ShuffleMask[i / 2], ShuffleMask.end(), + UndefValue::get(Builder.getInt32Ty())); + + Value *Shuf = Builder.CreateShuffleVector( + TmpVec, UndefValue::get(TmpVec->getType()), + ConstantVector::get(ShuffleMask), "rdx.shuf"); + + if (Op != Instruction::ICmp && Op != Instruction::FCmp) { + // Floating point operations had to be 'fast' to enable the reduction. + TmpVec = addFastMathFlag(Builder.CreateBinOp((Instruction::BinaryOps)Op, + TmpVec, Shuf, "bin.rdx")); + } else { + assert(MinMaxKind != RecurrenceDescriptor::MRK_Invalid && + "Invalid min/max"); + TmpVec = RecurrenceDescriptor::createMinMaxOp(Builder, MinMaxKind, TmpVec, + Shuf); + } + if (!RedOps.empty()) + propagateIRFlags(TmpVec, RedOps); + } + // The result is in the first element of the vector. + return Builder.CreateExtractElement(TmpVec, Builder.getInt32(0)); +} + +/// Create a simple vector reduction specified by an opcode and some +/// flags (if generating min/max reductions). +Value *llvm::createSimpleTargetReduction( + IRBuilder<> &Builder, const TargetTransformInfo *TTI, unsigned Opcode, + Value *Src, TargetTransformInfo::ReductionFlags Flags, + ArrayRef<Value *> RedOps) { + assert(isa<VectorType>(Src->getType()) && "Type must be a vector"); + + Value *ScalarUdf = UndefValue::get(Src->getType()->getVectorElementType()); + std::function<Value*()> BuildFunc; + using RD = RecurrenceDescriptor; + RD::MinMaxRecurrenceKind MinMaxKind = RD::MRK_Invalid; + // TODO: Support creating ordered reductions. + FastMathFlags FMFUnsafe; + FMFUnsafe.setUnsafeAlgebra(); + + switch (Opcode) { + case Instruction::Add: + BuildFunc = [&]() { return Builder.CreateAddReduce(Src); }; + break; + case Instruction::Mul: + BuildFunc = [&]() { return Builder.CreateMulReduce(Src); }; + break; + case Instruction::And: + BuildFunc = [&]() { return Builder.CreateAndReduce(Src); }; + break; + case Instruction::Or: + BuildFunc = [&]() { return Builder.CreateOrReduce(Src); }; + break; + case Instruction::Xor: + BuildFunc = [&]() { return Builder.CreateXorReduce(Src); }; + break; + case Instruction::FAdd: + BuildFunc = [&]() { + auto Rdx = Builder.CreateFAddReduce(ScalarUdf, Src); + cast<CallInst>(Rdx)->setFastMathFlags(FMFUnsafe); + return Rdx; + }; + break; + case Instruction::FMul: + BuildFunc = [&]() { + auto Rdx = Builder.CreateFMulReduce(ScalarUdf, Src); + cast<CallInst>(Rdx)->setFastMathFlags(FMFUnsafe); + return Rdx; + }; + break; + case Instruction::ICmp: + if (Flags.IsMaxOp) { + MinMaxKind = Flags.IsSigned ? RD::MRK_SIntMax : RD::MRK_UIntMax; + BuildFunc = [&]() { + return Builder.CreateIntMaxReduce(Src, Flags.IsSigned); + }; + } else { + MinMaxKind = Flags.IsSigned ? RD::MRK_SIntMin : RD::MRK_UIntMin; + BuildFunc = [&]() { + return Builder.CreateIntMinReduce(Src, Flags.IsSigned); + }; + } + break; + case Instruction::FCmp: + if (Flags.IsMaxOp) { + MinMaxKind = RD::MRK_FloatMax; + BuildFunc = [&]() { return Builder.CreateFPMaxReduce(Src, Flags.NoNaN); }; + } else { + MinMaxKind = RD::MRK_FloatMin; + BuildFunc = [&]() { return Builder.CreateFPMinReduce(Src, Flags.NoNaN); }; + } + break; + default: + llvm_unreachable("Unhandled opcode"); + break; + } + if (TTI->useReductionIntrinsic(Opcode, Src->getType(), Flags)) + return BuildFunc(); + return getShuffleReduction(Builder, Src, Opcode, MinMaxKind, RedOps); +} + +/// Create a vector reduction using a given recurrence descriptor. +Value *llvm::createTargetReduction(IRBuilder<> &Builder, + const TargetTransformInfo *TTI, + RecurrenceDescriptor &Desc, Value *Src, + bool NoNaN) { + // TODO: Support in-order reductions based on the recurrence descriptor. + RecurrenceDescriptor::RecurrenceKind RecKind = Desc.getRecurrenceKind(); + TargetTransformInfo::ReductionFlags Flags; + Flags.NoNaN = NoNaN; + auto getSimpleRdx = [&](unsigned Opc) { + return createSimpleTargetReduction(Builder, TTI, Opc, Src, Flags); + }; + switch (RecKind) { + case RecurrenceDescriptor::RK_FloatAdd: + return getSimpleRdx(Instruction::FAdd); + case RecurrenceDescriptor::RK_FloatMult: + return getSimpleRdx(Instruction::FMul); + case RecurrenceDescriptor::RK_IntegerAdd: + return getSimpleRdx(Instruction::Add); + case RecurrenceDescriptor::RK_IntegerMult: + return getSimpleRdx(Instruction::Mul); + case RecurrenceDescriptor::RK_IntegerAnd: + return getSimpleRdx(Instruction::And); + case RecurrenceDescriptor::RK_IntegerOr: + return getSimpleRdx(Instruction::Or); + case RecurrenceDescriptor::RK_IntegerXor: + return getSimpleRdx(Instruction::Xor); + case RecurrenceDescriptor::RK_IntegerMinMax: { + switch (Desc.getMinMaxRecurrenceKind()) { + case RecurrenceDescriptor::MRK_SIntMax: + Flags.IsSigned = true; + Flags.IsMaxOp = true; + break; + case RecurrenceDescriptor::MRK_UIntMax: + Flags.IsMaxOp = true; + break; + case RecurrenceDescriptor::MRK_SIntMin: + Flags.IsSigned = true; + break; + case RecurrenceDescriptor::MRK_UIntMin: + break; + default: + llvm_unreachable("Unhandled MRK"); + } + return getSimpleRdx(Instruction::ICmp); + } + case RecurrenceDescriptor::RK_FloatMinMax: { + Flags.IsMaxOp = + Desc.getMinMaxRecurrenceKind() == RecurrenceDescriptor::MRK_FloatMax; + return getSimpleRdx(Instruction::FCmp); + } + default: + llvm_unreachable("Unhandled RecKind"); + } +} + +void llvm::propagateIRFlags(Value *I, ArrayRef<Value *> VL, Value *OpValue) { + auto *VecOp = dyn_cast<Instruction>(I); + if (!VecOp) + return; + auto *Intersection = (OpValue == nullptr) ? dyn_cast<Instruction>(VL[0]) + : dyn_cast<Instruction>(OpValue); + if (!Intersection) + return; + const unsigned Opcode = Intersection->getOpcode(); + VecOp->copyIRFlags(Intersection); + for (auto *V : VL) { + auto *Instr = dyn_cast<Instruction>(V); + if (!Instr) + continue; + if (OpValue == nullptr || Opcode == Instr->getOpcode()) + VecOp->andIRFlags(V); + } +} diff --git a/contrib/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp b/contrib/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp new file mode 100644 index 0000000..900450b --- /dev/null +++ b/contrib/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp @@ -0,0 +1,510 @@ +//===- LowerMemIntrinsics.cpp ----------------------------------*- C++ -*--===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/LowerMemIntrinsics.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" + +using namespace llvm; + +static unsigned getLoopOperandSizeInBytes(Type *Type) { + if (VectorType *VTy = dyn_cast<VectorType>(Type)) { + return VTy->getBitWidth() / 8; + } + + return Type->getPrimitiveSizeInBits() / 8; +} + +void llvm::createMemCpyLoopKnownSize(Instruction *InsertBefore, Value *SrcAddr, + Value *DstAddr, ConstantInt *CopyLen, + unsigned SrcAlign, unsigned DestAlign, + bool SrcIsVolatile, bool DstIsVolatile, + const TargetTransformInfo &TTI) { + // No need to expand zero length copies. + if (CopyLen->isZero()) + return; + + BasicBlock *PreLoopBB = InsertBefore->getParent(); + BasicBlock *PostLoopBB = nullptr; + Function *ParentFunc = PreLoopBB->getParent(); + LLVMContext &Ctx = PreLoopBB->getContext(); + + Type *TypeOfCopyLen = CopyLen->getType(); + Type *LoopOpType = + TTI.getMemcpyLoopLoweringType(Ctx, CopyLen, SrcAlign, DestAlign); + + unsigned LoopOpSize = getLoopOperandSizeInBytes(LoopOpType); + uint64_t LoopEndCount = CopyLen->getZExtValue() / LoopOpSize; + + unsigned SrcAS = cast<PointerType>(SrcAddr->getType())->getAddressSpace(); + unsigned DstAS = cast<PointerType>(DstAddr->getType())->getAddressSpace(); + + if (LoopEndCount != 0) { + // Split + PostLoopBB = PreLoopBB->splitBasicBlock(InsertBefore, "memcpy-split"); + BasicBlock *LoopBB = + BasicBlock::Create(Ctx, "load-store-loop", ParentFunc, PostLoopBB); + PreLoopBB->getTerminator()->setSuccessor(0, LoopBB); + + IRBuilder<> PLBuilder(PreLoopBB->getTerminator()); + + // Cast the Src and Dst pointers to pointers to the loop operand type (if + // needed). + PointerType *SrcOpType = PointerType::get(LoopOpType, SrcAS); + PointerType *DstOpType = PointerType::get(LoopOpType, DstAS); + if (SrcAddr->getType() != SrcOpType) { + SrcAddr = PLBuilder.CreateBitCast(SrcAddr, SrcOpType); + } + if (DstAddr->getType() != DstOpType) { + DstAddr = PLBuilder.CreateBitCast(DstAddr, DstOpType); + } + + IRBuilder<> LoopBuilder(LoopBB); + PHINode *LoopIndex = LoopBuilder.CreatePHI(TypeOfCopyLen, 2, "loop-index"); + LoopIndex->addIncoming(ConstantInt::get(TypeOfCopyLen, 0U), PreLoopBB); + // Loop Body + Value *SrcGEP = + LoopBuilder.CreateInBoundsGEP(LoopOpType, SrcAddr, LoopIndex); + Value *Load = LoopBuilder.CreateLoad(SrcGEP, SrcIsVolatile); + Value *DstGEP = + LoopBuilder.CreateInBoundsGEP(LoopOpType, DstAddr, LoopIndex); + LoopBuilder.CreateStore(Load, DstGEP, DstIsVolatile); + + Value *NewIndex = + LoopBuilder.CreateAdd(LoopIndex, ConstantInt::get(TypeOfCopyLen, 1U)); + LoopIndex->addIncoming(NewIndex, LoopBB); + + // Create the loop branch condition. + Constant *LoopEndCI = ConstantInt::get(TypeOfCopyLen, LoopEndCount); + LoopBuilder.CreateCondBr(LoopBuilder.CreateICmpULT(NewIndex, LoopEndCI), + LoopBB, PostLoopBB); + } + + uint64_t BytesCopied = LoopEndCount * LoopOpSize; + uint64_t RemainingBytes = CopyLen->getZExtValue() - BytesCopied; + if (RemainingBytes) { + IRBuilder<> RBuilder(PostLoopBB ? PostLoopBB->getFirstNonPHI() + : InsertBefore); + + // Update the alignment based on the copy size used in the loop body. + SrcAlign = std::min(SrcAlign, LoopOpSize); + DestAlign = std::min(DestAlign, LoopOpSize); + + SmallVector<Type *, 5> RemainingOps; + TTI.getMemcpyLoopResidualLoweringType(RemainingOps, Ctx, RemainingBytes, + SrcAlign, DestAlign); + + for (auto OpTy : RemainingOps) { + // Calaculate the new index + unsigned OperandSize = getLoopOperandSizeInBytes(OpTy); + uint64_t GepIndex = BytesCopied / OperandSize; + assert(GepIndex * OperandSize == BytesCopied && + "Division should have no Remainder!"); + // Cast source to operand type and load + PointerType *SrcPtrType = PointerType::get(OpTy, SrcAS); + Value *CastedSrc = SrcAddr->getType() == SrcPtrType + ? SrcAddr + : RBuilder.CreateBitCast(SrcAddr, SrcPtrType); + Value *SrcGEP = RBuilder.CreateInBoundsGEP( + OpTy, CastedSrc, ConstantInt::get(TypeOfCopyLen, GepIndex)); + Value *Load = RBuilder.CreateLoad(SrcGEP, SrcIsVolatile); + + // Cast destination to operand type and store. + PointerType *DstPtrType = PointerType::get(OpTy, DstAS); + Value *CastedDst = DstAddr->getType() == DstPtrType + ? DstAddr + : RBuilder.CreateBitCast(DstAddr, DstPtrType); + Value *DstGEP = RBuilder.CreateInBoundsGEP( + OpTy, CastedDst, ConstantInt::get(TypeOfCopyLen, GepIndex)); + RBuilder.CreateStore(Load, DstGEP, DstIsVolatile); + + BytesCopied += OperandSize; + } + } + assert(BytesCopied == CopyLen->getZExtValue() && + "Bytes copied should match size in the call!"); +} + +void llvm::createMemCpyLoopUnknownSize(Instruction *InsertBefore, + Value *SrcAddr, Value *DstAddr, + Value *CopyLen, unsigned SrcAlign, + unsigned DestAlign, bool SrcIsVolatile, + bool DstIsVolatile, + const TargetTransformInfo &TTI) { + BasicBlock *PreLoopBB = InsertBefore->getParent(); + BasicBlock *PostLoopBB = + PreLoopBB->splitBasicBlock(InsertBefore, "post-loop-memcpy-expansion"); + + Function *ParentFunc = PreLoopBB->getParent(); + LLVMContext &Ctx = PreLoopBB->getContext(); + + Type *LoopOpType = + TTI.getMemcpyLoopLoweringType(Ctx, CopyLen, SrcAlign, DestAlign); + unsigned LoopOpSize = getLoopOperandSizeInBytes(LoopOpType); + + IRBuilder<> PLBuilder(PreLoopBB->getTerminator()); + + unsigned SrcAS = cast<PointerType>(SrcAddr->getType())->getAddressSpace(); + unsigned DstAS = cast<PointerType>(DstAddr->getType())->getAddressSpace(); + PointerType *SrcOpType = PointerType::get(LoopOpType, SrcAS); + PointerType *DstOpType = PointerType::get(LoopOpType, DstAS); + if (SrcAddr->getType() != SrcOpType) { + SrcAddr = PLBuilder.CreateBitCast(SrcAddr, SrcOpType); + } + if (DstAddr->getType() != DstOpType) { + DstAddr = PLBuilder.CreateBitCast(DstAddr, DstOpType); + } + + // Calculate the loop trip count, and remaining bytes to copy after the loop. + Type *CopyLenType = CopyLen->getType(); + IntegerType *ILengthType = dyn_cast<IntegerType>(CopyLenType); + assert(ILengthType && + "expected size argument to memcpy to be an integer type!"); + ConstantInt *CILoopOpSize = ConstantInt::get(ILengthType, LoopOpSize); + Value *RuntimeLoopCount = PLBuilder.CreateUDiv(CopyLen, CILoopOpSize); + Value *RuntimeResidual = PLBuilder.CreateURem(CopyLen, CILoopOpSize); + Value *RuntimeBytesCopied = PLBuilder.CreateSub(CopyLen, RuntimeResidual); + + BasicBlock *LoopBB = + BasicBlock::Create(Ctx, "loop-memcpy-expansion", ParentFunc, nullptr); + IRBuilder<> LoopBuilder(LoopBB); + + PHINode *LoopIndex = LoopBuilder.CreatePHI(CopyLenType, 2, "loop-index"); + LoopIndex->addIncoming(ConstantInt::get(CopyLenType, 0U), PreLoopBB); + + Value *SrcGEP = LoopBuilder.CreateInBoundsGEP(LoopOpType, SrcAddr, LoopIndex); + Value *Load = LoopBuilder.CreateLoad(SrcGEP, SrcIsVolatile); + Value *DstGEP = LoopBuilder.CreateInBoundsGEP(LoopOpType, DstAddr, LoopIndex); + LoopBuilder.CreateStore(Load, DstGEP, DstIsVolatile); + + Value *NewIndex = + LoopBuilder.CreateAdd(LoopIndex, ConstantInt::get(CopyLenType, 1U)); + LoopIndex->addIncoming(NewIndex, LoopBB); + + Type *Int8Type = Type::getInt8Ty(Ctx); + if (LoopOpType != Int8Type) { + // Loop body for the residual copy. + BasicBlock *ResLoopBB = BasicBlock::Create(Ctx, "loop-memcpy-residual", + PreLoopBB->getParent(), nullptr); + // Residual loop header. + BasicBlock *ResHeaderBB = BasicBlock::Create( + Ctx, "loop-memcpy-residual-header", PreLoopBB->getParent(), nullptr); + + // Need to update the pre-loop basic block to branch to the correct place. + // branch to the main loop if the count is non-zero, branch to the residual + // loop if the copy size is smaller then 1 iteration of the main loop but + // non-zero and finally branch to after the residual loop if the memcpy + // size is zero. + ConstantInt *Zero = ConstantInt::get(ILengthType, 0U); + PLBuilder.CreateCondBr(PLBuilder.CreateICmpNE(RuntimeLoopCount, Zero), + LoopBB, ResHeaderBB); + PreLoopBB->getTerminator()->eraseFromParent(); + + LoopBuilder.CreateCondBr( + LoopBuilder.CreateICmpULT(NewIndex, RuntimeLoopCount), LoopBB, + ResHeaderBB); + + // Determine if we need to branch to the residual loop or bypass it. + IRBuilder<> RHBuilder(ResHeaderBB); + RHBuilder.CreateCondBr(RHBuilder.CreateICmpNE(RuntimeResidual, Zero), + ResLoopBB, PostLoopBB); + + // Copy the residual with single byte load/store loop. + IRBuilder<> ResBuilder(ResLoopBB); + PHINode *ResidualIndex = + ResBuilder.CreatePHI(CopyLenType, 2, "residual-loop-index"); + ResidualIndex->addIncoming(Zero, ResHeaderBB); + + Value *SrcAsInt8 = + ResBuilder.CreateBitCast(SrcAddr, PointerType::get(Int8Type, SrcAS)); + Value *DstAsInt8 = + ResBuilder.CreateBitCast(DstAddr, PointerType::get(Int8Type, DstAS)); + Value *FullOffset = ResBuilder.CreateAdd(RuntimeBytesCopied, ResidualIndex); + Value *SrcGEP = + ResBuilder.CreateInBoundsGEP(Int8Type, SrcAsInt8, FullOffset); + Value *Load = ResBuilder.CreateLoad(SrcGEP, SrcIsVolatile); + Value *DstGEP = + ResBuilder.CreateInBoundsGEP(Int8Type, DstAsInt8, FullOffset); + ResBuilder.CreateStore(Load, DstGEP, DstIsVolatile); + + Value *ResNewIndex = + ResBuilder.CreateAdd(ResidualIndex, ConstantInt::get(CopyLenType, 1U)); + ResidualIndex->addIncoming(ResNewIndex, ResLoopBB); + + // Create the loop branch condition. + ResBuilder.CreateCondBr( + ResBuilder.CreateICmpULT(ResNewIndex, RuntimeResidual), ResLoopBB, + PostLoopBB); + } else { + // In this case the loop operand type was a byte, and there is no need for a + // residual loop to copy the remaining memory after the main loop. + // We do however need to patch up the control flow by creating the + // terminators for the preloop block and the memcpy loop. + ConstantInt *Zero = ConstantInt::get(ILengthType, 0U); + PLBuilder.CreateCondBr(PLBuilder.CreateICmpNE(RuntimeLoopCount, Zero), + LoopBB, PostLoopBB); + PreLoopBB->getTerminator()->eraseFromParent(); + LoopBuilder.CreateCondBr( + LoopBuilder.CreateICmpULT(NewIndex, RuntimeLoopCount), LoopBB, + PostLoopBB); + } +} + +void llvm::createMemCpyLoop(Instruction *InsertBefore, + Value *SrcAddr, Value *DstAddr, Value *CopyLen, + unsigned SrcAlign, unsigned DestAlign, + bool SrcIsVolatile, bool DstIsVolatile) { + Type *TypeOfCopyLen = CopyLen->getType(); + + BasicBlock *OrigBB = InsertBefore->getParent(); + Function *F = OrigBB->getParent(); + BasicBlock *NewBB = + InsertBefore->getParent()->splitBasicBlock(InsertBefore, "split"); + BasicBlock *LoopBB = BasicBlock::Create(F->getContext(), "loadstoreloop", + F, NewBB); + + IRBuilder<> Builder(OrigBB->getTerminator()); + + // SrcAddr and DstAddr are expected to be pointer types, + // so no check is made here. + unsigned SrcAS = cast<PointerType>(SrcAddr->getType())->getAddressSpace(); + unsigned DstAS = cast<PointerType>(DstAddr->getType())->getAddressSpace(); + + // Cast pointers to (char *) + SrcAddr = Builder.CreateBitCast(SrcAddr, Builder.getInt8PtrTy(SrcAS)); + DstAddr = Builder.CreateBitCast(DstAddr, Builder.getInt8PtrTy(DstAS)); + + Builder.CreateCondBr( + Builder.CreateICmpEQ(ConstantInt::get(TypeOfCopyLen, 0), CopyLen), NewBB, + LoopBB); + OrigBB->getTerminator()->eraseFromParent(); + + IRBuilder<> LoopBuilder(LoopBB); + PHINode *LoopIndex = LoopBuilder.CreatePHI(TypeOfCopyLen, 0); + LoopIndex->addIncoming(ConstantInt::get(TypeOfCopyLen, 0), OrigBB); + + // load from SrcAddr+LoopIndex + // TODO: we can leverage the align parameter of llvm.memcpy for more efficient + // word-sized loads and stores. + Value *Element = + LoopBuilder.CreateLoad(LoopBuilder.CreateInBoundsGEP( + LoopBuilder.getInt8Ty(), SrcAddr, LoopIndex), + SrcIsVolatile); + // store at DstAddr+LoopIndex + LoopBuilder.CreateStore(Element, + LoopBuilder.CreateInBoundsGEP(LoopBuilder.getInt8Ty(), + DstAddr, LoopIndex), + DstIsVolatile); + + // The value for LoopIndex coming from backedge is (LoopIndex + 1) + Value *NewIndex = + LoopBuilder.CreateAdd(LoopIndex, ConstantInt::get(TypeOfCopyLen, 1)); + LoopIndex->addIncoming(NewIndex, LoopBB); + + LoopBuilder.CreateCondBr(LoopBuilder.CreateICmpULT(NewIndex, CopyLen), LoopBB, + NewBB); +} + +// Lower memmove to IR. memmove is required to correctly copy overlapping memory +// regions; therefore, it has to check the relative positions of the source and +// destination pointers and choose the copy direction accordingly. +// +// The code below is an IR rendition of this C function: +// +// void* memmove(void* dst, const void* src, size_t n) { +// unsigned char* d = dst; +// const unsigned char* s = src; +// if (s < d) { +// // copy backwards +// while (n--) { +// d[n] = s[n]; +// } +// } else { +// // copy forward +// for (size_t i = 0; i < n; ++i) { +// d[i] = s[i]; +// } +// } +// return dst; +// } +static void createMemMoveLoop(Instruction *InsertBefore, + Value *SrcAddr, Value *DstAddr, Value *CopyLen, + unsigned SrcAlign, unsigned DestAlign, + bool SrcIsVolatile, bool DstIsVolatile) { + Type *TypeOfCopyLen = CopyLen->getType(); + BasicBlock *OrigBB = InsertBefore->getParent(); + Function *F = OrigBB->getParent(); + + // Create the a comparison of src and dst, based on which we jump to either + // the forward-copy part of the function (if src >= dst) or the backwards-copy + // part (if src < dst). + // SplitBlockAndInsertIfThenElse conveniently creates the basic if-then-else + // structure. Its block terminators (unconditional branches) are replaced by + // the appropriate conditional branches when the loop is built. + ICmpInst *PtrCompare = new ICmpInst(InsertBefore, ICmpInst::ICMP_ULT, + SrcAddr, DstAddr, "compare_src_dst"); + TerminatorInst *ThenTerm, *ElseTerm; + SplitBlockAndInsertIfThenElse(PtrCompare, InsertBefore, &ThenTerm, + &ElseTerm); + + // Each part of the function consists of two blocks: + // copy_backwards: used to skip the loop when n == 0 + // copy_backwards_loop: the actual backwards loop BB + // copy_forward: used to skip the loop when n == 0 + // copy_forward_loop: the actual forward loop BB + BasicBlock *CopyBackwardsBB = ThenTerm->getParent(); + CopyBackwardsBB->setName("copy_backwards"); + BasicBlock *CopyForwardBB = ElseTerm->getParent(); + CopyForwardBB->setName("copy_forward"); + BasicBlock *ExitBB = InsertBefore->getParent(); + ExitBB->setName("memmove_done"); + + // Initial comparison of n == 0 that lets us skip the loops altogether. Shared + // between both backwards and forward copy clauses. + ICmpInst *CompareN = + new ICmpInst(OrigBB->getTerminator(), ICmpInst::ICMP_EQ, CopyLen, + ConstantInt::get(TypeOfCopyLen, 0), "compare_n_to_0"); + + // Copying backwards. + BasicBlock *LoopBB = + BasicBlock::Create(F->getContext(), "copy_backwards_loop", F, CopyForwardBB); + IRBuilder<> LoopBuilder(LoopBB); + PHINode *LoopPhi = LoopBuilder.CreatePHI(TypeOfCopyLen, 0); + Value *IndexPtr = LoopBuilder.CreateSub( + LoopPhi, ConstantInt::get(TypeOfCopyLen, 1), "index_ptr"); + Value *Element = LoopBuilder.CreateLoad( + LoopBuilder.CreateInBoundsGEP(SrcAddr, IndexPtr), "element"); + LoopBuilder.CreateStore(Element, + LoopBuilder.CreateInBoundsGEP(DstAddr, IndexPtr)); + LoopBuilder.CreateCondBr( + LoopBuilder.CreateICmpEQ(IndexPtr, ConstantInt::get(TypeOfCopyLen, 0)), + ExitBB, LoopBB); + LoopPhi->addIncoming(IndexPtr, LoopBB); + LoopPhi->addIncoming(CopyLen, CopyBackwardsBB); + BranchInst::Create(ExitBB, LoopBB, CompareN, ThenTerm); + ThenTerm->eraseFromParent(); + + // Copying forward. + BasicBlock *FwdLoopBB = + BasicBlock::Create(F->getContext(), "copy_forward_loop", F, ExitBB); + IRBuilder<> FwdLoopBuilder(FwdLoopBB); + PHINode *FwdCopyPhi = FwdLoopBuilder.CreatePHI(TypeOfCopyLen, 0, "index_ptr"); + Value *FwdElement = FwdLoopBuilder.CreateLoad( + FwdLoopBuilder.CreateInBoundsGEP(SrcAddr, FwdCopyPhi), "element"); + FwdLoopBuilder.CreateStore( + FwdElement, FwdLoopBuilder.CreateInBoundsGEP(DstAddr, FwdCopyPhi)); + Value *FwdIndexPtr = FwdLoopBuilder.CreateAdd( + FwdCopyPhi, ConstantInt::get(TypeOfCopyLen, 1), "index_increment"); + FwdLoopBuilder.CreateCondBr(FwdLoopBuilder.CreateICmpEQ(FwdIndexPtr, CopyLen), + ExitBB, FwdLoopBB); + FwdCopyPhi->addIncoming(FwdIndexPtr, FwdLoopBB); + FwdCopyPhi->addIncoming(ConstantInt::get(TypeOfCopyLen, 0), CopyForwardBB); + + BranchInst::Create(ExitBB, FwdLoopBB, CompareN, ElseTerm); + ElseTerm->eraseFromParent(); +} + +static void createMemSetLoop(Instruction *InsertBefore, + Value *DstAddr, Value *CopyLen, Value *SetValue, + unsigned Align, bool IsVolatile) { + Type *TypeOfCopyLen = CopyLen->getType(); + BasicBlock *OrigBB = InsertBefore->getParent(); + Function *F = OrigBB->getParent(); + BasicBlock *NewBB = + OrigBB->splitBasicBlock(InsertBefore, "split"); + BasicBlock *LoopBB + = BasicBlock::Create(F->getContext(), "loadstoreloop", F, NewBB); + + IRBuilder<> Builder(OrigBB->getTerminator()); + + // Cast pointer to the type of value getting stored + unsigned dstAS = cast<PointerType>(DstAddr->getType())->getAddressSpace(); + DstAddr = Builder.CreateBitCast(DstAddr, + PointerType::get(SetValue->getType(), dstAS)); + + Builder.CreateCondBr( + Builder.CreateICmpEQ(ConstantInt::get(TypeOfCopyLen, 0), CopyLen), NewBB, + LoopBB); + OrigBB->getTerminator()->eraseFromParent(); + + IRBuilder<> LoopBuilder(LoopBB); + PHINode *LoopIndex = LoopBuilder.CreatePHI(TypeOfCopyLen, 0); + LoopIndex->addIncoming(ConstantInt::get(TypeOfCopyLen, 0), OrigBB); + + LoopBuilder.CreateStore( + SetValue, + LoopBuilder.CreateInBoundsGEP(SetValue->getType(), DstAddr, LoopIndex), + IsVolatile); + + Value *NewIndex = + LoopBuilder.CreateAdd(LoopIndex, ConstantInt::get(TypeOfCopyLen, 1)); + LoopIndex->addIncoming(NewIndex, LoopBB); + + LoopBuilder.CreateCondBr(LoopBuilder.CreateICmpULT(NewIndex, CopyLen), LoopBB, + NewBB); +} + +void llvm::expandMemCpyAsLoop(MemCpyInst *Memcpy, + const TargetTransformInfo &TTI) { + // Original implementation + if (!TTI.useWideIRMemcpyLoopLowering()) { + createMemCpyLoop(/* InsertBefore */ Memcpy, + /* SrcAddr */ Memcpy->getRawSource(), + /* DstAddr */ Memcpy->getRawDest(), + /* CopyLen */ Memcpy->getLength(), + /* SrcAlign */ Memcpy->getAlignment(), + /* DestAlign */ Memcpy->getAlignment(), + /* SrcIsVolatile */ Memcpy->isVolatile(), + /* DstIsVolatile */ Memcpy->isVolatile()); + } else { + if (ConstantInt *CI = dyn_cast<ConstantInt>(Memcpy->getLength())) { + createMemCpyLoopKnownSize(/* InsertBefore */ Memcpy, + /* SrcAddr */ Memcpy->getRawSource(), + /* DstAddr */ Memcpy->getRawDest(), + /* CopyLen */ CI, + /* SrcAlign */ Memcpy->getAlignment(), + /* DestAlign */ Memcpy->getAlignment(), + /* SrcIsVolatile */ Memcpy->isVolatile(), + /* DstIsVolatile */ Memcpy->isVolatile(), + /* TargetTransformInfo */ TTI); + } else { + createMemCpyLoopUnknownSize(/* InsertBefore */ Memcpy, + /* SrcAddr */ Memcpy->getRawSource(), + /* DstAddr */ Memcpy->getRawDest(), + /* CopyLen */ Memcpy->getLength(), + /* SrcAlign */ Memcpy->getAlignment(), + /* DestAlign */ Memcpy->getAlignment(), + /* SrcIsVolatile */ Memcpy->isVolatile(), + /* DstIsVolatile */ Memcpy->isVolatile(), + /* TargetTransfomrInfo */ TTI); + } + } +} + +void llvm::expandMemMoveAsLoop(MemMoveInst *Memmove) { + createMemMoveLoop(/* InsertBefore */ Memmove, + /* SrcAddr */ Memmove->getRawSource(), + /* DstAddr */ Memmove->getRawDest(), + /* CopyLen */ Memmove->getLength(), + /* SrcAlign */ Memmove->getAlignment(), + /* DestAlign */ Memmove->getAlignment(), + /* SrcIsVolatile */ Memmove->isVolatile(), + /* DstIsVolatile */ Memmove->isVolatile()); +} + +void llvm::expandMemSetAsLoop(MemSetInst *Memset) { + createMemSetLoop(/* InsertBefore */ Memset, + /* DstAddr */ Memset->getRawDest(), + /* CopyLen */ Memset->getLength(), + /* SetValue */ Memset->getValue(), + /* Alignment */ Memset->getAlignment(), + Memset->isVolatile()); +} diff --git a/contrib/llvm/lib/Transforms/Utils/LowerSwitch.cpp b/contrib/llvm/lib/Transforms/Utils/LowerSwitch.cpp index 75cd3bc..890afbc 100644 --- a/contrib/llvm/lib/Transforms/Utils/LowerSwitch.cpp +++ b/contrib/llvm/lib/Transforms/Utils/LowerSwitch.cpp @@ -13,7 +13,6 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar.h" #include "llvm/ADT/STLExtras.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" @@ -24,6 +23,7 @@ #include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h" #include <algorithm> @@ -356,10 +356,10 @@ unsigned LowerSwitch::Clusterify(CaseVector& Cases, SwitchInst *SI) { unsigned numCmps = 0; // Start with "simple" cases - for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e; ++i) - Cases.push_back(CaseRange(i.getCaseValue(), i.getCaseValue(), - i.getCaseSuccessor())); - + for (auto Case : SI->cases()) + Cases.push_back(CaseRange(Case.getCaseValue(), Case.getCaseValue(), + Case.getCaseSuccessor())); + std::sort(Cases.begin(), Cases.end(), CaseCmp()); // Merge case into clusters @@ -403,6 +403,14 @@ void LowerSwitch::processSwitchInst(SwitchInst *SI, Value *Val = SI->getCondition(); // The value we are switching on... BasicBlock* Default = SI->getDefaultDest(); + // Don't handle unreachable blocks. If there are successors with phis, this + // would leave them behind with missing predecessors. + if ((CurBlock != &F->getEntryBlock() && pred_empty(CurBlock)) || + CurBlock->getSinglePredecessor() == CurBlock) { + DeleteList.insert(CurBlock); + return; + } + // If there is only the default destination, just branch. if (!SI->getNumCases()) { BranchInst::Create(Default, CurBlock); diff --git a/contrib/llvm/lib/Transforms/Utils/Mem2Reg.cpp b/contrib/llvm/lib/Transforms/Utils/Mem2Reg.cpp index 24b3b12..b659a2e 100644 --- a/contrib/llvm/lib/Transforms/Utils/Mem2Reg.cpp +++ b/contrib/llvm/lib/Transforms/Utils/Mem2Reg.cpp @@ -46,7 +46,7 @@ static bool promoteMemoryToRegister(Function &F, DominatorTree &DT, if (Allocas.empty()) break; - PromoteMemToReg(Allocas, DT, nullptr, &AC); + PromoteMemToReg(Allocas, DT, &AC); NumPromoted += Allocas.size(); Changed = true; } @@ -59,8 +59,9 @@ PreservedAnalyses PromotePass::run(Function &F, FunctionAnalysisManager &AM) { if (!promoteMemoryToRegister(F, DT, AC)) return PreservedAnalyses::all(); - // FIXME: This should also 'preserve the CFG'. - return PreservedAnalyses::none(); + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + return PA; } namespace { diff --git a/contrib/llvm/lib/Transforms/Utils/MemorySSA.cpp b/contrib/llvm/lib/Transforms/Utils/MemorySSA.cpp deleted file mode 100644 index 1ce4225..0000000 --- a/contrib/llvm/lib/Transforms/Utils/MemorySSA.cpp +++ /dev/null @@ -1,2305 +0,0 @@ -//===-- MemorySSA.cpp - Memory SSA Builder---------------------------===// -// -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. -// -//===----------------------------------------------------------------===// -// -// This file implements the MemorySSA class. -// -//===----------------------------------------------------------------===// -#include "llvm/Transforms/Utils/MemorySSA.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/DepthFirstIterator.h" -#include "llvm/ADT/GraphTraits.h" -#include "llvm/ADT/PostOrderIterator.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallBitVector.h" -#include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/SmallSet.h" -#include "llvm/ADT/Statistic.h" -#include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/Analysis/CFG.h" -#include "llvm/Analysis/GlobalsModRef.h" -#include "llvm/Analysis/IteratedDominanceFrontier.h" -#include "llvm/Analysis/MemoryLocation.h" -#include "llvm/Analysis/PHITransAddr.h" -#include "llvm/IR/AssemblyAnnotationWriter.h" -#include "llvm/IR/DataLayout.h" -#include "llvm/IR/Dominators.h" -#include "llvm/IR/GlobalVariable.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/LLVMContext.h" -#include "llvm/IR/Metadata.h" -#include "llvm/IR/Module.h" -#include "llvm/IR/PatternMatch.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/FormattedStream.h" -#include "llvm/Transforms/Scalar.h" -#include <algorithm> - -#define DEBUG_TYPE "memoryssa" -using namespace llvm; -STATISTIC(NumClobberCacheLookups, "Number of Memory SSA version cache lookups"); -STATISTIC(NumClobberCacheHits, "Number of Memory SSA version cache hits"); -STATISTIC(NumClobberCacheInserts, "Number of MemorySSA version cache inserts"); - -INITIALIZE_PASS_BEGIN(MemorySSAWrapperPass, "memoryssa", "Memory SSA", false, - true) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) -INITIALIZE_PASS_END(MemorySSAWrapperPass, "memoryssa", "Memory SSA", false, - true) - -INITIALIZE_PASS_BEGIN(MemorySSAPrinterLegacyPass, "print-memoryssa", - "Memory SSA Printer", false, false) -INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) -INITIALIZE_PASS_END(MemorySSAPrinterLegacyPass, "print-memoryssa", - "Memory SSA Printer", false, false) - -static cl::opt<unsigned> MaxCheckLimit( - "memssa-check-limit", cl::Hidden, cl::init(100), - cl::desc("The maximum number of stores/phis MemorySSA" - "will consider trying to walk past (default = 100)")); - -static cl::opt<bool> - VerifyMemorySSA("verify-memoryssa", cl::init(false), cl::Hidden, - cl::desc("Verify MemorySSA in legacy printer pass.")); - -namespace llvm { -/// \brief An assembly annotator class to print Memory SSA information in -/// comments. -class MemorySSAAnnotatedWriter : public AssemblyAnnotationWriter { - friend class MemorySSA; - const MemorySSA *MSSA; - -public: - MemorySSAAnnotatedWriter(const MemorySSA *M) : MSSA(M) {} - - virtual void emitBasicBlockStartAnnot(const BasicBlock *BB, - formatted_raw_ostream &OS) { - if (MemoryAccess *MA = MSSA->getMemoryAccess(BB)) - OS << "; " << *MA << "\n"; - } - - virtual void emitInstructionAnnot(const Instruction *I, - formatted_raw_ostream &OS) { - if (MemoryAccess *MA = MSSA->getMemoryAccess(I)) - OS << "; " << *MA << "\n"; - } -}; -} - -namespace { -/// Our current alias analysis API differentiates heavily between calls and -/// non-calls, and functions called on one usually assert on the other. -/// This class encapsulates the distinction to simplify other code that wants -/// "Memory affecting instructions and related data" to use as a key. -/// For example, this class is used as a densemap key in the use optimizer. -class MemoryLocOrCall { -public: - MemoryLocOrCall() : IsCall(false) {} - MemoryLocOrCall(MemoryUseOrDef *MUD) - : MemoryLocOrCall(MUD->getMemoryInst()) {} - MemoryLocOrCall(const MemoryUseOrDef *MUD) - : MemoryLocOrCall(MUD->getMemoryInst()) {} - - MemoryLocOrCall(Instruction *Inst) { - if (ImmutableCallSite(Inst)) { - IsCall = true; - CS = ImmutableCallSite(Inst); - } else { - IsCall = false; - // There is no such thing as a memorylocation for a fence inst, and it is - // unique in that regard. - if (!isa<FenceInst>(Inst)) - Loc = MemoryLocation::get(Inst); - } - } - - explicit MemoryLocOrCall(const MemoryLocation &Loc) - : IsCall(false), Loc(Loc) {} - - bool IsCall; - ImmutableCallSite getCS() const { - assert(IsCall); - return CS; - } - MemoryLocation getLoc() const { - assert(!IsCall); - return Loc; - } - - bool operator==(const MemoryLocOrCall &Other) const { - if (IsCall != Other.IsCall) - return false; - - if (IsCall) - return CS.getCalledValue() == Other.CS.getCalledValue(); - return Loc == Other.Loc; - } - -private: - union { - ImmutableCallSite CS; - MemoryLocation Loc; - }; -}; -} - -namespace llvm { -template <> struct DenseMapInfo<MemoryLocOrCall> { - static inline MemoryLocOrCall getEmptyKey() { - return MemoryLocOrCall(DenseMapInfo<MemoryLocation>::getEmptyKey()); - } - static inline MemoryLocOrCall getTombstoneKey() { - return MemoryLocOrCall(DenseMapInfo<MemoryLocation>::getTombstoneKey()); - } - static unsigned getHashValue(const MemoryLocOrCall &MLOC) { - if (MLOC.IsCall) - return hash_combine(MLOC.IsCall, - DenseMapInfo<const Value *>::getHashValue( - MLOC.getCS().getCalledValue())); - return hash_combine( - MLOC.IsCall, DenseMapInfo<MemoryLocation>::getHashValue(MLOC.getLoc())); - } - static bool isEqual(const MemoryLocOrCall &LHS, const MemoryLocOrCall &RHS) { - return LHS == RHS; - } -}; - -enum class Reorderability { Always, IfNoAlias, Never }; - -/// This does one-way checks to see if Use could theoretically be hoisted above -/// MayClobber. This will not check the other way around. -/// -/// This assumes that, for the purposes of MemorySSA, Use comes directly after -/// MayClobber, with no potentially clobbering operations in between them. -/// (Where potentially clobbering ops are memory barriers, aliased stores, etc.) -static Reorderability getLoadReorderability(const LoadInst *Use, - const LoadInst *MayClobber) { - bool VolatileUse = Use->isVolatile(); - bool VolatileClobber = MayClobber->isVolatile(); - // Volatile operations may never be reordered with other volatile operations. - if (VolatileUse && VolatileClobber) - return Reorderability::Never; - - // The lang ref allows reordering of volatile and non-volatile operations. - // Whether an aliasing nonvolatile load and volatile load can be reordered, - // though, is ambiguous. Because it may not be best to exploit this ambiguity, - // we only allow volatile/non-volatile reordering if the volatile and - // non-volatile operations don't alias. - Reorderability Result = VolatileUse || VolatileClobber - ? Reorderability::IfNoAlias - : Reorderability::Always; - - // If a load is seq_cst, it cannot be moved above other loads. If its ordering - // is weaker, it can be moved above other loads. We just need to be sure that - // MayClobber isn't an acquire load, because loads can't be moved above - // acquire loads. - // - // Note that this explicitly *does* allow the free reordering of monotonic (or - // weaker) loads of the same address. - bool SeqCstUse = Use->getOrdering() == AtomicOrdering::SequentiallyConsistent; - bool MayClobberIsAcquire = isAtLeastOrStrongerThan(MayClobber->getOrdering(), - AtomicOrdering::Acquire); - if (SeqCstUse || MayClobberIsAcquire) - return Reorderability::Never; - return Result; -} - -static bool instructionClobbersQuery(MemoryDef *MD, - const MemoryLocation &UseLoc, - const Instruction *UseInst, - AliasAnalysis &AA) { - Instruction *DefInst = MD->getMemoryInst(); - assert(DefInst && "Defining instruction not actually an instruction"); - - if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(DefInst)) { - // These intrinsics will show up as affecting memory, but they are just - // markers. - switch (II->getIntrinsicID()) { - case Intrinsic::lifetime_start: - case Intrinsic::lifetime_end: - case Intrinsic::invariant_start: - case Intrinsic::invariant_end: - case Intrinsic::assume: - return false; - default: - break; - } - } - - ImmutableCallSite UseCS(UseInst); - if (UseCS) { - ModRefInfo I = AA.getModRefInfo(DefInst, UseCS); - return I != MRI_NoModRef; - } - - if (auto *DefLoad = dyn_cast<LoadInst>(DefInst)) { - if (auto *UseLoad = dyn_cast<LoadInst>(UseInst)) { - switch (getLoadReorderability(UseLoad, DefLoad)) { - case Reorderability::Always: - return false; - case Reorderability::Never: - return true; - case Reorderability::IfNoAlias: - return !AA.isNoAlias(UseLoc, MemoryLocation::get(DefLoad)); - } - } - } - - return AA.getModRefInfo(DefInst, UseLoc) & MRI_Mod; -} - -static bool instructionClobbersQuery(MemoryDef *MD, const MemoryUseOrDef *MU, - const MemoryLocOrCall &UseMLOC, - AliasAnalysis &AA) { - // FIXME: This is a temporary hack to allow a single instructionClobbersQuery - // to exist while MemoryLocOrCall is pushed through places. - if (UseMLOC.IsCall) - return instructionClobbersQuery(MD, MemoryLocation(), MU->getMemoryInst(), - AA); - return instructionClobbersQuery(MD, UseMLOC.getLoc(), MU->getMemoryInst(), - AA); -} - -// Return true when MD may alias MU, return false otherwise. -bool defClobbersUseOrDef(MemoryDef *MD, const MemoryUseOrDef *MU, - AliasAnalysis &AA) { - return instructionClobbersQuery(MD, MU, MemoryLocOrCall(MU), AA); -} -} - -namespace { -struct UpwardsMemoryQuery { - // True if our original query started off as a call - bool IsCall; - // The pointer location we started the query with. This will be empty if - // IsCall is true. - MemoryLocation StartingLoc; - // This is the instruction we were querying about. - const Instruction *Inst; - // The MemoryAccess we actually got called with, used to test local domination - const MemoryAccess *OriginalAccess; - - UpwardsMemoryQuery() - : IsCall(false), Inst(nullptr), OriginalAccess(nullptr) {} - - UpwardsMemoryQuery(const Instruction *Inst, const MemoryAccess *Access) - : IsCall(ImmutableCallSite(Inst)), Inst(Inst), OriginalAccess(Access) { - if (!IsCall) - StartingLoc = MemoryLocation::get(Inst); - } -}; - -static bool lifetimeEndsAt(MemoryDef *MD, const MemoryLocation &Loc, - AliasAnalysis &AA) { - Instruction *Inst = MD->getMemoryInst(); - if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) { - switch (II->getIntrinsicID()) { - case Intrinsic::lifetime_start: - case Intrinsic::lifetime_end: - return AA.isMustAlias(MemoryLocation(II->getArgOperand(1)), Loc); - default: - return false; - } - } - return false; -} - -static bool isUseTriviallyOptimizableToLiveOnEntry(AliasAnalysis &AA, - const Instruction *I) { - // If the memory can't be changed, then loads of the memory can't be - // clobbered. - // - // FIXME: We should handle invariant groups, as well. It's a bit harder, - // because we need to pay close attention to invariant group barriers. - return isa<LoadInst>(I) && (I->getMetadata(LLVMContext::MD_invariant_load) || - AA.pointsToConstantMemory(I)); -} - -/// Cache for our caching MemorySSA walker. -class WalkerCache { - DenseMap<ConstMemoryAccessPair, MemoryAccess *> Accesses; - DenseMap<const MemoryAccess *, MemoryAccess *> Calls; - -public: - MemoryAccess *lookup(const MemoryAccess *MA, const MemoryLocation &Loc, - bool IsCall) const { - ++NumClobberCacheLookups; - MemoryAccess *R = IsCall ? Calls.lookup(MA) : Accesses.lookup({MA, Loc}); - if (R) - ++NumClobberCacheHits; - return R; - } - - bool insert(const MemoryAccess *MA, MemoryAccess *To, - const MemoryLocation &Loc, bool IsCall) { - // This is fine for Phis, since there are times where we can't optimize - // them. Making a def its own clobber is never correct, though. - assert((MA != To || isa<MemoryPhi>(MA)) && - "Something can't clobber itself!"); - - ++NumClobberCacheInserts; - bool Inserted; - if (IsCall) - Inserted = Calls.insert({MA, To}).second; - else - Inserted = Accesses.insert({{MA, Loc}, To}).second; - - return Inserted; - } - - bool remove(const MemoryAccess *MA, const MemoryLocation &Loc, bool IsCall) { - return IsCall ? Calls.erase(MA) : Accesses.erase({MA, Loc}); - } - - void clear() { - Accesses.clear(); - Calls.clear(); - } - - bool contains(const MemoryAccess *MA) const { - for (auto &P : Accesses) - if (P.first.first == MA || P.second == MA) - return true; - for (auto &P : Calls) - if (P.first == MA || P.second == MA) - return true; - return false; - } -}; - -/// Walks the defining uses of MemoryDefs. Stops after we hit something that has -/// no defining use (e.g. a MemoryPhi or liveOnEntry). Note that, when comparing -/// against a null def_chain_iterator, this will compare equal only after -/// walking said Phi/liveOnEntry. -struct def_chain_iterator - : public iterator_facade_base<def_chain_iterator, std::forward_iterator_tag, - MemoryAccess *> { - def_chain_iterator() : MA(nullptr) {} - def_chain_iterator(MemoryAccess *MA) : MA(MA) {} - - MemoryAccess *operator*() const { return MA; } - - def_chain_iterator &operator++() { - // N.B. liveOnEntry has a null defining access. - if (auto *MUD = dyn_cast<MemoryUseOrDef>(MA)) - MA = MUD->getDefiningAccess(); - else - MA = nullptr; - return *this; - } - - bool operator==(const def_chain_iterator &O) const { return MA == O.MA; } - -private: - MemoryAccess *MA; -}; - -static iterator_range<def_chain_iterator> -def_chain(MemoryAccess *MA, MemoryAccess *UpTo = nullptr) { -#ifdef EXPENSIVE_CHECKS - assert((!UpTo || find(def_chain(MA), UpTo) != def_chain_iterator()) && - "UpTo isn't in the def chain!"); -#endif - return make_range(def_chain_iterator(MA), def_chain_iterator(UpTo)); -} - -/// Verifies that `Start` is clobbered by `ClobberAt`, and that nothing -/// inbetween `Start` and `ClobberAt` can clobbers `Start`. -/// -/// This is meant to be as simple and self-contained as possible. Because it -/// uses no cache, etc., it can be relatively expensive. -/// -/// \param Start The MemoryAccess that we want to walk from. -/// \param ClobberAt A clobber for Start. -/// \param StartLoc The MemoryLocation for Start. -/// \param MSSA The MemorySSA isntance that Start and ClobberAt belong to. -/// \param Query The UpwardsMemoryQuery we used for our search. -/// \param AA The AliasAnalysis we used for our search. -static void LLVM_ATTRIBUTE_UNUSED -checkClobberSanity(MemoryAccess *Start, MemoryAccess *ClobberAt, - const MemoryLocation &StartLoc, const MemorySSA &MSSA, - const UpwardsMemoryQuery &Query, AliasAnalysis &AA) { - assert(MSSA.dominates(ClobberAt, Start) && "Clobber doesn't dominate start?"); - - if (MSSA.isLiveOnEntryDef(Start)) { - assert(MSSA.isLiveOnEntryDef(ClobberAt) && - "liveOnEntry must clobber itself"); - return; - } - - bool FoundClobber = false; - DenseSet<MemoryAccessPair> VisitedPhis; - SmallVector<MemoryAccessPair, 8> Worklist; - Worklist.emplace_back(Start, StartLoc); - // Walk all paths from Start to ClobberAt, while looking for clobbers. If one - // is found, complain. - while (!Worklist.empty()) { - MemoryAccessPair MAP = Worklist.pop_back_val(); - // All we care about is that nothing from Start to ClobberAt clobbers Start. - // We learn nothing from revisiting nodes. - if (!VisitedPhis.insert(MAP).second) - continue; - - for (MemoryAccess *MA : def_chain(MAP.first)) { - if (MA == ClobberAt) { - if (auto *MD = dyn_cast<MemoryDef>(MA)) { - // instructionClobbersQuery isn't essentially free, so don't use `|=`, - // since it won't let us short-circuit. - // - // Also, note that this can't be hoisted out of the `Worklist` loop, - // since MD may only act as a clobber for 1 of N MemoryLocations. - FoundClobber = - FoundClobber || MSSA.isLiveOnEntryDef(MD) || - instructionClobbersQuery(MD, MAP.second, Query.Inst, AA); - } - break; - } - - // We should never hit liveOnEntry, unless it's the clobber. - assert(!MSSA.isLiveOnEntryDef(MA) && "Hit liveOnEntry before clobber?"); - - if (auto *MD = dyn_cast<MemoryDef>(MA)) { - (void)MD; - assert(!instructionClobbersQuery(MD, MAP.second, Query.Inst, AA) && - "Found clobber before reaching ClobberAt!"); - continue; - } - - assert(isa<MemoryPhi>(MA)); - Worklist.append(upward_defs_begin({MA, MAP.second}), upward_defs_end()); - } - } - - // If ClobberAt is a MemoryPhi, we can assume something above it acted as a - // clobber. Otherwise, `ClobberAt` should've acted as a clobber at some point. - assert((isa<MemoryPhi>(ClobberAt) || FoundClobber) && - "ClobberAt never acted as a clobber"); -} - -/// Our algorithm for walking (and trying to optimize) clobbers, all wrapped up -/// in one class. -class ClobberWalker { - /// Save a few bytes by using unsigned instead of size_t. - using ListIndex = unsigned; - - /// Represents a span of contiguous MemoryDefs, potentially ending in a - /// MemoryPhi. - struct DefPath { - MemoryLocation Loc; - // Note that, because we always walk in reverse, Last will always dominate - // First. Also note that First and Last are inclusive. - MemoryAccess *First; - MemoryAccess *Last; - Optional<ListIndex> Previous; - - DefPath(const MemoryLocation &Loc, MemoryAccess *First, MemoryAccess *Last, - Optional<ListIndex> Previous) - : Loc(Loc), First(First), Last(Last), Previous(Previous) {} - - DefPath(const MemoryLocation &Loc, MemoryAccess *Init, - Optional<ListIndex> Previous) - : DefPath(Loc, Init, Init, Previous) {} - }; - - const MemorySSA &MSSA; - AliasAnalysis &AA; - DominatorTree &DT; - WalkerCache &WC; - UpwardsMemoryQuery *Query; - bool UseCache; - - // Phi optimization bookkeeping - SmallVector<DefPath, 32> Paths; - DenseSet<ConstMemoryAccessPair> VisitedPhis; - DenseMap<const BasicBlock *, MemoryAccess *> WalkTargetCache; - - void setUseCache(bool Use) { UseCache = Use; } - bool shouldIgnoreCache() const { - // UseCache will only be false when we're debugging, or when expensive - // checks are enabled. In either case, we don't care deeply about speed. - return LLVM_UNLIKELY(!UseCache); - } - - void addCacheEntry(const MemoryAccess *What, MemoryAccess *To, - const MemoryLocation &Loc) const { -// EXPENSIVE_CHECKS because most of these queries are redundant. -#ifdef EXPENSIVE_CHECKS - assert(MSSA.dominates(To, What)); -#endif - if (shouldIgnoreCache()) - return; - WC.insert(What, To, Loc, Query->IsCall); - } - - MemoryAccess *lookupCache(const MemoryAccess *MA, const MemoryLocation &Loc) { - return shouldIgnoreCache() ? nullptr : WC.lookup(MA, Loc, Query->IsCall); - } - - void cacheDefPath(const DefPath &DN, MemoryAccess *Target) const { - if (shouldIgnoreCache()) - return; - - for (MemoryAccess *MA : def_chain(DN.First, DN.Last)) - addCacheEntry(MA, Target, DN.Loc); - - // DefPaths only express the path we walked. So, DN.Last could either be a - // thing we want to cache, or not. - if (DN.Last != Target) - addCacheEntry(DN.Last, Target, DN.Loc); - } - - /// Find the nearest def or phi that `From` can legally be optimized to. - /// - /// FIXME: Deduplicate this with MSSA::findDominatingDef. Ideally, MSSA should - /// keep track of this information for us, and allow us O(1) lookups of this - /// info. - MemoryAccess *getWalkTarget(const MemoryPhi *From) { - assert(From->getNumOperands() && "Phi with no operands?"); - - BasicBlock *BB = From->getBlock(); - auto At = WalkTargetCache.find(BB); - if (At != WalkTargetCache.end()) - return At->second; - - SmallVector<const BasicBlock *, 8> ToCache; - ToCache.push_back(BB); - - MemoryAccess *Result = MSSA.getLiveOnEntryDef(); - DomTreeNode *Node = DT.getNode(BB); - while ((Node = Node->getIDom())) { - auto At = WalkTargetCache.find(BB); - if (At != WalkTargetCache.end()) { - Result = At->second; - break; - } - - auto *Accesses = MSSA.getBlockAccesses(Node->getBlock()); - if (Accesses) { - auto Iter = find_if(reverse(*Accesses), [](const MemoryAccess &MA) { - return !isa<MemoryUse>(MA); - }); - if (Iter != Accesses->rend()) { - Result = const_cast<MemoryAccess *>(&*Iter); - break; - } - } - - ToCache.push_back(Node->getBlock()); - } - - for (const BasicBlock *BB : ToCache) - WalkTargetCache.insert({BB, Result}); - return Result; - } - - /// Result of calling walkToPhiOrClobber. - struct UpwardsWalkResult { - /// The "Result" of the walk. Either a clobber, the last thing we walked, or - /// both. - MemoryAccess *Result; - bool IsKnownClobber; - bool FromCache; - }; - - /// Walk to the next Phi or Clobber in the def chain starting at Desc.Last. - /// This will update Desc.Last as it walks. It will (optionally) also stop at - /// StopAt. - /// - /// This does not test for whether StopAt is a clobber - UpwardsWalkResult walkToPhiOrClobber(DefPath &Desc, - MemoryAccess *StopAt = nullptr) { - assert(!isa<MemoryUse>(Desc.Last) && "Uses don't exist in my world"); - - for (MemoryAccess *Current : def_chain(Desc.Last)) { - Desc.Last = Current; - if (Current == StopAt) - return {Current, false, false}; - - if (auto *MD = dyn_cast<MemoryDef>(Current)) - if (MSSA.isLiveOnEntryDef(MD) || - instructionClobbersQuery(MD, Desc.Loc, Query->Inst, AA)) - return {MD, true, false}; - - // Cache checks must be done last, because if Current is a clobber, the - // cache will contain the clobber for Current. - if (MemoryAccess *MA = lookupCache(Current, Desc.Loc)) - return {MA, true, true}; - } - - assert(isa<MemoryPhi>(Desc.Last) && - "Ended at a non-clobber that's not a phi?"); - return {Desc.Last, false, false}; - } - - void addSearches(MemoryPhi *Phi, SmallVectorImpl<ListIndex> &PausedSearches, - ListIndex PriorNode) { - auto UpwardDefs = make_range(upward_defs_begin({Phi, Paths[PriorNode].Loc}), - upward_defs_end()); - for (const MemoryAccessPair &P : UpwardDefs) { - PausedSearches.push_back(Paths.size()); - Paths.emplace_back(P.second, P.first, PriorNode); - } - } - - /// Represents a search that terminated after finding a clobber. This clobber - /// may or may not be present in the path of defs from LastNode..SearchStart, - /// since it may have been retrieved from cache. - struct TerminatedPath { - MemoryAccess *Clobber; - ListIndex LastNode; - }; - - /// Get an access that keeps us from optimizing to the given phi. - /// - /// PausedSearches is an array of indices into the Paths array. Its incoming - /// value is the indices of searches that stopped at the last phi optimization - /// target. It's left in an unspecified state. - /// - /// If this returns None, NewPaused is a vector of searches that terminated - /// at StopWhere. Otherwise, NewPaused is left in an unspecified state. - Optional<TerminatedPath> - getBlockingAccess(MemoryAccess *StopWhere, - SmallVectorImpl<ListIndex> &PausedSearches, - SmallVectorImpl<ListIndex> &NewPaused, - SmallVectorImpl<TerminatedPath> &Terminated) { - assert(!PausedSearches.empty() && "No searches to continue?"); - - // BFS vs DFS really doesn't make a difference here, so just do a DFS with - // PausedSearches as our stack. - while (!PausedSearches.empty()) { - ListIndex PathIndex = PausedSearches.pop_back_val(); - DefPath &Node = Paths[PathIndex]; - - // If we've already visited this path with this MemoryLocation, we don't - // need to do so again. - // - // NOTE: That we just drop these paths on the ground makes caching - // behavior sporadic. e.g. given a diamond: - // A - // B C - // D - // - // ...If we walk D, B, A, C, we'll only cache the result of phi - // optimization for A, B, and D; C will be skipped because it dies here. - // This arguably isn't the worst thing ever, since: - // - We generally query things in a top-down order, so if we got below D - // without needing cache entries for {C, MemLoc}, then chances are - // that those cache entries would end up ultimately unused. - // - We still cache things for A, so C only needs to walk up a bit. - // If this behavior becomes problematic, we can fix without a ton of extra - // work. - if (!VisitedPhis.insert({Node.Last, Node.Loc}).second) - continue; - - UpwardsWalkResult Res = walkToPhiOrClobber(Node, /*StopAt=*/StopWhere); - if (Res.IsKnownClobber) { - assert(Res.Result != StopWhere || Res.FromCache); - // If this wasn't a cache hit, we hit a clobber when walking. That's a - // failure. - TerminatedPath Term{Res.Result, PathIndex}; - if (!Res.FromCache || !MSSA.dominates(Res.Result, StopWhere)) - return Term; - - // Otherwise, it's a valid thing to potentially optimize to. - Terminated.push_back(Term); - continue; - } - - if (Res.Result == StopWhere) { - // We've hit our target. Save this path off for if we want to continue - // walking. - NewPaused.push_back(PathIndex); - continue; - } - - assert(!MSSA.isLiveOnEntryDef(Res.Result) && "liveOnEntry is a clobber"); - addSearches(cast<MemoryPhi>(Res.Result), PausedSearches, PathIndex); - } - - return None; - } - - template <typename T, typename Walker> - struct generic_def_path_iterator - : public iterator_facade_base<generic_def_path_iterator<T, Walker>, - std::forward_iterator_tag, T *> { - generic_def_path_iterator() : W(nullptr), N(None) {} - generic_def_path_iterator(Walker *W, ListIndex N) : W(W), N(N) {} - - T &operator*() const { return curNode(); } - - generic_def_path_iterator &operator++() { - N = curNode().Previous; - return *this; - } - - bool operator==(const generic_def_path_iterator &O) const { - if (N.hasValue() != O.N.hasValue()) - return false; - return !N.hasValue() || *N == *O.N; - } - - private: - T &curNode() const { return W->Paths[*N]; } - - Walker *W; - Optional<ListIndex> N; - }; - - using def_path_iterator = generic_def_path_iterator<DefPath, ClobberWalker>; - using const_def_path_iterator = - generic_def_path_iterator<const DefPath, const ClobberWalker>; - - iterator_range<def_path_iterator> def_path(ListIndex From) { - return make_range(def_path_iterator(this, From), def_path_iterator()); - } - - iterator_range<const_def_path_iterator> const_def_path(ListIndex From) const { - return make_range(const_def_path_iterator(this, From), - const_def_path_iterator()); - } - - struct OptznResult { - /// The path that contains our result. - TerminatedPath PrimaryClobber; - /// The paths that we can legally cache back from, but that aren't - /// necessarily the result of the Phi optimization. - SmallVector<TerminatedPath, 4> OtherClobbers; - }; - - ListIndex defPathIndex(const DefPath &N) const { - // The assert looks nicer if we don't need to do &N - const DefPath *NP = &N; - assert(!Paths.empty() && NP >= &Paths.front() && NP <= &Paths.back() && - "Out of bounds DefPath!"); - return NP - &Paths.front(); - } - - /// Try to optimize a phi as best as we can. Returns a SmallVector of Paths - /// that act as legal clobbers. Note that this won't return *all* clobbers. - /// - /// Phi optimization algorithm tl;dr: - /// - Find the earliest def/phi, A, we can optimize to - /// - Find if all paths from the starting memory access ultimately reach A - /// - If not, optimization isn't possible. - /// - Otherwise, walk from A to another clobber or phi, A'. - /// - If A' is a def, we're done. - /// - If A' is a phi, try to optimize it. - /// - /// A path is a series of {MemoryAccess, MemoryLocation} pairs. A path - /// terminates when a MemoryAccess that clobbers said MemoryLocation is found. - OptznResult tryOptimizePhi(MemoryPhi *Phi, MemoryAccess *Start, - const MemoryLocation &Loc) { - assert(Paths.empty() && VisitedPhis.empty() && - "Reset the optimization state."); - - Paths.emplace_back(Loc, Start, Phi, None); - // Stores how many "valid" optimization nodes we had prior to calling - // addSearches/getBlockingAccess. Necessary for caching if we had a blocker. - auto PriorPathsSize = Paths.size(); - - SmallVector<ListIndex, 16> PausedSearches; - SmallVector<ListIndex, 8> NewPaused; - SmallVector<TerminatedPath, 4> TerminatedPaths; - - addSearches(Phi, PausedSearches, 0); - - // Moves the TerminatedPath with the "most dominated" Clobber to the end of - // Paths. - auto MoveDominatedPathToEnd = [&](SmallVectorImpl<TerminatedPath> &Paths) { - assert(!Paths.empty() && "Need a path to move"); - auto Dom = Paths.begin(); - for (auto I = std::next(Dom), E = Paths.end(); I != E; ++I) - if (!MSSA.dominates(I->Clobber, Dom->Clobber)) - Dom = I; - auto Last = Paths.end() - 1; - if (Last != Dom) - std::iter_swap(Last, Dom); - }; - - MemoryPhi *Current = Phi; - while (1) { - assert(!MSSA.isLiveOnEntryDef(Current) && - "liveOnEntry wasn't treated as a clobber?"); - - MemoryAccess *Target = getWalkTarget(Current); - // If a TerminatedPath doesn't dominate Target, then it wasn't a legal - // optimization for the prior phi. - assert(all_of(TerminatedPaths, [&](const TerminatedPath &P) { - return MSSA.dominates(P.Clobber, Target); - })); - - // FIXME: This is broken, because the Blocker may be reported to be - // liveOnEntry, and we'll happily wait for that to disappear (read: never) - // For the moment, this is fine, since we do nothing with blocker info. - if (Optional<TerminatedPath> Blocker = getBlockingAccess( - Target, PausedSearches, NewPaused, TerminatedPaths)) { - // Cache our work on the blocking node, since we know that's correct. - cacheDefPath(Paths[Blocker->LastNode], Blocker->Clobber); - - // Find the node we started at. We can't search based on N->Last, since - // we may have gone around a loop with a different MemoryLocation. - auto Iter = find_if(def_path(Blocker->LastNode), [&](const DefPath &N) { - return defPathIndex(N) < PriorPathsSize; - }); - assert(Iter != def_path_iterator()); - - DefPath &CurNode = *Iter; - assert(CurNode.Last == Current); - - // Two things: - // A. We can't reliably cache all of NewPaused back. Consider a case - // where we have two paths in NewPaused; one of which can't optimize - // above this phi, whereas the other can. If we cache the second path - // back, we'll end up with suboptimal cache entries. We can handle - // cases like this a bit better when we either try to find all - // clobbers that block phi optimization, or when our cache starts - // supporting unfinished searches. - // B. We can't reliably cache TerminatedPaths back here without doing - // extra checks; consider a case like: - // T - // / \ - // D C - // \ / - // S - // Where T is our target, C is a node with a clobber on it, D is a - // diamond (with a clobber *only* on the left or right node, N), and - // S is our start. Say we walk to D, through the node opposite N - // (read: ignoring the clobber), and see a cache entry in the top - // node of D. That cache entry gets put into TerminatedPaths. We then - // walk up to C (N is later in our worklist), find the clobber, and - // quit. If we append TerminatedPaths to OtherClobbers, we'll cache - // the bottom part of D to the cached clobber, ignoring the clobber - // in N. Again, this problem goes away if we start tracking all - // blockers for a given phi optimization. - TerminatedPath Result{CurNode.Last, defPathIndex(CurNode)}; - return {Result, {}}; - } - - // If there's nothing left to search, then all paths led to valid clobbers - // that we got from our cache; pick the nearest to the start, and allow - // the rest to be cached back. - if (NewPaused.empty()) { - MoveDominatedPathToEnd(TerminatedPaths); - TerminatedPath Result = TerminatedPaths.pop_back_val(); - return {Result, std::move(TerminatedPaths)}; - } - - MemoryAccess *DefChainEnd = nullptr; - SmallVector<TerminatedPath, 4> Clobbers; - for (ListIndex Paused : NewPaused) { - UpwardsWalkResult WR = walkToPhiOrClobber(Paths[Paused]); - if (WR.IsKnownClobber) - Clobbers.push_back({WR.Result, Paused}); - else - // Micro-opt: If we hit the end of the chain, save it. - DefChainEnd = WR.Result; - } - - if (!TerminatedPaths.empty()) { - // If we couldn't find the dominating phi/liveOnEntry in the above loop, - // do it now. - if (!DefChainEnd) - for (MemoryAccess *MA : def_chain(Target)) - DefChainEnd = MA; - - // If any of the terminated paths don't dominate the phi we'll try to - // optimize, we need to figure out what they are and quit. - const BasicBlock *ChainBB = DefChainEnd->getBlock(); - for (const TerminatedPath &TP : TerminatedPaths) { - // Because we know that DefChainEnd is as "high" as we can go, we - // don't need local dominance checks; BB dominance is sufficient. - if (DT.dominates(ChainBB, TP.Clobber->getBlock())) - Clobbers.push_back(TP); - } - } - - // If we have clobbers in the def chain, find the one closest to Current - // and quit. - if (!Clobbers.empty()) { - MoveDominatedPathToEnd(Clobbers); - TerminatedPath Result = Clobbers.pop_back_val(); - return {Result, std::move(Clobbers)}; - } - - assert(all_of(NewPaused, - [&](ListIndex I) { return Paths[I].Last == DefChainEnd; })); - - // Because liveOnEntry is a clobber, this must be a phi. - auto *DefChainPhi = cast<MemoryPhi>(DefChainEnd); - - PriorPathsSize = Paths.size(); - PausedSearches.clear(); - for (ListIndex I : NewPaused) - addSearches(DefChainPhi, PausedSearches, I); - NewPaused.clear(); - - Current = DefChainPhi; - } - } - - /// Caches everything in an OptznResult. - void cacheOptResult(const OptznResult &R) { - if (R.OtherClobbers.empty()) { - // If we're not going to be caching OtherClobbers, don't bother with - // marking visited/etc. - for (const DefPath &N : const_def_path(R.PrimaryClobber.LastNode)) - cacheDefPath(N, R.PrimaryClobber.Clobber); - return; - } - - // PrimaryClobber is our answer. If we can cache anything back, we need to - // stop caching when we visit PrimaryClobber. - SmallBitVector Visited(Paths.size()); - for (const DefPath &N : const_def_path(R.PrimaryClobber.LastNode)) { - Visited[defPathIndex(N)] = true; - cacheDefPath(N, R.PrimaryClobber.Clobber); - } - - for (const TerminatedPath &P : R.OtherClobbers) { - for (const DefPath &N : const_def_path(P.LastNode)) { - ListIndex NIndex = defPathIndex(N); - if (Visited[NIndex]) - break; - Visited[NIndex] = true; - cacheDefPath(N, P.Clobber); - } - } - } - - void verifyOptResult(const OptznResult &R) const { - assert(all_of(R.OtherClobbers, [&](const TerminatedPath &P) { - return MSSA.dominates(P.Clobber, R.PrimaryClobber.Clobber); - })); - } - - void resetPhiOptznState() { - Paths.clear(); - VisitedPhis.clear(); - } - -public: - ClobberWalker(const MemorySSA &MSSA, AliasAnalysis &AA, DominatorTree &DT, - WalkerCache &WC) - : MSSA(MSSA), AA(AA), DT(DT), WC(WC), UseCache(true) {} - - void reset() { WalkTargetCache.clear(); } - - /// Finds the nearest clobber for the given query, optimizing phis if - /// possible. - MemoryAccess *findClobber(MemoryAccess *Start, UpwardsMemoryQuery &Q, - bool UseWalkerCache = true) { - setUseCache(UseWalkerCache); - Query = &Q; - - MemoryAccess *Current = Start; - // This walker pretends uses don't exist. If we're handed one, silently grab - // its def. (This has the nice side-effect of ensuring we never cache uses) - if (auto *MU = dyn_cast<MemoryUse>(Start)) - Current = MU->getDefiningAccess(); - - DefPath FirstDesc(Q.StartingLoc, Current, Current, None); - // Fast path for the overly-common case (no crazy phi optimization - // necessary) - UpwardsWalkResult WalkResult = walkToPhiOrClobber(FirstDesc); - MemoryAccess *Result; - if (WalkResult.IsKnownClobber) { - cacheDefPath(FirstDesc, WalkResult.Result); - Result = WalkResult.Result; - } else { - OptznResult OptRes = tryOptimizePhi(cast<MemoryPhi>(FirstDesc.Last), - Current, Q.StartingLoc); - verifyOptResult(OptRes); - cacheOptResult(OptRes); - resetPhiOptznState(); - Result = OptRes.PrimaryClobber.Clobber; - } - -#ifdef EXPENSIVE_CHECKS - checkClobberSanity(Current, Result, Q.StartingLoc, MSSA, Q, AA); -#endif - return Result; - } - - void verify(const MemorySSA *MSSA) { assert(MSSA == &this->MSSA); } -}; - -struct RenamePassData { - DomTreeNode *DTN; - DomTreeNode::const_iterator ChildIt; - MemoryAccess *IncomingVal; - - RenamePassData(DomTreeNode *D, DomTreeNode::const_iterator It, - MemoryAccess *M) - : DTN(D), ChildIt(It), IncomingVal(M) {} - void swap(RenamePassData &RHS) { - std::swap(DTN, RHS.DTN); - std::swap(ChildIt, RHS.ChildIt); - std::swap(IncomingVal, RHS.IncomingVal); - } -}; -} // anonymous namespace - -namespace llvm { -/// \brief A MemorySSAWalker that does AA walks and caching of lookups to -/// disambiguate accesses. -/// -/// FIXME: The current implementation of this can take quadratic space in rare -/// cases. This can be fixed, but it is something to note until it is fixed. -/// -/// In order to trigger this behavior, you need to store to N distinct locations -/// (that AA can prove don't alias), perform M stores to other memory -/// locations that AA can prove don't alias any of the initial N locations, and -/// then load from all of the N locations. In this case, we insert M cache -/// entries for each of the N loads. -/// -/// For example: -/// define i32 @foo() { -/// %a = alloca i32, align 4 -/// %b = alloca i32, align 4 -/// store i32 0, i32* %a, align 4 -/// store i32 0, i32* %b, align 4 -/// -/// ; Insert M stores to other memory that doesn't alias %a or %b here -/// -/// %c = load i32, i32* %a, align 4 ; Caches M entries in -/// ; CachedUpwardsClobberingAccess for the -/// ; MemoryLocation %a -/// %d = load i32, i32* %b, align 4 ; Caches M entries in -/// ; CachedUpwardsClobberingAccess for the -/// ; MemoryLocation %b -/// -/// ; For completeness' sake, loading %a or %b again would not cache *another* -/// ; M entries. -/// %r = add i32 %c, %d -/// ret i32 %r -/// } -class MemorySSA::CachingWalker final : public MemorySSAWalker { - WalkerCache Cache; - ClobberWalker Walker; - bool AutoResetWalker; - - MemoryAccess *getClobberingMemoryAccess(MemoryAccess *, UpwardsMemoryQuery &); - void verifyRemoved(MemoryAccess *); - -public: - CachingWalker(MemorySSA *, AliasAnalysis *, DominatorTree *); - ~CachingWalker() override; - - using MemorySSAWalker::getClobberingMemoryAccess; - MemoryAccess *getClobberingMemoryAccess(MemoryAccess *) override; - MemoryAccess *getClobberingMemoryAccess(MemoryAccess *, - const MemoryLocation &) override; - void invalidateInfo(MemoryAccess *) override; - - /// Whether we call resetClobberWalker() after each time we *actually* walk to - /// answer a clobber query. - void setAutoResetWalker(bool AutoReset) { AutoResetWalker = AutoReset; } - - /// Drop the walker's persistent data structures. At the moment, this means - /// "drop the walker's cache of BasicBlocks -> - /// earliest-MemoryAccess-we-can-optimize-to". This is necessary if we're - /// going to have DT updates, if we remove MemoryAccesses, etc. - void resetClobberWalker() { Walker.reset(); } - - void verify(const MemorySSA *MSSA) override { - MemorySSAWalker::verify(MSSA); - Walker.verify(MSSA); - } -}; - -/// \brief Rename a single basic block into MemorySSA form. -/// Uses the standard SSA renaming algorithm. -/// \returns The new incoming value. -MemoryAccess *MemorySSA::renameBlock(BasicBlock *BB, - MemoryAccess *IncomingVal) { - auto It = PerBlockAccesses.find(BB); - // Skip most processing if the list is empty. - if (It != PerBlockAccesses.end()) { - AccessList *Accesses = It->second.get(); - for (MemoryAccess &L : *Accesses) { - if (MemoryUseOrDef *MUD = dyn_cast<MemoryUseOrDef>(&L)) { - if (MUD->getDefiningAccess() == nullptr) - MUD->setDefiningAccess(IncomingVal); - if (isa<MemoryDef>(&L)) - IncomingVal = &L; - } else { - IncomingVal = &L; - } - } - } - - // Pass through values to our successors - for (const BasicBlock *S : successors(BB)) { - auto It = PerBlockAccesses.find(S); - // Rename the phi nodes in our successor block - if (It == PerBlockAccesses.end() || !isa<MemoryPhi>(It->second->front())) - continue; - AccessList *Accesses = It->second.get(); - auto *Phi = cast<MemoryPhi>(&Accesses->front()); - Phi->addIncoming(IncomingVal, BB); - } - - return IncomingVal; -} - -/// \brief This is the standard SSA renaming algorithm. -/// -/// We walk the dominator tree in preorder, renaming accesses, and then filling -/// in phi nodes in our successors. -void MemorySSA::renamePass(DomTreeNode *Root, MemoryAccess *IncomingVal, - SmallPtrSet<BasicBlock *, 16> &Visited) { - SmallVector<RenamePassData, 32> WorkStack; - IncomingVal = renameBlock(Root->getBlock(), IncomingVal); - WorkStack.push_back({Root, Root->begin(), IncomingVal}); - Visited.insert(Root->getBlock()); - - while (!WorkStack.empty()) { - DomTreeNode *Node = WorkStack.back().DTN; - DomTreeNode::const_iterator ChildIt = WorkStack.back().ChildIt; - IncomingVal = WorkStack.back().IncomingVal; - - if (ChildIt == Node->end()) { - WorkStack.pop_back(); - } else { - DomTreeNode *Child = *ChildIt; - ++WorkStack.back().ChildIt; - BasicBlock *BB = Child->getBlock(); - Visited.insert(BB); - IncomingVal = renameBlock(BB, IncomingVal); - WorkStack.push_back({Child, Child->begin(), IncomingVal}); - } - } -} - -/// \brief Compute dominator levels, used by the phi insertion algorithm above. -void MemorySSA::computeDomLevels(DenseMap<DomTreeNode *, unsigned> &DomLevels) { - for (auto DFI = df_begin(DT->getRootNode()), DFE = df_end(DT->getRootNode()); - DFI != DFE; ++DFI) - DomLevels[*DFI] = DFI.getPathLength() - 1; -} - -/// \brief This handles unreachable block accesses by deleting phi nodes in -/// unreachable blocks, and marking all other unreachable MemoryAccess's as -/// being uses of the live on entry definition. -void MemorySSA::markUnreachableAsLiveOnEntry(BasicBlock *BB) { - assert(!DT->isReachableFromEntry(BB) && - "Reachable block found while handling unreachable blocks"); - - // Make sure phi nodes in our reachable successors end up with a - // LiveOnEntryDef for our incoming edge, even though our block is forward - // unreachable. We could just disconnect these blocks from the CFG fully, - // but we do not right now. - for (const BasicBlock *S : successors(BB)) { - if (!DT->isReachableFromEntry(S)) - continue; - auto It = PerBlockAccesses.find(S); - // Rename the phi nodes in our successor block - if (It == PerBlockAccesses.end() || !isa<MemoryPhi>(It->second->front())) - continue; - AccessList *Accesses = It->second.get(); - auto *Phi = cast<MemoryPhi>(&Accesses->front()); - Phi->addIncoming(LiveOnEntryDef.get(), BB); - } - - auto It = PerBlockAccesses.find(BB); - if (It == PerBlockAccesses.end()) - return; - - auto &Accesses = It->second; - for (auto AI = Accesses->begin(), AE = Accesses->end(); AI != AE;) { - auto Next = std::next(AI); - // If we have a phi, just remove it. We are going to replace all - // users with live on entry. - if (auto *UseOrDef = dyn_cast<MemoryUseOrDef>(AI)) - UseOrDef->setDefiningAccess(LiveOnEntryDef.get()); - else - Accesses->erase(AI); - AI = Next; - } -} - -MemorySSA::MemorySSA(Function &Func, AliasAnalysis *AA, DominatorTree *DT) - : AA(AA), DT(DT), F(Func), LiveOnEntryDef(nullptr), Walker(nullptr), - NextID(INVALID_MEMORYACCESS_ID) { - buildMemorySSA(); -} - -MemorySSA::~MemorySSA() { - // Drop all our references - for (const auto &Pair : PerBlockAccesses) - for (MemoryAccess &MA : *Pair.second) - MA.dropAllReferences(); -} - -MemorySSA::AccessList *MemorySSA::getOrCreateAccessList(const BasicBlock *BB) { - auto Res = PerBlockAccesses.insert(std::make_pair(BB, nullptr)); - - if (Res.second) - Res.first->second = make_unique<AccessList>(); - return Res.first->second.get(); -} - -/// This class is a batch walker of all MemoryUse's in the program, and points -/// their defining access at the thing that actually clobbers them. Because it -/// is a batch walker that touches everything, it does not operate like the -/// other walkers. This walker is basically performing a top-down SSA renaming -/// pass, where the version stack is used as the cache. This enables it to be -/// significantly more time and memory efficient than using the regular walker, -/// which is walking bottom-up. -class MemorySSA::OptimizeUses { -public: - OptimizeUses(MemorySSA *MSSA, MemorySSAWalker *Walker, AliasAnalysis *AA, - DominatorTree *DT) - : MSSA(MSSA), Walker(Walker), AA(AA), DT(DT) { - Walker = MSSA->getWalker(); - } - - void optimizeUses(); - -private: - /// This represents where a given memorylocation is in the stack. - struct MemlocStackInfo { - // This essentially is keeping track of versions of the stack. Whenever - // the stack changes due to pushes or pops, these versions increase. - unsigned long StackEpoch; - unsigned long PopEpoch; - // This is the lower bound of places on the stack to check. It is equal to - // the place the last stack walk ended. - // Note: Correctness depends on this being initialized to 0, which densemap - // does - unsigned long LowerBound; - const BasicBlock *LowerBoundBlock; - // This is where the last walk for this memory location ended. - unsigned long LastKill; - bool LastKillValid; - }; - void optimizeUsesInBlock(const BasicBlock *, unsigned long &, unsigned long &, - SmallVectorImpl<MemoryAccess *> &, - DenseMap<MemoryLocOrCall, MemlocStackInfo> &); - MemorySSA *MSSA; - MemorySSAWalker *Walker; - AliasAnalysis *AA; - DominatorTree *DT; -}; - -/// Optimize the uses in a given block This is basically the SSA renaming -/// algorithm, with one caveat: We are able to use a single stack for all -/// MemoryUses. This is because the set of *possible* reaching MemoryDefs is -/// the same for every MemoryUse. The *actual* clobbering MemoryDef is just -/// going to be some position in that stack of possible ones. -/// -/// We track the stack positions that each MemoryLocation needs -/// to check, and last ended at. This is because we only want to check the -/// things that changed since last time. The same MemoryLocation should -/// get clobbered by the same store (getModRefInfo does not use invariantness or -/// things like this, and if they start, we can modify MemoryLocOrCall to -/// include relevant data) -void MemorySSA::OptimizeUses::optimizeUsesInBlock( - const BasicBlock *BB, unsigned long &StackEpoch, unsigned long &PopEpoch, - SmallVectorImpl<MemoryAccess *> &VersionStack, - DenseMap<MemoryLocOrCall, MemlocStackInfo> &LocStackInfo) { - - /// If no accesses, nothing to do. - MemorySSA::AccessList *Accesses = MSSA->getWritableBlockAccesses(BB); - if (Accesses == nullptr) - return; - - // Pop everything that doesn't dominate the current block off the stack, - // increment the PopEpoch to account for this. - while (!VersionStack.empty()) { - BasicBlock *BackBlock = VersionStack.back()->getBlock(); - if (DT->dominates(BackBlock, BB)) - break; - while (VersionStack.back()->getBlock() == BackBlock) - VersionStack.pop_back(); - ++PopEpoch; - } - for (MemoryAccess &MA : *Accesses) { - auto *MU = dyn_cast<MemoryUse>(&MA); - if (!MU) { - VersionStack.push_back(&MA); - ++StackEpoch; - continue; - } - - if (isUseTriviallyOptimizableToLiveOnEntry(*AA, MU->getMemoryInst())) { - MU->setDefiningAccess(MSSA->getLiveOnEntryDef(), true); - continue; - } - - MemoryLocOrCall UseMLOC(MU); - auto &LocInfo = LocStackInfo[UseMLOC]; - // If the pop epoch changed, it means we've removed stuff from top of - // stack due to changing blocks. We may have to reset the lower bound or - // last kill info. - if (LocInfo.PopEpoch != PopEpoch) { - LocInfo.PopEpoch = PopEpoch; - LocInfo.StackEpoch = StackEpoch; - // If the lower bound was in something that no longer dominates us, we - // have to reset it. - // We can't simply track stack size, because the stack may have had - // pushes/pops in the meantime. - // XXX: This is non-optimal, but only is slower cases with heavily - // branching dominator trees. To get the optimal number of queries would - // be to make lowerbound and lastkill a per-loc stack, and pop it until - // the top of that stack dominates us. This does not seem worth it ATM. - // A much cheaper optimization would be to always explore the deepest - // branch of the dominator tree first. This will guarantee this resets on - // the smallest set of blocks. - if (LocInfo.LowerBoundBlock && LocInfo.LowerBoundBlock != BB && - !DT->dominates(LocInfo.LowerBoundBlock, BB)) { - // Reset the lower bound of things to check. - // TODO: Some day we should be able to reset to last kill, rather than - // 0. - LocInfo.LowerBound = 0; - LocInfo.LowerBoundBlock = VersionStack[0]->getBlock(); - LocInfo.LastKillValid = false; - } - } else if (LocInfo.StackEpoch != StackEpoch) { - // If all that has changed is the StackEpoch, we only have to check the - // new things on the stack, because we've checked everything before. In - // this case, the lower bound of things to check remains the same. - LocInfo.PopEpoch = PopEpoch; - LocInfo.StackEpoch = StackEpoch; - } - if (!LocInfo.LastKillValid) { - LocInfo.LastKill = VersionStack.size() - 1; - LocInfo.LastKillValid = true; - } - - // At this point, we should have corrected last kill and LowerBound to be - // in bounds. - assert(LocInfo.LowerBound < VersionStack.size() && - "Lower bound out of range"); - assert(LocInfo.LastKill < VersionStack.size() && - "Last kill info out of range"); - // In any case, the new upper bound is the top of the stack. - unsigned long UpperBound = VersionStack.size() - 1; - - if (UpperBound - LocInfo.LowerBound > MaxCheckLimit) { - DEBUG(dbgs() << "MemorySSA skipping optimization of " << *MU << " (" - << *(MU->getMemoryInst()) << ")" - << " because there are " << UpperBound - LocInfo.LowerBound - << " stores to disambiguate\n"); - // Because we did not walk, LastKill is no longer valid, as this may - // have been a kill. - LocInfo.LastKillValid = false; - continue; - } - bool FoundClobberResult = false; - while (UpperBound > LocInfo.LowerBound) { - if (isa<MemoryPhi>(VersionStack[UpperBound])) { - // For phis, use the walker, see where we ended up, go there - Instruction *UseInst = MU->getMemoryInst(); - MemoryAccess *Result = Walker->getClobberingMemoryAccess(UseInst); - // We are guaranteed to find it or something is wrong - while (VersionStack[UpperBound] != Result) { - assert(UpperBound != 0); - --UpperBound; - } - FoundClobberResult = true; - break; - } - - MemoryDef *MD = cast<MemoryDef>(VersionStack[UpperBound]); - // If the lifetime of the pointer ends at this instruction, it's live on - // entry. - if (!UseMLOC.IsCall && lifetimeEndsAt(MD, UseMLOC.getLoc(), *AA)) { - // Reset UpperBound to liveOnEntryDef's place in the stack - UpperBound = 0; - FoundClobberResult = true; - break; - } - if (instructionClobbersQuery(MD, MU, UseMLOC, *AA)) { - FoundClobberResult = true; - break; - } - --UpperBound; - } - // At the end of this loop, UpperBound is either a clobber, or lower bound - // PHI walking may cause it to be < LowerBound, and in fact, < LastKill. - if (FoundClobberResult || UpperBound < LocInfo.LastKill) { - MU->setDefiningAccess(VersionStack[UpperBound], true); - // We were last killed now by where we got to - LocInfo.LastKill = UpperBound; - } else { - // Otherwise, we checked all the new ones, and now we know we can get to - // LastKill. - MU->setDefiningAccess(VersionStack[LocInfo.LastKill], true); - } - LocInfo.LowerBound = VersionStack.size() - 1; - LocInfo.LowerBoundBlock = BB; - } -} - -/// Optimize uses to point to their actual clobbering definitions. -void MemorySSA::OptimizeUses::optimizeUses() { - - // We perform a non-recursive top-down dominator tree walk - struct StackInfo { - const DomTreeNode *Node; - DomTreeNode::const_iterator Iter; - }; - - SmallVector<MemoryAccess *, 16> VersionStack; - SmallVector<StackInfo, 16> DomTreeWorklist; - DenseMap<MemoryLocOrCall, MemlocStackInfo> LocStackInfo; - VersionStack.push_back(MSSA->getLiveOnEntryDef()); - - unsigned long StackEpoch = 1; - unsigned long PopEpoch = 1; - for (const auto *DomNode : depth_first(DT->getRootNode())) - optimizeUsesInBlock(DomNode->getBlock(), StackEpoch, PopEpoch, VersionStack, - LocStackInfo); -} - -void MemorySSA::placePHINodes( - const SmallPtrSetImpl<BasicBlock *> &DefiningBlocks, - const DenseMap<const BasicBlock *, unsigned int> &BBNumbers) { - // Determine where our MemoryPhi's should go - ForwardIDFCalculator IDFs(*DT); - IDFs.setDefiningBlocks(DefiningBlocks); - SmallVector<BasicBlock *, 32> IDFBlocks; - IDFs.calculate(IDFBlocks); - - std::sort(IDFBlocks.begin(), IDFBlocks.end(), - [&BBNumbers](const BasicBlock *A, const BasicBlock *B) { - return BBNumbers.lookup(A) < BBNumbers.lookup(B); - }); - - // Now place MemoryPhi nodes. - for (auto &BB : IDFBlocks) { - // Insert phi node - AccessList *Accesses = getOrCreateAccessList(BB); - MemoryPhi *Phi = new MemoryPhi(BB->getContext(), BB, NextID++); - ValueToMemoryAccess[BB] = Phi; - // Phi's always are placed at the front of the block. - Accesses->push_front(Phi); - } -} - -void MemorySSA::buildMemorySSA() { - // We create an access to represent "live on entry", for things like - // arguments or users of globals, where the memory they use is defined before - // the beginning of the function. We do not actually insert it into the IR. - // We do not define a live on exit for the immediate uses, and thus our - // semantics do *not* imply that something with no immediate uses can simply - // be removed. - BasicBlock &StartingPoint = F.getEntryBlock(); - LiveOnEntryDef = make_unique<MemoryDef>(F.getContext(), nullptr, nullptr, - &StartingPoint, NextID++); - DenseMap<const BasicBlock *, unsigned int> BBNumbers; - unsigned NextBBNum = 0; - - // We maintain lists of memory accesses per-block, trading memory for time. We - // could just look up the memory access for every possible instruction in the - // stream. - SmallPtrSet<BasicBlock *, 32> DefiningBlocks; - SmallPtrSet<BasicBlock *, 32> DefUseBlocks; - // Go through each block, figure out where defs occur, and chain together all - // the accesses. - for (BasicBlock &B : F) { - BBNumbers[&B] = NextBBNum++; - bool InsertIntoDef = false; - AccessList *Accesses = nullptr; - for (Instruction &I : B) { - MemoryUseOrDef *MUD = createNewAccess(&I); - if (!MUD) - continue; - InsertIntoDef |= isa<MemoryDef>(MUD); - - if (!Accesses) - Accesses = getOrCreateAccessList(&B); - Accesses->push_back(MUD); - } - if (InsertIntoDef) - DefiningBlocks.insert(&B); - if (Accesses) - DefUseBlocks.insert(&B); - } - placePHINodes(DefiningBlocks, BBNumbers); - - // Now do regular SSA renaming on the MemoryDef/MemoryUse. Visited will get - // filled in with all blocks. - SmallPtrSet<BasicBlock *, 16> Visited; - renamePass(DT->getRootNode(), LiveOnEntryDef.get(), Visited); - - CachingWalker *Walker = getWalkerImpl(); - - // We're doing a batch of updates; don't drop useful caches between them. - Walker->setAutoResetWalker(false); - OptimizeUses(this, Walker, AA, DT).optimizeUses(); - Walker->setAutoResetWalker(true); - Walker->resetClobberWalker(); - - // Mark the uses in unreachable blocks as live on entry, so that they go - // somewhere. - for (auto &BB : F) - if (!Visited.count(&BB)) - markUnreachableAsLiveOnEntry(&BB); -} - -MemorySSAWalker *MemorySSA::getWalker() { return getWalkerImpl(); } - -MemorySSA::CachingWalker *MemorySSA::getWalkerImpl() { - if (Walker) - return Walker.get(); - - Walker = make_unique<CachingWalker>(this, AA, DT); - return Walker.get(); -} - -MemoryPhi *MemorySSA::createMemoryPhi(BasicBlock *BB) { - assert(!getMemoryAccess(BB) && "MemoryPhi already exists for this BB"); - AccessList *Accesses = getOrCreateAccessList(BB); - MemoryPhi *Phi = new MemoryPhi(BB->getContext(), BB, NextID++); - ValueToMemoryAccess[BB] = Phi; - // Phi's always are placed at the front of the block. - Accesses->push_front(Phi); - BlockNumberingValid.erase(BB); - return Phi; -} - -MemoryUseOrDef *MemorySSA::createDefinedAccess(Instruction *I, - MemoryAccess *Definition) { - assert(!isa<PHINode>(I) && "Cannot create a defined access for a PHI"); - MemoryUseOrDef *NewAccess = createNewAccess(I); - assert( - NewAccess != nullptr && - "Tried to create a memory access for a non-memory touching instruction"); - NewAccess->setDefiningAccess(Definition); - return NewAccess; -} - -MemoryAccess *MemorySSA::createMemoryAccessInBB(Instruction *I, - MemoryAccess *Definition, - const BasicBlock *BB, - InsertionPlace Point) { - MemoryUseOrDef *NewAccess = createDefinedAccess(I, Definition); - auto *Accesses = getOrCreateAccessList(BB); - if (Point == Beginning) { - // It goes after any phi nodes - auto AI = find_if( - *Accesses, [](const MemoryAccess &MA) { return !isa<MemoryPhi>(MA); }); - - Accesses->insert(AI, NewAccess); - } else { - Accesses->push_back(NewAccess); - } - BlockNumberingValid.erase(BB); - return NewAccess; -} - -MemoryUseOrDef *MemorySSA::createMemoryAccessBefore(Instruction *I, - MemoryAccess *Definition, - MemoryUseOrDef *InsertPt) { - assert(I->getParent() == InsertPt->getBlock() && - "New and old access must be in the same block"); - MemoryUseOrDef *NewAccess = createDefinedAccess(I, Definition); - auto *Accesses = getOrCreateAccessList(InsertPt->getBlock()); - Accesses->insert(AccessList::iterator(InsertPt), NewAccess); - BlockNumberingValid.erase(InsertPt->getBlock()); - return NewAccess; -} - -MemoryUseOrDef *MemorySSA::createMemoryAccessAfter(Instruction *I, - MemoryAccess *Definition, - MemoryAccess *InsertPt) { - assert(I->getParent() == InsertPt->getBlock() && - "New and old access must be in the same block"); - MemoryUseOrDef *NewAccess = createDefinedAccess(I, Definition); - auto *Accesses = getOrCreateAccessList(InsertPt->getBlock()); - Accesses->insertAfter(AccessList::iterator(InsertPt), NewAccess); - BlockNumberingValid.erase(InsertPt->getBlock()); - return NewAccess; -} - -void MemorySSA::spliceMemoryAccessAbove(MemoryDef *Where, - MemoryUseOrDef *What) { - assert(What != getLiveOnEntryDef() && - Where != getLiveOnEntryDef() && "Can't splice (above) LOE."); - assert(dominates(Where, What) && "Only upwards splices are permitted."); - - if (Where == What) - return; - if (isa<MemoryDef>(What)) { - // TODO: possibly use removeMemoryAccess' more efficient RAUW - What->replaceAllUsesWith(What->getDefiningAccess()); - What->setDefiningAccess(Where->getDefiningAccess()); - Where->setDefiningAccess(What); - } - AccessList *Src = getWritableBlockAccesses(What->getBlock()); - AccessList *Dest = getWritableBlockAccesses(Where->getBlock()); - Dest->splice(AccessList::iterator(Where), *Src, What); - - BlockNumberingValid.erase(What->getBlock()); - if (What->getBlock() != Where->getBlock()) - BlockNumberingValid.erase(Where->getBlock()); -} - -/// \brief Helper function to create new memory accesses -MemoryUseOrDef *MemorySSA::createNewAccess(Instruction *I) { - // The assume intrinsic has a control dependency which we model by claiming - // that it writes arbitrarily. Ignore that fake memory dependency here. - // FIXME: Replace this special casing with a more accurate modelling of - // assume's control dependency. - if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) - if (II->getIntrinsicID() == Intrinsic::assume) - return nullptr; - - // Find out what affect this instruction has on memory. - ModRefInfo ModRef = AA->getModRefInfo(I); - bool Def = bool(ModRef & MRI_Mod); - bool Use = bool(ModRef & MRI_Ref); - - // It's possible for an instruction to not modify memory at all. During - // construction, we ignore them. - if (!Def && !Use) - return nullptr; - - assert((Def || Use) && - "Trying to create a memory access with a non-memory instruction"); - - MemoryUseOrDef *MUD; - if (Def) - MUD = new MemoryDef(I->getContext(), nullptr, I, I->getParent(), NextID++); - else - MUD = new MemoryUse(I->getContext(), nullptr, I, I->getParent()); - ValueToMemoryAccess[I] = MUD; - return MUD; -} - -MemoryAccess *MemorySSA::findDominatingDef(BasicBlock *UseBlock, - enum InsertionPlace Where) { - // Handle the initial case - if (Where == Beginning) - // The only thing that could define us at the beginning is a phi node - if (MemoryPhi *Phi = getMemoryAccess(UseBlock)) - return Phi; - - DomTreeNode *CurrNode = DT->getNode(UseBlock); - // Need to be defined by our dominator - if (Where == Beginning) - CurrNode = CurrNode->getIDom(); - Where = End; - while (CurrNode) { - auto It = PerBlockAccesses.find(CurrNode->getBlock()); - if (It != PerBlockAccesses.end()) { - auto &Accesses = It->second; - for (MemoryAccess &RA : reverse(*Accesses)) { - if (isa<MemoryDef>(RA) || isa<MemoryPhi>(RA)) - return &RA; - } - } - CurrNode = CurrNode->getIDom(); - } - return LiveOnEntryDef.get(); -} - -/// \brief Returns true if \p Replacer dominates \p Replacee . -bool MemorySSA::dominatesUse(const MemoryAccess *Replacer, - const MemoryAccess *Replacee) const { - if (isa<MemoryUseOrDef>(Replacee)) - return DT->dominates(Replacer->getBlock(), Replacee->getBlock()); - const auto *MP = cast<MemoryPhi>(Replacee); - // For a phi node, the use occurs in the predecessor block of the phi node. - // Since we may occur multiple times in the phi node, we have to check each - // operand to ensure Replacer dominates each operand where Replacee occurs. - for (const Use &Arg : MP->operands()) { - if (Arg.get() != Replacee && - !DT->dominates(Replacer->getBlock(), MP->getIncomingBlock(Arg))) - return false; - } - return true; -} - -/// \brief If all arguments of a MemoryPHI are defined by the same incoming -/// argument, return that argument. -static MemoryAccess *onlySingleValue(MemoryPhi *MP) { - MemoryAccess *MA = nullptr; - - for (auto &Arg : MP->operands()) { - if (!MA) - MA = cast<MemoryAccess>(Arg); - else if (MA != Arg) - return nullptr; - } - return MA; -} - -/// \brief Properly remove \p MA from all of MemorySSA's lookup tables. -/// -/// Because of the way the intrusive list and use lists work, it is important to -/// do removal in the right order. -void MemorySSA::removeFromLookups(MemoryAccess *MA) { - assert(MA->use_empty() && - "Trying to remove memory access that still has uses"); - BlockNumbering.erase(MA); - if (MemoryUseOrDef *MUD = dyn_cast<MemoryUseOrDef>(MA)) - MUD->setDefiningAccess(nullptr); - // Invalidate our walker's cache if necessary - if (!isa<MemoryUse>(MA)) - Walker->invalidateInfo(MA); - // The call below to erase will destroy MA, so we can't change the order we - // are doing things here - Value *MemoryInst; - if (MemoryUseOrDef *MUD = dyn_cast<MemoryUseOrDef>(MA)) { - MemoryInst = MUD->getMemoryInst(); - } else { - MemoryInst = MA->getBlock(); - } - auto VMA = ValueToMemoryAccess.find(MemoryInst); - if (VMA->second == MA) - ValueToMemoryAccess.erase(VMA); - - auto AccessIt = PerBlockAccesses.find(MA->getBlock()); - std::unique_ptr<AccessList> &Accesses = AccessIt->second; - Accesses->erase(MA); - if (Accesses->empty()) - PerBlockAccesses.erase(AccessIt); -} - -void MemorySSA::removeMemoryAccess(MemoryAccess *MA) { - assert(!isLiveOnEntryDef(MA) && "Trying to remove the live on entry def"); - // We can only delete phi nodes if they have no uses, or we can replace all - // uses with a single definition. - MemoryAccess *NewDefTarget = nullptr; - if (MemoryPhi *MP = dyn_cast<MemoryPhi>(MA)) { - // Note that it is sufficient to know that all edges of the phi node have - // the same argument. If they do, by the definition of dominance frontiers - // (which we used to place this phi), that argument must dominate this phi, - // and thus, must dominate the phi's uses, and so we will not hit the assert - // below. - NewDefTarget = onlySingleValue(MP); - assert((NewDefTarget || MP->use_empty()) && - "We can't delete this memory phi"); - } else { - NewDefTarget = cast<MemoryUseOrDef>(MA)->getDefiningAccess(); - } - - // Re-point the uses at our defining access - if (!MA->use_empty()) { - // Reset optimized on users of this store, and reset the uses. - // A few notes: - // 1. This is a slightly modified version of RAUW to avoid walking the - // uses twice here. - // 2. If we wanted to be complete, we would have to reset the optimized - // flags on users of phi nodes if doing the below makes a phi node have all - // the same arguments. Instead, we prefer users to removeMemoryAccess those - // phi nodes, because doing it here would be N^3. - if (MA->hasValueHandle()) - ValueHandleBase::ValueIsRAUWd(MA, NewDefTarget); - // Note: We assume MemorySSA is not used in metadata since it's not really - // part of the IR. - - while (!MA->use_empty()) { - Use &U = *MA->use_begin(); - if (MemoryUse *MU = dyn_cast<MemoryUse>(U.getUser())) - MU->resetOptimized(); - U.set(NewDefTarget); - } - } - - // The call below to erase will destroy MA, so we can't change the order we - // are doing things here - removeFromLookups(MA); -} - -void MemorySSA::print(raw_ostream &OS) const { - MemorySSAAnnotatedWriter Writer(this); - F.print(OS, &Writer); -} - -void MemorySSA::dump() const { - MemorySSAAnnotatedWriter Writer(this); - F.print(dbgs(), &Writer); -} - -void MemorySSA::verifyMemorySSA() const { - verifyDefUses(F); - verifyDomination(F); - verifyOrdering(F); - Walker->verify(this); -} - -/// \brief Verify that the order and existence of MemoryAccesses matches the -/// order and existence of memory affecting instructions. -void MemorySSA::verifyOrdering(Function &F) const { - // Walk all the blocks, comparing what the lookups think and what the access - // lists think, as well as the order in the blocks vs the order in the access - // lists. - SmallVector<MemoryAccess *, 32> ActualAccesses; - for (BasicBlock &B : F) { - const AccessList *AL = getBlockAccesses(&B); - MemoryAccess *Phi = getMemoryAccess(&B); - if (Phi) - ActualAccesses.push_back(Phi); - for (Instruction &I : B) { - MemoryAccess *MA = getMemoryAccess(&I); - assert((!MA || AL) && "We have memory affecting instructions " - "in this block but they are not in the " - "access list"); - if (MA) - ActualAccesses.push_back(MA); - } - // Either we hit the assert, really have no accesses, or we have both - // accesses and an access list - if (!AL) - continue; - assert(AL->size() == ActualAccesses.size() && - "We don't have the same number of accesses in the block as on the " - "access list"); - auto ALI = AL->begin(); - auto AAI = ActualAccesses.begin(); - while (ALI != AL->end() && AAI != ActualAccesses.end()) { - assert(&*ALI == *AAI && "Not the same accesses in the same order"); - ++ALI; - ++AAI; - } - ActualAccesses.clear(); - } -} - -/// \brief Verify the domination properties of MemorySSA by checking that each -/// definition dominates all of its uses. -void MemorySSA::verifyDomination(Function &F) const { -#ifndef NDEBUG - for (BasicBlock &B : F) { - // Phi nodes are attached to basic blocks - if (MemoryPhi *MP = getMemoryAccess(&B)) - for (const Use &U : MP->uses()) - assert(dominates(MP, U) && "Memory PHI does not dominate it's uses"); - - for (Instruction &I : B) { - MemoryAccess *MD = dyn_cast_or_null<MemoryDef>(getMemoryAccess(&I)); - if (!MD) - continue; - - for (const Use &U : MD->uses()) - assert(dominates(MD, U) && "Memory Def does not dominate it's uses"); - } - } -#endif -} - -/// \brief Verify the def-use lists in MemorySSA, by verifying that \p Use -/// appears in the use list of \p Def. - -void MemorySSA::verifyUseInDefs(MemoryAccess *Def, MemoryAccess *Use) const { -#ifndef NDEBUG - // The live on entry use may cause us to get a NULL def here - if (!Def) - assert(isLiveOnEntryDef(Use) && - "Null def but use not point to live on entry def"); - else - assert(is_contained(Def->users(), Use) && - "Did not find use in def's use list"); -#endif -} - -/// \brief Verify the immediate use information, by walking all the memory -/// accesses and verifying that, for each use, it appears in the -/// appropriate def's use list -void MemorySSA::verifyDefUses(Function &F) const { - for (BasicBlock &B : F) { - // Phi nodes are attached to basic blocks - if (MemoryPhi *Phi = getMemoryAccess(&B)) { - assert(Phi->getNumOperands() == static_cast<unsigned>(std::distance( - pred_begin(&B), pred_end(&B))) && - "Incomplete MemoryPhi Node"); - for (unsigned I = 0, E = Phi->getNumIncomingValues(); I != E; ++I) - verifyUseInDefs(Phi->getIncomingValue(I), Phi); - } - - for (Instruction &I : B) { - if (MemoryUseOrDef *MA = getMemoryAccess(&I)) { - verifyUseInDefs(MA->getDefiningAccess(), MA); - } - } - } -} - -MemoryUseOrDef *MemorySSA::getMemoryAccess(const Instruction *I) const { - return cast_or_null<MemoryUseOrDef>(ValueToMemoryAccess.lookup(I)); -} - -MemoryPhi *MemorySSA::getMemoryAccess(const BasicBlock *BB) const { - return cast_or_null<MemoryPhi>(ValueToMemoryAccess.lookup(cast<Value>(BB))); -} - -/// Perform a local numbering on blocks so that instruction ordering can be -/// determined in constant time. -/// TODO: We currently just number in order. If we numbered by N, we could -/// allow at least N-1 sequences of insertBefore or insertAfter (and at least -/// log2(N) sequences of mixed before and after) without needing to invalidate -/// the numbering. -void MemorySSA::renumberBlock(const BasicBlock *B) const { - // The pre-increment ensures the numbers really start at 1. - unsigned long CurrentNumber = 0; - const AccessList *AL = getBlockAccesses(B); - assert(AL != nullptr && "Asking to renumber an empty block"); - for (const auto &I : *AL) - BlockNumbering[&I] = ++CurrentNumber; - BlockNumberingValid.insert(B); -} - -/// \brief Determine, for two memory accesses in the same block, -/// whether \p Dominator dominates \p Dominatee. -/// \returns True if \p Dominator dominates \p Dominatee. -bool MemorySSA::locallyDominates(const MemoryAccess *Dominator, - const MemoryAccess *Dominatee) const { - - const BasicBlock *DominatorBlock = Dominator->getBlock(); - - assert((DominatorBlock == Dominatee->getBlock()) && - "Asking for local domination when accesses are in different blocks!"); - // A node dominates itself. - if (Dominatee == Dominator) - return true; - - // When Dominatee is defined on function entry, it is not dominated by another - // memory access. - if (isLiveOnEntryDef(Dominatee)) - return false; - - // When Dominator is defined on function entry, it dominates the other memory - // access. - if (isLiveOnEntryDef(Dominator)) - return true; - - if (!BlockNumberingValid.count(DominatorBlock)) - renumberBlock(DominatorBlock); - - unsigned long DominatorNum = BlockNumbering.lookup(Dominator); - // All numbers start with 1 - assert(DominatorNum != 0 && "Block was not numbered properly"); - unsigned long DominateeNum = BlockNumbering.lookup(Dominatee); - assert(DominateeNum != 0 && "Block was not numbered properly"); - return DominatorNum < DominateeNum; -} - -bool MemorySSA::dominates(const MemoryAccess *Dominator, - const MemoryAccess *Dominatee) const { - if (Dominator == Dominatee) - return true; - - if (isLiveOnEntryDef(Dominatee)) - return false; - - if (Dominator->getBlock() != Dominatee->getBlock()) - return DT->dominates(Dominator->getBlock(), Dominatee->getBlock()); - return locallyDominates(Dominator, Dominatee); -} - -bool MemorySSA::dominates(const MemoryAccess *Dominator, - const Use &Dominatee) const { - if (MemoryPhi *MP = dyn_cast<MemoryPhi>(Dominatee.getUser())) { - BasicBlock *UseBB = MP->getIncomingBlock(Dominatee); - // The def must dominate the incoming block of the phi. - if (UseBB != Dominator->getBlock()) - return DT->dominates(Dominator->getBlock(), UseBB); - // If the UseBB and the DefBB are the same, compare locally. - return locallyDominates(Dominator, cast<MemoryAccess>(Dominatee)); - } - // If it's not a PHI node use, the normal dominates can already handle it. - return dominates(Dominator, cast<MemoryAccess>(Dominatee.getUser())); -} - -const static char LiveOnEntryStr[] = "liveOnEntry"; - -void MemoryDef::print(raw_ostream &OS) const { - MemoryAccess *UO = getDefiningAccess(); - - OS << getID() << " = MemoryDef("; - if (UO && UO->getID()) - OS << UO->getID(); - else - OS << LiveOnEntryStr; - OS << ')'; -} - -void MemoryPhi::print(raw_ostream &OS) const { - bool First = true; - OS << getID() << " = MemoryPhi("; - for (const auto &Op : operands()) { - BasicBlock *BB = getIncomingBlock(Op); - MemoryAccess *MA = cast<MemoryAccess>(Op); - if (!First) - OS << ','; - else - First = false; - - OS << '{'; - if (BB->hasName()) - OS << BB->getName(); - else - BB->printAsOperand(OS, false); - OS << ','; - if (unsigned ID = MA->getID()) - OS << ID; - else - OS << LiveOnEntryStr; - OS << '}'; - } - OS << ')'; -} - -MemoryAccess::~MemoryAccess() {} - -void MemoryUse::print(raw_ostream &OS) const { - MemoryAccess *UO = getDefiningAccess(); - OS << "MemoryUse("; - if (UO && UO->getID()) - OS << UO->getID(); - else - OS << LiveOnEntryStr; - OS << ')'; -} - -void MemoryAccess::dump() const { - print(dbgs()); - dbgs() << "\n"; -} - -char MemorySSAPrinterLegacyPass::ID = 0; - -MemorySSAPrinterLegacyPass::MemorySSAPrinterLegacyPass() : FunctionPass(ID) { - initializeMemorySSAPrinterLegacyPassPass(*PassRegistry::getPassRegistry()); -} - -void MemorySSAPrinterLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const { - AU.setPreservesAll(); - AU.addRequired<MemorySSAWrapperPass>(); - AU.addPreserved<MemorySSAWrapperPass>(); -} - -bool MemorySSAPrinterLegacyPass::runOnFunction(Function &F) { - auto &MSSA = getAnalysis<MemorySSAWrapperPass>().getMSSA(); - MSSA.print(dbgs()); - if (VerifyMemorySSA) - MSSA.verifyMemorySSA(); - return false; -} - -AnalysisKey MemorySSAAnalysis::Key; - -MemorySSAAnalysis::Result MemorySSAAnalysis::run(Function &F, - FunctionAnalysisManager &AM) { - auto &DT = AM.getResult<DominatorTreeAnalysis>(F); - auto &AA = AM.getResult<AAManager>(F); - return MemorySSAAnalysis::Result(make_unique<MemorySSA>(F, &AA, &DT)); -} - -PreservedAnalyses MemorySSAPrinterPass::run(Function &F, - FunctionAnalysisManager &AM) { - OS << "MemorySSA for function: " << F.getName() << "\n"; - AM.getResult<MemorySSAAnalysis>(F).getMSSA().print(OS); - - return PreservedAnalyses::all(); -} - -PreservedAnalyses MemorySSAVerifierPass::run(Function &F, - FunctionAnalysisManager &AM) { - AM.getResult<MemorySSAAnalysis>(F).getMSSA().verifyMemorySSA(); - - return PreservedAnalyses::all(); -} - -char MemorySSAWrapperPass::ID = 0; - -MemorySSAWrapperPass::MemorySSAWrapperPass() : FunctionPass(ID) { - initializeMemorySSAWrapperPassPass(*PassRegistry::getPassRegistry()); -} - -void MemorySSAWrapperPass::releaseMemory() { MSSA.reset(); } - -void MemorySSAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { - AU.setPreservesAll(); - AU.addRequiredTransitive<DominatorTreeWrapperPass>(); - AU.addRequiredTransitive<AAResultsWrapperPass>(); -} - -bool MemorySSAWrapperPass::runOnFunction(Function &F) { - auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); - MSSA.reset(new MemorySSA(F, &AA, &DT)); - return false; -} - -void MemorySSAWrapperPass::verifyAnalysis() const { MSSA->verifyMemorySSA(); } - -void MemorySSAWrapperPass::print(raw_ostream &OS, const Module *M) const { - MSSA->print(OS); -} - -MemorySSAWalker::MemorySSAWalker(MemorySSA *M) : MSSA(M) {} - -MemorySSA::CachingWalker::CachingWalker(MemorySSA *M, AliasAnalysis *A, - DominatorTree *D) - : MemorySSAWalker(M), Walker(*M, *A, *D, Cache), AutoResetWalker(true) {} - -MemorySSA::CachingWalker::~CachingWalker() {} - -void MemorySSA::CachingWalker::invalidateInfo(MemoryAccess *MA) { - // TODO: We can do much better cache invalidation with differently stored - // caches. For now, for MemoryUses, we simply remove them - // from the cache, and kill the entire call/non-call cache for everything - // else. The problem is for phis or defs, currently we'd need to follow use - // chains down and invalidate anything below us in the chain that currently - // terminates at this access. - - // See if this is a MemoryUse, if so, just remove the cached info. MemoryUse - // is by definition never a barrier, so nothing in the cache could point to - // this use. In that case, we only need invalidate the info for the use - // itself. - - if (MemoryUse *MU = dyn_cast<MemoryUse>(MA)) { - UpwardsMemoryQuery Q(MU->getMemoryInst(), MU); - Cache.remove(MU, Q.StartingLoc, Q.IsCall); - MU->resetOptimized(); - } else { - // If it is not a use, the best we can do right now is destroy the cache. - Cache.clear(); - } - -#ifdef EXPENSIVE_CHECKS - verifyRemoved(MA); -#endif -} - -/// \brief Walk the use-def chains starting at \p MA and find -/// the MemoryAccess that actually clobbers Loc. -/// -/// \returns our clobbering memory access -MemoryAccess *MemorySSA::CachingWalker::getClobberingMemoryAccess( - MemoryAccess *StartingAccess, UpwardsMemoryQuery &Q) { - MemoryAccess *New = Walker.findClobber(StartingAccess, Q); -#ifdef EXPENSIVE_CHECKS - MemoryAccess *NewNoCache = - Walker.findClobber(StartingAccess, Q, /*UseWalkerCache=*/false); - assert(NewNoCache == New && "Cache made us hand back a different result?"); -#endif - if (AutoResetWalker) - resetClobberWalker(); - return New; -} - -MemoryAccess *MemorySSA::CachingWalker::getClobberingMemoryAccess( - MemoryAccess *StartingAccess, const MemoryLocation &Loc) { - if (isa<MemoryPhi>(StartingAccess)) - return StartingAccess; - - auto *StartingUseOrDef = cast<MemoryUseOrDef>(StartingAccess); - if (MSSA->isLiveOnEntryDef(StartingUseOrDef)) - return StartingUseOrDef; - - Instruction *I = StartingUseOrDef->getMemoryInst(); - - // Conservatively, fences are always clobbers, so don't perform the walk if we - // hit a fence. - if (!ImmutableCallSite(I) && I->isFenceLike()) - return StartingUseOrDef; - - UpwardsMemoryQuery Q; - Q.OriginalAccess = StartingUseOrDef; - Q.StartingLoc = Loc; - Q.Inst = I; - Q.IsCall = false; - - if (auto *CacheResult = Cache.lookup(StartingUseOrDef, Loc, Q.IsCall)) - return CacheResult; - - // Unlike the other function, do not walk to the def of a def, because we are - // handed something we already believe is the clobbering access. - MemoryAccess *DefiningAccess = isa<MemoryUse>(StartingUseOrDef) - ? StartingUseOrDef->getDefiningAccess() - : StartingUseOrDef; - - MemoryAccess *Clobber = getClobberingMemoryAccess(DefiningAccess, Q); - DEBUG(dbgs() << "Starting Memory SSA clobber for " << *I << " is "); - DEBUG(dbgs() << *StartingUseOrDef << "\n"); - DEBUG(dbgs() << "Final Memory SSA clobber for " << *I << " is "); - DEBUG(dbgs() << *Clobber << "\n"); - return Clobber; -} - -MemoryAccess * -MemorySSA::CachingWalker::getClobberingMemoryAccess(MemoryAccess *MA) { - auto *StartingAccess = dyn_cast<MemoryUseOrDef>(MA); - // If this is a MemoryPhi, we can't do anything. - if (!StartingAccess) - return MA; - - // If this is an already optimized use or def, return the optimized result. - // Note: Currently, we do not store the optimized def result because we'd need - // a separate field, since we can't use it as the defining access. - if (MemoryUse *MU = dyn_cast<MemoryUse>(StartingAccess)) - if (MU->isOptimized()) - return MU->getDefiningAccess(); - - const Instruction *I = StartingAccess->getMemoryInst(); - UpwardsMemoryQuery Q(I, StartingAccess); - // We can't sanely do anything with a fences, they conservatively - // clobber all memory, and have no locations to get pointers from to - // try to disambiguate. - if (!Q.IsCall && I->isFenceLike()) - return StartingAccess; - - if (auto *CacheResult = Cache.lookup(StartingAccess, Q.StartingLoc, Q.IsCall)) - return CacheResult; - - if (isUseTriviallyOptimizableToLiveOnEntry(*MSSA->AA, I)) { - MemoryAccess *LiveOnEntry = MSSA->getLiveOnEntryDef(); - Cache.insert(StartingAccess, LiveOnEntry, Q.StartingLoc, Q.IsCall); - if (MemoryUse *MU = dyn_cast<MemoryUse>(StartingAccess)) - MU->setDefiningAccess(LiveOnEntry, true); - return LiveOnEntry; - } - - // Start with the thing we already think clobbers this location - MemoryAccess *DefiningAccess = StartingAccess->getDefiningAccess(); - - // At this point, DefiningAccess may be the live on entry def. - // If it is, we will not get a better result. - if (MSSA->isLiveOnEntryDef(DefiningAccess)) - return DefiningAccess; - - MemoryAccess *Result = getClobberingMemoryAccess(DefiningAccess, Q); - DEBUG(dbgs() << "Starting Memory SSA clobber for " << *I << " is "); - DEBUG(dbgs() << *DefiningAccess << "\n"); - DEBUG(dbgs() << "Final Memory SSA clobber for " << *I << " is "); - DEBUG(dbgs() << *Result << "\n"); - if (MemoryUse *MU = dyn_cast<MemoryUse>(StartingAccess)) - MU->setDefiningAccess(Result, true); - - return Result; -} - -// Verify that MA doesn't exist in any of the caches. -void MemorySSA::CachingWalker::verifyRemoved(MemoryAccess *MA) { - assert(!Cache.contains(MA) && "Found removed MemoryAccess in cache."); -} - -MemoryAccess * -DoNothingMemorySSAWalker::getClobberingMemoryAccess(MemoryAccess *MA) { - if (auto *Use = dyn_cast<MemoryUseOrDef>(MA)) - return Use->getDefiningAccess(); - return MA; -} - -MemoryAccess *DoNothingMemorySSAWalker::getClobberingMemoryAccess( - MemoryAccess *StartingAccess, const MemoryLocation &) { - if (auto *Use = dyn_cast<MemoryUseOrDef>(StartingAccess)) - return Use->getDefiningAccess(); - return StartingAccess; -} -} // namespace llvm diff --git a/contrib/llvm/lib/Transforms/Utils/MetaRenamer.cpp b/contrib/llvm/lib/Transforms/Utils/MetaRenamer.cpp index c999bd0..9f2ad54 100644 --- a/contrib/llvm/lib/Transforms/Utils/MetaRenamer.cpp +++ b/contrib/llvm/lib/Transforms/Utils/MetaRenamer.cpp @@ -13,15 +13,16 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/IPO.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallString.h" +#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" #include "llvm/IR/TypeFinder.h" #include "llvm/Pass.h" +#include "llvm/Transforms/IPO.h" using namespace llvm; namespace { @@ -67,6 +68,7 @@ namespace { } void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<TargetLibraryInfoWrapperPass>(); AU.setPreservesAll(); } @@ -110,9 +112,15 @@ namespace { } // Rename all functions + const TargetLibraryInfo &TLI = + getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); for (auto &F : M) { StringRef Name = F.getName(); - if (Name.startswith("llvm.") || (!Name.empty() && Name[0] == 1)) + LibFunc Tmp; + // Leave library functions alone because their presence or absence could + // affect the behavior of other passes. + if (Name.startswith("llvm.") || (!Name.empty() && Name[0] == 1) || + TLI.getLibFunc(F, Tmp)) continue; F.setName(renamer.newName()); @@ -139,8 +147,11 @@ namespace { } char MetaRenamer::ID = 0; -INITIALIZE_PASS(MetaRenamer, "metarenamer", - "Assign new names to everything", false, false) +INITIALIZE_PASS_BEGIN(MetaRenamer, "metarenamer", + "Assign new names to everything", false, false) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(MetaRenamer, "metarenamer", + "Assign new names to everything", false, false) //===----------------------------------------------------------------------===// // // MetaRenamer - Rename everything with metasyntactic names. diff --git a/contrib/llvm/lib/Transforms/Utils/ModuleUtils.cpp b/contrib/llvm/lib/Transforms/Utils/ModuleUtils.cpp index 0d623df..2ef3d63 100644 --- a/contrib/llvm/lib/Transforms/Utils/ModuleUtils.cpp +++ b/contrib/llvm/lib/Transforms/Utils/ModuleUtils.cpp @@ -35,7 +35,7 @@ static void appendToGlobalArray(const char *Array, Module &M, Function *F, // Upgrade a 2-field global array type to the new 3-field format if needed. if (Data && OldEltTy->getNumElements() < 3) EltTy = StructType::get(IRB.getInt32Ty(), PointerType::getUnqual(FnTy), - IRB.getInt8PtrTy(), nullptr); + IRB.getInt8PtrTy()); else EltTy = OldEltTy; if (Constant *Init = GVCtor->getInitializer()) { @@ -44,10 +44,10 @@ static void appendToGlobalArray(const char *Array, Module &M, Function *F, for (unsigned i = 0; i != n; ++i) { auto Ctor = cast<Constant>(Init->getOperand(i)); if (EltTy != OldEltTy) - Ctor = ConstantStruct::get( - EltTy, Ctor->getAggregateElement((unsigned)0), - Ctor->getAggregateElement(1), - Constant::getNullValue(IRB.getInt8PtrTy()), nullptr); + Ctor = + ConstantStruct::get(EltTy, Ctor->getAggregateElement((unsigned)0), + Ctor->getAggregateElement(1), + Constant::getNullValue(IRB.getInt8PtrTy())); CurrentCtors.push_back(Ctor); } } @@ -55,7 +55,7 @@ static void appendToGlobalArray(const char *Array, Module &M, Function *F, } else { // Use the new three-field struct if there isn't one already. EltTy = StructType::get(IRB.getInt32Ty(), PointerType::getUnqual(FnTy), - IRB.getInt8PtrTy(), nullptr); + IRB.getInt8PtrTy()); } // Build a 2 or 3 field global_ctor entry. We don't take a comdat key. @@ -130,13 +130,25 @@ void llvm::appendToCompilerUsed(Module &M, ArrayRef<GlobalValue *> Values) { Function *llvm::checkSanitizerInterfaceFunction(Constant *FuncOrBitcast) { if (isa<Function>(FuncOrBitcast)) return cast<Function>(FuncOrBitcast); - FuncOrBitcast->dump(); + FuncOrBitcast->print(errs()); + errs() << '\n'; std::string Err; raw_string_ostream Stream(Err); Stream << "Sanitizer interface function redefined: " << *FuncOrBitcast; report_fatal_error(Err); } +Function *llvm::declareSanitizerInitFunction(Module &M, StringRef InitName, + ArrayRef<Type *> InitArgTypes) { + assert(!InitName.empty() && "Expected init function name"); + Function *F = checkSanitizerInterfaceFunction(M.getOrInsertFunction( + InitName, + FunctionType::get(Type::getVoidTy(M.getContext()), InitArgTypes, false), + AttributeList())); + F->setLinkage(Function::ExternalLinkage); + return F; +} + std::pair<Function *, Function *> llvm::createSanitizerCtorAndInitFunctions( Module &M, StringRef CtorName, StringRef InitName, ArrayRef<Type *> InitArgTypes, ArrayRef<Value *> InitArgs, @@ -144,22 +156,19 @@ std::pair<Function *, Function *> llvm::createSanitizerCtorAndInitFunctions( assert(!InitName.empty() && "Expected init function name"); assert(InitArgs.size() == InitArgTypes.size() && "Sanitizer's init function expects different number of arguments"); + Function *InitFunction = + declareSanitizerInitFunction(M, InitName, InitArgTypes); Function *Ctor = Function::Create( FunctionType::get(Type::getVoidTy(M.getContext()), false), GlobalValue::InternalLinkage, CtorName, &M); BasicBlock *CtorBB = BasicBlock::Create(M.getContext(), "", Ctor); IRBuilder<> IRB(ReturnInst::Create(M.getContext(), CtorBB)); - Function *InitFunction = - checkSanitizerInterfaceFunction(M.getOrInsertFunction( - InitName, FunctionType::get(IRB.getVoidTy(), InitArgTypes, false), - AttributeSet())); - InitFunction->setLinkage(Function::ExternalLinkage); IRB.CreateCall(InitFunction, InitArgs); if (!VersionCheckName.empty()) { Function *VersionCheckFunction = checkSanitizerInterfaceFunction(M.getOrInsertFunction( VersionCheckName, FunctionType::get(IRB.getVoidTy(), {}, false), - AttributeSet())); + AttributeList())); IRB.CreateCall(VersionCheckFunction, {}); } return std::make_pair(Ctor, InitFunction); @@ -228,3 +237,35 @@ void llvm::filterDeadComdatFunctions( ComdatEntriesCovered.end(); }); } + +std::string llvm::getUniqueModuleId(Module *M) { + MD5 Md5; + bool ExportsSymbols = false; + auto AddGlobal = [&](GlobalValue &GV) { + if (GV.isDeclaration() || GV.getName().startswith("llvm.") || + !GV.hasExternalLinkage()) + return; + ExportsSymbols = true; + Md5.update(GV.getName()); + Md5.update(ArrayRef<uint8_t>{0}); + }; + + for (auto &F : *M) + AddGlobal(F); + for (auto &GV : M->globals()) + AddGlobal(GV); + for (auto &GA : M->aliases()) + AddGlobal(GA); + for (auto &IF : M->ifuncs()) + AddGlobal(IF); + + if (!ExportsSymbols) + return ""; + + MD5::MD5Result R; + Md5.final(R); + + SmallString<32> Str; + MD5::stringifyResult(R, Str); + return ("$" + Str).str(); +} diff --git a/contrib/llvm/lib/Transforms/Utils/OrderedInstructions.cpp b/contrib/llvm/lib/Transforms/Utils/OrderedInstructions.cpp new file mode 100644 index 0000000..dc78054 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Utils/OrderedInstructions.cpp @@ -0,0 +1,32 @@ +//===-- OrderedInstructions.cpp - Instruction dominance function ---------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines utility to check dominance relation of 2 instructions. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/OrderedInstructions.h" +using namespace llvm; + +/// Given 2 instructions, use OrderedBasicBlock to check for dominance relation +/// if the instructions are in the same basic block, Otherwise, use dominator +/// tree. +bool OrderedInstructions::dominates(const Instruction *InstA, + const Instruction *InstB) const { + const BasicBlock *IBB = InstA->getParent(); + // Use ordered basic block to do dominance check in case the 2 instructions + // are in the same basic block. + if (IBB == InstB->getParent()) { + auto OBB = OBBMap.find(IBB); + if (OBB == OBBMap.end()) + OBB = OBBMap.insert({IBB, make_unique<OrderedBasicBlock>(IBB)}).first; + return OBB->second->dominates(InstA, InstB); + } + return DT->dominates(InstA->getParent(), InstB->getParent()); +} diff --git a/contrib/llvm/lib/Transforms/Utils/PredicateInfo.cpp b/contrib/llvm/lib/Transforms/Utils/PredicateInfo.cpp new file mode 100644 index 0000000..d4cdaed --- /dev/null +++ b/contrib/llvm/lib/Transforms/Utils/PredicateInfo.cpp @@ -0,0 +1,793 @@ +//===-- PredicateInfo.cpp - PredicateInfo Builder--------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------===// +// +// This file implements the PredicateInfo class. +// +//===----------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/PredicateInfo.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/CFG.h" +#include "llvm/IR/AssemblyAnnotationWriter.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugCounter.h" +#include "llvm/Support/FormattedStream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/OrderedInstructions.h" +#include <algorithm> +#define DEBUG_TYPE "predicateinfo" +using namespace llvm; +using namespace PatternMatch; +using namespace llvm::PredicateInfoClasses; + +INITIALIZE_PASS_BEGIN(PredicateInfoPrinterLegacyPass, "print-predicateinfo", + "PredicateInfo Printer", false, false) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_END(PredicateInfoPrinterLegacyPass, "print-predicateinfo", + "PredicateInfo Printer", false, false) +static cl::opt<bool> VerifyPredicateInfo( + "verify-predicateinfo", cl::init(false), cl::Hidden, + cl::desc("Verify PredicateInfo in legacy printer pass.")); +namespace { +DEBUG_COUNTER(RenameCounter, "predicateinfo-rename", + "Controls which variables are renamed with predicateinfo") +// Given a predicate info that is a type of branching terminator, get the +// branching block. +const BasicBlock *getBranchBlock(const PredicateBase *PB) { + assert(isa<PredicateWithEdge>(PB) && + "Only branches and switches should have PHIOnly defs that " + "require branch blocks."); + return cast<PredicateWithEdge>(PB)->From; +} + +// Given a predicate info that is a type of branching terminator, get the +// branching terminator. +static Instruction *getBranchTerminator(const PredicateBase *PB) { + assert(isa<PredicateWithEdge>(PB) && + "Not a predicate info type we know how to get a terminator from."); + return cast<PredicateWithEdge>(PB)->From->getTerminator(); +} + +// Given a predicate info that is a type of branching terminator, get the +// edge this predicate info represents +const std::pair<BasicBlock *, BasicBlock *> +getBlockEdge(const PredicateBase *PB) { + assert(isa<PredicateWithEdge>(PB) && + "Not a predicate info type we know how to get an edge from."); + const auto *PEdge = cast<PredicateWithEdge>(PB); + return std::make_pair(PEdge->From, PEdge->To); +} +} + +namespace llvm { +namespace PredicateInfoClasses { +enum LocalNum { + // Operations that must appear first in the block. + LN_First, + // Operations that are somewhere in the middle of the block, and are sorted on + // demand. + LN_Middle, + // Operations that must appear last in a block, like successor phi node uses. + LN_Last +}; + +// Associate global and local DFS info with defs and uses, so we can sort them +// into a global domination ordering. +struct ValueDFS { + int DFSIn = 0; + int DFSOut = 0; + unsigned int LocalNum = LN_Middle; + // Only one of Def or Use will be set. + Value *Def = nullptr; + Use *U = nullptr; + // Neither PInfo nor EdgeOnly participate in the ordering + PredicateBase *PInfo = nullptr; + bool EdgeOnly = false; +}; + +// Perform a strict weak ordering on instructions and arguments. +static bool valueComesBefore(OrderedInstructions &OI, const Value *A, + const Value *B) { + auto *ArgA = dyn_cast_or_null<Argument>(A); + auto *ArgB = dyn_cast_or_null<Argument>(B); + if (ArgA && !ArgB) + return true; + if (ArgB && !ArgA) + return false; + if (ArgA && ArgB) + return ArgA->getArgNo() < ArgB->getArgNo(); + return OI.dominates(cast<Instruction>(A), cast<Instruction>(B)); +} + +// This compares ValueDFS structures, creating OrderedBasicBlocks where +// necessary to compare uses/defs in the same block. Doing so allows us to walk +// the minimum number of instructions necessary to compute our def/use ordering. +struct ValueDFS_Compare { + OrderedInstructions &OI; + ValueDFS_Compare(OrderedInstructions &OI) : OI(OI) {} + + bool operator()(const ValueDFS &A, const ValueDFS &B) const { + if (&A == &B) + return false; + // The only case we can't directly compare them is when they in the same + // block, and both have localnum == middle. In that case, we have to use + // comesbefore to see what the real ordering is, because they are in the + // same basic block. + + bool SameBlock = std::tie(A.DFSIn, A.DFSOut) == std::tie(B.DFSIn, B.DFSOut); + + // We want to put the def that will get used for a given set of phi uses, + // before those phi uses. + // So we sort by edge, then by def. + // Note that only phi nodes uses and defs can come last. + if (SameBlock && A.LocalNum == LN_Last && B.LocalNum == LN_Last) + return comparePHIRelated(A, B); + + if (!SameBlock || A.LocalNum != LN_Middle || B.LocalNum != LN_Middle) + return std::tie(A.DFSIn, A.DFSOut, A.LocalNum, A.Def, A.U) < + std::tie(B.DFSIn, B.DFSOut, B.LocalNum, B.Def, B.U); + return localComesBefore(A, B); + } + + // For a phi use, or a non-materialized def, return the edge it represents. + const std::pair<BasicBlock *, BasicBlock *> + getBlockEdge(const ValueDFS &VD) const { + if (!VD.Def && VD.U) { + auto *PHI = cast<PHINode>(VD.U->getUser()); + return std::make_pair(PHI->getIncomingBlock(*VD.U), PHI->getParent()); + } + // This is really a non-materialized def. + return ::getBlockEdge(VD.PInfo); + } + + // For two phi related values, return the ordering. + bool comparePHIRelated(const ValueDFS &A, const ValueDFS &B) const { + auto &ABlockEdge = getBlockEdge(A); + auto &BBlockEdge = getBlockEdge(B); + // Now sort by block edge and then defs before uses. + return std::tie(ABlockEdge, A.Def, A.U) < std::tie(BBlockEdge, B.Def, B.U); + } + + // Get the definition of an instruction that occurs in the middle of a block. + Value *getMiddleDef(const ValueDFS &VD) const { + if (VD.Def) + return VD.Def; + // It's possible for the defs and uses to be null. For branches, the local + // numbering will say the placed predicaeinfos should go first (IE + // LN_beginning), so we won't be in this function. For assumes, we will end + // up here, beause we need to order the def we will place relative to the + // assume. So for the purpose of ordering, we pretend the def is the assume + // because that is where we will insert the info. + if (!VD.U) { + assert(VD.PInfo && + "No def, no use, and no predicateinfo should not occur"); + assert(isa<PredicateAssume>(VD.PInfo) && + "Middle of block should only occur for assumes"); + return cast<PredicateAssume>(VD.PInfo)->AssumeInst; + } + return nullptr; + } + + // Return either the Def, if it's not null, or the user of the Use, if the def + // is null. + const Instruction *getDefOrUser(const Value *Def, const Use *U) const { + if (Def) + return cast<Instruction>(Def); + return cast<Instruction>(U->getUser()); + } + + // This performs the necessary local basic block ordering checks to tell + // whether A comes before B, where both are in the same basic block. + bool localComesBefore(const ValueDFS &A, const ValueDFS &B) const { + auto *ADef = getMiddleDef(A); + auto *BDef = getMiddleDef(B); + + // See if we have real values or uses. If we have real values, we are + // guaranteed they are instructions or arguments. No matter what, we are + // guaranteed they are in the same block if they are instructions. + auto *ArgA = dyn_cast_or_null<Argument>(ADef); + auto *ArgB = dyn_cast_or_null<Argument>(BDef); + + if (ArgA || ArgB) + return valueComesBefore(OI, ArgA, ArgB); + + auto *AInst = getDefOrUser(ADef, A.U); + auto *BInst = getDefOrUser(BDef, B.U); + return valueComesBefore(OI, AInst, BInst); + } +}; + +} // namespace PredicateInfoClasses + +bool PredicateInfo::stackIsInScope(const ValueDFSStack &Stack, + const ValueDFS &VDUse) const { + if (Stack.empty()) + return false; + // If it's a phi only use, make sure it's for this phi node edge, and that the + // use is in a phi node. If it's anything else, and the top of the stack is + // EdgeOnly, we need to pop the stack. We deliberately sort phi uses next to + // the defs they must go with so that we can know it's time to pop the stack + // when we hit the end of the phi uses for a given def. + if (Stack.back().EdgeOnly) { + if (!VDUse.U) + return false; + auto *PHI = dyn_cast<PHINode>(VDUse.U->getUser()); + if (!PHI) + return false; + // Check edge + BasicBlock *EdgePred = PHI->getIncomingBlock(*VDUse.U); + if (EdgePred != getBranchBlock(Stack.back().PInfo)) + return false; + + // Use dominates, which knows how to handle edge dominance. + return DT.dominates(getBlockEdge(Stack.back().PInfo), *VDUse.U); + } + + return (VDUse.DFSIn >= Stack.back().DFSIn && + VDUse.DFSOut <= Stack.back().DFSOut); +} + +void PredicateInfo::popStackUntilDFSScope(ValueDFSStack &Stack, + const ValueDFS &VD) { + while (!Stack.empty() && !stackIsInScope(Stack, VD)) + Stack.pop_back(); +} + +// Convert the uses of Op into a vector of uses, associating global and local +// DFS info with each one. +void PredicateInfo::convertUsesToDFSOrdered( + Value *Op, SmallVectorImpl<ValueDFS> &DFSOrderedSet) { + for (auto &U : Op->uses()) { + if (auto *I = dyn_cast<Instruction>(U.getUser())) { + ValueDFS VD; + // Put the phi node uses in the incoming block. + BasicBlock *IBlock; + if (auto *PN = dyn_cast<PHINode>(I)) { + IBlock = PN->getIncomingBlock(U); + // Make phi node users appear last in the incoming block + // they are from. + VD.LocalNum = LN_Last; + } else { + // If it's not a phi node use, it is somewhere in the middle of the + // block. + IBlock = I->getParent(); + VD.LocalNum = LN_Middle; + } + DomTreeNode *DomNode = DT.getNode(IBlock); + // It's possible our use is in an unreachable block. Skip it if so. + if (!DomNode) + continue; + VD.DFSIn = DomNode->getDFSNumIn(); + VD.DFSOut = DomNode->getDFSNumOut(); + VD.U = &U; + DFSOrderedSet.push_back(VD); + } + } +} + +// Collect relevant operations from Comparison that we may want to insert copies +// for. +void collectCmpOps(CmpInst *Comparison, SmallVectorImpl<Value *> &CmpOperands) { + auto *Op0 = Comparison->getOperand(0); + auto *Op1 = Comparison->getOperand(1); + if (Op0 == Op1) + return; + CmpOperands.push_back(Comparison); + // Only want real values, not constants. Additionally, operands with one use + // are only being used in the comparison, which means they will not be useful + // for us to consider for predicateinfo. + // + if ((isa<Instruction>(Op0) || isa<Argument>(Op0)) && !Op0->hasOneUse()) + CmpOperands.push_back(Op0); + if ((isa<Instruction>(Op1) || isa<Argument>(Op1)) && !Op1->hasOneUse()) + CmpOperands.push_back(Op1); +} + +// Add Op, PB to the list of value infos for Op, and mark Op to be renamed. +void PredicateInfo::addInfoFor(SmallPtrSetImpl<Value *> &OpsToRename, Value *Op, + PredicateBase *PB) { + OpsToRename.insert(Op); + auto &OperandInfo = getOrCreateValueInfo(Op); + AllInfos.push_back(PB); + OperandInfo.Infos.push_back(PB); +} + +// Process an assume instruction and place relevant operations we want to rename +// into OpsToRename. +void PredicateInfo::processAssume(IntrinsicInst *II, BasicBlock *AssumeBB, + SmallPtrSetImpl<Value *> &OpsToRename) { + // See if we have a comparison we support + SmallVector<Value *, 8> CmpOperands; + SmallVector<Value *, 2> ConditionsToProcess; + CmpInst::Predicate Pred; + Value *Operand = II->getOperand(0); + if (m_c_And(m_Cmp(Pred, m_Value(), m_Value()), + m_Cmp(Pred, m_Value(), m_Value())) + .match(II->getOperand(0))) { + ConditionsToProcess.push_back(cast<BinaryOperator>(Operand)->getOperand(0)); + ConditionsToProcess.push_back(cast<BinaryOperator>(Operand)->getOperand(1)); + ConditionsToProcess.push_back(Operand); + } else if (isa<CmpInst>(Operand)) { + + ConditionsToProcess.push_back(Operand); + } + for (auto Cond : ConditionsToProcess) { + if (auto *Cmp = dyn_cast<CmpInst>(Cond)) { + collectCmpOps(Cmp, CmpOperands); + // Now add our copy infos for our operands + for (auto *Op : CmpOperands) { + auto *PA = new PredicateAssume(Op, II, Cmp); + addInfoFor(OpsToRename, Op, PA); + } + CmpOperands.clear(); + } else if (auto *BinOp = dyn_cast<BinaryOperator>(Cond)) { + // Otherwise, it should be an AND. + assert(BinOp->getOpcode() == Instruction::And && + "Should have been an AND"); + auto *PA = new PredicateAssume(BinOp, II, BinOp); + addInfoFor(OpsToRename, BinOp, PA); + } else { + llvm_unreachable("Unknown type of condition"); + } + } +} + +// Process a block terminating branch, and place relevant operations to be +// renamed into OpsToRename. +void PredicateInfo::processBranch(BranchInst *BI, BasicBlock *BranchBB, + SmallPtrSetImpl<Value *> &OpsToRename) { + BasicBlock *FirstBB = BI->getSuccessor(0); + BasicBlock *SecondBB = BI->getSuccessor(1); + SmallVector<BasicBlock *, 2> SuccsToProcess; + SuccsToProcess.push_back(FirstBB); + SuccsToProcess.push_back(SecondBB); + SmallVector<Value *, 2> ConditionsToProcess; + + auto InsertHelper = [&](Value *Op, bool isAnd, bool isOr, Value *Cond) { + for (auto *Succ : SuccsToProcess) { + // Don't try to insert on a self-edge. This is mainly because we will + // eliminate during renaming anyway. + if (Succ == BranchBB) + continue; + bool TakenEdge = (Succ == FirstBB); + // For and, only insert on the true edge + // For or, only insert on the false edge + if ((isAnd && !TakenEdge) || (isOr && TakenEdge)) + continue; + PredicateBase *PB = + new PredicateBranch(Op, BranchBB, Succ, Cond, TakenEdge); + addInfoFor(OpsToRename, Op, PB); + if (!Succ->getSinglePredecessor()) + EdgeUsesOnly.insert({BranchBB, Succ}); + } + }; + + // Match combinations of conditions. + CmpInst::Predicate Pred; + bool isAnd = false; + bool isOr = false; + SmallVector<Value *, 8> CmpOperands; + if (match(BI->getCondition(), m_And(m_Cmp(Pred, m_Value(), m_Value()), + m_Cmp(Pred, m_Value(), m_Value()))) || + match(BI->getCondition(), m_Or(m_Cmp(Pred, m_Value(), m_Value()), + m_Cmp(Pred, m_Value(), m_Value())))) { + auto *BinOp = cast<BinaryOperator>(BI->getCondition()); + if (BinOp->getOpcode() == Instruction::And) + isAnd = true; + else if (BinOp->getOpcode() == Instruction::Or) + isOr = true; + ConditionsToProcess.push_back(BinOp->getOperand(0)); + ConditionsToProcess.push_back(BinOp->getOperand(1)); + ConditionsToProcess.push_back(BI->getCondition()); + } else if (isa<CmpInst>(BI->getCondition())) { + ConditionsToProcess.push_back(BI->getCondition()); + } + for (auto Cond : ConditionsToProcess) { + if (auto *Cmp = dyn_cast<CmpInst>(Cond)) { + collectCmpOps(Cmp, CmpOperands); + // Now add our copy infos for our operands + for (auto *Op : CmpOperands) + InsertHelper(Op, isAnd, isOr, Cmp); + } else if (auto *BinOp = dyn_cast<BinaryOperator>(Cond)) { + // This must be an AND or an OR. + assert((BinOp->getOpcode() == Instruction::And || + BinOp->getOpcode() == Instruction::Or) && + "Should have been an AND or an OR"); + // The actual value of the binop is not subject to the same restrictions + // as the comparison. It's either true or false on the true/false branch. + InsertHelper(BinOp, false, false, BinOp); + } else { + llvm_unreachable("Unknown type of condition"); + } + CmpOperands.clear(); + } +} +// Process a block terminating switch, and place relevant operations to be +// renamed into OpsToRename. +void PredicateInfo::processSwitch(SwitchInst *SI, BasicBlock *BranchBB, + SmallPtrSetImpl<Value *> &OpsToRename) { + Value *Op = SI->getCondition(); + if ((!isa<Instruction>(Op) && !isa<Argument>(Op)) || Op->hasOneUse()) + return; + + // Remember how many outgoing edges there are to every successor. + SmallDenseMap<BasicBlock *, unsigned, 16> SwitchEdges; + for (unsigned i = 0, e = SI->getNumSuccessors(); i != e; ++i) { + BasicBlock *TargetBlock = SI->getSuccessor(i); + ++SwitchEdges[TargetBlock]; + } + + // Now propagate info for each case value + for (auto C : SI->cases()) { + BasicBlock *TargetBlock = C.getCaseSuccessor(); + if (SwitchEdges.lookup(TargetBlock) == 1) { + PredicateSwitch *PS = new PredicateSwitch( + Op, SI->getParent(), TargetBlock, C.getCaseValue(), SI); + addInfoFor(OpsToRename, Op, PS); + if (!TargetBlock->getSinglePredecessor()) + EdgeUsesOnly.insert({BranchBB, TargetBlock}); + } + } +} + +// Build predicate info for our function +void PredicateInfo::buildPredicateInfo() { + DT.updateDFSNumbers(); + // Collect operands to rename from all conditional branch terminators, as well + // as assume statements. + SmallPtrSet<Value *, 8> OpsToRename; + for (auto DTN : depth_first(DT.getRootNode())) { + BasicBlock *BranchBB = DTN->getBlock(); + if (auto *BI = dyn_cast<BranchInst>(BranchBB->getTerminator())) { + if (!BI->isConditional()) + continue; + // Can't insert conditional information if they all go to the same place. + if (BI->getSuccessor(0) == BI->getSuccessor(1)) + continue; + processBranch(BI, BranchBB, OpsToRename); + } else if (auto *SI = dyn_cast<SwitchInst>(BranchBB->getTerminator())) { + processSwitch(SI, BranchBB, OpsToRename); + } + } + for (auto &Assume : AC.assumptions()) { + if (auto *II = dyn_cast_or_null<IntrinsicInst>(Assume)) + processAssume(II, II->getParent(), OpsToRename); + } + // Now rename all our operations. + renameUses(OpsToRename); +} + +// Given the renaming stack, make all the operands currently on the stack real +// by inserting them into the IR. Return the last operation's value. +Value *PredicateInfo::materializeStack(unsigned int &Counter, + ValueDFSStack &RenameStack, + Value *OrigOp) { + // Find the first thing we have to materialize + auto RevIter = RenameStack.rbegin(); + for (; RevIter != RenameStack.rend(); ++RevIter) + if (RevIter->Def) + break; + + size_t Start = RevIter - RenameStack.rbegin(); + // The maximum number of things we should be trying to materialize at once + // right now is 4, depending on if we had an assume, a branch, and both used + // and of conditions. + for (auto RenameIter = RenameStack.end() - Start; + RenameIter != RenameStack.end(); ++RenameIter) { + auto *Op = + RenameIter == RenameStack.begin() ? OrigOp : (RenameIter - 1)->Def; + ValueDFS &Result = *RenameIter; + auto *ValInfo = Result.PInfo; + // For edge predicates, we can just place the operand in the block before + // the terminator. For assume, we have to place it right before the assume + // to ensure we dominate all of our uses. Always insert right before the + // relevant instruction (terminator, assume), so that we insert in proper + // order in the case of multiple predicateinfo in the same block. + if (isa<PredicateWithEdge>(ValInfo)) { + IRBuilder<> B(getBranchTerminator(ValInfo)); + Function *IF = Intrinsic::getDeclaration( + F.getParent(), Intrinsic::ssa_copy, Op->getType()); + CallInst *PIC = + B.CreateCall(IF, Op, Op->getName() + "." + Twine(Counter++)); + PredicateMap.insert({PIC, ValInfo}); + Result.Def = PIC; + } else { + auto *PAssume = dyn_cast<PredicateAssume>(ValInfo); + assert(PAssume && + "Should not have gotten here without it being an assume"); + IRBuilder<> B(PAssume->AssumeInst); + Function *IF = Intrinsic::getDeclaration( + F.getParent(), Intrinsic::ssa_copy, Op->getType()); + CallInst *PIC = B.CreateCall(IF, Op); + PredicateMap.insert({PIC, ValInfo}); + Result.Def = PIC; + } + } + return RenameStack.back().Def; +} + +// Instead of the standard SSA renaming algorithm, which is O(Number of +// instructions), and walks the entire dominator tree, we walk only the defs + +// uses. The standard SSA renaming algorithm does not really rely on the +// dominator tree except to order the stack push/pops of the renaming stacks, so +// that defs end up getting pushed before hitting the correct uses. This does +// not require the dominator tree, only the *order* of the dominator tree. The +// complete and correct ordering of the defs and uses, in dominator tree is +// contained in the DFS numbering of the dominator tree. So we sort the defs and +// uses into the DFS ordering, and then just use the renaming stack as per +// normal, pushing when we hit a def (which is a predicateinfo instruction), +// popping when we are out of the dfs scope for that def, and replacing any uses +// with top of stack if it exists. In order to handle liveness without +// propagating liveness info, we don't actually insert the predicateinfo +// instruction def until we see a use that it would dominate. Once we see such +// a use, we materialize the predicateinfo instruction in the right place and +// use it. +// +// TODO: Use this algorithm to perform fast single-variable renaming in +// promotememtoreg and memoryssa. +void PredicateInfo::renameUses(SmallPtrSetImpl<Value *> &OpSet) { + // Sort OpsToRename since we are going to iterate it. + SmallVector<Value *, 8> OpsToRename(OpSet.begin(), OpSet.end()); + auto Comparator = [&](const Value *A, const Value *B) { + return valueComesBefore(OI, A, B); + }; + std::sort(OpsToRename.begin(), OpsToRename.end(), Comparator); + ValueDFS_Compare Compare(OI); + // Compute liveness, and rename in O(uses) per Op. + for (auto *Op : OpsToRename) { + unsigned Counter = 0; + SmallVector<ValueDFS, 16> OrderedUses; + const auto &ValueInfo = getValueInfo(Op); + // Insert the possible copies into the def/use list. + // They will become real copies if we find a real use for them, and never + // created otherwise. + for (auto &PossibleCopy : ValueInfo.Infos) { + ValueDFS VD; + // Determine where we are going to place the copy by the copy type. + // The predicate info for branches always come first, they will get + // materialized in the split block at the top of the block. + // The predicate info for assumes will be somewhere in the middle, + // it will get materialized in front of the assume. + if (const auto *PAssume = dyn_cast<PredicateAssume>(PossibleCopy)) { + VD.LocalNum = LN_Middle; + DomTreeNode *DomNode = DT.getNode(PAssume->AssumeInst->getParent()); + if (!DomNode) + continue; + VD.DFSIn = DomNode->getDFSNumIn(); + VD.DFSOut = DomNode->getDFSNumOut(); + VD.PInfo = PossibleCopy; + OrderedUses.push_back(VD); + } else if (isa<PredicateWithEdge>(PossibleCopy)) { + // If we can only do phi uses, we treat it like it's in the branch + // block, and handle it specially. We know that it goes last, and only + // dominate phi uses. + auto BlockEdge = getBlockEdge(PossibleCopy); + if (EdgeUsesOnly.count(BlockEdge)) { + VD.LocalNum = LN_Last; + auto *DomNode = DT.getNode(BlockEdge.first); + if (DomNode) { + VD.DFSIn = DomNode->getDFSNumIn(); + VD.DFSOut = DomNode->getDFSNumOut(); + VD.PInfo = PossibleCopy; + VD.EdgeOnly = true; + OrderedUses.push_back(VD); + } + } else { + // Otherwise, we are in the split block (even though we perform + // insertion in the branch block). + // Insert a possible copy at the split block and before the branch. + VD.LocalNum = LN_First; + auto *DomNode = DT.getNode(BlockEdge.second); + if (DomNode) { + VD.DFSIn = DomNode->getDFSNumIn(); + VD.DFSOut = DomNode->getDFSNumOut(); + VD.PInfo = PossibleCopy; + OrderedUses.push_back(VD); + } + } + } + } + + convertUsesToDFSOrdered(Op, OrderedUses); + std::sort(OrderedUses.begin(), OrderedUses.end(), Compare); + SmallVector<ValueDFS, 8> RenameStack; + // For each use, sorted into dfs order, push values and replaces uses with + // top of stack, which will represent the reaching def. + for (auto &VD : OrderedUses) { + // We currently do not materialize copy over copy, but we should decide if + // we want to. + bool PossibleCopy = VD.PInfo != nullptr; + if (RenameStack.empty()) { + DEBUG(dbgs() << "Rename Stack is empty\n"); + } else { + DEBUG(dbgs() << "Rename Stack Top DFS numbers are (" + << RenameStack.back().DFSIn << "," + << RenameStack.back().DFSOut << ")\n"); + } + + DEBUG(dbgs() << "Current DFS numbers are (" << VD.DFSIn << "," + << VD.DFSOut << ")\n"); + + bool ShouldPush = (VD.Def || PossibleCopy); + bool OutOfScope = !stackIsInScope(RenameStack, VD); + if (OutOfScope || ShouldPush) { + // Sync to our current scope. + popStackUntilDFSScope(RenameStack, VD); + if (ShouldPush) { + RenameStack.push_back(VD); + } + } + // If we get to this point, and the stack is empty we must have a use + // with no renaming needed, just skip it. + if (RenameStack.empty()) + continue; + // Skip values, only want to rename the uses + if (VD.Def || PossibleCopy) + continue; + if (!DebugCounter::shouldExecute(RenameCounter)) { + DEBUG(dbgs() << "Skipping execution due to debug counter\n"); + continue; + } + ValueDFS &Result = RenameStack.back(); + + // If the possible copy dominates something, materialize our stack up to + // this point. This ensures every comparison that affects our operation + // ends up with predicateinfo. + if (!Result.Def) + Result.Def = materializeStack(Counter, RenameStack, Op); + + DEBUG(dbgs() << "Found replacement " << *Result.Def << " for " + << *VD.U->get() << " in " << *(VD.U->getUser()) << "\n"); + assert(DT.dominates(cast<Instruction>(Result.Def), *VD.U) && + "Predicateinfo def should have dominated this use"); + VD.U->set(Result.Def); + } + } +} + +PredicateInfo::ValueInfo &PredicateInfo::getOrCreateValueInfo(Value *Operand) { + auto OIN = ValueInfoNums.find(Operand); + if (OIN == ValueInfoNums.end()) { + // This will grow it + ValueInfos.resize(ValueInfos.size() + 1); + // This will use the new size and give us a 0 based number of the info + auto InsertResult = ValueInfoNums.insert({Operand, ValueInfos.size() - 1}); + assert(InsertResult.second && "Value info number already existed?"); + return ValueInfos[InsertResult.first->second]; + } + return ValueInfos[OIN->second]; +} + +const PredicateInfo::ValueInfo & +PredicateInfo::getValueInfo(Value *Operand) const { + auto OINI = ValueInfoNums.lookup(Operand); + assert(OINI != 0 && "Operand was not really in the Value Info Numbers"); + assert(OINI < ValueInfos.size() && + "Value Info Number greater than size of Value Info Table"); + return ValueInfos[OINI]; +} + +PredicateInfo::PredicateInfo(Function &F, DominatorTree &DT, + AssumptionCache &AC) + : F(F), DT(DT), AC(AC), OI(&DT) { + // Push an empty operand info so that we can detect 0 as not finding one + ValueInfos.resize(1); + buildPredicateInfo(); +} + +PredicateInfo::~PredicateInfo() {} + +void PredicateInfo::verifyPredicateInfo() const {} + +char PredicateInfoPrinterLegacyPass::ID = 0; + +PredicateInfoPrinterLegacyPass::PredicateInfoPrinterLegacyPass() + : FunctionPass(ID) { + initializePredicateInfoPrinterLegacyPassPass( + *PassRegistry::getPassRegistry()); +} + +void PredicateInfoPrinterLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequiredTransitive<DominatorTreeWrapperPass>(); + AU.addRequired<AssumptionCacheTracker>(); +} + +bool PredicateInfoPrinterLegacyPass::runOnFunction(Function &F) { + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + auto PredInfo = make_unique<PredicateInfo>(F, DT, AC); + PredInfo->print(dbgs()); + if (VerifyPredicateInfo) + PredInfo->verifyPredicateInfo(); + return false; +} + +PreservedAnalyses PredicateInfoPrinterPass::run(Function &F, + FunctionAnalysisManager &AM) { + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &AC = AM.getResult<AssumptionAnalysis>(F); + OS << "PredicateInfo for function: " << F.getName() << "\n"; + make_unique<PredicateInfo>(F, DT, AC)->print(OS); + + return PreservedAnalyses::all(); +} + +/// \brief An assembly annotator class to print PredicateInfo information in +/// comments. +class PredicateInfoAnnotatedWriter : public AssemblyAnnotationWriter { + friend class PredicateInfo; + const PredicateInfo *PredInfo; + +public: + PredicateInfoAnnotatedWriter(const PredicateInfo *M) : PredInfo(M) {} + + virtual void emitBasicBlockStartAnnot(const BasicBlock *BB, + formatted_raw_ostream &OS) {} + + virtual void emitInstructionAnnot(const Instruction *I, + formatted_raw_ostream &OS) { + if (const auto *PI = PredInfo->getPredicateInfoFor(I)) { + OS << "; Has predicate info\n"; + if (const auto *PB = dyn_cast<PredicateBranch>(PI)) { + OS << "; branch predicate info { TrueEdge: " << PB->TrueEdge + << " Comparison:" << *PB->Condition << " Edge: ["; + PB->From->printAsOperand(OS); + OS << ","; + PB->To->printAsOperand(OS); + OS << "] }\n"; + } else if (const auto *PS = dyn_cast<PredicateSwitch>(PI)) { + OS << "; switch predicate info { CaseValue: " << *PS->CaseValue + << " Switch:" << *PS->Switch << " Edge: ["; + PS->From->printAsOperand(OS); + OS << ","; + PS->To->printAsOperand(OS); + OS << "] }\n"; + } else if (const auto *PA = dyn_cast<PredicateAssume>(PI)) { + OS << "; assume predicate info {" + << " Comparison:" << *PA->Condition << " }\n"; + } + } + } +}; + +void PredicateInfo::print(raw_ostream &OS) const { + PredicateInfoAnnotatedWriter Writer(this); + F.print(OS, &Writer); +} + +void PredicateInfo::dump() const { + PredicateInfoAnnotatedWriter Writer(this); + F.print(dbgs(), &Writer); +} + +PreservedAnalyses PredicateInfoVerifierPass::run(Function &F, + FunctionAnalysisManager &AM) { + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &AC = AM.getResult<AssumptionAnalysis>(F); + make_unique<PredicateInfo>(F, DT, AC)->verifyPredicateInfo(); + + return PreservedAnalyses::all(); +} +} diff --git a/contrib/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp b/contrib/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp index 35faa6f..cdba982 100644 --- a/contrib/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp +++ b/contrib/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp @@ -15,7 +15,6 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Utils/PromoteMemToReg.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" @@ -23,6 +22,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasSetTracker.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/IteratedDominanceFrontier.h" #include "llvm/Analysis/ValueTracking.h" @@ -38,6 +38,7 @@ #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/PromoteMemToReg.h" #include <algorithm> using namespace llvm; @@ -224,13 +225,10 @@ struct PromoteMem2Reg { std::vector<AllocaInst *> Allocas; DominatorTree &DT; DIBuilder DIB; - - /// An AliasSetTracker object to update. If null, don't update it. - AliasSetTracker *AST; - /// A cache of @llvm.assume intrinsics used by SimplifyInstruction. AssumptionCache *AC; + const SimplifyQuery SQ; /// Reverse mapping of Allocas. DenseMap<AllocaInst *, unsigned> AllocaLookup; @@ -269,10 +267,11 @@ struct PromoteMem2Reg { public: PromoteMem2Reg(ArrayRef<AllocaInst *> Allocas, DominatorTree &DT, - AliasSetTracker *AST, AssumptionCache *AC) + AssumptionCache *AC) : Allocas(Allocas.begin(), Allocas.end()), DT(DT), DIB(*DT.getRoot()->getParent()->getParent(), /*AllowUnresolved*/ false), - AST(AST), AC(AC) {} + AC(AC), SQ(DT.getRoot()->getParent()->getParent()->getDataLayout(), + nullptr, &DT, AC) {} void run(); @@ -301,6 +300,18 @@ private: } // end of anonymous namespace +/// Given a LoadInst LI this adds assume(LI != null) after it. +static void addAssumeNonNull(AssumptionCache *AC, LoadInst *LI) { + Function *AssumeIntrinsic = + Intrinsic::getDeclaration(LI->getModule(), Intrinsic::assume); + ICmpInst *LoadNotNull = new ICmpInst(ICmpInst::ICMP_NE, LI, + Constant::getNullValue(LI->getType())); + LoadNotNull->insertAfter(LI); + CallInst *CI = CallInst::Create(AssumeIntrinsic, {LoadNotNull}); + CI->insertAfter(LoadNotNull); + AC->registerAssumption(CI); +} + static void removeLifetimeIntrinsicUsers(AllocaInst *AI) { // Knowing that this alloca is promotable, we know that it's safe to kill all // instructions except for load and store. @@ -334,9 +345,8 @@ static void removeLifetimeIntrinsicUsers(AllocaInst *AI) { /// and thus must be phi-ed with undef. We fall back to the standard alloca /// promotion algorithm in that case. static bool rewriteSingleStoreAlloca(AllocaInst *AI, AllocaInfo &Info, - LargeBlockInfo &LBI, - DominatorTree &DT, - AliasSetTracker *AST) { + LargeBlockInfo &LBI, DominatorTree &DT, + AssumptionCache *AC) { StoreInst *OnlyStore = Info.OnlyStore; bool StoringGlobalVal = !isa<Instruction>(OnlyStore->getOperand(0)); BasicBlock *StoreBB = OnlyStore->getParent(); @@ -387,9 +397,15 @@ static bool rewriteSingleStoreAlloca(AllocaInst *AI, AllocaInfo &Info, // code. if (ReplVal == LI) ReplVal = UndefValue::get(LI->getType()); + + // If the load was marked as nonnull we don't want to lose + // that information when we erase this Load. So we preserve + // it with an assume. + if (AC && LI->getMetadata(LLVMContext::MD_nonnull) && + !llvm::isKnownNonNullAt(ReplVal, LI, &DT)) + addAssumeNonNull(AC, LI); + LI->replaceAllUsesWith(ReplVal); - if (AST && LI->getType()->isPointerTy()) - AST->deleteValue(LI); LI->eraseFromParent(); LBI.deleteValue(LI); } @@ -410,8 +426,6 @@ static bool rewriteSingleStoreAlloca(AllocaInst *AI, AllocaInfo &Info, Info.OnlyStore->eraseFromParent(); LBI.deleteValue(Info.OnlyStore); - if (AST) - AST->deleteValue(AI); AI->eraseFromParent(); LBI.deleteValue(AI); return true; @@ -435,7 +449,8 @@ static bool rewriteSingleStoreAlloca(AllocaInst *AI, AllocaInfo &Info, /// } static bool promoteSingleBlockAlloca(AllocaInst *AI, const AllocaInfo &Info, LargeBlockInfo &LBI, - AliasSetTracker *AST) { + DominatorTree &DT, + AssumptionCache *AC) { // The trickiest case to handle is when we have large blocks. Because of this, // this code is optimized assuming that large blocks happen. This does not // significantly pessimize the small block case. This uses LargeBlockInfo to @@ -476,13 +491,18 @@ static bool promoteSingleBlockAlloca(AllocaInst *AI, const AllocaInfo &Info, // There is no store before this load, bail out (load may be affected // by the following stores - see main comment). return false; - } - else + } else { // Otherwise, there was a store before this load, the load takes its value. - LI->replaceAllUsesWith(std::prev(I)->second->getOperand(0)); + // Note, if the load was marked as nonnull we don't want to lose that + // information when we erase it. So we preserve it with an assume. + Value *ReplVal = std::prev(I)->second->getOperand(0); + if (AC && LI->getMetadata(LLVMContext::MD_nonnull) && + !llvm::isKnownNonNullAt(ReplVal, LI, &DT)) + addAssumeNonNull(AC, LI); + + LI->replaceAllUsesWith(ReplVal); + } - if (AST && LI->getType()->isPointerTy()) - AST->deleteValue(LI); LI->eraseFromParent(); LBI.deleteValue(LI); } @@ -499,8 +519,6 @@ static bool promoteSingleBlockAlloca(AllocaInst *AI, const AllocaInfo &Info, LBI.deleteValue(SI); } - if (AST) - AST->deleteValue(AI); AI->eraseFromParent(); LBI.deleteValue(AI); @@ -517,8 +535,6 @@ static bool promoteSingleBlockAlloca(AllocaInst *AI, const AllocaInfo &Info, void PromoteMem2Reg::run() { Function &F = *DT.getRoot()->getParent(); - if (AST) - PointerAllocaValues.resize(Allocas.size()); AllocaDbgDeclares.resize(Allocas.size()); AllocaInfo Info; @@ -536,8 +552,6 @@ void PromoteMem2Reg::run() { if (AI->use_empty()) { // If there are no uses of the alloca, just delete it now. - if (AST) - AST->deleteValue(AI); AI->eraseFromParent(); // Remove the alloca from the Allocas list, since it has been processed @@ -553,7 +567,7 @@ void PromoteMem2Reg::run() { // If there is only a single store to this value, replace any loads of // it that are directly dominated by the definition with the value stored. if (Info.DefiningBlocks.size() == 1) { - if (rewriteSingleStoreAlloca(AI, Info, LBI, DT, AST)) { + if (rewriteSingleStoreAlloca(AI, Info, LBI, DT, AC)) { // The alloca has been processed, move on. RemoveFromAllocasList(AllocaNum); ++NumSingleStore; @@ -564,7 +578,7 @@ void PromoteMem2Reg::run() { // If the alloca is only read and written in one basic block, just perform a // linear sweep over the block to eliminate it. if (Info.OnlyUsedInOneBlock && - promoteSingleBlockAlloca(AI, Info, LBI, AST)) { + promoteSingleBlockAlloca(AI, Info, LBI, DT, AC)) { // The alloca has been processed, move on. RemoveFromAllocasList(AllocaNum); continue; @@ -578,11 +592,6 @@ void PromoteMem2Reg::run() { BBNumbers[&BB] = ID++; } - // If we have an AST to keep updated, remember some pointer value that is - // stored into the alloca. - if (AST) - PointerAllocaValues[AllocaNum] = Info.AllocaPointerVal; - // Remember the dbg.declare intrinsic describing this alloca, if any. if (Info.DbgDeclare) AllocaDbgDeclares[AllocaNum] = Info.DbgDeclare; @@ -662,13 +671,9 @@ void PromoteMem2Reg::run() { // tree. Just delete the users now. if (!A->use_empty()) A->replaceAllUsesWith(UndefValue::get(A->getType())); - if (AST) - AST->deleteValue(A); A->eraseFromParent(); } - const DataLayout &DL = F.getParent()->getDataLayout(); - // Remove alloca's dbg.declare instrinsics from the function. for (unsigned i = 0, e = AllocaDbgDeclares.size(); i != e; ++i) if (DbgDeclareInst *DDI = AllocaDbgDeclares[i]) @@ -693,9 +698,7 @@ void PromoteMem2Reg::run() { PHINode *PN = I->second; // If this PHI node merges one value and/or undefs, get the value. - if (Value *V = SimplifyInstruction(PN, DL, nullptr, &DT, AC)) { - if (AST && PN->getType()->isPointerTy()) - AST->deleteValue(PN); + if (Value *V = SimplifyInstruction(PN, SQ)) { PN->replaceAllUsesWith(V); PN->eraseFromParent(); NewPhiNodes.erase(I++); @@ -863,10 +866,6 @@ bool PromoteMem2Reg::QueuePhiNode(BasicBlock *BB, unsigned AllocaNo, &BB->front()); ++NumPHIInsert; PhiToAllocaMap[PN] = AllocaNo; - - if (AST && PN->getType()->isPointerTy()) - AST->copyValue(PointerAllocaValues[AllocaNo], PN); - return true; } @@ -940,10 +939,15 @@ NextIteration: Value *V = IncomingVals[AI->second]; + // If the load was marked as nonnull we don't want to lose + // that information when we erase this Load. So we preserve + // it with an assume. + if (AC && LI->getMetadata(LLVMContext::MD_nonnull) && + !llvm::isKnownNonNullAt(V, LI, &DT)) + addAssumeNonNull(AC, LI); + // Anything using the load now uses the current value. LI->replaceAllUsesWith(V); - if (AST && LI->getType()->isPointerTy()) - AST->deleteValue(LI); BB->getInstList().erase(LI); } else if (StoreInst *SI = dyn_cast<StoreInst>(I)) { // Delete this instruction and mark the name as the current holder of the @@ -987,10 +991,10 @@ NextIteration: } void llvm::PromoteMemToReg(ArrayRef<AllocaInst *> Allocas, DominatorTree &DT, - AliasSetTracker *AST, AssumptionCache *AC) { + AssumptionCache *AC) { // If there is nothing to do, bail out... if (Allocas.empty()) return; - PromoteMem2Reg(Allocas, DT, AST, AC).run(); + PromoteMem2Reg(Allocas, DT, AC).run(); } diff --git a/contrib/llvm/lib/Transforms/Utils/SSAUpdater.cpp b/contrib/llvm/lib/Transforms/Utils/SSAUpdater.cpp index 8e93ee7..6ccf54e 100644 --- a/contrib/llvm/lib/Transforms/Utils/SSAUpdater.cpp +++ b/contrib/llvm/lib/Transforms/Utils/SSAUpdater.cpp @@ -13,18 +13,27 @@ #include "llvm/Transforms/Utils/SSAUpdater.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" #include "llvm/ADT/TinyPtrVector.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/DebugLoc.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/Value.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/SSAUpdaterImpl.h" +#include <cassert> +#include <utility> using namespace llvm; @@ -36,7 +45,7 @@ static AvailableValsTy &getAvailableVals(void *AV) { } SSAUpdater::SSAUpdater(SmallVectorImpl<PHINode*> *NewPHI) - : AV(nullptr), ProtoType(nullptr), ProtoName(), InsertedPHIs(NewPHI) {} + : InsertedPHIs(NewPHI) {} SSAUpdater::~SSAUpdater() { delete static_cast<AvailableValsTy*>(AV); @@ -205,6 +214,7 @@ void SSAUpdater::RewriteUseAfterInsertions(Use &U) { } namespace llvm { + template<> class SSAUpdaterTraits<SSAUpdater> { public: @@ -230,6 +240,7 @@ public: PHI_iterator &operator++() { ++idx; return *this; } bool operator==(const PHI_iterator& x) const { return idx == x.idx; } bool operator!=(const PHI_iterator& x) const { return !operator==(x); } + Value *getIncomingValue() { return PHI->getIncomingValue(idx); } BasicBlock *getIncomingBlock() { return PHI->getIncomingBlock(idx); } }; @@ -303,7 +314,7 @@ public: } }; -} // End llvm namespace +} // end namespace llvm /// Check to see if AvailableVals has an entry for the specified BB and if so, /// return it. If not, construct SSA form by first calculating the required @@ -337,14 +348,12 @@ LoadAndStorePromoter(ArrayRef<const Instruction*> Insts, SSA.Initialize(SomeVal->getType(), BaseName); } - void LoadAndStorePromoter:: run(const SmallVectorImpl<Instruction*> &Insts) const { - // First step: bucket up uses of the alloca by the block they occur in. // This is important because we have to handle multiple defs/uses in a block // ourselves: SSAUpdater is purely for cross-block references. - DenseMap<BasicBlock*, TinyPtrVector<Instruction*> > UsesByBlock; + DenseMap<BasicBlock*, TinyPtrVector<Instruction*>> UsesByBlock; for (Instruction *User : Insts) UsesByBlock[User->getParent()].push_back(User); diff --git a/contrib/llvm/lib/Transforms/Utils/SanitizerStats.cpp b/contrib/llvm/lib/Transforms/Utils/SanitizerStats.cpp index 9afd175..8c23957 100644 --- a/contrib/llvm/lib/Transforms/Utils/SanitizerStats.cpp +++ b/contrib/llvm/lib/Transforms/Utils/SanitizerStats.cpp @@ -12,13 +12,13 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/SanitizerStats.h" -#include "llvm/Transforms/Utils/ModuleUtils.h" #include "llvm/ADT/Triple.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Module.h" +#include "llvm/Transforms/Utils/ModuleUtils.h" using namespace llvm; diff --git a/contrib/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/contrib/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index 7b0bddb..8784b97 100644 --- a/contrib/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/contrib/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -15,21 +15,22 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetOperations.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/ADT/STLExtras.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/EHPersonalities.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/BasicBlock.h" -#include "llvm/IR/CallSite.h" #include "llvm/IR/CFG.h" +#include "llvm/IR/CallSite.h" #include "llvm/IR/Constant.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/Constants.h" @@ -54,11 +55,11 @@ #include "llvm/IR/Type.h" #include "llvm/IR/User.h" #include "llvm/IR/Value.h" -#include "llvm/IR/DebugInfo.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/KnownBits.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" @@ -169,6 +170,8 @@ class SimplifyCFGOpt { unsigned BonusInstThreshold; AssumptionCache *AC; SmallPtrSetImpl<BasicBlock *> *LoopHeaders; + // See comments in SimplifyCFGOpt::SimplifySwitch. + bool LateSimplifyCFG; Value *isValueEqualityComparison(TerminatorInst *TI); BasicBlock *GetValueEqualityComparisonCases( TerminatorInst *TI, std::vector<ValueEqualityComparisonCase> &Cases); @@ -192,9 +195,10 @@ class SimplifyCFGOpt { public: SimplifyCFGOpt(const TargetTransformInfo &TTI, const DataLayout &DL, unsigned BonusInstThreshold, AssumptionCache *AC, - SmallPtrSetImpl<BasicBlock *> *LoopHeaders) + SmallPtrSetImpl<BasicBlock *> *LoopHeaders, + bool LateSimplifyCFG) : TTI(TTI), DL(DL), BonusInstThreshold(BonusInstThreshold), AC(AC), - LoopHeaders(LoopHeaders) {} + LoopHeaders(LoopHeaders), LateSimplifyCFG(LateSimplifyCFG) {} bool run(BasicBlock *BB); }; @@ -591,7 +595,7 @@ private: Span = Span.inverse(); // If there are a ton of values, we don't want to make a ginormous switch. - if (Span.getSetSize().ugt(8) || Span.isEmptySet()) { + if (Span.isSizeLargerThan(8) || Span.isEmptySet()) { return false; } @@ -710,10 +714,9 @@ BasicBlock *SimplifyCFGOpt::GetValueEqualityComparisonCases( TerminatorInst *TI, std::vector<ValueEqualityComparisonCase> &Cases) { if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) { Cases.reserve(SI->getNumCases()); - for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e; - ++i) - Cases.push_back( - ValueEqualityComparisonCase(i.getCaseValue(), i.getCaseSuccessor())); + for (auto Case : SI->cases()) + Cases.push_back(ValueEqualityComparisonCase(Case.getCaseValue(), + Case.getCaseSuccessor())); return SI->getDefaultDest(); } @@ -846,12 +849,12 @@ bool SimplifyCFGOpt::SimplifyEqualityComparisonWithOnlyPredecessor( } for (SwitchInst::CaseIt i = SI->case_end(), e = SI->case_begin(); i != e;) { --i; - if (DeadCases.count(i.getCaseValue())) { + if (DeadCases.count(i->getCaseValue())) { if (HasWeight) { - std::swap(Weights[i.getCaseIndex() + 1], Weights.back()); + std::swap(Weights[i->getCaseIndex() + 1], Weights.back()); Weights.pop_back(); } - i.getCaseSuccessor()->removePredecessor(TI->getParent()); + i->getCaseSuccessor()->removePredecessor(TI->getParent()); SI->removeCase(i); } } @@ -996,8 +999,7 @@ bool SimplifyCFGOpt::FoldValueComparisonIntoPredecessors(TerminatorInst *TI, SmallSetVector<BasicBlock*, 4> FailBlocks; if (!SafeToMergeTerminators(TI, PTI, &FailBlocks)) { for (auto *Succ : FailBlocks) { - std::vector<BasicBlock*> Blocks = { TI->getParent() }; - if (!SplitBlockPredecessors(Succ, Blocks, ".fold.split")) + if (!SplitBlockPredecessors(Succ, TI->getParent(), ".fold.split")) return false; } } @@ -1280,7 +1282,7 @@ static bool HoistThenElseCodeToIf(BranchInst *BI, if (!isa<CallInst>(I1)) I1->setDebugLoc( DILocation::getMergedLocation(I1->getDebugLoc(), I2->getDebugLoc())); - + I2->eraseFromParent(); Changed = true; @@ -1373,53 +1375,6 @@ HoistTerminator: return true; } -// Is it legal to place a variable in operand \c OpIdx of \c I? -// FIXME: This should be promoted to Instruction. -static bool canReplaceOperandWithVariable(const Instruction *I, - unsigned OpIdx) { - // We can't have a PHI with a metadata type. - if (I->getOperand(OpIdx)->getType()->isMetadataTy()) - return false; - - // Early exit. - if (!isa<Constant>(I->getOperand(OpIdx))) - return true; - - switch (I->getOpcode()) { - default: - return true; - case Instruction::Call: - case Instruction::Invoke: - // FIXME: many arithmetic intrinsics have no issue taking a - // variable, however it's hard to distingish these from - // specials such as @llvm.frameaddress that require a constant. - if (isa<IntrinsicInst>(I)) - return false; - - // Constant bundle operands may need to retain their constant-ness for - // correctness. - if (ImmutableCallSite(I).isBundleOperand(OpIdx)) - return false; - - return true; - - case Instruction::ShuffleVector: - // Shufflevector masks are constant. - return OpIdx != 2; - case Instruction::ExtractValue: - case Instruction::InsertValue: - // All operands apart from the first are constant. - return OpIdx == 0; - case Instruction::Alloca: - return false; - case Instruction::GetElementPtr: - if (OpIdx == 0) - return true; - gep_type_iterator It = std::next(gep_type_begin(I), OpIdx - 1); - return It.isSequential(); - } -} - // All instructions in Insts belong to different blocks that all unconditionally // branch to a common successor. Analyze each instruction and return true if it // would be possible to sink them into their successor, creating one common @@ -1472,29 +1427,28 @@ static bool canSinkInstructions( return false; } + // Because SROA can't handle speculating stores of selects, try not + // to sink loads or stores of allocas when we'd have to create a PHI for + // the address operand. Also, because it is likely that loads or stores + // of allocas will disappear when Mem2Reg/SROA is run, don't sink them. + // This can cause code churn which can have unintended consequences down + // the line - see https://llvm.org/bugs/show_bug.cgi?id=30244. + // FIXME: This is a workaround for a deficiency in SROA - see + // https://llvm.org/bugs/show_bug.cgi?id=30188 + if (isa<StoreInst>(I0) && any_of(Insts, [](const Instruction *I) { + return isa<AllocaInst>(I->getOperand(1)); + })) + return false; + if (isa<LoadInst>(I0) && any_of(Insts, [](const Instruction *I) { + return isa<AllocaInst>(I->getOperand(0)); + })) + return false; + for (unsigned OI = 0, OE = I0->getNumOperands(); OI != OE; ++OI) { if (I0->getOperand(OI)->getType()->isTokenTy()) // Don't touch any operand of token type. return false; - // Because SROA can't handle speculating stores of selects, try not - // to sink loads or stores of allocas when we'd have to create a PHI for - // the address operand. Also, because it is likely that loads or stores - // of allocas will disappear when Mem2Reg/SROA is run, don't sink them. - // This can cause code churn which can have unintended consequences down - // the line - see https://llvm.org/bugs/show_bug.cgi?id=30244. - // FIXME: This is a workaround for a deficiency in SROA - see - // https://llvm.org/bugs/show_bug.cgi?id=30188 - if (OI == 1 && isa<StoreInst>(I0) && - any_of(Insts, [](const Instruction *I) { - return isa<AllocaInst>(I->getOperand(1)); - })) - return false; - if (OI == 0 && isa<LoadInst>(I0) && any_of(Insts, [](const Instruction *I) { - return isa<AllocaInst>(I->getOperand(0)); - })) - return false; - auto SameAsI0 = [&I0, OI](const Instruction *I) { assert(I->getNumOperands() == I0->getNumOperands()); return I->getOperand(OI) == I0->getOperand(OI); @@ -1546,7 +1500,7 @@ static bool sinkLastInstruction(ArrayRef<BasicBlock*> Blocks) { })) return false; } - + // We don't need to do any more checking here; canSinkLastInstruction should // have done it all for us. SmallVector<Value*, 4> NewOperands; @@ -1653,7 +1607,7 @@ namespace { bool isValid() const { return !Fail; } - + void operator -- () { if (Fail) return; @@ -1699,7 +1653,7 @@ static bool SinkThenElseCodeToEnd(BranchInst *BI1) { // / \ // [f(1)] [if] // | | \ - // | | \ + // | | | // | [f(2)]| // \ | / // [ end ] @@ -1737,7 +1691,7 @@ static bool SinkThenElseCodeToEnd(BranchInst *BI1) { } if (UnconditionalPreds.size() < 2) return false; - + bool Changed = false; // We take a two-step approach to tail sinking. First we scan from the end of // each block upwards in lockstep. If the n'th instruction from the end of each @@ -1767,7 +1721,7 @@ static bool SinkThenElseCodeToEnd(BranchInst *BI1) { unsigned NumPHIInsts = NumPHIdValues / UnconditionalPreds.size(); if ((NumPHIdValues % UnconditionalPreds.size()) != 0) NumPHIInsts++; - + return NumPHIInsts <= 1; }; @@ -1790,7 +1744,7 @@ static bool SinkThenElseCodeToEnd(BranchInst *BI1) { } if (!Profitable) return false; - + DEBUG(dbgs() << "SINK: Splitting edge\n"); // We have a conditional edge and we're going to sink some instructions. // Insert a new block postdominating all blocks we're going to sink from. @@ -1800,7 +1754,7 @@ static bool SinkThenElseCodeToEnd(BranchInst *BI1) { return false; Changed = true; } - + // Now that we've analyzed all potential sinking candidates, perform the // actual sink. We iteratively sink the last non-terminator of the source // blocks into their common successor unless doing so would require too @@ -1826,7 +1780,7 @@ static bool SinkThenElseCodeToEnd(BranchInst *BI1) { DEBUG(dbgs() << "SINK: stopping here, too many PHIs would be created!\n"); break; } - + if (!sinkLastInstruction(UnconditionalPreds)) return Changed; NumSinkCommons++; @@ -2078,6 +2032,9 @@ static bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, Value *S = Builder.CreateSelect( BrCond, TrueV, FalseV, TrueV->getName() + "." + FalseV->getName(), BI); SpeculatedStore->setOperand(0, S); + SpeculatedStore->setDebugLoc( + DILocation::getMergedLocation( + BI->getDebugLoc(), SpeculatedStore->getDebugLoc())); } // Metadata can be dependent on the condition we are hoisting above. @@ -2147,7 +2104,8 @@ static bool BlockIsSimpleEnoughToThreadThrough(BasicBlock *BB) { /// If we have a conditional branch on a PHI node value that is defined in the /// same block as the branch and if any PHI entries are constants, thread edges /// corresponding to that entry to be branches to their ultimate destination. -static bool FoldCondBranchOnPHI(BranchInst *BI, const DataLayout &DL) { +static bool FoldCondBranchOnPHI(BranchInst *BI, const DataLayout &DL, + AssumptionCache *AC) { BasicBlock *BB = BI->getParent(); PHINode *PN = dyn_cast<PHINode>(BI->getCondition()); // NOTE: we currently cannot transform this case if the PHI node is used @@ -2225,11 +2183,11 @@ static bool FoldCondBranchOnPHI(BranchInst *BI, const DataLayout &DL) { } // Check for trivial simplification. - if (Value *V = SimplifyInstruction(N, DL)) { + if (Value *V = SimplifyInstruction(N, {DL, nullptr, nullptr, AC})) { if (!BBI->use_empty()) TranslateMap[&*BBI] = V; if (!N->mayHaveSideEffects()) { - delete N; // Instruction folded away, don't need actual inst + N->deleteValue(); // Instruction folded away, don't need actual inst N = nullptr; } } else { @@ -2239,6 +2197,11 @@ static bool FoldCondBranchOnPHI(BranchInst *BI, const DataLayout &DL) { // Insert the new instruction into its new home. if (N) EdgeBB->getInstList().insert(InsertPt, N); + + // Register the new instruction with the assumption cache if necessary. + if (auto *II = dyn_cast_or_null<IntrinsicInst>(N)) + if (II->getIntrinsicID() == Intrinsic::assume) + AC->registerAssumption(II); } // Loop over all of the edges from PredBB to BB, changing them to branch @@ -2251,7 +2214,7 @@ static bool FoldCondBranchOnPHI(BranchInst *BI, const DataLayout &DL) { } // Recurse, simplifying any other constants. - return FoldCondBranchOnPHI(BI, DL) | true; + return FoldCondBranchOnPHI(BI, DL, AC) | true; } return false; @@ -2296,7 +2259,7 @@ static bool FoldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI, for (BasicBlock::iterator II = BB->begin(); isa<PHINode>(II);) { PHINode *PN = cast<PHINode>(II++); - if (Value *V = SimplifyInstruction(PN, DL)) { + if (Value *V = SimplifyInstruction(PN, {DL, PN})) { PN->replaceAllUsesWith(V); PN->eraseFromParent(); continue; @@ -3045,6 +3008,15 @@ static bool mergeConditionalStores(BranchInst *PBI, BranchInst *QBI) { BasicBlock *QFB = QBI->getSuccessor(1); BasicBlock *PostBB = QFB->getSingleSuccessor(); + // Make sure we have a good guess for PostBB. If QTB's only successor is + // QFB, then QFB is a better PostBB. + if (QTB->getSingleSuccessor() == QFB) + PostBB = QFB; + + // If we couldn't find a good PostBB, stop. + if (!PostBB) + return false; + bool InvertPCond = false, InvertQCond = false; // Canonicalize fallthroughs to the true branches. if (PFB == QBI->getParent()) { @@ -3069,14 +3041,13 @@ static bool mergeConditionalStores(BranchInst *PBI, BranchInst *QBI) { auto HasOnePredAndOneSucc = [](BasicBlock *BB, BasicBlock *P, BasicBlock *S) { return BB->getSinglePredecessor() == P && BB->getSingleSuccessor() == S; }; - if (!PostBB || - !HasOnePredAndOneSucc(PFB, PBI->getParent(), QBI->getParent()) || + if (!HasOnePredAndOneSucc(PFB, PBI->getParent(), QBI->getParent()) || !HasOnePredAndOneSucc(QFB, QBI->getParent(), PostBB)) return false; if ((PTB && !HasOnePredAndOneSucc(PTB, PBI->getParent(), QBI->getParent())) || (QTB && !HasOnePredAndOneSucc(QTB, QBI->getParent(), PostBB))) return false; - if (PostBB->getNumUses() != 2 || QBI->getParent()->getNumUses() != 2) + if (!PostBB->hasNUses(2) || !QBI->getParent()->hasNUses(2)) return false; // OK, this is a sequence of two diamonds or triangles. @@ -3433,8 +3404,8 @@ static bool SimplifySwitchOnSelect(SwitchInst *SI, SelectInst *Select) { // Find the relevant condition and destinations. Value *Condition = Select->getCondition(); - BasicBlock *TrueBB = SI->findCaseValue(TrueVal).getCaseSuccessor(); - BasicBlock *FalseBB = SI->findCaseValue(FalseVal).getCaseSuccessor(); + BasicBlock *TrueBB = SI->findCaseValue(TrueVal)->getCaseSuccessor(); + BasicBlock *FalseBB = SI->findCaseValue(FalseVal)->getCaseSuccessor(); // Get weight for TrueBB and FalseBB. uint32_t TrueWeight = 0, FalseWeight = 0; @@ -3444,9 +3415,9 @@ static bool SimplifySwitchOnSelect(SwitchInst *SI, SelectInst *Select) { GetBranchWeights(SI, Weights); if (Weights.size() == 1 + SI->getNumCases()) { TrueWeight = - (uint32_t)Weights[SI->findCaseValue(TrueVal).getSuccessorIndex()]; + (uint32_t)Weights[SI->findCaseValue(TrueVal)->getSuccessorIndex()]; FalseWeight = - (uint32_t)Weights[SI->findCaseValue(FalseVal).getSuccessorIndex()]; + (uint32_t)Weights[SI->findCaseValue(FalseVal)->getSuccessorIndex()]; } } @@ -3526,7 +3497,7 @@ static bool TryToSimplifyUncondBranchWithICmpInIt( assert(VVal && "Should have a unique destination value"); ICI->setOperand(0, VVal); - if (Value *V = SimplifyInstruction(ICI, DL)) { + if (Value *V = SimplifyInstruction(ICI, {DL, ICI})) { ICI->replaceAllUsesWith(V); ICI->eraseFromParent(); } @@ -3736,7 +3707,7 @@ bool SimplifyCFGOpt::SimplifyCommonResume(ResumeInst *RI) { if (!isa<DbgInfoIntrinsic>(I)) return false; - SmallSet<BasicBlock *, 4> TrivialUnwindBlocks; + SmallSetVector<BasicBlock *, 4> TrivialUnwindBlocks; auto *PhiLPInst = cast<PHINode>(RI->getValue()); // Check incoming blocks to see if any of them are trivial. @@ -4148,15 +4119,16 @@ bool SimplifyCFGOpt::SimplifyUnreachable(UnreachableInst *UI) { } } } else if (auto *SI = dyn_cast<SwitchInst>(TI)) { - for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e; - ++i) - if (i.getCaseSuccessor() == BB) { - BB->removePredecessor(SI->getParent()); - SI->removeCase(i); - --i; - --e; - Changed = true; + for (auto i = SI->case_begin(), e = SI->case_end(); i != e;) { + if (i->getCaseSuccessor() != BB) { + ++i; + continue; } + BB->removePredecessor(SI->getParent()); + i = SI->removeCase(i); + e = SI->case_end(); + Changed = true; + } } else if (auto *II = dyn_cast<InvokeInst>(TI)) { if (II->getUnwindDest() == BB) { removeUnwindEdge(TI->getParent()); @@ -4239,18 +4211,18 @@ static bool TurnSwitchRangeIntoICmp(SwitchInst *SI, IRBuilder<> &Builder) { SmallVector<ConstantInt *, 16> CasesA; SmallVector<ConstantInt *, 16> CasesB; - for (SwitchInst::CaseIt I : SI->cases()) { - BasicBlock *Dest = I.getCaseSuccessor(); + for (auto Case : SI->cases()) { + BasicBlock *Dest = Case.getCaseSuccessor(); if (!DestA) DestA = Dest; if (Dest == DestA) { - CasesA.push_back(I.getCaseValue()); + CasesA.push_back(Case.getCaseValue()); continue; } if (!DestB) DestB = Dest; if (Dest == DestB) { - CasesB.push_back(I.getCaseValue()); + CasesB.push_back(Case.getCaseValue()); continue; } return false; // More than two destinations. @@ -4348,8 +4320,7 @@ static bool EliminateDeadSwitchCases(SwitchInst *SI, AssumptionCache *AC, const DataLayout &DL) { Value *Cond = SI->getCondition(); unsigned Bits = Cond->getType()->getIntegerBitWidth(); - APInt KnownZero(Bits, 0), KnownOne(Bits, 0); - computeKnownBits(Cond, KnownZero, KnownOne, DL, 0, AC, SI); + KnownBits Known = computeKnownBits(Cond, DL, 0, AC, SI); // We can also eliminate cases by determining that their values are outside of // the limited range of the condition based on how many significant (non-sign) @@ -4360,8 +4331,8 @@ static bool EliminateDeadSwitchCases(SwitchInst *SI, AssumptionCache *AC, // Gather dead cases. SmallVector<ConstantInt *, 8> DeadCases; for (auto &Case : SI->cases()) { - APInt CaseVal = Case.getCaseValue()->getValue(); - if ((CaseVal & KnownZero) != 0 || (CaseVal & KnownOne) != KnownOne || + const APInt &CaseVal = Case.getCaseValue()->getValue(); + if (Known.Zero.intersects(CaseVal) || !Known.One.isSubsetOf(CaseVal) || (CaseVal.getMinSignedBits() > MaxSignificantBitsInCond)) { DeadCases.push_back(Case.getCaseValue()); DEBUG(dbgs() << "SimplifyCFG: switch case " << CaseVal << " is dead.\n"); @@ -4375,7 +4346,7 @@ static bool EliminateDeadSwitchCases(SwitchInst *SI, AssumptionCache *AC, bool HasDefault = !isa<UnreachableInst>(SI->getDefaultDest()->getFirstNonPHIOrDbg()); const unsigned NumUnknownBits = - Bits - (KnownZero.Or(KnownOne)).countPopulation(); + Bits - (Known.Zero | Known.One).countPopulation(); assert(NumUnknownBits <= Bits); if (HasDefault && DeadCases.empty() && NumUnknownBits < 64 /* avoid overflow */ && @@ -4400,17 +4371,17 @@ static bool EliminateDeadSwitchCases(SwitchInst *SI, AssumptionCache *AC, // Remove dead cases from the switch. for (ConstantInt *DeadCase : DeadCases) { - SwitchInst::CaseIt Case = SI->findCaseValue(DeadCase); - assert(Case != SI->case_default() && + SwitchInst::CaseIt CaseI = SI->findCaseValue(DeadCase); + assert(CaseI != SI->case_default() && "Case was not found. Probably mistake in DeadCases forming."); if (HasWeight) { - std::swap(Weights[Case.getCaseIndex() + 1], Weights.back()); + std::swap(Weights[CaseI->getCaseIndex() + 1], Weights.back()); Weights.pop_back(); } // Prune unused values from PHI nodes. - Case.getCaseSuccessor()->removePredecessor(SI->getParent()); - SI->removeCase(Case); + CaseI->getCaseSuccessor()->removePredecessor(SI->getParent()); + SI->removeCase(CaseI); } if (HasWeight && Weights.size() >= 2) { SmallVector<uint32_t, 8> MDWeights(Weights.begin(), Weights.end()); @@ -4464,10 +4435,9 @@ static bool ForwardSwitchConditionToPHI(SwitchInst *SI) { typedef DenseMap<PHINode *, SmallVector<int, 4>> ForwardingNodesMap; ForwardingNodesMap ForwardingNodes; - for (SwitchInst::CaseIt I = SI->case_begin(), E = SI->case_end(); I != E; - ++I) { - ConstantInt *CaseValue = I.getCaseValue(); - BasicBlock *CaseDest = I.getCaseSuccessor(); + for (auto Case : SI->cases()) { + ConstantInt *CaseValue = Case.getCaseValue(); + BasicBlock *CaseDest = Case.getCaseSuccessor(); int PhiIndex; PHINode *PHI = @@ -4811,7 +4781,7 @@ public: SwitchLookupTable( Module &M, uint64_t TableSize, ConstantInt *Offset, const SmallVectorImpl<std::pair<ConstantInt *, Constant *>> &Values, - Constant *DefaultValue, const DataLayout &DL); + Constant *DefaultValue, const DataLayout &DL, const StringRef &FuncName); /// Build instructions with Builder to retrieve the value at /// the position given by Index in the lookup table. @@ -4865,7 +4835,7 @@ private: SwitchLookupTable::SwitchLookupTable( Module &M, uint64_t TableSize, ConstantInt *Offset, const SmallVectorImpl<std::pair<ConstantInt *, Constant *>> &Values, - Constant *DefaultValue, const DataLayout &DL) + Constant *DefaultValue, const DataLayout &DL, const StringRef &FuncName) : SingleValue(nullptr), BitMap(nullptr), BitMapElementTy(nullptr), LinearOffset(nullptr), LinearMultiplier(nullptr), Array(nullptr) { assert(Values.size() && "Can't build lookup table without values!"); @@ -4927,7 +4897,7 @@ SwitchLookupTable::SwitchLookupTable( LinearMappingPossible = false; break; } - APInt Val = ConstVal->getValue(); + const APInt &Val = ConstVal->getValue(); if (I != 0) { APInt Dist = Val - PrevVal; if (I == 1) { @@ -4973,7 +4943,7 @@ SwitchLookupTable::SwitchLookupTable( Array = new GlobalVariable(M, ArrayTy, /*constant=*/true, GlobalVariable::PrivateLinkage, Initializer, - "switch.table"); + "switch.table." + FuncName); Array->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); Kind = ArrayKind; } @@ -5202,8 +5172,8 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, // common destination, as well as the min and max case values. assert(SI->case_begin() != SI->case_end()); SwitchInst::CaseIt CI = SI->case_begin(); - ConstantInt *MinCaseVal = CI.getCaseValue(); - ConstantInt *MaxCaseVal = CI.getCaseValue(); + ConstantInt *MinCaseVal = CI->getCaseValue(); + ConstantInt *MaxCaseVal = CI->getCaseValue(); BasicBlock *CommonDest = nullptr; typedef SmallVector<std::pair<ConstantInt *, Constant *>, 4> ResultListTy; @@ -5213,7 +5183,7 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, SmallVector<PHINode *, 4> PHIs; for (SwitchInst::CaseIt E = SI->case_end(); CI != E; ++CI) { - ConstantInt *CaseVal = CI.getCaseValue(); + ConstantInt *CaseVal = CI->getCaseValue(); if (CaseVal->getValue().slt(MinCaseVal->getValue())) MinCaseVal = CaseVal; if (CaseVal->getValue().sgt(MaxCaseVal->getValue())) @@ -5222,7 +5192,7 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, // Resulting value at phi nodes for this case value. typedef SmallVector<std::pair<PHINode *, Constant *>, 4> ResultsTy; ResultsTy Results; - if (!GetCaseResults(SI, CaseVal, CI.getCaseSuccessor(), &CommonDest, + if (!GetCaseResults(SI, CaseVal, CI->getCaseSuccessor(), &CommonDest, Results, DL, TTI)) return false; @@ -5363,7 +5333,9 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, // If using a bitmask, use any value to fill the lookup table holes. Constant *DV = NeedMask ? ResultLists[PHI][0].second : DefaultResults[PHI]; - SwitchLookupTable Table(Mod, TableSize, MinCaseVal, ResultList, DV, DL); + StringRef FuncName = SI->getParent()->getParent()->getName(); + SwitchLookupTable Table(Mod, TableSize, MinCaseVal, ResultList, DV, DL, + FuncName); Value *Result = Table.BuildLookup(TableIndex, Builder); @@ -5503,11 +5475,10 @@ static bool ReduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder, auto *Rot = Builder.CreateOr(LShr, Shl); SI->replaceUsesOfWith(SI->getCondition(), Rot); - for (SwitchInst::CaseIt C = SI->case_begin(), E = SI->case_end(); C != E; - ++C) { - auto *Orig = C.getCaseValue(); + for (auto Case : SI->cases()) { + auto *Orig = Case.getCaseValue(); auto Sub = Orig->getValue() - APInt(Ty->getBitWidth(), Base); - C.setValue( + Case.setValue( cast<ConstantInt>(ConstantInt::get(Ty, Sub.lshr(ShiftC->getValue())))); } return true; @@ -5553,7 +5524,12 @@ bool SimplifyCFGOpt::SimplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) { if (ForwardSwitchConditionToPHI(SI)) return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true; - if (SwitchToLookupTable(SI, Builder, DL, TTI)) + // The conversion from switch to lookup tables results in difficult + // to analyze code and makes pruning branches much harder. + // This is a problem of the switch expression itself can still be + // restricted as a result of inlining or CVP. There only apply this + // transformation during late steps of the optimisation chain. + if (LateSimplifyCFG && SwitchToLookupTable(SI, Builder, DL, TTI)) return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true; if (ReduceSwitchRange(SI, Builder, DL, TTI)) @@ -5680,20 +5656,22 @@ static bool TryToMergeLandingPad(LandingPadInst *LPad, BranchInst *BI, bool SimplifyCFGOpt::SimplifyUncondBranch(BranchInst *BI, IRBuilder<> &Builder) { BasicBlock *BB = BI->getParent(); + BasicBlock *Succ = BI->getSuccessor(0); if (SinkCommon && SinkThenElseCodeToEnd(BI)) return true; // If the Terminator is the only non-phi instruction, simplify the block. - // if LoopHeader is provided, check if the block is a loop header - // (This is for early invocations before loop simplify and vectorization - // to keep canonical loop forms for nested loops. - // These blocks can be eliminated when the pass is invoked later - // in the back-end.) + // if LoopHeader is provided, check if the block or its successor is a loop + // header (This is for early invocations before loop simplify and + // vectorization to keep canonical loop forms for nested loops. These blocks + // can be eliminated when the pass is invoked later in the back-end.) + bool NeedCanonicalLoop = + !LateSimplifyCFG && + (LoopHeaders && (LoopHeaders->count(BB) || LoopHeaders->count(Succ))); BasicBlock::iterator I = BB->getFirstNonPHIOrDbg()->getIterator(); if (I->isTerminator() && BB != &BB->getParent()->getEntryBlock() && - (!LoopHeaders || !LoopHeaders->count(BB)) && - TryToSimplifyUncondBranchFromEmptyBlock(BB)) + !NeedCanonicalLoop && TryToSimplifyUncondBranchFromEmptyBlock(BB)) return true; // If the only instruction in the block is a seteq/setne comparison @@ -5778,8 +5756,8 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { if (BasicBlock *Dom = BB->getSinglePredecessor()) { auto *PBI = dyn_cast_or_null<BranchInst>(Dom->getTerminator()); if (PBI && PBI->isConditional() && - PBI->getSuccessor(0) != PBI->getSuccessor(1) && - (PBI->getSuccessor(0) == BB || PBI->getSuccessor(1) == BB)) { + PBI->getSuccessor(0) != PBI->getSuccessor(1)) { + assert(PBI->getSuccessor(0) == BB || PBI->getSuccessor(1) == BB); bool CondIsFalse = PBI->getSuccessor(1) == BB; Optional<bool> Implication = isImpliedCondition( PBI->getCondition(), BI->getCondition(), DL, CondIsFalse); @@ -5833,7 +5811,7 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { // through this block if any PHI node entries are constants. if (PHINode *PN = dyn_cast<PHINode>(BI->getCondition())) if (PN->getParent() == BI->getParent()) - if (FoldCondBranchOnPHI(BI, DL)) + if (FoldCondBranchOnPHI(BI, DL, AC)) return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true; // Scan predecessor blocks for conditional branches. @@ -6012,8 +5990,9 @@ bool SimplifyCFGOpt::run(BasicBlock *BB) { /// bool llvm::SimplifyCFG(BasicBlock *BB, const TargetTransformInfo &TTI, unsigned BonusInstThreshold, AssumptionCache *AC, - SmallPtrSetImpl<BasicBlock *> *LoopHeaders) { + SmallPtrSetImpl<BasicBlock *> *LoopHeaders, + bool LateSimplifyCFG) { return SimplifyCFGOpt(TTI, BB->getModule()->getDataLayout(), - BonusInstThreshold, AC, LoopHeaders) + BonusInstThreshold, AC, LoopHeaders, LateSimplifyCFG) .run(BB); } diff --git a/contrib/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp b/contrib/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp index 6b1d3dc..6d90e6b 100644 --- a/contrib/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp +++ b/contrib/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp @@ -25,6 +25,7 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -35,6 +36,9 @@ using namespace llvm; STATISTIC(NumElimIdentity, "Number of IV identities eliminated"); STATISTIC(NumElimOperand, "Number of IV operands folded into a use"); STATISTIC(NumElimRem , "Number of IV remainder operations eliminated"); +STATISTIC( + NumSimplifiedSDiv, + "Number of IV signed division operations converted to unsigned division"); STATISTIC(NumElimCmp , "Number of IV comparisons eliminated"); namespace { @@ -48,13 +52,13 @@ namespace { ScalarEvolution *SE; DominatorTree *DT; - SmallVectorImpl<WeakVH> &DeadInsts; + SmallVectorImpl<WeakTrackingVH> &DeadInsts; bool Changed; public: SimplifyIndvar(Loop *Loop, ScalarEvolution *SE, DominatorTree *DT, - LoopInfo *LI,SmallVectorImpl<WeakVH> &Dead) + LoopInfo *LI, SmallVectorImpl<WeakTrackingVH> &Dead) : L(Loop), LI(LI), SE(SE), DT(DT), DeadInsts(Dead), Changed(false) { assert(LI && "IV simplification requires LoopInfo"); } @@ -75,7 +79,9 @@ namespace { void eliminateIVComparison(ICmpInst *ICmp, Value *IVOperand); void eliminateIVRemainder(BinaryOperator *Rem, Value *IVOperand, bool IsSigned); + bool eliminateSDiv(BinaryOperator *SDiv); bool strengthenOverflowingOperation(BinaryOperator *OBO, Value *IVOperand); + bool strengthenRightShift(BinaryOperator *BO, Value *IVOperand); }; } @@ -150,6 +156,7 @@ Value *SimplifyIndvar::foldIVUser(Instruction *UseInst, Instruction *IVOperand) void SimplifyIndvar::eliminateIVComparison(ICmpInst *ICmp, Value *IVOperand) { unsigned IVOperIdx = 0; ICmpInst::Predicate Pred = ICmp->getPredicate(); + ICmpInst::Predicate OriginalPred = Pred; if (IVOperand != ICmp->getOperand(0)) { // Swapped assert(IVOperand == ICmp->getOperand(1) && "Can't find IVOperand"); @@ -258,6 +265,16 @@ void SimplifyIndvar::eliminateIVComparison(ICmpInst *ICmp, Value *IVOperand) { ICmp->setPredicate(InvariantPredicate); ICmp->setOperand(0, NewLHS); ICmp->setOperand(1, NewRHS); + } else if (ICmpInst::isSigned(OriginalPred) && + SE->isKnownNonNegative(S) && SE->isKnownNonNegative(X)) { + // If we were unable to make anything above, all we can is to canonicalize + // the comparison hoping that it will open the doors for other + // optimizations. If we find out that we compare two non-negative values, + // we turn the instruction's predicate to its unsigned version. Note that + // we cannot rely on Pred here unless we check if we have swapped it. + assert(ICmp->getPredicate() == OriginalPred && "Predicate changed?"); + DEBUG(dbgs() << "INDVARS: Turn to unsigned comparison: " << *ICmp << '\n'); + ICmp->setPredicate(ICmpInst::getUnsignedPredicate(OriginalPred)); } else return; @@ -265,6 +282,33 @@ void SimplifyIndvar::eliminateIVComparison(ICmpInst *ICmp, Value *IVOperand) { Changed = true; } +bool SimplifyIndvar::eliminateSDiv(BinaryOperator *SDiv) { + // Get the SCEVs for the ICmp operands. + auto *N = SE->getSCEV(SDiv->getOperand(0)); + auto *D = SE->getSCEV(SDiv->getOperand(1)); + + // Simplify unnecessary loops away. + const Loop *L = LI->getLoopFor(SDiv->getParent()); + N = SE->getSCEVAtScope(N, L); + D = SE->getSCEVAtScope(D, L); + + // Replace sdiv by udiv if both of the operands are non-negative + if (SE->isKnownNonNegative(N) && SE->isKnownNonNegative(D)) { + auto *UDiv = BinaryOperator::Create( + BinaryOperator::UDiv, SDiv->getOperand(0), SDiv->getOperand(1), + SDiv->getName() + ".udiv", SDiv); + UDiv->setIsExact(SDiv->isExact()); + SDiv->replaceAllUsesWith(UDiv); + DEBUG(dbgs() << "INDVARS: Simplified sdiv: " << *SDiv << '\n'); + ++NumSimplifiedSDiv; + Changed = true; + DeadInsts.push_back(SDiv); + return true; + } + + return false; +} + /// SimplifyIVUsers helper for eliminating useless /// remainder operations operating on an induction variable. void SimplifyIndvar::eliminateIVRemainder(BinaryOperator *Rem, @@ -321,9 +365,9 @@ bool SimplifyIndvar::eliminateOverflowIntrinsic(CallInst *CI) { return false; typedef const SCEV *(ScalarEvolution::*OperationFunctionTy)( - const SCEV *, const SCEV *, SCEV::NoWrapFlags); + const SCEV *, const SCEV *, SCEV::NoWrapFlags, unsigned); typedef const SCEV *(ScalarEvolution::*ExtensionFunctionTy)( - const SCEV *, Type *); + const SCEV *, Type *, unsigned); OperationFunctionTy Operation; ExtensionFunctionTy Extension; @@ -375,10 +419,11 @@ bool SimplifyIndvar::eliminateOverflowIntrinsic(CallInst *CI) { IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2); const SCEV *A = - (SE->*Extension)((SE->*Operation)(LHS, RHS, SCEV::FlagAnyWrap), WideTy); + (SE->*Extension)((SE->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), + WideTy, 0); const SCEV *B = - (SE->*Operation)((SE->*Extension)(LHS, WideTy), - (SE->*Extension)(RHS, WideTy), SCEV::FlagAnyWrap); + (SE->*Operation)((SE->*Extension)(LHS, WideTy, 0), + (SE->*Extension)(RHS, WideTy, 0), SCEV::FlagAnyWrap, 0); if (A != B) return false; @@ -426,12 +471,15 @@ bool SimplifyIndvar::eliminateIVUser(Instruction *UseInst, eliminateIVComparison(ICmp, IVOperand); return true; } - if (BinaryOperator *Rem = dyn_cast<BinaryOperator>(UseInst)) { - bool IsSigned = Rem->getOpcode() == Instruction::SRem; - if (IsSigned || Rem->getOpcode() == Instruction::URem) { - eliminateIVRemainder(Rem, IVOperand, IsSigned); + if (BinaryOperator *Bin = dyn_cast<BinaryOperator>(UseInst)) { + bool IsSRem = Bin->getOpcode() == Instruction::SRem; + if (IsSRem || Bin->getOpcode() == Instruction::URem) { + eliminateIVRemainder(Bin, IVOperand, IsSRem); return true; } + + if (Bin->getOpcode() == Instruction::SDiv) + return eliminateSDiv(Bin); } if (auto *CI = dyn_cast<CallInst>(UseInst)) @@ -496,8 +544,7 @@ bool SimplifyIndvar::strengthenOverflowingOperation(BinaryOperator *BO, return false; const SCEV *(ScalarEvolution::*GetExprForBO)(const SCEV *, const SCEV *, - SCEV::NoWrapFlags); - + SCEV::NoWrapFlags, unsigned); switch (BO->getOpcode()) { default: return false; @@ -526,7 +573,7 @@ bool SimplifyIndvar::strengthenOverflowingOperation(BinaryOperator *BO, const SCEV *ExtendAfterOp = SE->getZeroExtendExpr(SE->getSCEV(BO), WideTy); const SCEV *OpAfterExtend = (SE->*GetExprForBO)( SE->getZeroExtendExpr(LHS, WideTy), SE->getZeroExtendExpr(RHS, WideTy), - SCEV::FlagAnyWrap); + SCEV::FlagAnyWrap, 0u); if (ExtendAfterOp == OpAfterExtend) { BO->setHasNoUnsignedWrap(); SE->forgetValue(BO); @@ -538,7 +585,7 @@ bool SimplifyIndvar::strengthenOverflowingOperation(BinaryOperator *BO, const SCEV *ExtendAfterOp = SE->getSignExtendExpr(SE->getSCEV(BO), WideTy); const SCEV *OpAfterExtend = (SE->*GetExprForBO)( SE->getSignExtendExpr(LHS, WideTy), SE->getSignExtendExpr(RHS, WideTy), - SCEV::FlagAnyWrap); + SCEV::FlagAnyWrap, 0u); if (ExtendAfterOp == OpAfterExtend) { BO->setHasNoSignedWrap(); SE->forgetValue(BO); @@ -549,6 +596,35 @@ bool SimplifyIndvar::strengthenOverflowingOperation(BinaryOperator *BO, return Changed; } +/// Annotate the Shr in (X << IVOperand) >> C as exact using the +/// information from the IV's range. Returns true if anything changed, false +/// otherwise. +bool SimplifyIndvar::strengthenRightShift(BinaryOperator *BO, + Value *IVOperand) { + using namespace llvm::PatternMatch; + + if (BO->getOpcode() == Instruction::Shl) { + bool Changed = false; + ConstantRange IVRange = SE->getUnsignedRange(SE->getSCEV(IVOperand)); + for (auto *U : BO->users()) { + const APInt *C; + if (match(U, + m_AShr(m_Shl(m_Value(), m_Specific(IVOperand)), m_APInt(C))) || + match(U, + m_LShr(m_Shl(m_Value(), m_Specific(IVOperand)), m_APInt(C)))) { + BinaryOperator *Shr = cast<BinaryOperator>(U); + if (!Shr->isExact() && IVRange.getUnsignedMin().uge(*C)) { + Shr->setIsExact(true); + Changed = true; + } + } + } + return Changed; + } + + return false; +} + /// Add all uses of Def to the current IV's worklist. static void pushIVUsers( Instruction *Def, @@ -641,8 +717,9 @@ void SimplifyIndvar::simplifyUsers(PHINode *CurrIV, IVVisitor *V) { } if (BinaryOperator *BO = dyn_cast<BinaryOperator>(UseOper.first)) { - if (isa<OverflowingBinaryOperator>(BO) && - strengthenOverflowingOperation(BO, IVOperand)) { + if ((isa<OverflowingBinaryOperator>(BO) && + strengthenOverflowingOperation(BO, IVOperand)) || + (isa<ShlOperator>(BO) && strengthenRightShift(BO, IVOperand))) { // re-queue uses of the now modified binary operator and fall // through to the checks that remain. pushIVUsers(IVOperand, Simplified, SimpleIVUsers); @@ -667,7 +744,7 @@ void IVVisitor::anchor() { } /// Simplify instructions that use this induction variable /// by using ScalarEvolution to analyze the IV's recurrence. bool simplifyUsersOfIV(PHINode *CurrIV, ScalarEvolution *SE, DominatorTree *DT, - LoopInfo *LI, SmallVectorImpl<WeakVH> &Dead, + LoopInfo *LI, SmallVectorImpl<WeakTrackingVH> &Dead, IVVisitor *V) { SimplifyIndvar SIV(LI->getLoopFor(CurrIV->getParent()), SE, DT, LI, Dead); SIV.simplifyUsers(CurrIV, V); @@ -677,7 +754,7 @@ bool simplifyUsersOfIV(PHINode *CurrIV, ScalarEvolution *SE, DominatorTree *DT, /// Simplify users of induction variables within this /// loop. This does not actually change or add IVs. bool simplifyLoopIVs(Loop *L, ScalarEvolution *SE, DominatorTree *DT, - LoopInfo *LI, SmallVectorImpl<WeakVH> &Dead) { + LoopInfo *LI, SmallVectorImpl<WeakTrackingVH> &Dead) { bool Changed = false; for (BasicBlock::iterator I = L->getHeader()->begin(); isa<PHINode>(I); ++I) { Changed |= simplifyUsersOfIV(cast<PHINode>(I), SE, DT, LI, Dead); diff --git a/contrib/llvm/lib/Transforms/Utils/SimplifyInstructions.cpp b/contrib/llvm/lib/Transforms/Utils/SimplifyInstructions.cpp index 1220490..2ea15f6 100644 --- a/contrib/llvm/lib/Transforms/Utils/SimplifyInstructions.cpp +++ b/contrib/llvm/lib/Transforms/Utils/SimplifyInstructions.cpp @@ -20,23 +20,23 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/OptimizationDiagnosticInfo.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/Type.h" #include "llvm/Pass.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" using namespace llvm; #define DEBUG_TYPE "instsimplify" STATISTIC(NumSimplified, "Number of redundant instructions removed"); -static bool runImpl(Function &F, const DominatorTree *DT, - const TargetLibraryInfo *TLI, AssumptionCache *AC) { - const DataLayout &DL = F.getParent()->getDataLayout(); +static bool runImpl(Function &F, const SimplifyQuery &SQ, + OptimizationRemarkEmitter *ORE) { SmallPtrSet<const Instruction *, 8> S1, S2, *ToSimplify = &S1, *Next = &S2; bool Changed = false; @@ -54,7 +54,7 @@ static bool runImpl(Function &F, const DominatorTree *DT, // Don't waste time simplifying unused instructions. if (!I->use_empty()) { - if (Value *V = SimplifyInstruction(I, DL, TLI, DT, AC)) { + if (Value *V = SimplifyInstruction(I, SQ, ORE)) { // Mark all uses for resimplification next time round the loop. for (User *U : I->users()) Next->insert(cast<Instruction>(U)); @@ -63,7 +63,7 @@ static bool runImpl(Function &F, const DominatorTree *DT, Changed = true; } } - if (RecursivelyDeleteTriviallyDeadInstructions(I, TLI)) { + if (RecursivelyDeleteTriviallyDeadInstructions(I, SQ.TLI)) { // RecursivelyDeleteTriviallyDeadInstruction can remove more than one // instruction, so simply incrementing the iterator does not work. // When instructions get deleted re-iterate instead. @@ -95,6 +95,7 @@ namespace { AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); } /// runOnFunction - Remove instructions that simplify. @@ -108,7 +109,11 @@ namespace { &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); AssumptionCache *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - return runImpl(F, DT, TLI, AC); + OptimizationRemarkEmitter *ORE = + &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); + const DataLayout &DL = F.getParent()->getDataLayout(); + const SimplifyQuery SQ(DL, TLI, DT, AC); + return runImpl(F, SQ, ORE); } }; } @@ -119,6 +124,7 @@ INITIALIZE_PASS_BEGIN(InstSimplifier, "instsimplify", INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) INITIALIZE_PASS_END(InstSimplifier, "instsimplify", "Remove redundant instructions", false, false) char &llvm::InstructionSimplifierID = InstSimplifier::ID; @@ -133,9 +139,14 @@ PreservedAnalyses InstSimplifierPass::run(Function &F, auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); auto &AC = AM.getResult<AssumptionAnalysis>(F); - bool Changed = runImpl(F, &DT, &TLI, &AC); + auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); + const DataLayout &DL = F.getParent()->getDataLayout(); + const SimplifyQuery SQ(DL, &TLI, &DT, &AC); + bool Changed = runImpl(F, SQ, &ORE); if (!Changed) return PreservedAnalyses::all(); - // FIXME: This should also 'preserve the CFG'. - return PreservedAnalyses::none(); + + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + return PA; } diff --git a/contrib/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/contrib/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp index 8eaeb10..77c0a41 100644 --- a/contrib/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/contrib/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -30,6 +30,7 @@ #include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/KnownBits.h" #include "llvm/Transforms/Utils/BuildLibCalls.h" #include "llvm/Transforms/Utils/Local.h" @@ -37,10 +38,6 @@ using namespace llvm; using namespace PatternMatch; static cl::opt<bool> - ColdErrorCalls("error-reporting-is-cold", cl::init(true), cl::Hidden, - cl::desc("Treat error-reporting calls as cold")); - -static cl::opt<bool> EnableUnsafeFPShrink("enable-double-float-shrink", cl::Hidden, cl::init(false), cl::desc("Enable unsafe double to float " @@ -51,9 +48,9 @@ static cl::opt<bool> // Helper Functions //===----------------------------------------------------------------------===// -static bool ignoreCallingConv(LibFunc::Func Func) { - return Func == LibFunc::abs || Func == LibFunc::labs || - Func == LibFunc::llabs || Func == LibFunc::strlen; +static bool ignoreCallingConv(LibFunc Func) { + return Func == LibFunc_abs || Func == LibFunc_labs || + Func == LibFunc_llabs || Func == LibFunc_strlen; } static bool isCallingConvCCompatible(CallInst *CI) { @@ -88,20 +85,6 @@ static bool isCallingConvCCompatible(CallInst *CI) { return false; } -/// Return true if it only matters that the value is equal or not-equal to zero. -static bool isOnlyUsedInZeroEqualityComparison(Value *V) { - for (User *U : V->users()) { - if (ICmpInst *IC = dyn_cast<ICmpInst>(U)) - if (IC->isEquality()) - if (Constant *C = dyn_cast<Constant>(IC->getOperand(1))) - if (C->isNullValue()) - continue; - // Unknown instruction. - return false; - } - return true; -} - /// Return true if it is only used in equality comparisons with With. static bool isOnlyUsedInEqualityComparison(Value *V, Value *With) { for (User *U : V->users()) { @@ -123,8 +106,8 @@ static bool callHasFloatingPointArgument(const CallInst *CI) { /// \brief Check whether the overloaded unary floating point function /// corresponding to \a Ty is available. static bool hasUnaryFloatFn(const TargetLibraryInfo *TLI, Type *Ty, - LibFunc::Func DoubleFn, LibFunc::Func FloatFn, - LibFunc::Func LongDoubleFn) { + LibFunc DoubleFn, LibFunc FloatFn, + LibFunc LongDoubleFn) { switch (Ty->getTypeID()) { case Type::FloatTyID: return TLI->has(FloatFn); @@ -429,59 +412,68 @@ Value *LibCallSimplifier::optimizeStrNCpy(CallInst *CI, IRBuilder<> &B) { return Dst; } -Value *LibCallSimplifier::optimizeStrLen(CallInst *CI, IRBuilder<> &B) { +Value *LibCallSimplifier::optimizeStringLength(CallInst *CI, IRBuilder<> &B, + unsigned CharSize) { Value *Src = CI->getArgOperand(0); // Constant folding: strlen("xyz") -> 3 - if (uint64_t Len = GetStringLength(Src)) + if (uint64_t Len = GetStringLength(Src, CharSize)) return ConstantInt::get(CI->getType(), Len - 1); // If s is a constant pointer pointing to a string literal, we can fold - // strlen(s + x) to strlen(s) - x, when x is known to be in the range + // strlen(s + x) to strlen(s) - x, when x is known to be in the range // [0, strlen(s)] or the string has a single null terminator '\0' at the end. - // We only try to simplify strlen when the pointer s points to an array + // We only try to simplify strlen when the pointer s points to an array // of i8. Otherwise, we would need to scale the offset x before doing the - // subtraction. This will make the optimization more complex, and it's not - // very useful because calling strlen for a pointer of other types is + // subtraction. This will make the optimization more complex, and it's not + // very useful because calling strlen for a pointer of other types is // very uncommon. if (GEPOperator *GEP = dyn_cast<GEPOperator>(Src)) { - if (!isGEPBasedOnPointerToString(GEP)) + if (!isGEPBasedOnPointerToString(GEP, CharSize)) return nullptr; - StringRef Str; - if (getConstantStringInfo(GEP->getOperand(0), Str, 0, false)) { - size_t NullTermIdx = Str.find('\0'); - - // If the string does not have '\0', leave it to strlen to compute - // its length. - if (NullTermIdx == StringRef::npos) - return nullptr; - + ConstantDataArraySlice Slice; + if (getConstantDataArrayInfo(GEP->getOperand(0), Slice, CharSize)) { + uint64_t NullTermIdx; + if (Slice.Array == nullptr) { + NullTermIdx = 0; + } else { + NullTermIdx = ~((uint64_t)0); + for (uint64_t I = 0, E = Slice.Length; I < E; ++I) { + if (Slice.Array->getElementAsInteger(I + Slice.Offset) == 0) { + NullTermIdx = I; + break; + } + } + // If the string does not have '\0', leave it to strlen to compute + // its length. + if (NullTermIdx == ~((uint64_t)0)) + return nullptr; + } + Value *Offset = GEP->getOperand(2); - unsigned BitWidth = Offset->getType()->getIntegerBitWidth(); - APInt KnownZero(BitWidth, 0); - APInt KnownOne(BitWidth, 0); - computeKnownBits(Offset, KnownZero, KnownOne, DL, 0, nullptr, CI, - nullptr); - KnownZero.flipAllBits(); - size_t ArrSize = + KnownBits Known = computeKnownBits(Offset, DL, 0, nullptr, CI, nullptr); + Known.Zero.flipAllBits(); + uint64_t ArrSize = cast<ArrayType>(GEP->getSourceElementType())->getNumElements(); - // KnownZero's bits are flipped, so zeros in KnownZero now represent - // bits known to be zeros in Offset, and ones in KnowZero represent + // KnownZero's bits are flipped, so zeros in KnownZero now represent + // bits known to be zeros in Offset, and ones in KnowZero represent // bits unknown in Offset. Therefore, Offset is known to be in range - // [0, NullTermIdx] when the flipped KnownZero is non-negative and + // [0, NullTermIdx] when the flipped KnownZero is non-negative and // unsigned-less-than NullTermIdx. // - // If Offset is not provably in the range [0, NullTermIdx], we can still - // optimize if we can prove that the program has undefined behavior when - // Offset is outside that range. That is the case when GEP->getOperand(0) + // If Offset is not provably in the range [0, NullTermIdx], we can still + // optimize if we can prove that the program has undefined behavior when + // Offset is outside that range. That is the case when GEP->getOperand(0) // is a pointer to an object whose memory extent is NullTermIdx+1. - if ((KnownZero.isNonNegative() && KnownZero.ule(NullTermIdx)) || + if ((Known.Zero.isNonNegative() && Known.Zero.ule(NullTermIdx)) || (GEP->isInBounds() && isa<GlobalVariable>(GEP->getOperand(0)) && - NullTermIdx == ArrSize - 1)) - return B.CreateSub(ConstantInt::get(CI->getType(), NullTermIdx), + NullTermIdx == ArrSize - 1)) { + Offset = B.CreateSExtOrTrunc(Offset, CI->getType()); + return B.CreateSub(ConstantInt::get(CI->getType(), NullTermIdx), Offset); + } } return nullptr; @@ -489,8 +481,8 @@ Value *LibCallSimplifier::optimizeStrLen(CallInst *CI, IRBuilder<> &B) { // strlen(x?"foo":"bars") --> x ? 3 : 4 if (SelectInst *SI = dyn_cast<SelectInst>(Src)) { - uint64_t LenTrue = GetStringLength(SI->getTrueValue()); - uint64_t LenFalse = GetStringLength(SI->getFalseValue()); + uint64_t LenTrue = GetStringLength(SI->getTrueValue(), CharSize); + uint64_t LenFalse = GetStringLength(SI->getFalseValue(), CharSize); if (LenTrue && LenFalse) { Function *Caller = CI->getParent()->getParent(); emitOptimizationRemark(CI->getContext(), "simplify-libcalls", *Caller, @@ -510,6 +502,17 @@ Value *LibCallSimplifier::optimizeStrLen(CallInst *CI, IRBuilder<> &B) { return nullptr; } +Value *LibCallSimplifier::optimizeStrLen(CallInst *CI, IRBuilder<> &B) { + return optimizeStringLength(CI, B, 8); +} + +Value *LibCallSimplifier::optimizeWcslen(CallInst *CI, IRBuilder<> &B) { + Module &M = *CI->getParent()->getParent()->getParent(); + unsigned WCharSize = TLI->getWCharSize(M) * 8; + + return optimizeStringLength(CI, B, WCharSize); +} + Value *LibCallSimplifier::optimizeStrPBrk(CallInst *CI, IRBuilder<> &B) { StringRef S1, S2; bool HasS1 = getConstantStringInfo(CI->getArgOperand(0), S1); @@ -542,7 +545,7 @@ Value *LibCallSimplifier::optimizeStrTo(CallInst *CI, IRBuilder<> &B) { if (isa<ConstantPointerNull>(EndPtr)) { // With a null EndPtr, this function won't capture the main argument. // It would be readonly too, except that it still may write to errno. - CI->addAttribute(1, Attribute::NoCapture); + CI->addParamAttr(0, Attribute::NoCapture); } return nullptr; @@ -653,7 +656,7 @@ Value *LibCallSimplifier::optimizeMemChr(CallInst *CI, IRBuilder<> &B) { ConstantInt *LenC = dyn_cast<ConstantInt>(CI->getArgOperand(2)); // memchr(x, y, 0) -> null - if (LenC && LenC->isNullValue()) + if (LenC && LenC->isZero()) return Constant::getNullValue(CI->getType()); // From now on we need at least constant length and string. @@ -735,8 +738,8 @@ Value *LibCallSimplifier::optimizeMemCmp(CallInst *CI, IRBuilder<> &B) { ConstantInt *LenC = dyn_cast<ConstantInt>(CI->getArgOperand(2)); if (!LenC) return nullptr; - uint64_t Len = LenC->getZExtValue(); + uint64_t Len = LenC->getZExtValue(); if (Len == 0) // memcmp(s1,s2,0) -> 0 return Constant::getNullValue(CI->getType()); @@ -809,9 +812,9 @@ Value *LibCallSimplifier::optimizeMemMove(CallInst *CI, IRBuilder<> &B) { // TODO: Does this belong in BuildLibCalls or should all of those similar // functions be moved here? -static Value *emitCalloc(Value *Num, Value *Size, const AttributeSet &Attrs, +static Value *emitCalloc(Value *Num, Value *Size, const AttributeList &Attrs, IRBuilder<> &B, const TargetLibraryInfo &TLI) { - LibFunc::Func Func; + LibFunc Func; if (!TLI.getLibFunc("calloc", Func) || !TLI.has(Func)) return nullptr; @@ -819,7 +822,7 @@ static Value *emitCalloc(Value *Num, Value *Size, const AttributeSet &Attrs, const DataLayout &DL = M->getDataLayout(); IntegerType *PtrType = DL.getIntPtrType((B.GetInsertBlock()->getContext())); Value *Calloc = M->getOrInsertFunction("calloc", Attrs, B.getInt8PtrTy(), - PtrType, PtrType, nullptr); + PtrType, PtrType); CallInst *CI = B.CreateCall(Calloc, { Num, Size }, "calloc"); if (const auto *F = dyn_cast<Function>(Calloc->stripPointerCasts())) @@ -846,9 +849,12 @@ static Value *foldMallocMemset(CallInst *Memset, IRBuilder<> &B, // Is the inner call really malloc()? Function *InnerCallee = Malloc->getCalledFunction(); - LibFunc::Func Func; + if (!InnerCallee) + return nullptr; + + LibFunc Func; if (!TLI.getLibFunc(*InnerCallee, Func) || !TLI.has(Func) || - Func != LibFunc::malloc) + Func != LibFunc_malloc) return nullptr; // The memset must cover the same number of bytes that are malloc'd. @@ -930,6 +936,24 @@ static Value *optimizeUnaryDoubleFP(CallInst *CI, IRBuilder<> &B, if (V == nullptr) return nullptr; + // If call isn't an intrinsic, check that it isn't within a function with the + // same name as the float version of this call. + // + // e.g. inline float expf(float val) { return (float) exp((double) val); } + // + // A similar such definition exists in the MinGW-w64 math.h header file which + // when compiled with -O2 -ffast-math causes the generation of infinite loops + // where expf is called. + if (!Callee->isIntrinsic()) { + const Function *F = CI->getFunction(); + StringRef FName = F->getName(); + StringRef CalleeName = Callee->getName(); + if ((FName.size() == (CalleeName.size() + 1)) && + (FName.back() == 'f') && + FName.startswith(CalleeName)) + return nullptr; + } + // Propagate fast-math flags from the existing call to the new call. IRBuilder<>::FastMathFlagGuard Guard(B); B.setFastMathFlags(CI->getFastMathFlags()); @@ -948,6 +972,20 @@ static Value *optimizeUnaryDoubleFP(CallInst *CI, IRBuilder<> &B, return B.CreateFPExt(V, B.getDoubleTy()); } +// Replace a libcall \p CI with a call to intrinsic \p IID +static Value *replaceUnaryCall(CallInst *CI, IRBuilder<> &B, Intrinsic::ID IID) { + // Propagate fast-math flags from the existing call to the new call. + IRBuilder<>::FastMathFlagGuard Guard(B); + B.setFastMathFlags(CI->getFastMathFlags()); + + Module *M = CI->getModule(); + Value *V = CI->getArgOperand(0); + Function *F = Intrinsic::getDeclaration(M, IID, CI->getType()); + CallInst *NewCall = B.CreateCall(F, V); + NewCall->takeName(CI); + return NewCall; +} + /// Shrink double -> float for binary functions like 'fmin/fmax'. static Value *optimizeBinaryDoubleFP(CallInst *CI, IRBuilder<> &B) { Function *Callee = CI->getCalledFunction(); @@ -1041,9 +1079,9 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) { if (ConstantFP *Op1C = dyn_cast<ConstantFP>(Op1)) { // pow(10.0, x) -> exp10(x) if (Op1C->isExactlyValue(10.0) && - hasUnaryFloatFn(TLI, Op1->getType(), LibFunc::exp10, LibFunc::exp10f, - LibFunc::exp10l)) - return emitUnaryFloatFnCall(Op2, TLI->getName(LibFunc::exp10), B, + hasUnaryFloatFn(TLI, Op1->getType(), LibFunc_exp10, LibFunc_exp10f, + LibFunc_exp10l)) + return emitUnaryFloatFnCall(Op2, TLI->getName(LibFunc_exp10), B, Callee->getAttributes()); } @@ -1055,10 +1093,10 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) { // pow(exp(x), y) = pow(inf, 0.001) = inf, whereas exp(x*y) = exp(1). auto *OpC = dyn_cast<CallInst>(Op1); if (OpC && OpC->hasUnsafeAlgebra() && CI->hasUnsafeAlgebra()) { - LibFunc::Func Func; + LibFunc Func; Function *OpCCallee = OpC->getCalledFunction(); if (OpCCallee && TLI->getLibFunc(OpCCallee->getName(), Func) && - TLI->has(Func) && (Func == LibFunc::exp || Func == LibFunc::exp2)) { + TLI->has(Func) && (Func == LibFunc_exp || Func == LibFunc_exp2)) { IRBuilder<>::FastMathFlagGuard Guard(B); B.setFastMathFlags(CI->getFastMathFlags()); Value *FMul = B.CreateFMul(OpC->getArgOperand(0), Op2, "mul"); @@ -1075,17 +1113,20 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) { return ConstantFP::get(CI->getType(), 1.0); if (Op2C->isExactlyValue(-0.5) && - hasUnaryFloatFn(TLI, Op2->getType(), LibFunc::sqrt, LibFunc::sqrtf, - LibFunc::sqrtl)) { + hasUnaryFloatFn(TLI, Op2->getType(), LibFunc_sqrt, LibFunc_sqrtf, + LibFunc_sqrtl)) { // If -ffast-math: // pow(x, -0.5) -> 1.0 / sqrt(x) if (CI->hasUnsafeAlgebra()) { IRBuilder<>::FastMathFlagGuard Guard(B); B.setFastMathFlags(CI->getFastMathFlags()); - // Here we cannot lower to an intrinsic because C99 sqrt() and llvm.sqrt - // are not guaranteed to have the same semantics. - Value *Sqrt = emitUnaryFloatFnCall(Op1, TLI->getName(LibFunc::sqrt), B, + // TODO: If the pow call is an intrinsic, we should lower to the sqrt + // intrinsic, so we match errno semantics. We also should check that the + // target can in fact lower the sqrt intrinsic -- we currently have no way + // to ask this question other than asking whether the target has a sqrt + // libcall, which is a sufficient but not necessary condition. + Value *Sqrt = emitUnaryFloatFnCall(Op1, TLI->getName(LibFunc_sqrt), B, Callee->getAttributes()); return B.CreateFDiv(ConstantFP::get(CI->getType(), 1.0), Sqrt, "sqrtrecip"); @@ -1093,19 +1134,17 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) { } if (Op2C->isExactlyValue(0.5) && - hasUnaryFloatFn(TLI, Op2->getType(), LibFunc::sqrt, LibFunc::sqrtf, - LibFunc::sqrtl) && - hasUnaryFloatFn(TLI, Op2->getType(), LibFunc::fabs, LibFunc::fabsf, - LibFunc::fabsl)) { + hasUnaryFloatFn(TLI, Op2->getType(), LibFunc_sqrt, LibFunc_sqrtf, + LibFunc_sqrtl)) { // In -ffast-math, pow(x, 0.5) -> sqrt(x). if (CI->hasUnsafeAlgebra()) { IRBuilder<>::FastMathFlagGuard Guard(B); B.setFastMathFlags(CI->getFastMathFlags()); - // Unlike other math intrinsics, sqrt has differerent semantics - // from the libc function. See LangRef for details. - return emitUnaryFloatFnCall(Op1, TLI->getName(LibFunc::sqrt), B, + // TODO: As above, we should lower to the sqrt intrinsic if the pow is an + // intrinsic, to match errno semantics. + return emitUnaryFloatFnCall(Op1, TLI->getName(LibFunc_sqrt), B, Callee->getAttributes()); } @@ -1115,9 +1154,16 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) { // TODO: In finite-only mode, this could be just fabs(sqrt(x)). Value *Inf = ConstantFP::getInfinity(CI->getType()); Value *NegInf = ConstantFP::getInfinity(CI->getType(), true); + + // TODO: As above, we should lower to the sqrt intrinsic if the pow is an + // intrinsic, to match errno semantics. Value *Sqrt = emitUnaryFloatFnCall(Op1, "sqrt", B, Callee->getAttributes()); - Value *FAbs = - emitUnaryFloatFnCall(Sqrt, "fabs", B, Callee->getAttributes()); + + Module *M = Callee->getParent(); + Function *FabsF = Intrinsic::getDeclaration(M, Intrinsic::fabs, + CI->getType()); + Value *FAbs = B.CreateCall(FabsF, Sqrt); + Value *FCmp = B.CreateFCmpOEQ(Op1, NegInf); Value *Sel = B.CreateSelect(FCmp, Inf, FAbs); return Sel; @@ -1173,11 +1219,11 @@ Value *LibCallSimplifier::optimizeExp2(CallInst *CI, IRBuilder<> &B) { Value *Op = CI->getArgOperand(0); // Turn exp2(sitofp(x)) -> ldexp(1.0, sext(x)) if sizeof(x) <= 32 // Turn exp2(uitofp(x)) -> ldexp(1.0, zext(x)) if sizeof(x) < 32 - LibFunc::Func LdExp = LibFunc::ldexpl; + LibFunc LdExp = LibFunc_ldexpl; if (Op->getType()->isFloatTy()) - LdExp = LibFunc::ldexpf; + LdExp = LibFunc_ldexpf; else if (Op->getType()->isDoubleTy()) - LdExp = LibFunc::ldexp; + LdExp = LibFunc_ldexp; if (TLI->has(LdExp)) { Value *LdExpArg = nullptr; @@ -1197,7 +1243,7 @@ Value *LibCallSimplifier::optimizeExp2(CallInst *CI, IRBuilder<> &B) { Module *M = CI->getModule(); Value *NewCallee = M->getOrInsertFunction(TLI->getName(LdExp), Op->getType(), - Op->getType(), B.getInt32Ty(), nullptr); + Op->getType(), B.getInt32Ty()); CallInst *CI = B.CreateCall(NewCallee, {One, LdExpArg}); if (const Function *F = dyn_cast<Function>(Callee->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); @@ -1208,15 +1254,6 @@ Value *LibCallSimplifier::optimizeExp2(CallInst *CI, IRBuilder<> &B) { return Ret; } -Value *LibCallSimplifier::optimizeFabs(CallInst *CI, IRBuilder<> &B) { - Function *Callee = CI->getCalledFunction(); - StringRef Name = Callee->getName(); - if (Name == "fabs" && hasFloatVersion(Name)) - return optimizeUnaryDoubleFP(CI, B, false); - - return nullptr; -} - Value *LibCallSimplifier::optimizeFMinFMax(CallInst *CI, IRBuilder<> &B) { Function *Callee = CI->getCalledFunction(); // If we can shrink the call to a float function rather than a double @@ -1280,17 +1317,17 @@ Value *LibCallSimplifier::optimizeLog(CallInst *CI, IRBuilder<> &B) { FMF.setUnsafeAlgebra(); B.setFastMathFlags(FMF); - LibFunc::Func Func; + LibFunc Func; Function *F = OpC->getCalledFunction(); if (F && ((TLI->getLibFunc(F->getName(), Func) && TLI->has(Func) && - Func == LibFunc::pow) || F->getIntrinsicID() == Intrinsic::pow)) + Func == LibFunc_pow) || F->getIntrinsicID() == Intrinsic::pow)) return B.CreateFMul(OpC->getArgOperand(1), emitUnaryFloatFnCall(OpC->getOperand(0), Callee->getName(), B, Callee->getAttributes()), "mul"); // log(exp2(y)) -> y*log(2) if (F && Name == "log" && TLI->getLibFunc(F->getName(), Func) && - TLI->has(Func) && Func == LibFunc::exp2) + TLI->has(Func) && Func == LibFunc_exp2) return B.CreateFMul( OpC->getArgOperand(0), emitUnaryFloatFnCall(ConstantFP::get(CI->getType(), 2.0), @@ -1302,8 +1339,11 @@ Value *LibCallSimplifier::optimizeLog(CallInst *CI, IRBuilder<> &B) { Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilder<> &B) { Function *Callee = CI->getCalledFunction(); Value *Ret = nullptr; - if (TLI->has(LibFunc::sqrtf) && (Callee->getName() == "sqrt" || - Callee->getIntrinsicID() == Intrinsic::sqrt)) + // TODO: Once we have a way (other than checking for the existince of the + // libcall) to tell whether our target can lower @llvm.sqrt, relax the + // condition below. + if (TLI->has(LibFunc_sqrtf) && (Callee->getName() == "sqrt" || + Callee->getIntrinsicID() == Intrinsic::sqrt)) Ret = optimizeUnaryDoubleFP(CI, B, true); if (!CI->hasUnsafeAlgebra()) @@ -1385,12 +1425,12 @@ Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilder<> &B) { // tan(atan(x)) -> x // tanf(atanf(x)) -> x // tanl(atanl(x)) -> x - LibFunc::Func Func; + LibFunc Func; Function *F = OpC->getCalledFunction(); if (F && TLI->getLibFunc(F->getName(), Func) && TLI->has(Func) && - ((Func == LibFunc::atan && Callee->getName() == "tan") || - (Func == LibFunc::atanf && Callee->getName() == "tanf") || - (Func == LibFunc::atanl && Callee->getName() == "tanl"))) + ((Func == LibFunc_atan && Callee->getName() == "tan") || + (Func == LibFunc_atanf && Callee->getName() == "tanf") || + (Func == LibFunc_atanl && Callee->getName() == "tanl"))) Ret = OpC->getArgOperand(0); return Ret; } @@ -1418,16 +1458,16 @@ static void insertSinCosCall(IRBuilder<> &B, Function *OrigCallee, Value *Arg, // x86_64 can't use {float, float} since that would be returned in both // xmm0 and xmm1, which isn't what a real struct would do. ResTy = T.getArch() == Triple::x86_64 - ? static_cast<Type *>(VectorType::get(ArgTy, 2)) - : static_cast<Type *>(StructType::get(ArgTy, ArgTy, nullptr)); + ? static_cast<Type *>(VectorType::get(ArgTy, 2)) + : static_cast<Type *>(StructType::get(ArgTy, ArgTy)); } else { Name = "__sincospi_stret"; - ResTy = StructType::get(ArgTy, ArgTy, nullptr); + ResTy = StructType::get(ArgTy, ArgTy); } Module *M = OrigCallee->getParent(); Value *Callee = M->getOrInsertFunction(Name, OrigCallee->getAttributes(), - ResTy, ArgTy, nullptr); + ResTy, ArgTy); if (Instruction *ArgInst = dyn_cast<Instruction>(Arg)) { // If the argument is an instruction, it must dominate all uses so put our @@ -1508,24 +1548,24 @@ void LibCallSimplifier::classifyArgUse( return; Function *Callee = CI->getCalledFunction(); - LibFunc::Func Func; + LibFunc Func; if (!Callee || !TLI->getLibFunc(*Callee, Func) || !TLI->has(Func) || !isTrigLibCall(CI)) return; if (IsFloat) { - if (Func == LibFunc::sinpif) + if (Func == LibFunc_sinpif) SinCalls.push_back(CI); - else if (Func == LibFunc::cospif) + else if (Func == LibFunc_cospif) CosCalls.push_back(CI); - else if (Func == LibFunc::sincospif_stret) + else if (Func == LibFunc_sincospif_stret) SinCosCalls.push_back(CI); } else { - if (Func == LibFunc::sinpi) + if (Func == LibFunc_sinpi) SinCalls.push_back(CI); - else if (Func == LibFunc::cospi) + else if (Func == LibFunc_cospi) CosCalls.push_back(CI); - else if (Func == LibFunc::sincospi_stret) + else if (Func == LibFunc_sincospi_stret) SinCosCalls.push_back(CI); } } @@ -1609,14 +1649,14 @@ Value *LibCallSimplifier::optimizeErrorReporting(CallInst *CI, IRBuilder<> &B, // Proceedings of PACT'98, Oct. 1998, IEEE if (!CI->hasFnAttr(Attribute::Cold) && isReportingError(Callee, CI, StreamArg)) { - CI->addAttribute(AttributeSet::FunctionIndex, Attribute::Cold); + CI->addAttribute(AttributeList::FunctionIndex, Attribute::Cold); } return nullptr; } static bool isReportingError(Function *Callee, CallInst *CI, int StreamArg) { - if (!ColdErrorCalls || !Callee || !Callee->isDeclaration()) + if (!Callee || !Callee->isDeclaration()) return false; if (StreamArg < 0) @@ -1699,7 +1739,7 @@ Value *LibCallSimplifier::optimizePrintF(CallInst *CI, IRBuilder<> &B) { // printf(format, ...) -> iprintf(format, ...) if no floating point // arguments. - if (TLI->has(LibFunc::iprintf) && !callHasFloatingPointArgument(CI)) { + if (TLI->has(LibFunc_iprintf) && !callHasFloatingPointArgument(CI)) { Module *M = B.GetInsertBlock()->getParent()->getParent(); Constant *IPrintFFn = M->getOrInsertFunction("iprintf", FT, Callee->getAttributes()); @@ -1780,7 +1820,7 @@ Value *LibCallSimplifier::optimizeSPrintF(CallInst *CI, IRBuilder<> &B) { // sprintf(str, format, ...) -> siprintf(str, format, ...) if no floating // point arguments. - if (TLI->has(LibFunc::siprintf) && !callHasFloatingPointArgument(CI)) { + if (TLI->has(LibFunc_siprintf) && !callHasFloatingPointArgument(CI)) { Module *M = B.GetInsertBlock()->getParent()->getParent(); Constant *SIPrintFFn = M->getOrInsertFunction("siprintf", FT, Callee->getAttributes()); @@ -1850,7 +1890,7 @@ Value *LibCallSimplifier::optimizeFPrintF(CallInst *CI, IRBuilder<> &B) { // fprintf(stream, format, ...) -> fiprintf(stream, format, ...) if no // floating point arguments. - if (TLI->has(LibFunc::fiprintf) && !callHasFloatingPointArgument(CI)) { + if (TLI->has(LibFunc_fiprintf) && !callHasFloatingPointArgument(CI)) { Module *M = B.GetInsertBlock()->getParent()->getParent(); Constant *FIPrintFFn = M->getOrInsertFunction("fiprintf", FT, Callee->getAttributes()); @@ -1929,7 +1969,7 @@ Value *LibCallSimplifier::optimizePuts(CallInst *CI, IRBuilder<> &B) { } bool LibCallSimplifier::hasFloatVersion(StringRef FuncName) { - LibFunc::Func Func; + LibFunc Func; SmallString<20> FloatFuncName = FuncName; FloatFuncName += 'f'; if (TLI->getLibFunc(FloatFuncName, Func)) @@ -1939,7 +1979,7 @@ bool LibCallSimplifier::hasFloatVersion(StringRef FuncName) { Value *LibCallSimplifier::optimizeStringMemoryLibCall(CallInst *CI, IRBuilder<> &Builder) { - LibFunc::Func Func; + LibFunc Func; Function *Callee = CI->getCalledFunction(); // Check for string/memory library functions. if (TLI->getLibFunc(*Callee, Func) && TLI->has(Func)) { @@ -1948,52 +1988,54 @@ Value *LibCallSimplifier::optimizeStringMemoryLibCall(CallInst *CI, isCallingConvCCompatible(CI)) && "Optimizing string/memory libcall would change the calling convention"); switch (Func) { - case LibFunc::strcat: + case LibFunc_strcat: return optimizeStrCat(CI, Builder); - case LibFunc::strncat: + case LibFunc_strncat: return optimizeStrNCat(CI, Builder); - case LibFunc::strchr: + case LibFunc_strchr: return optimizeStrChr(CI, Builder); - case LibFunc::strrchr: + case LibFunc_strrchr: return optimizeStrRChr(CI, Builder); - case LibFunc::strcmp: + case LibFunc_strcmp: return optimizeStrCmp(CI, Builder); - case LibFunc::strncmp: + case LibFunc_strncmp: return optimizeStrNCmp(CI, Builder); - case LibFunc::strcpy: + case LibFunc_strcpy: return optimizeStrCpy(CI, Builder); - case LibFunc::stpcpy: + case LibFunc_stpcpy: return optimizeStpCpy(CI, Builder); - case LibFunc::strncpy: + case LibFunc_strncpy: return optimizeStrNCpy(CI, Builder); - case LibFunc::strlen: + case LibFunc_strlen: return optimizeStrLen(CI, Builder); - case LibFunc::strpbrk: + case LibFunc_strpbrk: return optimizeStrPBrk(CI, Builder); - case LibFunc::strtol: - case LibFunc::strtod: - case LibFunc::strtof: - case LibFunc::strtoul: - case LibFunc::strtoll: - case LibFunc::strtold: - case LibFunc::strtoull: + case LibFunc_strtol: + case LibFunc_strtod: + case LibFunc_strtof: + case LibFunc_strtoul: + case LibFunc_strtoll: + case LibFunc_strtold: + case LibFunc_strtoull: return optimizeStrTo(CI, Builder); - case LibFunc::strspn: + case LibFunc_strspn: return optimizeStrSpn(CI, Builder); - case LibFunc::strcspn: + case LibFunc_strcspn: return optimizeStrCSpn(CI, Builder); - case LibFunc::strstr: + case LibFunc_strstr: return optimizeStrStr(CI, Builder); - case LibFunc::memchr: + case LibFunc_memchr: return optimizeMemChr(CI, Builder); - case LibFunc::memcmp: + case LibFunc_memcmp: return optimizeMemCmp(CI, Builder); - case LibFunc::memcpy: + case LibFunc_memcpy: return optimizeMemCpy(CI, Builder); - case LibFunc::memmove: + case LibFunc_memmove: return optimizeMemMove(CI, Builder); - case LibFunc::memset: + case LibFunc_memset: return optimizeMemSet(CI, Builder); + case LibFunc_wcslen: + return optimizeWcslen(CI, Builder); default: break; } @@ -2005,7 +2047,7 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) { if (CI->isNoBuiltin()) return nullptr; - LibFunc::Func Func; + LibFunc Func; Function *Callee = CI->getCalledFunction(); StringRef FuncName = Callee->getName(); @@ -2029,8 +2071,6 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) { return optimizePow(CI, Builder); case Intrinsic::exp2: return optimizeExp2(CI, Builder); - case Intrinsic::fabs: - return optimizeFabs(CI, Builder); case Intrinsic::log: return optimizeLog(CI, Builder); case Intrinsic::sqrt: @@ -2067,114 +2107,117 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) { if (Value *V = optimizeStringMemoryLibCall(CI, Builder)) return V; switch (Func) { - case LibFunc::cosf: - case LibFunc::cos: - case LibFunc::cosl: + case LibFunc_cosf: + case LibFunc_cos: + case LibFunc_cosl: return optimizeCos(CI, Builder); - case LibFunc::sinpif: - case LibFunc::sinpi: - case LibFunc::cospif: - case LibFunc::cospi: + case LibFunc_sinpif: + case LibFunc_sinpi: + case LibFunc_cospif: + case LibFunc_cospi: return optimizeSinCosPi(CI, Builder); - case LibFunc::powf: - case LibFunc::pow: - case LibFunc::powl: + case LibFunc_powf: + case LibFunc_pow: + case LibFunc_powl: return optimizePow(CI, Builder); - case LibFunc::exp2l: - case LibFunc::exp2: - case LibFunc::exp2f: + case LibFunc_exp2l: + case LibFunc_exp2: + case LibFunc_exp2f: return optimizeExp2(CI, Builder); - case LibFunc::fabsf: - case LibFunc::fabs: - case LibFunc::fabsl: - return optimizeFabs(CI, Builder); - case LibFunc::sqrtf: - case LibFunc::sqrt: - case LibFunc::sqrtl: + case LibFunc_fabsf: + case LibFunc_fabs: + case LibFunc_fabsl: + return replaceUnaryCall(CI, Builder, Intrinsic::fabs); + case LibFunc_sqrtf: + case LibFunc_sqrt: + case LibFunc_sqrtl: return optimizeSqrt(CI, Builder); - case LibFunc::ffs: - case LibFunc::ffsl: - case LibFunc::ffsll: + case LibFunc_ffs: + case LibFunc_ffsl: + case LibFunc_ffsll: return optimizeFFS(CI, Builder); - case LibFunc::fls: - case LibFunc::flsl: - case LibFunc::flsll: + case LibFunc_fls: + case LibFunc_flsl: + case LibFunc_flsll: return optimizeFls(CI, Builder); - case LibFunc::abs: - case LibFunc::labs: - case LibFunc::llabs: + case LibFunc_abs: + case LibFunc_labs: + case LibFunc_llabs: return optimizeAbs(CI, Builder); - case LibFunc::isdigit: + case LibFunc_isdigit: return optimizeIsDigit(CI, Builder); - case LibFunc::isascii: + case LibFunc_isascii: return optimizeIsAscii(CI, Builder); - case LibFunc::toascii: + case LibFunc_toascii: return optimizeToAscii(CI, Builder); - case LibFunc::printf: + case LibFunc_printf: return optimizePrintF(CI, Builder); - case LibFunc::sprintf: + case LibFunc_sprintf: return optimizeSPrintF(CI, Builder); - case LibFunc::fprintf: + case LibFunc_fprintf: return optimizeFPrintF(CI, Builder); - case LibFunc::fwrite: + case LibFunc_fwrite: return optimizeFWrite(CI, Builder); - case LibFunc::fputs: + case LibFunc_fputs: return optimizeFPuts(CI, Builder); - case LibFunc::log: - case LibFunc::log10: - case LibFunc::log1p: - case LibFunc::log2: - case LibFunc::logb: + case LibFunc_log: + case LibFunc_log10: + case LibFunc_log1p: + case LibFunc_log2: + case LibFunc_logb: return optimizeLog(CI, Builder); - case LibFunc::puts: + case LibFunc_puts: return optimizePuts(CI, Builder); - case LibFunc::tan: - case LibFunc::tanf: - case LibFunc::tanl: + case LibFunc_tan: + case LibFunc_tanf: + case LibFunc_tanl: return optimizeTan(CI, Builder); - case LibFunc::perror: + case LibFunc_perror: return optimizeErrorReporting(CI, Builder); - case LibFunc::vfprintf: - case LibFunc::fiprintf: + case LibFunc_vfprintf: + case LibFunc_fiprintf: return optimizeErrorReporting(CI, Builder, 0); - case LibFunc::fputc: + case LibFunc_fputc: return optimizeErrorReporting(CI, Builder, 1); - case LibFunc::ceil: - case LibFunc::floor: - case LibFunc::rint: - case LibFunc::round: - case LibFunc::nearbyint: - case LibFunc::trunc: - if (hasFloatVersion(FuncName)) - return optimizeUnaryDoubleFP(CI, Builder, false); - return nullptr; - case LibFunc::acos: - case LibFunc::acosh: - case LibFunc::asin: - case LibFunc::asinh: - case LibFunc::atan: - case LibFunc::atanh: - case LibFunc::cbrt: - case LibFunc::cosh: - case LibFunc::exp: - case LibFunc::exp10: - case LibFunc::expm1: - case LibFunc::sin: - case LibFunc::sinh: - case LibFunc::tanh: + case LibFunc_ceil: + return replaceUnaryCall(CI, Builder, Intrinsic::ceil); + case LibFunc_floor: + return replaceUnaryCall(CI, Builder, Intrinsic::floor); + case LibFunc_round: + return replaceUnaryCall(CI, Builder, Intrinsic::round); + case LibFunc_nearbyint: + return replaceUnaryCall(CI, Builder, Intrinsic::nearbyint); + case LibFunc_rint: + return replaceUnaryCall(CI, Builder, Intrinsic::rint); + case LibFunc_trunc: + return replaceUnaryCall(CI, Builder, Intrinsic::trunc); + case LibFunc_acos: + case LibFunc_acosh: + case LibFunc_asin: + case LibFunc_asinh: + case LibFunc_atan: + case LibFunc_atanh: + case LibFunc_cbrt: + case LibFunc_cosh: + case LibFunc_exp: + case LibFunc_exp10: + case LibFunc_expm1: + case LibFunc_sin: + case LibFunc_sinh: + case LibFunc_tanh: if (UnsafeFPShrink && hasFloatVersion(FuncName)) return optimizeUnaryDoubleFP(CI, Builder, true); return nullptr; - case LibFunc::copysign: + case LibFunc_copysign: if (hasFloatVersion(FuncName)) return optimizeBinaryDoubleFP(CI, Builder); return nullptr; - case LibFunc::fminf: - case LibFunc::fmin: - case LibFunc::fminl: - case LibFunc::fmaxf: - case LibFunc::fmax: - case LibFunc::fmaxl: + case LibFunc_fminf: + case LibFunc_fmin: + case LibFunc_fminl: + case LibFunc_fmaxf: + case LibFunc_fmax: + case LibFunc_fmaxl: return optimizeFMinFMax(CI, Builder); default: return nullptr; @@ -2211,16 +2254,10 @@ void LibCallSimplifier::replaceAllUsesWith(Instruction *I, Value *With) { // * log(exp10(y)) -> y*log(10) // * log(sqrt(x)) -> 0.5*log(x) // -// lround, lroundf, lroundl: -// * lround(cnst) -> cnst' -// // pow, powf, powl: // * pow(sqrt(x),y) -> pow(x,y*0.5) // * pow(pow(x,y),z)-> pow(x,y*z) // -// round, roundf, roundl: -// * round(cnst) -> cnst' -// // signbit: // * signbit(cnst) -> cnst' // * signbit(nncst) -> 0 (if pstv is a non-negative constant) @@ -2230,10 +2267,6 @@ void LibCallSimplifier::replaceAllUsesWith(Instruction *I, Value *With) { // * sqrt(Nroot(x)) -> pow(x,1/(2*N)) // * sqrt(pow(x,y)) -> pow(|x|,y*0.5) // -// trunc, truncf, truncl: -// * trunc(cnst) -> cnst' -// -// //===----------------------------------------------------------------------===// // Fortified Library Call Optimizations @@ -2247,7 +2280,7 @@ bool FortifiedLibCallSimplifier::isFortifiedCallFoldable(CallInst *CI, return true; if (ConstantInt *ObjSizeCI = dyn_cast<ConstantInt>(CI->getArgOperand(ObjSizeOp))) { - if (ObjSizeCI->isAllOnesValue()) + if (ObjSizeCI->isMinusOne()) return true; // If the object size wasn't -1 (unknown), bail out if we were asked to. if (OnlyLowerUnknownSize) @@ -2300,7 +2333,7 @@ Value *FortifiedLibCallSimplifier::optimizeMemSetChk(CallInst *CI, Value *FortifiedLibCallSimplifier::optimizeStrpCpyChk(CallInst *CI, IRBuilder<> &B, - LibFunc::Func Func) { + LibFunc Func) { Function *Callee = CI->getCalledFunction(); StringRef Name = Callee->getName(); const DataLayout &DL = CI->getModule()->getDataLayout(); @@ -2308,7 +2341,7 @@ Value *FortifiedLibCallSimplifier::optimizeStrpCpyChk(CallInst *CI, *ObjSize = CI->getArgOperand(2); // __stpcpy_chk(x,x,...) -> x+strlen(x) - if (Func == LibFunc::stpcpy_chk && !OnlyLowerUnknownSize && Dst == Src) { + if (Func == LibFunc_stpcpy_chk && !OnlyLowerUnknownSize && Dst == Src) { Value *StrLen = emitStrLen(Src, B, DL, TLI); return StrLen ? B.CreateInBoundsGEP(B.getInt8Ty(), Dst, StrLen) : nullptr; } @@ -2334,14 +2367,14 @@ Value *FortifiedLibCallSimplifier::optimizeStrpCpyChk(CallInst *CI, Value *Ret = emitMemCpyChk(Dst, Src, LenV, ObjSize, B, DL, TLI); // If the function was an __stpcpy_chk, and we were able to fold it into // a __memcpy_chk, we still need to return the correct end pointer. - if (Ret && Func == LibFunc::stpcpy_chk) + if (Ret && Func == LibFunc_stpcpy_chk) return B.CreateGEP(B.getInt8Ty(), Dst, ConstantInt::get(SizeTTy, Len - 1)); return Ret; } Value *FortifiedLibCallSimplifier::optimizeStrpNCpyChk(CallInst *CI, IRBuilder<> &B, - LibFunc::Func Func) { + LibFunc Func) { Function *Callee = CI->getCalledFunction(); StringRef Name = Callee->getName(); if (isFortifiedCallFoldable(CI, 3, 2, false)) { @@ -2366,7 +2399,7 @@ Value *FortifiedLibCallSimplifier::optimizeCall(CallInst *CI) { // // PR23093. - LibFunc::Func Func; + LibFunc Func; Function *Callee = CI->getCalledFunction(); SmallVector<OperandBundleDef, 2> OpBundles; @@ -2384,17 +2417,17 @@ Value *FortifiedLibCallSimplifier::optimizeCall(CallInst *CI) { return nullptr; switch (Func) { - case LibFunc::memcpy_chk: + case LibFunc_memcpy_chk: return optimizeMemCpyChk(CI, Builder); - case LibFunc::memmove_chk: + case LibFunc_memmove_chk: return optimizeMemMoveChk(CI, Builder); - case LibFunc::memset_chk: + case LibFunc_memset_chk: return optimizeMemSetChk(CI, Builder); - case LibFunc::stpcpy_chk: - case LibFunc::strcpy_chk: + case LibFunc_stpcpy_chk: + case LibFunc_strcpy_chk: return optimizeStrpCpyChk(CI, Builder, Func); - case LibFunc::stpncpy_chk: - case LibFunc::strncpy_chk: + case LibFunc_stpncpy_chk: + case LibFunc_strncpy_chk: return optimizeStrpNCpyChk(CI, Builder, Func); default: break; diff --git a/contrib/llvm/lib/Transforms/Utils/StripGCRelocates.cpp b/contrib/llvm/lib/Transforms/Utils/StripGCRelocates.cpp index f3d3fad..49dc15c 100644 --- a/contrib/llvm/lib/Transforms/Utils/StripGCRelocates.cpp +++ b/contrib/llvm/lib/Transforms/Utils/StripGCRelocates.cpp @@ -20,8 +20,8 @@ #include "llvm/IR/Statepoint.h" #include "llvm/IR/Type.h" #include "llvm/Pass.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" using namespace llvm; diff --git a/contrib/llvm/lib/Transforms/Utils/StripNonLineTableDebugInfo.cpp b/contrib/llvm/lib/Transforms/Utils/StripNonLineTableDebugInfo.cpp index 66dbf33..cd0378e 100644 --- a/contrib/llvm/lib/Transforms/Utils/StripNonLineTableDebugInfo.cpp +++ b/contrib/llvm/lib/Transforms/Utils/StripNonLineTableDebugInfo.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/IPO.h" #include "llvm/IR/DebugInfo.h" #include "llvm/Pass.h" +#include "llvm/Transforms/IPO.h" using namespace llvm; namespace { diff --git a/contrib/llvm/lib/Transforms/Utils/SymbolRewriter.cpp b/contrib/llvm/lib/Transforms/Utils/SymbolRewriter.cpp index 6d13663..2010755 100644 --- a/contrib/llvm/lib/Transforms/Utils/SymbolRewriter.cpp +++ b/contrib/llvm/lib/Transforms/Utils/SymbolRewriter.cpp @@ -59,9 +59,9 @@ #define DEBUG_TYPE "symbol-rewriter" #include "llvm/Transforms/Utils/SymbolRewriter.h" -#include "llvm/Pass.h" #include "llvm/ADT/SmallString.h" #include "llvm/IR/LegacyPassManager.h" +#include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/MemoryBuffer.h" diff --git a/contrib/llvm/lib/Transforms/Utils/Utils.cpp b/contrib/llvm/lib/Transforms/Utils/Utils.cpp index 7b9de2e..f6c7d1c 100644 --- a/contrib/llvm/lib/Transforms/Utils/Utils.cpp +++ b/contrib/llvm/lib/Transforms/Utils/Utils.cpp @@ -12,8 +12,8 @@ // //===----------------------------------------------------------------------===// -#include "llvm/InitializePasses.h" #include "llvm-c/Initialization.h" +#include "llvm/InitializePasses.h" #include "llvm/PassRegistry.h" using namespace llvm; @@ -35,9 +35,8 @@ void llvm::initializeTransformUtils(PassRegistry &Registry) { initializeUnifyFunctionExitNodesPass(Registry); initializeInstSimplifierPass(Registry); initializeMetaRenamerPass(Registry); - initializeMemorySSAWrapperPassPass(Registry); - initializeMemorySSAPrinterLegacyPassPass(Registry); initializeStripGCRelocatesPass(Registry); + initializePredicateInfoPrinterLegacyPassPass(Registry); } /// LLVMInitializeTransformUtils - C binding for initializeTransformUtilsPasses. diff --git a/contrib/llvm/lib/Transforms/Utils/VNCoercion.cpp b/contrib/llvm/lib/Transforms/Utils/VNCoercion.cpp new file mode 100644 index 0000000..c3feea6 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Utils/VNCoercion.cpp @@ -0,0 +1,495 @@ +#include "llvm/Transforms/Utils/VNCoercion.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/MemoryDependenceAnalysis.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "vncoerce" +namespace llvm { +namespace VNCoercion { + +/// Return true if coerceAvailableValueToLoadType will succeed. +bool canCoerceMustAliasedValueToLoad(Value *StoredVal, Type *LoadTy, + const DataLayout &DL) { + // If the loaded or stored value is an first class array or struct, don't try + // to transform them. We need to be able to bitcast to integer. + if (LoadTy->isStructTy() || LoadTy->isArrayTy() || + StoredVal->getType()->isStructTy() || StoredVal->getType()->isArrayTy()) + return false; + + // The store has to be at least as big as the load. + if (DL.getTypeSizeInBits(StoredVal->getType()) < DL.getTypeSizeInBits(LoadTy)) + return false; + + // Don't coerce non-integral pointers to integers or vice versa. + if (DL.isNonIntegralPointerType(StoredVal->getType()) != + DL.isNonIntegralPointerType(LoadTy)) + return false; + + return true; +} + +template <class T, class HelperClass> +static T *coerceAvailableValueToLoadTypeHelper(T *StoredVal, Type *LoadedTy, + HelperClass &Helper, + const DataLayout &DL) { + assert(canCoerceMustAliasedValueToLoad(StoredVal, LoadedTy, DL) && + "precondition violation - materialization can't fail"); + if (auto *C = dyn_cast<Constant>(StoredVal)) + if (auto *FoldedStoredVal = ConstantFoldConstant(C, DL)) + StoredVal = FoldedStoredVal; + + // If this is already the right type, just return it. + Type *StoredValTy = StoredVal->getType(); + + uint64_t StoredValSize = DL.getTypeSizeInBits(StoredValTy); + uint64_t LoadedValSize = DL.getTypeSizeInBits(LoadedTy); + + // If the store and reload are the same size, we can always reuse it. + if (StoredValSize == LoadedValSize) { + // Pointer to Pointer -> use bitcast. + if (StoredValTy->isPtrOrPtrVectorTy() && LoadedTy->isPtrOrPtrVectorTy()) { + StoredVal = Helper.CreateBitCast(StoredVal, LoadedTy); + } else { + // Convert source pointers to integers, which can be bitcast. + if (StoredValTy->isPtrOrPtrVectorTy()) { + StoredValTy = DL.getIntPtrType(StoredValTy); + StoredVal = Helper.CreatePtrToInt(StoredVal, StoredValTy); + } + + Type *TypeToCastTo = LoadedTy; + if (TypeToCastTo->isPtrOrPtrVectorTy()) + TypeToCastTo = DL.getIntPtrType(TypeToCastTo); + + if (StoredValTy != TypeToCastTo) + StoredVal = Helper.CreateBitCast(StoredVal, TypeToCastTo); + + // Cast to pointer if the load needs a pointer type. + if (LoadedTy->isPtrOrPtrVectorTy()) + StoredVal = Helper.CreateIntToPtr(StoredVal, LoadedTy); + } + + if (auto *C = dyn_cast<ConstantExpr>(StoredVal)) + if (auto *FoldedStoredVal = ConstantFoldConstant(C, DL)) + StoredVal = FoldedStoredVal; + + return StoredVal; + } + // If the loaded value is smaller than the available value, then we can + // extract out a piece from it. If the available value is too small, then we + // can't do anything. + assert(StoredValSize >= LoadedValSize && + "canCoerceMustAliasedValueToLoad fail"); + + // Convert source pointers to integers, which can be manipulated. + if (StoredValTy->isPtrOrPtrVectorTy()) { + StoredValTy = DL.getIntPtrType(StoredValTy); + StoredVal = Helper.CreatePtrToInt(StoredVal, StoredValTy); + } + + // Convert vectors and fp to integer, which can be manipulated. + if (!StoredValTy->isIntegerTy()) { + StoredValTy = IntegerType::get(StoredValTy->getContext(), StoredValSize); + StoredVal = Helper.CreateBitCast(StoredVal, StoredValTy); + } + + // If this is a big-endian system, we need to shift the value down to the low + // bits so that a truncate will work. + if (DL.isBigEndian()) { + uint64_t ShiftAmt = DL.getTypeStoreSizeInBits(StoredValTy) - + DL.getTypeStoreSizeInBits(LoadedTy); + StoredVal = Helper.CreateLShr( + StoredVal, ConstantInt::get(StoredVal->getType(), ShiftAmt)); + } + + // Truncate the integer to the right size now. + Type *NewIntTy = IntegerType::get(StoredValTy->getContext(), LoadedValSize); + StoredVal = Helper.CreateTruncOrBitCast(StoredVal, NewIntTy); + + if (LoadedTy != NewIntTy) { + // If the result is a pointer, inttoptr. + if (LoadedTy->isPtrOrPtrVectorTy()) + StoredVal = Helper.CreateIntToPtr(StoredVal, LoadedTy); + else + // Otherwise, bitcast. + StoredVal = Helper.CreateBitCast(StoredVal, LoadedTy); + } + + if (auto *C = dyn_cast<Constant>(StoredVal)) + if (auto *FoldedStoredVal = ConstantFoldConstant(C, DL)) + StoredVal = FoldedStoredVal; + + return StoredVal; +} + +/// If we saw a store of a value to memory, and +/// then a load from a must-aliased pointer of a different type, try to coerce +/// the stored value. LoadedTy is the type of the load we want to replace. +/// IRB is IRBuilder used to insert new instructions. +/// +/// If we can't do it, return null. +Value *coerceAvailableValueToLoadType(Value *StoredVal, Type *LoadedTy, + IRBuilder<> &IRB, const DataLayout &DL) { + return coerceAvailableValueToLoadTypeHelper(StoredVal, LoadedTy, IRB, DL); +} + +/// This function is called when we have a memdep query of a load that ends up +/// being a clobbering memory write (store, memset, memcpy, memmove). This +/// means that the write *may* provide bits used by the load but we can't be +/// sure because the pointers don't must-alias. +/// +/// Check this case to see if there is anything more we can do before we give +/// up. This returns -1 if we have to give up, or a byte number in the stored +/// value of the piece that feeds the load. +static int analyzeLoadFromClobberingWrite(Type *LoadTy, Value *LoadPtr, + Value *WritePtr, + uint64_t WriteSizeInBits, + const DataLayout &DL) { + // If the loaded or stored value is a first class array or struct, don't try + // to transform them. We need to be able to bitcast to integer. + if (LoadTy->isStructTy() || LoadTy->isArrayTy()) + return -1; + + int64_t StoreOffset = 0, LoadOffset = 0; + Value *StoreBase = + GetPointerBaseWithConstantOffset(WritePtr, StoreOffset, DL); + Value *LoadBase = GetPointerBaseWithConstantOffset(LoadPtr, LoadOffset, DL); + if (StoreBase != LoadBase) + return -1; + + // If the load and store are to the exact same address, they should have been + // a must alias. AA must have gotten confused. + // FIXME: Study to see if/when this happens. One case is forwarding a memset + // to a load from the base of the memset. + + // If the load and store don't overlap at all, the store doesn't provide + // anything to the load. In this case, they really don't alias at all, AA + // must have gotten confused. + uint64_t LoadSize = DL.getTypeSizeInBits(LoadTy); + + if ((WriteSizeInBits & 7) | (LoadSize & 7)) + return -1; + uint64_t StoreSize = WriteSizeInBits / 8; // Convert to bytes. + LoadSize /= 8; + + bool isAAFailure = false; + if (StoreOffset < LoadOffset) + isAAFailure = StoreOffset + int64_t(StoreSize) <= LoadOffset; + else + isAAFailure = LoadOffset + int64_t(LoadSize) <= StoreOffset; + + if (isAAFailure) + return -1; + + // If the Load isn't completely contained within the stored bits, we don't + // have all the bits to feed it. We could do something crazy in the future + // (issue a smaller load then merge the bits in) but this seems unlikely to be + // valuable. + if (StoreOffset > LoadOffset || + StoreOffset + StoreSize < LoadOffset + LoadSize) + return -1; + + // Okay, we can do this transformation. Return the number of bytes into the + // store that the load is. + return LoadOffset - StoreOffset; +} + +/// This function is called when we have a +/// memdep query of a load that ends up being a clobbering store. +int analyzeLoadFromClobberingStore(Type *LoadTy, Value *LoadPtr, + StoreInst *DepSI, const DataLayout &DL) { + // Cannot handle reading from store of first-class aggregate yet. + if (DepSI->getValueOperand()->getType()->isStructTy() || + DepSI->getValueOperand()->getType()->isArrayTy()) + return -1; + + Value *StorePtr = DepSI->getPointerOperand(); + uint64_t StoreSize = + DL.getTypeSizeInBits(DepSI->getValueOperand()->getType()); + return analyzeLoadFromClobberingWrite(LoadTy, LoadPtr, StorePtr, StoreSize, + DL); +} + +/// This function is called when we have a +/// memdep query of a load that ends up being clobbered by another load. See if +/// the other load can feed into the second load. +int analyzeLoadFromClobberingLoad(Type *LoadTy, Value *LoadPtr, LoadInst *DepLI, + const DataLayout &DL) { + // Cannot handle reading from store of first-class aggregate yet. + if (DepLI->getType()->isStructTy() || DepLI->getType()->isArrayTy()) + return -1; + + Value *DepPtr = DepLI->getPointerOperand(); + uint64_t DepSize = DL.getTypeSizeInBits(DepLI->getType()); + int R = analyzeLoadFromClobberingWrite(LoadTy, LoadPtr, DepPtr, DepSize, DL); + if (R != -1) + return R; + + // If we have a load/load clobber an DepLI can be widened to cover this load, + // then we should widen it! + int64_t LoadOffs = 0; + const Value *LoadBase = + GetPointerBaseWithConstantOffset(LoadPtr, LoadOffs, DL); + unsigned LoadSize = DL.getTypeStoreSize(LoadTy); + + unsigned Size = MemoryDependenceResults::getLoadLoadClobberFullWidthSize( + LoadBase, LoadOffs, LoadSize, DepLI); + if (Size == 0) + return -1; + + // Check non-obvious conditions enforced by MDA which we rely on for being + // able to materialize this potentially available value + assert(DepLI->isSimple() && "Cannot widen volatile/atomic load!"); + assert(DepLI->getType()->isIntegerTy() && "Can't widen non-integer load"); + + return analyzeLoadFromClobberingWrite(LoadTy, LoadPtr, DepPtr, Size * 8, DL); +} + +int analyzeLoadFromClobberingMemInst(Type *LoadTy, Value *LoadPtr, + MemIntrinsic *MI, const DataLayout &DL) { + // If the mem operation is a non-constant size, we can't handle it. + ConstantInt *SizeCst = dyn_cast<ConstantInt>(MI->getLength()); + if (!SizeCst) + return -1; + uint64_t MemSizeInBits = SizeCst->getZExtValue() * 8; + + // If this is memset, we just need to see if the offset is valid in the size + // of the memset.. + if (MI->getIntrinsicID() == Intrinsic::memset) + return analyzeLoadFromClobberingWrite(LoadTy, LoadPtr, MI->getDest(), + MemSizeInBits, DL); + + // If we have a memcpy/memmove, the only case we can handle is if this is a + // copy from constant memory. In that case, we can read directly from the + // constant memory. + MemTransferInst *MTI = cast<MemTransferInst>(MI); + + Constant *Src = dyn_cast<Constant>(MTI->getSource()); + if (!Src) + return -1; + + GlobalVariable *GV = dyn_cast<GlobalVariable>(GetUnderlyingObject(Src, DL)); + if (!GV || !GV->isConstant()) + return -1; + + // See if the access is within the bounds of the transfer. + int Offset = analyzeLoadFromClobberingWrite(LoadTy, LoadPtr, MI->getDest(), + MemSizeInBits, DL); + if (Offset == -1) + return Offset; + + unsigned AS = Src->getType()->getPointerAddressSpace(); + // Otherwise, see if we can constant fold a load from the constant with the + // offset applied as appropriate. + Src = + ConstantExpr::getBitCast(Src, Type::getInt8PtrTy(Src->getContext(), AS)); + Constant *OffsetCst = + ConstantInt::get(Type::getInt64Ty(Src->getContext()), (unsigned)Offset); + Src = ConstantExpr::getGetElementPtr(Type::getInt8Ty(Src->getContext()), Src, + OffsetCst); + Src = ConstantExpr::getBitCast(Src, PointerType::get(LoadTy, AS)); + if (ConstantFoldLoadFromConstPtr(Src, LoadTy, DL)) + return Offset; + return -1; +} + +template <class T, class HelperClass> +static T *getStoreValueForLoadHelper(T *SrcVal, unsigned Offset, Type *LoadTy, + HelperClass &Helper, + const DataLayout &DL) { + LLVMContext &Ctx = SrcVal->getType()->getContext(); + + // If two pointers are in the same address space, they have the same size, + // so we don't need to do any truncation, etc. This avoids introducing + // ptrtoint instructions for pointers that may be non-integral. + if (SrcVal->getType()->isPointerTy() && LoadTy->isPointerTy() && + cast<PointerType>(SrcVal->getType())->getAddressSpace() == + cast<PointerType>(LoadTy)->getAddressSpace()) { + return SrcVal; + } + + uint64_t StoreSize = (DL.getTypeSizeInBits(SrcVal->getType()) + 7) / 8; + uint64_t LoadSize = (DL.getTypeSizeInBits(LoadTy) + 7) / 8; + // Compute which bits of the stored value are being used by the load. Convert + // to an integer type to start with. + if (SrcVal->getType()->isPtrOrPtrVectorTy()) + SrcVal = Helper.CreatePtrToInt(SrcVal, DL.getIntPtrType(SrcVal->getType())); + if (!SrcVal->getType()->isIntegerTy()) + SrcVal = Helper.CreateBitCast(SrcVal, IntegerType::get(Ctx, StoreSize * 8)); + + // Shift the bits to the least significant depending on endianness. + unsigned ShiftAmt; + if (DL.isLittleEndian()) + ShiftAmt = Offset * 8; + else + ShiftAmt = (StoreSize - LoadSize - Offset) * 8; + if (ShiftAmt) + SrcVal = Helper.CreateLShr(SrcVal, + ConstantInt::get(SrcVal->getType(), ShiftAmt)); + + if (LoadSize != StoreSize) + SrcVal = Helper.CreateTruncOrBitCast(SrcVal, + IntegerType::get(Ctx, LoadSize * 8)); + return SrcVal; +} + +/// This function is called when we have a memdep query of a load that ends up +/// being a clobbering store. This means that the store provides bits used by +/// the load but the pointers don't must-alias. Check this case to see if +/// there is anything more we can do before we give up. +Value *getStoreValueForLoad(Value *SrcVal, unsigned Offset, Type *LoadTy, + Instruction *InsertPt, const DataLayout &DL) { + + IRBuilder<> Builder(InsertPt); + SrcVal = getStoreValueForLoadHelper(SrcVal, Offset, LoadTy, Builder, DL); + return coerceAvailableValueToLoadTypeHelper(SrcVal, LoadTy, Builder, DL); +} + +Constant *getConstantStoreValueForLoad(Constant *SrcVal, unsigned Offset, + Type *LoadTy, const DataLayout &DL) { + ConstantFolder F; + SrcVal = getStoreValueForLoadHelper(SrcVal, Offset, LoadTy, F, DL); + return coerceAvailableValueToLoadTypeHelper(SrcVal, LoadTy, F, DL); +} + +/// This function is called when we have a memdep query of a load that ends up +/// being a clobbering load. This means that the load *may* provide bits used +/// by the load but we can't be sure because the pointers don't must-alias. +/// Check this case to see if there is anything more we can do before we give +/// up. +Value *getLoadValueForLoad(LoadInst *SrcVal, unsigned Offset, Type *LoadTy, + Instruction *InsertPt, const DataLayout &DL) { + // If Offset+LoadTy exceeds the size of SrcVal, then we must be wanting to + // widen SrcVal out to a larger load. + unsigned SrcValStoreSize = DL.getTypeStoreSize(SrcVal->getType()); + unsigned LoadSize = DL.getTypeStoreSize(LoadTy); + if (Offset + LoadSize > SrcValStoreSize) { + assert(SrcVal->isSimple() && "Cannot widen volatile/atomic load!"); + assert(SrcVal->getType()->isIntegerTy() && "Can't widen non-integer load"); + // If we have a load/load clobber an DepLI can be widened to cover this + // load, then we should widen it to the next power of 2 size big enough! + unsigned NewLoadSize = Offset + LoadSize; + if (!isPowerOf2_32(NewLoadSize)) + NewLoadSize = NextPowerOf2(NewLoadSize); + + Value *PtrVal = SrcVal->getPointerOperand(); + // Insert the new load after the old load. This ensures that subsequent + // memdep queries will find the new load. We can't easily remove the old + // load completely because it is already in the value numbering table. + IRBuilder<> Builder(SrcVal->getParent(), ++BasicBlock::iterator(SrcVal)); + Type *DestPTy = IntegerType::get(LoadTy->getContext(), NewLoadSize * 8); + DestPTy = + PointerType::get(DestPTy, PtrVal->getType()->getPointerAddressSpace()); + Builder.SetCurrentDebugLocation(SrcVal->getDebugLoc()); + PtrVal = Builder.CreateBitCast(PtrVal, DestPTy); + LoadInst *NewLoad = Builder.CreateLoad(PtrVal); + NewLoad->takeName(SrcVal); + NewLoad->setAlignment(SrcVal->getAlignment()); + + DEBUG(dbgs() << "GVN WIDENED LOAD: " << *SrcVal << "\n"); + DEBUG(dbgs() << "TO: " << *NewLoad << "\n"); + + // Replace uses of the original load with the wider load. On a big endian + // system, we need to shift down to get the relevant bits. + Value *RV = NewLoad; + if (DL.isBigEndian()) + RV = Builder.CreateLShr(RV, (NewLoadSize - SrcValStoreSize) * 8); + RV = Builder.CreateTrunc(RV, SrcVal->getType()); + SrcVal->replaceAllUsesWith(RV); + + SrcVal = NewLoad; + } + + return getStoreValueForLoad(SrcVal, Offset, LoadTy, InsertPt, DL); +} + +Constant *getConstantLoadValueForLoad(Constant *SrcVal, unsigned Offset, + Type *LoadTy, const DataLayout &DL) { + unsigned SrcValStoreSize = DL.getTypeStoreSize(SrcVal->getType()); + unsigned LoadSize = DL.getTypeStoreSize(LoadTy); + if (Offset + LoadSize > SrcValStoreSize) + return nullptr; + return getConstantStoreValueForLoad(SrcVal, Offset, LoadTy, DL); +} + +template <class T, class HelperClass> +T *getMemInstValueForLoadHelper(MemIntrinsic *SrcInst, unsigned Offset, + Type *LoadTy, HelperClass &Helper, + const DataLayout &DL) { + LLVMContext &Ctx = LoadTy->getContext(); + uint64_t LoadSize = DL.getTypeSizeInBits(LoadTy) / 8; + + // We know that this method is only called when the mem transfer fully + // provides the bits for the load. + if (MemSetInst *MSI = dyn_cast<MemSetInst>(SrcInst)) { + // memset(P, 'x', 1234) -> splat('x'), even if x is a variable, and + // independently of what the offset is. + T *Val = cast<T>(MSI->getValue()); + if (LoadSize != 1) + Val = + Helper.CreateZExtOrBitCast(Val, IntegerType::get(Ctx, LoadSize * 8)); + T *OneElt = Val; + + // Splat the value out to the right number of bits. + for (unsigned NumBytesSet = 1; NumBytesSet != LoadSize;) { + // If we can double the number of bytes set, do it. + if (NumBytesSet * 2 <= LoadSize) { + T *ShVal = Helper.CreateShl( + Val, ConstantInt::get(Val->getType(), NumBytesSet * 8)); + Val = Helper.CreateOr(Val, ShVal); + NumBytesSet <<= 1; + continue; + } + + // Otherwise insert one byte at a time. + T *ShVal = Helper.CreateShl(Val, ConstantInt::get(Val->getType(), 1 * 8)); + Val = Helper.CreateOr(OneElt, ShVal); + ++NumBytesSet; + } + + return coerceAvailableValueToLoadTypeHelper(Val, LoadTy, Helper, DL); + } + + // Otherwise, this is a memcpy/memmove from a constant global. + MemTransferInst *MTI = cast<MemTransferInst>(SrcInst); + Constant *Src = cast<Constant>(MTI->getSource()); + unsigned AS = Src->getType()->getPointerAddressSpace(); + + // Otherwise, see if we can constant fold a load from the constant with the + // offset applied as appropriate. + Src = + ConstantExpr::getBitCast(Src, Type::getInt8PtrTy(Src->getContext(), AS)); + Constant *OffsetCst = + ConstantInt::get(Type::getInt64Ty(Src->getContext()), (unsigned)Offset); + Src = ConstantExpr::getGetElementPtr(Type::getInt8Ty(Src->getContext()), Src, + OffsetCst); + Src = ConstantExpr::getBitCast(Src, PointerType::get(LoadTy, AS)); + return ConstantFoldLoadFromConstPtr(Src, LoadTy, DL); +} + +/// This function is called when we have a +/// memdep query of a load that ends up being a clobbering mem intrinsic. +Value *getMemInstValueForLoad(MemIntrinsic *SrcInst, unsigned Offset, + Type *LoadTy, Instruction *InsertPt, + const DataLayout &DL) { + IRBuilder<> Builder(InsertPt); + return getMemInstValueForLoadHelper<Value, IRBuilder<>>(SrcInst, Offset, + LoadTy, Builder, DL); +} + +Constant *getConstantMemInstValueForLoad(MemIntrinsic *SrcInst, unsigned Offset, + Type *LoadTy, const DataLayout &DL) { + // The only case analyzeLoadFromClobberingMemInst cannot be converted to a + // constant is when it's a memset of a non-constant. + if (auto *MSI = dyn_cast<MemSetInst>(SrcInst)) + if (!isa<Constant>(MSI->getValue())) + return nullptr; + ConstantFolder F; + return getMemInstValueForLoadHelper<Constant, ConstantFolder>(SrcInst, Offset, + LoadTy, F, DL); +} +} // namespace VNCoercion +} // namespace llvm diff --git a/contrib/llvm/lib/Transforms/Utils/ValueMapper.cpp b/contrib/llvm/lib/Transforms/Utils/ValueMapper.cpp index 0e9baaf..9309729 100644 --- a/contrib/llvm/lib/Transforms/Utils/ValueMapper.cpp +++ b/contrib/llvm/lib/Transforms/Utils/ValueMapper.cpp @@ -121,6 +121,8 @@ public: void addFlags(RemapFlags Flags); + void remapGlobalObjectMetadata(GlobalObject &GO); + Value *mapValue(const Value *V); void remapInstruction(Instruction *I); void remapFunction(Function &F); @@ -681,6 +683,7 @@ void MDNodeMapper::mapNodesInPOT(UniquedGraph &G) { remapOperands(*ClonedN, [this, &D, &G](Metadata *Old) { if (Optional<Metadata *> MappedOp = getMappedOp(Old)) return *MappedOp; + (void)D; assert(G.Info[Old].ID > D.ID && "Expected a forward reference"); return &G.getFwdReference(*cast<MDNode>(Old)); }); @@ -801,6 +804,7 @@ void Mapper::flush() { switch (E.Kind) { case WorklistEntry::MapGlobalInit: E.Data.GVInit.GV->setInitializer(mapConstant(E.Data.GVInit.Init)); + remapGlobalObjectMetadata(*E.Data.GVInit.GV); break; case WorklistEntry::MapAppendingVar: { unsigned PrefixSize = AppendingInits.size() - E.AppendingGVNumNewMembers; @@ -891,6 +895,14 @@ void Mapper::remapInstruction(Instruction *I) { I->mutateType(TypeMapper->remapType(I->getType())); } +void Mapper::remapGlobalObjectMetadata(GlobalObject &GO) { + SmallVector<std::pair<unsigned, MDNode *>, 8> MDs; + GO.getAllMetadata(MDs); + GO.clearMetadata(); + for (const auto &I : MDs) + GO.addMetadata(I.first, *cast<MDNode>(mapMetadata(I.second))); +} + void Mapper::remapFunction(Function &F) { // Remap the operands. for (Use &Op : F.operands()) @@ -898,11 +910,7 @@ void Mapper::remapFunction(Function &F) { Op = mapValue(Op); // Remap the metadata attachments. - SmallVector<std::pair<unsigned, MDNode *>, 8> MDs; - F.getAllMetadata(MDs); - F.clearMetadata(); - for (const auto &I : MDs) - F.addMetadata(I.first, *cast<MDNode>(mapMetadata(I.second))); + remapGlobalObjectMetadata(F); // Remap the argument types. if (TypeMapper) @@ -941,11 +949,10 @@ void Mapper::mapAppendingVariable(GlobalVariable &GV, Constant *InitPrefix, Constant *NewV; if (IsOldCtorDtor) { auto *S = cast<ConstantStruct>(V); - auto *E1 = mapValue(S->getOperand(0)); - auto *E2 = mapValue(S->getOperand(1)); - Value *Null = Constant::getNullValue(VoidPtrTy); - NewV = - ConstantStruct::get(cast<StructType>(EltTy), E1, E2, Null, nullptr); + auto *E1 = cast<Constant>(mapValue(S->getOperand(0))); + auto *E2 = cast<Constant>(mapValue(S->getOperand(1))); + Constant *Null = Constant::getNullValue(VoidPtrTy); + NewV = ConstantStruct::get(cast<StructType>(EltTy), E1, E2, Null); } else { NewV = cast_or_null<Constant>(mapValue(V)); } diff --git a/contrib/llvm/lib/Transforms/Vectorize/BBVectorize.cpp b/contrib/llvm/lib/Transforms/Vectorize/BBVectorize.cpp deleted file mode 100644 index c01740b..0000000 --- a/contrib/llvm/lib/Transforms/Vectorize/BBVectorize.cpp +++ /dev/null @@ -1,3269 +0,0 @@ -//===- BBVectorize.cpp - A Basic-Block Vectorizer -------------------------===// -// -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. -// -//===----------------------------------------------------------------------===// -// -// This file implements a basic-block vectorization pass. The algorithm was -// inspired by that used by the Vienna MAP Vectorizor by Franchetti and Kral, -// et al. It works by looking for chains of pairable operations and then -// pairing them. -// -//===----------------------------------------------------------------------===// - -#define BBV_NAME "bb-vectorize" -#include "llvm/Transforms/Vectorize.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallSet.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/Statistic.h" -#include "llvm/ADT/StringExtras.h" -#include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/Analysis/AliasSetTracker.h" -#include "llvm/Analysis/GlobalsModRef.h" -#include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" -#include "llvm/Analysis/ScalarEvolutionExpressions.h" -#include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/Analysis/TargetTransformInfo.h" -#include "llvm/Analysis/ValueTracking.h" -#include "llvm/IR/Constants.h" -#include "llvm/IR/DataLayout.h" -#include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/Dominators.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/Instructions.h" -#include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Intrinsics.h" -#include "llvm/IR/LLVMContext.h" -#include "llvm/IR/Metadata.h" -#include "llvm/IR/Module.h" -#include "llvm/IR/Type.h" -#include "llvm/IR/ValueHandle.h" -#include "llvm/Pass.h" -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Utils/Local.h" -#include <algorithm> -using namespace llvm; - -#define DEBUG_TYPE BBV_NAME - -static cl::opt<bool> -IgnoreTargetInfo("bb-vectorize-ignore-target-info", cl::init(false), - cl::Hidden, cl::desc("Ignore target information")); - -static cl::opt<unsigned> -ReqChainDepth("bb-vectorize-req-chain-depth", cl::init(6), cl::Hidden, - cl::desc("The required chain depth for vectorization")); - -static cl::opt<bool> -UseChainDepthWithTI("bb-vectorize-use-chain-depth", cl::init(false), - cl::Hidden, cl::desc("Use the chain depth requirement with" - " target information")); - -static cl::opt<unsigned> -SearchLimit("bb-vectorize-search-limit", cl::init(400), cl::Hidden, - cl::desc("The maximum search distance for instruction pairs")); - -static cl::opt<bool> -SplatBreaksChain("bb-vectorize-splat-breaks-chain", cl::init(false), cl::Hidden, - cl::desc("Replicating one element to a pair breaks the chain")); - -static cl::opt<unsigned> -VectorBits("bb-vectorize-vector-bits", cl::init(128), cl::Hidden, - cl::desc("The size of the native vector registers")); - -static cl::opt<unsigned> -MaxIter("bb-vectorize-max-iter", cl::init(0), cl::Hidden, - cl::desc("The maximum number of pairing iterations")); - -static cl::opt<bool> -Pow2LenOnly("bb-vectorize-pow2-len-only", cl::init(false), cl::Hidden, - cl::desc("Don't try to form non-2^n-length vectors")); - -static cl::opt<unsigned> -MaxInsts("bb-vectorize-max-instr-per-group", cl::init(500), cl::Hidden, - cl::desc("The maximum number of pairable instructions per group")); - -static cl::opt<unsigned> -MaxPairs("bb-vectorize-max-pairs-per-group", cl::init(3000), cl::Hidden, - cl::desc("The maximum number of candidate instruction pairs per group")); - -static cl::opt<unsigned> -MaxCandPairsForCycleCheck("bb-vectorize-max-cycle-check-pairs", cl::init(200), - cl::Hidden, cl::desc("The maximum number of candidate pairs with which to use" - " a full cycle check")); - -static cl::opt<bool> -NoBools("bb-vectorize-no-bools", cl::init(false), cl::Hidden, - cl::desc("Don't try to vectorize boolean (i1) values")); - -static cl::opt<bool> -NoInts("bb-vectorize-no-ints", cl::init(false), cl::Hidden, - cl::desc("Don't try to vectorize integer values")); - -static cl::opt<bool> -NoFloats("bb-vectorize-no-floats", cl::init(false), cl::Hidden, - cl::desc("Don't try to vectorize floating-point values")); - -// FIXME: This should default to false once pointer vector support works. -static cl::opt<bool> -NoPointers("bb-vectorize-no-pointers", cl::init(/*false*/ true), cl::Hidden, - cl::desc("Don't try to vectorize pointer values")); - -static cl::opt<bool> -NoCasts("bb-vectorize-no-casts", cl::init(false), cl::Hidden, - cl::desc("Don't try to vectorize casting (conversion) operations")); - -static cl::opt<bool> -NoMath("bb-vectorize-no-math", cl::init(false), cl::Hidden, - cl::desc("Don't try to vectorize floating-point math intrinsics")); - -static cl::opt<bool> - NoBitManipulation("bb-vectorize-no-bitmanip", cl::init(false), cl::Hidden, - cl::desc("Don't try to vectorize BitManipulation intrinsics")); - -static cl::opt<bool> -NoFMA("bb-vectorize-no-fma", cl::init(false), cl::Hidden, - cl::desc("Don't try to vectorize the fused-multiply-add intrinsic")); - -static cl::opt<bool> -NoSelect("bb-vectorize-no-select", cl::init(false), cl::Hidden, - cl::desc("Don't try to vectorize select instructions")); - -static cl::opt<bool> -NoCmp("bb-vectorize-no-cmp", cl::init(false), cl::Hidden, - cl::desc("Don't try to vectorize comparison instructions")); - -static cl::opt<bool> -NoGEP("bb-vectorize-no-gep", cl::init(false), cl::Hidden, - cl::desc("Don't try to vectorize getelementptr instructions")); - -static cl::opt<bool> -NoMemOps("bb-vectorize-no-mem-ops", cl::init(false), cl::Hidden, - cl::desc("Don't try to vectorize loads and stores")); - -static cl::opt<bool> -AlignedOnly("bb-vectorize-aligned-only", cl::init(false), cl::Hidden, - cl::desc("Only generate aligned loads and stores")); - -static cl::opt<bool> -NoMemOpBoost("bb-vectorize-no-mem-op-boost", - cl::init(false), cl::Hidden, - cl::desc("Don't boost the chain-depth contribution of loads and stores")); - -static cl::opt<bool> -FastDep("bb-vectorize-fast-dep", cl::init(false), cl::Hidden, - cl::desc("Use a fast instruction dependency analysis")); - -#ifndef NDEBUG -static cl::opt<bool> -DebugInstructionExamination("bb-vectorize-debug-instruction-examination", - cl::init(false), cl::Hidden, - cl::desc("When debugging is enabled, output information on the" - " instruction-examination process")); -static cl::opt<bool> -DebugCandidateSelection("bb-vectorize-debug-candidate-selection", - cl::init(false), cl::Hidden, - cl::desc("When debugging is enabled, output information on the" - " candidate-selection process")); -static cl::opt<bool> -DebugPairSelection("bb-vectorize-debug-pair-selection", - cl::init(false), cl::Hidden, - cl::desc("When debugging is enabled, output information on the" - " pair-selection process")); -static cl::opt<bool> -DebugCycleCheck("bb-vectorize-debug-cycle-check", - cl::init(false), cl::Hidden, - cl::desc("When debugging is enabled, output information on the" - " cycle-checking process")); - -static cl::opt<bool> -PrintAfterEveryPair("bb-vectorize-debug-print-after-every-pair", - cl::init(false), cl::Hidden, - cl::desc("When debugging is enabled, dump the basic block after" - " every pair is fused")); -#endif - -STATISTIC(NumFusedOps, "Number of operations fused by bb-vectorize"); - -namespace { - struct BBVectorize : public BasicBlockPass { - static char ID; // Pass identification, replacement for typeid - - const VectorizeConfig Config; - - BBVectorize(const VectorizeConfig &C = VectorizeConfig()) - : BasicBlockPass(ID), Config(C) { - initializeBBVectorizePass(*PassRegistry::getPassRegistry()); - } - - BBVectorize(Pass *P, Function &F, const VectorizeConfig &C) - : BasicBlockPass(ID), Config(C) { - AA = &P->getAnalysis<AAResultsWrapperPass>().getAAResults(); - DT = &P->getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - SE = &P->getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - TLI = &P->getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); - TTI = IgnoreTargetInfo - ? nullptr - : &P->getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); - } - - typedef std::pair<Value *, Value *> ValuePair; - typedef std::pair<ValuePair, int> ValuePairWithCost; - typedef std::pair<ValuePair, size_t> ValuePairWithDepth; - typedef std::pair<ValuePair, ValuePair> VPPair; // A ValuePair pair - typedef std::pair<VPPair, unsigned> VPPairWithType; - - AliasAnalysis *AA; - DominatorTree *DT; - ScalarEvolution *SE; - const TargetLibraryInfo *TLI; - const TargetTransformInfo *TTI; - - // FIXME: const correct? - - bool vectorizePairs(BasicBlock &BB, bool NonPow2Len = false); - - bool getCandidatePairs(BasicBlock &BB, - BasicBlock::iterator &Start, - DenseMap<Value *, std::vector<Value *> > &CandidatePairs, - DenseSet<ValuePair> &FixedOrderPairs, - DenseMap<ValuePair, int> &CandidatePairCostSavings, - std::vector<Value *> &PairableInsts, bool NonPow2Len); - - // FIXME: The current implementation does not account for pairs that - // are connected in multiple ways. For example: - // C1 = A1 / A2; C2 = A2 / A1 (which may be both direct and a swap) - enum PairConnectionType { - PairConnectionDirect, - PairConnectionSwap, - PairConnectionSplat - }; - - void computeConnectedPairs( - DenseMap<Value *, std::vector<Value *> > &CandidatePairs, - DenseSet<ValuePair> &CandidatePairsSet, - std::vector<Value *> &PairableInsts, - DenseMap<ValuePair, std::vector<ValuePair> > &ConnectedPairs, - DenseMap<VPPair, unsigned> &PairConnectionTypes); - - void buildDepMap(BasicBlock &BB, - DenseMap<Value *, std::vector<Value *> > &CandidatePairs, - std::vector<Value *> &PairableInsts, - DenseSet<ValuePair> &PairableInstUsers); - - void choosePairs(DenseMap<Value *, std::vector<Value *> > &CandidatePairs, - DenseSet<ValuePair> &CandidatePairsSet, - DenseMap<ValuePair, int> &CandidatePairCostSavings, - std::vector<Value *> &PairableInsts, - DenseSet<ValuePair> &FixedOrderPairs, - DenseMap<VPPair, unsigned> &PairConnectionTypes, - DenseMap<ValuePair, std::vector<ValuePair> > &ConnectedPairs, - DenseMap<ValuePair, std::vector<ValuePair> > &ConnectedPairDeps, - DenseSet<ValuePair> &PairableInstUsers, - DenseMap<Value *, Value *>& ChosenPairs); - - void fuseChosenPairs(BasicBlock &BB, - std::vector<Value *> &PairableInsts, - DenseMap<Value *, Value *>& ChosenPairs, - DenseSet<ValuePair> &FixedOrderPairs, - DenseMap<VPPair, unsigned> &PairConnectionTypes, - DenseMap<ValuePair, std::vector<ValuePair> > &ConnectedPairs, - DenseMap<ValuePair, std::vector<ValuePair> > &ConnectedPairDeps); - - - bool isInstVectorizable(Instruction *I, bool &IsSimpleLoadStore); - - bool areInstsCompatible(Instruction *I, Instruction *J, - bool IsSimpleLoadStore, bool NonPow2Len, - int &CostSavings, int &FixedOrder); - - bool trackUsesOfI(DenseSet<Value *> &Users, - AliasSetTracker &WriteSet, Instruction *I, - Instruction *J, bool UpdateUsers = true, - DenseSet<ValuePair> *LoadMoveSetPairs = nullptr); - - void computePairsConnectedTo( - DenseMap<Value *, std::vector<Value *> > &CandidatePairs, - DenseSet<ValuePair> &CandidatePairsSet, - std::vector<Value *> &PairableInsts, - DenseMap<ValuePair, std::vector<ValuePair> > &ConnectedPairs, - DenseMap<VPPair, unsigned> &PairConnectionTypes, - ValuePair P); - - bool pairsConflict(ValuePair P, ValuePair Q, - DenseSet<ValuePair> &PairableInstUsers, - DenseMap<ValuePair, std::vector<ValuePair> > - *PairableInstUserMap = nullptr, - DenseSet<VPPair> *PairableInstUserPairSet = nullptr); - - bool pairWillFormCycle(ValuePair P, - DenseMap<ValuePair, std::vector<ValuePair> > &PairableInstUsers, - DenseSet<ValuePair> &CurrentPairs); - - void pruneDAGFor( - DenseMap<Value *, std::vector<Value *> > &CandidatePairs, - std::vector<Value *> &PairableInsts, - DenseMap<ValuePair, std::vector<ValuePair> > &ConnectedPairs, - DenseSet<ValuePair> &PairableInstUsers, - DenseMap<ValuePair, std::vector<ValuePair> > &PairableInstUserMap, - DenseSet<VPPair> &PairableInstUserPairSet, - DenseMap<Value *, Value *> &ChosenPairs, - DenseMap<ValuePair, size_t> &DAG, - DenseSet<ValuePair> &PrunedDAG, ValuePair J, - bool UseCycleCheck); - - void buildInitialDAGFor( - DenseMap<Value *, std::vector<Value *> > &CandidatePairs, - DenseSet<ValuePair> &CandidatePairsSet, - std::vector<Value *> &PairableInsts, - DenseMap<ValuePair, std::vector<ValuePair> > &ConnectedPairs, - DenseSet<ValuePair> &PairableInstUsers, - DenseMap<Value *, Value *> &ChosenPairs, - DenseMap<ValuePair, size_t> &DAG, ValuePair J); - - void findBestDAGFor( - DenseMap<Value *, std::vector<Value *> > &CandidatePairs, - DenseSet<ValuePair> &CandidatePairsSet, - DenseMap<ValuePair, int> &CandidatePairCostSavings, - std::vector<Value *> &PairableInsts, - DenseSet<ValuePair> &FixedOrderPairs, - DenseMap<VPPair, unsigned> &PairConnectionTypes, - DenseMap<ValuePair, std::vector<ValuePair> > &ConnectedPairs, - DenseMap<ValuePair, std::vector<ValuePair> > &ConnectedPairDeps, - DenseSet<ValuePair> &PairableInstUsers, - DenseMap<ValuePair, std::vector<ValuePair> > &PairableInstUserMap, - DenseSet<VPPair> &PairableInstUserPairSet, - DenseMap<Value *, Value *> &ChosenPairs, - DenseSet<ValuePair> &BestDAG, size_t &BestMaxDepth, - int &BestEffSize, Value *II, std::vector<Value *>&JJ, - bool UseCycleCheck); - - Value *getReplacementPointerInput(LLVMContext& Context, Instruction *I, - Instruction *J, unsigned o); - - void fillNewShuffleMask(LLVMContext& Context, Instruction *J, - unsigned MaskOffset, unsigned NumInElem, - unsigned NumInElem1, unsigned IdxOffset, - std::vector<Constant*> &Mask); - - Value *getReplacementShuffleMask(LLVMContext& Context, Instruction *I, - Instruction *J); - - bool expandIEChain(LLVMContext& Context, Instruction *I, Instruction *J, - unsigned o, Value *&LOp, unsigned numElemL, - Type *ArgTypeL, Type *ArgTypeR, bool IBeforeJ, - unsigned IdxOff = 0); - - Value *getReplacementInput(LLVMContext& Context, Instruction *I, - Instruction *J, unsigned o, bool IBeforeJ); - - void getReplacementInputsForPair(LLVMContext& Context, Instruction *I, - Instruction *J, SmallVectorImpl<Value *> &ReplacedOperands, - bool IBeforeJ); - - void replaceOutputsOfPair(LLVMContext& Context, Instruction *I, - Instruction *J, Instruction *K, - Instruction *&InsertionPt, Instruction *&K1, - Instruction *&K2); - - void collectPairLoadMoveSet(BasicBlock &BB, - DenseMap<Value *, Value *> &ChosenPairs, - DenseMap<Value *, std::vector<Value *> > &LoadMoveSet, - DenseSet<ValuePair> &LoadMoveSetPairs, - Instruction *I); - - void collectLoadMoveSet(BasicBlock &BB, - std::vector<Value *> &PairableInsts, - DenseMap<Value *, Value *> &ChosenPairs, - DenseMap<Value *, std::vector<Value *> > &LoadMoveSet, - DenseSet<ValuePair> &LoadMoveSetPairs); - - bool canMoveUsesOfIAfterJ(BasicBlock &BB, - DenseSet<ValuePair> &LoadMoveSetPairs, - Instruction *I, Instruction *J); - - void moveUsesOfIAfterJ(BasicBlock &BB, - DenseSet<ValuePair> &LoadMoveSetPairs, - Instruction *&InsertionPt, - Instruction *I, Instruction *J); - - bool vectorizeBB(BasicBlock &BB) { - if (skipBasicBlock(BB)) - return false; - if (!DT->isReachableFromEntry(&BB)) { - DEBUG(dbgs() << "BBV: skipping unreachable " << BB.getName() << - " in " << BB.getParent()->getName() << "\n"); - return false; - } - - DEBUG(if (TTI) dbgs() << "BBV: using target information\n"); - - bool changed = false; - // Iterate a sufficient number of times to merge types of size 1 bit, - // then 2 bits, then 4, etc. up to half of the target vector width of the - // target vector register. - unsigned n = 1; - for (unsigned v = 2; - (TTI || v <= Config.VectorBits) && - (!Config.MaxIter || n <= Config.MaxIter); - v *= 2, ++n) { - DEBUG(dbgs() << "BBV: fusing loop #" << n << - " for " << BB.getName() << " in " << - BB.getParent()->getName() << "...\n"); - if (vectorizePairs(BB)) - changed = true; - else - break; - } - - if (changed && !Pow2LenOnly) { - ++n; - for (; !Config.MaxIter || n <= Config.MaxIter; ++n) { - DEBUG(dbgs() << "BBV: fusing for non-2^n-length vectors loop #: " << - n << " for " << BB.getName() << " in " << - BB.getParent()->getName() << "...\n"); - if (!vectorizePairs(BB, true)) break; - } - } - - DEBUG(dbgs() << "BBV: done!\n"); - return changed; - } - - bool runOnBasicBlock(BasicBlock &BB) override { - // OptimizeNone check deferred to vectorizeBB(). - - AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); - DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); - TTI = IgnoreTargetInfo - ? nullptr - : &getAnalysis<TargetTransformInfoWrapperPass>().getTTI( - *BB.getParent()); - - return vectorizeBB(BB); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - BasicBlockPass::getAnalysisUsage(AU); - AU.addRequired<AAResultsWrapperPass>(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<ScalarEvolutionWrapperPass>(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); - AU.addRequired<TargetTransformInfoWrapperPass>(); - AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); - AU.addPreserved<ScalarEvolutionWrapperPass>(); - AU.addPreserved<SCEVAAWrapperPass>(); - AU.setPreservesCFG(); - } - - static inline VectorType *getVecTypeForPair(Type *ElemTy, Type *Elem2Ty) { - assert(ElemTy->getScalarType() == Elem2Ty->getScalarType() && - "Cannot form vector from incompatible scalar types"); - Type *STy = ElemTy->getScalarType(); - - unsigned numElem; - if (VectorType *VTy = dyn_cast<VectorType>(ElemTy)) { - numElem = VTy->getNumElements(); - } else { - numElem = 1; - } - - if (VectorType *VTy = dyn_cast<VectorType>(Elem2Ty)) { - numElem += VTy->getNumElements(); - } else { - numElem += 1; - } - - return VectorType::get(STy, numElem); - } - - static inline void getInstructionTypes(Instruction *I, - Type *&T1, Type *&T2) { - if (StoreInst *SI = dyn_cast<StoreInst>(I)) { - // For stores, it is the value type, not the pointer type that matters - // because the value is what will come from a vector register. - - Value *IVal = SI->getValueOperand(); - T1 = IVal->getType(); - } else { - T1 = I->getType(); - } - - if (CastInst *CI = dyn_cast<CastInst>(I)) - T2 = CI->getSrcTy(); - else - T2 = T1; - - if (SelectInst *SI = dyn_cast<SelectInst>(I)) { - T2 = SI->getCondition()->getType(); - } else if (ShuffleVectorInst *SI = dyn_cast<ShuffleVectorInst>(I)) { - T2 = SI->getOperand(0)->getType(); - } else if (CmpInst *CI = dyn_cast<CmpInst>(I)) { - T2 = CI->getOperand(0)->getType(); - } - } - - // Returns the weight associated with the provided value. A chain of - // candidate pairs has a length given by the sum of the weights of its - // members (one weight per pair; the weight of each member of the pair - // is assumed to be the same). This length is then compared to the - // chain-length threshold to determine if a given chain is significant - // enough to be vectorized. The length is also used in comparing - // candidate chains where longer chains are considered to be better. - // Note: when this function returns 0, the resulting instructions are - // not actually fused. - inline size_t getDepthFactor(Value *V) { - // InsertElement and ExtractElement have a depth factor of zero. This is - // for two reasons: First, they cannot be usefully fused. Second, because - // the pass generates a lot of these, they can confuse the simple metric - // used to compare the dags in the next iteration. Thus, giving them a - // weight of zero allows the pass to essentially ignore them in - // subsequent iterations when looking for vectorization opportunities - // while still tracking dependency chains that flow through those - // instructions. - if (isa<InsertElementInst>(V) || isa<ExtractElementInst>(V)) - return 0; - - // Give a load or store half of the required depth so that load/store - // pairs will vectorize. - if (!Config.NoMemOpBoost && (isa<LoadInst>(V) || isa<StoreInst>(V))) - return Config.ReqChainDepth/2; - - return 1; - } - - // Returns the cost of the provided instruction using TTI. - // This does not handle loads and stores. - unsigned getInstrCost(unsigned Opcode, Type *T1, Type *T2, - TargetTransformInfo::OperandValueKind Op1VK = - TargetTransformInfo::OK_AnyValue, - TargetTransformInfo::OperandValueKind Op2VK = - TargetTransformInfo::OK_AnyValue) { - switch (Opcode) { - default: break; - case Instruction::GetElementPtr: - // We mark this instruction as zero-cost because scalar GEPs are usually - // lowered to the instruction addressing mode. At the moment we don't - // generate vector GEPs. - return 0; - case Instruction::Br: - return TTI->getCFInstrCost(Opcode); - case Instruction::PHI: - return 0; - case Instruction::Add: - case Instruction::FAdd: - case Instruction::Sub: - case Instruction::FSub: - case Instruction::Mul: - case Instruction::FMul: - case Instruction::UDiv: - case Instruction::SDiv: - case Instruction::FDiv: - case Instruction::URem: - case Instruction::SRem: - case Instruction::FRem: - case Instruction::Shl: - case Instruction::LShr: - case Instruction::AShr: - case Instruction::And: - case Instruction::Or: - case Instruction::Xor: - return TTI->getArithmeticInstrCost(Opcode, T1, Op1VK, Op2VK); - case Instruction::Select: - case Instruction::ICmp: - case Instruction::FCmp: - return TTI->getCmpSelInstrCost(Opcode, T1, T2); - case Instruction::ZExt: - case Instruction::SExt: - case Instruction::FPToUI: - case Instruction::FPToSI: - case Instruction::FPExt: - case Instruction::PtrToInt: - case Instruction::IntToPtr: - case Instruction::SIToFP: - case Instruction::UIToFP: - case Instruction::Trunc: - case Instruction::FPTrunc: - case Instruction::BitCast: - case Instruction::ShuffleVector: - return TTI->getCastInstrCost(Opcode, T1, T2); - } - - return 1; - } - - // This determines the relative offset of two loads or stores, returning - // true if the offset could be determined to be some constant value. - // For example, if OffsetInElmts == 1, then J accesses the memory directly - // after I; if OffsetInElmts == -1 then I accesses the memory - // directly after J. - bool getPairPtrInfo(Instruction *I, Instruction *J, - Value *&IPtr, Value *&JPtr, unsigned &IAlignment, unsigned &JAlignment, - unsigned &IAddressSpace, unsigned &JAddressSpace, - int64_t &OffsetInElmts, bool ComputeOffset = true) { - OffsetInElmts = 0; - if (LoadInst *LI = dyn_cast<LoadInst>(I)) { - LoadInst *LJ = cast<LoadInst>(J); - IPtr = LI->getPointerOperand(); - JPtr = LJ->getPointerOperand(); - IAlignment = LI->getAlignment(); - JAlignment = LJ->getAlignment(); - IAddressSpace = LI->getPointerAddressSpace(); - JAddressSpace = LJ->getPointerAddressSpace(); - } else { - StoreInst *SI = cast<StoreInst>(I), *SJ = cast<StoreInst>(J); - IPtr = SI->getPointerOperand(); - JPtr = SJ->getPointerOperand(); - IAlignment = SI->getAlignment(); - JAlignment = SJ->getAlignment(); - IAddressSpace = SI->getPointerAddressSpace(); - JAddressSpace = SJ->getPointerAddressSpace(); - } - - if (!ComputeOffset) - return true; - - const SCEV *IPtrSCEV = SE->getSCEV(IPtr); - const SCEV *JPtrSCEV = SE->getSCEV(JPtr); - - // If this is a trivial offset, then we'll get something like - // 1*sizeof(type). With target data, which we need anyway, this will get - // constant folded into a number. - const SCEV *OffsetSCEV = SE->getMinusSCEV(JPtrSCEV, IPtrSCEV); - if (const SCEVConstant *ConstOffSCEV = - dyn_cast<SCEVConstant>(OffsetSCEV)) { - ConstantInt *IntOff = ConstOffSCEV->getValue(); - int64_t Offset = IntOff->getSExtValue(); - const DataLayout &DL = I->getModule()->getDataLayout(); - Type *VTy = IPtr->getType()->getPointerElementType(); - int64_t VTyTSS = (int64_t)DL.getTypeStoreSize(VTy); - - Type *VTy2 = JPtr->getType()->getPointerElementType(); - if (VTy != VTy2 && Offset < 0) { - int64_t VTy2TSS = (int64_t)DL.getTypeStoreSize(VTy2); - OffsetInElmts = Offset/VTy2TSS; - return (std::abs(Offset) % VTy2TSS) == 0; - } - - OffsetInElmts = Offset/VTyTSS; - return (std::abs(Offset) % VTyTSS) == 0; - } - - return false; - } - - // Returns true if the provided CallInst represents an intrinsic that can - // be vectorized. - bool isVectorizableIntrinsic(CallInst* I) { - Function *F = I->getCalledFunction(); - if (!F) return false; - - Intrinsic::ID IID = F->getIntrinsicID(); - if (!IID) return false; - - switch(IID) { - default: - return false; - case Intrinsic::sqrt: - case Intrinsic::powi: - case Intrinsic::sin: - case Intrinsic::cos: - case Intrinsic::log: - case Intrinsic::log2: - case Intrinsic::log10: - case Intrinsic::exp: - case Intrinsic::exp2: - case Intrinsic::pow: - case Intrinsic::round: - case Intrinsic::copysign: - case Intrinsic::ceil: - case Intrinsic::nearbyint: - case Intrinsic::rint: - case Intrinsic::trunc: - case Intrinsic::floor: - case Intrinsic::fabs: - case Intrinsic::minnum: - case Intrinsic::maxnum: - return Config.VectorizeMath; - case Intrinsic::bswap: - case Intrinsic::ctpop: - case Intrinsic::ctlz: - case Intrinsic::cttz: - return Config.VectorizeBitManipulations; - case Intrinsic::fma: - case Intrinsic::fmuladd: - return Config.VectorizeFMA; - } - } - - bool isPureIEChain(InsertElementInst *IE) { - InsertElementInst *IENext = IE; - do { - if (!isa<UndefValue>(IENext->getOperand(0)) && - !isa<InsertElementInst>(IENext->getOperand(0))) { - return false; - } - } while ((IENext = - dyn_cast<InsertElementInst>(IENext->getOperand(0)))); - - return true; - } - }; - - // This function implements one vectorization iteration on the provided - // basic block. It returns true if the block is changed. - bool BBVectorize::vectorizePairs(BasicBlock &BB, bool NonPow2Len) { - bool ShouldContinue; - BasicBlock::iterator Start = BB.getFirstInsertionPt(); - - std::vector<Value *> AllPairableInsts; - DenseMap<Value *, Value *> AllChosenPairs; - DenseSet<ValuePair> AllFixedOrderPairs; - DenseMap<VPPair, unsigned> AllPairConnectionTypes; - DenseMap<ValuePair, std::vector<ValuePair> > AllConnectedPairs, - AllConnectedPairDeps; - - do { - std::vector<Value *> PairableInsts; - DenseMap<Value *, std::vector<Value *> > CandidatePairs; - DenseSet<ValuePair> FixedOrderPairs; - DenseMap<ValuePair, int> CandidatePairCostSavings; - ShouldContinue = getCandidatePairs(BB, Start, CandidatePairs, - FixedOrderPairs, - CandidatePairCostSavings, - PairableInsts, NonPow2Len); - if (PairableInsts.empty()) continue; - - // Build the candidate pair set for faster lookups. - DenseSet<ValuePair> CandidatePairsSet; - for (DenseMap<Value *, std::vector<Value *> >::iterator I = - CandidatePairs.begin(), E = CandidatePairs.end(); I != E; ++I) - for (std::vector<Value *>::iterator J = I->second.begin(), - JE = I->second.end(); J != JE; ++J) - CandidatePairsSet.insert(ValuePair(I->first, *J)); - - // Now we have a map of all of the pairable instructions and we need to - // select the best possible pairing. A good pairing is one such that the - // users of the pair are also paired. This defines a (directed) forest - // over the pairs such that two pairs are connected iff the second pair - // uses the first. - - // Note that it only matters that both members of the second pair use some - // element of the first pair (to allow for splatting). - - DenseMap<ValuePair, std::vector<ValuePair> > ConnectedPairs, - ConnectedPairDeps; - DenseMap<VPPair, unsigned> PairConnectionTypes; - computeConnectedPairs(CandidatePairs, CandidatePairsSet, - PairableInsts, ConnectedPairs, PairConnectionTypes); - if (ConnectedPairs.empty()) continue; - - for (DenseMap<ValuePair, std::vector<ValuePair> >::iterator - I = ConnectedPairs.begin(), IE = ConnectedPairs.end(); - I != IE; ++I) - for (std::vector<ValuePair>::iterator J = I->second.begin(), - JE = I->second.end(); J != JE; ++J) - ConnectedPairDeps[*J].push_back(I->first); - - // Build the pairable-instruction dependency map - DenseSet<ValuePair> PairableInstUsers; - buildDepMap(BB, CandidatePairs, PairableInsts, PairableInstUsers); - - // There is now a graph of the connected pairs. For each variable, pick - // the pairing with the largest dag meeting the depth requirement on at - // least one branch. Then select all pairings that are part of that dag - // and remove them from the list of available pairings and pairable - // variables. - - DenseMap<Value *, Value *> ChosenPairs; - choosePairs(CandidatePairs, CandidatePairsSet, - CandidatePairCostSavings, - PairableInsts, FixedOrderPairs, PairConnectionTypes, - ConnectedPairs, ConnectedPairDeps, - PairableInstUsers, ChosenPairs); - - if (ChosenPairs.empty()) continue; - AllPairableInsts.insert(AllPairableInsts.end(), PairableInsts.begin(), - PairableInsts.end()); - AllChosenPairs.insert(ChosenPairs.begin(), ChosenPairs.end()); - - // Only for the chosen pairs, propagate information on fixed-order pairs, - // pair connections, and their types to the data structures used by the - // pair fusion procedures. - for (DenseMap<Value *, Value *>::iterator I = ChosenPairs.begin(), - IE = ChosenPairs.end(); I != IE; ++I) { - if (FixedOrderPairs.count(*I)) - AllFixedOrderPairs.insert(*I); - else if (FixedOrderPairs.count(ValuePair(I->second, I->first))) - AllFixedOrderPairs.insert(ValuePair(I->second, I->first)); - - for (DenseMap<Value *, Value *>::iterator J = ChosenPairs.begin(); - J != IE; ++J) { - DenseMap<VPPair, unsigned>::iterator K = - PairConnectionTypes.find(VPPair(*I, *J)); - if (K != PairConnectionTypes.end()) { - AllPairConnectionTypes.insert(*K); - } else { - K = PairConnectionTypes.find(VPPair(*J, *I)); - if (K != PairConnectionTypes.end()) - AllPairConnectionTypes.insert(*K); - } - } - } - - for (DenseMap<ValuePair, std::vector<ValuePair> >::iterator - I = ConnectedPairs.begin(), IE = ConnectedPairs.end(); - I != IE; ++I) - for (std::vector<ValuePair>::iterator J = I->second.begin(), - JE = I->second.end(); J != JE; ++J) - if (AllPairConnectionTypes.count(VPPair(I->first, *J))) { - AllConnectedPairs[I->first].push_back(*J); - AllConnectedPairDeps[*J].push_back(I->first); - } - } while (ShouldContinue); - - if (AllChosenPairs.empty()) return false; - NumFusedOps += AllChosenPairs.size(); - - // A set of pairs has now been selected. It is now necessary to replace the - // paired instructions with vector instructions. For this procedure each - // operand must be replaced with a vector operand. This vector is formed - // by using build_vector on the old operands. The replaced values are then - // replaced with a vector_extract on the result. Subsequent optimization - // passes should coalesce the build/extract combinations. - - fuseChosenPairs(BB, AllPairableInsts, AllChosenPairs, AllFixedOrderPairs, - AllPairConnectionTypes, - AllConnectedPairs, AllConnectedPairDeps); - - // It is important to cleanup here so that future iterations of this - // function have less work to do. - (void)SimplifyInstructionsInBlock(&BB, TLI); - return true; - } - - // This function returns true if the provided instruction is capable of being - // fused into a vector instruction. This determination is based only on the - // type and other attributes of the instruction. - bool BBVectorize::isInstVectorizable(Instruction *I, - bool &IsSimpleLoadStore) { - IsSimpleLoadStore = false; - - if (CallInst *C = dyn_cast<CallInst>(I)) { - if (!isVectorizableIntrinsic(C)) - return false; - } else if (LoadInst *L = dyn_cast<LoadInst>(I)) { - // Vectorize simple loads if possbile: - IsSimpleLoadStore = L->isSimple(); - if (!IsSimpleLoadStore || !Config.VectorizeMemOps) - return false; - } else if (StoreInst *S = dyn_cast<StoreInst>(I)) { - // Vectorize simple stores if possbile: - IsSimpleLoadStore = S->isSimple(); - if (!IsSimpleLoadStore || !Config.VectorizeMemOps) - return false; - } else if (CastInst *C = dyn_cast<CastInst>(I)) { - // We can vectorize casts, but not casts of pointer types, etc. - if (!Config.VectorizeCasts) - return false; - - Type *SrcTy = C->getSrcTy(); - if (!SrcTy->isSingleValueType()) - return false; - - Type *DestTy = C->getDestTy(); - if (!DestTy->isSingleValueType()) - return false; - } else if (SelectInst *SI = dyn_cast<SelectInst>(I)) { - if (!Config.VectorizeSelect) - return false; - // We can vectorize a select if either all operands are scalars, - // or all operands are vectors. Trying to "widen" a select between - // vectors that has a scalar condition results in a malformed select. - // FIXME: We could probably be smarter about this by rewriting the select - // with different types instead. - return (SI->getCondition()->getType()->isVectorTy() == - SI->getTrueValue()->getType()->isVectorTy()); - } else if (isa<CmpInst>(I)) { - if (!Config.VectorizeCmp) - return false; - } else if (GetElementPtrInst *G = dyn_cast<GetElementPtrInst>(I)) { - if (!Config.VectorizeGEP) - return false; - - // Currently, vector GEPs exist only with one index. - if (G->getNumIndices() != 1) - return false; - } else if (!(I->isBinaryOp() || isa<ShuffleVectorInst>(I) || - isa<ExtractElementInst>(I) || isa<InsertElementInst>(I))) { - return false; - } - - Type *T1, *T2; - getInstructionTypes(I, T1, T2); - - // Not every type can be vectorized... - if (!(VectorType::isValidElementType(T1) || T1->isVectorTy()) || - !(VectorType::isValidElementType(T2) || T2->isVectorTy())) - return false; - - if (T1->getScalarSizeInBits() == 1) { - if (!Config.VectorizeBools) - return false; - } else { - if (!Config.VectorizeInts && T1->isIntOrIntVectorTy()) - return false; - } - - if (T2->getScalarSizeInBits() == 1) { - if (!Config.VectorizeBools) - return false; - } else { - if (!Config.VectorizeInts && T2->isIntOrIntVectorTy()) - return false; - } - - if (!Config.VectorizeFloats - && (T1->isFPOrFPVectorTy() || T2->isFPOrFPVectorTy())) - return false; - - // Don't vectorize target-specific types. - if (T1->isX86_FP80Ty() || T1->isPPC_FP128Ty() || T1->isX86_MMXTy()) - return false; - if (T2->isX86_FP80Ty() || T2->isPPC_FP128Ty() || T2->isX86_MMXTy()) - return false; - - if (!Config.VectorizePointers && (T1->getScalarType()->isPointerTy() || - T2->getScalarType()->isPointerTy())) - return false; - - if (!TTI && (T1->getPrimitiveSizeInBits() >= Config.VectorBits || - T2->getPrimitiveSizeInBits() >= Config.VectorBits)) - return false; - - return true; - } - - // This function returns true if the two provided instructions are compatible - // (meaning that they can be fused into a vector instruction). This assumes - // that I has already been determined to be vectorizable and that J is not - // in the use dag of I. - bool BBVectorize::areInstsCompatible(Instruction *I, Instruction *J, - bool IsSimpleLoadStore, bool NonPow2Len, - int &CostSavings, int &FixedOrder) { - DEBUG(if (DebugInstructionExamination) dbgs() << "BBV: looking at " << *I << - " <-> " << *J << "\n"); - - CostSavings = 0; - FixedOrder = 0; - - // Loads and stores can be merged if they have different alignments, - // but are otherwise the same. - if (!J->isSameOperationAs(I, Instruction::CompareIgnoringAlignment | - (NonPow2Len ? Instruction::CompareUsingScalarTypes : 0))) - return false; - - Type *IT1, *IT2, *JT1, *JT2; - getInstructionTypes(I, IT1, IT2); - getInstructionTypes(J, JT1, JT2); - unsigned MaxTypeBits = std::max( - IT1->getPrimitiveSizeInBits() + JT1->getPrimitiveSizeInBits(), - IT2->getPrimitiveSizeInBits() + JT2->getPrimitiveSizeInBits()); - if (!TTI && MaxTypeBits > Config.VectorBits) - return false; - - // FIXME: handle addsub-type operations! - - if (IsSimpleLoadStore) { - Value *IPtr, *JPtr; - unsigned IAlignment, JAlignment, IAddressSpace, JAddressSpace; - int64_t OffsetInElmts = 0; - if (getPairPtrInfo(I, J, IPtr, JPtr, IAlignment, JAlignment, - IAddressSpace, JAddressSpace, OffsetInElmts) && - std::abs(OffsetInElmts) == 1) { - FixedOrder = (int) OffsetInElmts; - unsigned BottomAlignment = IAlignment; - if (OffsetInElmts < 0) BottomAlignment = JAlignment; - - Type *aTypeI = isa<StoreInst>(I) ? - cast<StoreInst>(I)->getValueOperand()->getType() : I->getType(); - Type *aTypeJ = isa<StoreInst>(J) ? - cast<StoreInst>(J)->getValueOperand()->getType() : J->getType(); - Type *VType = getVecTypeForPair(aTypeI, aTypeJ); - - if (Config.AlignedOnly) { - // An aligned load or store is possible only if the instruction - // with the lower offset has an alignment suitable for the - // vector type. - const DataLayout &DL = I->getModule()->getDataLayout(); - unsigned VecAlignment = DL.getPrefTypeAlignment(VType); - if (BottomAlignment < VecAlignment) - return false; - } - - if (TTI) { - unsigned ICost = TTI->getMemoryOpCost(I->getOpcode(), aTypeI, - IAlignment, IAddressSpace); - unsigned JCost = TTI->getMemoryOpCost(J->getOpcode(), aTypeJ, - JAlignment, JAddressSpace); - unsigned VCost = TTI->getMemoryOpCost(I->getOpcode(), VType, - BottomAlignment, - IAddressSpace); - - ICost += TTI->getAddressComputationCost(aTypeI); - JCost += TTI->getAddressComputationCost(aTypeJ); - VCost += TTI->getAddressComputationCost(VType); - - if (VCost > ICost + JCost) - return false; - - // We don't want to fuse to a type that will be split, even - // if the two input types will also be split and there is no other - // associated cost. - unsigned VParts = TTI->getNumberOfParts(VType); - if (VParts > 1) - return false; - else if (!VParts && VCost == ICost + JCost) - return false; - - CostSavings = ICost + JCost - VCost; - } - } else { - return false; - } - } else if (TTI) { - unsigned ICost = getInstrCost(I->getOpcode(), IT1, IT2); - unsigned JCost = getInstrCost(J->getOpcode(), JT1, JT2); - Type *VT1 = getVecTypeForPair(IT1, JT1), - *VT2 = getVecTypeForPair(IT2, JT2); - TargetTransformInfo::OperandValueKind Op1VK = - TargetTransformInfo::OK_AnyValue; - TargetTransformInfo::OperandValueKind Op2VK = - TargetTransformInfo::OK_AnyValue; - - // On some targets (example X86) the cost of a vector shift may vary - // depending on whether the second operand is a Uniform or - // NonUniform Constant. - switch (I->getOpcode()) { - default : break; - case Instruction::Shl: - case Instruction::LShr: - case Instruction::AShr: - - // If both I and J are scalar shifts by constant, then the - // merged vector shift count would be either a constant splat value - // or a non-uniform vector of constants. - if (ConstantInt *CII = dyn_cast<ConstantInt>(I->getOperand(1))) { - if (ConstantInt *CIJ = dyn_cast<ConstantInt>(J->getOperand(1))) - Op2VK = CII == CIJ ? TargetTransformInfo::OK_UniformConstantValue : - TargetTransformInfo::OK_NonUniformConstantValue; - } else { - // Check for a splat of a constant or for a non uniform vector - // of constants. - Value *IOp = I->getOperand(1); - Value *JOp = J->getOperand(1); - if ((isa<ConstantVector>(IOp) || isa<ConstantDataVector>(IOp)) && - (isa<ConstantVector>(JOp) || isa<ConstantDataVector>(JOp))) { - Op2VK = TargetTransformInfo::OK_NonUniformConstantValue; - Constant *SplatValue = cast<Constant>(IOp)->getSplatValue(); - if (SplatValue != nullptr && - SplatValue == cast<Constant>(JOp)->getSplatValue()) - Op2VK = TargetTransformInfo::OK_UniformConstantValue; - } - } - } - - // Note that this procedure is incorrect for insert and extract element - // instructions (because combining these often results in a shuffle), - // but this cost is ignored (because insert and extract element - // instructions are assigned a zero depth factor and are not really - // fused in general). - unsigned VCost = getInstrCost(I->getOpcode(), VT1, VT2, Op1VK, Op2VK); - - if (VCost > ICost + JCost) - return false; - - // We don't want to fuse to a type that will be split, even - // if the two input types will also be split and there is no other - // associated cost. - unsigned VParts1 = TTI->getNumberOfParts(VT1), - VParts2 = TTI->getNumberOfParts(VT2); - if (VParts1 > 1 || VParts2 > 1) - return false; - else if ((!VParts1 || !VParts2) && VCost == ICost + JCost) - return false; - - CostSavings = ICost + JCost - VCost; - } - - // The powi,ctlz,cttz intrinsics are special because only the first - // argument is vectorized, the second arguments must be equal. - CallInst *CI = dyn_cast<CallInst>(I); - Function *FI; - if (CI && (FI = CI->getCalledFunction())) { - Intrinsic::ID IID = FI->getIntrinsicID(); - if (IID == Intrinsic::powi || IID == Intrinsic::ctlz || - IID == Intrinsic::cttz) { - Value *A1I = CI->getArgOperand(1), - *A1J = cast<CallInst>(J)->getArgOperand(1); - const SCEV *A1ISCEV = SE->getSCEV(A1I), - *A1JSCEV = SE->getSCEV(A1J); - return (A1ISCEV == A1JSCEV); - } - - if (IID && TTI) { - FastMathFlags FMFCI; - if (auto *FPMOCI = dyn_cast<FPMathOperator>(CI)) - FMFCI = FPMOCI->getFastMathFlags(); - - SmallVector<Type*, 4> Tys; - for (unsigned i = 0, ie = CI->getNumArgOperands(); i != ie; ++i) - Tys.push_back(CI->getArgOperand(i)->getType()); - unsigned ICost = TTI->getIntrinsicInstrCost(IID, IT1, Tys, FMFCI); - - Tys.clear(); - CallInst *CJ = cast<CallInst>(J); - - FastMathFlags FMFCJ; - if (auto *FPMOCJ = dyn_cast<FPMathOperator>(CJ)) - FMFCJ = FPMOCJ->getFastMathFlags(); - - for (unsigned i = 0, ie = CJ->getNumArgOperands(); i != ie; ++i) - Tys.push_back(CJ->getArgOperand(i)->getType()); - unsigned JCost = TTI->getIntrinsicInstrCost(IID, JT1, Tys, FMFCJ); - - Tys.clear(); - assert(CI->getNumArgOperands() == CJ->getNumArgOperands() && - "Intrinsic argument counts differ"); - for (unsigned i = 0, ie = CI->getNumArgOperands(); i != ie; ++i) { - if ((IID == Intrinsic::powi || IID == Intrinsic::ctlz || - IID == Intrinsic::cttz) && i == 1) - Tys.push_back(CI->getArgOperand(i)->getType()); - else - Tys.push_back(getVecTypeForPair(CI->getArgOperand(i)->getType(), - CJ->getArgOperand(i)->getType())); - } - - FastMathFlags FMFV = FMFCI; - FMFV &= FMFCJ; - Type *RetTy = getVecTypeForPair(IT1, JT1); - unsigned VCost = TTI->getIntrinsicInstrCost(IID, RetTy, Tys, FMFV); - - if (VCost > ICost + JCost) - return false; - - // We don't want to fuse to a type that will be split, even - // if the two input types will also be split and there is no other - // associated cost. - unsigned RetParts = TTI->getNumberOfParts(RetTy); - if (RetParts > 1) - return false; - else if (!RetParts && VCost == ICost + JCost) - return false; - - for (unsigned i = 0, ie = CI->getNumArgOperands(); i != ie; ++i) { - if (!Tys[i]->isVectorTy()) - continue; - - unsigned NumParts = TTI->getNumberOfParts(Tys[i]); - if (NumParts > 1) - return false; - else if (!NumParts && VCost == ICost + JCost) - return false; - } - - CostSavings = ICost + JCost - VCost; - } - } - - return true; - } - - // Figure out whether or not J uses I and update the users and write-set - // structures associated with I. Specifically, Users represents the set of - // instructions that depend on I. WriteSet represents the set - // of memory locations that are dependent on I. If UpdateUsers is true, - // and J uses I, then Users is updated to contain J and WriteSet is updated - // to contain any memory locations to which J writes. The function returns - // true if J uses I. By default, alias analysis is used to determine - // whether J reads from memory that overlaps with a location in WriteSet. - // If LoadMoveSet is not null, then it is a previously-computed map - // where the key is the memory-based user instruction and the value is - // the instruction to be compared with I. So, if LoadMoveSet is provided, - // then the alias analysis is not used. This is necessary because this - // function is called during the process of moving instructions during - // vectorization and the results of the alias analysis are not stable during - // that process. - bool BBVectorize::trackUsesOfI(DenseSet<Value *> &Users, - AliasSetTracker &WriteSet, Instruction *I, - Instruction *J, bool UpdateUsers, - DenseSet<ValuePair> *LoadMoveSetPairs) { - bool UsesI = false; - - // This instruction may already be marked as a user due, for example, to - // being a member of a selected pair. - if (Users.count(J)) - UsesI = true; - - if (!UsesI) - for (User::op_iterator JU = J->op_begin(), JE = J->op_end(); - JU != JE; ++JU) { - Value *V = *JU; - if (I == V || Users.count(V)) { - UsesI = true; - break; - } - } - if (!UsesI && J->mayReadFromMemory()) { - if (LoadMoveSetPairs) { - UsesI = LoadMoveSetPairs->count(ValuePair(J, I)); - } else { - for (AliasSetTracker::iterator W = WriteSet.begin(), - WE = WriteSet.end(); W != WE; ++W) { - if (W->aliasesUnknownInst(J, *AA)) { - UsesI = true; - break; - } - } - } - } - - if (UsesI && UpdateUsers) { - if (J->mayWriteToMemory()) WriteSet.add(J); - Users.insert(J); - } - - return UsesI; - } - - // This function iterates over all instruction pairs in the provided - // basic block and collects all candidate pairs for vectorization. - bool BBVectorize::getCandidatePairs(BasicBlock &BB, - BasicBlock::iterator &Start, - DenseMap<Value *, std::vector<Value *> > &CandidatePairs, - DenseSet<ValuePair> &FixedOrderPairs, - DenseMap<ValuePair, int> &CandidatePairCostSavings, - std::vector<Value *> &PairableInsts, bool NonPow2Len) { - size_t TotalPairs = 0; - BasicBlock::iterator E = BB.end(); - if (Start == E) return false; - - bool ShouldContinue = false, IAfterStart = false; - for (BasicBlock::iterator I = Start++; I != E; ++I) { - if (I == Start) IAfterStart = true; - - bool IsSimpleLoadStore; - if (!isInstVectorizable(&*I, IsSimpleLoadStore)) - continue; - - // Look for an instruction with which to pair instruction *I... - DenseSet<Value *> Users; - AliasSetTracker WriteSet(*AA); - if (I->mayWriteToMemory()) - WriteSet.add(&*I); - - bool JAfterStart = IAfterStart; - BasicBlock::iterator J = std::next(I); - for (unsigned ss = 0; J != E && ss <= Config.SearchLimit; ++J, ++ss) { - if (J == Start) - JAfterStart = true; - - // Determine if J uses I, if so, exit the loop. - bool UsesI = trackUsesOfI(Users, WriteSet, &*I, &*J, !Config.FastDep); - if (Config.FastDep) { - // Note: For this heuristic to be effective, independent operations - // must tend to be intermixed. This is likely to be true from some - // kinds of grouped loop unrolling (but not the generic LLVM pass), - // but otherwise may require some kind of reordering pass. - - // When using fast dependency analysis, - // stop searching after first use: - if (UsesI) break; - } else { - if (UsesI) continue; - } - - // J does not use I, and comes before the first use of I, so it can be - // merged with I if the instructions are compatible. - int CostSavings, FixedOrder; - if (!areInstsCompatible(&*I, &*J, IsSimpleLoadStore, NonPow2Len, - CostSavings, FixedOrder)) - continue; - - // J is a candidate for merging with I. - if (PairableInsts.empty() || - PairableInsts[PairableInsts.size() - 1] != &*I) { - PairableInsts.push_back(&*I); - } - - CandidatePairs[&*I].push_back(&*J); - ++TotalPairs; - if (TTI) - CandidatePairCostSavings.insert( - ValuePairWithCost(ValuePair(&*I, &*J), CostSavings)); - - if (FixedOrder == 1) - FixedOrderPairs.insert(ValuePair(&*I, &*J)); - else if (FixedOrder == -1) - FixedOrderPairs.insert(ValuePair(&*J, &*I)); - - // The next call to this function must start after the last instruction - // selected during this invocation. - if (JAfterStart) { - Start = std::next(J); - IAfterStart = JAfterStart = false; - } - - DEBUG(if (DebugCandidateSelection) dbgs() << "BBV: candidate pair " - << *I << " <-> " << *J << " (cost savings: " << - CostSavings << ")\n"); - - // If we have already found too many pairs, break here and this function - // will be called again starting after the last instruction selected - // during this invocation. - if (PairableInsts.size() >= Config.MaxInsts || - TotalPairs >= Config.MaxPairs) { - ShouldContinue = true; - break; - } - } - - if (ShouldContinue) - break; - } - - DEBUG(dbgs() << "BBV: found " << PairableInsts.size() - << " instructions with candidate pairs\n"); - - return ShouldContinue; - } - - // Finds candidate pairs connected to the pair P = <PI, PJ>. This means that - // it looks for pairs such that both members have an input which is an - // output of PI or PJ. - void BBVectorize::computePairsConnectedTo( - DenseMap<Value *, std::vector<Value *> > &CandidatePairs, - DenseSet<ValuePair> &CandidatePairsSet, - std::vector<Value *> &PairableInsts, - DenseMap<ValuePair, std::vector<ValuePair> > &ConnectedPairs, - DenseMap<VPPair, unsigned> &PairConnectionTypes, - ValuePair P) { - StoreInst *SI, *SJ; - - // For each possible pairing for this variable, look at the uses of - // the first value... - for (Value::user_iterator I = P.first->user_begin(), - E = P.first->user_end(); - I != E; ++I) { - User *UI = *I; - if (isa<LoadInst>(UI)) { - // A pair cannot be connected to a load because the load only takes one - // operand (the address) and it is a scalar even after vectorization. - continue; - } else if ((SI = dyn_cast<StoreInst>(UI)) && - P.first == SI->getPointerOperand()) { - // Similarly, a pair cannot be connected to a store through its - // pointer operand. - continue; - } - - // For each use of the first variable, look for uses of the second - // variable... - for (User *UJ : P.second->users()) { - if ((SJ = dyn_cast<StoreInst>(UJ)) && - P.second == SJ->getPointerOperand()) - continue; - - // Look for <I, J>: - if (CandidatePairsSet.count(ValuePair(UI, UJ))) { - VPPair VP(P, ValuePair(UI, UJ)); - ConnectedPairs[VP.first].push_back(VP.second); - PairConnectionTypes.insert(VPPairWithType(VP, PairConnectionDirect)); - } - - // Look for <J, I>: - if (CandidatePairsSet.count(ValuePair(UJ, UI))) { - VPPair VP(P, ValuePair(UJ, UI)); - ConnectedPairs[VP.first].push_back(VP.second); - PairConnectionTypes.insert(VPPairWithType(VP, PairConnectionSwap)); - } - } - - if (Config.SplatBreaksChain) continue; - // Look for cases where just the first value in the pair is used by - // both members of another pair (splatting). - for (Value::user_iterator J = P.first->user_begin(); J != E; ++J) { - User *UJ = *J; - if ((SJ = dyn_cast<StoreInst>(UJ)) && - P.first == SJ->getPointerOperand()) - continue; - - if (CandidatePairsSet.count(ValuePair(UI, UJ))) { - VPPair VP(P, ValuePair(UI, UJ)); - ConnectedPairs[VP.first].push_back(VP.second); - PairConnectionTypes.insert(VPPairWithType(VP, PairConnectionSplat)); - } - } - } - - if (Config.SplatBreaksChain) return; - // Look for cases where just the second value in the pair is used by - // both members of another pair (splatting). - for (Value::user_iterator I = P.second->user_begin(), - E = P.second->user_end(); - I != E; ++I) { - User *UI = *I; - if (isa<LoadInst>(UI)) - continue; - else if ((SI = dyn_cast<StoreInst>(UI)) && - P.second == SI->getPointerOperand()) - continue; - - for (Value::user_iterator J = P.second->user_begin(); J != E; ++J) { - User *UJ = *J; - if ((SJ = dyn_cast<StoreInst>(UJ)) && - P.second == SJ->getPointerOperand()) - continue; - - if (CandidatePairsSet.count(ValuePair(UI, UJ))) { - VPPair VP(P, ValuePair(UI, UJ)); - ConnectedPairs[VP.first].push_back(VP.second); - PairConnectionTypes.insert(VPPairWithType(VP, PairConnectionSplat)); - } - } - } - } - - // This function figures out which pairs are connected. Two pairs are - // connected if some output of the first pair forms an input to both members - // of the second pair. - void BBVectorize::computeConnectedPairs( - DenseMap<Value *, std::vector<Value *> > &CandidatePairs, - DenseSet<ValuePair> &CandidatePairsSet, - std::vector<Value *> &PairableInsts, - DenseMap<ValuePair, std::vector<ValuePair> > &ConnectedPairs, - DenseMap<VPPair, unsigned> &PairConnectionTypes) { - for (std::vector<Value *>::iterator PI = PairableInsts.begin(), - PE = PairableInsts.end(); PI != PE; ++PI) { - DenseMap<Value *, std::vector<Value *> >::iterator PP = - CandidatePairs.find(*PI); - if (PP == CandidatePairs.end()) - continue; - - for (std::vector<Value *>::iterator P = PP->second.begin(), - E = PP->second.end(); P != E; ++P) - computePairsConnectedTo(CandidatePairs, CandidatePairsSet, - PairableInsts, ConnectedPairs, - PairConnectionTypes, ValuePair(*PI, *P)); - } - - DEBUG(size_t TotalPairs = 0; - for (DenseMap<ValuePair, std::vector<ValuePair> >::iterator I = - ConnectedPairs.begin(), IE = ConnectedPairs.end(); I != IE; ++I) - TotalPairs += I->second.size(); - dbgs() << "BBV: found " << TotalPairs - << " pair connections.\n"); - } - - // This function builds a set of use tuples such that <A, B> is in the set - // if B is in the use dag of A. If B is in the use dag of A, then B - // depends on the output of A. - void BBVectorize::buildDepMap( - BasicBlock &BB, - DenseMap<Value *, std::vector<Value *> > &CandidatePairs, - std::vector<Value *> &PairableInsts, - DenseSet<ValuePair> &PairableInstUsers) { - DenseSet<Value *> IsInPair; - for (DenseMap<Value *, std::vector<Value *> >::iterator C = - CandidatePairs.begin(), E = CandidatePairs.end(); C != E; ++C) { - IsInPair.insert(C->first); - IsInPair.insert(C->second.begin(), C->second.end()); - } - - // Iterate through the basic block, recording all users of each - // pairable instruction. - - BasicBlock::iterator E = BB.end(), EL = - BasicBlock::iterator(cast<Instruction>(PairableInsts.back())); - for (BasicBlock::iterator I = BB.getFirstInsertionPt(); I != E; ++I) { - if (IsInPair.find(&*I) == IsInPair.end()) - continue; - - DenseSet<Value *> Users; - AliasSetTracker WriteSet(*AA); - if (I->mayWriteToMemory()) - WriteSet.add(&*I); - - for (BasicBlock::iterator J = std::next(I); J != E; ++J) { - (void)trackUsesOfI(Users, WriteSet, &*I, &*J); - - if (J == EL) - break; - } - - for (DenseSet<Value *>::iterator U = Users.begin(), E = Users.end(); - U != E; ++U) { - if (IsInPair.find(*U) == IsInPair.end()) continue; - PairableInstUsers.insert(ValuePair(&*I, *U)); - } - - if (I == EL) - break; - } - } - - // Returns true if an input to pair P is an output of pair Q and also an - // input of pair Q is an output of pair P. If this is the case, then these - // two pairs cannot be simultaneously fused. - bool BBVectorize::pairsConflict(ValuePair P, ValuePair Q, - DenseSet<ValuePair> &PairableInstUsers, - DenseMap<ValuePair, std::vector<ValuePair> > *PairableInstUserMap, - DenseSet<VPPair> *PairableInstUserPairSet) { - // Two pairs are in conflict if they are mutual Users of eachother. - bool QUsesP = PairableInstUsers.count(ValuePair(P.first, Q.first)) || - PairableInstUsers.count(ValuePair(P.first, Q.second)) || - PairableInstUsers.count(ValuePair(P.second, Q.first)) || - PairableInstUsers.count(ValuePair(P.second, Q.second)); - bool PUsesQ = PairableInstUsers.count(ValuePair(Q.first, P.first)) || - PairableInstUsers.count(ValuePair(Q.first, P.second)) || - PairableInstUsers.count(ValuePair(Q.second, P.first)) || - PairableInstUsers.count(ValuePair(Q.second, P.second)); - if (PairableInstUserMap) { - // FIXME: The expensive part of the cycle check is not so much the cycle - // check itself but this edge insertion procedure. This needs some - // profiling and probably a different data structure. - if (PUsesQ) { - if (PairableInstUserPairSet->insert(VPPair(Q, P)).second) - (*PairableInstUserMap)[Q].push_back(P); - } - if (QUsesP) { - if (PairableInstUserPairSet->insert(VPPair(P, Q)).second) - (*PairableInstUserMap)[P].push_back(Q); - } - } - - return (QUsesP && PUsesQ); - } - - // This function walks the use graph of current pairs to see if, starting - // from P, the walk returns to P. - bool BBVectorize::pairWillFormCycle(ValuePair P, - DenseMap<ValuePair, std::vector<ValuePair> > &PairableInstUserMap, - DenseSet<ValuePair> &CurrentPairs) { - DEBUG(if (DebugCycleCheck) - dbgs() << "BBV: starting cycle check for : " << *P.first << " <-> " - << *P.second << "\n"); - // A lookup table of visisted pairs is kept because the PairableInstUserMap - // contains non-direct associations. - DenseSet<ValuePair> Visited; - SmallVector<ValuePair, 32> Q; - // General depth-first post-order traversal: - Q.push_back(P); - do { - ValuePair QTop = Q.pop_back_val(); - Visited.insert(QTop); - - DEBUG(if (DebugCycleCheck) - dbgs() << "BBV: cycle check visiting: " << *QTop.first << " <-> " - << *QTop.second << "\n"); - DenseMap<ValuePair, std::vector<ValuePair> >::iterator QQ = - PairableInstUserMap.find(QTop); - if (QQ == PairableInstUserMap.end()) - continue; - - for (std::vector<ValuePair>::iterator C = QQ->second.begin(), - CE = QQ->second.end(); C != CE; ++C) { - if (*C == P) { - DEBUG(dbgs() - << "BBV: rejected to prevent non-trivial cycle formation: " - << QTop.first << " <-> " << C->second << "\n"); - return true; - } - - if (CurrentPairs.count(*C) && !Visited.count(*C)) - Q.push_back(*C); - } - } while (!Q.empty()); - - return false; - } - - // This function builds the initial dag of connected pairs with the - // pair J at the root. - void BBVectorize::buildInitialDAGFor( - DenseMap<Value *, std::vector<Value *> > &CandidatePairs, - DenseSet<ValuePair> &CandidatePairsSet, - std::vector<Value *> &PairableInsts, - DenseMap<ValuePair, std::vector<ValuePair> > &ConnectedPairs, - DenseSet<ValuePair> &PairableInstUsers, - DenseMap<Value *, Value *> &ChosenPairs, - DenseMap<ValuePair, size_t> &DAG, ValuePair J) { - // Each of these pairs is viewed as the root node of a DAG. The DAG - // is then walked (depth-first). As this happens, we keep track of - // the pairs that compose the DAG and the maximum depth of the DAG. - SmallVector<ValuePairWithDepth, 32> Q; - // General depth-first post-order traversal: - Q.push_back(ValuePairWithDepth(J, getDepthFactor(J.first))); - do { - ValuePairWithDepth QTop = Q.back(); - - // Push each child onto the queue: - bool MoreChildren = false; - size_t MaxChildDepth = QTop.second; - DenseMap<ValuePair, std::vector<ValuePair> >::iterator QQ = - ConnectedPairs.find(QTop.first); - if (QQ != ConnectedPairs.end()) - for (std::vector<ValuePair>::iterator k = QQ->second.begin(), - ke = QQ->second.end(); k != ke; ++k) { - // Make sure that this child pair is still a candidate: - if (CandidatePairsSet.count(*k)) { - DenseMap<ValuePair, size_t>::iterator C = DAG.find(*k); - if (C == DAG.end()) { - size_t d = getDepthFactor(k->first); - Q.push_back(ValuePairWithDepth(*k, QTop.second+d)); - MoreChildren = true; - } else { - MaxChildDepth = std::max(MaxChildDepth, C->second); - } - } - } - - if (!MoreChildren) { - // Record the current pair as part of the DAG: - DAG.insert(ValuePairWithDepth(QTop.first, MaxChildDepth)); - Q.pop_back(); - } - } while (!Q.empty()); - } - - // Given some initial dag, prune it by removing conflicting pairs (pairs - // that cannot be simultaneously chosen for vectorization). - void BBVectorize::pruneDAGFor( - DenseMap<Value *, std::vector<Value *> > &CandidatePairs, - std::vector<Value *> &PairableInsts, - DenseMap<ValuePair, std::vector<ValuePair> > &ConnectedPairs, - DenseSet<ValuePair> &PairableInstUsers, - DenseMap<ValuePair, std::vector<ValuePair> > &PairableInstUserMap, - DenseSet<VPPair> &PairableInstUserPairSet, - DenseMap<Value *, Value *> &ChosenPairs, - DenseMap<ValuePair, size_t> &DAG, - DenseSet<ValuePair> &PrunedDAG, ValuePair J, - bool UseCycleCheck) { - SmallVector<ValuePairWithDepth, 32> Q; - // General depth-first post-order traversal: - Q.push_back(ValuePairWithDepth(J, getDepthFactor(J.first))); - do { - ValuePairWithDepth QTop = Q.pop_back_val(); - PrunedDAG.insert(QTop.first); - - // Visit each child, pruning as necessary... - SmallVector<ValuePairWithDepth, 8> BestChildren; - DenseMap<ValuePair, std::vector<ValuePair> >::iterator QQ = - ConnectedPairs.find(QTop.first); - if (QQ == ConnectedPairs.end()) - continue; - - for (std::vector<ValuePair>::iterator K = QQ->second.begin(), - KE = QQ->second.end(); K != KE; ++K) { - DenseMap<ValuePair, size_t>::iterator C = DAG.find(*K); - if (C == DAG.end()) continue; - - // This child is in the DAG, now we need to make sure it is the - // best of any conflicting children. There could be multiple - // conflicting children, so first, determine if we're keeping - // this child, then delete conflicting children as necessary. - - // It is also necessary to guard against pairing-induced - // dependencies. Consider instructions a .. x .. y .. b - // such that (a,b) are to be fused and (x,y) are to be fused - // but a is an input to x and b is an output from y. This - // means that y cannot be moved after b but x must be moved - // after b for (a,b) to be fused. In other words, after - // fusing (a,b) we have y .. a/b .. x where y is an input - // to a/b and x is an output to a/b: x and y can no longer - // be legally fused. To prevent this condition, we must - // make sure that a child pair added to the DAG is not - // both an input and output of an already-selected pair. - - // Pairing-induced dependencies can also form from more complicated - // cycles. The pair vs. pair conflicts are easy to check, and so - // that is done explicitly for "fast rejection", and because for - // child vs. child conflicts, we may prefer to keep the current - // pair in preference to the already-selected child. - DenseSet<ValuePair> CurrentPairs; - - bool CanAdd = true; - for (SmallVectorImpl<ValuePairWithDepth>::iterator C2 - = BestChildren.begin(), E2 = BestChildren.end(); - C2 != E2; ++C2) { - if (C2->first.first == C->first.first || - C2->first.first == C->first.second || - C2->first.second == C->first.first || - C2->first.second == C->first.second || - pairsConflict(C2->first, C->first, PairableInstUsers, - UseCycleCheck ? &PairableInstUserMap : nullptr, - UseCycleCheck ? &PairableInstUserPairSet - : nullptr)) { - if (C2->second >= C->second) { - CanAdd = false; - break; - } - - CurrentPairs.insert(C2->first); - } - } - if (!CanAdd) continue; - - // Even worse, this child could conflict with another node already - // selected for the DAG. If that is the case, ignore this child. - for (DenseSet<ValuePair>::iterator T = PrunedDAG.begin(), - E2 = PrunedDAG.end(); T != E2; ++T) { - if (T->first == C->first.first || - T->first == C->first.second || - T->second == C->first.first || - T->second == C->first.second || - pairsConflict(*T, C->first, PairableInstUsers, - UseCycleCheck ? &PairableInstUserMap : nullptr, - UseCycleCheck ? &PairableInstUserPairSet - : nullptr)) { - CanAdd = false; - break; - } - - CurrentPairs.insert(*T); - } - if (!CanAdd) continue; - - // And check the queue too... - for (SmallVectorImpl<ValuePairWithDepth>::iterator C2 = Q.begin(), - E2 = Q.end(); C2 != E2; ++C2) { - if (C2->first.first == C->first.first || - C2->first.first == C->first.second || - C2->first.second == C->first.first || - C2->first.second == C->first.second || - pairsConflict(C2->first, C->first, PairableInstUsers, - UseCycleCheck ? &PairableInstUserMap : nullptr, - UseCycleCheck ? &PairableInstUserPairSet - : nullptr)) { - CanAdd = false; - break; - } - - CurrentPairs.insert(C2->first); - } - if (!CanAdd) continue; - - // Last but not least, check for a conflict with any of the - // already-chosen pairs. - for (DenseMap<Value *, Value *>::iterator C2 = - ChosenPairs.begin(), E2 = ChosenPairs.end(); - C2 != E2; ++C2) { - if (pairsConflict(*C2, C->first, PairableInstUsers, - UseCycleCheck ? &PairableInstUserMap : nullptr, - UseCycleCheck ? &PairableInstUserPairSet - : nullptr)) { - CanAdd = false; - break; - } - - CurrentPairs.insert(*C2); - } - if (!CanAdd) continue; - - // To check for non-trivial cycles formed by the addition of the - // current pair we've formed a list of all relevant pairs, now use a - // graph walk to check for a cycle. We start from the current pair and - // walk the use dag to see if we again reach the current pair. If we - // do, then the current pair is rejected. - - // FIXME: It may be more efficient to use a topological-ordering - // algorithm to improve the cycle check. This should be investigated. - if (UseCycleCheck && - pairWillFormCycle(C->first, PairableInstUserMap, CurrentPairs)) - continue; - - // This child can be added, but we may have chosen it in preference - // to an already-selected child. Check for this here, and if a - // conflict is found, then remove the previously-selected child - // before adding this one in its place. - for (SmallVectorImpl<ValuePairWithDepth>::iterator C2 - = BestChildren.begin(); C2 != BestChildren.end();) { - if (C2->first.first == C->first.first || - C2->first.first == C->first.second || - C2->first.second == C->first.first || - C2->first.second == C->first.second || - pairsConflict(C2->first, C->first, PairableInstUsers)) - C2 = BestChildren.erase(C2); - else - ++C2; - } - - BestChildren.push_back(ValuePairWithDepth(C->first, C->second)); - } - - for (SmallVectorImpl<ValuePairWithDepth>::iterator C - = BestChildren.begin(), E2 = BestChildren.end(); - C != E2; ++C) { - size_t DepthF = getDepthFactor(C->first.first); - Q.push_back(ValuePairWithDepth(C->first, QTop.second+DepthF)); - } - } while (!Q.empty()); - } - - // This function finds the best dag of mututally-compatible connected - // pairs, given the choice of root pairs as an iterator range. - void BBVectorize::findBestDAGFor( - DenseMap<Value *, std::vector<Value *> > &CandidatePairs, - DenseSet<ValuePair> &CandidatePairsSet, - DenseMap<ValuePair, int> &CandidatePairCostSavings, - std::vector<Value *> &PairableInsts, - DenseSet<ValuePair> &FixedOrderPairs, - DenseMap<VPPair, unsigned> &PairConnectionTypes, - DenseMap<ValuePair, std::vector<ValuePair> > &ConnectedPairs, - DenseMap<ValuePair, std::vector<ValuePair> > &ConnectedPairDeps, - DenseSet<ValuePair> &PairableInstUsers, - DenseMap<ValuePair, std::vector<ValuePair> > &PairableInstUserMap, - DenseSet<VPPair> &PairableInstUserPairSet, - DenseMap<Value *, Value *> &ChosenPairs, - DenseSet<ValuePair> &BestDAG, size_t &BestMaxDepth, - int &BestEffSize, Value *II, std::vector<Value *>&JJ, - bool UseCycleCheck) { - for (std::vector<Value *>::iterator J = JJ.begin(), JE = JJ.end(); - J != JE; ++J) { - ValuePair IJ(II, *J); - if (!CandidatePairsSet.count(IJ)) - continue; - - // Before going any further, make sure that this pair does not - // conflict with any already-selected pairs (see comment below - // near the DAG pruning for more details). - DenseSet<ValuePair> ChosenPairSet; - bool DoesConflict = false; - for (DenseMap<Value *, Value *>::iterator C = ChosenPairs.begin(), - E = ChosenPairs.end(); C != E; ++C) { - if (pairsConflict(*C, IJ, PairableInstUsers, - UseCycleCheck ? &PairableInstUserMap : nullptr, - UseCycleCheck ? &PairableInstUserPairSet : nullptr)) { - DoesConflict = true; - break; - } - - ChosenPairSet.insert(*C); - } - if (DoesConflict) continue; - - if (UseCycleCheck && - pairWillFormCycle(IJ, PairableInstUserMap, ChosenPairSet)) - continue; - - DenseMap<ValuePair, size_t> DAG; - buildInitialDAGFor(CandidatePairs, CandidatePairsSet, - PairableInsts, ConnectedPairs, - PairableInstUsers, ChosenPairs, DAG, IJ); - - // Because we'll keep the child with the largest depth, the largest - // depth is still the same in the unpruned DAG. - size_t MaxDepth = DAG.lookup(IJ); - - DEBUG(if (DebugPairSelection) dbgs() << "BBV: found DAG for pair {" - << *IJ.first << " <-> " << *IJ.second << "} of depth " << - MaxDepth << " and size " << DAG.size() << "\n"); - - // At this point the DAG has been constructed, but, may contain - // contradictory children (meaning that different children of - // some dag node may be attempting to fuse the same instruction). - // So now we walk the dag again, in the case of a conflict, - // keep only the child with the largest depth. To break a tie, - // favor the first child. - - DenseSet<ValuePair> PrunedDAG; - pruneDAGFor(CandidatePairs, PairableInsts, ConnectedPairs, - PairableInstUsers, PairableInstUserMap, - PairableInstUserPairSet, - ChosenPairs, DAG, PrunedDAG, IJ, UseCycleCheck); - - int EffSize = 0; - if (TTI) { - DenseSet<Value *> PrunedDAGInstrs; - for (DenseSet<ValuePair>::iterator S = PrunedDAG.begin(), - E = PrunedDAG.end(); S != E; ++S) { - PrunedDAGInstrs.insert(S->first); - PrunedDAGInstrs.insert(S->second); - } - - // The set of pairs that have already contributed to the total cost. - DenseSet<ValuePair> IncomingPairs; - - // If the cost model were perfect, this might not be necessary; but we - // need to make sure that we don't get stuck vectorizing our own - // shuffle chains. - bool HasNontrivialInsts = false; - - // The node weights represent the cost savings associated with - // fusing the pair of instructions. - for (DenseSet<ValuePair>::iterator S = PrunedDAG.begin(), - E = PrunedDAG.end(); S != E; ++S) { - if (!isa<ShuffleVectorInst>(S->first) && - !isa<InsertElementInst>(S->first) && - !isa<ExtractElementInst>(S->first)) - HasNontrivialInsts = true; - - bool FlipOrder = false; - - if (getDepthFactor(S->first)) { - int ESContrib = CandidatePairCostSavings.find(*S)->second; - DEBUG(if (DebugPairSelection) dbgs() << "\tweight {" - << *S->first << " <-> " << *S->second << "} = " << - ESContrib << "\n"); - EffSize += ESContrib; - } - - // The edge weights contribute in a negative sense: they represent - // the cost of shuffles. - DenseMap<ValuePair, std::vector<ValuePair> >::iterator SS = - ConnectedPairDeps.find(*S); - if (SS != ConnectedPairDeps.end()) { - unsigned NumDepsDirect = 0, NumDepsSwap = 0; - for (std::vector<ValuePair>::iterator T = SS->second.begin(), - TE = SS->second.end(); T != TE; ++T) { - VPPair Q(*S, *T); - if (!PrunedDAG.count(Q.second)) - continue; - DenseMap<VPPair, unsigned>::iterator R = - PairConnectionTypes.find(VPPair(Q.second, Q.first)); - assert(R != PairConnectionTypes.end() && - "Cannot find pair connection type"); - if (R->second == PairConnectionDirect) - ++NumDepsDirect; - else if (R->second == PairConnectionSwap) - ++NumDepsSwap; - } - - // If there are more swaps than direct connections, then - // the pair order will be flipped during fusion. So the real - // number of swaps is the minimum number. - FlipOrder = !FixedOrderPairs.count(*S) && - ((NumDepsSwap > NumDepsDirect) || - FixedOrderPairs.count(ValuePair(S->second, S->first))); - - for (std::vector<ValuePair>::iterator T = SS->second.begin(), - TE = SS->second.end(); T != TE; ++T) { - VPPair Q(*S, *T); - if (!PrunedDAG.count(Q.second)) - continue; - DenseMap<VPPair, unsigned>::iterator R = - PairConnectionTypes.find(VPPair(Q.second, Q.first)); - assert(R != PairConnectionTypes.end() && - "Cannot find pair connection type"); - Type *Ty1 = Q.second.first->getType(), - *Ty2 = Q.second.second->getType(); - Type *VTy = getVecTypeForPair(Ty1, Ty2); - if ((R->second == PairConnectionDirect && FlipOrder) || - (R->second == PairConnectionSwap && !FlipOrder) || - R->second == PairConnectionSplat) { - int ESContrib = (int) getInstrCost(Instruction::ShuffleVector, - VTy, VTy); - - if (VTy->getVectorNumElements() == 2) { - if (R->second == PairConnectionSplat) - ESContrib = std::min(ESContrib, (int) TTI->getShuffleCost( - TargetTransformInfo::SK_Broadcast, VTy)); - else - ESContrib = std::min(ESContrib, (int) TTI->getShuffleCost( - TargetTransformInfo::SK_Reverse, VTy)); - } - - DEBUG(if (DebugPairSelection) dbgs() << "\tcost {" << - *Q.second.first << " <-> " << *Q.second.second << - "} -> {" << - *S->first << " <-> " << *S->second << "} = " << - ESContrib << "\n"); - EffSize -= ESContrib; - } - } - } - - // Compute the cost of outgoing edges. We assume that edges outgoing - // to shuffles, inserts or extracts can be merged, and so contribute - // no additional cost. - if (!S->first->getType()->isVoidTy()) { - Type *Ty1 = S->first->getType(), - *Ty2 = S->second->getType(); - Type *VTy = getVecTypeForPair(Ty1, Ty2); - - bool NeedsExtraction = false; - for (User *U : S->first->users()) { - if (ShuffleVectorInst *SI = dyn_cast<ShuffleVectorInst>(U)) { - // Shuffle can be folded if it has no other input - if (isa<UndefValue>(SI->getOperand(1))) - continue; - } - if (isa<ExtractElementInst>(U)) - continue; - if (PrunedDAGInstrs.count(U)) - continue; - NeedsExtraction = true; - break; - } - - if (NeedsExtraction) { - int ESContrib; - if (Ty1->isVectorTy()) { - ESContrib = (int) getInstrCost(Instruction::ShuffleVector, - Ty1, VTy); - ESContrib = std::min(ESContrib, (int) TTI->getShuffleCost( - TargetTransformInfo::SK_ExtractSubvector, VTy, 0, Ty1)); - } else - ESContrib = (int) TTI->getVectorInstrCost( - Instruction::ExtractElement, VTy, 0); - - DEBUG(if (DebugPairSelection) dbgs() << "\tcost {" << - *S->first << "} = " << ESContrib << "\n"); - EffSize -= ESContrib; - } - - NeedsExtraction = false; - for (User *U : S->second->users()) { - if (ShuffleVectorInst *SI = dyn_cast<ShuffleVectorInst>(U)) { - // Shuffle can be folded if it has no other input - if (isa<UndefValue>(SI->getOperand(1))) - continue; - } - if (isa<ExtractElementInst>(U)) - continue; - if (PrunedDAGInstrs.count(U)) - continue; - NeedsExtraction = true; - break; - } - - if (NeedsExtraction) { - int ESContrib; - if (Ty2->isVectorTy()) { - ESContrib = (int) getInstrCost(Instruction::ShuffleVector, - Ty2, VTy); - ESContrib = std::min(ESContrib, (int) TTI->getShuffleCost( - TargetTransformInfo::SK_ExtractSubvector, VTy, - Ty1->isVectorTy() ? Ty1->getVectorNumElements() : 1, Ty2)); - } else - ESContrib = (int) TTI->getVectorInstrCost( - Instruction::ExtractElement, VTy, 1); - DEBUG(if (DebugPairSelection) dbgs() << "\tcost {" << - *S->second << "} = " << ESContrib << "\n"); - EffSize -= ESContrib; - } - } - - // Compute the cost of incoming edges. - if (!isa<LoadInst>(S->first) && !isa<StoreInst>(S->first)) { - Instruction *S1 = cast<Instruction>(S->first), - *S2 = cast<Instruction>(S->second); - for (unsigned o = 0; o < S1->getNumOperands(); ++o) { - Value *O1 = S1->getOperand(o), *O2 = S2->getOperand(o); - - // Combining constants into vector constants (or small vector - // constants into larger ones are assumed free). - if (isa<Constant>(O1) && isa<Constant>(O2)) - continue; - - if (FlipOrder) - std::swap(O1, O2); - - ValuePair VP = ValuePair(O1, O2); - ValuePair VPR = ValuePair(O2, O1); - - // Internal edges are not handled here. - if (PrunedDAG.count(VP) || PrunedDAG.count(VPR)) - continue; - - Type *Ty1 = O1->getType(), - *Ty2 = O2->getType(); - Type *VTy = getVecTypeForPair(Ty1, Ty2); - - // Combining vector operations of the same type is also assumed - // folded with other operations. - if (Ty1 == Ty2) { - // If both are insert elements, then both can be widened. - InsertElementInst *IEO1 = dyn_cast<InsertElementInst>(O1), - *IEO2 = dyn_cast<InsertElementInst>(O2); - if (IEO1 && IEO2 && isPureIEChain(IEO1) && isPureIEChain(IEO2)) - continue; - // If both are extract elements, and both have the same input - // type, then they can be replaced with a shuffle - ExtractElementInst *EIO1 = dyn_cast<ExtractElementInst>(O1), - *EIO2 = dyn_cast<ExtractElementInst>(O2); - if (EIO1 && EIO2 && - EIO1->getOperand(0)->getType() == - EIO2->getOperand(0)->getType()) - continue; - // If both are a shuffle with equal operand types and only two - // unqiue operands, then they can be replaced with a single - // shuffle - ShuffleVectorInst *SIO1 = dyn_cast<ShuffleVectorInst>(O1), - *SIO2 = dyn_cast<ShuffleVectorInst>(O2); - if (SIO1 && SIO2 && - SIO1->getOperand(0)->getType() == - SIO2->getOperand(0)->getType()) { - SmallSet<Value *, 4> SIOps; - SIOps.insert(SIO1->getOperand(0)); - SIOps.insert(SIO1->getOperand(1)); - SIOps.insert(SIO2->getOperand(0)); - SIOps.insert(SIO2->getOperand(1)); - if (SIOps.size() <= 2) - continue; - } - } - - int ESContrib; - // This pair has already been formed. - if (IncomingPairs.count(VP)) { - continue; - } else if (IncomingPairs.count(VPR)) { - ESContrib = (int) getInstrCost(Instruction::ShuffleVector, - VTy, VTy); - - if (VTy->getVectorNumElements() == 2) - ESContrib = std::min(ESContrib, (int) TTI->getShuffleCost( - TargetTransformInfo::SK_Reverse, VTy)); - } else if (!Ty1->isVectorTy() && !Ty2->isVectorTy()) { - ESContrib = (int) TTI->getVectorInstrCost( - Instruction::InsertElement, VTy, 0); - ESContrib += (int) TTI->getVectorInstrCost( - Instruction::InsertElement, VTy, 1); - } else if (!Ty1->isVectorTy()) { - // O1 needs to be inserted into a vector of size O2, and then - // both need to be shuffled together. - ESContrib = (int) TTI->getVectorInstrCost( - Instruction::InsertElement, Ty2, 0); - ESContrib += (int) getInstrCost(Instruction::ShuffleVector, - VTy, Ty2); - } else if (!Ty2->isVectorTy()) { - // O2 needs to be inserted into a vector of size O1, and then - // both need to be shuffled together. - ESContrib = (int) TTI->getVectorInstrCost( - Instruction::InsertElement, Ty1, 0); - ESContrib += (int) getInstrCost(Instruction::ShuffleVector, - VTy, Ty1); - } else { - Type *TyBig = Ty1, *TySmall = Ty2; - if (Ty2->getVectorNumElements() > Ty1->getVectorNumElements()) - std::swap(TyBig, TySmall); - - ESContrib = (int) getInstrCost(Instruction::ShuffleVector, - VTy, TyBig); - if (TyBig != TySmall) - ESContrib += (int) getInstrCost(Instruction::ShuffleVector, - TyBig, TySmall); - } - - DEBUG(if (DebugPairSelection) dbgs() << "\tcost {" - << *O1 << " <-> " << *O2 << "} = " << - ESContrib << "\n"); - EffSize -= ESContrib; - IncomingPairs.insert(VP); - } - } - } - - if (!HasNontrivialInsts) { - DEBUG(if (DebugPairSelection) dbgs() << - "\tNo non-trivial instructions in DAG;" - " override to zero effective size\n"); - EffSize = 0; - } - } else { - for (DenseSet<ValuePair>::iterator S = PrunedDAG.begin(), - E = PrunedDAG.end(); S != E; ++S) - EffSize += (int) getDepthFactor(S->first); - } - - DEBUG(if (DebugPairSelection) - dbgs() << "BBV: found pruned DAG for pair {" - << *IJ.first << " <-> " << *IJ.second << "} of depth " << - MaxDepth << " and size " << PrunedDAG.size() << - " (effective size: " << EffSize << ")\n"); - if (((TTI && !UseChainDepthWithTI) || - MaxDepth >= Config.ReqChainDepth) && - EffSize > 0 && EffSize > BestEffSize) { - BestMaxDepth = MaxDepth; - BestEffSize = EffSize; - BestDAG = PrunedDAG; - } - } - } - - // Given the list of candidate pairs, this function selects those - // that will be fused into vector instructions. - void BBVectorize::choosePairs( - DenseMap<Value *, std::vector<Value *> > &CandidatePairs, - DenseSet<ValuePair> &CandidatePairsSet, - DenseMap<ValuePair, int> &CandidatePairCostSavings, - std::vector<Value *> &PairableInsts, - DenseSet<ValuePair> &FixedOrderPairs, - DenseMap<VPPair, unsigned> &PairConnectionTypes, - DenseMap<ValuePair, std::vector<ValuePair> > &ConnectedPairs, - DenseMap<ValuePair, std::vector<ValuePair> > &ConnectedPairDeps, - DenseSet<ValuePair> &PairableInstUsers, - DenseMap<Value *, Value *>& ChosenPairs) { - bool UseCycleCheck = - CandidatePairsSet.size() <= Config.MaxCandPairsForCycleCheck; - - DenseMap<Value *, std::vector<Value *> > CandidatePairs2; - for (DenseSet<ValuePair>::iterator I = CandidatePairsSet.begin(), - E = CandidatePairsSet.end(); I != E; ++I) { - std::vector<Value *> &JJ = CandidatePairs2[I->second]; - if (JJ.empty()) JJ.reserve(32); - JJ.push_back(I->first); - } - - DenseMap<ValuePair, std::vector<ValuePair> > PairableInstUserMap; - DenseSet<VPPair> PairableInstUserPairSet; - for (std::vector<Value *>::iterator I = PairableInsts.begin(), - E = PairableInsts.end(); I != E; ++I) { - // The number of possible pairings for this variable: - size_t NumChoices = CandidatePairs.lookup(*I).size(); - if (!NumChoices) continue; - - std::vector<Value *> &JJ = CandidatePairs[*I]; - - // The best pair to choose and its dag: - size_t BestMaxDepth = 0; - int BestEffSize = 0; - DenseSet<ValuePair> BestDAG; - findBestDAGFor(CandidatePairs, CandidatePairsSet, - CandidatePairCostSavings, - PairableInsts, FixedOrderPairs, PairConnectionTypes, - ConnectedPairs, ConnectedPairDeps, - PairableInstUsers, PairableInstUserMap, - PairableInstUserPairSet, ChosenPairs, - BestDAG, BestMaxDepth, BestEffSize, *I, JJ, - UseCycleCheck); - - if (BestDAG.empty()) - continue; - - // A dag has been chosen (or not) at this point. If no dag was - // chosen, then this instruction, I, cannot be paired (and is no longer - // considered). - - DEBUG(dbgs() << "BBV: selected pairs in the best DAG for: " - << *cast<Instruction>(*I) << "\n"); - - for (DenseSet<ValuePair>::iterator S = BestDAG.begin(), - SE2 = BestDAG.end(); S != SE2; ++S) { - // Insert the members of this dag into the list of chosen pairs. - ChosenPairs.insert(ValuePair(S->first, S->second)); - DEBUG(dbgs() << "BBV: selected pair: " << *S->first << " <-> " << - *S->second << "\n"); - - // Remove all candidate pairs that have values in the chosen dag. - std::vector<Value *> &KK = CandidatePairs[S->first]; - for (std::vector<Value *>::iterator K = KK.begin(), KE = KK.end(); - K != KE; ++K) { - if (*K == S->second) - continue; - - CandidatePairsSet.erase(ValuePair(S->first, *K)); - } - - std::vector<Value *> &LL = CandidatePairs2[S->second]; - for (std::vector<Value *>::iterator L = LL.begin(), LE = LL.end(); - L != LE; ++L) { - if (*L == S->first) - continue; - - CandidatePairsSet.erase(ValuePair(*L, S->second)); - } - - std::vector<Value *> &MM = CandidatePairs[S->second]; - for (std::vector<Value *>::iterator M = MM.begin(), ME = MM.end(); - M != ME; ++M) { - assert(*M != S->first && "Flipped pair in candidate list?"); - CandidatePairsSet.erase(ValuePair(S->second, *M)); - } - - std::vector<Value *> &NN = CandidatePairs2[S->first]; - for (std::vector<Value *>::iterator N = NN.begin(), NE = NN.end(); - N != NE; ++N) { - assert(*N != S->second && "Flipped pair in candidate list?"); - CandidatePairsSet.erase(ValuePair(*N, S->first)); - } - } - } - - DEBUG(dbgs() << "BBV: selected " << ChosenPairs.size() << " pairs.\n"); - } - - std::string getReplacementName(Instruction *I, bool IsInput, unsigned o, - unsigned n = 0) { - if (!I->hasName()) - return ""; - - return (I->getName() + (IsInput ? ".v.i" : ".v.r") + utostr(o) + - (n > 0 ? "." + utostr(n) : "")).str(); - } - - // Returns the value that is to be used as the pointer input to the vector - // instruction that fuses I with J. - Value *BBVectorize::getReplacementPointerInput(LLVMContext& Context, - Instruction *I, Instruction *J, unsigned o) { - Value *IPtr, *JPtr; - unsigned IAlignment, JAlignment, IAddressSpace, JAddressSpace; - int64_t OffsetInElmts; - - // Note: the analysis might fail here, that is why the pair order has - // been precomputed (OffsetInElmts must be unused here). - (void) getPairPtrInfo(I, J, IPtr, JPtr, IAlignment, JAlignment, - IAddressSpace, JAddressSpace, - OffsetInElmts, false); - - // The pointer value is taken to be the one with the lowest offset. - Value *VPtr = IPtr; - - Type *ArgTypeI = IPtr->getType()->getPointerElementType(); - Type *ArgTypeJ = JPtr->getType()->getPointerElementType(); - Type *VArgType = getVecTypeForPair(ArgTypeI, ArgTypeJ); - Type *VArgPtrType - = PointerType::get(VArgType, - IPtr->getType()->getPointerAddressSpace()); - return new BitCastInst(VPtr, VArgPtrType, getReplacementName(I, true, o), - /* insert before */ I); - } - - void BBVectorize::fillNewShuffleMask(LLVMContext& Context, Instruction *J, - unsigned MaskOffset, unsigned NumInElem, - unsigned NumInElem1, unsigned IdxOffset, - std::vector<Constant*> &Mask) { - unsigned NumElem1 = J->getType()->getVectorNumElements(); - for (unsigned v = 0; v < NumElem1; ++v) { - int m = cast<ShuffleVectorInst>(J)->getMaskValue(v); - if (m < 0) { - Mask[v+MaskOffset] = UndefValue::get(Type::getInt32Ty(Context)); - } else { - unsigned mm = m + (int) IdxOffset; - if (m >= (int) NumInElem1) - mm += (int) NumInElem; - - Mask[v+MaskOffset] = - ConstantInt::get(Type::getInt32Ty(Context), mm); - } - } - } - - // Returns the value that is to be used as the vector-shuffle mask to the - // vector instruction that fuses I with J. - Value *BBVectorize::getReplacementShuffleMask(LLVMContext& Context, - Instruction *I, Instruction *J) { - // This is the shuffle mask. We need to append the second - // mask to the first, and the numbers need to be adjusted. - - Type *ArgTypeI = I->getType(); - Type *ArgTypeJ = J->getType(); - Type *VArgType = getVecTypeForPair(ArgTypeI, ArgTypeJ); - - unsigned NumElemI = ArgTypeI->getVectorNumElements(); - - // Get the total number of elements in the fused vector type. - // By definition, this must equal the number of elements in - // the final mask. - unsigned NumElem = VArgType->getVectorNumElements(); - std::vector<Constant*> Mask(NumElem); - - Type *OpTypeI = I->getOperand(0)->getType(); - unsigned NumInElemI = OpTypeI->getVectorNumElements(); - Type *OpTypeJ = J->getOperand(0)->getType(); - unsigned NumInElemJ = OpTypeJ->getVectorNumElements(); - - // The fused vector will be: - // ----------------------------------------------------- - // | NumInElemI | NumInElemJ | NumInElemI | NumInElemJ | - // ----------------------------------------------------- - // from which we'll extract NumElem total elements (where the first NumElemI - // of them come from the mask in I and the remainder come from the mask - // in J. - - // For the mask from the first pair... - fillNewShuffleMask(Context, I, 0, NumInElemJ, NumInElemI, - 0, Mask); - - // For the mask from the second pair... - fillNewShuffleMask(Context, J, NumElemI, NumInElemI, NumInElemJ, - NumInElemI, Mask); - - return ConstantVector::get(Mask); - } - - bool BBVectorize::expandIEChain(LLVMContext& Context, Instruction *I, - Instruction *J, unsigned o, Value *&LOp, - unsigned numElemL, - Type *ArgTypeL, Type *ArgTypeH, - bool IBeforeJ, unsigned IdxOff) { - bool ExpandedIEChain = false; - if (InsertElementInst *LIE = dyn_cast<InsertElementInst>(LOp)) { - // If we have a pure insertelement chain, then this can be rewritten - // into a chain that directly builds the larger type. - if (isPureIEChain(LIE)) { - SmallVector<Value *, 8> VectElemts(numElemL, - UndefValue::get(ArgTypeL->getScalarType())); - InsertElementInst *LIENext = LIE; - do { - unsigned Idx = - cast<ConstantInt>(LIENext->getOperand(2))->getSExtValue(); - VectElemts[Idx] = LIENext->getOperand(1); - } while ((LIENext = - dyn_cast<InsertElementInst>(LIENext->getOperand(0)))); - - LIENext = nullptr; - Value *LIEPrev = UndefValue::get(ArgTypeH); - for (unsigned i = 0; i < numElemL; ++i) { - if (isa<UndefValue>(VectElemts[i])) continue; - LIENext = InsertElementInst::Create(LIEPrev, VectElemts[i], - ConstantInt::get(Type::getInt32Ty(Context), - i + IdxOff), - getReplacementName(IBeforeJ ? I : J, - true, o, i+1)); - LIENext->insertBefore(IBeforeJ ? J : I); - LIEPrev = LIENext; - } - - LOp = LIENext ? (Value*) LIENext : UndefValue::get(ArgTypeH); - ExpandedIEChain = true; - } - } - - return ExpandedIEChain; - } - - static unsigned getNumScalarElements(Type *Ty) { - if (VectorType *VecTy = dyn_cast<VectorType>(Ty)) - return VecTy->getNumElements(); - return 1; - } - - // Returns the value to be used as the specified operand of the vector - // instruction that fuses I with J. - Value *BBVectorize::getReplacementInput(LLVMContext& Context, Instruction *I, - Instruction *J, unsigned o, bool IBeforeJ) { - Value *CV0 = ConstantInt::get(Type::getInt32Ty(Context), 0); - Value *CV1 = ConstantInt::get(Type::getInt32Ty(Context), 1); - - // Compute the fused vector type for this operand - Type *ArgTypeI = I->getOperand(o)->getType(); - Type *ArgTypeJ = J->getOperand(o)->getType(); - VectorType *VArgType = getVecTypeForPair(ArgTypeI, ArgTypeJ); - - Instruction *L = I, *H = J; - Type *ArgTypeL = ArgTypeI, *ArgTypeH = ArgTypeJ; - - unsigned numElemL = getNumScalarElements(ArgTypeL); - unsigned numElemH = getNumScalarElements(ArgTypeH); - - Value *LOp = L->getOperand(o); - Value *HOp = H->getOperand(o); - unsigned numElem = VArgType->getNumElements(); - - // First, we check if we can reuse the "original" vector outputs (if these - // exist). We might need a shuffle. - ExtractElementInst *LEE = dyn_cast<ExtractElementInst>(LOp); - ExtractElementInst *HEE = dyn_cast<ExtractElementInst>(HOp); - ShuffleVectorInst *LSV = dyn_cast<ShuffleVectorInst>(LOp); - ShuffleVectorInst *HSV = dyn_cast<ShuffleVectorInst>(HOp); - - // FIXME: If we're fusing shuffle instructions, then we can't apply this - // optimization. The input vectors to the shuffle might be a different - // length from the shuffle outputs. Unfortunately, the replacement - // shuffle mask has already been formed, and the mask entries are sensitive - // to the sizes of the inputs. - bool IsSizeChangeShuffle = - isa<ShuffleVectorInst>(L) && - (LOp->getType() != L->getType() || HOp->getType() != H->getType()); - - if ((LEE || LSV) && (HEE || HSV) && !IsSizeChangeShuffle) { - // We can have at most two unique vector inputs. - bool CanUseInputs = true; - Value *I1, *I2 = nullptr; - if (LEE) { - I1 = LEE->getOperand(0); - } else { - I1 = LSV->getOperand(0); - I2 = LSV->getOperand(1); - if (I2 == I1 || isa<UndefValue>(I2)) - I2 = nullptr; - } - - if (HEE) { - Value *I3 = HEE->getOperand(0); - if (!I2 && I3 != I1) - I2 = I3; - else if (I3 != I1 && I3 != I2) - CanUseInputs = false; - } else { - Value *I3 = HSV->getOperand(0); - if (!I2 && I3 != I1) - I2 = I3; - else if (I3 != I1 && I3 != I2) - CanUseInputs = false; - - if (CanUseInputs) { - Value *I4 = HSV->getOperand(1); - if (!isa<UndefValue>(I4)) { - if (!I2 && I4 != I1) - I2 = I4; - else if (I4 != I1 && I4 != I2) - CanUseInputs = false; - } - } - } - - if (CanUseInputs) { - unsigned LOpElem = - cast<Instruction>(LOp)->getOperand(0)->getType() - ->getVectorNumElements(); - - unsigned HOpElem = - cast<Instruction>(HOp)->getOperand(0)->getType() - ->getVectorNumElements(); - - // We have one or two input vectors. We need to map each index of the - // operands to the index of the original vector. - SmallVector<std::pair<int, int>, 8> II(numElem); - for (unsigned i = 0; i < numElemL; ++i) { - int Idx, INum; - if (LEE) { - Idx = - cast<ConstantInt>(LEE->getOperand(1))->getSExtValue(); - INum = LEE->getOperand(0) == I1 ? 0 : 1; - } else { - Idx = LSV->getMaskValue(i); - if (Idx < (int) LOpElem) { - INum = LSV->getOperand(0) == I1 ? 0 : 1; - } else { - Idx -= LOpElem; - INum = LSV->getOperand(1) == I1 ? 0 : 1; - } - } - - II[i] = std::pair<int, int>(Idx, INum); - } - for (unsigned i = 0; i < numElemH; ++i) { - int Idx, INum; - if (HEE) { - Idx = - cast<ConstantInt>(HEE->getOperand(1))->getSExtValue(); - INum = HEE->getOperand(0) == I1 ? 0 : 1; - } else { - Idx = HSV->getMaskValue(i); - if (Idx < (int) HOpElem) { - INum = HSV->getOperand(0) == I1 ? 0 : 1; - } else { - Idx -= HOpElem; - INum = HSV->getOperand(1) == I1 ? 0 : 1; - } - } - - II[i + numElemL] = std::pair<int, int>(Idx, INum); - } - - // We now have an array which tells us from which index of which - // input vector each element of the operand comes. - VectorType *I1T = cast<VectorType>(I1->getType()); - unsigned I1Elem = I1T->getNumElements(); - - if (!I2) { - // In this case there is only one underlying vector input. Check for - // the trivial case where we can use the input directly. - if (I1Elem == numElem) { - bool ElemInOrder = true; - for (unsigned i = 0; i < numElem; ++i) { - if (II[i].first != (int) i && II[i].first != -1) { - ElemInOrder = false; - break; - } - } - - if (ElemInOrder) - return I1; - } - - // A shuffle is needed. - std::vector<Constant *> Mask(numElem); - for (unsigned i = 0; i < numElem; ++i) { - int Idx = II[i].first; - if (Idx == -1) - Mask[i] = UndefValue::get(Type::getInt32Ty(Context)); - else - Mask[i] = ConstantInt::get(Type::getInt32Ty(Context), Idx); - } - - Instruction *S = - new ShuffleVectorInst(I1, UndefValue::get(I1T), - ConstantVector::get(Mask), - getReplacementName(IBeforeJ ? I : J, - true, o)); - S->insertBefore(IBeforeJ ? J : I); - return S; - } - - VectorType *I2T = cast<VectorType>(I2->getType()); - unsigned I2Elem = I2T->getNumElements(); - - // This input comes from two distinct vectors. The first step is to - // make sure that both vectors are the same length. If not, the - // smaller one will need to grow before they can be shuffled together. - if (I1Elem < I2Elem) { - std::vector<Constant *> Mask(I2Elem); - unsigned v = 0; - for (; v < I1Elem; ++v) - Mask[v] = ConstantInt::get(Type::getInt32Ty(Context), v); - for (; v < I2Elem; ++v) - Mask[v] = UndefValue::get(Type::getInt32Ty(Context)); - - Instruction *NewI1 = - new ShuffleVectorInst(I1, UndefValue::get(I1T), - ConstantVector::get(Mask), - getReplacementName(IBeforeJ ? I : J, - true, o, 1)); - NewI1->insertBefore(IBeforeJ ? J : I); - I1 = NewI1; - I1Elem = I2Elem; - } else if (I1Elem > I2Elem) { - std::vector<Constant *> Mask(I1Elem); - unsigned v = 0; - for (; v < I2Elem; ++v) - Mask[v] = ConstantInt::get(Type::getInt32Ty(Context), v); - for (; v < I1Elem; ++v) - Mask[v] = UndefValue::get(Type::getInt32Ty(Context)); - - Instruction *NewI2 = - new ShuffleVectorInst(I2, UndefValue::get(I2T), - ConstantVector::get(Mask), - getReplacementName(IBeforeJ ? I : J, - true, o, 1)); - NewI2->insertBefore(IBeforeJ ? J : I); - I2 = NewI2; - } - - // Now that both I1 and I2 are the same length we can shuffle them - // together (and use the result). - std::vector<Constant *> Mask(numElem); - for (unsigned v = 0; v < numElem; ++v) { - if (II[v].first == -1) { - Mask[v] = UndefValue::get(Type::getInt32Ty(Context)); - } else { - int Idx = II[v].first + II[v].second * I1Elem; - Mask[v] = ConstantInt::get(Type::getInt32Ty(Context), Idx); - } - } - - Instruction *NewOp = - new ShuffleVectorInst(I1, I2, ConstantVector::get(Mask), - getReplacementName(IBeforeJ ? I : J, true, o)); - NewOp->insertBefore(IBeforeJ ? J : I); - return NewOp; - } - } - - Type *ArgType = ArgTypeL; - if (numElemL < numElemH) { - if (numElemL == 1 && expandIEChain(Context, I, J, o, HOp, numElemH, - ArgTypeL, VArgType, IBeforeJ, 1)) { - // This is another short-circuit case: we're combining a scalar into - // a vector that is formed by an IE chain. We've just expanded the IE - // chain, now insert the scalar and we're done. - - Instruction *S = InsertElementInst::Create(HOp, LOp, CV0, - getReplacementName(IBeforeJ ? I : J, true, o)); - S->insertBefore(IBeforeJ ? J : I); - return S; - } else if (!expandIEChain(Context, I, J, o, LOp, numElemL, ArgTypeL, - ArgTypeH, IBeforeJ)) { - // The two vector inputs to the shuffle must be the same length, - // so extend the smaller vector to be the same length as the larger one. - Instruction *NLOp; - if (numElemL > 1) { - - std::vector<Constant *> Mask(numElemH); - unsigned v = 0; - for (; v < numElemL; ++v) - Mask[v] = ConstantInt::get(Type::getInt32Ty(Context), v); - for (; v < numElemH; ++v) - Mask[v] = UndefValue::get(Type::getInt32Ty(Context)); - - NLOp = new ShuffleVectorInst(LOp, UndefValue::get(ArgTypeL), - ConstantVector::get(Mask), - getReplacementName(IBeforeJ ? I : J, - true, o, 1)); - } else { - NLOp = InsertElementInst::Create(UndefValue::get(ArgTypeH), LOp, CV0, - getReplacementName(IBeforeJ ? I : J, - true, o, 1)); - } - - NLOp->insertBefore(IBeforeJ ? J : I); - LOp = NLOp; - } - - ArgType = ArgTypeH; - } else if (numElemL > numElemH) { - if (numElemH == 1 && expandIEChain(Context, I, J, o, LOp, numElemL, - ArgTypeH, VArgType, IBeforeJ)) { - Instruction *S = - InsertElementInst::Create(LOp, HOp, - ConstantInt::get(Type::getInt32Ty(Context), - numElemL), - getReplacementName(IBeforeJ ? I : J, - true, o)); - S->insertBefore(IBeforeJ ? J : I); - return S; - } else if (!expandIEChain(Context, I, J, o, HOp, numElemH, ArgTypeH, - ArgTypeL, IBeforeJ)) { - Instruction *NHOp; - if (numElemH > 1) { - std::vector<Constant *> Mask(numElemL); - unsigned v = 0; - for (; v < numElemH; ++v) - Mask[v] = ConstantInt::get(Type::getInt32Ty(Context), v); - for (; v < numElemL; ++v) - Mask[v] = UndefValue::get(Type::getInt32Ty(Context)); - - NHOp = new ShuffleVectorInst(HOp, UndefValue::get(ArgTypeH), - ConstantVector::get(Mask), - getReplacementName(IBeforeJ ? I : J, - true, o, 1)); - } else { - NHOp = InsertElementInst::Create(UndefValue::get(ArgTypeL), HOp, CV0, - getReplacementName(IBeforeJ ? I : J, - true, o, 1)); - } - - NHOp->insertBefore(IBeforeJ ? J : I); - HOp = NHOp; - } - } - - if (ArgType->isVectorTy()) { - unsigned numElem = VArgType->getVectorNumElements(); - std::vector<Constant*> Mask(numElem); - for (unsigned v = 0; v < numElem; ++v) { - unsigned Idx = v; - // If the low vector was expanded, we need to skip the extra - // undefined entries. - if (v >= numElemL && numElemH > numElemL) - Idx += (numElemH - numElemL); - Mask[v] = ConstantInt::get(Type::getInt32Ty(Context), Idx); - } - - Instruction *BV = new ShuffleVectorInst(LOp, HOp, - ConstantVector::get(Mask), - getReplacementName(IBeforeJ ? I : J, true, o)); - BV->insertBefore(IBeforeJ ? J : I); - return BV; - } - - Instruction *BV1 = InsertElementInst::Create( - UndefValue::get(VArgType), LOp, CV0, - getReplacementName(IBeforeJ ? I : J, - true, o, 1)); - BV1->insertBefore(IBeforeJ ? J : I); - Instruction *BV2 = InsertElementInst::Create(BV1, HOp, CV1, - getReplacementName(IBeforeJ ? I : J, - true, o, 2)); - BV2->insertBefore(IBeforeJ ? J : I); - return BV2; - } - - // This function creates an array of values that will be used as the inputs - // to the vector instruction that fuses I with J. - void BBVectorize::getReplacementInputsForPair(LLVMContext& Context, - Instruction *I, Instruction *J, - SmallVectorImpl<Value *> &ReplacedOperands, - bool IBeforeJ) { - unsigned NumOperands = I->getNumOperands(); - - for (unsigned p = 0, o = NumOperands-1; p < NumOperands; ++p, --o) { - // Iterate backward so that we look at the store pointer - // first and know whether or not we need to flip the inputs. - - if (isa<LoadInst>(I) || (o == 1 && isa<StoreInst>(I))) { - // This is the pointer for a load/store instruction. - ReplacedOperands[o] = getReplacementPointerInput(Context, I, J, o); - continue; - } else if (isa<CallInst>(I)) { - Function *F = cast<CallInst>(I)->getCalledFunction(); - Intrinsic::ID IID = F->getIntrinsicID(); - if (o == NumOperands-1) { - BasicBlock &BB = *I->getParent(); - - Module *M = BB.getParent()->getParent(); - Type *ArgTypeI = I->getType(); - Type *ArgTypeJ = J->getType(); - Type *VArgType = getVecTypeForPair(ArgTypeI, ArgTypeJ); - - ReplacedOperands[o] = Intrinsic::getDeclaration(M, IID, VArgType); - continue; - } else if ((IID == Intrinsic::powi || IID == Intrinsic::ctlz || - IID == Intrinsic::cttz) && o == 1) { - // The second argument of powi/ctlz/cttz is a single integer/constant - // and we've already checked that both arguments are equal. - // As a result, we just keep I's second argument. - ReplacedOperands[o] = I->getOperand(o); - continue; - } - } else if (isa<ShuffleVectorInst>(I) && o == NumOperands-1) { - ReplacedOperands[o] = getReplacementShuffleMask(Context, I, J); - continue; - } - - ReplacedOperands[o] = getReplacementInput(Context, I, J, o, IBeforeJ); - } - } - - // This function creates two values that represent the outputs of the - // original I and J instructions. These are generally vector shuffles - // or extracts. In many cases, these will end up being unused and, thus, - // eliminated by later passes. - void BBVectorize::replaceOutputsOfPair(LLVMContext& Context, Instruction *I, - Instruction *J, Instruction *K, - Instruction *&InsertionPt, - Instruction *&K1, Instruction *&K2) { - if (isa<StoreInst>(I)) - return; - - Type *IType = I->getType(); - Type *JType = J->getType(); - - VectorType *VType = getVecTypeForPair(IType, JType); - unsigned numElem = VType->getNumElements(); - - unsigned numElemI = getNumScalarElements(IType); - unsigned numElemJ = getNumScalarElements(JType); - - if (IType->isVectorTy()) { - std::vector<Constant *> Mask1(numElemI), Mask2(numElemI); - for (unsigned v = 0; v < numElemI; ++v) { - Mask1[v] = ConstantInt::get(Type::getInt32Ty(Context), v); - Mask2[v] = ConstantInt::get(Type::getInt32Ty(Context), numElemJ + v); - } - - K1 = new ShuffleVectorInst(K, UndefValue::get(VType), - ConstantVector::get(Mask1), - getReplacementName(K, false, 1)); - } else { - Value *CV0 = ConstantInt::get(Type::getInt32Ty(Context), 0); - K1 = ExtractElementInst::Create(K, CV0, getReplacementName(K, false, 1)); - } - - if (JType->isVectorTy()) { - std::vector<Constant *> Mask1(numElemJ), Mask2(numElemJ); - for (unsigned v = 0; v < numElemJ; ++v) { - Mask1[v] = ConstantInt::get(Type::getInt32Ty(Context), v); - Mask2[v] = ConstantInt::get(Type::getInt32Ty(Context), numElemI + v); - } - - K2 = new ShuffleVectorInst(K, UndefValue::get(VType), - ConstantVector::get(Mask2), - getReplacementName(K, false, 2)); - } else { - Value *CV1 = ConstantInt::get(Type::getInt32Ty(Context), numElem - 1); - K2 = ExtractElementInst::Create(K, CV1, getReplacementName(K, false, 2)); - } - - K1->insertAfter(K); - K2->insertAfter(K1); - InsertionPt = K2; - } - - // Move all uses of the function I (including pairing-induced uses) after J. - bool BBVectorize::canMoveUsesOfIAfterJ(BasicBlock &BB, - DenseSet<ValuePair> &LoadMoveSetPairs, - Instruction *I, Instruction *J) { - // Skip to the first instruction past I. - BasicBlock::iterator L = std::next(BasicBlock::iterator(I)); - - DenseSet<Value *> Users; - AliasSetTracker WriteSet(*AA); - if (I->mayWriteToMemory()) WriteSet.add(I); - - for (; cast<Instruction>(L) != J; ++L) - (void)trackUsesOfI(Users, WriteSet, I, &*L, true, &LoadMoveSetPairs); - - assert(cast<Instruction>(L) == J && - "Tracking has not proceeded far enough to check for dependencies"); - // If J is now in the use set of I, then trackUsesOfI will return true - // and we have a dependency cycle (and the fusing operation must abort). - return !trackUsesOfI(Users, WriteSet, I, J, true, &LoadMoveSetPairs); - } - - // Move all uses of the function I (including pairing-induced uses) after J. - void BBVectorize::moveUsesOfIAfterJ(BasicBlock &BB, - DenseSet<ValuePair> &LoadMoveSetPairs, - Instruction *&InsertionPt, - Instruction *I, Instruction *J) { - // Skip to the first instruction past I. - BasicBlock::iterator L = std::next(BasicBlock::iterator(I)); - - DenseSet<Value *> Users; - AliasSetTracker WriteSet(*AA); - if (I->mayWriteToMemory()) WriteSet.add(I); - - for (; cast<Instruction>(L) != J;) { - if (trackUsesOfI(Users, WriteSet, I, &*L, true, &LoadMoveSetPairs)) { - // Move this instruction - Instruction *InstToMove = &*L++; - - DEBUG(dbgs() << "BBV: moving: " << *InstToMove << - " to after " << *InsertionPt << "\n"); - InstToMove->removeFromParent(); - InstToMove->insertAfter(InsertionPt); - InsertionPt = InstToMove; - } else { - ++L; - } - } - } - - // Collect all load instruction that are in the move set of a given first - // pair member. These loads depend on the first instruction, I, and so need - // to be moved after J (the second instruction) when the pair is fused. - void BBVectorize::collectPairLoadMoveSet(BasicBlock &BB, - DenseMap<Value *, Value *> &ChosenPairs, - DenseMap<Value *, std::vector<Value *> > &LoadMoveSet, - DenseSet<ValuePair> &LoadMoveSetPairs, - Instruction *I) { - // Skip to the first instruction past I. - BasicBlock::iterator L = std::next(BasicBlock::iterator(I)); - - DenseSet<Value *> Users; - AliasSetTracker WriteSet(*AA); - if (I->mayWriteToMemory()) WriteSet.add(I); - - // Note: We cannot end the loop when we reach J because J could be moved - // farther down the use chain by another instruction pairing. Also, J - // could be before I if this is an inverted input. - for (BasicBlock::iterator E = BB.end(); L != E; ++L) { - if (trackUsesOfI(Users, WriteSet, I, &*L)) { - if (L->mayReadFromMemory()) { - LoadMoveSet[&*L].push_back(I); - LoadMoveSetPairs.insert(ValuePair(&*L, I)); - } - } - } - } - - // In cases where both load/stores and the computation of their pointers - // are chosen for vectorization, we can end up in a situation where the - // aliasing analysis starts returning different query results as the - // process of fusing instruction pairs continues. Because the algorithm - // relies on finding the same use dags here as were found earlier, we'll - // need to precompute the necessary aliasing information here and then - // manually update it during the fusion process. - void BBVectorize::collectLoadMoveSet(BasicBlock &BB, - std::vector<Value *> &PairableInsts, - DenseMap<Value *, Value *> &ChosenPairs, - DenseMap<Value *, std::vector<Value *> > &LoadMoveSet, - DenseSet<ValuePair> &LoadMoveSetPairs) { - for (std::vector<Value *>::iterator PI = PairableInsts.begin(), - PIE = PairableInsts.end(); PI != PIE; ++PI) { - DenseMap<Value *, Value *>::iterator P = ChosenPairs.find(*PI); - if (P == ChosenPairs.end()) continue; - - Instruction *I = cast<Instruction>(P->first); - collectPairLoadMoveSet(BB, ChosenPairs, LoadMoveSet, - LoadMoveSetPairs, I); - } - } - - // This function fuses the chosen instruction pairs into vector instructions, - // taking care preserve any needed scalar outputs and, then, it reorders the - // remaining instructions as needed (users of the first member of the pair - // need to be moved to after the location of the second member of the pair - // because the vector instruction is inserted in the location of the pair's - // second member). - void BBVectorize::fuseChosenPairs(BasicBlock &BB, - std::vector<Value *> &PairableInsts, - DenseMap<Value *, Value *> &ChosenPairs, - DenseSet<ValuePair> &FixedOrderPairs, - DenseMap<VPPair, unsigned> &PairConnectionTypes, - DenseMap<ValuePair, std::vector<ValuePair> > &ConnectedPairs, - DenseMap<ValuePair, std::vector<ValuePair> > &ConnectedPairDeps) { - LLVMContext& Context = BB.getContext(); - - // During the vectorization process, the order of the pairs to be fused - // could be flipped. So we'll add each pair, flipped, into the ChosenPairs - // list. After a pair is fused, the flipped pair is removed from the list. - DenseSet<ValuePair> FlippedPairs; - for (DenseMap<Value *, Value *>::iterator P = ChosenPairs.begin(), - E = ChosenPairs.end(); P != E; ++P) - FlippedPairs.insert(ValuePair(P->second, P->first)); - for (DenseSet<ValuePair>::iterator P = FlippedPairs.begin(), - E = FlippedPairs.end(); P != E; ++P) - ChosenPairs.insert(*P); - - DenseMap<Value *, std::vector<Value *> > LoadMoveSet; - DenseSet<ValuePair> LoadMoveSetPairs; - collectLoadMoveSet(BB, PairableInsts, ChosenPairs, - LoadMoveSet, LoadMoveSetPairs); - - DEBUG(dbgs() << "BBV: initial: \n" << BB << "\n"); - - for (BasicBlock::iterator PI = BB.getFirstInsertionPt(); PI != BB.end();) { - DenseMap<Value *, Value *>::iterator P = ChosenPairs.find(&*PI); - if (P == ChosenPairs.end()) { - ++PI; - continue; - } - - if (getDepthFactor(P->first) == 0) { - // These instructions are not really fused, but are tracked as though - // they are. Any case in which it would be interesting to fuse them - // will be taken care of by InstCombine. - --NumFusedOps; - ++PI; - continue; - } - - Instruction *I = cast<Instruction>(P->first), - *J = cast<Instruction>(P->second); - - DEBUG(dbgs() << "BBV: fusing: " << *I << - " <-> " << *J << "\n"); - - // Remove the pair and flipped pair from the list. - DenseMap<Value *, Value *>::iterator FP = ChosenPairs.find(P->second); - assert(FP != ChosenPairs.end() && "Flipped pair not found in list"); - ChosenPairs.erase(FP); - ChosenPairs.erase(P); - - if (!canMoveUsesOfIAfterJ(BB, LoadMoveSetPairs, I, J)) { - DEBUG(dbgs() << "BBV: fusion of: " << *I << - " <-> " << *J << - " aborted because of non-trivial dependency cycle\n"); - --NumFusedOps; - ++PI; - continue; - } - - // If the pair must have the other order, then flip it. - bool FlipPairOrder = FixedOrderPairs.count(ValuePair(J, I)); - if (!FlipPairOrder && !FixedOrderPairs.count(ValuePair(I, J))) { - // This pair does not have a fixed order, and so we might want to - // flip it if that will yield fewer shuffles. We count the number - // of dependencies connected via swaps, and those directly connected, - // and flip the order if the number of swaps is greater. - bool OrigOrder = true; - DenseMap<ValuePair, std::vector<ValuePair> >::iterator IJ = - ConnectedPairDeps.find(ValuePair(I, J)); - if (IJ == ConnectedPairDeps.end()) { - IJ = ConnectedPairDeps.find(ValuePair(J, I)); - OrigOrder = false; - } - - if (IJ != ConnectedPairDeps.end()) { - unsigned NumDepsDirect = 0, NumDepsSwap = 0; - for (std::vector<ValuePair>::iterator T = IJ->second.begin(), - TE = IJ->second.end(); T != TE; ++T) { - VPPair Q(IJ->first, *T); - DenseMap<VPPair, unsigned>::iterator R = - PairConnectionTypes.find(VPPair(Q.second, Q.first)); - assert(R != PairConnectionTypes.end() && - "Cannot find pair connection type"); - if (R->second == PairConnectionDirect) - ++NumDepsDirect; - else if (R->second == PairConnectionSwap) - ++NumDepsSwap; - } - - if (!OrigOrder) - std::swap(NumDepsDirect, NumDepsSwap); - - if (NumDepsSwap > NumDepsDirect) { - FlipPairOrder = true; - DEBUG(dbgs() << "BBV: reordering pair: " << *I << - " <-> " << *J << "\n"); - } - } - } - - Instruction *L = I, *H = J; - if (FlipPairOrder) - std::swap(H, L); - - // If the pair being fused uses the opposite order from that in the pair - // connection map, then we need to flip the types. - DenseMap<ValuePair, std::vector<ValuePair> >::iterator HL = - ConnectedPairs.find(ValuePair(H, L)); - if (HL != ConnectedPairs.end()) - for (std::vector<ValuePair>::iterator T = HL->second.begin(), - TE = HL->second.end(); T != TE; ++T) { - VPPair Q(HL->first, *T); - DenseMap<VPPair, unsigned>::iterator R = PairConnectionTypes.find(Q); - assert(R != PairConnectionTypes.end() && - "Cannot find pair connection type"); - if (R->second == PairConnectionDirect) - R->second = PairConnectionSwap; - else if (R->second == PairConnectionSwap) - R->second = PairConnectionDirect; - } - - bool LBeforeH = !FlipPairOrder; - unsigned NumOperands = I->getNumOperands(); - SmallVector<Value *, 3> ReplacedOperands(NumOperands); - getReplacementInputsForPair(Context, L, H, ReplacedOperands, - LBeforeH); - - // Make a copy of the original operation, change its type to the vector - // type and replace its operands with the vector operands. - Instruction *K = L->clone(); - if (L->hasName()) - K->takeName(L); - else if (H->hasName()) - K->takeName(H); - - if (auto CS = CallSite(K)) { - SmallVector<Type *, 3> Tys; - FunctionType *Old = CS.getFunctionType(); - unsigned NumOld = Old->getNumParams(); - assert(NumOld <= ReplacedOperands.size()); - for (unsigned i = 0; i != NumOld; ++i) - Tys.push_back(ReplacedOperands[i]->getType()); - CS.mutateFunctionType( - FunctionType::get(getVecTypeForPair(L->getType(), H->getType()), - Tys, Old->isVarArg())); - } else if (!isa<StoreInst>(K)) - K->mutateType(getVecTypeForPair(L->getType(), H->getType())); - - unsigned KnownIDs[] = {LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, - LLVMContext::MD_noalias, LLVMContext::MD_fpmath, - LLVMContext::MD_invariant_group}; - combineMetadata(K, H, KnownIDs); - K->andIRFlags(H); - - for (unsigned o = 0; o < NumOperands; ++o) - K->setOperand(o, ReplacedOperands[o]); - - K->insertAfter(J); - - // Instruction insertion point: - Instruction *InsertionPt = K; - Instruction *K1 = nullptr, *K2 = nullptr; - replaceOutputsOfPair(Context, L, H, K, InsertionPt, K1, K2); - - // The use dag of the first original instruction must be moved to after - // the location of the second instruction. The entire use dag of the - // first instruction is disjoint from the input dag of the second - // (by definition), and so commutes with it. - - moveUsesOfIAfterJ(BB, LoadMoveSetPairs, InsertionPt, I, J); - - if (!isa<StoreInst>(I)) { - L->replaceAllUsesWith(K1); - H->replaceAllUsesWith(K2); - } - - // Instructions that may read from memory may be in the load move set. - // Once an instruction is fused, we no longer need its move set, and so - // the values of the map never need to be updated. However, when a load - // is fused, we need to merge the entries from both instructions in the - // pair in case those instructions were in the move set of some other - // yet-to-be-fused pair. The loads in question are the keys of the map. - if (I->mayReadFromMemory()) { - std::vector<ValuePair> NewSetMembers; - DenseMap<Value *, std::vector<Value *> >::iterator II = - LoadMoveSet.find(I); - if (II != LoadMoveSet.end()) - for (std::vector<Value *>::iterator N = II->second.begin(), - NE = II->second.end(); N != NE; ++N) - NewSetMembers.push_back(ValuePair(K, *N)); - DenseMap<Value *, std::vector<Value *> >::iterator JJ = - LoadMoveSet.find(J); - if (JJ != LoadMoveSet.end()) - for (std::vector<Value *>::iterator N = JJ->second.begin(), - NE = JJ->second.end(); N != NE; ++N) - NewSetMembers.push_back(ValuePair(K, *N)); - for (std::vector<ValuePair>::iterator A = NewSetMembers.begin(), - AE = NewSetMembers.end(); A != AE; ++A) { - LoadMoveSet[A->first].push_back(A->second); - LoadMoveSetPairs.insert(*A); - } - } - - // Before removing I, set the iterator to the next instruction. - PI = std::next(BasicBlock::iterator(I)); - if (cast<Instruction>(PI) == J) - ++PI; - - SE->forgetValue(I); - SE->forgetValue(J); - I->eraseFromParent(); - J->eraseFromParent(); - - DEBUG(if (PrintAfterEveryPair) dbgs() << "BBV: block is now: \n" << - BB << "\n"); - } - - DEBUG(dbgs() << "BBV: final: \n" << BB << "\n"); - } -} - -char BBVectorize::ID = 0; -static const char bb_vectorize_name[] = "Basic-Block Vectorization"; -INITIALIZE_PASS_BEGIN(BBVectorize, BBV_NAME, bb_vectorize_name, false, false) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) -INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) -INITIALIZE_PASS_DEPENDENCY(SCEVAAWrapperPass) -INITIALIZE_PASS_END(BBVectorize, BBV_NAME, bb_vectorize_name, false, false) - -BasicBlockPass *llvm::createBBVectorizePass(const VectorizeConfig &C) { - return new BBVectorize(C); -} - -bool -llvm::vectorizeBasicBlock(Pass *P, BasicBlock &BB, const VectorizeConfig &C) { - BBVectorize BBVectorizer(P, *BB.getParent(), C); - return BBVectorizer.vectorizeBB(BB); -} - -//===----------------------------------------------------------------------===// -VectorizeConfig::VectorizeConfig() { - VectorBits = ::VectorBits; - VectorizeBools = !::NoBools; - VectorizeInts = !::NoInts; - VectorizeFloats = !::NoFloats; - VectorizePointers = !::NoPointers; - VectorizeCasts = !::NoCasts; - VectorizeMath = !::NoMath; - VectorizeBitManipulations = !::NoBitManipulation; - VectorizeFMA = !::NoFMA; - VectorizeSelect = !::NoSelect; - VectorizeCmp = !::NoCmp; - VectorizeGEP = !::NoGEP; - VectorizeMemOps = !::NoMemOps; - AlignedOnly = ::AlignedOnly; - ReqChainDepth= ::ReqChainDepth; - SearchLimit = ::SearchLimit; - MaxCandPairsForCycleCheck = ::MaxCandPairsForCycleCheck; - SplatBreaksChain = ::SplatBreaksChain; - MaxInsts = ::MaxInsts; - MaxPairs = ::MaxPairs; - MaxIter = ::MaxIter; - Pow2LenOnly = ::Pow2LenOnly; - NoMemOpBoost = ::NoMemOpBoost; - FastDep = ::FastDep; -} diff --git a/contrib/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp b/contrib/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp index c44a393..9cf6638 100644 --- a/contrib/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp +++ b/contrib/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp @@ -30,6 +30,7 @@ #include "llvm/IR/Value.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/KnownBits.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Vectorize.h" @@ -65,7 +66,9 @@ public: bool run(); private: - Value *getPointerOperand(Value *I); + Value *getPointerOperand(Value *I) const; + + GetElementPtrInst *getSourceGEP(Value *Src) const; unsigned getPointerAddressSpace(Value *I); @@ -215,7 +218,7 @@ bool Vectorizer::run() { return Changed; } -Value *Vectorizer::getPointerOperand(Value *I) { +Value *Vectorizer::getPointerOperand(Value *I) const { if (LoadInst *LI = dyn_cast<LoadInst>(I)) return LI->getPointerOperand(); if (StoreInst *SI = dyn_cast<StoreInst>(I)) @@ -231,6 +234,19 @@ unsigned Vectorizer::getPointerAddressSpace(Value *I) { return -1; } +GetElementPtrInst *Vectorizer::getSourceGEP(Value *Src) const { + // First strip pointer bitcasts. Make sure pointee size is the same with + // and without casts. + // TODO: a stride set by the add instruction below can match the difference + // in pointee type size here. Currently it will not be vectorized. + Value *SrcPtr = getPointerOperand(Src); + Value *SrcBase = SrcPtr->stripPointerCasts(); + if (DL.getTypeStoreSize(SrcPtr->getType()->getPointerElementType()) == + DL.getTypeStoreSize(SrcBase->getType()->getPointerElementType())) + SrcPtr = SrcBase; + return dyn_cast<GetElementPtrInst>(SrcPtr); +} + // FIXME: Merge with llvm::isConsecutiveAccess bool Vectorizer::isConsecutiveAccess(Value *A, Value *B) { Value *PtrA = getPointerOperand(A); @@ -283,8 +299,8 @@ bool Vectorizer::isConsecutiveAccess(Value *A, Value *B) { // Look through GEPs after checking they're the same except for the last // index. - GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(getPointerOperand(A)); - GetElementPtrInst *GEPB = dyn_cast<GetElementPtrInst>(getPointerOperand(B)); + GetElementPtrInst *GEPA = getSourceGEP(A); + GetElementPtrInst *GEPB = getSourceGEP(B); if (!GEPA || !GEPB || GEPA->getNumOperands() != GEPB->getNumOperands()) return false; unsigned FinalIndex = GEPA->getNumOperands() - 1; @@ -328,11 +344,9 @@ bool Vectorizer::isConsecutiveAccess(Value *A, Value *B) { // If any bits are known to be zero other than the sign bit in OpA, we can // add 1 to it while guaranteeing no overflow of any sort. if (!Safe) { - APInt KnownZero(BitWidth, 0); - APInt KnownOne(BitWidth, 0); - computeKnownBits(OpA, KnownZero, KnownOne, DL, 0, nullptr, OpA, &DT); - KnownZero &= ~APInt::getHighBitsSet(BitWidth, 1); - if (KnownZero != 0) + KnownBits Known(BitWidth); + computeKnownBits(OpA, Known, DL, 0, nullptr, OpA, &DT); + if (Known.countMaxTrailingOnes() < (BitWidth - 1)) Safe = true; } @@ -432,9 +446,12 @@ Vectorizer::splitOddVectorElts(ArrayRef<Instruction *> Chain, unsigned ElementSizeBytes = ElementSizeBits / 8; unsigned SizeBytes = ElementSizeBytes * Chain.size(); unsigned NumLeft = (SizeBytes - (SizeBytes % 4)) / ElementSizeBytes; - if (NumLeft == Chain.size()) - --NumLeft; - else if (NumLeft == 0) + if (NumLeft == Chain.size()) { + if ((NumLeft & 1) == 0) + NumLeft /= 2; // Split even in half + else + --NumLeft; // Split off last element + } else if (NumLeft == 0) NumLeft = 1; return std::make_pair(Chain.slice(0, NumLeft), Chain.slice(NumLeft)); } @@ -588,7 +605,7 @@ Vectorizer::collectInstructions(BasicBlock *BB) { continue; // Make sure all the users of a vector are constant-index extracts. - if (isa<VectorType>(Ty) && !all_of(LI->users(), [LI](const User *U) { + if (isa<VectorType>(Ty) && !all_of(LI->users(), [](const User *U) { const ExtractElementInst *EEI = dyn_cast<ExtractElementInst>(U); return EEI && isa<ConstantInt>(EEI->getOperand(1)); })) @@ -622,7 +639,7 @@ Vectorizer::collectInstructions(BasicBlock *BB) { if (TySize > VecRegSize / 2) continue; - if (isa<VectorType>(Ty) && !all_of(SI->users(), [SI](const User *U) { + if (isa<VectorType>(Ty) && !all_of(SI->users(), [](const User *U) { const ExtractElementInst *EEI = dyn_cast<ExtractElementInst>(U); return EEI && isa<ConstantInt>(EEI->getOperand(1)); })) diff --git a/contrib/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/contrib/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index dac7032..012b10c 100644 --- a/contrib/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/contrib/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -50,6 +50,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/MapVector.h" +#include "llvm/ADT/Optional.h" #include "llvm/ADT/SCCIterator.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" @@ -92,6 +93,7 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/LoopSimplify.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/LoopVersioning.h" #include "llvm/Transforms/Vectorize.h" @@ -112,12 +114,13 @@ static cl::opt<bool> EnableIfConversion("enable-if-conversion", cl::init(true), cl::Hidden, cl::desc("Enable if-conversion during vectorization.")); -/// We don't vectorize loops with a known constant trip count below this number. +/// Loops with a known constant trip count below this number are vectorized only +/// if no scalar iteration overheads are incurred. static cl::opt<unsigned> TinyTripCountVectorThreshold( "vectorizer-min-trip-count", cl::init(16), cl::Hidden, - cl::desc("Don't vectorize loops with a constant " - "trip count that is smaller than this " - "value.")); + cl::desc("Loops with a constant trip count that is smaller than this " + "value are vectorized only if no scalar iteration overheads " + "are incurred.")); static cl::opt<bool> MaximizeBandwidth( "vectorizer-maximize-bandwidth", cl::init(false), cl::Hidden, @@ -266,21 +269,6 @@ static bool hasCyclesInLoopBody(const Loop &L) { return false; } -/// \brief This modifies LoopAccessReport to initialize message with -/// loop-vectorizer-specific part. -class VectorizationReport : public LoopAccessReport { -public: - VectorizationReport(Instruction *I = nullptr) - : LoopAccessReport("loop not vectorized: ", I) {} - - /// \brief This allows promotion of the loop-access analysis report into the - /// loop-vectorizer report. It modifies the message to add the - /// loop-vectorizer-specific part of the message. - explicit VectorizationReport(const LoopAccessReport &R) - : LoopAccessReport(Twine("loop not vectorized: ") + R.str(), - R.getInstr()) {} -}; - /// A helper function for converting Scalar types to vector types. /// If the incoming type is void, we return void. If the VF is 1, we return /// the scalar type. @@ -290,31 +278,9 @@ static Type *ToVectorTy(Type *Scalar, unsigned VF) { return VectorType::get(Scalar, VF); } -/// A helper function that returns GEP instruction and knows to skip a -/// 'bitcast'. The 'bitcast' may be skipped if the source and the destination -/// pointee types of the 'bitcast' have the same size. -/// For example: -/// bitcast double** %var to i64* - can be skipped -/// bitcast double** %var to i8* - can not -static GetElementPtrInst *getGEPInstruction(Value *Ptr) { - - if (isa<GetElementPtrInst>(Ptr)) - return cast<GetElementPtrInst>(Ptr); - - if (isa<BitCastInst>(Ptr) && - isa<GetElementPtrInst>(cast<BitCastInst>(Ptr)->getOperand(0))) { - Type *BitcastTy = Ptr->getType(); - Type *GEPTy = cast<BitCastInst>(Ptr)->getSrcTy(); - if (!isa<PointerType>(BitcastTy) || !isa<PointerType>(GEPTy)) - return nullptr; - Type *Pointee1Ty = cast<PointerType>(BitcastTy)->getPointerElementType(); - Type *Pointee2Ty = cast<PointerType>(GEPTy)->getPointerElementType(); - const DataLayout &DL = cast<BitCastInst>(Ptr)->getModule()->getDataLayout(); - if (DL.getTypeSizeInBits(Pointee1Ty) == DL.getTypeSizeInBits(Pointee2Ty)) - return cast<GetElementPtrInst>(cast<BitCastInst>(Ptr)->getOperand(0)); - } - return nullptr; -} +// FIXME: The following helper functions have multiple implementations +// in the project. They can be effectively organized in a common Load/Store +// utilities unit. /// A helper function that returns the pointer operand of a load or store /// instruction. @@ -326,6 +292,34 @@ static Value *getPointerOperand(Value *I) { return nullptr; } +/// A helper function that returns the type of loaded or stored value. +static Type *getMemInstValueType(Value *I) { + assert((isa<LoadInst>(I) || isa<StoreInst>(I)) && + "Expected Load or Store instruction"); + if (auto *LI = dyn_cast<LoadInst>(I)) + return LI->getType(); + return cast<StoreInst>(I)->getValueOperand()->getType(); +} + +/// A helper function that returns the alignment of load or store instruction. +static unsigned getMemInstAlignment(Value *I) { + assert((isa<LoadInst>(I) || isa<StoreInst>(I)) && + "Expected Load or Store instruction"); + if (auto *LI = dyn_cast<LoadInst>(I)) + return LI->getAlignment(); + return cast<StoreInst>(I)->getAlignment(); +} + +/// A helper function that returns the address space of the pointer operand of +/// load or store instruction. +static unsigned getMemInstAddressSpace(Value *I) { + assert((isa<LoadInst>(I) || isa<StoreInst>(I)) && + "Expected Load or Store instruction"); + if (auto *LI = dyn_cast<LoadInst>(I)) + return LI->getPointerAddressSpace(); + return cast<StoreInst>(I)->getPointerAddressSpace(); +} + /// A helper function that returns true if the given type is irregular. The /// type is irregular if its allocated size doesn't equal the store size of an /// element of the corresponding vector type at the given vectorization factor. @@ -351,6 +345,23 @@ static bool hasIrregularType(Type *Ty, const DataLayout &DL, unsigned VF) { /// we always assume predicated blocks have a 50% chance of executing. static unsigned getReciprocalPredBlockProb() { return 2; } +/// A helper function that adds a 'fast' flag to floating-point operations. +static Value *addFastMathFlag(Value *V) { + if (isa<FPMathOperator>(V)) { + FastMathFlags Flags; + Flags.setUnsafeAlgebra(); + cast<Instruction>(V)->setFastMathFlags(Flags); + } + return V; +} + +/// A helper function that returns an integer or floating-point constant with +/// value C. +static Constant *getSignedIntOrFpConstant(Type *Ty, int64_t C) { + return Ty->isIntegerTy() ? ConstantInt::getSigned(Ty, C) + : ConstantFP::get(Ty, C); +} + /// InnerLoopVectorizer vectorizes loops which contain only one basic /// block to a specified vectorization factor (VF). /// This class performs the widening of scalars into vectors, or multiple @@ -381,13 +392,14 @@ public: TripCount(nullptr), VectorTripCount(nullptr), Legal(LVL), Cost(CM), AddedSafetyChecks(false) {} - // Perform the actual loop widening (vectorization). - void vectorize() { - // Create a new empty loop. Unlink the old loop and connect the new one. - createEmptyLoop(); - // Widen each instruction in the old loop to a new one in the new loop. - vectorizeLoop(); - } + /// Create a new empty loop. Unlink the old loop and connect the new one. + void createVectorizedLoopSkeleton(); + + /// Vectorize a single instruction within the innermost loop. + void vectorizeInstruction(Instruction &I); + + /// Fix the vectorized code, taking care of header phi's, live-outs, and more. + void fixVectorizedLoop(); // Return true if any runtime check is added. bool areSafetyChecksAdded() { return AddedSafetyChecks; } @@ -412,10 +424,8 @@ protected: // When we if-convert we need to create edge masks. We have to cache values // so that we don't end up with exponential recursion/IR. typedef DenseMap<std::pair<BasicBlock *, BasicBlock *>, VectorParts> - EdgeMaskCache; - - /// Create an empty loop, based on the loop ranges of the old loop. - void createEmptyLoop(); + EdgeMaskCacheTy; + typedef DenseMap<BasicBlock *, VectorParts> BlockMaskCacheTy; /// Set up the values of the IVs correctly when exiting the vector loop. void fixupIVUsers(PHINode *OrigPhi, const InductionDescriptor &II, @@ -425,17 +435,22 @@ protected: /// Create a new induction variable inside L. PHINode *createInductionVariable(Loop *L, Value *Start, Value *End, Value *Step, Instruction *DL); - /// Copy and widen the instructions from the old loop. - virtual void vectorizeLoop(); + + /// Handle all cross-iteration phis in the header. + void fixCrossIterationPHIs(); /// Fix a first-order recurrence. This is the second phase of vectorizing /// this phi node. void fixFirstOrderRecurrence(PHINode *Phi); - /// \brief The Loop exit block may have single value PHI nodes where the - /// incoming value is 'Undef'. While vectorizing we only handled real values - /// that were defined inside the loop. Here we fix the 'undef case'. - /// See PR14725. + /// Fix a reduction cross-iteration phi. This is the second phase of + /// vectorizing this phi node. + void fixReduction(PHINode *Phi); + + /// \brief The Loop exit block may have single value PHI nodes with some + /// incoming value. While vectorizing we only handled real values + /// that were defined inside the loop and we should have one value for + /// each predecessor of its parent basic block. See PR14725. void fixLCSSAPHIs(); /// Iteratively sink the scalarized operands of a predicated instruction into @@ -446,10 +461,6 @@ protected: /// respective conditions. void predicateInstructions(); - /// Collect the instructions from the original loop that would be trivially - /// dead in the vectorized loop if generated. - void collectTriviallyDeadInstructions(); - /// Shrinks vector element sizes to the smallest bitwidth they can be legally /// represented as. void truncateToMinimalBitwidths(); @@ -462,14 +473,10 @@ protected: /// and DST. VectorParts createEdgeMask(BasicBlock *Src, BasicBlock *Dst); - /// A helper function to vectorize a single BB within the innermost loop. - void vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV); - /// Vectorize a single PHINode in a block. This method handles the induction /// variable canonicalization. It supports both VF = 1 for unrolled loops and /// arbitrary length vectors. - void widenPHIInstruction(Instruction *PN, unsigned UF, unsigned VF, - PhiVector *PV); + void widenPHIInstruction(Instruction *PN, unsigned UF, unsigned VF); /// Insert the new loop to the loop hierarchy and pass manager /// and update the analysis passes. @@ -479,8 +486,7 @@ protected: /// of scalars. If \p IfPredicateInstr is true we need to 'hide' each /// scalarized instruction behind an if block predicated on the control /// dependence of the instruction. - virtual void scalarizeInstruction(Instruction *Instr, - bool IfPredicateInstr = false); + void scalarizeInstruction(Instruction *Instr, bool IfPredicateInstr = false); /// Vectorize Load and Store instructions, virtual void vectorizeMemoryInstruction(Instruction *Instr); @@ -504,20 +510,21 @@ protected: /// \p EntryVal is the value from the original loop that maps to the steps. /// Note that \p EntryVal doesn't have to be an induction variable (e.g., it /// can be a truncate instruction). - void buildScalarSteps(Value *ScalarIV, Value *Step, Value *EntryVal); - - /// Create a vector induction phi node based on an existing scalar one. This - /// currently only works for integer induction variables with a constant - /// step. \p EntryVal is the value from the original loop that maps to the - /// vector phi node. If \p EntryVal is a truncate instruction, instead of - /// widening the original IV, we widen a version of the IV truncated to \p - /// EntryVal's type. - void createVectorIntInductionPHI(const InductionDescriptor &II, - Instruction *EntryVal); - - /// Widen an integer induction variable \p IV. If \p Trunc is provided, the - /// induction variable will first be truncated to the corresponding type. - void widenIntInduction(PHINode *IV, TruncInst *Trunc = nullptr); + void buildScalarSteps(Value *ScalarIV, Value *Step, Value *EntryVal, + const InductionDescriptor &ID); + + /// Create a vector induction phi node based on an existing scalar one. \p + /// EntryVal is the value from the original loop that maps to the vector phi + /// node, and \p Step is the loop-invariant step. If \p EntryVal is a + /// truncate instruction, instead of widening the original IV, we widen a + /// version of the IV truncated to \p EntryVal's type. + void createVectorIntOrFpInductionPHI(const InductionDescriptor &II, + Value *Step, Instruction *EntryVal); + + /// Widen an integer or floating-point induction variable \p IV. If \p Trunc + /// is provided, the integer induction variable will first be truncated to + /// the corresponding type. + void widenIntOrFpInduction(PHINode *IV, TruncInst *Trunc = nullptr); /// Returns true if an instruction \p I should be scalarized instead of /// vectorized for the chosen vectorization factor. @@ -526,21 +533,34 @@ protected: /// Returns true if we should generate a scalar version of \p IV. bool needsScalarInduction(Instruction *IV) const; - /// Return a constant reference to the VectorParts corresponding to \p V from - /// the original loop. If the value has already been vectorized, the - /// corresponding vector entry in VectorLoopValueMap is returned. If, + /// getOrCreateVectorValue and getOrCreateScalarValue coordinate to generate a + /// vector or scalar value on-demand if one is not yet available. When + /// vectorizing a loop, we visit the definition of an instruction before its + /// uses. When visiting the definition, we either vectorize or scalarize the + /// instruction, creating an entry for it in the corresponding map. (In some + /// cases, such as induction variables, we will create both vector and scalar + /// entries.) Then, as we encounter uses of the definition, we derive values + /// for each scalar or vector use unless such a value is already available. + /// For example, if we scalarize a definition and one of its uses is vector, + /// we build the required vector on-demand with an insertelement sequence + /// when visiting the use. Otherwise, if the use is scalar, we can use the + /// existing scalar definition. + /// + /// Return a value in the new loop corresponding to \p V from the original + /// loop at unroll index \p Part. If the value has already been vectorized, + /// the corresponding vector entry in VectorLoopValueMap is returned. If, /// however, the value has a scalar entry in VectorLoopValueMap, we construct - /// new vector values on-demand by inserting the scalar values into vectors + /// a new vector value on-demand by inserting the scalar values into a vector /// with an insertelement sequence. If the value has been neither vectorized /// nor scalarized, it must be loop invariant, so we simply broadcast the - /// value into vectors. - const VectorParts &getVectorValue(Value *V); + /// value into a vector. + Value *getOrCreateVectorValue(Value *V, unsigned Part); /// Return a value in the new loop corresponding to \p V from the original /// loop at unroll index \p Part and vector index \p Lane. If the value has /// been vectorized but not scalarized, the necessary extractelement /// instruction will be generated. - Value *getScalarValue(Value *V, unsigned Part, unsigned Lane); + Value *getOrCreateScalarValue(Value *V, unsigned Part, unsigned Lane); /// Try to vectorize the interleaved access group that \p Instr belongs to. void vectorizeInterleaveGroup(Instruction *Instr); @@ -554,11 +574,9 @@ protected: /// Returns (and creates if needed) the trip count of the widened loop. Value *getOrCreateVectorTripCount(Loop *NewLoop); - /// Emit a bypass check to see if the trip count would overflow, or we - /// wouldn't have enough iterations to execute one vector loop. + /// Emit a bypass check to see if the vector trip count is zero, including if + /// it overflows. void emitMinimumIterationCountCheck(Loop *L, BasicBlock *Bypass); - /// Emit a bypass check to see if the vector trip count is nonzero. - void emitVectorLoopEnteredCheck(Loop *L, BasicBlock *Bypass); /// Emit a bypass check to see if all of the SCEV assumptions we've /// had to make are correct. void emitSCEVChecks(Loop *L, BasicBlock *Bypass); @@ -583,6 +601,10 @@ protected: /// vector of instructions. void addMetadata(ArrayRef<Value *> To, Instruction *From); + /// \brief Set the debug location in the builder using the debug location in + /// the instruction. + void setDebugLocFromInst(IRBuilder<> &B, const Value *Ptr); + /// This is a helper class for maintaining vectorization state. It's used for /// mapping values from the original loop to their corresponding values in /// the new loop. Two mappings are maintained: one for vectorized values and @@ -591,90 +613,103 @@ protected: /// UF x VF scalar values in the new loop. UF and VF are the unroll and /// vectorization factors, respectively. /// - /// Entries can be added to either map with initVector and initScalar, which - /// initialize and return a constant reference to the new entry. If a - /// non-constant reference to a vector entry is required, getVector can be - /// used to retrieve a mutable entry. We currently directly modify the mapped - /// values during "fix-up" operations that occur once the first phase of - /// widening is complete. These operations include type truncation and the - /// second phase of recurrence widening. + /// Entries can be added to either map with setVectorValue and setScalarValue, + /// which assert that an entry was not already added before. If an entry is to + /// replace an existing one, call resetVectorValue. This is currently needed + /// to modify the mapped values during "fix-up" operations that occur once the + /// first phase of widening is complete. These operations include type + /// truncation and the second phase of recurrence widening. /// - /// Otherwise, entries from either map should be accessed using the - /// getVectorValue or getScalarValue functions from InnerLoopVectorizer. - /// getVectorValue and getScalarValue coordinate to generate a vector or - /// scalar value on-demand if one is not yet available. When vectorizing a - /// loop, we visit the definition of an instruction before its uses. When - /// visiting the definition, we either vectorize or scalarize the - /// instruction, creating an entry for it in the corresponding map. (In some - /// cases, such as induction variables, we will create both vector and scalar - /// entries.) Then, as we encounter uses of the definition, we derive values - /// for each scalar or vector use unless such a value is already available. - /// For example, if we scalarize a definition and one of its uses is vector, - /// we build the required vector on-demand with an insertelement sequence - /// when visiting the use. Otherwise, if the use is scalar, we can use the - /// existing scalar definition. + /// Entries from either map can be retrieved using the getVectorValue and + /// getScalarValue functions, which assert that the desired value exists. + struct ValueMap { /// Construct an empty map with the given unroll and vectorization factors. - ValueMap(unsigned UnrollFactor, unsigned VecWidth) - : UF(UnrollFactor), VF(VecWidth) { - // The unroll and vectorization factors are only used in asserts builds - // to verify map entries are sized appropriately. - (void)UF; - (void)VF; + ValueMap(unsigned UF, unsigned VF) : UF(UF), VF(VF) {} + + /// \return True if the map has any vector entry for \p Key. + bool hasAnyVectorValue(Value *Key) const { + return VectorMapStorage.count(Key); } - /// \return True if the map has a vector entry for \p Key. - bool hasVector(Value *Key) const { return VectorMapStorage.count(Key); } - - /// \return True if the map has a scalar entry for \p Key. - bool hasScalar(Value *Key) const { return ScalarMapStorage.count(Key); } - - /// \brief Map \p Key to the given VectorParts \p Entry, and return a - /// constant reference to the new vector map entry. The given key should - /// not already be in the map, and the given VectorParts should be - /// correctly sized for the current unroll factor. - const VectorParts &initVector(Value *Key, const VectorParts &Entry) { - assert(!hasVector(Key) && "Vector entry already initialized"); - assert(Entry.size() == UF && "VectorParts has wrong dimensions"); - VectorMapStorage[Key] = Entry; - return VectorMapStorage[Key]; + /// \return True if the map has a vector entry for \p Key and \p Part. + bool hasVectorValue(Value *Key, unsigned Part) const { + assert(Part < UF && "Queried Vector Part is too large."); + if (!hasAnyVectorValue(Key)) + return false; + const VectorParts &Entry = VectorMapStorage.find(Key)->second; + assert(Entry.size() == UF && "VectorParts has wrong dimensions."); + return Entry[Part] != nullptr; + } + + /// \return True if the map has any scalar entry for \p Key. + bool hasAnyScalarValue(Value *Key) const { + return ScalarMapStorage.count(Key); + } + + /// \return True if the map has a scalar entry for \p Key, \p Part and + /// \p Part. + bool hasScalarValue(Value *Key, unsigned Part, unsigned Lane) const { + assert(Part < UF && "Queried Scalar Part is too large."); + assert(Lane < VF && "Queried Scalar Lane is too large."); + if (!hasAnyScalarValue(Key)) + return false; + const ScalarParts &Entry = ScalarMapStorage.find(Key)->second; + assert(Entry.size() == UF && "ScalarParts has wrong dimensions."); + assert(Entry[Part].size() == VF && "ScalarParts has wrong dimensions."); + return Entry[Part][Lane] != nullptr; } - /// \brief Map \p Key to the given ScalarParts \p Entry, and return a - /// constant reference to the new scalar map entry. The given key should - /// not already be in the map, and the given ScalarParts should be - /// correctly sized for the current unroll and vectorization factors. - const ScalarParts &initScalar(Value *Key, const ScalarParts &Entry) { - assert(!hasScalar(Key) && "Scalar entry already initialized"); - assert(Entry.size() == UF && - all_of(make_range(Entry.begin(), Entry.end()), - [&](const SmallVectorImpl<Value *> &Values) -> bool { - return Values.size() == VF; - }) && - "ScalarParts has wrong dimensions"); - ScalarMapStorage[Key] = Entry; - return ScalarMapStorage[Key]; + /// Retrieve the existing vector value that corresponds to \p Key and + /// \p Part. + Value *getVectorValue(Value *Key, unsigned Part) { + assert(hasVectorValue(Key, Part) && "Getting non-existent value."); + return VectorMapStorage[Key][Part]; } - /// \return A reference to the vector map entry corresponding to \p Key. - /// The key should already be in the map. This function should only be used - /// when it's necessary to update values that have already been vectorized. - /// This is the case for "fix-up" operations including type truncation and - /// the second phase of recurrence vectorization. If a non-const reference - /// isn't required, getVectorValue should be used instead. - VectorParts &getVector(Value *Key) { - assert(hasVector(Key) && "Vector entry not initialized"); - return VectorMapStorage.find(Key)->second; + /// Retrieve the existing scalar value that corresponds to \p Key, \p Part + /// and \p Lane. + Value *getScalarValue(Value *Key, unsigned Part, unsigned Lane) { + assert(hasScalarValue(Key, Part, Lane) && "Getting non-existent value."); + return ScalarMapStorage[Key][Part][Lane]; } - /// Retrieve an entry from the vector or scalar maps. The preferred way to - /// access an existing mapped entry is with getVectorValue or - /// getScalarValue from InnerLoopVectorizer. Until those functions can be - /// moved inside ValueMap, we have to declare them as friends. - friend const VectorParts &InnerLoopVectorizer::getVectorValue(Value *V); - friend Value *InnerLoopVectorizer::getScalarValue(Value *V, unsigned Part, - unsigned Lane); + /// Set a vector value associated with \p Key and \p Part. Assumes such a + /// value is not already set. If it is, use resetVectorValue() instead. + void setVectorValue(Value *Key, unsigned Part, Value *Vector) { + assert(!hasVectorValue(Key, Part) && "Vector value already set for part"); + if (!VectorMapStorage.count(Key)) { + VectorParts Entry(UF); + VectorMapStorage[Key] = Entry; + } + VectorMapStorage[Key][Part] = Vector; + } + + /// Set a scalar value associated with \p Key for \p Part and \p Lane. + /// Assumes such a value is not already set. + void setScalarValue(Value *Key, unsigned Part, unsigned Lane, + Value *Scalar) { + assert(!hasScalarValue(Key, Part, Lane) && "Scalar value already set"); + if (!ScalarMapStorage.count(Key)) { + ScalarParts Entry(UF); + for (unsigned Part = 0; Part < UF; ++Part) + Entry[Part].resize(VF, nullptr); + // TODO: Consider storing uniform values only per-part, as they occupy + // lane 0 only, keeping the other VF-1 redundant entries null. + ScalarMapStorage[Key] = Entry; + } + ScalarMapStorage[Key][Part][Lane] = Scalar; + } + + /// Reset the vector value associated with \p Key for the given \p Part. + /// This function can be used to update values that have already been + /// vectorized. This is the case for "fix-up" operations including type + /// truncation and the second phase of recurrence vectorization. + void resetVectorValue(Value *Key, unsigned Part, Value *Vector) { + assert(hasVectorValue(Key, Part) && "Vector value not set for part"); + VectorMapStorage[Key][Part] = Vector; + } private: /// The unroll factor. Each entry in the vector map contains UF vector @@ -762,7 +797,8 @@ protected: /// Store instructions that should be predicated, as a pair /// <StoreInst, Predicate> SmallVector<std::pair<Instruction *, Value *>, 4> PredicatedInstructions; - EdgeMaskCache MaskCache; + EdgeMaskCacheTy EdgeMaskCache; + BlockMaskCacheTy BlockMaskCache; /// Trip count of the original loop. Value *TripCount; /// Trip count of the widened loop (TripCount - TripCount % (VF*UF)) @@ -777,14 +813,6 @@ protected: // Record whether runtime checks are added. bool AddedSafetyChecks; - // Holds instructions from the original loop whose counterparts in the - // vectorized loop would be trivially dead if generated. For example, - // original induction update instructions can become dead because we - // separately emit induction "steps" when generating code for the new loop. - // Similarly, we create a new latch condition when setting up the structure - // of the new loop, so the old one can become dead. - SmallPtrSet<Instruction *, 4> DeadInstructions; - // Holds the end values for each induction variable. We save the end values // so we can later fix-up the external users of the induction variables. DenseMap<PHINode *, Value *> IVEndValues; @@ -803,8 +831,6 @@ public: UnrollFactor, LVL, CM) {} private: - void scalarizeInstruction(Instruction *Instr, - bool IfPredicateInstr = false) override; void vectorizeMemoryInstruction(Instruction *Instr) override; Value *getBroadcastInstrs(Value *V) override; Value *getStepVector(Value *Val, int StartIdx, Value *Step, @@ -832,12 +858,14 @@ static Instruction *getDebugLocFromInstOrOperands(Instruction *I) { return I; } -/// \brief Set the debug location in the builder using the debug location in the -/// instruction. -static void setDebugLocFromInst(IRBuilder<> &B, const Value *Ptr) { - if (const Instruction *Inst = dyn_cast_or_null<Instruction>(Ptr)) - B.SetCurrentDebugLocation(Inst->getDebugLoc()); - else +void InnerLoopVectorizer::setDebugLocFromInst(IRBuilder<> &B, const Value *Ptr) { + if (const Instruction *Inst = dyn_cast_or_null<Instruction>(Ptr)) { + const DILocation *DIL = Inst->getDebugLoc(); + if (DIL && Inst->getFunction()->isDebugInfoForProfiling()) + B.SetCurrentDebugLocation(DIL->cloneWithDuplicationFactor(UF * VF)); + else + B.SetCurrentDebugLocation(DIL); + } else B.SetCurrentDebugLocation(DebugLoc()); } @@ -1497,14 +1525,6 @@ private: OptimizationRemarkEmitter &ORE; }; -static void emitAnalysisDiag(const Loop *TheLoop, - const LoopVectorizeHints &Hints, - OptimizationRemarkEmitter &ORE, - const LoopAccessReport &Message) { - const char *Name = Hints.vectorizeAnalysisPassName(); - LoopAccessReport::emitAnalysis(Message, TheLoop, Name, ORE); -} - static void emitMissedWarning(Function *F, Loop *L, const LoopVectorizeHints &LH, OptimizationRemarkEmitter *ORE) { @@ -1512,13 +1532,17 @@ static void emitMissedWarning(Function *F, Loop *L, if (LH.getForce() == LoopVectorizeHints::FK_Enabled) { if (LH.getWidth() != 1) - emitLoopVectorizeWarning( - F->getContext(), *F, L->getStartLoc(), - "failed explicitly specified loop vectorization"); + ORE->emit(DiagnosticInfoOptimizationFailure( + DEBUG_TYPE, "FailedRequestedVectorization", + L->getStartLoc(), L->getHeader()) + << "loop not vectorized: " + << "failed explicitly specified loop vectorization"); else if (LH.getInterleave() != 1) - emitLoopInterleaveWarning( - F->getContext(), *F, L->getStartLoc(), - "failed explicitly specified loop interleaving"); + ORE->emit(DiagnosticInfoOptimizationFailure( + DEBUG_TYPE, "FailedRequestedInterleaving", L->getStartLoc(), + L->getHeader()) + << "loop not interleaved: " + << "failed explicitly specified loop interleaving"); } } @@ -1546,7 +1570,7 @@ public: LoopVectorizeHints *H) : NumPredStores(0), TheLoop(L), PSE(PSE), TLI(TLI), TTI(TTI), DT(DT), GetLAA(GetLAA), LAI(nullptr), ORE(ORE), InterleaveInfo(PSE, L, DT, LI), - Induction(nullptr), WidestIndTy(nullptr), HasFunNoNaNAttr(false), + PrimaryInduction(nullptr), WidestIndTy(nullptr), HasFunNoNaNAttr(false), Requirements(R), Hints(H) {} /// ReductionList contains the reduction descriptors for all @@ -1566,8 +1590,8 @@ public: /// loop, only that it is legal to do so. bool canVectorize(); - /// Returns the Induction variable. - PHINode *getInduction() { return Induction; } + /// Returns the primary induction variable. + PHINode *getPrimaryInduction() { return PrimaryInduction; } /// Returns the reduction variables found in the loop. ReductionList *getReductionVars() { return &Reductions; } @@ -1578,6 +1602,9 @@ public: /// Return the first-order recurrences found in the loop. RecurrenceSet *getFirstOrderRecurrences() { return &FirstOrderRecurrences; } + /// Return the set of instructions to sink to handle first-order recurrences. + DenseMap<Instruction *, Instruction *> &getSinkAfter() { return SinkAfter; } + /// Returns the widest induction type. Type *getWidestInductionType() { return WidestIndTy; } @@ -1607,12 +1634,6 @@ public: /// Returns true if the value V is uniform within the loop. bool isUniform(Value *V); - /// Returns true if \p I is known to be uniform after vectorization. - bool isUniformAfterVectorization(Instruction *I) { return Uniforms.count(I); } - - /// Returns true if \p I is known to be scalar after vectorization. - bool isScalarAfterVectorization(Instruction *I) { return Scalars.count(I); } - /// Returns the information that we collected about runtime memory check. const RuntimePointerChecking *getRuntimePointerChecking() const { return LAI->getRuntimePointerChecking(); @@ -1689,15 +1710,12 @@ public: /// instructions that may divide by zero. bool isScalarWithPredication(Instruction *I); - /// Returns true if \p I is a memory instruction that has a consecutive or - /// consecutive-like pointer operand. Consecutive-like pointers are pointers - /// that are treated like consecutive pointers during vectorization. The - /// pointer operands of interleaved accesses are an example. - bool hasConsecutiveLikePtrOperand(Instruction *I); + /// Returns true if \p I is a memory instruction with consecutive memory + /// access that can be widened. + bool memoryInstructionCanBeWidened(Instruction *I, unsigned VF = 1); - /// Returns true if \p I is a memory instruction that must be scalarized - /// during vectorization. - bool memoryInstructionMustBeScalarized(Instruction *I, unsigned VF = 1); + // Returns true if the NoNaN attribute is set on the function. + bool hasFunNoNaNAttr() const { return HasFunNoNaNAttr; } private: /// Check if a single basic block loop is vectorizable. @@ -1715,24 +1733,6 @@ private: /// transformation. bool canVectorizeWithIfConvert(); - /// Collect the instructions that are uniform after vectorization. An - /// instruction is uniform if we represent it with a single scalar value in - /// the vectorized loop corresponding to each vector iteration. Examples of - /// uniform instructions include pointer operands of consecutive or - /// interleaved memory accesses. Note that although uniformity implies an - /// instruction will be scalar, the reverse is not true. In general, a - /// scalarized instruction will be represented by VF scalar values in the - /// vectorized loop, each corresponding to an iteration of the original - /// scalar loop. - void collectLoopUniforms(); - - /// Collect the instructions that are scalar after vectorization. An - /// instruction is scalar if it is known to be uniform or will be scalarized - /// during vectorization. Non-uniform scalarized instructions will be - /// represented by VF values in the vectorized loop, each corresponding to an - /// iteration of the original scalar loop. - void collectLoopScalars(); - /// Return true if all of the instructions in the block can be speculatively /// executed. \p SafePtrs is a list of addresses that are known to be legal /// and we know that we can read from them without segfault. @@ -1744,14 +1744,6 @@ private: void addInductionPhi(PHINode *Phi, const InductionDescriptor &ID, SmallPtrSetImpl<Value *> &AllowedExit); - /// Report an analysis message to assist the user in diagnosing loops that are - /// not vectorized. These are handled as LoopAccessReport rather than - /// VectorizationReport because the << operator of VectorizationReport returns - /// LoopAccessReport. - void emitAnalysis(const LoopAccessReport &Message) const { - emitAnalysisDiag(TheLoop, *Hints, *ORE, Message); - } - /// Create an analysis remark that explains why vectorization failed /// /// \p RemarkName is the identifier for the remark. If \p I is passed it is @@ -1804,9 +1796,9 @@ private: // --- vectorization state --- // - /// Holds the integer induction variable. This is the counter of the + /// Holds the primary induction variable. This is the counter of the /// loop. - PHINode *Induction; + PHINode *PrimaryInduction; /// Holds the reduction variables. ReductionList Reductions; /// Holds all of the induction variables that we found in the loop. @@ -1815,6 +1807,9 @@ private: InductionList Inductions; /// Holds the phi nodes that are first-order recurrences. RecurrenceSet FirstOrderRecurrences; + /// Holds instructions that need to sink past other instructions to handle + /// first-order recurrences. + DenseMap<Instruction *, Instruction *> SinkAfter; /// Holds the widest induction type encountered. Type *WidestIndTy; @@ -1822,12 +1817,6 @@ private: /// vars which can be accessed from outside the loop. SmallPtrSet<Value *, 4> AllowedExit; - /// Holds the instructions known to be uniform after vectorization. - SmallPtrSet<Instruction *, 4> Uniforms; - - /// Holds the instructions known to be scalar after vectorization. - SmallPtrSet<Instruction *, 4> Scalars; - /// Can we assume the absence of NaNs. bool HasFunNoNaNAttr; @@ -1861,16 +1850,26 @@ public: : TheLoop(L), PSE(PSE), LI(LI), Legal(Legal), TTI(TTI), TLI(TLI), DB(DB), AC(AC), ORE(ORE), TheFunction(F), Hints(Hints) {} + /// \return An upper bound for the vectorization factor, or None if + /// vectorization should be avoided up front. + Optional<unsigned> computeMaxVF(bool OptForSize); + /// Information about vectorization costs struct VectorizationFactor { unsigned Width; // Vector width with best cost unsigned Cost; // Cost of the loop with that width }; /// \return The most profitable vectorization factor and the cost of that VF. - /// This method checks every power of two up to VF. If UserVF is not ZERO + /// This method checks every power of two up to MaxVF. If UserVF is not ZERO /// then this vectorization factor will be selected if vectorization is /// possible. - VectorizationFactor selectVectorizationFactor(bool OptForSize); + VectorizationFactor selectVectorizationFactor(unsigned MaxVF); + + /// Setup cost-based decisions for user vectorization factor. + void selectUserVectorizationFactor(unsigned UserVF) { + collectUniformsAndScalars(UserVF); + collectInstsToScalarize(UserVF); + } /// \return The size (in bits) of the smallest and widest types in the code /// that needs to be vectorized. We ignore values that remain scalar such as @@ -1884,6 +1883,15 @@ public: unsigned selectInterleaveCount(bool OptForSize, unsigned VF, unsigned LoopCost); + /// Memory access instruction may be vectorized in more than one way. + /// Form of instruction after vectorization depends on cost. + /// This function takes cost-based decisions for Load/Store instructions + /// and collects them in a map. This decisions map is used for building + /// the lists of loop-uniform and loop-scalar instructions. + /// The calculated cost is saved with widening decision in order to + /// avoid redundant calculations. + void setCostBasedWideningDecision(unsigned VF); + /// \brief A struct that represents some properties of the register usage /// of a loop. struct RegisterUsage { @@ -1918,14 +1926,118 @@ public: return Scalars->second.count(I); } + /// Returns true if \p I is known to be uniform after vectorization. + bool isUniformAfterVectorization(Instruction *I, unsigned VF) const { + if (VF == 1) + return true; + assert(Uniforms.count(VF) && "VF not yet analyzed for uniformity"); + auto UniformsPerVF = Uniforms.find(VF); + return UniformsPerVF->second.count(I); + } + + /// Returns true if \p I is known to be scalar after vectorization. + bool isScalarAfterVectorization(Instruction *I, unsigned VF) const { + if (VF == 1) + return true; + assert(Scalars.count(VF) && "Scalar values are not calculated for VF"); + auto ScalarsPerVF = Scalars.find(VF); + return ScalarsPerVF->second.count(I); + } + /// \returns True if instruction \p I can be truncated to a smaller bitwidth /// for vectorization factor \p VF. bool canTruncateToMinimalBitwidth(Instruction *I, unsigned VF) const { return VF > 1 && MinBWs.count(I) && !isProfitableToScalarize(I, VF) && - !Legal->isScalarAfterVectorization(I); + !isScalarAfterVectorization(I, VF); + } + + /// Decision that was taken during cost calculation for memory instruction. + enum InstWidening { + CM_Unknown, + CM_Widen, + CM_Interleave, + CM_GatherScatter, + CM_Scalarize + }; + + /// Save vectorization decision \p W and \p Cost taken by the cost model for + /// instruction \p I and vector width \p VF. + void setWideningDecision(Instruction *I, unsigned VF, InstWidening W, + unsigned Cost) { + assert(VF >= 2 && "Expected VF >=2"); + WideningDecisions[std::make_pair(I, VF)] = std::make_pair(W, Cost); + } + + /// Save vectorization decision \p W and \p Cost taken by the cost model for + /// interleaving group \p Grp and vector width \p VF. + void setWideningDecision(const InterleaveGroup *Grp, unsigned VF, + InstWidening W, unsigned Cost) { + assert(VF >= 2 && "Expected VF >=2"); + /// Broadcast this decicion to all instructions inside the group. + /// But the cost will be assigned to one instruction only. + for (unsigned i = 0; i < Grp->getFactor(); ++i) { + if (auto *I = Grp->getMember(i)) { + if (Grp->getInsertPos() == I) + WideningDecisions[std::make_pair(I, VF)] = std::make_pair(W, Cost); + else + WideningDecisions[std::make_pair(I, VF)] = std::make_pair(W, 0); + } + } + } + + /// Return the cost model decision for the given instruction \p I and vector + /// width \p VF. Return CM_Unknown if this instruction did not pass + /// through the cost modeling. + InstWidening getWideningDecision(Instruction *I, unsigned VF) { + assert(VF >= 2 && "Expected VF >=2"); + std::pair<Instruction *, unsigned> InstOnVF = std::make_pair(I, VF); + auto Itr = WideningDecisions.find(InstOnVF); + if (Itr == WideningDecisions.end()) + return CM_Unknown; + return Itr->second.first; + } + + /// Return the vectorization cost for the given instruction \p I and vector + /// width \p VF. + unsigned getWideningCost(Instruction *I, unsigned VF) { + assert(VF >= 2 && "Expected VF >=2"); + std::pair<Instruction *, unsigned> InstOnVF = std::make_pair(I, VF); + assert(WideningDecisions.count(InstOnVF) && "The cost is not calculated"); + return WideningDecisions[InstOnVF].second; + } + + /// Return True if instruction \p I is an optimizable truncate whose operand + /// is an induction variable. Such a truncate will be removed by adding a new + /// induction variable with the destination type. + bool isOptimizableIVTruncate(Instruction *I, unsigned VF) { + + // If the instruction is not a truncate, return false. + auto *Trunc = dyn_cast<TruncInst>(I); + if (!Trunc) + return false; + + // Get the source and destination types of the truncate. + Type *SrcTy = ToVectorTy(cast<CastInst>(I)->getSrcTy(), VF); + Type *DestTy = ToVectorTy(cast<CastInst>(I)->getDestTy(), VF); + + // If the truncate is free for the given types, return false. Replacing a + // free truncate with an induction variable would add an induction variable + // update instruction to each iteration of the loop. We exclude from this + // check the primary induction variable since it will need an update + // instruction regardless. + Value *Op = Trunc->getOperand(0); + if (Op != Legal->getPrimaryInduction() && TTI.isTruncateFree(SrcTy, DestTy)) + return false; + + // If the truncated value is not an induction variable, return false. + return Legal->isInductionVariable(Op); } private: + /// \return An upper bound for the vectorization factor, larger than zero. + /// One is returned if vectorization should best be avoided due to cost. + unsigned computeFeasibleMaxVF(bool OptForSize); + /// The vectorization cost is a combination of the cost itself and a boolean /// indicating whether any of the contributing operations will actually /// operate on @@ -1949,6 +2061,26 @@ private: /// the vector type as an output parameter. unsigned getInstructionCost(Instruction *I, unsigned VF, Type *&VectorTy); + /// Calculate vectorization cost of memory instruction \p I. + unsigned getMemoryInstructionCost(Instruction *I, unsigned VF); + + /// The cost computation for scalarized memory instruction. + unsigned getMemInstScalarizationCost(Instruction *I, unsigned VF); + + /// The cost computation for interleaving group of memory instructions. + unsigned getInterleaveGroupCost(Instruction *I, unsigned VF); + + /// The cost computation for Gather/Scatter instruction. + unsigned getGatherScatterCost(Instruction *I, unsigned VF); + + /// The cost computation for widening instruction \p I with consecutive + /// memory access. + unsigned getConsecutiveMemOpCost(Instruction *I, unsigned VF); + + /// The cost calculation for Load instruction \p I with uniform pointer - + /// scalar load + broadcast. + unsigned getUniformMemOpCost(Instruction *I, unsigned VF); + /// Returns whether the instruction is a load or store and will be a emitted /// as a vector operation. bool isConsecutiveLoadOrStore(Instruction *I); @@ -1972,12 +2104,28 @@ private: /// pairs. typedef DenseMap<Instruction *, unsigned> ScalarCostsTy; + /// A set containing all BasicBlocks that are known to present after + /// vectorization as a predicated block. + SmallPtrSet<BasicBlock *, 4> PredicatedBBsAfterVectorization; + /// A map holding scalar costs for different vectorization factors. The /// presence of a cost for an instruction in the mapping indicates that the /// instruction will be scalarized when vectorizing with the associated /// vectorization factor. The entries are VF-ScalarCostTy pairs. DenseMap<unsigned, ScalarCostsTy> InstsToScalarize; + /// Holds the instructions known to be uniform after vectorization. + /// The data is collected per VF. + DenseMap<unsigned, SmallPtrSet<Instruction *, 4>> Uniforms; + + /// Holds the instructions known to be scalar after vectorization. + /// The data is collected per VF. + DenseMap<unsigned, SmallPtrSet<Instruction *, 4>> Scalars; + + /// Holds the instructions (address computations) that are forced to be + /// scalarized. + DenseMap<unsigned, SmallPtrSet<Instruction *, 4>> ForcedScalars; + /// Returns the expected difference in cost from scalarizing the expression /// feeding a predicated instruction \p PredInst. The instructions to /// scalarize and their scalar costs are collected in \p ScalarCosts. A @@ -1990,6 +2138,44 @@ private: /// the loop. void collectInstsToScalarize(unsigned VF); + /// Collect the instructions that are uniform after vectorization. An + /// instruction is uniform if we represent it with a single scalar value in + /// the vectorized loop corresponding to each vector iteration. Examples of + /// uniform instructions include pointer operands of consecutive or + /// interleaved memory accesses. Note that although uniformity implies an + /// instruction will be scalar, the reverse is not true. In general, a + /// scalarized instruction will be represented by VF scalar values in the + /// vectorized loop, each corresponding to an iteration of the original + /// scalar loop. + void collectLoopUniforms(unsigned VF); + + /// Collect the instructions that are scalar after vectorization. An + /// instruction is scalar if it is known to be uniform or will be scalarized + /// during vectorization. Non-uniform scalarized instructions will be + /// represented by VF values in the vectorized loop, each corresponding to an + /// iteration of the original scalar loop. + void collectLoopScalars(unsigned VF); + + /// Collect Uniform and Scalar values for the given \p VF. + /// The sets depend on CM decision for Load/Store instructions + /// that may be vectorized as interleave, gather-scatter or scalarized. + void collectUniformsAndScalars(unsigned VF) { + // Do the analysis once. + if (VF == 1 || Uniforms.count(VF)) + return; + setCostBasedWideningDecision(VF); + collectLoopUniforms(VF); + collectLoopScalars(VF); + } + + /// Keeps cost model vectorization decision and cost for instructions. + /// Right now it is used for memory instructions only. + typedef DenseMap<std::pair<Instruction *, unsigned>, + std::pair<InstWidening, unsigned>> + DecisionList; + + DecisionList WideningDecisions; + public: /// The loop that we evaluate. Loop *TheLoop; @@ -2019,6 +2205,44 @@ public: SmallPtrSet<const Value *, 16> VecValuesToIgnore; }; +/// LoopVectorizationPlanner - drives the vectorization process after having +/// passed Legality checks. +class LoopVectorizationPlanner { +public: + LoopVectorizationPlanner(Loop *OrigLoop, LoopInfo *LI, + LoopVectorizationLegality *Legal, + LoopVectorizationCostModel &CM) + : OrigLoop(OrigLoop), LI(LI), Legal(Legal), CM(CM) {} + + ~LoopVectorizationPlanner() {} + + /// Plan how to best vectorize, return the best VF and its cost. + LoopVectorizationCostModel::VectorizationFactor plan(bool OptForSize, + unsigned UserVF); + + /// Generate the IR code for the vectorized loop. + void executePlan(InnerLoopVectorizer &ILV); + +protected: + /// Collect the instructions from the original loop that would be trivially + /// dead in the vectorized loop if generated. + void collectTriviallyDeadInstructions( + SmallPtrSetImpl<Instruction *> &DeadInstructions); + +private: + /// The loop that we evaluate. + Loop *OrigLoop; + + /// Loop Info analysis. + LoopInfo *LI; + + /// The legality analysis. + LoopVectorizationLegality *Legal; + + /// The profitablity analysis. + LoopVectorizationCostModel &CM; +}; + /// \brief This holds vectorization requirements that must be verified late in /// the process. The requirements are set by legalize and costmodel. Once /// vectorization has been determined to be possible and profitable the @@ -2134,8 +2358,6 @@ struct LoopVectorize : public FunctionPass { void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<AssumptionCacheTracker>(); - AU.addRequiredID(LoopSimplifyID); - AU.addRequiredID(LCSSAID); AU.addRequired<BlockFrequencyInfoWrapperPass>(); AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<LoopInfoWrapperPass>(); @@ -2156,7 +2378,7 @@ struct LoopVectorize : public FunctionPass { //===----------------------------------------------------------------------===// // Implementation of LoopVectorizationLegality, InnerLoopVectorizer and -// LoopVectorizationCostModel. +// LoopVectorizationCostModel and LoopVectorizationPlanner. //===----------------------------------------------------------------------===// Value *InnerLoopVectorizer::getBroadcastInstrs(Value *V) { @@ -2176,41 +2398,63 @@ Value *InnerLoopVectorizer::getBroadcastInstrs(Value *V) { return Shuf; } -void InnerLoopVectorizer::createVectorIntInductionPHI( - const InductionDescriptor &II, Instruction *EntryVal) { +void InnerLoopVectorizer::createVectorIntOrFpInductionPHI( + const InductionDescriptor &II, Value *Step, Instruction *EntryVal) { Value *Start = II.getStartValue(); - ConstantInt *Step = II.getConstIntStepValue(); - assert(Step && "Can not widen an IV with a non-constant step"); // Construct the initial value of the vector IV in the vector loop preheader auto CurrIP = Builder.saveIP(); Builder.SetInsertPoint(LoopVectorPreHeader->getTerminator()); if (isa<TruncInst>(EntryVal)) { + assert(Start->getType()->isIntegerTy() && + "Truncation requires an integer type"); auto *TruncType = cast<IntegerType>(EntryVal->getType()); - Step = ConstantInt::getSigned(TruncType, Step->getSExtValue()); + Step = Builder.CreateTrunc(Step, TruncType); Start = Builder.CreateCast(Instruction::Trunc, Start, TruncType); } Value *SplatStart = Builder.CreateVectorSplat(VF, Start); - Value *SteppedStart = getStepVector(SplatStart, 0, Step); + Value *SteppedStart = + getStepVector(SplatStart, 0, Step, II.getInductionOpcode()); + + // We create vector phi nodes for both integer and floating-point induction + // variables. Here, we determine the kind of arithmetic we will perform. + Instruction::BinaryOps AddOp; + Instruction::BinaryOps MulOp; + if (Step->getType()->isIntegerTy()) { + AddOp = Instruction::Add; + MulOp = Instruction::Mul; + } else { + AddOp = II.getInductionOpcode(); + MulOp = Instruction::FMul; + } + + // Multiply the vectorization factor by the step using integer or + // floating-point arithmetic as appropriate. + Value *ConstVF = getSignedIntOrFpConstant(Step->getType(), VF); + Value *Mul = addFastMathFlag(Builder.CreateBinOp(MulOp, Step, ConstVF)); + + // Create a vector splat to use in the induction update. + // + // FIXME: If the step is non-constant, we create the vector splat with + // IRBuilder. IRBuilder can constant-fold the multiply, but it doesn't + // handle a constant vector splat. + Value *SplatVF = isa<Constant>(Mul) + ? ConstantVector::getSplat(VF, cast<Constant>(Mul)) + : Builder.CreateVectorSplat(VF, Mul); Builder.restoreIP(CurrIP); - Value *SplatVF = - ConstantVector::getSplat(VF, ConstantInt::getSigned(Start->getType(), - VF * Step->getSExtValue())); // We may need to add the step a number of times, depending on the unroll // factor. The last of those goes into the PHI. PHINode *VecInd = PHINode::Create(SteppedStart->getType(), 2, "vec.ind", &*LoopVectorBody->getFirstInsertionPt()); Instruction *LastInduction = VecInd; - VectorParts Entry(UF); for (unsigned Part = 0; Part < UF; ++Part) { - Entry[Part] = LastInduction; - LastInduction = cast<Instruction>( - Builder.CreateAdd(LastInduction, SplatVF, "step.add")); + VectorLoopValueMap.setVectorValue(EntryVal, Part, LastInduction); + if (isa<TruncInst>(EntryVal)) + addMetadata(LastInduction, EntryVal); + LastInduction = cast<Instruction>(addFastMathFlag( + Builder.CreateBinOp(AddOp, LastInduction, SplatVF, "step.add"))); } - VectorLoopValueMap.initVector(EntryVal, Entry); - if (isa<TruncInst>(EntryVal)) - addMetadata(Entry, EntryVal); // Move the last step to the end of the latch block. This ensures consistent // placement of all induction updates. @@ -2225,7 +2469,7 @@ void InnerLoopVectorizer::createVectorIntInductionPHI( } bool InnerLoopVectorizer::shouldScalarizeInstruction(Instruction *I) const { - return Legal->isScalarAfterVectorization(I) || + return Cost->isScalarAfterVectorization(I, VF) || Cost->isProfitableToScalarize(I, VF); } @@ -2239,7 +2483,10 @@ bool InnerLoopVectorizer::needsScalarInduction(Instruction *IV) const { return any_of(IV->users(), isScalarInst); } -void InnerLoopVectorizer::widenIntInduction(PHINode *IV, TruncInst *Trunc) { +void InnerLoopVectorizer::widenIntOrFpInduction(PHINode *IV, TruncInst *Trunc) { + + assert((IV->getType()->isIntegerTy() || IV != OldInduction) && + "Primary induction variable must have an integer type"); auto II = Legal->getInductionVars()->find(IV); assert(II != Legal->getInductionVars()->end() && "IV is not an induction"); @@ -2251,9 +2498,6 @@ void InnerLoopVectorizer::widenIntInduction(PHINode *IV, TruncInst *Trunc) { // induction variable. Value *ScalarIV = nullptr; - // The step of the induction. - Value *Step = nullptr; - // The value from the original loop to which we are mapping the new induction // variable. Instruction *EntryVal = Trunc ? cast<Instruction>(Trunc) : IV; @@ -2266,45 +2510,49 @@ void InnerLoopVectorizer::widenIntInduction(PHINode *IV, TruncInst *Trunc) { // least one user in the loop that is not widened. auto NeedsScalarIV = VF > 1 && needsScalarInduction(EntryVal); - // If the induction variable has a constant integer step value, go ahead and - // get it now. - if (ID.getConstIntStepValue()) - Step = ID.getConstIntStepValue(); + // Generate code for the induction step. Note that induction steps are + // required to be loop-invariant + assert(PSE.getSE()->isLoopInvariant(ID.getStep(), OrigLoop) && + "Induction step should be loop invariant"); + auto &DL = OrigLoop->getHeader()->getModule()->getDataLayout(); + Value *Step = nullptr; + if (PSE.getSE()->isSCEVable(IV->getType())) { + SCEVExpander Exp(*PSE.getSE(), DL, "induction"); + Step = Exp.expandCodeFor(ID.getStep(), ID.getStep()->getType(), + LoopVectorPreHeader->getTerminator()); + } else { + Step = cast<SCEVUnknown>(ID.getStep())->getValue(); + } // Try to create a new independent vector induction variable. If we can't // create the phi node, we will splat the scalar induction variable in each // loop iteration. - if (VF > 1 && IV->getType() == Induction->getType() && Step && - !shouldScalarizeInstruction(EntryVal)) { - createVectorIntInductionPHI(ID, EntryVal); + if (VF > 1 && !shouldScalarizeInstruction(EntryVal)) { + createVectorIntOrFpInductionPHI(ID, Step, EntryVal); VectorizedIV = true; } // If we haven't yet vectorized the induction variable, or if we will create // a scalar one, we need to define the scalar induction variable and step // values. If we were given a truncation type, truncate the canonical - // induction variable and constant step. Otherwise, derive these values from - // the induction descriptor. + // induction variable and step. Otherwise, derive these values from the + // induction descriptor. if (!VectorizedIV || NeedsScalarIV) { + ScalarIV = Induction; + if (IV != OldInduction) { + ScalarIV = IV->getType()->isIntegerTy() + ? Builder.CreateSExtOrTrunc(Induction, IV->getType()) + : Builder.CreateCast(Instruction::SIToFP, Induction, + IV->getType()); + ScalarIV = ID.transform(Builder, ScalarIV, PSE.getSE(), DL); + ScalarIV->setName("offset.idx"); + } if (Trunc) { auto *TruncType = cast<IntegerType>(Trunc->getType()); - assert(Step && "Truncation requires constant integer step"); - auto StepInt = cast<ConstantInt>(Step)->getSExtValue(); - ScalarIV = Builder.CreateCast(Instruction::Trunc, Induction, TruncType); - Step = ConstantInt::getSigned(TruncType, StepInt); - } else { - ScalarIV = Induction; - auto &DL = OrigLoop->getHeader()->getModule()->getDataLayout(); - if (IV != OldInduction) { - ScalarIV = Builder.CreateSExtOrTrunc(ScalarIV, IV->getType()); - ScalarIV = ID.transform(Builder, ScalarIV, PSE.getSE(), DL); - ScalarIV->setName("offset.idx"); - } - if (!Step) { - SCEVExpander Exp(*PSE.getSE(), DL, "induction"); - Step = Exp.expandCodeFor(ID.getStep(), ID.getStep()->getType(), - &*Builder.GetInsertPoint()); - } + assert(Step->getType()->isIntegerTy() && + "Truncation requires an integer step"); + ScalarIV = Builder.CreateTrunc(ScalarIV, TruncType); + Step = Builder.CreateTrunc(Step, TruncType); } } @@ -2312,12 +2560,13 @@ void InnerLoopVectorizer::widenIntInduction(PHINode *IV, TruncInst *Trunc) { // induction variable, and build the necessary step vectors. if (!VectorizedIV) { Value *Broadcasted = getBroadcastInstrs(ScalarIV); - VectorParts Entry(UF); - for (unsigned Part = 0; Part < UF; ++Part) - Entry[Part] = getStepVector(Broadcasted, VF * Part, Step); - VectorLoopValueMap.initVector(EntryVal, Entry); - if (Trunc) - addMetadata(Entry, Trunc); + for (unsigned Part = 0; Part < UF; ++Part) { + Value *EntryPart = + getStepVector(Broadcasted, VF * Part, Step, ID.getInductionOpcode()); + VectorLoopValueMap.setVectorValue(EntryVal, Part, EntryPart); + if (Trunc) + addMetadata(EntryPart, Trunc); + } } // If an induction variable is only used for counting loop iterations or @@ -2327,7 +2576,7 @@ void InnerLoopVectorizer::widenIntInduction(PHINode *IV, TruncInst *Trunc) { // in the loop in the common case prior to InstCombine. We will be trading // one vector extract for each scalar step. if (NeedsScalarIV) - buildScalarSteps(ScalarIV, Step, EntryVal); + buildScalarSteps(ScalarIV, Step, EntryVal, ID); } Value *InnerLoopVectorizer::getStepVector(Value *Val, int StartIdx, Value *Step, @@ -2387,34 +2636,44 @@ Value *InnerLoopVectorizer::getStepVector(Value *Val, int StartIdx, Value *Step, } void InnerLoopVectorizer::buildScalarSteps(Value *ScalarIV, Value *Step, - Value *EntryVal) { + Value *EntryVal, + const InductionDescriptor &ID) { // We shouldn't have to build scalar steps if we aren't vectorizing. assert(VF > 1 && "VF should be greater than one"); // Get the value type and ensure it and the step have the same integer type. Type *ScalarIVTy = ScalarIV->getType()->getScalarType(); - assert(ScalarIVTy->isIntegerTy() && ScalarIVTy == Step->getType() && - "Val and Step should have the same integer type"); + assert(ScalarIVTy == Step->getType() && + "Val and Step should have the same type"); + + // We build scalar steps for both integer and floating-point induction + // variables. Here, we determine the kind of arithmetic we will perform. + Instruction::BinaryOps AddOp; + Instruction::BinaryOps MulOp; + if (ScalarIVTy->isIntegerTy()) { + AddOp = Instruction::Add; + MulOp = Instruction::Mul; + } else { + AddOp = ID.getInductionOpcode(); + MulOp = Instruction::FMul; + } // Determine the number of scalars we need to generate for each unroll // iteration. If EntryVal is uniform, we only need to generate the first // lane. Otherwise, we generate all VF values. unsigned Lanes = - Legal->isUniformAfterVectorization(cast<Instruction>(EntryVal)) ? 1 : VF; + Cost->isUniformAfterVectorization(cast<Instruction>(EntryVal), VF) ? 1 : VF; // Compute the scalar steps and save the results in VectorLoopValueMap. - ScalarParts Entry(UF); for (unsigned Part = 0; Part < UF; ++Part) { - Entry[Part].resize(VF); for (unsigned Lane = 0; Lane < Lanes; ++Lane) { - auto *StartIdx = ConstantInt::get(ScalarIVTy, VF * Part + Lane); - auto *Mul = Builder.CreateMul(StartIdx, Step); - auto *Add = Builder.CreateAdd(ScalarIV, Mul); - Entry[Part][Lane] = Add; + auto *StartIdx = getSignedIntOrFpConstant(ScalarIVTy, VF * Part + Lane); + auto *Mul = addFastMathFlag(Builder.CreateBinOp(MulOp, StartIdx, Step)); + auto *Add = addFastMathFlag(Builder.CreateBinOp(AddOp, ScalarIV, Mul)); + VectorLoopValueMap.setScalarValue(EntryVal, Part, Lane, Add); } } - VectorLoopValueMap.initScalar(EntryVal, Entry); } int LoopVectorizationLegality::isConsecutivePtr(Value *Ptr) { @@ -2432,8 +2691,7 @@ bool LoopVectorizationLegality::isUniform(Value *V) { return LAI->isUniform(V); } -const InnerLoopVectorizer::VectorParts & -InnerLoopVectorizer::getVectorValue(Value *V) { +Value *InnerLoopVectorizer::getOrCreateVectorValue(Value *V, unsigned Part) { assert(V != Induction && "The new induction variable should not be used."); assert(!V->getType()->isVectorTy() && "Can't widen a vector"); assert(!V->getType()->isVoidTy() && "Type does not produce a value"); @@ -2442,17 +2700,16 @@ InnerLoopVectorizer::getVectorValue(Value *V) { if (Legal->hasStride(V)) V = ConstantInt::get(V->getType(), 1); - // If we have this scalar in the map, return it. - if (VectorLoopValueMap.hasVector(V)) - return VectorLoopValueMap.VectorMapStorage[V]; + // If we have a vector mapped to this value, return it. + if (VectorLoopValueMap.hasVectorValue(V, Part)) + return VectorLoopValueMap.getVectorValue(V, Part); // If the value has not been vectorized, check if it has been scalarized // instead. If it has been scalarized, and we actually need the value in // vector form, we will construct the vector values on demand. - if (VectorLoopValueMap.hasScalar(V)) { + if (VectorLoopValueMap.hasAnyScalarValue(V)) { - // Initialize a new vector map entry. - VectorParts Entry(UF); + Value *ScalarValue = VectorLoopValueMap.getScalarValue(V, Part, 0); // If we've scalarized a value, that value should be an instruction. auto *I = cast<Instruction>(V); @@ -2460,17 +2717,17 @@ InnerLoopVectorizer::getVectorValue(Value *V) { // If we aren't vectorizing, we can just copy the scalar map values over to // the vector map. if (VF == 1) { - for (unsigned Part = 0; Part < UF; ++Part) - Entry[Part] = getScalarValue(V, Part, 0); - return VectorLoopValueMap.initVector(V, Entry); + VectorLoopValueMap.setVectorValue(V, Part, ScalarValue); + return ScalarValue; } - // Get the last scalar instruction we generated for V. If the value is - // known to be uniform after vectorization, this corresponds to lane zero - // of the last unroll iteration. Otherwise, the last instruction is the one - // we created for the last vector lane of the last unroll iteration. - unsigned LastLane = Legal->isUniformAfterVectorization(I) ? 0 : VF - 1; - auto *LastInst = cast<Instruction>(getScalarValue(V, UF - 1, LastLane)); + // Get the last scalar instruction we generated for V and Part. If the value + // is known to be uniform after vectorization, this corresponds to lane zero + // of the Part unroll iteration. Otherwise, the last instruction is the one + // we created for the last vector lane of the Part unroll iteration. + unsigned LastLane = Cost->isUniformAfterVectorization(I, VF) ? 0 : VF - 1; + auto *LastInst = + cast<Instruction>(VectorLoopValueMap.getScalarValue(V, Part, LastLane)); // Set the insert point after the last scalarized instruction. This ensures // the insertelement sequence will directly follow the scalar definitions. @@ -2484,51 +2741,50 @@ InnerLoopVectorizer::getVectorValue(Value *V) { // iteration. Otherwise, we construct the vector values using insertelement // instructions. Since the resulting vectors are stored in // VectorLoopValueMap, we will only generate the insertelements once. - for (unsigned Part = 0; Part < UF; ++Part) { - Value *VectorValue = nullptr; - if (Legal->isUniformAfterVectorization(I)) { - VectorValue = getBroadcastInstrs(getScalarValue(V, Part, 0)); - } else { - VectorValue = UndefValue::get(VectorType::get(V->getType(), VF)); - for (unsigned Lane = 0; Lane < VF; ++Lane) - VectorValue = Builder.CreateInsertElement( - VectorValue, getScalarValue(V, Part, Lane), - Builder.getInt32(Lane)); - } - Entry[Part] = VectorValue; + Value *VectorValue = nullptr; + if (Cost->isUniformAfterVectorization(I, VF)) { + VectorValue = getBroadcastInstrs(ScalarValue); + } else { + VectorValue = UndefValue::get(VectorType::get(V->getType(), VF)); + for (unsigned Lane = 0; Lane < VF; ++Lane) + VectorValue = Builder.CreateInsertElement( + VectorValue, getOrCreateScalarValue(V, Part, Lane), + Builder.getInt32(Lane)); } + VectorLoopValueMap.setVectorValue(V, Part, VectorValue); Builder.restoreIP(OldIP); - return VectorLoopValueMap.initVector(V, Entry); + return VectorValue; } // If this scalar is unknown, assume that it is a constant or that it is // loop invariant. Broadcast V and save the value for future uses. Value *B = getBroadcastInstrs(V); - return VectorLoopValueMap.initVector(V, VectorParts(UF, B)); + VectorLoopValueMap.setVectorValue(V, Part, B); + return B; } -Value *InnerLoopVectorizer::getScalarValue(Value *V, unsigned Part, - unsigned Lane) { +Value *InnerLoopVectorizer::getOrCreateScalarValue(Value *V, unsigned Part, + unsigned Lane) { // If the value is not an instruction contained in the loop, it should // already be scalar. if (OrigLoop->isLoopInvariant(V)) return V; - assert(Lane > 0 ? !Legal->isUniformAfterVectorization(cast<Instruction>(V)) + assert(Lane > 0 ? !Cost->isUniformAfterVectorization(cast<Instruction>(V), VF) : true && "Uniform values only have lane zero"); // If the value from the original loop has not been vectorized, it is // represented by UF x VF scalar values in the new loop. Return the requested // scalar value. - if (VectorLoopValueMap.hasScalar(V)) - return VectorLoopValueMap.ScalarMapStorage[V][Part][Lane]; + if (VectorLoopValueMap.hasScalarValue(V, Part, Lane)) + return VectorLoopValueMap.getScalarValue(V, Part, Lane); // If the value has not been scalarized, get its entry in VectorLoopValueMap // for the given unroll part. If this entry is not a vector type (i.e., the // vectorization factor is one), there is no need to generate an // extractelement instruction. - auto *U = getVectorValue(V)[Part]; + auto *U = getOrCreateVectorValue(V, Part); if (!U->getType()->isVectorTy()) { assert(VF == 1 && "Value not scalarized has non-vector type"); return U; @@ -2551,102 +2807,6 @@ Value *InnerLoopVectorizer::reverseVector(Value *Vec) { "reverse"); } -// Get a mask to interleave \p NumVec vectors into a wide vector. -// I.e. <0, VF, VF*2, ..., VF*(NumVec-1), 1, VF+1, VF*2+1, ...> -// E.g. For 2 interleaved vectors, if VF is 4, the mask is: -// <0, 4, 1, 5, 2, 6, 3, 7> -static Constant *getInterleavedMask(IRBuilder<> &Builder, unsigned VF, - unsigned NumVec) { - SmallVector<Constant *, 16> Mask; - for (unsigned i = 0; i < VF; i++) - for (unsigned j = 0; j < NumVec; j++) - Mask.push_back(Builder.getInt32(j * VF + i)); - - return ConstantVector::get(Mask); -} - -// Get the strided mask starting from index \p Start. -// I.e. <Start, Start + Stride, ..., Start + Stride*(VF-1)> -static Constant *getStridedMask(IRBuilder<> &Builder, unsigned Start, - unsigned Stride, unsigned VF) { - SmallVector<Constant *, 16> Mask; - for (unsigned i = 0; i < VF; i++) - Mask.push_back(Builder.getInt32(Start + i * Stride)); - - return ConstantVector::get(Mask); -} - -// Get a mask of two parts: The first part consists of sequential integers -// starting from 0, The second part consists of UNDEFs. -// I.e. <0, 1, 2, ..., NumInt - 1, undef, ..., undef> -static Constant *getSequentialMask(IRBuilder<> &Builder, unsigned NumInt, - unsigned NumUndef) { - SmallVector<Constant *, 16> Mask; - for (unsigned i = 0; i < NumInt; i++) - Mask.push_back(Builder.getInt32(i)); - - Constant *Undef = UndefValue::get(Builder.getInt32Ty()); - for (unsigned i = 0; i < NumUndef; i++) - Mask.push_back(Undef); - - return ConstantVector::get(Mask); -} - -// Concatenate two vectors with the same element type. The 2nd vector should -// not have more elements than the 1st vector. If the 2nd vector has less -// elements, extend it with UNDEFs. -static Value *ConcatenateTwoVectors(IRBuilder<> &Builder, Value *V1, - Value *V2) { - VectorType *VecTy1 = dyn_cast<VectorType>(V1->getType()); - VectorType *VecTy2 = dyn_cast<VectorType>(V2->getType()); - assert(VecTy1 && VecTy2 && - VecTy1->getScalarType() == VecTy2->getScalarType() && - "Expect two vectors with the same element type"); - - unsigned NumElts1 = VecTy1->getNumElements(); - unsigned NumElts2 = VecTy2->getNumElements(); - assert(NumElts1 >= NumElts2 && "Unexpect the first vector has less elements"); - - if (NumElts1 > NumElts2) { - // Extend with UNDEFs. - Constant *ExtMask = - getSequentialMask(Builder, NumElts2, NumElts1 - NumElts2); - V2 = Builder.CreateShuffleVector(V2, UndefValue::get(VecTy2), ExtMask); - } - - Constant *Mask = getSequentialMask(Builder, NumElts1 + NumElts2, 0); - return Builder.CreateShuffleVector(V1, V2, Mask); -} - -// Concatenate vectors in the given list. All vectors have the same type. -static Value *ConcatenateVectors(IRBuilder<> &Builder, - ArrayRef<Value *> InputList) { - unsigned NumVec = InputList.size(); - assert(NumVec > 1 && "Should be at least two vectors"); - - SmallVector<Value *, 8> ResList; - ResList.append(InputList.begin(), InputList.end()); - do { - SmallVector<Value *, 8> TmpList; - for (unsigned i = 0; i < NumVec - 1; i += 2) { - Value *V0 = ResList[i], *V1 = ResList[i + 1]; - assert((V0->getType() == V1->getType() || i == NumVec - 2) && - "Only the last vector may have a different type"); - - TmpList.push_back(ConcatenateTwoVectors(Builder, V0, V1)); - } - - // Push the last vector if the total number of vectors is odd. - if (NumVec % 2 != 0) - TmpList.push_back(ResList[NumVec - 1]); - - ResList = TmpList; - NumVec = ResList.size(); - } while (NumVec > 1); - - return ResList[0]; -} - // Try to vectorize the interleave group that \p Instr belongs to. // // E.g. Translate following interleaved load group (factor = 3): @@ -2683,15 +2843,13 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr) { if (Instr != Group->getInsertPos()) return; - LoadInst *LI = dyn_cast<LoadInst>(Instr); - StoreInst *SI = dyn_cast<StoreInst>(Instr); Value *Ptr = getPointerOperand(Instr); // Prepare for the vector type of the interleaved load/store. - Type *ScalarTy = LI ? LI->getType() : SI->getValueOperand()->getType(); + Type *ScalarTy = getMemInstValueType(Instr); unsigned InterleaveFactor = Group->getFactor(); Type *VecTy = VectorType::get(ScalarTy, InterleaveFactor * VF); - Type *PtrTy = VecTy->getPointerTo(Ptr->getType()->getPointerAddressSpace()); + Type *PtrTy = VecTy->getPointerTo(getMemInstAddressSpace(Instr)); // Prepare for the new pointers. setDebugLocFromInst(Builder, Ptr); @@ -2708,7 +2866,7 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr) { Index += (VF - 1) * Group->getFactor(); for (unsigned Part = 0; Part < UF; Part++) { - Value *NewPtr = getScalarValue(Ptr, Part, 0); + Value *NewPtr = getOrCreateScalarValue(Ptr, Part, 0); // Notice current instruction could be any index. Need to adjust the address // to the member of index 0. @@ -2731,7 +2889,7 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr) { Value *UndefVec = UndefValue::get(VecTy); // Vectorize the interleaved load group. - if (LI) { + if (isa<LoadInst>(Instr)) { // For each unroll part, create a wide load for the group. SmallVector<Value *, 2> NewLoads; @@ -2751,8 +2909,7 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr) { if (!Member) continue; - VectorParts Entry(UF); - Constant *StrideMask = getStridedMask(Builder, I, InterleaveFactor, VF); + Constant *StrideMask = createStrideMask(Builder, I, InterleaveFactor, VF); for (unsigned Part = 0; Part < UF; Part++) { Value *StridedVec = Builder.CreateShuffleVector( NewLoads[Part], UndefVec, StrideMask, "strided.vec"); @@ -2763,10 +2920,11 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr) { StridedVec = Builder.CreateBitOrPointerCast(StridedVec, OtherVTy); } - Entry[Part] = - Group->isReverse() ? reverseVector(StridedVec) : StridedVec; + if (Group->isReverse()) + StridedVec = reverseVector(StridedVec); + + VectorLoopValueMap.setVectorValue(Member, Part, StridedVec); } - VectorLoopValueMap.initVector(Member, Entry); } return; } @@ -2783,8 +2941,8 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr) { Instruction *Member = Group->getMember(i); assert(Member && "Fail to get a member from an interleaved store group"); - Value *StoredVec = - getVectorValue(cast<StoreInst>(Member)->getValueOperand())[Part]; + Value *StoredVec = getOrCreateVectorValue( + cast<StoreInst>(Member)->getValueOperand(), Part); if (Group->isReverse()) StoredVec = reverseVector(StoredVec); @@ -2796,10 +2954,10 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr) { } // Concatenate all vectors into a wide vector. - Value *WideVec = ConcatenateVectors(Builder, StoredVecs); + Value *WideVec = concatenateVectors(Builder, StoredVecs); // Interleave the elements in the wide vector. - Constant *IMask = getInterleavedMask(Builder, VF, InterleaveFactor); + Constant *IMask = createInterleaveMask(Builder, VF, InterleaveFactor); Value *IVec = Builder.CreateShuffleVector(WideVec, UndefVec, IMask, "interleaved.vec"); @@ -2816,104 +2974,43 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) { assert((LI || SI) && "Invalid Load/Store instruction"); - // Try to vectorize the interleave group if this access is interleaved. - if (Legal->isAccessInterleaved(Instr)) + LoopVectorizationCostModel::InstWidening Decision = + Cost->getWideningDecision(Instr, VF); + assert(Decision != LoopVectorizationCostModel::CM_Unknown && + "CM decision should be taken at this point"); + if (Decision == LoopVectorizationCostModel::CM_Interleave) return vectorizeInterleaveGroup(Instr); - Type *ScalarDataTy = LI ? LI->getType() : SI->getValueOperand()->getType(); + Type *ScalarDataTy = getMemInstValueType(Instr); Type *DataTy = VectorType::get(ScalarDataTy, VF); Value *Ptr = getPointerOperand(Instr); - unsigned Alignment = LI ? LI->getAlignment() : SI->getAlignment(); + unsigned Alignment = getMemInstAlignment(Instr); // An alignment of 0 means target abi alignment. We need to use the scalar's // target abi alignment in such a case. const DataLayout &DL = Instr->getModule()->getDataLayout(); if (!Alignment) Alignment = DL.getABITypeAlignment(ScalarDataTy); - unsigned AddressSpace = Ptr->getType()->getPointerAddressSpace(); + unsigned AddressSpace = getMemInstAddressSpace(Instr); // Scalarize the memory instruction if necessary. - if (Legal->memoryInstructionMustBeScalarized(Instr, VF)) + if (Decision == LoopVectorizationCostModel::CM_Scalarize) return scalarizeInstruction(Instr, Legal->isScalarWithPredication(Instr)); // Determine if the pointer operand of the access is either consecutive or // reverse consecutive. int ConsecutiveStride = Legal->isConsecutivePtr(Ptr); bool Reverse = ConsecutiveStride < 0; - - // Determine if either a gather or scatter operation is legal. bool CreateGatherScatter = - !ConsecutiveStride && Legal->isLegalGatherOrScatter(Instr); + (Decision == LoopVectorizationCostModel::CM_GatherScatter); - VectorParts VectorGep; + // Either Ptr feeds a vector load/store, or a vector GEP should feed a vector + // gather/scatter. Otherwise Decision should have been to Scalarize. + assert((ConsecutiveStride || CreateGatherScatter) && + "The instruction should be scalarized"); // Handle consecutive loads/stores. - GetElementPtrInst *Gep = getGEPInstruction(Ptr); - if (ConsecutiveStride) { - if (Gep) { - unsigned NumOperands = Gep->getNumOperands(); -#ifndef NDEBUG - // The original GEP that identified as a consecutive memory access - // should have only one loop-variant operand. - unsigned NumOfLoopVariantOps = 0; - for (unsigned i = 0; i < NumOperands; ++i) - if (!PSE.getSE()->isLoopInvariant(PSE.getSCEV(Gep->getOperand(i)), - OrigLoop)) - NumOfLoopVariantOps++; - assert(NumOfLoopVariantOps == 1 && - "Consecutive GEP should have only one loop-variant operand"); -#endif - GetElementPtrInst *Gep2 = cast<GetElementPtrInst>(Gep->clone()); - Gep2->setName("gep.indvar"); - - // A new GEP is created for a 0-lane value of the first unroll iteration. - // The GEPs for the rest of the unroll iterations are computed below as an - // offset from this GEP. - for (unsigned i = 0; i < NumOperands; ++i) - // We can apply getScalarValue() for all GEP indices. It returns an - // original value for loop-invariant operand and 0-lane for consecutive - // operand. - Gep2->setOperand(i, getScalarValue(Gep->getOperand(i), - 0, /* First unroll iteration */ - 0 /* 0-lane of the vector */ )); - setDebugLocFromInst(Builder, Gep); - Ptr = Builder.Insert(Gep2); - - } else { // No GEP - setDebugLocFromInst(Builder, Ptr); - Ptr = getScalarValue(Ptr, 0, 0); - } - } else { - // At this point we should vector version of GEP for Gather or Scatter - assert(CreateGatherScatter && "The instruction should be scalarized"); - if (Gep) { - // Vectorizing GEP, across UF parts. We want to get a vector value for base - // and each index that's defined inside the loop, even if it is - // loop-invariant but wasn't hoisted out. Otherwise we want to keep them - // scalar. - SmallVector<VectorParts, 4> OpsV; - for (Value *Op : Gep->operands()) { - Instruction *SrcInst = dyn_cast<Instruction>(Op); - if (SrcInst && OrigLoop->contains(SrcInst)) - OpsV.push_back(getVectorValue(Op)); - else - OpsV.push_back(VectorParts(UF, Op)); - } - for (unsigned Part = 0; Part < UF; ++Part) { - SmallVector<Value *, 4> Ops; - Value *GEPBasePtr = OpsV[0][Part]; - for (unsigned i = 1; i < Gep->getNumOperands(); i++) - Ops.push_back(OpsV[i][Part]); - Value *NewGep = Builder.CreateGEP(GEPBasePtr, Ops, "VectorGep"); - cast<GetElementPtrInst>(NewGep)->setIsInBounds(Gep->isInBounds()); - assert(NewGep->getType()->isVectorTy() && "Expected vector GEP"); - - NewGep = - Builder.CreateBitCast(NewGep, VectorType::get(Ptr->getType(), VF)); - VectorGep.push_back(NewGep); - } - } else - VectorGep = getVectorValue(Ptr); - } + if (ConsecutiveStride) + Ptr = getOrCreateScalarValue(Ptr, 0, 0); VectorParts Mask = createBlockInMask(Instr->getParent()); // Handle Stores: @@ -2921,16 +3018,15 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) { assert(!Legal->isUniform(SI->getPointerOperand()) && "We do not allow storing to uniform addresses"); setDebugLocFromInst(Builder, SI); - // We don't want to update the value in the map as it might be used in - // another expression. So don't use a reference type for "StoredVal". - VectorParts StoredVal = getVectorValue(SI->getValueOperand()); for (unsigned Part = 0; Part < UF; ++Part) { Instruction *NewSI = nullptr; + Value *StoredVal = getOrCreateVectorValue(SI->getValueOperand(), Part); if (CreateGatherScatter) { Value *MaskPart = Legal->isMaskRequired(SI) ? Mask[Part] : nullptr; - NewSI = Builder.CreateMaskedScatter(StoredVal[Part], VectorGep[Part], - Alignment, MaskPart); + Value *VectorGep = getOrCreateVectorValue(Ptr, Part); + NewSI = Builder.CreateMaskedScatter(StoredVal, VectorGep, Alignment, + MaskPart); } else { // Calculate the pointer for the specific unroll-part. Value *PartPtr = @@ -2939,7 +3035,10 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) { if (Reverse) { // If we store to reverse consecutive memory locations, then we need // to reverse the order of elements in the stored value. - StoredVal[Part] = reverseVector(StoredVal[Part]); + StoredVal = reverseVector(StoredVal); + // We don't want to update the value in the map as it might be used in + // another expression. So don't call resetVectorValue(StoredVal). + // If the address is consecutive but reversed, then the // wide store needs to start at the last vector element. PartPtr = @@ -2953,11 +3052,10 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) { Builder.CreateBitCast(PartPtr, DataTy->getPointerTo(AddressSpace)); if (Legal->isMaskRequired(SI)) - NewSI = Builder.CreateMaskedStore(StoredVal[Part], VecPtr, Alignment, + NewSI = Builder.CreateMaskedStore(StoredVal, VecPtr, Alignment, Mask[Part]); else - NewSI = - Builder.CreateAlignedStore(StoredVal[Part], VecPtr, Alignment); + NewSI = Builder.CreateAlignedStore(StoredVal, VecPtr, Alignment); } addMetadata(NewSI, SI); } @@ -2967,14 +3065,14 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) { // Handle loads. assert(LI && "Must have a load instruction"); setDebugLocFromInst(Builder, LI); - VectorParts Entry(UF); for (unsigned Part = 0; Part < UF; ++Part) { - Instruction *NewLI; + Value *NewLI; if (CreateGatherScatter) { Value *MaskPart = Legal->isMaskRequired(LI) ? Mask[Part] : nullptr; - NewLI = Builder.CreateMaskedGather(VectorGep[Part], Alignment, MaskPart, - 0, "wide.masked.gather"); - Entry[Part] = NewLI; + Value *VectorGep = getOrCreateVectorValue(Ptr, Part); + NewLI = Builder.CreateMaskedGather(VectorGep, Alignment, MaskPart, + nullptr, "wide.masked.gather"); + addMetadata(NewLI, LI); } else { // Calculate the pointer for the specific unroll-part. Value *PartPtr = @@ -2996,11 +3094,14 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) { "wide.masked.load"); else NewLI = Builder.CreateAlignedLoad(VecPtr, Alignment, "wide.load"); - Entry[Part] = Reverse ? reverseVector(NewLI) : NewLI; + + // Add metadata to the load, but setVectorValue to the reverse shuffle. + addMetadata(NewLI, LI); + if (Reverse) + NewLI = reverseVector(NewLI); } - addMetadata(NewLI, LI); + VectorLoopValueMap.setVectorValue(Instr, Part, NewLI); } - VectorLoopValueMap.initVector(Instr, Entry); } void InnerLoopVectorizer::scalarizeInstruction(Instruction *Instr, @@ -3017,9 +3118,6 @@ void InnerLoopVectorizer::scalarizeInstruction(Instruction *Instr, // Does this instruction return a value ? bool IsVoidRetTy = Instr->getType()->isVoidTy(); - // Initialize a new scalar map entry. - ScalarParts Entry(UF); - VectorParts Cond; if (IfPredicateInstr) Cond = createBlockInMask(Instr->getParent()); @@ -3027,18 +3125,19 @@ void InnerLoopVectorizer::scalarizeInstruction(Instruction *Instr, // Determine the number of scalars we need to generate for each unroll // iteration. If the instruction is uniform, we only need to generate the // first lane. Otherwise, we generate all VF values. - unsigned Lanes = Legal->isUniformAfterVectorization(Instr) ? 1 : VF; + unsigned Lanes = Cost->isUniformAfterVectorization(Instr, VF) ? 1 : VF; // For each vector unroll 'part': for (unsigned Part = 0; Part < UF; ++Part) { - Entry[Part].resize(VF); // For each scalar that we create: for (unsigned Lane = 0; Lane < Lanes; ++Lane) { // Start if-block. Value *Cmp = nullptr; if (IfPredicateInstr) { - Cmp = Builder.CreateExtractElement(Cond[Part], Builder.getInt32(Lane)); + Cmp = Cond[Part]; + if (Cmp->getType()->isVectorTy()) + Cmp = Builder.CreateExtractElement(Cmp, Builder.getInt32(Lane)); Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Cmp, ConstantInt::get(Cmp->getType(), 1)); } @@ -3050,7 +3149,7 @@ void InnerLoopVectorizer::scalarizeInstruction(Instruction *Instr, // Replace the operands of the cloned instructions with their scalar // equivalents in the new loop. for (unsigned op = 0, e = Instr->getNumOperands(); op != e; ++op) { - auto *NewOp = getScalarValue(Instr->getOperand(op), Part, Lane); + auto *NewOp = getOrCreateScalarValue(Instr->getOperand(op), Part, Lane); Cloned->setOperand(op, NewOp); } addNewMetadata(Cloned, Instr); @@ -3059,7 +3158,7 @@ void InnerLoopVectorizer::scalarizeInstruction(Instruction *Instr, Builder.Insert(Cloned); // Add the cloned scalar to the scalar map entry. - Entry[Part][Lane] = Cloned; + VectorLoopValueMap.setScalarValue(Instr, Part, Lane, Cloned); // If we just cloned a new assumption, add it the assumption cache. if (auto *II = dyn_cast<IntrinsicInst>(Cloned)) @@ -3071,7 +3170,6 @@ void InnerLoopVectorizer::scalarizeInstruction(Instruction *Instr, PredicatedInstructions.push_back(std::make_pair(Cloned, Cmp)); } } - VectorLoopValueMap.initScalar(Instr, Entry); } PHINode *InnerLoopVectorizer::createInductionVariable(Loop *L, Value *Start, @@ -3189,37 +3287,16 @@ void InnerLoopVectorizer::emitMinimumIterationCountCheck(Loop *L, BasicBlock *BB = L->getLoopPreheader(); IRBuilder<> Builder(BB->getTerminator()); - // Generate code to check that the loop's trip count that we computed by - // adding one to the backedge-taken count will not overflow. - Value *CheckMinIters = Builder.CreateICmpULT( - Count, ConstantInt::get(Count->getType(), VF * UF), "min.iters.check"); - - BasicBlock *NewBB = - BB->splitBasicBlock(BB->getTerminator(), "min.iters.checked"); - // Update dominator tree immediately if the generated block is a - // LoopBypassBlock because SCEV expansions to generate loop bypass - // checks may query it before the current function is finished. - DT->addNewBlock(NewBB, BB); - if (L->getParentLoop()) - L->getParentLoop()->addBasicBlockToLoop(NewBB, *LI); - ReplaceInstWithInst(BB->getTerminator(), - BranchInst::Create(Bypass, NewBB, CheckMinIters)); - LoopBypassBlocks.push_back(BB); -} - -void InnerLoopVectorizer::emitVectorLoopEnteredCheck(Loop *L, - BasicBlock *Bypass) { - Value *TC = getOrCreateVectorTripCount(L); - BasicBlock *BB = L->getLoopPreheader(); - IRBuilder<> Builder(BB->getTerminator()); - - // Now, compare the new count to zero. If it is zero skip the vector loop and - // jump to the scalar loop. - Value *Cmp = Builder.CreateICmpEQ(TC, Constant::getNullValue(TC->getType()), - "cmp.zero"); + // Generate code to check if the loop's trip count is less than VF * UF, or + // equal to it in case a scalar epilogue is required; this implies that the + // vector trip count is zero. This check also covers the case where adding one + // to the backedge-taken count overflowed leading to an incorrect trip count + // of zero. In this case we will also jump to the scalar loop. + auto P = Legal->requiresScalarEpilogue() ? ICmpInst::ICMP_ULE + : ICmpInst::ICMP_ULT; + Value *CheckMinIters = Builder.CreateICmp( + P, Count, ConstantInt::get(Count->getType(), VF * UF), "min.iters.check"); - // Generate code to check that the loop's trip count that we computed by - // adding one to the backedge-taken count will not overflow. BasicBlock *NewBB = BB->splitBasicBlock(BB->getTerminator(), "vector.ph"); // Update dominator tree immediately if the generated block is a // LoopBypassBlock because SCEV expansions to generate loop bypass @@ -3228,7 +3305,7 @@ void InnerLoopVectorizer::emitVectorLoopEnteredCheck(Loop *L, if (L->getParentLoop()) L->getParentLoop()->addBasicBlockToLoop(NewBB, *LI); ReplaceInstWithInst(BB->getTerminator(), - BranchInst::Create(Bypass, NewBB, Cmp)); + BranchInst::Create(Bypass, NewBB, CheckMinIters)); LoopBypassBlocks.push_back(BB); } @@ -3296,7 +3373,7 @@ void InnerLoopVectorizer::emitMemRuntimeChecks(Loop *L, BasicBlock *Bypass) { LVer->prepareNoAliasMetadata(); } -void InnerLoopVectorizer::createEmptyLoop() { +void InnerLoopVectorizer::createVectorizedLoopSkeleton() { /* In this function we generate a new loop. The new loop will contain the vectorized instructions while the old loop will continue to run the @@ -3346,7 +3423,7 @@ void InnerLoopVectorizer::createEmptyLoop() { // - counts from zero, stepping by one // - is the size of the widest induction variable type // then we create a new one. - OldInduction = Legal->getInduction(); + OldInduction = Legal->getPrimaryInduction(); Type *IdxTy = Legal->getWidestInductionType(); // Split the single block loop into the two loop structure described above. @@ -3377,14 +3454,13 @@ void InnerLoopVectorizer::createEmptyLoop() { Value *StartIdx = ConstantInt::get(IdxTy, 0); - // We need to test whether the backedge-taken count is uint##_max. Adding one - // to it will cause overflow and an incorrect loop trip count in the vector - // body. In case of overflow we want to directly jump to the scalar remainder - // loop. - emitMinimumIterationCountCheck(Lp, ScalarPH); // Now, compare the new count to zero. If it is zero skip the vector loop and - // jump to the scalar loop. - emitVectorLoopEnteredCheck(Lp, ScalarPH); + // jump to the scalar loop. This check also covers the case where the + // backedge-taken count is uint##_max: adding one to it will overflow leading + // to an incorrect trip count of zero. In this (rare) case we will also jump + // to the scalar loop. + emitMinimumIterationCountCheck(Lp, ScalarPH); + // Generate the code to check any assumptions that we've made for SCEV // expressions. emitSCEVChecks(Lp, ScalarPH); @@ -3427,7 +3503,7 @@ void InnerLoopVectorizer::createEmptyLoop() { // We know what the end value is. EndValue = CountRoundDown; } else { - IRBuilder<> B(LoopBypassBlocks.back()->getTerminator()); + IRBuilder<> B(Lp->getLoopPreheader()->getTerminator()); Type *StepType = II.getStep()->getType(); Instruction::CastOps CastOp = CastInst::getCastOpcode(CountRoundDown, true, StepType, true); @@ -3521,8 +3597,12 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi, IRBuilder<> B(MiddleBlock->getTerminator()); Value *CountMinusOne = B.CreateSub( CountRoundDown, ConstantInt::get(CountRoundDown->getType(), 1)); - Value *CMO = B.CreateSExtOrTrunc(CountMinusOne, II.getStep()->getType(), - "cast.cmo"); + Value *CMO = + !II.getStep()->getType()->isIntegerTy() + ? B.CreateCast(Instruction::SIToFP, CountMinusOne, + II.getStep()->getType()) + : B.CreateSExtOrTrunc(CountMinusOne, II.getStep()->getType()); + CMO->setName("cast.cmo"); Value *Escape = II.transform(B, CMO, PSE.getSE(), DL); Escape->setName("ind.escape"); MissingVals[UI] = Escape; @@ -3543,7 +3623,7 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi, namespace { struct CSEDenseMapInfo { - static bool canHandle(Instruction *I) { + static bool canHandle(const Instruction *I) { return isa<InsertElementInst>(I) || isa<ExtractElementInst>(I) || isa<ShuffleVectorInst>(I) || isa<GetElementPtrInst>(I); } @@ -3553,12 +3633,12 @@ struct CSEDenseMapInfo { static inline Instruction *getTombstoneKey() { return DenseMapInfo<Instruction *>::getTombstoneKey(); } - static unsigned getHashValue(Instruction *I) { + static unsigned getHashValue(const Instruction *I) { assert(canHandle(I) && "Unknown instruction!"); return hash_combine(I->getOpcode(), hash_combine_range(I->value_op_begin(), I->value_op_end())); } - static bool isEqual(Instruction *LHS, Instruction *RHS) { + static bool isEqual(const Instruction *LHS, const Instruction *RHS) { if (LHS == getEmptyKey() || RHS == getEmptyKey() || LHS == getTombstoneKey() || RHS == getTombstoneKey()) return LHS == RHS; @@ -3589,51 +3669,6 @@ static void cse(BasicBlock *BB) { } } -/// \brief Adds a 'fast' flag to floating point operations. -static Value *addFastMathFlag(Value *V) { - if (isa<FPMathOperator>(V)) { - FastMathFlags Flags; - Flags.setUnsafeAlgebra(); - cast<Instruction>(V)->setFastMathFlags(Flags); - } - return V; -} - -/// \brief Estimate the overhead of scalarizing a value based on its type. -/// Insert and Extract are set if the result needs to be inserted and/or -/// extracted from vectors. -static unsigned getScalarizationOverhead(Type *Ty, bool Insert, bool Extract, - const TargetTransformInfo &TTI) { - if (Ty->isVoidTy()) - return 0; - - assert(Ty->isVectorTy() && "Can only scalarize vectors"); - unsigned Cost = 0; - - for (unsigned I = 0, E = Ty->getVectorNumElements(); I < E; ++I) { - if (Extract) - Cost += TTI.getVectorInstrCost(Instruction::ExtractElement, Ty, I); - if (Insert) - Cost += TTI.getVectorInstrCost(Instruction::InsertElement, Ty, I); - } - - return Cost; -} - -/// \brief Estimate the overhead of scalarizing an Instruction based on the -/// types of its operands and return value. -static unsigned getScalarizationOverhead(SmallVectorImpl<Type *> &OpTys, - Type *RetTy, - const TargetTransformInfo &TTI) { - unsigned ScalarizationCost = - getScalarizationOverhead(RetTy, true, false, TTI); - - for (Type *Ty : OpTys) - ScalarizationCost += getScalarizationOverhead(Ty, false, true, TTI); - - return ScalarizationCost; -} - /// \brief Estimate the overhead of scalarizing an instruction. This is a /// convenience wrapper for the type-based getScalarizationOverhead API. static unsigned getScalarizationOverhead(Instruction *I, unsigned VF, @@ -3641,14 +3676,24 @@ static unsigned getScalarizationOverhead(Instruction *I, unsigned VF, if (VF == 1) return 0; + unsigned Cost = 0; Type *RetTy = ToVectorTy(I->getType(), VF); + if (!RetTy->isVoidTy() && + (!isa<LoadInst>(I) || + !TTI.supportsEfficientVectorElementLoadStore())) + Cost += TTI.getScalarizationOverhead(RetTy, true, false); - SmallVector<Type *, 4> OpTys; - unsigned OperandsNum = I->getNumOperands(); - for (unsigned OpInd = 0; OpInd < OperandsNum; ++OpInd) - OpTys.push_back(ToVectorTy(I->getOperand(OpInd)->getType(), VF)); + if (CallInst *CI = dyn_cast<CallInst>(I)) { + SmallVector<const Value *, 4> Operands(CI->arg_operands()); + Cost += TTI.getOperandsScalarizationOverhead(Operands, VF); + } + else if (!isa<StoreInst>(I) || + !TTI.supportsEfficientVectorElementLoadStore()) { + SmallVector<const Value *, 4> Operands(I->operand_values()); + Cost += TTI.getOperandsScalarizationOverhead(Operands, VF); + } - return getScalarizationOverhead(OpTys, RetTy, TTI); + return Cost; } // Estimate cost of a call instruction CI if it were vectorized with factor VF. @@ -3681,7 +3726,7 @@ static unsigned getVectorCallCost(CallInst *CI, unsigned VF, // Compute costs of unpacking argument values for the scalar calls and // packing the return values to a vector. - unsigned ScalarizationCost = getScalarizationOverhead(Tys, RetTy, TTI); + unsigned ScalarizationCost = getScalarizationOverhead(CI, VF, TTI); unsigned Cost = ScalarCallCost * VF + ScalarizationCost; @@ -3709,16 +3754,12 @@ static unsigned getVectorIntrinsicCost(CallInst *CI, unsigned VF, Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); assert(ID && "Expected intrinsic call!"); - Type *RetTy = ToVectorTy(CI->getType(), VF); - SmallVector<Type *, 4> Tys; - for (Value *ArgOperand : CI->arg_operands()) - Tys.push_back(ToVectorTy(ArgOperand->getType(), VF)); - FastMathFlags FMF; if (auto *FPMO = dyn_cast<FPMathOperator>(CI)) FMF = FPMO->getFastMathFlags(); - return TTI.getIntrinsicInstrCost(ID, RetTy, Tys, FMF); + SmallVector<Value *, 4> Operands(CI->arg_operands()); + return TTI.getIntrinsicInstrCost(ID, CI->getType(), Operands, FMF, VF); } static Type *smallestIntegerVectorType(Type *T1, Type *T2) { @@ -3742,10 +3783,10 @@ void InnerLoopVectorizer::truncateToMinimalBitwidths() { // If the value wasn't vectorized, we must maintain the original scalar // type. The absence of the value from VectorLoopValueMap indicates that it // wasn't vectorized. - if (!VectorLoopValueMap.hasVector(KV.first)) + if (!VectorLoopValueMap.hasAnyVectorValue(KV.first)) continue; - VectorParts &Parts = VectorLoopValueMap.getVector(KV.first); - for (Value *&I : Parts) { + for (unsigned Part = 0; Part < UF; ++Part) { + Value *I = getOrCreateVectorValue(KV.first, Part); if (Erased.count(I) || I->use_empty() || !isa<Instruction>(I)) continue; Type *OriginalTy = I->getType(); @@ -3770,7 +3811,11 @@ void InnerLoopVectorizer::truncateToMinimalBitwidths() { if (auto *BO = dyn_cast<BinaryOperator>(I)) { NewI = B.CreateBinOp(BO->getOpcode(), ShrinkOperand(BO->getOperand(0)), ShrinkOperand(BO->getOperand(1))); - cast<BinaryOperator>(NewI)->copyIRFlags(I); + + // Any wrapping introduced by shrinking this operation shouldn't be + // considered undefined behavior. So, we can't unconditionally copy + // arithmetic wrapping flags to NewI. + cast<BinaryOperator>(NewI)->copyIRFlags(I, /*IncludeWrapFlags=*/false); } else if (auto *CI = dyn_cast<ICmpInst>(I)) { NewI = B.CreateICmp(CI->getPredicate(), ShrinkOperand(CI->getOperand(0)), @@ -3830,7 +3875,7 @@ void InnerLoopVectorizer::truncateToMinimalBitwidths() { I->replaceAllUsesWith(Res); cast<Instruction>(I)->eraseFromParent(); Erased.insert(I); - I = Res; + VectorLoopValueMap.resetVectorValue(KV.first, Part, Res); } } @@ -3839,277 +3884,31 @@ void InnerLoopVectorizer::truncateToMinimalBitwidths() { // If the value wasn't vectorized, we must maintain the original scalar // type. The absence of the value from VectorLoopValueMap indicates that it // wasn't vectorized. - if (!VectorLoopValueMap.hasVector(KV.first)) + if (!VectorLoopValueMap.hasAnyVectorValue(KV.first)) continue; - VectorParts &Parts = VectorLoopValueMap.getVector(KV.first); - for (Value *&I : Parts) { + for (unsigned Part = 0; Part < UF; ++Part) { + Value *I = getOrCreateVectorValue(KV.first, Part); ZExtInst *Inst = dyn_cast<ZExtInst>(I); if (Inst && Inst->use_empty()) { Value *NewI = Inst->getOperand(0); Inst->eraseFromParent(); - I = NewI; + VectorLoopValueMap.resetVectorValue(KV.first, Part, NewI); } } } } -void InnerLoopVectorizer::vectorizeLoop() { - //===------------------------------------------------===// - // - // Notice: any optimization or new instruction that go - // into the code below should be also be implemented in - // the cost-model. - // - //===------------------------------------------------===// - Constant *Zero = Builder.getInt32(0); - - // In order to support recurrences we need to be able to vectorize Phi nodes. - // Phi nodes have cycles, so we need to vectorize them in two stages. First, - // we create a new vector PHI node with no incoming edges. We use this value - // when we vectorize all of the instructions that use the PHI. Next, after - // all of the instructions in the block are complete we add the new incoming - // edges to the PHI. At this point all of the instructions in the basic block - // are vectorized, so we can use them to construct the PHI. - PhiVector PHIsToFix; - - // Collect instructions from the original loop that will become trivially - // dead in the vectorized loop. We don't need to vectorize these - // instructions. - collectTriviallyDeadInstructions(); - - // Scan the loop in a topological order to ensure that defs are vectorized - // before users. - LoopBlocksDFS DFS(OrigLoop); - DFS.perform(LI); - - // Vectorize all of the blocks in the original loop. - for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO())) - vectorizeBlockInLoop(BB, &PHIsToFix); - +void InnerLoopVectorizer::fixVectorizedLoop() { // Insert truncates and extends for any truncated instructions as hints to // InstCombine. if (VF > 1) truncateToMinimalBitwidths(); // At this point every instruction in the original loop is widened to a - // vector form. Now we need to fix the recurrences in PHIsToFix. These PHI + // vector form. Now we need to fix the recurrences in the loop. These PHI // nodes are currently empty because we did not want to introduce cycles. // This is the second stage of vectorizing recurrences. - for (PHINode *Phi : PHIsToFix) { - assert(Phi && "Unable to recover vectorized PHI"); - - // Handle first-order recurrences that need to be fixed. - if (Legal->isFirstOrderRecurrence(Phi)) { - fixFirstOrderRecurrence(Phi); - continue; - } - - // If the phi node is not a first-order recurrence, it must be a reduction. - // Get it's reduction variable descriptor. - assert(Legal->isReductionVariable(Phi) && - "Unable to find the reduction variable"); - RecurrenceDescriptor RdxDesc = (*Legal->getReductionVars())[Phi]; - - RecurrenceDescriptor::RecurrenceKind RK = RdxDesc.getRecurrenceKind(); - TrackingVH<Value> ReductionStartValue = RdxDesc.getRecurrenceStartValue(); - Instruction *LoopExitInst = RdxDesc.getLoopExitInstr(); - RecurrenceDescriptor::MinMaxRecurrenceKind MinMaxKind = - RdxDesc.getMinMaxRecurrenceKind(); - setDebugLocFromInst(Builder, ReductionStartValue); - - // We need to generate a reduction vector from the incoming scalar. - // To do so, we need to generate the 'identity' vector and override - // one of the elements with the incoming scalar reduction. We need - // to do it in the vector-loop preheader. - Builder.SetInsertPoint(LoopBypassBlocks[1]->getTerminator()); - - // This is the vector-clone of the value that leaves the loop. - const VectorParts &VectorExit = getVectorValue(LoopExitInst); - Type *VecTy = VectorExit[0]->getType(); - - // Find the reduction identity variable. Zero for addition, or, xor, - // one for multiplication, -1 for And. - Value *Identity; - Value *VectorStart; - if (RK == RecurrenceDescriptor::RK_IntegerMinMax || - RK == RecurrenceDescriptor::RK_FloatMinMax) { - // MinMax reduction have the start value as their identify. - if (VF == 1) { - VectorStart = Identity = ReductionStartValue; - } else { - VectorStart = Identity = - Builder.CreateVectorSplat(VF, ReductionStartValue, "minmax.ident"); - } - } else { - // Handle other reduction kinds: - Constant *Iden = RecurrenceDescriptor::getRecurrenceIdentity( - RK, VecTy->getScalarType()); - if (VF == 1) { - Identity = Iden; - // This vector is the Identity vector where the first element is the - // incoming scalar reduction. - VectorStart = ReductionStartValue; - } else { - Identity = ConstantVector::getSplat(VF, Iden); - - // This vector is the Identity vector where the first element is the - // incoming scalar reduction. - VectorStart = - Builder.CreateInsertElement(Identity, ReductionStartValue, Zero); - } - } - - // Fix the vector-loop phi. - - // Reductions do not have to start at zero. They can start with - // any loop invariant values. - const VectorParts &VecRdxPhi = getVectorValue(Phi); - BasicBlock *Latch = OrigLoop->getLoopLatch(); - Value *LoopVal = Phi->getIncomingValueForBlock(Latch); - const VectorParts &Val = getVectorValue(LoopVal); - for (unsigned part = 0; part < UF; ++part) { - // Make sure to add the reduction stat value only to the - // first unroll part. - Value *StartVal = (part == 0) ? VectorStart : Identity; - cast<PHINode>(VecRdxPhi[part]) - ->addIncoming(StartVal, LoopVectorPreHeader); - cast<PHINode>(VecRdxPhi[part]) - ->addIncoming(Val[part], LoopVectorBody); - } - - // Before each round, move the insertion point right between - // the PHIs and the values we are going to write. - // This allows us to write both PHINodes and the extractelement - // instructions. - Builder.SetInsertPoint(&*LoopMiddleBlock->getFirstInsertionPt()); - - VectorParts &RdxParts = VectorLoopValueMap.getVector(LoopExitInst); - setDebugLocFromInst(Builder, LoopExitInst); - - // If the vector reduction can be performed in a smaller type, we truncate - // then extend the loop exit value to enable InstCombine to evaluate the - // entire expression in the smaller type. - if (VF > 1 && Phi->getType() != RdxDesc.getRecurrenceType()) { - Type *RdxVecTy = VectorType::get(RdxDesc.getRecurrenceType(), VF); - Builder.SetInsertPoint(LoopVectorBody->getTerminator()); - for (unsigned part = 0; part < UF; ++part) { - Value *Trunc = Builder.CreateTrunc(RdxParts[part], RdxVecTy); - Value *Extnd = RdxDesc.isSigned() ? Builder.CreateSExt(Trunc, VecTy) - : Builder.CreateZExt(Trunc, VecTy); - for (Value::user_iterator UI = RdxParts[part]->user_begin(); - UI != RdxParts[part]->user_end();) - if (*UI != Trunc) { - (*UI++)->replaceUsesOfWith(RdxParts[part], Extnd); - RdxParts[part] = Extnd; - } else { - ++UI; - } - } - Builder.SetInsertPoint(&*LoopMiddleBlock->getFirstInsertionPt()); - for (unsigned part = 0; part < UF; ++part) - RdxParts[part] = Builder.CreateTrunc(RdxParts[part], RdxVecTy); - } - - // Reduce all of the unrolled parts into a single vector. - Value *ReducedPartRdx = RdxParts[0]; - unsigned Op = RecurrenceDescriptor::getRecurrenceBinOp(RK); - setDebugLocFromInst(Builder, ReducedPartRdx); - for (unsigned part = 1; part < UF; ++part) { - if (Op != Instruction::ICmp && Op != Instruction::FCmp) - // Floating point operations had to be 'fast' to enable the reduction. - ReducedPartRdx = addFastMathFlag( - Builder.CreateBinOp((Instruction::BinaryOps)Op, RdxParts[part], - ReducedPartRdx, "bin.rdx")); - else - ReducedPartRdx = RecurrenceDescriptor::createMinMaxOp( - Builder, MinMaxKind, ReducedPartRdx, RdxParts[part]); - } - - if (VF > 1) { - // VF is a power of 2 so we can emit the reduction using log2(VF) shuffles - // and vector ops, reducing the set of values being computed by half each - // round. - assert(isPowerOf2_32(VF) && - "Reduction emission only supported for pow2 vectors!"); - Value *TmpVec = ReducedPartRdx; - SmallVector<Constant *, 32> ShuffleMask(VF, nullptr); - for (unsigned i = VF; i != 1; i >>= 1) { - // Move the upper half of the vector to the lower half. - for (unsigned j = 0; j != i / 2; ++j) - ShuffleMask[j] = Builder.getInt32(i / 2 + j); - - // Fill the rest of the mask with undef. - std::fill(&ShuffleMask[i / 2], ShuffleMask.end(), - UndefValue::get(Builder.getInt32Ty())); - - Value *Shuf = Builder.CreateShuffleVector( - TmpVec, UndefValue::get(TmpVec->getType()), - ConstantVector::get(ShuffleMask), "rdx.shuf"); - - if (Op != Instruction::ICmp && Op != Instruction::FCmp) - // Floating point operations had to be 'fast' to enable the reduction. - TmpVec = addFastMathFlag(Builder.CreateBinOp( - (Instruction::BinaryOps)Op, TmpVec, Shuf, "bin.rdx")); - else - TmpVec = RecurrenceDescriptor::createMinMaxOp(Builder, MinMaxKind, - TmpVec, Shuf); - } - - // The result is in the first element of the vector. - ReducedPartRdx = - Builder.CreateExtractElement(TmpVec, Builder.getInt32(0)); - - // If the reduction can be performed in a smaller type, we need to extend - // the reduction to the wider type before we branch to the original loop. - if (Phi->getType() != RdxDesc.getRecurrenceType()) - ReducedPartRdx = - RdxDesc.isSigned() - ? Builder.CreateSExt(ReducedPartRdx, Phi->getType()) - : Builder.CreateZExt(ReducedPartRdx, Phi->getType()); - } - - // Create a phi node that merges control-flow from the backedge-taken check - // block and the middle block. - PHINode *BCBlockPhi = PHINode::Create(Phi->getType(), 2, "bc.merge.rdx", - LoopScalarPreHeader->getTerminator()); - for (unsigned I = 0, E = LoopBypassBlocks.size(); I != E; ++I) - BCBlockPhi->addIncoming(ReductionStartValue, LoopBypassBlocks[I]); - BCBlockPhi->addIncoming(ReducedPartRdx, LoopMiddleBlock); - - // Now, we need to fix the users of the reduction variable - // inside and outside of the scalar remainder loop. - // We know that the loop is in LCSSA form. We need to update the - // PHI nodes in the exit blocks. - for (BasicBlock::iterator LEI = LoopExitBlock->begin(), - LEE = LoopExitBlock->end(); - LEI != LEE; ++LEI) { - PHINode *LCSSAPhi = dyn_cast<PHINode>(LEI); - if (!LCSSAPhi) - break; - - // All PHINodes need to have a single entry edge, or two if - // we already fixed them. - assert(LCSSAPhi->getNumIncomingValues() < 3 && "Invalid LCSSA PHI"); - - // We found our reduction value exit-PHI. Update it with the - // incoming bypass edge. - if (LCSSAPhi->getIncomingValue(0) == LoopExitInst) { - // Add an edge coming from the bypass. - LCSSAPhi->addIncoming(ReducedPartRdx, LoopMiddleBlock); - break; - } - } // end of the LCSSA phi scan. - - // Fix the scalar loop reduction variable with the incoming reduction sum - // from the vector body and from the backedge value. - int IncomingEdgeBlockIdx = - Phi->getBasicBlockIndex(OrigLoop->getLoopLatch()); - assert(IncomingEdgeBlockIdx >= 0 && "Invalid block index"); - // Pick the other block. - int SelfEdgeBlockIdx = (IncomingEdgeBlockIdx ? 0 : 1); - Phi->setIncomingValue(SelfEdgeBlockIdx, BCBlockPhi); - Phi->setIncomingValue(IncomingEdgeBlockIdx, LoopExitInst); - } // end of for each Phi in PHIsToFix. + fixCrossIterationPHIs(); // Update the dominator tree. // @@ -4134,6 +3933,25 @@ void InnerLoopVectorizer::vectorizeLoop() { cse(LoopVectorBody); } +void InnerLoopVectorizer::fixCrossIterationPHIs() { + // In order to support recurrences we need to be able to vectorize Phi nodes. + // Phi nodes have cycles, so we need to vectorize them in two stages. This is + // stage #2: We now need to fix the recurrences by adding incoming edges to + // the currently empty PHI nodes. At this point every instruction in the + // original loop is widened to a vector form so we can use them to construct + // the incoming edges. + for (Instruction &I : *OrigLoop->getHeader()) { + PHINode *Phi = dyn_cast<PHINode>(&I); + if (!Phi) + break; + // Handle first-order recurrences and reductions that need to be fixed. + if (Legal->isFirstOrderRecurrence(Phi)) + fixFirstOrderRecurrence(Phi); + else if (Legal->isReductionVariable(Phi)) + fixReduction(Phi); + } +} + void InnerLoopVectorizer::fixFirstOrderRecurrence(PHINode *Phi) { // This is the second phase of vectorizing first-order recurrences. An @@ -4204,23 +4022,29 @@ void InnerLoopVectorizer::fixFirstOrderRecurrence(PHINode *Phi) { // We constructed a temporary phi node in the first phase of vectorization. // This phi node will eventually be deleted. - VectorParts &PhiParts = VectorLoopValueMap.getVector(Phi); - Builder.SetInsertPoint(cast<Instruction>(PhiParts[0])); + Builder.SetInsertPoint( + cast<Instruction>(VectorLoopValueMap.getVectorValue(Phi, 0))); // Create a phi node for the new recurrence. The current value will either be // the initial value inserted into a vector or loop-varying vector value. auto *VecPhi = Builder.CreatePHI(VectorInit->getType(), 2, "vector.recur"); VecPhi->addIncoming(VectorInit, LoopVectorPreHeader); - // Get the vectorized previous value. We ensured the previous values was an - // instruction when detecting the recurrence. - auto &PreviousParts = getVectorValue(Previous); - - // Set the insertion point to be after this instruction. We ensured the - // previous value dominated all uses of the phi when detecting the - // recurrence. - Builder.SetInsertPoint( - &*++BasicBlock::iterator(cast<Instruction>(PreviousParts[UF - 1]))); + // Get the vectorized previous value of the last part UF - 1. It appears last + // among all unrolled iterations, due to the order of their construction. + Value *PreviousLastPart = getOrCreateVectorValue(Previous, UF - 1); + + // Set the insertion point after the previous value if it is an instruction. + // Note that the previous value may have been constant-folded so it is not + // guaranteed to be an instruction in the vector loop. Also, if the previous + // value is a phi node, we should insert after all the phi nodes to avoid + // breaking basic block verification. + if (LI->getLoopFor(LoopVectorBody)->isLoopInvariant(PreviousLastPart) || + isa<PHINode>(PreviousLastPart)) + Builder.SetInsertPoint(&*LoopVectorBody->getFirstInsertionPt()); + else + Builder.SetInsertPoint( + &*++BasicBlock::iterator(cast<Instruction>(PreviousLastPart))); // We will construct a vector for the recurrence by combining the values for // the current and previous iterations. This is the required shuffle mask. @@ -4235,15 +4059,16 @@ void InnerLoopVectorizer::fixFirstOrderRecurrence(PHINode *Phi) { // Shuffle the current and previous vector and update the vector parts. for (unsigned Part = 0; Part < UF; ++Part) { + Value *PreviousPart = getOrCreateVectorValue(Previous, Part); + Value *PhiPart = VectorLoopValueMap.getVectorValue(Phi, Part); auto *Shuffle = - VF > 1 - ? Builder.CreateShuffleVector(Incoming, PreviousParts[Part], - ConstantVector::get(ShuffleMask)) - : Incoming; - PhiParts[Part]->replaceAllUsesWith(Shuffle); - cast<Instruction>(PhiParts[Part])->eraseFromParent(); - PhiParts[Part] = Shuffle; - Incoming = PreviousParts[Part]; + VF > 1 ? Builder.CreateShuffleVector(Incoming, PreviousPart, + ConstantVector::get(ShuffleMask)) + : Incoming; + PhiPart->replaceAllUsesWith(Shuffle); + cast<Instruction>(PhiPart)->eraseFromParent(); + VectorLoopValueMap.resetVectorValue(Phi, Part, Shuffle); + Incoming = PreviousPart; } // Fix the latch value of the new recurrence in the vector loop. @@ -4251,18 +4076,33 @@ void InnerLoopVectorizer::fixFirstOrderRecurrence(PHINode *Phi) { // Extract the last vector element in the middle block. This will be the // initial value for the recurrence when jumping to the scalar loop. - auto *Extract = Incoming; + auto *ExtractForScalar = Incoming; if (VF > 1) { Builder.SetInsertPoint(LoopMiddleBlock->getTerminator()); - Extract = Builder.CreateExtractElement(Extract, Builder.getInt32(VF - 1), - "vector.recur.extract"); - } + ExtractForScalar = Builder.CreateExtractElement( + ExtractForScalar, Builder.getInt32(VF - 1), "vector.recur.extract"); + } + // Extract the second last element in the middle block if the + // Phi is used outside the loop. We need to extract the phi itself + // and not the last element (the phi update in the current iteration). This + // will be the value when jumping to the exit block from the LoopMiddleBlock, + // when the scalar loop is not run at all. + Value *ExtractForPhiUsedOutsideLoop = nullptr; + if (VF > 1) + ExtractForPhiUsedOutsideLoop = Builder.CreateExtractElement( + Incoming, Builder.getInt32(VF - 2), "vector.recur.extract.for.phi"); + // When loop is unrolled without vectorizing, initialize + // ExtractForPhiUsedOutsideLoop with the value just prior to unrolled value of + // `Incoming`. This is analogous to the vectorized case above: extracting the + // second last element when VF > 1. + else if (UF > 1) + ExtractForPhiUsedOutsideLoop = getOrCreateVectorValue(Previous, UF - 2); // Fix the initial value of the original recurrence in the scalar loop. Builder.SetInsertPoint(&*LoopScalarPreHeader->begin()); auto *Start = Builder.CreatePHI(Phi->getType(), 2, "scalar.recur.init"); for (auto *BB : predecessors(LoopScalarPreHeader)) { - auto *Incoming = BB == LoopMiddleBlock ? Extract : ScalarInit; + auto *Incoming = BB == LoopMiddleBlock ? ExtractForScalar : ScalarInit; Start->addIncoming(Incoming, BB); } @@ -4279,43 +4119,200 @@ void InnerLoopVectorizer::fixFirstOrderRecurrence(PHINode *Phi) { if (!LCSSAPhi) break; if (LCSSAPhi->getIncomingValue(0) == Phi) { - LCSSAPhi->addIncoming(Extract, LoopMiddleBlock); + LCSSAPhi->addIncoming(ExtractForPhiUsedOutsideLoop, LoopMiddleBlock); break; } } } -void InnerLoopVectorizer::fixLCSSAPHIs() { - for (Instruction &LEI : *LoopExitBlock) { - auto *LCSSAPhi = dyn_cast<PHINode>(&LEI); - if (!LCSSAPhi) - break; - if (LCSSAPhi->getNumIncomingValues() == 1) - LCSSAPhi->addIncoming(UndefValue::get(LCSSAPhi->getType()), - LoopMiddleBlock); +void InnerLoopVectorizer::fixReduction(PHINode *Phi) { + Constant *Zero = Builder.getInt32(0); + + // Get it's reduction variable descriptor. + assert(Legal->isReductionVariable(Phi) && + "Unable to find the reduction variable"); + RecurrenceDescriptor RdxDesc = (*Legal->getReductionVars())[Phi]; + + RecurrenceDescriptor::RecurrenceKind RK = RdxDesc.getRecurrenceKind(); + TrackingVH<Value> ReductionStartValue = RdxDesc.getRecurrenceStartValue(); + Instruction *LoopExitInst = RdxDesc.getLoopExitInstr(); + RecurrenceDescriptor::MinMaxRecurrenceKind MinMaxKind = + RdxDesc.getMinMaxRecurrenceKind(); + setDebugLocFromInst(Builder, ReductionStartValue); + + // We need to generate a reduction vector from the incoming scalar. + // To do so, we need to generate the 'identity' vector and override + // one of the elements with the incoming scalar reduction. We need + // to do it in the vector-loop preheader. + Builder.SetInsertPoint(LoopVectorPreHeader->getTerminator()); + + // This is the vector-clone of the value that leaves the loop. + Type *VecTy = getOrCreateVectorValue(LoopExitInst, 0)->getType(); + + // Find the reduction identity variable. Zero for addition, or, xor, + // one for multiplication, -1 for And. + Value *Identity; + Value *VectorStart; + if (RK == RecurrenceDescriptor::RK_IntegerMinMax || + RK == RecurrenceDescriptor::RK_FloatMinMax) { + // MinMax reduction have the start value as their identify. + if (VF == 1) { + VectorStart = Identity = ReductionStartValue; + } else { + VectorStart = Identity = + Builder.CreateVectorSplat(VF, ReductionStartValue, "minmax.ident"); + } + } else { + // Handle other reduction kinds: + Constant *Iden = RecurrenceDescriptor::getRecurrenceIdentity( + RK, VecTy->getScalarType()); + if (VF == 1) { + Identity = Iden; + // This vector is the Identity vector where the first element is the + // incoming scalar reduction. + VectorStart = ReductionStartValue; + } else { + Identity = ConstantVector::getSplat(VF, Iden); + + // This vector is the Identity vector where the first element is the + // incoming scalar reduction. + VectorStart = + Builder.CreateInsertElement(Identity, ReductionStartValue, Zero); + } } -} -void InnerLoopVectorizer::collectTriviallyDeadInstructions() { + // Fix the vector-loop phi. + + // Reductions do not have to start at zero. They can start with + // any loop invariant values. BasicBlock *Latch = OrigLoop->getLoopLatch(); + Value *LoopVal = Phi->getIncomingValueForBlock(Latch); + for (unsigned Part = 0; Part < UF; ++Part) { + Value *VecRdxPhi = getOrCreateVectorValue(Phi, Part); + Value *Val = getOrCreateVectorValue(LoopVal, Part); + // Make sure to add the reduction stat value only to the + // first unroll part. + Value *StartVal = (Part == 0) ? VectorStart : Identity; + cast<PHINode>(VecRdxPhi)->addIncoming(StartVal, LoopVectorPreHeader); + cast<PHINode>(VecRdxPhi) + ->addIncoming(Val, LI->getLoopFor(LoopVectorBody)->getLoopLatch()); + } + + // Before each round, move the insertion point right between + // the PHIs and the values we are going to write. + // This allows us to write both PHINodes and the extractelement + // instructions. + Builder.SetInsertPoint(&*LoopMiddleBlock->getFirstInsertionPt()); - // We create new control-flow for the vectorized loop, so the original - // condition will be dead after vectorization if it's only used by the - // branch. - auto *Cmp = dyn_cast<Instruction>(Latch->getTerminator()->getOperand(0)); - if (Cmp && Cmp->hasOneUse()) - DeadInstructions.insert(Cmp); + setDebugLocFromInst(Builder, LoopExitInst); - // We create new "steps" for induction variable updates to which the original - // induction variables map. An original update instruction will be dead if - // all its users except the induction variable are dead. - for (auto &Induction : *Legal->getInductionVars()) { - PHINode *Ind = Induction.first; - auto *IndUpdate = cast<Instruction>(Ind->getIncomingValueForBlock(Latch)); - if (all_of(IndUpdate->users(), [&](User *U) -> bool { - return U == Ind || DeadInstructions.count(cast<Instruction>(U)); - })) - DeadInstructions.insert(IndUpdate); + // If the vector reduction can be performed in a smaller type, we truncate + // then extend the loop exit value to enable InstCombine to evaluate the + // entire expression in the smaller type. + if (VF > 1 && Phi->getType() != RdxDesc.getRecurrenceType()) { + Type *RdxVecTy = VectorType::get(RdxDesc.getRecurrenceType(), VF); + Builder.SetInsertPoint(LoopVectorBody->getTerminator()); + VectorParts RdxParts(UF); + for (unsigned Part = 0; Part < UF; ++Part) { + RdxParts[Part] = VectorLoopValueMap.getVectorValue(LoopExitInst, Part); + Value *Trunc = Builder.CreateTrunc(RdxParts[Part], RdxVecTy); + Value *Extnd = RdxDesc.isSigned() ? Builder.CreateSExt(Trunc, VecTy) + : Builder.CreateZExt(Trunc, VecTy); + for (Value::user_iterator UI = RdxParts[Part]->user_begin(); + UI != RdxParts[Part]->user_end();) + if (*UI != Trunc) { + (*UI++)->replaceUsesOfWith(RdxParts[Part], Extnd); + RdxParts[Part] = Extnd; + } else { + ++UI; + } + } + Builder.SetInsertPoint(&*LoopMiddleBlock->getFirstInsertionPt()); + for (unsigned Part = 0; Part < UF; ++Part) { + RdxParts[Part] = Builder.CreateTrunc(RdxParts[Part], RdxVecTy); + VectorLoopValueMap.resetVectorValue(LoopExitInst, Part, RdxParts[Part]); + } + } + + // Reduce all of the unrolled parts into a single vector. + Value *ReducedPartRdx = VectorLoopValueMap.getVectorValue(LoopExitInst, 0); + unsigned Op = RecurrenceDescriptor::getRecurrenceBinOp(RK); + setDebugLocFromInst(Builder, ReducedPartRdx); + for (unsigned Part = 1; Part < UF; ++Part) { + Value *RdxPart = VectorLoopValueMap.getVectorValue(LoopExitInst, Part); + if (Op != Instruction::ICmp && Op != Instruction::FCmp) + // Floating point operations had to be 'fast' to enable the reduction. + ReducedPartRdx = addFastMathFlag( + Builder.CreateBinOp((Instruction::BinaryOps)Op, RdxPart, + ReducedPartRdx, "bin.rdx")); + else + ReducedPartRdx = RecurrenceDescriptor::createMinMaxOp( + Builder, MinMaxKind, ReducedPartRdx, RdxPart); + } + + if (VF > 1) { + bool NoNaN = Legal->hasFunNoNaNAttr(); + ReducedPartRdx = + createTargetReduction(Builder, TTI, RdxDesc, ReducedPartRdx, NoNaN); + // If the reduction can be performed in a smaller type, we need to extend + // the reduction to the wider type before we branch to the original loop. + if (Phi->getType() != RdxDesc.getRecurrenceType()) + ReducedPartRdx = + RdxDesc.isSigned() + ? Builder.CreateSExt(ReducedPartRdx, Phi->getType()) + : Builder.CreateZExt(ReducedPartRdx, Phi->getType()); + } + + // Create a phi node that merges control-flow from the backedge-taken check + // block and the middle block. + PHINode *BCBlockPhi = PHINode::Create(Phi->getType(), 2, "bc.merge.rdx", + LoopScalarPreHeader->getTerminator()); + for (unsigned I = 0, E = LoopBypassBlocks.size(); I != E; ++I) + BCBlockPhi->addIncoming(ReductionStartValue, LoopBypassBlocks[I]); + BCBlockPhi->addIncoming(ReducedPartRdx, LoopMiddleBlock); + + // Now, we need to fix the users of the reduction variable + // inside and outside of the scalar remainder loop. + // We know that the loop is in LCSSA form. We need to update the + // PHI nodes in the exit blocks. + for (BasicBlock::iterator LEI = LoopExitBlock->begin(), + LEE = LoopExitBlock->end(); + LEI != LEE; ++LEI) { + PHINode *LCSSAPhi = dyn_cast<PHINode>(LEI); + if (!LCSSAPhi) + break; + + // All PHINodes need to have a single entry edge, or two if + // we already fixed them. + assert(LCSSAPhi->getNumIncomingValues() < 3 && "Invalid LCSSA PHI"); + + // We found a reduction value exit-PHI. Update it with the + // incoming bypass edge. + if (LCSSAPhi->getIncomingValue(0) == LoopExitInst) + LCSSAPhi->addIncoming(ReducedPartRdx, LoopMiddleBlock); + } // end of the LCSSA phi scan. + + // Fix the scalar loop reduction variable with the incoming reduction sum + // from the vector body and from the backedge value. + int IncomingEdgeBlockIdx = + Phi->getBasicBlockIndex(OrigLoop->getLoopLatch()); + assert(IncomingEdgeBlockIdx >= 0 && "Invalid block index"); + // Pick the other block. + int SelfEdgeBlockIdx = (IncomingEdgeBlockIdx ? 0 : 1); + Phi->setIncomingValue(SelfEdgeBlockIdx, BCBlockPhi); + Phi->setIncomingValue(IncomingEdgeBlockIdx, LoopExitInst); +} + +void InnerLoopVectorizer::fixLCSSAPHIs() { + for (Instruction &LEI : *LoopExitBlock) { + auto *LCSSAPhi = dyn_cast<PHINode>(&LEI); + if (!LCSSAPhi) + break; + if (LCSSAPhi->getNumIncomingValues() == 1) { + assert(OrigLoop->isLoopInvariant(LCSSAPhi->getIncomingValue(0)) && + "Incoming value isn't loop invariant"); + LCSSAPhi->addIncoming(LCSSAPhi->getIncomingValue(0), LoopMiddleBlock); + } } } @@ -4464,14 +4461,15 @@ void InnerLoopVectorizer::predicateInstructions() { for (auto KV : PredicatedInstructions) { BasicBlock::iterator I(KV.first); BasicBlock *Head = I->getParent(); - auto *BB = SplitBlock(Head, &*std::next(I), DT, LI); auto *T = SplitBlockAndInsertIfThen(KV.second, &*I, /*Unreachable=*/false, /*BranchWeights=*/nullptr, DT, LI); I->moveBefore(T); sinkScalarOperands(&*I); - I->getParent()->setName(Twine("pred.") + I->getOpcodeName() + ".if"); - BB->setName(Twine("pred.") + I->getOpcodeName() + ".continue"); + BasicBlock *PredicatedBlock = I->getParent(); + Twine BBNamePrefix = Twine("pred.") + I->getOpcodeName(); + PredicatedBlock->setName(BBNamePrefix + ".if"); + PredicatedBlock->getSingleSuccessor()->setName(BBNamePrefix + ".continue"); // If the instruction is non-void create a Phi node at reconvergence point. if (!I->getType()->isVoidTy()) { @@ -4510,8 +4508,8 @@ InnerLoopVectorizer::createEdgeMask(BasicBlock *Src, BasicBlock *Dst) { // Look for cached value. std::pair<BasicBlock *, BasicBlock *> Edge(Src, Dst); - EdgeMaskCache::iterator ECEntryIt = MaskCache.find(Edge); - if (ECEntryIt != MaskCache.end()) + EdgeMaskCacheTy::iterator ECEntryIt = EdgeMaskCache.find(Edge); + if (ECEntryIt != EdgeMaskCache.end()) return ECEntryIt->second; VectorParts SrcMask = createBlockInMask(Src); @@ -4521,20 +4519,22 @@ InnerLoopVectorizer::createEdgeMask(BasicBlock *Src, BasicBlock *Dst) { assert(BI && "Unexpected terminator found"); if (BI->isConditional()) { - VectorParts EdgeMask = getVectorValue(BI->getCondition()); - if (BI->getSuccessor(0) != Dst) - for (unsigned part = 0; part < UF; ++part) - EdgeMask[part] = Builder.CreateNot(EdgeMask[part]); + VectorParts EdgeMask(UF); + for (unsigned Part = 0; Part < UF; ++Part) { + auto *EdgeMaskPart = getOrCreateVectorValue(BI->getCondition(), Part); + if (BI->getSuccessor(0) != Dst) + EdgeMaskPart = Builder.CreateNot(EdgeMaskPart); - for (unsigned part = 0; part < UF; ++part) - EdgeMask[part] = Builder.CreateAnd(EdgeMask[part], SrcMask[part]); + EdgeMaskPart = Builder.CreateAnd(EdgeMaskPart, SrcMask[Part]); + EdgeMask[Part] = EdgeMaskPart; + } - MaskCache[Edge] = EdgeMask; + EdgeMaskCache[Edge] = EdgeMask; return EdgeMask; } - MaskCache[Edge] = SrcMask; + EdgeMaskCache[Edge] = SrcMask; return SrcMask; } @@ -4542,41 +4542,54 @@ InnerLoopVectorizer::VectorParts InnerLoopVectorizer::createBlockInMask(BasicBlock *BB) { assert(OrigLoop->contains(BB) && "Block is not a part of a loop"); + // Look for cached value. + BlockMaskCacheTy::iterator BCEntryIt = BlockMaskCache.find(BB); + if (BCEntryIt != BlockMaskCache.end()) + return BCEntryIt->second; + + VectorParts BlockMask(UF); + // Loop incoming mask is all-one. if (OrigLoop->getHeader() == BB) { Value *C = ConstantInt::get(IntegerType::getInt1Ty(BB->getContext()), 1); - return getVectorValue(C); + for (unsigned Part = 0; Part < UF; ++Part) + BlockMask[Part] = getOrCreateVectorValue(C, Part); + BlockMaskCache[BB] = BlockMask; + return BlockMask; } // This is the block mask. We OR all incoming edges, and with zero. Value *Zero = ConstantInt::get(IntegerType::getInt1Ty(BB->getContext()), 0); - VectorParts BlockMask = getVectorValue(Zero); + for (unsigned Part = 0; Part < UF; ++Part) + BlockMask[Part] = getOrCreateVectorValue(Zero, Part); // For each pred: - for (pred_iterator it = pred_begin(BB), e = pred_end(BB); it != e; ++it) { - VectorParts EM = createEdgeMask(*it, BB); - for (unsigned part = 0; part < UF; ++part) - BlockMask[part] = Builder.CreateOr(BlockMask[part], EM[part]); + for (pred_iterator It = pred_begin(BB), E = pred_end(BB); It != E; ++It) { + VectorParts EM = createEdgeMask(*It, BB); + for (unsigned Part = 0; Part < UF; ++Part) + BlockMask[Part] = Builder.CreateOr(BlockMask[Part], EM[Part]); } + BlockMaskCache[BB] = BlockMask; return BlockMask; } void InnerLoopVectorizer::widenPHIInstruction(Instruction *PN, unsigned UF, - unsigned VF, PhiVector *PV) { + unsigned VF) { PHINode *P = cast<PHINode>(PN); - // Handle recurrences. + // In order to support recurrences we need to be able to vectorize Phi nodes. + // Phi nodes have cycles, so we need to vectorize them in two stages. This is + // stage #1: We create a new vector PHI node with no incoming edges. We'll use + // this value when we vectorize all of the instructions that use the PHI. if (Legal->isReductionVariable(P) || Legal->isFirstOrderRecurrence(P)) { - VectorParts Entry(UF); - for (unsigned part = 0; part < UF; ++part) { + for (unsigned Part = 0; Part < UF; ++Part) { // This is phase one of vectorizing PHIs. Type *VecTy = (VF == 1) ? PN->getType() : VectorType::get(PN->getType(), VF); - Entry[part] = PHINode::Create( + Value *EntryPart = PHINode::Create( VecTy, 2, "vec.phi", &*LoopVectorBody->getFirstInsertionPt()); + VectorLoopValueMap.setVectorValue(P, Part, EntryPart); } - VectorLoopValueMap.initVector(P, Entry); - PV->push_back(P); return; } @@ -4600,21 +4613,22 @@ void InnerLoopVectorizer::widenPHIInstruction(Instruction *PN, unsigned UF, for (unsigned In = 0; In < NumIncoming; In++) { VectorParts Cond = createEdgeMask(P->getIncomingBlock(In), P->getParent()); - const VectorParts &In0 = getVectorValue(P->getIncomingValue(In)); - for (unsigned part = 0; part < UF; ++part) { + for (unsigned Part = 0; Part < UF; ++Part) { + Value *In0 = getOrCreateVectorValue(P->getIncomingValue(In), Part); // We might have single edge PHIs (blocks) - use an identity // 'select' for the first PHI operand. if (In == 0) - Entry[part] = Builder.CreateSelect(Cond[part], In0[part], In0[part]); + Entry[Part] = Builder.CreateSelect(Cond[Part], In0, In0); else // Select between the current value and the previous incoming edge // based on the incoming mask. - Entry[part] = Builder.CreateSelect(Cond[part], In0[part], Entry[part], + Entry[Part] = Builder.CreateSelect(Cond[Part], In0, Entry[Part], "predphi"); } } - VectorLoopValueMap.initVector(P, Entry); + for (unsigned Part = 0; Part < UF; ++Part) + VectorLoopValueMap.setVectorValue(P, Part, Entry[Part]); return; } @@ -4631,7 +4645,8 @@ void InnerLoopVectorizer::widenPHIInstruction(Instruction *PN, unsigned UF, case InductionDescriptor::IK_NoInduction: llvm_unreachable("Unknown induction"); case InductionDescriptor::IK_IntInduction: - return widenIntInduction(P); + case InductionDescriptor::IK_FpInduction: + return widenIntOrFpInduction(P); case InductionDescriptor::IK_PtrInduction: { // Handle the pointer induction variable case. assert(P->getType()->isPointerTy() && "Unexpected type."); @@ -4641,45 +4656,18 @@ void InnerLoopVectorizer::widenPHIInstruction(Instruction *PN, unsigned UF, // Determine the number of scalars we need to generate for each unroll // iteration. If the instruction is uniform, we only need to generate the // first lane. Otherwise, we generate all VF values. - unsigned Lanes = Legal->isUniformAfterVectorization(P) ? 1 : VF; + unsigned Lanes = Cost->isUniformAfterVectorization(P, VF) ? 1 : VF; // These are the scalar results. Notice that we don't generate vector GEPs // because scalar GEPs result in better code. - ScalarParts Entry(UF); for (unsigned Part = 0; Part < UF; ++Part) { - Entry[Part].resize(VF); for (unsigned Lane = 0; Lane < Lanes; ++Lane) { Constant *Idx = ConstantInt::get(PtrInd->getType(), Lane + Part * VF); Value *GlobalIdx = Builder.CreateAdd(PtrInd, Idx); Value *SclrGep = II.transform(Builder, GlobalIdx, PSE.getSE(), DL); SclrGep->setName("next.gep"); - Entry[Part][Lane] = SclrGep; + VectorLoopValueMap.setScalarValue(P, Part, Lane, SclrGep); } } - VectorLoopValueMap.initScalar(P, Entry); - return; - } - case InductionDescriptor::IK_FpInduction: { - assert(P->getType() == II.getStartValue()->getType() && - "Types must match"); - // Handle other induction variables that are now based on the - // canonical one. - assert(P != OldInduction && "Primary induction can be integer only"); - - Value *V = Builder.CreateCast(Instruction::SIToFP, Induction, P->getType()); - V = II.transform(Builder, V, PSE.getSE(), DL); - V->setName("fp.offset.idx"); - - // Now we have scalar op: %fp.offset.idx = StartVal +/- Induction*StepVal - - Value *Broadcasted = getBroadcastInstrs(V); - // After broadcasting the induction variable we need to make the vector - // consecutive by adding StepVal*0, StepVal*1, StepVal*2, etc. - Value *StepVal = cast<SCEVUnknown>(II.getStep())->getValue(); - VectorParts Entry(UF); - for (unsigned part = 0; part < UF; ++part) - Entry[part] = getStepVector(Broadcasted, VF * part, StepVal, - II.getInductionOpcode()); - VectorLoopValueMap.initVector(P, Entry); return; } } @@ -4703,269 +4691,317 @@ static bool mayDivideByZero(Instruction &I) { return !CInt || CInt->isZero(); } -void InnerLoopVectorizer::vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV) { - // For each instruction in the old loop. - for (Instruction &I : *BB) { +void InnerLoopVectorizer::vectorizeInstruction(Instruction &I) { + // Scalarize instructions that should remain scalar after vectorization. + if (VF > 1 && + !(isa<BranchInst>(&I) || isa<PHINode>(&I) || isa<DbgInfoIntrinsic>(&I)) && + shouldScalarizeInstruction(&I)) { + scalarizeInstruction(&I, Legal->isScalarWithPredication(&I)); + return; + } - // If the instruction will become trivially dead when vectorized, we don't - // need to generate it. - if (DeadInstructions.count(&I)) - continue; + switch (I.getOpcode()) { + case Instruction::Br: + // Nothing to do for PHIs and BR, since we already took care of the + // loop control flow instructions. + break; + case Instruction::PHI: { + // Vectorize PHINodes. + widenPHIInstruction(&I, UF, VF); + break; + } // End of PHI. + case Instruction::GetElementPtr: { + // Construct a vector GEP by widening the operands of the scalar GEP as + // necessary. We mark the vector GEP 'inbounds' if appropriate. A GEP + // results in a vector of pointers when at least one operand of the GEP + // is vector-typed. Thus, to keep the representation compact, we only use + // vector-typed operands for loop-varying values. + auto *GEP = cast<GetElementPtrInst>(&I); + + if (VF > 1 && OrigLoop->hasLoopInvariantOperands(GEP)) { + // If we are vectorizing, but the GEP has only loop-invariant operands, + // the GEP we build (by only using vector-typed operands for + // loop-varying values) would be a scalar pointer. Thus, to ensure we + // produce a vector of pointers, we need to either arbitrarily pick an + // operand to broadcast, or broadcast a clone of the original GEP. + // Here, we broadcast a clone of the original. + // + // TODO: If at some point we decide to scalarize instructions having + // loop-invariant operands, this special case will no longer be + // required. We would add the scalarization decision to + // collectLoopScalars() and teach getVectorValue() to broadcast + // the lane-zero scalar value. + auto *Clone = Builder.Insert(GEP->clone()); + for (unsigned Part = 0; Part < UF; ++Part) { + Value *EntryPart = Builder.CreateVectorSplat(VF, Clone); + VectorLoopValueMap.setVectorValue(&I, Part, EntryPart); + addMetadata(EntryPart, GEP); + } + } else { + // If the GEP has at least one loop-varying operand, we are sure to + // produce a vector of pointers. But if we are only unrolling, we want + // to produce a scalar GEP for each unroll part. Thus, the GEP we + // produce with the code below will be scalar (if VF == 1) or vector + // (otherwise). Note that for the unroll-only case, we still maintain + // values in the vector mapping with initVector, as we do for other + // instructions. + for (unsigned Part = 0; Part < UF; ++Part) { - // Scalarize instructions that should remain scalar after vectorization. - if (VF > 1 && - !(isa<BranchInst>(&I) || isa<PHINode>(&I) || - isa<DbgInfoIntrinsic>(&I)) && - shouldScalarizeInstruction(&I)) { - scalarizeInstruction(&I, Legal->isScalarWithPredication(&I)); - continue; - } + // The pointer operand of the new GEP. If it's loop-invariant, we + // won't broadcast it. + auto *Ptr = + OrigLoop->isLoopInvariant(GEP->getPointerOperand()) + ? GEP->getPointerOperand() + : getOrCreateVectorValue(GEP->getPointerOperand(), Part); + + // Collect all the indices for the new GEP. If any index is + // loop-invariant, we won't broadcast it. + SmallVector<Value *, 4> Indices; + for (auto &U : make_range(GEP->idx_begin(), GEP->idx_end())) { + if (OrigLoop->isLoopInvariant(U.get())) + Indices.push_back(U.get()); + else + Indices.push_back(getOrCreateVectorValue(U.get(), Part)); + } - switch (I.getOpcode()) { - case Instruction::Br: - // Nothing to do for PHIs and BR, since we already took care of the - // loop control flow instructions. - continue; - case Instruction::PHI: { - // Vectorize PHINodes. - widenPHIInstruction(&I, UF, VF, PV); - continue; - } // End of PHI. - - case Instruction::UDiv: - case Instruction::SDiv: - case Instruction::SRem: - case Instruction::URem: - // Scalarize with predication if this instruction may divide by zero and - // block execution is conditional, otherwise fallthrough. - if (Legal->isScalarWithPredication(&I)) { - scalarizeInstruction(&I, true); - continue; + // Create the new GEP. Note that this GEP may be a scalar if VF == 1, + // but it should be a vector, otherwise. + auto *NewGEP = GEP->isInBounds() + ? Builder.CreateInBoundsGEP(Ptr, Indices) + : Builder.CreateGEP(Ptr, Indices); + assert((VF == 1 || NewGEP->getType()->isVectorTy()) && + "NewGEP is not a pointer vector"); + VectorLoopValueMap.setVectorValue(&I, Part, NewGEP); + addMetadata(NewGEP, GEP); } - case Instruction::Add: - case Instruction::FAdd: - case Instruction::Sub: - case Instruction::FSub: - case Instruction::Mul: - case Instruction::FMul: - case Instruction::FDiv: - case Instruction::FRem: - case Instruction::Shl: - case Instruction::LShr: - case Instruction::AShr: - case Instruction::And: - case Instruction::Or: - case Instruction::Xor: { - // Just widen binops. - auto *BinOp = cast<BinaryOperator>(&I); - setDebugLocFromInst(Builder, BinOp); - const VectorParts &A = getVectorValue(BinOp->getOperand(0)); - const VectorParts &B = getVectorValue(BinOp->getOperand(1)); + } + + break; + } + case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::SRem: + case Instruction::URem: + // Scalarize with predication if this instruction may divide by zero and + // block execution is conditional, otherwise fallthrough. + if (Legal->isScalarWithPredication(&I)) { + scalarizeInstruction(&I, true); + break; + } + LLVM_FALLTHROUGH; + case Instruction::Add: + case Instruction::FAdd: + case Instruction::Sub: + case Instruction::FSub: + case Instruction::Mul: + case Instruction::FMul: + case Instruction::FDiv: + case Instruction::FRem: + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: { + // Just widen binops. + auto *BinOp = cast<BinaryOperator>(&I); + setDebugLocFromInst(Builder, BinOp); + + for (unsigned Part = 0; Part < UF; ++Part) { + Value *A = getOrCreateVectorValue(BinOp->getOperand(0), Part); + Value *B = getOrCreateVectorValue(BinOp->getOperand(1), Part); + Value *V = Builder.CreateBinOp(BinOp->getOpcode(), A, B); + + if (BinaryOperator *VecOp = dyn_cast<BinaryOperator>(V)) + VecOp->copyIRFlags(BinOp); // Use this vector value for all users of the original instruction. - VectorParts Entry(UF); - for (unsigned Part = 0; Part < UF; ++Part) { - Value *V = Builder.CreateBinOp(BinOp->getOpcode(), A[Part], B[Part]); + VectorLoopValueMap.setVectorValue(&I, Part, V); + addMetadata(V, BinOp); + } - if (BinaryOperator *VecOp = dyn_cast<BinaryOperator>(V)) - VecOp->copyIRFlags(BinOp); + break; + } + case Instruction::Select: { + // Widen selects. + // If the selector is loop invariant we can create a select + // instruction with a scalar condition. Otherwise, use vector-select. + auto *SE = PSE.getSE(); + bool InvariantCond = + SE->isLoopInvariant(PSE.getSCEV(I.getOperand(0)), OrigLoop); + setDebugLocFromInst(Builder, &I); - Entry[Part] = V; - } + // The condition can be loop invariant but still defined inside the + // loop. This means that we can't just use the original 'cond' value. + // We have to take the 'vectorized' value and pick the first lane. + // Instcombine will make this a no-op. - VectorLoopValueMap.initVector(&I, Entry); - addMetadata(Entry, BinOp); - break; + auto *ScalarCond = getOrCreateScalarValue(I.getOperand(0), 0, 0); + + for (unsigned Part = 0; Part < UF; ++Part) { + Value *Cond = getOrCreateVectorValue(I.getOperand(0), Part); + Value *Op0 = getOrCreateVectorValue(I.getOperand(1), Part); + Value *Op1 = getOrCreateVectorValue(I.getOperand(2), Part); + Value *Sel = + Builder.CreateSelect(InvariantCond ? ScalarCond : Cond, Op0, Op1); + VectorLoopValueMap.setVectorValue(&I, Part, Sel); + addMetadata(Sel, &I); } - case Instruction::Select: { - // Widen selects. - // If the selector is loop invariant we can create a select - // instruction with a scalar condition. Otherwise, use vector-select. - auto *SE = PSE.getSE(); - bool InvariantCond = - SE->isLoopInvariant(PSE.getSCEV(I.getOperand(0)), OrigLoop); - setDebugLocFromInst(Builder, &I); - - // The condition can be loop invariant but still defined inside the - // loop. This means that we can't just use the original 'cond' value. - // We have to take the 'vectorized' value and pick the first lane. - // Instcombine will make this a no-op. - const VectorParts &Cond = getVectorValue(I.getOperand(0)); - const VectorParts &Op0 = getVectorValue(I.getOperand(1)); - const VectorParts &Op1 = getVectorValue(I.getOperand(2)); - - auto *ScalarCond = getScalarValue(I.getOperand(0), 0, 0); - - VectorParts Entry(UF); - for (unsigned Part = 0; Part < UF; ++Part) { - Entry[Part] = Builder.CreateSelect( - InvariantCond ? ScalarCond : Cond[Part], Op0[Part], Op1[Part]); + + break; + } + + case Instruction::ICmp: + case Instruction::FCmp: { + // Widen compares. Generate vector compares. + bool FCmp = (I.getOpcode() == Instruction::FCmp); + auto *Cmp = dyn_cast<CmpInst>(&I); + setDebugLocFromInst(Builder, Cmp); + for (unsigned Part = 0; Part < UF; ++Part) { + Value *A = getOrCreateVectorValue(Cmp->getOperand(0), Part); + Value *B = getOrCreateVectorValue(Cmp->getOperand(1), Part); + Value *C = nullptr; + if (FCmp) { + C = Builder.CreateFCmp(Cmp->getPredicate(), A, B); + cast<FCmpInst>(C)->copyFastMathFlags(Cmp); + } else { + C = Builder.CreateICmp(Cmp->getPredicate(), A, B); } + VectorLoopValueMap.setVectorValue(&I, Part, C); + addMetadata(C, &I); + } - VectorLoopValueMap.initVector(&I, Entry); - addMetadata(Entry, &I); + break; + } + + case Instruction::Store: + case Instruction::Load: + vectorizeMemoryInstruction(&I); + break; + case Instruction::ZExt: + case Instruction::SExt: + case Instruction::FPToUI: + case Instruction::FPToSI: + case Instruction::FPExt: + case Instruction::PtrToInt: + case Instruction::IntToPtr: + case Instruction::SIToFP: + case Instruction::UIToFP: + case Instruction::Trunc: + case Instruction::FPTrunc: + case Instruction::BitCast: { + auto *CI = dyn_cast<CastInst>(&I); + setDebugLocFromInst(Builder, CI); + + // Optimize the special case where the source is a constant integer + // induction variable. Notice that we can only optimize the 'trunc' case + // because (a) FP conversions lose precision, (b) sext/zext may wrap, and + // (c) other casts depend on pointer size. + if (Cost->isOptimizableIVTruncate(CI, VF)) { + widenIntOrFpInduction(cast<PHINode>(CI->getOperand(0)), + cast<TruncInst>(CI)); break; } - case Instruction::ICmp: - case Instruction::FCmp: { - // Widen compares. Generate vector compares. - bool FCmp = (I.getOpcode() == Instruction::FCmp); - auto *Cmp = dyn_cast<CmpInst>(&I); - setDebugLocFromInst(Builder, Cmp); - const VectorParts &A = getVectorValue(Cmp->getOperand(0)); - const VectorParts &B = getVectorValue(Cmp->getOperand(1)); - VectorParts Entry(UF); - for (unsigned Part = 0; Part < UF; ++Part) { - Value *C = nullptr; - if (FCmp) { - C = Builder.CreateFCmp(Cmp->getPredicate(), A[Part], B[Part]); - cast<FCmpInst>(C)->copyFastMathFlags(Cmp); - } else { - C = Builder.CreateICmp(Cmp->getPredicate(), A[Part], B[Part]); - } - Entry[Part] = C; - } + /// Vectorize casts. + Type *DestTy = + (VF == 1) ? CI->getType() : VectorType::get(CI->getType(), VF); - VectorLoopValueMap.initVector(&I, Entry); - addMetadata(Entry, &I); - break; + for (unsigned Part = 0; Part < UF; ++Part) { + Value *A = getOrCreateVectorValue(CI->getOperand(0), Part); + Value *Cast = Builder.CreateCast(CI->getOpcode(), A, DestTy); + VectorLoopValueMap.setVectorValue(&I, Part, Cast); + addMetadata(Cast, &I); } + break; + } - case Instruction::Store: - case Instruction::Load: - vectorizeMemoryInstruction(&I); + case Instruction::Call: { + // Ignore dbg intrinsics. + if (isa<DbgInfoIntrinsic>(I)) break; - case Instruction::ZExt: - case Instruction::SExt: - case Instruction::FPToUI: - case Instruction::FPToSI: - case Instruction::FPExt: - case Instruction::PtrToInt: - case Instruction::IntToPtr: - case Instruction::SIToFP: - case Instruction::UIToFP: - case Instruction::Trunc: - case Instruction::FPTrunc: - case Instruction::BitCast: { - auto *CI = dyn_cast<CastInst>(&I); - setDebugLocFromInst(Builder, CI); - - // Optimize the special case where the source is a constant integer - // induction variable. Notice that we can only optimize the 'trunc' case - // because (a) FP conversions lose precision, (b) sext/zext may wrap, and - // (c) other casts depend on pointer size. - auto ID = Legal->getInductionVars()->lookup(OldInduction); - if (isa<TruncInst>(CI) && CI->getOperand(0) == OldInduction && - ID.getConstIntStepValue()) { - widenIntInduction(OldInduction, cast<TruncInst>(CI)); - break; - } + setDebugLocFromInst(Builder, &I); - /// Vectorize casts. - Type *DestTy = - (VF == 1) ? CI->getType() : VectorType::get(CI->getType(), VF); + Module *M = I.getParent()->getParent()->getParent(); + auto *CI = cast<CallInst>(&I); - const VectorParts &A = getVectorValue(CI->getOperand(0)); - VectorParts Entry(UF); - for (unsigned Part = 0; Part < UF; ++Part) - Entry[Part] = Builder.CreateCast(CI->getOpcode(), A[Part], DestTy); - VectorLoopValueMap.initVector(&I, Entry); - addMetadata(Entry, &I); + StringRef FnName = CI->getCalledFunction()->getName(); + Function *F = CI->getCalledFunction(); + Type *RetTy = ToVectorTy(CI->getType(), VF); + SmallVector<Type *, 4> Tys; + for (Value *ArgOperand : CI->arg_operands()) + Tys.push_back(ToVectorTy(ArgOperand->getType(), VF)); + + Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); + if (ID && (ID == Intrinsic::assume || ID == Intrinsic::lifetime_end || + ID == Intrinsic::lifetime_start)) { + scalarizeInstruction(&I); + break; + } + // The flag shows whether we use Intrinsic or a usual Call for vectorized + // version of the instruction. + // Is it beneficial to perform intrinsic call compared to lib call? + bool NeedToScalarize; + unsigned CallCost = getVectorCallCost(CI, VF, *TTI, TLI, NeedToScalarize); + bool UseVectorIntrinsic = + ID && getVectorIntrinsicCost(CI, VF, *TTI, TLI) <= CallCost; + if (!UseVectorIntrinsic && NeedToScalarize) { + scalarizeInstruction(&I); break; } - case Instruction::Call: { - // Ignore dbg intrinsics. - if (isa<DbgInfoIntrinsic>(I)) - break; - setDebugLocFromInst(Builder, &I); - - Module *M = BB->getParent()->getParent(); - auto *CI = cast<CallInst>(&I); - - StringRef FnName = CI->getCalledFunction()->getName(); - Function *F = CI->getCalledFunction(); - Type *RetTy = ToVectorTy(CI->getType(), VF); - SmallVector<Type *, 4> Tys; - for (Value *ArgOperand : CI->arg_operands()) - Tys.push_back(ToVectorTy(ArgOperand->getType(), VF)); - - Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); - if (ID && (ID == Intrinsic::assume || ID == Intrinsic::lifetime_end || - ID == Intrinsic::lifetime_start)) { - scalarizeInstruction(&I); - break; - } - // The flag shows whether we use Intrinsic or a usual Call for vectorized - // version of the instruction. - // Is it beneficial to perform intrinsic call compared to lib call? - bool NeedToScalarize; - unsigned CallCost = getVectorCallCost(CI, VF, *TTI, TLI, NeedToScalarize); - bool UseVectorIntrinsic = - ID && getVectorIntrinsicCost(CI, VF, *TTI, TLI) <= CallCost; - if (!UseVectorIntrinsic && NeedToScalarize) { - scalarizeInstruction(&I); - break; + for (unsigned Part = 0; Part < UF; ++Part) { + SmallVector<Value *, 4> Args; + for (unsigned i = 0, ie = CI->getNumArgOperands(); i != ie; ++i) { + Value *Arg = CI->getArgOperand(i); + // Some intrinsics have a scalar argument - don't replace it with a + // vector. + if (!UseVectorIntrinsic || !hasVectorInstrinsicScalarOpd(ID, i)) + Arg = getOrCreateVectorValue(CI->getArgOperand(i), Part); + Args.push_back(Arg); } - VectorParts Entry(UF); - for (unsigned Part = 0; Part < UF; ++Part) { - SmallVector<Value *, 4> Args; - for (unsigned i = 0, ie = CI->getNumArgOperands(); i != ie; ++i) { - Value *Arg = CI->getArgOperand(i); - // Some intrinsics have a scalar argument - don't replace it with a - // vector. - if (!UseVectorIntrinsic || !hasVectorInstrinsicScalarOpd(ID, i)) { - const VectorParts &VectorArg = getVectorValue(CI->getArgOperand(i)); - Arg = VectorArg[Part]; - } - Args.push_back(Arg); - } - - Function *VectorF; - if (UseVectorIntrinsic) { - // Use vector version of the intrinsic. - Type *TysForDecl[] = {CI->getType()}; - if (VF > 1) - TysForDecl[0] = VectorType::get(CI->getType()->getScalarType(), VF); - VectorF = Intrinsic::getDeclaration(M, ID, TysForDecl); - } else { - // Use vector version of the library call. - StringRef VFnName = TLI->getVectorizedFunction(FnName, VF); - assert(!VFnName.empty() && "Vector function name is empty."); - VectorF = M->getFunction(VFnName); - if (!VectorF) { - // Generate a declaration - FunctionType *FTy = FunctionType::get(RetTy, Tys, false); - VectorF = - Function::Create(FTy, Function::ExternalLinkage, VFnName, M); - VectorF->copyAttributesFrom(F); - } + Function *VectorF; + if (UseVectorIntrinsic) { + // Use vector version of the intrinsic. + Type *TysForDecl[] = {CI->getType()}; + if (VF > 1) + TysForDecl[0] = VectorType::get(CI->getType()->getScalarType(), VF); + VectorF = Intrinsic::getDeclaration(M, ID, TysForDecl); + } else { + // Use vector version of the library call. + StringRef VFnName = TLI->getVectorizedFunction(FnName, VF); + assert(!VFnName.empty() && "Vector function name is empty."); + VectorF = M->getFunction(VFnName); + if (!VectorF) { + // Generate a declaration + FunctionType *FTy = FunctionType::get(RetTy, Tys, false); + VectorF = + Function::Create(FTy, Function::ExternalLinkage, VFnName, M); + VectorF->copyAttributesFrom(F); } - assert(VectorF && "Can't create vector function."); + } + assert(VectorF && "Can't create vector function."); - SmallVector<OperandBundleDef, 1> OpBundles; - CI->getOperandBundlesAsDefs(OpBundles); - CallInst *V = Builder.CreateCall(VectorF, Args, OpBundles); + SmallVector<OperandBundleDef, 1> OpBundles; + CI->getOperandBundlesAsDefs(OpBundles); + CallInst *V = Builder.CreateCall(VectorF, Args, OpBundles); - if (isa<FPMathOperator>(V)) - V->copyFastMathFlags(CI); + if (isa<FPMathOperator>(V)) + V->copyFastMathFlags(CI); - Entry[Part] = V; - } - - VectorLoopValueMap.initVector(&I, Entry); - addMetadata(Entry, &I); - break; + VectorLoopValueMap.setVectorValue(&I, Part, V); + addMetadata(V, &I); } - default: - // All other instructions are unsupported. Scalarize them. - scalarizeInstruction(&I); - break; - } // end of switch. - } // end of for_each instr. + break; + } + + default: + // All other instructions are unsupported. Scalarize them. + scalarizeInstruction(&I); + break; + } // end of switch. } void InnerLoopVectorizer::updateAnalysis() { @@ -4976,11 +5012,10 @@ void InnerLoopVectorizer::updateAnalysis() { assert(DT->properlyDominates(LoopBypassBlocks.front(), LoopExitBlock) && "Entry does not dominate exit."); - // We don't predicate stores by this point, so the vector body should be a - // single loop. - DT->addNewBlock(LoopVectorBody, LoopVectorPreHeader); - - DT->addNewBlock(LoopMiddleBlock, LoopVectorBody); + DT->addNewBlock(LI->getLoopFor(LoopVectorBody)->getHeader(), + LoopVectorPreHeader); + DT->addNewBlock(LoopMiddleBlock, + LI->getLoopFor(LoopVectorBody)->getLoopLatch()); DT->addNewBlock(LoopScalarPreHeader, LoopBypassBlocks[0]); DT->changeImmediateDominator(LoopScalarBody, LoopScalarPreHeader); DT->changeImmediateDominator(LoopExitBlock, LoopBypassBlocks[0]); @@ -5056,12 +5091,18 @@ bool LoopVectorizationLegality::canVectorizeWithIfConvert() { } bool LoopVectorizationLegality::canVectorize() { + // Store the result and return it at the end instead of exiting early, in case + // allowExtraAnalysis is used to report multiple reasons for not vectorizing. + bool Result = true; // We must have a loop in canonical form. Loops with indirectbr in them cannot // be canonicalized. if (!TheLoop->getLoopPreheader()) { ORE->emit(createMissedAnalysis("CFGNotUnderstood") << "loop control flow is not understood by vectorizer"); - return false; + if (ORE->allowExtraAnalysis()) + Result = false; + else + return false; } // FIXME: The code is currently dead, since the loop gets sent to @@ -5071,21 +5112,30 @@ bool LoopVectorizationLegality::canVectorize() { if (!TheLoop->empty()) { ORE->emit(createMissedAnalysis("NotInnermostLoop") << "loop is not the innermost loop"); - return false; + if (ORE->allowExtraAnalysis()) + Result = false; + else + return false; } // We must have a single backedge. if (TheLoop->getNumBackEdges() != 1) { ORE->emit(createMissedAnalysis("CFGNotUnderstood") << "loop control flow is not understood by vectorizer"); - return false; + if (ORE->allowExtraAnalysis()) + Result = false; + else + return false; } // We must have a single exiting block. if (!TheLoop->getExitingBlock()) { ORE->emit(createMissedAnalysis("CFGNotUnderstood") << "loop control flow is not understood by vectorizer"); - return false; + if (ORE->allowExtraAnalysis()) + Result = false; + else + return false; } // We only handle bottom-tested loops, i.e. loop in which the condition is @@ -5094,7 +5144,10 @@ bool LoopVectorizationLegality::canVectorize() { if (TheLoop->getExitingBlock() != TheLoop->getLoopLatch()) { ORE->emit(createMissedAnalysis("CFGNotUnderstood") << "loop control flow is not understood by vectorizer"); - return false; + if (ORE->allowExtraAnalysis()) + Result = false; + else + return false; } // We need to have a loop header. @@ -5105,28 +5158,28 @@ bool LoopVectorizationLegality::canVectorize() { unsigned NumBlocks = TheLoop->getNumBlocks(); if (NumBlocks != 1 && !canVectorizeWithIfConvert()) { DEBUG(dbgs() << "LV: Can't if-convert the loop.\n"); - return false; - } - - // ScalarEvolution needs to be able to find the exit count. - const SCEV *ExitCount = PSE.getBackedgeTakenCount(); - if (ExitCount == PSE.getSE()->getCouldNotCompute()) { - ORE->emit(createMissedAnalysis("CantComputeNumberOfIterations") - << "could not determine number of loop iterations"); - DEBUG(dbgs() << "LV: SCEV could not compute the loop exit count.\n"); - return false; + if (ORE->allowExtraAnalysis()) + Result = false; + else + return false; } // Check if we can vectorize the instructions and CFG in this loop. if (!canVectorizeInstrs()) { DEBUG(dbgs() << "LV: Can't vectorize the instructions or CFG\n"); - return false; + if (ORE->allowExtraAnalysis()) + Result = false; + else + return false; } // Go over each instruction and look at memory deps. if (!canVectorizeMemory()) { DEBUG(dbgs() << "LV: Can't vectorize due to memory conflicts\n"); - return false; + if (ORE->allowExtraAnalysis()) + Result = false; + else + return false; } DEBUG(dbgs() << "LV: We can vectorize this loop" @@ -5145,12 +5198,6 @@ bool LoopVectorizationLegality::canVectorize() { if (UseInterleaved) InterleaveInfo.analyzeInterleaving(*getSymbolicStrides()); - // Collect all instructions that are known to be uniform after vectorization. - collectLoopUniforms(); - - // Collect all instructions that are known to be scalar after vectorization. - collectLoopScalars(); - unsigned SCEVThreshold = VectorizeSCEVCheckThreshold; if (Hints->getForce() == LoopVectorizeHints::FK_Enabled) SCEVThreshold = PragmaVectorizeSCEVCheckThreshold; @@ -5160,13 +5207,17 @@ bool LoopVectorizationLegality::canVectorize() { << "Too many SCEV assumptions need to be made and checked " << "at runtime"); DEBUG(dbgs() << "LV: Too many SCEV checks needed.\n"); - return false; + if (ORE->allowExtraAnalysis()) + Result = false; + else + return false; } - // Okay! We can vectorize. At this point we don't have any other mem analysis + // Okay! We've done all the tests. If any have failed, return false. Otherwise + // we can vectorize, and at this point we don't have any other mem analysis // which may limit our maximum vectorization factor, so just return true with // no restrictions. - return true; + return Result; } static Type *convertPointerToIntegerType(const DataLayout &DL, Type *Ty) { @@ -5234,14 +5285,19 @@ void LoopVectorizationLegality::addInductionPhi( // one if there are multiple (no good reason for doing this other // than it is expedient). We've checked that it begins at zero and // steps by one, so this is a canonical induction variable. - if (!Induction || PhiTy == WidestIndTy) - Induction = Phi; + if (!PrimaryInduction || PhiTy == WidestIndTy) + PrimaryInduction = Phi; } // Both the PHI node itself, and the "post-increment" value feeding // back into the PHI node may have external users. - AllowedExit.insert(Phi); - AllowedExit.insert(Phi->getIncomingValueForBlock(TheLoop->getLoopLatch())); + // We can allow those uses, except if the SCEVs we have for them rely + // on predicates that only hold within the loop, since allowing the exit + // currently means re-using this SCEV outside the loop. + if (PSE.getUnionPredicate().isAlwaysTrue()) { + AllowedExit.insert(Phi); + AllowedExit.insert(Phi->getIncomingValueForBlock(TheLoop->getLoopLatch())); + } DEBUG(dbgs() << "LV: Found an induction variable.\n"); return; @@ -5309,7 +5365,8 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { continue; } - if (RecurrenceDescriptor::isFirstOrderRecurrence(Phi, TheLoop, DT)) { + if (RecurrenceDescriptor::isFirstOrderRecurrence(Phi, TheLoop, + SinkAfter, DT)) { FirstOrderRecurrences.insert(Phi); continue; } @@ -5398,7 +5455,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { } // next instr. } - if (!Induction) { + if (!PrimaryInduction) { DEBUG(dbgs() << "LV: Did not find one integer induction var.\n"); if (Inductions.empty()) { ORE->emit(createMissedAnalysis("NoInductionVariable") @@ -5410,46 +5467,173 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { // Now we know the widest induction type, check if our found induction // is the same size. If it's not, unset it here and InnerLoopVectorizer // will create another. - if (Induction && WidestIndTy != Induction->getType()) - Induction = nullptr; + if (PrimaryInduction && WidestIndTy != PrimaryInduction->getType()) + PrimaryInduction = nullptr; return true; } -void LoopVectorizationLegality::collectLoopScalars() { +void LoopVectorizationCostModel::collectLoopScalars(unsigned VF) { + + // We should not collect Scalars more than once per VF. Right now, this + // function is called from collectUniformsAndScalars(), which already does + // this check. Collecting Scalars for VF=1 does not make any sense. + assert(VF >= 2 && !Scalars.count(VF) && + "This function should not be visited twice for the same VF"); + + SmallSetVector<Instruction *, 8> Worklist; + + // These sets are used to seed the analysis with pointers used by memory + // accesses that will remain scalar. + SmallSetVector<Instruction *, 8> ScalarPtrs; + SmallPtrSet<Instruction *, 8> PossibleNonScalarPtrs; + + // A helper that returns true if the use of Ptr by MemAccess will be scalar. + // The pointer operands of loads and stores will be scalar as long as the + // memory access is not a gather or scatter operation. The value operand of a + // store will remain scalar if the store is scalarized. + auto isScalarUse = [&](Instruction *MemAccess, Value *Ptr) { + InstWidening WideningDecision = getWideningDecision(MemAccess, VF); + assert(WideningDecision != CM_Unknown && + "Widening decision should be ready at this moment"); + if (auto *Store = dyn_cast<StoreInst>(MemAccess)) + if (Ptr == Store->getValueOperand()) + return WideningDecision == CM_Scalarize; + assert(Ptr == getPointerOperand(MemAccess) && + "Ptr is neither a value or pointer operand"); + return WideningDecision != CM_GatherScatter; + }; + + // A helper that returns true if the given value is a bitcast or + // getelementptr instruction contained in the loop. + auto isLoopVaryingBitCastOrGEP = [&](Value *V) { + return ((isa<BitCastInst>(V) && V->getType()->isPointerTy()) || + isa<GetElementPtrInst>(V)) && + !TheLoop->isLoopInvariant(V); + }; + + // A helper that evaluates a memory access's use of a pointer. If the use + // will be a scalar use, and the pointer is only used by memory accesses, we + // place the pointer in ScalarPtrs. Otherwise, the pointer is placed in + // PossibleNonScalarPtrs. + auto evaluatePtrUse = [&](Instruction *MemAccess, Value *Ptr) { + + // We only care about bitcast and getelementptr instructions contained in + // the loop. + if (!isLoopVaryingBitCastOrGEP(Ptr)) + return; - // If an instruction is uniform after vectorization, it will remain scalar. - Scalars.insert(Uniforms.begin(), Uniforms.end()); + // If the pointer has already been identified as scalar (e.g., if it was + // also identified as uniform), there's nothing to do. + auto *I = cast<Instruction>(Ptr); + if (Worklist.count(I)) + return; - // Collect the getelementptr instructions that will not be vectorized. A - // getelementptr instruction is only vectorized if it is used for a legal - // gather or scatter operation. + // If the use of the pointer will be a scalar use, and all users of the + // pointer are memory accesses, place the pointer in ScalarPtrs. Otherwise, + // place the pointer in PossibleNonScalarPtrs. + if (isScalarUse(MemAccess, Ptr) && all_of(I->users(), [&](User *U) { + return isa<LoadInst>(U) || isa<StoreInst>(U); + })) + ScalarPtrs.insert(I); + else + PossibleNonScalarPtrs.insert(I); + }; + + // We seed the scalars analysis with three classes of instructions: (1) + // instructions marked uniform-after-vectorization, (2) bitcast and + // getelementptr instructions used by memory accesses requiring a scalar use, + // and (3) pointer induction variables and their update instructions (we + // currently only scalarize these). + // + // (1) Add to the worklist all instructions that have been identified as + // uniform-after-vectorization. + Worklist.insert(Uniforms[VF].begin(), Uniforms[VF].end()); + + // (2) Add to the worklist all bitcast and getelementptr instructions used by + // memory accesses requiring a scalar use. The pointer operands of loads and + // stores will be scalar as long as the memory accesses is not a gather or + // scatter operation. The value operand of a store will remain scalar if the + // store is scalarized. for (auto *BB : TheLoop->blocks()) for (auto &I : *BB) { - if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) { - Scalars.insert(GEP); - continue; + if (auto *Load = dyn_cast<LoadInst>(&I)) { + evaluatePtrUse(Load, Load->getPointerOperand()); + } else if (auto *Store = dyn_cast<StoreInst>(&I)) { + evaluatePtrUse(Store, Store->getPointerOperand()); + evaluatePtrUse(Store, Store->getValueOperand()); } - auto *Ptr = getPointerOperand(&I); - if (!Ptr) - continue; - auto *GEP = getGEPInstruction(Ptr); - if (GEP && isLegalGatherOrScatter(&I)) - Scalars.erase(GEP); + } + for (auto *I : ScalarPtrs) + if (!PossibleNonScalarPtrs.count(I)) { + DEBUG(dbgs() << "LV: Found scalar instruction: " << *I << "\n"); + Worklist.insert(I); } + // (3) Add to the worklist all pointer induction variables and their update + // instructions. + // + // TODO: Once we are able to vectorize pointer induction variables we should + // no longer insert them into the worklist here. + auto *Latch = TheLoop->getLoopLatch(); + for (auto &Induction : *Legal->getInductionVars()) { + auto *Ind = Induction.first; + auto *IndUpdate = cast<Instruction>(Ind->getIncomingValueForBlock(Latch)); + if (Induction.second.getKind() != InductionDescriptor::IK_PtrInduction) + continue; + Worklist.insert(Ind); + Worklist.insert(IndUpdate); + DEBUG(dbgs() << "LV: Found scalar instruction: " << *Ind << "\n"); + DEBUG(dbgs() << "LV: Found scalar instruction: " << *IndUpdate << "\n"); + } + + // Insert the forced scalars. + // FIXME: Currently widenPHIInstruction() often creates a dead vector + // induction variable when the PHI user is scalarized. + if (ForcedScalars.count(VF)) + for (auto *I : ForcedScalars.find(VF)->second) + Worklist.insert(I); + + // Expand the worklist by looking through any bitcasts and getelementptr + // instructions we've already identified as scalar. This is similar to the + // expansion step in collectLoopUniforms(); however, here we're only + // expanding to include additional bitcasts and getelementptr instructions. + unsigned Idx = 0; + while (Idx != Worklist.size()) { + Instruction *Dst = Worklist[Idx++]; + if (!isLoopVaryingBitCastOrGEP(Dst->getOperand(0))) + continue; + auto *Src = cast<Instruction>(Dst->getOperand(0)); + if (all_of(Src->users(), [&](User *U) -> bool { + auto *J = cast<Instruction>(U); + return !TheLoop->contains(J) || Worklist.count(J) || + ((isa<LoadInst>(J) || isa<StoreInst>(J)) && + isScalarUse(J, Src)); + })) { + Worklist.insert(Src); + DEBUG(dbgs() << "LV: Found scalar instruction: " << *Src << "\n"); + } + } + // An induction variable will remain scalar if all users of the induction // variable and induction variable update remain scalar. - auto *Latch = TheLoop->getLoopLatch(); - for (auto &Induction : *getInductionVars()) { + for (auto &Induction : *Legal->getInductionVars()) { auto *Ind = Induction.first; auto *IndUpdate = cast<Instruction>(Ind->getIncomingValueForBlock(Latch)); + // We already considered pointer induction variables, so there's no reason + // to look at their users again. + // + // TODO: Once we are able to vectorize pointer induction variables we + // should no longer skip over them here. + if (Induction.second.getKind() == InductionDescriptor::IK_PtrInduction) + continue; + // Determine if all users of the induction variable are scalar after // vectorization. auto ScalarInd = all_of(Ind->users(), [&](User *U) -> bool { auto *I = cast<Instruction>(U); - return I == IndUpdate || !TheLoop->contains(I) || Scalars.count(I); + return I == IndUpdate || !TheLoop->contains(I) || Worklist.count(I); }); if (!ScalarInd) continue; @@ -5458,23 +5642,19 @@ void LoopVectorizationLegality::collectLoopScalars() { // scalar after vectorization. auto ScalarIndUpdate = all_of(IndUpdate->users(), [&](User *U) -> bool { auto *I = cast<Instruction>(U); - return I == Ind || !TheLoop->contains(I) || Scalars.count(I); + return I == Ind || !TheLoop->contains(I) || Worklist.count(I); }); if (!ScalarIndUpdate) continue; // The induction variable and its update instruction will remain scalar. - Scalars.insert(Ind); - Scalars.insert(IndUpdate); + Worklist.insert(Ind); + Worklist.insert(IndUpdate); + DEBUG(dbgs() << "LV: Found scalar instruction: " << *Ind << "\n"); + DEBUG(dbgs() << "LV: Found scalar instruction: " << *IndUpdate << "\n"); } -} -bool LoopVectorizationLegality::hasConsecutiveLikePtrOperand(Instruction *I) { - if (isAccessInterleaved(I)) - return true; - if (auto *Ptr = getPointerOperand(I)) - return isConsecutivePtr(Ptr); - return false; + Scalars[VF].insert(Worklist.begin(), Worklist.end()); } bool LoopVectorizationLegality::isScalarWithPredication(Instruction *I) { @@ -5494,48 +5674,48 @@ bool LoopVectorizationLegality::isScalarWithPredication(Instruction *I) { return false; } -bool LoopVectorizationLegality::memoryInstructionMustBeScalarized( - Instruction *I, unsigned VF) { - - // If the memory instruction is in an interleaved group, it will be - // vectorized and its pointer will remain uniform. - if (isAccessInterleaved(I)) - return false; - +bool LoopVectorizationLegality::memoryInstructionCanBeWidened(Instruction *I, + unsigned VF) { // Get and ensure we have a valid memory instruction. LoadInst *LI = dyn_cast<LoadInst>(I); StoreInst *SI = dyn_cast<StoreInst>(I); assert((LI || SI) && "Invalid memory instruction"); - // If the pointer operand is uniform (loop invariant), the memory instruction - // will be scalarized. auto *Ptr = getPointerOperand(I); - if (LI && isUniform(Ptr)) - return true; - // If the pointer operand is non-consecutive and neither a gather nor a - // scatter operation is legal, the memory instruction will be scalarized. - if (!isConsecutivePtr(Ptr) && !isLegalGatherOrScatter(I)) - return true; + // In order to be widened, the pointer should be consecutive, first of all. + if (!isConsecutivePtr(Ptr)) + return false; // If the instruction is a store located in a predicated block, it will be // scalarized. if (isScalarWithPredication(I)) - return true; + return false; // If the instruction's allocated size doesn't equal it's type size, it // requires padding and will be scalarized. auto &DL = I->getModule()->getDataLayout(); auto *ScalarTy = LI ? LI->getType() : SI->getValueOperand()->getType(); if (hasIrregularType(ScalarTy, DL, VF)) - return true; + return false; - // Otherwise, the memory instruction should be vectorized if the rest of the - // loop is. - return false; + return true; } -void LoopVectorizationLegality::collectLoopUniforms() { +void LoopVectorizationCostModel::collectLoopUniforms(unsigned VF) { + + // We should not collect Uniforms more than once per VF. Right now, + // this function is called from collectUniformsAndScalars(), which + // already does this check. Collecting Uniforms for VF=1 does not make any + // sense. + + assert(VF >= 2 && !Uniforms.count(VF) && + "This function should not be visited twice for the same VF"); + + // Visit the list of Uniforms. If we'll not find any uniform value, we'll + // not analyze again. Uniforms.count(VF) will return 1. + Uniforms[VF].clear(); + // We now know that the loop is vectorizable! // Collect instructions inside the loop that will remain uniform after // vectorization. @@ -5568,6 +5748,14 @@ void LoopVectorizationLegality::collectLoopUniforms() { // Holds pointer operands of instructions that are possibly non-uniform. SmallPtrSet<Instruction *, 8> PossibleNonUniformPtrs; + auto isUniformDecision = [&](Instruction *I, unsigned VF) { + InstWidening WideningDecision = getWideningDecision(I, VF); + assert(WideningDecision != CM_Unknown && + "Widening decision should be ready at this moment"); + + return (WideningDecision == CM_Widen || + WideningDecision == CM_Interleave); + }; // Iterate over the instructions in the loop, and collect all // consecutive-like pointer operands in ConsecutiveLikePtrs. If it's possible // that a consecutive-like pointer operand will be scalarized, we collect it @@ -5590,25 +5778,18 @@ void LoopVectorizationLegality::collectLoopUniforms() { return getPointerOperand(U) == Ptr; }); - // Ensure the memory instruction will not be scalarized, making its - // pointer operand non-uniform. If the pointer operand is used by some - // instruction other than a memory access, we're not going to check if - // that other instruction may be scalarized here. Thus, conservatively - // assume the pointer operand may be non-uniform. - if (!UsersAreMemAccesses || memoryInstructionMustBeScalarized(&I)) + // Ensure the memory instruction will not be scalarized or used by + // gather/scatter, making its pointer operand non-uniform. If the pointer + // operand is used by any instruction other than a memory access, we + // conservatively assume the pointer operand may be non-uniform. + if (!UsersAreMemAccesses || !isUniformDecision(&I, VF)) PossibleNonUniformPtrs.insert(Ptr); // If the memory instruction will be vectorized and its pointer operand - // is consecutive-like, the pointer operand should remain uniform. - else if (hasConsecutiveLikePtrOperand(&I)) - ConsecutiveLikePtrs.insert(Ptr); - - // Otherwise, if the memory instruction will be vectorized and its - // pointer operand is non-consecutive-like, the memory instruction should - // be a gather or scatter operation. Its pointer operand will be - // non-uniform. + // is consecutive-like, or interleaving - the pointer operand should + // remain uniform. else - PossibleNonUniformPtrs.insert(Ptr); + ConsecutiveLikePtrs.insert(Ptr); } // Add to the Worklist all consecutive and consecutive-like pointers that @@ -5632,7 +5813,9 @@ void LoopVectorizationLegality::collectLoopUniforms() { continue; auto *OI = cast<Instruction>(OV); if (all_of(OI->users(), [&](User *U) -> bool { - return isOutOfScope(U) || Worklist.count(cast<Instruction>(U)); + auto *J = cast<Instruction>(U); + return !TheLoop->contains(J) || Worklist.count(J) || + (OI == getPointerOperand(J) && isUniformDecision(J, VF)); })) { Worklist.insert(OI); DEBUG(dbgs() << "LV: Found uniform instruction: " << *OI << "\n"); @@ -5643,7 +5826,7 @@ void LoopVectorizationLegality::collectLoopUniforms() { // Returns true if Ptr is the pointer operand of a memory access instruction // I, and I is known to not require scalarization. auto isVectorizedMemAccessUse = [&](Instruction *I, Value *Ptr) -> bool { - return getPointerOperand(I) == Ptr && !memoryInstructionMustBeScalarized(I); + return getPointerOperand(I) == Ptr && isUniformDecision(I, VF); }; // For an instruction to be added into Worklist above, all its users inside @@ -5652,7 +5835,7 @@ void LoopVectorizationLegality::collectLoopUniforms() { // nodes separately. An induction variable will remain uniform if all users // of the induction variable and induction variable update remain uniform. // The code below handles both pointer and non-pointer induction variables. - for (auto &Induction : Inductions) { + for (auto &Induction : *Legal->getInductionVars()) { auto *Ind = Induction.first; auto *IndUpdate = cast<Instruction>(Ind->getIncomingValueForBlock(Latch)); @@ -5683,7 +5866,7 @@ void LoopVectorizationLegality::collectLoopUniforms() { DEBUG(dbgs() << "LV: Found uniform instruction: " << *IndUpdate << "\n"); } - Uniforms.insert(Worklist.begin(), Worklist.end()); + Uniforms[VF].insert(Worklist.begin(), Worklist.end()); } bool LoopVectorizationLegality::canVectorizeMemory() { @@ -5808,10 +5991,10 @@ void InterleavedAccessInfo::collectConstStrideAccesses( continue; Value *Ptr = getPointerOperand(&I); - // We don't check wrapping here because we don't know yet if Ptr will be - // part of a full group or a group with gaps. Checking wrapping for all + // We don't check wrapping here because we don't know yet if Ptr will be + // part of a full group or a group with gaps. Checking wrapping for all // pointers (even those that end up in groups with no gaps) will be overly - // conservative. For full groups, wrapping should be ok since if we would + // conservative. For full groups, wrapping should be ok since if we would // wrap around the address space we would do a memory access at nullptr // even without the transformation. The wrapping checks are therefore // deferred until after we've formed the interleaved groups. @@ -5823,7 +6006,7 @@ void InterleavedAccessInfo::collectConstStrideAccesses( uint64_t Size = DL.getTypeAllocSize(PtrTy->getElementType()); // An alignment of 0 means target ABI alignment. - unsigned Align = LI ? LI->getAlignment() : SI->getAlignment(); + unsigned Align = getMemInstAlignment(&I); if (!Align) Align = DL.getABITypeAlignment(PtrTy->getElementType()); @@ -5978,6 +6161,11 @@ void InterleavedAccessInfo::analyzeInterleaving( if (DesA.Stride != DesB.Stride || DesA.Size != DesB.Size) continue; + // Ignore A if the memory object of A and B don't belong to the same + // address space + if (getMemInstAddressSpace(A) != getMemInstAddressSpace(B)) + continue; + // Calculate the distance from A to B. const SCEVConstant *DistToB = dyn_cast<SCEVConstant>( PSE.getSE()->getMinusSCEV(DesA.Scev, DesB.Scev)); @@ -6020,36 +6208,36 @@ void InterleavedAccessInfo::analyzeInterleaving( if (Group->getNumMembers() != Group->getFactor()) releaseGroup(Group); - // Remove interleaved groups with gaps (currently only loads) whose memory - // accesses may wrap around. We have to revisit the getPtrStride analysis, - // this time with ShouldCheckWrap=true, since collectConstStrideAccesses does + // Remove interleaved groups with gaps (currently only loads) whose memory + // accesses may wrap around. We have to revisit the getPtrStride analysis, + // this time with ShouldCheckWrap=true, since collectConstStrideAccesses does // not check wrapping (see documentation there). - // FORNOW we use Assume=false; - // TODO: Change to Assume=true but making sure we don't exceed the threshold + // FORNOW we use Assume=false; + // TODO: Change to Assume=true but making sure we don't exceed the threshold // of runtime SCEV assumptions checks (thereby potentially failing to - // vectorize altogether). + // vectorize altogether). // Additional optional optimizations: - // TODO: If we are peeling the loop and we know that the first pointer doesn't + // TODO: If we are peeling the loop and we know that the first pointer doesn't // wrap then we can deduce that all pointers in the group don't wrap. - // This means that we can forcefully peel the loop in order to only have to - // check the first pointer for no-wrap. When we'll change to use Assume=true + // This means that we can forcefully peel the loop in order to only have to + // check the first pointer for no-wrap. When we'll change to use Assume=true // we'll only need at most one runtime check per interleaved group. // for (InterleaveGroup *Group : LoadGroups) { // Case 1: A full group. Can Skip the checks; For full groups, if the wide - // load would wrap around the address space we would do a memory access at - // nullptr even without the transformation. - if (Group->getNumMembers() == Group->getFactor()) + // load would wrap around the address space we would do a memory access at + // nullptr even without the transformation. + if (Group->getNumMembers() == Group->getFactor()) continue; - // Case 2: If first and last members of the group don't wrap this implies + // Case 2: If first and last members of the group don't wrap this implies // that all the pointers in the group don't wrap. // So we check only group member 0 (which is always guaranteed to exist), - // and group member Factor - 1; If the latter doesn't exist we rely on + // and group member Factor - 1; If the latter doesn't exist we rely on // peeling (if it is a non-reveresed accsess -- see Case 3). Value *FirstMemberPtr = getPointerOperand(Group->getMember(0)); - if (!getPtrStride(PSE, FirstMemberPtr, TheLoop, Strides, /*Assume=*/false, + if (!getPtrStride(PSE, FirstMemberPtr, TheLoop, Strides, /*Assume=*/false, /*ShouldCheckWrap=*/true)) { DEBUG(dbgs() << "LV: Invalidate candidate interleaved group due to " "first group member potentially pointer-wrapping.\n"); @@ -6059,18 +6247,17 @@ void InterleavedAccessInfo::analyzeInterleaving( Instruction *LastMember = Group->getMember(Group->getFactor() - 1); if (LastMember) { Value *LastMemberPtr = getPointerOperand(LastMember); - if (!getPtrStride(PSE, LastMemberPtr, TheLoop, Strides, /*Assume=*/false, + if (!getPtrStride(PSE, LastMemberPtr, TheLoop, Strides, /*Assume=*/false, /*ShouldCheckWrap=*/true)) { DEBUG(dbgs() << "LV: Invalidate candidate interleaved group due to " "last group member potentially pointer-wrapping.\n"); releaseGroup(Group); } - } - else { + } else { // Case 3: A non-reversed interleaved load group with gaps: We need - // to execute at least one scalar epilogue iteration. This will ensure + // to execute at least one scalar epilogue iteration. This will ensure // we don't speculatively access memory out-of-bounds. We only need - // to look for a member at index factor - 1, since every group must have + // to look for a member at index factor - 1, since every group must have // a member at index zero. if (Group->isReverse()) { releaseGroup(Group); @@ -6082,27 +6269,62 @@ void InterleavedAccessInfo::analyzeInterleaving( } } -LoopVectorizationCostModel::VectorizationFactor -LoopVectorizationCostModel::selectVectorizationFactor(bool OptForSize) { - // Width 1 means no vectorize - VectorizationFactor Factor = {1U, 0U}; - if (OptForSize && Legal->getRuntimePointerChecking()->Need) { +Optional<unsigned> LoopVectorizationCostModel::computeMaxVF(bool OptForSize) { + if (!EnableCondStoresVectorization && Legal->getNumPredStores()) { + ORE->emit(createMissedAnalysis("ConditionalStore") + << "store that is conditionally executed prevents vectorization"); + DEBUG(dbgs() << "LV: No vectorization. There are conditional stores.\n"); + return None; + } + + if (!OptForSize) // Remaining checks deal with scalar loop when OptForSize. + return computeFeasibleMaxVF(OptForSize); + + if (Legal->getRuntimePointerChecking()->Need) { ORE->emit(createMissedAnalysis("CantVersionLoopWithOptForSize") << "runtime pointer checks needed. Enable vectorization of this " "loop with '#pragma clang loop vectorize(enable)' when " "compiling with -Os/-Oz"); DEBUG(dbgs() << "LV: Aborting. Runtime ptr check is required with -Os/-Oz.\n"); - return Factor; + return None; } - if (!EnableCondStoresVectorization && Legal->getNumPredStores()) { - ORE->emit(createMissedAnalysis("ConditionalStore") - << "store that is conditionally executed prevents vectorization"); - DEBUG(dbgs() << "LV: No vectorization. There are conditional stores.\n"); - return Factor; + // If we optimize the program for size, avoid creating the tail loop. + unsigned TC = PSE.getSE()->getSmallConstantTripCount(TheLoop); + DEBUG(dbgs() << "LV: Found trip count: " << TC << '\n'); + + // If we don't know the precise trip count, don't try to vectorize. + if (TC < 2) { + ORE->emit( + createMissedAnalysis("UnknownLoopCountComplexCFG") + << "unable to calculate the loop count due to complex control flow"); + DEBUG(dbgs() << "LV: Aborting. A tail loop is required with -Os/-Oz.\n"); + return None; } + unsigned MaxVF = computeFeasibleMaxVF(OptForSize); + + if (TC % MaxVF != 0) { + // If the trip count that we found modulo the vectorization factor is not + // zero then we require a tail. + // FIXME: look for a smaller MaxVF that does divide TC rather than give up. + // FIXME: return None if loop requiresScalarEpilog(<MaxVF>), or look for a + // smaller MaxVF that does not require a scalar epilog. + + ORE->emit(createMissedAnalysis("NoTailLoopWithOptForSize") + << "cannot optimize for size and vectorize at the " + "same time. Enable vectorization of this loop " + "with '#pragma clang loop vectorize(enable)' " + "when compiling with -Os/-Oz"); + DEBUG(dbgs() << "LV: Aborting. A tail loop is required with -Os/-Oz.\n"); + return None; + } + + return MaxVF; +} + +unsigned LoopVectorizationCostModel::computeFeasibleMaxVF(bool OptForSize) { MinBWs = computeMinimumValueSizes(TheLoop->getBlocks(), *DB, &TTI); unsigned SmallestType, WidestType; std::tie(SmallestType, WidestType) = getSmallestAndWidestTypes(); @@ -6136,7 +6358,7 @@ LoopVectorizationCostModel::selectVectorizationFactor(bool OptForSize) { assert(MaxVectorSize <= 64 && "Did not expect to pack so many elements" " into one vector!"); - unsigned VF = MaxVectorSize; + unsigned MaxVF = MaxVectorSize; if (MaximizeBandwidth && !OptForSize) { // Collect all viable vectorization factors. SmallVector<unsigned, 8> VFs; @@ -6152,54 +6374,16 @@ LoopVectorizationCostModel::selectVectorizationFactor(bool OptForSize) { unsigned TargetNumRegisters = TTI.getNumberOfRegisters(true); for (int i = RUs.size() - 1; i >= 0; --i) { if (RUs[i].MaxLocalUsers <= TargetNumRegisters) { - VF = VFs[i]; + MaxVF = VFs[i]; break; } } } + return MaxVF; +} - // If we optimize the program for size, avoid creating the tail loop. - if (OptForSize) { - unsigned TC = PSE.getSE()->getSmallConstantTripCount(TheLoop); - DEBUG(dbgs() << "LV: Found trip count: " << TC << '\n'); - - // If we don't know the precise trip count, don't try to vectorize. - if (TC < 2) { - ORE->emit( - createMissedAnalysis("UnknownLoopCountComplexCFG") - << "unable to calculate the loop count due to complex control flow"); - DEBUG(dbgs() << "LV: Aborting. A tail loop is required with -Os/-Oz.\n"); - return Factor; - } - - // Find the maximum SIMD width that can fit within the trip count. - VF = TC % MaxVectorSize; - - if (VF == 0) - VF = MaxVectorSize; - else { - // If the trip count that we found modulo the vectorization factor is not - // zero then we require a tail. - ORE->emit(createMissedAnalysis("NoTailLoopWithOptForSize") - << "cannot optimize for size and vectorize at the " - "same time. Enable vectorization of this loop " - "with '#pragma clang loop vectorize(enable)' " - "when compiling with -Os/-Oz"); - DEBUG(dbgs() << "LV: Aborting. A tail loop is required with -Os/-Oz.\n"); - return Factor; - } - } - - int UserVF = Hints->getWidth(); - if (UserVF != 0) { - assert(isPowerOf2_32(UserVF) && "VF needs to be a power of two"); - DEBUG(dbgs() << "LV: Using user VF " << UserVF << ".\n"); - - Factor.Width = UserVF; - collectInstsToScalarize(UserVF); - return Factor; - } - +LoopVectorizationCostModel::VectorizationFactor +LoopVectorizationCostModel::selectVectorizationFactor(unsigned MaxVF) { float Cost = expectedCost(1).first; #ifndef NDEBUG const float ScalarCost = Cost; @@ -6209,12 +6393,12 @@ LoopVectorizationCostModel::selectVectorizationFactor(bool OptForSize) { bool ForceVectorization = Hints->getForce() == LoopVectorizeHints::FK_Enabled; // Ignore scalar width, because the user explicitly wants vectorization. - if (ForceVectorization && VF > 1) { + if (ForceVectorization && MaxVF > 1) { Width = 2; Cost = expectedCost(Width).first / (float)Width; } - for (unsigned i = 2; i <= VF; i *= 2) { + for (unsigned i = 2; i <= MaxVF; i *= 2) { // Notice that the vector loop needs to be executed less times, so // we need to divide the cost of the vector loops by the width of // the vector elements. @@ -6238,8 +6422,7 @@ LoopVectorizationCostModel::selectVectorizationFactor(bool OptForSize) { << "LV: Vectorization seems to be not beneficial, " << "but was forced by a user.\n"); DEBUG(dbgs() << "LV: Selecting VF: " << Width << ".\n"); - Factor.Width = Width; - Factor.Cost = Width * Cost; + VectorizationFactor Factor = {Width, (unsigned)(Width * Cost)}; return Factor; } @@ -6277,9 +6460,16 @@ LoopVectorizationCostModel::getSmallestAndWidestTypes() { T = ST->getValueOperand()->getType(); // Ignore loaded pointer types and stored pointer types that are not - // consecutive. However, we do want to take consecutive stores/loads of - // pointer vectors into account. - if (T->isPointerTy() && !isConsecutiveLoadOrStore(&I)) + // vectorizable. + // + // FIXME: The check here attempts to predict whether a load or store will + // be vectorized. We only know this for certain after a VF has + // been selected. Here, we assume that if an access can be + // vectorized, it will be. We should also look at extending this + // optimization to non-pointer types. + // + if (T->isPointerTy() && !isConsecutiveLoadOrStore(&I) && + !Legal->isAccessInterleaved(&I) && !Legal->isLegalGatherOrScatter(&I)) continue; MinWidth = std::min(MinWidth, @@ -6562,12 +6752,13 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<unsigned> VFs) { MaxUsages[j] = std::max(MaxUsages[j], OpenIntervals.size()); continue; } - + collectUniformsAndScalars(VFs[j]); // Count the number of live intervals. unsigned RegUsage = 0; for (auto Inst : OpenIntervals) { // Skip ignored values for VF > 1. - if (VecValuesToIgnore.count(Inst)) + if (VecValuesToIgnore.count(Inst) || + isScalarAfterVectorization(Inst, VFs[j])) continue; RegUsage += GetRegUsage(Inst->getType(), VFs[j]); } @@ -6628,6 +6819,9 @@ void LoopVectorizationCostModel::collectInstsToScalarize(unsigned VF) { ScalarCostsTy ScalarCosts; if (computePredInstDiscount(&I, ScalarCosts, VF) >= 0) ScalarCostsVF.insert(ScalarCosts.begin(), ScalarCosts.end()); + + // Remember that BB will remain after vectorization. + PredicatedBBsAfterVectorization.insert(BB); } } } @@ -6636,7 +6830,7 @@ int LoopVectorizationCostModel::computePredInstDiscount( Instruction *PredInst, DenseMap<Instruction *, unsigned> &ScalarCosts, unsigned VF) { - assert(!Legal->isUniformAfterVectorization(PredInst) && + assert(!isUniformAfterVectorization(PredInst, VF) && "Instruction marked uniform-after-vectorization will be predicated"); // Initialize the discount to zero, meaning that the scalar version and the @@ -6657,7 +6851,7 @@ int LoopVectorizationCostModel::computePredInstDiscount( // already be scalar to avoid traversing chains that are unlikely to be // beneficial. if (!I->hasOneUse() || PredInst->getParent() != I->getParent() || - Legal->isScalarAfterVectorization(I)) + isScalarAfterVectorization(I, VF)) return false; // If the instruction is scalar with predication, it will be analyzed @@ -6677,7 +6871,7 @@ int LoopVectorizationCostModel::computePredInstDiscount( // the lane zero values for uniforms rather than asserting. for (Use &U : I->operands()) if (auto *J = dyn_cast<Instruction>(U.get())) - if (Legal->isUniformAfterVectorization(J)) + if (isUniformAfterVectorization(J, VF)) return false; // Otherwise, we can scalarize the instruction. @@ -6690,7 +6884,7 @@ int LoopVectorizationCostModel::computePredInstDiscount( // and their return values are inserted into vectors. Thus, an extract would // still be required. auto needsExtract = [&](Instruction *I) -> bool { - return TheLoop->contains(I) && !Legal->isScalarAfterVectorization(I); + return TheLoop->contains(I) && !isScalarAfterVectorization(I, VF); }; // Compute the expected cost discount from scalarizing the entire expression @@ -6717,8 +6911,8 @@ int LoopVectorizationCostModel::computePredInstDiscount( // Compute the scalarization overhead of needed insertelement instructions // and phi nodes. if (Legal->isScalarWithPredication(I) && !I->getType()->isVoidTy()) { - ScalarCost += getScalarizationOverhead(ToVectorTy(I->getType(), VF), true, - false, TTI); + ScalarCost += TTI.getScalarizationOverhead(ToVectorTy(I->getType(), VF), + true, false); ScalarCost += VF * TTI.getCFInstrCost(Instruction::PHI); } @@ -6733,8 +6927,8 @@ int LoopVectorizationCostModel::computePredInstDiscount( if (canBeScalarized(J)) Worklist.push_back(J); else if (needsExtract(J)) - ScalarCost += getScalarizationOverhead(ToVectorTy(J->getType(), VF), - false, true, TTI); + ScalarCost += TTI.getScalarizationOverhead( + ToVectorTy(J->getType(),VF), false, true); } // Scale the total scalar cost by block probability. @@ -6753,6 +6947,9 @@ LoopVectorizationCostModel::VectorizationCostTy LoopVectorizationCostModel::expectedCost(unsigned VF) { VectorizationCostTy Cost; + // Collect Uniform and Scalar instructions after vectorization with VF. + collectUniformsAndScalars(VF); + // Collect the instructions (and their associated costs) that will be more // profitable to scalarize. collectInstsToScalarize(VF); @@ -6832,31 +7029,295 @@ static bool isStrideMul(Instruction *I, LoopVectorizationLegality *Legal) { Legal->hasStride(I->getOperand(1)); } +unsigned LoopVectorizationCostModel::getMemInstScalarizationCost(Instruction *I, + unsigned VF) { + Type *ValTy = getMemInstValueType(I); + auto SE = PSE.getSE(); + + unsigned Alignment = getMemInstAlignment(I); + unsigned AS = getMemInstAddressSpace(I); + Value *Ptr = getPointerOperand(I); + Type *PtrTy = ToVectorTy(Ptr->getType(), VF); + + // Figure out whether the access is strided and get the stride value + // if it's known in compile time + const SCEV *PtrSCEV = getAddressAccessSCEV(Ptr, Legal, SE, TheLoop); + + // Get the cost of the scalar memory instruction and address computation. + unsigned Cost = VF * TTI.getAddressComputationCost(PtrTy, SE, PtrSCEV); + + Cost += VF * + TTI.getMemoryOpCost(I->getOpcode(), ValTy->getScalarType(), Alignment, + AS, I); + + // Get the overhead of the extractelement and insertelement instructions + // we might create due to scalarization. + Cost += getScalarizationOverhead(I, VF, TTI); + + // If we have a predicated store, it may not be executed for each vector + // lane. Scale the cost by the probability of executing the predicated + // block. + if (Legal->isScalarWithPredication(I)) + Cost /= getReciprocalPredBlockProb(); + + return Cost; +} + +unsigned LoopVectorizationCostModel::getConsecutiveMemOpCost(Instruction *I, + unsigned VF) { + Type *ValTy = getMemInstValueType(I); + Type *VectorTy = ToVectorTy(ValTy, VF); + unsigned Alignment = getMemInstAlignment(I); + Value *Ptr = getPointerOperand(I); + unsigned AS = getMemInstAddressSpace(I); + int ConsecutiveStride = Legal->isConsecutivePtr(Ptr); + + assert((ConsecutiveStride == 1 || ConsecutiveStride == -1) && + "Stride should be 1 or -1 for consecutive memory access"); + unsigned Cost = 0; + if (Legal->isMaskRequired(I)) + Cost += TTI.getMaskedMemoryOpCost(I->getOpcode(), VectorTy, Alignment, AS); + else + Cost += TTI.getMemoryOpCost(I->getOpcode(), VectorTy, Alignment, AS, I); + + bool Reverse = ConsecutiveStride < 0; + if (Reverse) + Cost += TTI.getShuffleCost(TargetTransformInfo::SK_Reverse, VectorTy, 0); + return Cost; +} + +unsigned LoopVectorizationCostModel::getUniformMemOpCost(Instruction *I, + unsigned VF) { + LoadInst *LI = cast<LoadInst>(I); + Type *ValTy = LI->getType(); + Type *VectorTy = ToVectorTy(ValTy, VF); + unsigned Alignment = LI->getAlignment(); + unsigned AS = LI->getPointerAddressSpace(); + + return TTI.getAddressComputationCost(ValTy) + + TTI.getMemoryOpCost(Instruction::Load, ValTy, Alignment, AS) + + TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, VectorTy); +} + +unsigned LoopVectorizationCostModel::getGatherScatterCost(Instruction *I, + unsigned VF) { + Type *ValTy = getMemInstValueType(I); + Type *VectorTy = ToVectorTy(ValTy, VF); + unsigned Alignment = getMemInstAlignment(I); + Value *Ptr = getPointerOperand(I); + + return TTI.getAddressComputationCost(VectorTy) + + TTI.getGatherScatterOpCost(I->getOpcode(), VectorTy, Ptr, + Legal->isMaskRequired(I), Alignment); +} + +unsigned LoopVectorizationCostModel::getInterleaveGroupCost(Instruction *I, + unsigned VF) { + Type *ValTy = getMemInstValueType(I); + Type *VectorTy = ToVectorTy(ValTy, VF); + unsigned AS = getMemInstAddressSpace(I); + + auto Group = Legal->getInterleavedAccessGroup(I); + assert(Group && "Fail to get an interleaved access group."); + + unsigned InterleaveFactor = Group->getFactor(); + Type *WideVecTy = VectorType::get(ValTy, VF * InterleaveFactor); + + // Holds the indices of existing members in an interleaved load group. + // An interleaved store group doesn't need this as it doesn't allow gaps. + SmallVector<unsigned, 4> Indices; + if (isa<LoadInst>(I)) { + for (unsigned i = 0; i < InterleaveFactor; i++) + if (Group->getMember(i)) + Indices.push_back(i); + } + + // Calculate the cost of the whole interleaved group. + unsigned Cost = TTI.getInterleavedMemoryOpCost(I->getOpcode(), WideVecTy, + Group->getFactor(), Indices, + Group->getAlignment(), AS); + + if (Group->isReverse()) + Cost += Group->getNumMembers() * + TTI.getShuffleCost(TargetTransformInfo::SK_Reverse, VectorTy, 0); + return Cost; +} + +unsigned LoopVectorizationCostModel::getMemoryInstructionCost(Instruction *I, + unsigned VF) { + + // Calculate scalar cost only. Vectorization cost should be ready at this + // moment. + if (VF == 1) { + Type *ValTy = getMemInstValueType(I); + unsigned Alignment = getMemInstAlignment(I); + unsigned AS = getMemInstAddressSpace(I); + + return TTI.getAddressComputationCost(ValTy) + + TTI.getMemoryOpCost(I->getOpcode(), ValTy, Alignment, AS, I); + } + return getWideningCost(I, VF); +} + LoopVectorizationCostModel::VectorizationCostTy LoopVectorizationCostModel::getInstructionCost(Instruction *I, unsigned VF) { // If we know that this instruction will remain uniform, check the cost of // the scalar version. - if (Legal->isUniformAfterVectorization(I)) + if (isUniformAfterVectorization(I, VF)) VF = 1; if (VF > 1 && isProfitableToScalarize(I, VF)) return VectorizationCostTy(InstsToScalarize[VF][I], false); + // Forced scalars do not have any scalarization overhead. + if (VF > 1 && ForcedScalars.count(VF) && + ForcedScalars.find(VF)->second.count(I)) + return VectorizationCostTy((getInstructionCost(I, 1).first * VF), false); + Type *VectorTy; unsigned C = getInstructionCost(I, VF, VectorTy); bool TypeNotScalarized = - VF > 1 && !VectorTy->isVoidTy() && TTI.getNumberOfParts(VectorTy) < VF; + VF > 1 && VectorTy->isVectorTy() && TTI.getNumberOfParts(VectorTy) < VF; return VectorizationCostTy(C, TypeNotScalarized); } +void LoopVectorizationCostModel::setCostBasedWideningDecision(unsigned VF) { + if (VF == 1) + return; + for (BasicBlock *BB : TheLoop->blocks()) { + // For each instruction in the old loop. + for (Instruction &I : *BB) { + Value *Ptr = getPointerOperand(&I); + if (!Ptr) + continue; + + if (isa<LoadInst>(&I) && Legal->isUniform(Ptr)) { + // Scalar load + broadcast + unsigned Cost = getUniformMemOpCost(&I, VF); + setWideningDecision(&I, VF, CM_Scalarize, Cost); + continue; + } + + // We assume that widening is the best solution when possible. + if (Legal->memoryInstructionCanBeWidened(&I, VF)) { + unsigned Cost = getConsecutiveMemOpCost(&I, VF); + setWideningDecision(&I, VF, CM_Widen, Cost); + continue; + } + + // Choose between Interleaving, Gather/Scatter or Scalarization. + unsigned InterleaveCost = UINT_MAX; + unsigned NumAccesses = 1; + if (Legal->isAccessInterleaved(&I)) { + auto Group = Legal->getInterleavedAccessGroup(&I); + assert(Group && "Fail to get an interleaved access group."); + + // Make one decision for the whole group. + if (getWideningDecision(&I, VF) != CM_Unknown) + continue; + + NumAccesses = Group->getNumMembers(); + InterleaveCost = getInterleaveGroupCost(&I, VF); + } + + unsigned GatherScatterCost = + Legal->isLegalGatherOrScatter(&I) + ? getGatherScatterCost(&I, VF) * NumAccesses + : UINT_MAX; + + unsigned ScalarizationCost = + getMemInstScalarizationCost(&I, VF) * NumAccesses; + + // Choose better solution for the current VF, + // write down this decision and use it during vectorization. + unsigned Cost; + InstWidening Decision; + if (InterleaveCost <= GatherScatterCost && + InterleaveCost < ScalarizationCost) { + Decision = CM_Interleave; + Cost = InterleaveCost; + } else if (GatherScatterCost < ScalarizationCost) { + Decision = CM_GatherScatter; + Cost = GatherScatterCost; + } else { + Decision = CM_Scalarize; + Cost = ScalarizationCost; + } + // If the instructions belongs to an interleave group, the whole group + // receives the same decision. The whole group receives the cost, but + // the cost will actually be assigned to one instruction. + if (auto Group = Legal->getInterleavedAccessGroup(&I)) + setWideningDecision(Group, VF, Decision, Cost); + else + setWideningDecision(&I, VF, Decision, Cost); + } + } + + // Make sure that any load of address and any other address computation + // remains scalar unless there is gather/scatter support. This avoids + // inevitable extracts into address registers, and also has the benefit of + // activating LSR more, since that pass can't optimize vectorized + // addresses. + if (TTI.prefersVectorizedAddressing()) + return; + + // Start with all scalar pointer uses. + SmallPtrSet<Instruction *, 8> AddrDefs; + for (BasicBlock *BB : TheLoop->blocks()) + for (Instruction &I : *BB) { + Instruction *PtrDef = + dyn_cast_or_null<Instruction>(getPointerOperand(&I)); + if (PtrDef && TheLoop->contains(PtrDef) && + getWideningDecision(&I, VF) != CM_GatherScatter) + AddrDefs.insert(PtrDef); + } + + // Add all instructions used to generate the addresses. + SmallVector<Instruction *, 4> Worklist; + for (auto *I : AddrDefs) + Worklist.push_back(I); + while (!Worklist.empty()) { + Instruction *I = Worklist.pop_back_val(); + for (auto &Op : I->operands()) + if (auto *InstOp = dyn_cast<Instruction>(Op)) + if ((InstOp->getParent() == I->getParent()) && !isa<PHINode>(InstOp) && + AddrDefs.insert(InstOp).second == true) + Worklist.push_back(InstOp); + } + + for (auto *I : AddrDefs) { + if (isa<LoadInst>(I)) { + // Setting the desired widening decision should ideally be handled in + // by cost functions, but since this involves the task of finding out + // if the loaded register is involved in an address computation, it is + // instead changed here when we know this is the case. + if (getWideningDecision(I, VF) == CM_Widen) + // Scalarize a widened load of address. + setWideningDecision(I, VF, CM_Scalarize, + (VF * getMemoryInstructionCost(I, 1))); + else if (auto Group = Legal->getInterleavedAccessGroup(I)) { + // Scalarize an interleave group of address loads. + for (unsigned I = 0; I < Group->getFactor(); ++I) { + if (Instruction *Member = Group->getMember(I)) + setWideningDecision(Member, VF, CM_Scalarize, + (VF * getMemoryInstructionCost(Member, 1))); + } + } + } else + // Make sure I gets scalarized and a cost estimate without + // scalarization overhead. + ForcedScalars[VF].insert(I); + } +} + unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I, unsigned VF, Type *&VectorTy) { Type *RetTy = I->getType(); if (canTruncateToMinimalBitwidth(I, VF)) RetTy = IntegerType::get(RetTy->getContext(), MinBWs[I]); - VectorTy = ToVectorTy(RetTy, VF); + VectorTy = isScalarAfterVectorization(I, VF) ? RetTy : ToVectorTy(RetTy, VF); auto SE = PSE.getSE(); // TODO: We need to estimate the cost of intrinsic calls. @@ -6868,7 +7329,31 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I, // instruction cost. return 0; case Instruction::Br: { - return TTI.getCFInstrCost(I->getOpcode()); + // In cases of scalarized and predicated instructions, there will be VF + // predicated blocks in the vectorized loop. Each branch around these + // blocks requires also an extract of its vector compare i1 element. + bool ScalarPredicatedBB = false; + BranchInst *BI = cast<BranchInst>(I); + if (VF > 1 && BI->isConditional() && + (PredicatedBBsAfterVectorization.count(BI->getSuccessor(0)) || + PredicatedBBsAfterVectorization.count(BI->getSuccessor(1)))) + ScalarPredicatedBB = true; + + if (ScalarPredicatedBB) { + // Return cost for branches around scalarized and predicated blocks. + Type *Vec_i1Ty = + VectorType::get(IntegerType::getInt1Ty(RetTy->getContext()), VF); + return (TTI.getScalarizationOverhead(Vec_i1Ty, false, true) + + (TTI.getCFInstrCost(Instruction::Br) * VF)); + } else if (I->getParent() == TheLoop->getLoopLatch() || VF == 1) + // The back-edge branch will remain, as will all scalar branches. + return TTI.getCFInstrCost(Instruction::Br); + else + // This branch will be eliminated by if-conversion. + return 0; + // Note: We currently assume zero cost for an unconditional branch inside + // a predicated block since it will become a fall-through, although we + // may decide in the future to call TTI for all branches. } case Instruction::PHI: { auto *Phi = cast<PHINode>(I); @@ -6878,8 +7363,16 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I, return TTI.getShuffleCost(TargetTransformInfo::SK_ExtractSubvector, VectorTy, VF - 1, VectorTy); - // TODO: IF-converted IFs become selects. - return 0; + // Phi nodes in non-header blocks (not inductions, reductions, etc.) are + // converted into select instructions. We require N - 1 selects per phi + // node, where N is the number of incoming values. + if (VF > 1 && Phi->getParent() != TheLoop->getHeader()) + return (Phi->getNumIncomingValues() - 1) * + TTI.getCmpSelInstrCost( + Instruction::Select, ToVectorTy(Phi->getType(), VF), + ToVectorTy(Type::getInt1Ty(Phi->getContext()), VF)); + + return TTI.getCFInstrCost(Instruction::PHI); } case Instruction::UDiv: case Instruction::SDiv: @@ -6910,6 +7403,7 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I, // likely. return Cost / getReciprocalPredBlockProb(); } + LLVM_FALLTHROUGH; case Instruction::Add: case Instruction::FAdd: case Instruction::Sub: @@ -6957,9 +7451,10 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I, } else if (Legal->isUniform(Op2)) { Op2VK = TargetTransformInfo::OK_UniformValue; } - SmallVector<const Value *, 4> Operands(I->operand_values()); - return TTI.getArithmeticInstrCost(I->getOpcode(), VectorTy, Op1VK, - Op2VK, Op1VP, Op2VP, Operands); + SmallVector<const Value *, 4> Operands(I->operand_values()); + unsigned N = isScalarAfterVectorization(I, VF) ? VF : 1; + return N * TTI.getArithmeticInstrCost(I->getOpcode(), VectorTy, Op1VK, + Op2VK, Op1VP, Op2VP, Operands); } case Instruction::Select: { SelectInst *SI = cast<SelectInst>(I); @@ -6969,7 +7464,7 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I, if (!ScalarCond) CondTy = VectorType::get(CondTy, VF); - return TTI.getCmpSelInstrCost(I->getOpcode(), VectorTy, CondTy); + return TTI.getCmpSelInstrCost(I->getOpcode(), VectorTy, CondTy, I); } case Instruction::ICmp: case Instruction::FCmp: { @@ -6978,130 +7473,20 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I, if (canTruncateToMinimalBitwidth(Op0AsInstruction, VF)) ValTy = IntegerType::get(ValTy->getContext(), MinBWs[Op0AsInstruction]); VectorTy = ToVectorTy(ValTy, VF); - return TTI.getCmpSelInstrCost(I->getOpcode(), VectorTy); + return TTI.getCmpSelInstrCost(I->getOpcode(), VectorTy, nullptr, I); } case Instruction::Store: case Instruction::Load: { - StoreInst *SI = dyn_cast<StoreInst>(I); - LoadInst *LI = dyn_cast<LoadInst>(I); - Type *ValTy = (SI ? SI->getValueOperand()->getType() : LI->getType()); - VectorTy = ToVectorTy(ValTy, VF); - - unsigned Alignment = SI ? SI->getAlignment() : LI->getAlignment(); - unsigned AS = - SI ? SI->getPointerAddressSpace() : LI->getPointerAddressSpace(); - Value *Ptr = getPointerOperand(I); - // We add the cost of address computation here instead of with the gep - // instruction because only here we know whether the operation is - // scalarized. - if (VF == 1) - return TTI.getAddressComputationCost(VectorTy) + - TTI.getMemoryOpCost(I->getOpcode(), VectorTy, Alignment, AS); - - if (LI && Legal->isUniform(Ptr)) { - // Scalar load + broadcast - unsigned Cost = TTI.getAddressComputationCost(ValTy->getScalarType()); - Cost += TTI.getMemoryOpCost(I->getOpcode(), ValTy->getScalarType(), - Alignment, AS); - return Cost + - TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, ValTy); - } - - // For an interleaved access, calculate the total cost of the whole - // interleave group. - if (Legal->isAccessInterleaved(I)) { - auto Group = Legal->getInterleavedAccessGroup(I); - assert(Group && "Fail to get an interleaved access group."); - - // Only calculate the cost once at the insert position. - if (Group->getInsertPos() != I) - return 0; - - unsigned InterleaveFactor = Group->getFactor(); - Type *WideVecTy = - VectorType::get(VectorTy->getVectorElementType(), - VectorTy->getVectorNumElements() * InterleaveFactor); - - // Holds the indices of existing members in an interleaved load group. - // An interleaved store group doesn't need this as it doesn't allow gaps. - SmallVector<unsigned, 4> Indices; - if (LI) { - for (unsigned i = 0; i < InterleaveFactor; i++) - if (Group->getMember(i)) - Indices.push_back(i); - } - - // Calculate the cost of the whole interleaved group. - unsigned Cost = TTI.getInterleavedMemoryOpCost( - I->getOpcode(), WideVecTy, Group->getFactor(), Indices, - Group->getAlignment(), AS); - - if (Group->isReverse()) - Cost += - Group->getNumMembers() * - TTI.getShuffleCost(TargetTransformInfo::SK_Reverse, VectorTy, 0); - - // FIXME: The interleaved load group with a huge gap could be even more - // expensive than scalar operations. Then we could ignore such group and - // use scalar operations instead. - return Cost; + unsigned Width = VF; + if (Width > 1) { + InstWidening Decision = getWideningDecision(I, Width); + assert(Decision != CM_Unknown && + "CM decision should be taken at this point"); + if (Decision == CM_Scalarize) + Width = 1; } - - // Check if the memory instruction will be scalarized. - if (Legal->memoryInstructionMustBeScalarized(I, VF)) { - unsigned Cost = 0; - Type *PtrTy = ToVectorTy(Ptr->getType(), VF); - - // Figure out whether the access is strided and get the stride value - // if it's known in compile time - const SCEV *PtrSCEV = getAddressAccessSCEV(Ptr, Legal, SE, TheLoop); - - // Get the cost of the scalar memory instruction and address computation. - Cost += VF * TTI.getAddressComputationCost(PtrTy, SE, PtrSCEV); - Cost += VF * - TTI.getMemoryOpCost(I->getOpcode(), ValTy->getScalarType(), - Alignment, AS); - - // Get the overhead of the extractelement and insertelement instructions - // we might create due to scalarization. - Cost += getScalarizationOverhead(I, VF, TTI); - - // If we have a predicated store, it may not be executed for each vector - // lane. Scale the cost by the probability of executing the predicated - // block. - if (Legal->isScalarWithPredication(I)) - Cost /= getReciprocalPredBlockProb(); - - return Cost; - } - - // Determine if the pointer operand of the access is either consecutive or - // reverse consecutive. - int ConsecutiveStride = Legal->isConsecutivePtr(Ptr); - bool Reverse = ConsecutiveStride < 0; - - // Determine if either a gather or scatter operation is legal. - bool UseGatherOrScatter = - !ConsecutiveStride && Legal->isLegalGatherOrScatter(I); - - unsigned Cost = TTI.getAddressComputationCost(VectorTy); - if (UseGatherOrScatter) { - assert(ConsecutiveStride == 0 && - "Gather/Scatter are not used for consecutive stride"); - return Cost + - TTI.getGatherScatterOpCost(I->getOpcode(), VectorTy, Ptr, - Legal->isMaskRequired(I), Alignment); - } - // Wide load/stores. - if (Legal->isMaskRequired(I)) - Cost += - TTI.getMaskedMemoryOpCost(I->getOpcode(), VectorTy, Alignment, AS); - else - Cost += TTI.getMemoryOpCost(I->getOpcode(), VectorTy, Alignment, AS); - - if (Reverse) - Cost += TTI.getShuffleCost(TargetTransformInfo::SK_Reverse, VectorTy, 0); - return Cost; + VectorTy = ToVectorTy(getMemInstValueType(I), Width); + return getMemoryInstructionCost(I, VF); } case Instruction::ZExt: case Instruction::SExt: @@ -7115,15 +7500,18 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I, case Instruction::Trunc: case Instruction::FPTrunc: case Instruction::BitCast: { - // We optimize the truncation of induction variable. - // The cost of these is the same as the scalar operation. - if (I->getOpcode() == Instruction::Trunc && - Legal->isInductionVariable(I->getOperand(0))) - return TTI.getCastInstrCost(I->getOpcode(), I->getType(), - I->getOperand(0)->getType()); + // We optimize the truncation of induction variables having constant + // integer steps. The cost of these truncations is the same as the scalar + // operation. + if (isOptimizableIVTruncate(I, VF)) { + auto *Trunc = cast<TruncInst>(I); + return TTI.getCastInstrCost(Instruction::Trunc, Trunc->getDestTy(), + Trunc->getSrcTy(), Trunc); + } Type *SrcScalarTy = I->getOperand(0)->getType(); - Type *SrcVecTy = ToVectorTy(SrcScalarTy, VF); + Type *SrcVecTy = + VectorTy->isVectorTy() ? ToVectorTy(SrcScalarTy, VF) : SrcScalarTy; if (canTruncateToMinimalBitwidth(I, VF)) { // This cast is going to be shrunk. This may remove the cast or it might // turn it into slightly different cast. For example, if MinBW == 16, @@ -7143,7 +7531,8 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I, } } - return TTI.getCastInstrCost(I->getOpcode(), VectorTy, SrcVecTy); + unsigned N = isScalarAfterVectorization(I, VF) ? VF : 1; + return N * TTI.getCastInstrCost(I->getOpcode(), VectorTy, SrcVecTy, I); } case Instruction::Call: { bool NeedToScalarize; @@ -7172,9 +7561,7 @@ INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopSimplify) INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis) INITIALIZE_PASS_DEPENDENCY(DemandedBitsWrapperPass) INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) @@ -7206,81 +7593,109 @@ void LoopVectorizationCostModel::collectValuesToIgnore() { SmallPtrSetImpl<Instruction *> &Casts = RedDes.getCastInsts(); VecValuesToIgnore.insert(Casts.begin(), Casts.end()); } - - // Insert values known to be scalar into VecValuesToIgnore. This is a - // conservative estimation of the values that will later be scalarized. - // - // FIXME: Even though an instruction is not scalar-after-vectoriztion, it may - // still be scalarized. For example, we may find an instruction to be - // more profitable for a given vectorization factor if it were to be - // scalarized. But at this point, we haven't yet computed the - // vectorization factor. - for (auto *BB : TheLoop->getBlocks()) - for (auto &I : *BB) - if (Legal->isScalarAfterVectorization(&I)) - VecValuesToIgnore.insert(&I); } -void InnerLoopUnroller::scalarizeInstruction(Instruction *Instr, - bool IfPredicateInstr) { - assert(!Instr->getType()->isAggregateType() && "Can't handle vectors"); - // Holds vector parameters or scalars, in case of uniform vals. - SmallVector<VectorParts, 4> Params; +LoopVectorizationCostModel::VectorizationFactor +LoopVectorizationPlanner::plan(bool OptForSize, unsigned UserVF) { - setDebugLocFromInst(Builder, Instr); + // Width 1 means no vectorize, cost 0 means uncomputed cost. + const LoopVectorizationCostModel::VectorizationFactor NoVectorization = {1U, + 0U}; + Optional<unsigned> MaybeMaxVF = CM.computeMaxVF(OptForSize); + if (!MaybeMaxVF.hasValue()) // Cases considered too costly to vectorize. + return NoVectorization; - // Does this instruction return a value ? - bool IsVoidRetTy = Instr->getType()->isVoidTy(); + if (UserVF) { + DEBUG(dbgs() << "LV: Using user VF " << UserVF << ".\n"); + assert(isPowerOf2_32(UserVF) && "VF needs to be a power of two"); + // Collect the instructions (and their associated costs) that will be more + // profitable to scalarize. + CM.selectUserVectorizationFactor(UserVF); + return {UserVF, 0}; + } - // Initialize a new scalar map entry. - ScalarParts Entry(UF); + unsigned MaxVF = MaybeMaxVF.getValue(); + assert(MaxVF != 0 && "MaxVF is zero."); + if (MaxVF == 1) + return NoVectorization; - VectorParts Cond; - if (IfPredicateInstr) - Cond = createBlockInMask(Instr->getParent()); + // Select the optimal vectorization factor. + return CM.selectVectorizationFactor(MaxVF); +} - // For each vector unroll 'part': - for (unsigned Part = 0; Part < UF; ++Part) { - Entry[Part].resize(1); - // For each scalar that we create: +void LoopVectorizationPlanner::executePlan(InnerLoopVectorizer &ILV) { + // Perform the actual loop transformation. - // Start an "if (pred) a[i] = ..." block. - Value *Cmp = nullptr; - if (IfPredicateInstr) { - if (Cond[Part]->getType()->isVectorTy()) - Cond[Part] = - Builder.CreateExtractElement(Cond[Part], Builder.getInt32(0)); - Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Cond[Part], - ConstantInt::get(Cond[Part]->getType(), 1)); - } + // 1. Create a new empty loop. Unlink the old loop and connect the new one. + ILV.createVectorizedLoopSkeleton(); - Instruction *Cloned = Instr->clone(); - if (!IsVoidRetTy) - Cloned->setName(Instr->getName() + ".cloned"); + //===------------------------------------------------===// + // + // Notice: any optimization or new instruction that go + // into the code below should also be implemented in + // the cost-model. + // + //===------------------------------------------------===// - // Replace the operands of the cloned instructions with their scalar - // equivalents in the new loop. - for (unsigned op = 0, e = Instr->getNumOperands(); op != e; ++op) { - auto *NewOp = getScalarValue(Instr->getOperand(op), Part, 0); - Cloned->setOperand(op, NewOp); - } + // 2. Copy and widen instructions from the old loop into the new loop. + + // Move instructions to handle first-order recurrences. + DenseMap<Instruction *, Instruction *> SinkAfter = Legal->getSinkAfter(); + for (auto &Entry : SinkAfter) { + Entry.first->removeFromParent(); + Entry.first->insertAfter(Entry.second); + DEBUG(dbgs() << "Sinking" << *Entry.first << " after" << *Entry.second + << " to vectorize a 1st order recurrence.\n"); + } - // Place the cloned scalar in the new loop. - Builder.Insert(Cloned); + // Collect instructions from the original loop that will become trivially dead + // in the vectorized loop. We don't need to vectorize these instructions. For + // example, original induction update instructions can become dead because we + // separately emit induction "steps" when generating code for the new loop. + // Similarly, we create a new latch condition when setting up the structure + // of the new loop, so the old one can become dead. + SmallPtrSet<Instruction *, 4> DeadInstructions; + collectTriviallyDeadInstructions(DeadInstructions); - // Add the cloned scalar to the scalar map entry. - Entry[Part][0] = Cloned; + // Scan the loop in a topological order to ensure that defs are vectorized + // before users. + LoopBlocksDFS DFS(OrigLoop); + DFS.perform(LI); - // If we just cloned a new assumption, add it the assumption cache. - if (auto *II = dyn_cast<IntrinsicInst>(Cloned)) - if (II->getIntrinsicID() == Intrinsic::assume) - AC->registerAssumption(II); + // Vectorize all instructions in the original loop that will not become + // trivially dead when vectorized. + for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO())) + for (Instruction &I : *BB) + if (!DeadInstructions.count(&I)) + ILV.vectorizeInstruction(I); + + // 3. Fix the vectorized code: take care of header phi's, live-outs, + // predication, updating analyses. + ILV.fixVectorizedLoop(); +} - // End if-block. - if (IfPredicateInstr) - PredicatedInstructions.push_back(std::make_pair(Cloned, Cmp)); +void LoopVectorizationPlanner::collectTriviallyDeadInstructions( + SmallPtrSetImpl<Instruction *> &DeadInstructions) { + BasicBlock *Latch = OrigLoop->getLoopLatch(); + + // We create new control-flow for the vectorized loop, so the original + // condition will be dead after vectorization if it's only used by the + // branch. + auto *Cmp = dyn_cast<Instruction>(Latch->getTerminator()->getOperand(0)); + if (Cmp && Cmp->hasOneUse()) + DeadInstructions.insert(Cmp); + + // We create new "steps" for induction variable updates to which the original + // induction variables map. An original update instruction will be dead if + // all its users except the induction variable are dead. + for (auto &Induction : *Legal->getInductionVars()) { + PHINode *Ind = Induction.first; + auto *IndUpdate = cast<Instruction>(Ind->getIncomingValueForBlock(Latch)); + if (all_of(IndUpdate->users(), [&](User *U) -> bool { + return U == Ind || DeadInstructions.count(cast<Instruction>(U)); + })) + DeadInstructions.insert(IndUpdate); } - VectorLoopValueMap.initScalar(Instr, Entry); } void InnerLoopUnroller::vectorizeMemoryInstruction(Instruction *Instr) { @@ -7384,24 +7799,6 @@ bool LoopVectorizePass::processLoop(Loop *L) { return false; } - // Check the loop for a trip count threshold: - // do not vectorize loops with a tiny trip count. - const unsigned MaxTC = SE->getSmallConstantMaxTripCount(L); - if (MaxTC > 0u && MaxTC < TinyTripCountVectorThreshold) { - DEBUG(dbgs() << "LV: Found a loop with a very small trip count. " - << "This loop is not worth vectorizing."); - if (Hints.getForce() == LoopVectorizeHints::FK_Enabled) - DEBUG(dbgs() << " But vectorizing was explicitly forced.\n"); - else { - DEBUG(dbgs() << "\n"); - ORE->emit(createMissedAnalysis(Hints.vectorizeAnalysisPassName(), - "NotBeneficial", L) - << "vectorization is not beneficial " - "and is not explicitly forced"); - return false; - } - } - PredicatedScalarEvolution PSE(*SE, *L); // Check if it is legal to vectorize the loop. @@ -7414,26 +7811,37 @@ bool LoopVectorizePass::processLoop(Loop *L) { return false; } - // Use the cost model. - LoopVectorizationCostModel CM(L, PSE, LI, &LVL, *TTI, TLI, DB, AC, ORE, F, - &Hints); - CM.collectValuesToIgnore(); - // Check the function attributes to find out if this function should be // optimized for size. bool OptForSize = Hints.getForce() != LoopVectorizeHints::FK_Enabled && F->optForSize(); - // Compute the weighted frequency of this loop being executed and see if it - // is less than 20% of the function entry baseline frequency. Note that we - // always have a canonical loop here because we think we *can* vectorize. - // FIXME: This is hidden behind a flag due to pervasive problems with - // exactly what block frequency models. - if (LoopVectorizeWithBlockFrequency) { - BlockFrequency LoopEntryFreq = BFI->getBlockFreq(L->getLoopPreheader()); - if (Hints.getForce() != LoopVectorizeHints::FK_Enabled && - LoopEntryFreq < ColdEntryFreq) + // Check the loop for a trip count threshold: vectorize loops with a tiny trip + // count by optimizing for size, to minimize overheads. + unsigned ExpectedTC = SE->getSmallConstantMaxTripCount(L); + bool HasExpectedTC = (ExpectedTC > 0); + + if (!HasExpectedTC && LoopVectorizeWithBlockFrequency) { + auto EstimatedTC = getLoopEstimatedTripCount(L); + if (EstimatedTC) { + ExpectedTC = *EstimatedTC; + HasExpectedTC = true; + } + } + + if (HasExpectedTC && ExpectedTC < TinyTripCountVectorThreshold) { + DEBUG(dbgs() << "LV: Found a loop with a very small trip count. " + << "This loop is worth vectorizing only if no scalar " + << "iteration overheads are incurred."); + if (Hints.getForce() == LoopVectorizeHints::FK_Enabled) + DEBUG(dbgs() << " But vectorizing was explicitly forced.\n"); + else { + DEBUG(dbgs() << "\n"); + // Loops with a very small trip count are considered for vectorization + // under OptForSize, thereby making sure the cost of their loop body is + // dominant, free of runtime guards and scalar iteration overheads. OptForSize = true; + } } // Check the function attributes to see if implicit floats are allowed. @@ -7464,9 +7872,20 @@ bool LoopVectorizePass::processLoop(Loop *L) { return false; } - // Select the optimal vectorization factor. - const LoopVectorizationCostModel::VectorizationFactor VF = - CM.selectVectorizationFactor(OptForSize); + // Use the cost model. + LoopVectorizationCostModel CM(L, PSE, LI, &LVL, *TTI, TLI, DB, AC, ORE, F, + &Hints); + CM.collectValuesToIgnore(); + + // Use the planner for vectorization. + LoopVectorizationPlanner LVP(L, LI, &LVL, CM); + + // Get user vectorization factor. + unsigned UserVF = Hints.getWidth(); + + // Plan how to best vectorize, return the best VF and its cost. + LoopVectorizationCostModel::VectorizationFactor VF = + LVP.plan(OptForSize, UserVF); // Select the interleave count. unsigned IC = CM.selectInterleaveCount(OptForSize, VF.Width, VF.Cost); @@ -7522,10 +7941,10 @@ bool LoopVectorizePass::processLoop(Loop *L) { const char *VAPassName = Hints.vectorizeAnalysisPassName(); if (!VectorizeLoop && !InterleaveLoop) { // Do not vectorize or interleaving the loop. - ORE->emit(OptimizationRemarkAnalysis(VAPassName, VecDiagMsg.first, + ORE->emit(OptimizationRemarkMissed(VAPassName, VecDiagMsg.first, L->getStartLoc(), L->getHeader()) << VecDiagMsg.second); - ORE->emit(OptimizationRemarkAnalysis(LV_NAME, IntDiagMsg.first, + ORE->emit(OptimizationRemarkMissed(LV_NAME, IntDiagMsg.first, L->getStartLoc(), L->getHeader()) << IntDiagMsg.second); return false; @@ -7553,7 +7972,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { // interleave it. InnerLoopUnroller Unroller(L, PSE, LI, DT, TLI, TTI, AC, ORE, IC, &LVL, &CM); - Unroller.vectorize(); + LVP.executePlan(Unroller); ORE->emit(OptimizationRemark(LV_NAME, "Interleaved", L->getStartLoc(), L->getHeader()) @@ -7563,7 +7982,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { // If we decided that it is *legal* to vectorize the loop, then do it. InnerLoopVectorizer LB(L, PSE, LI, DT, TLI, TTI, AC, ORE, VF.Width, IC, &LVL, &CM); - LB.vectorize(); + LVP.executePlan(LB); ++LoopsVectorized; // Add metadata to disable runtime unrolling a scalar loop when there are @@ -7606,11 +8025,6 @@ bool LoopVectorizePass::runImpl( DB = &DB_; ORE = &ORE_; - // Compute some weights outside of the loop over the loops. Compute this - // using a BranchProbability to re-use its scaling math. - const BranchProbability ColdProb(1, 5); // 20% - ColdEntryFreq = BlockFrequency(BFI->getEntryFreq()) * ColdProb; - // Don't attempt if // 1. the target claims to have no vector registers, and // 2. interleaving won't help ILP. @@ -7621,6 +8035,16 @@ bool LoopVectorizePass::runImpl( if (!TTI->getNumberOfRegisters(true) && TTI->getMaxInterleaveFactor(1) < 2) return false; + bool Changed = false; + + // The vectorizer requires loops to be in simplified form. + // Since simplification may add new inner loops, it has to run before the + // legality and profitability checks. This means running the loop vectorizer + // will simplify all loops, regardless of whether anything end up being + // vectorized. + for (auto &L : *LI) + Changed |= simplifyLoop(L, DT, LI, SE, AC, false /* PreserveLCSSA */); + // Build up a worklist of inner-loops to vectorize. This is necessary as // the act of vectorizing or partially unrolling a loop creates new loops // and can invalidate iterators across the loops. @@ -7632,9 +8056,15 @@ bool LoopVectorizePass::runImpl( LoopsAnalyzed += Worklist.size(); // Now walk the identified inner loops. - bool Changed = false; - while (!Worklist.empty()) - Changed |= processLoop(Worklist.pop_back_val()); + while (!Worklist.empty()) { + Loop *L = Worklist.pop_back_val(); + + // For the inner loops we actually process, form LCSSA to simplify the + // transform. + Changed |= formLCSSARecursively(*L, *DT, LI, SE); + + Changed |= processLoop(L); + } // Process each loop nest in the function. return Changed; diff --git a/contrib/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/contrib/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 328f270..dcbcab4 100644 --- a/contrib/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/contrib/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -39,7 +39,10 @@ #include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/GraphWriter.h" +#include "llvm/Support/KnownBits.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Vectorize.h" #include <algorithm> #include <memory> @@ -90,6 +93,10 @@ static cl::opt<unsigned> MinTreeSize( "slp-min-tree-size", cl::init(3), cl::Hidden, cl::desc("Only vectorize small trees if they are fully vectorizable")); +static cl::opt<bool> + ViewSLPTree("view-slp-tree", cl::Hidden, + cl::desc("Display the SLP trees with Graphviz")); + // Limit the number of alias checks. The limit is chosen so that // it has no negative effect on the llvm benchmarks. static const unsigned AliasedCheckLimit = 10; @@ -166,6 +173,11 @@ static unsigned getAltOpcode(unsigned Op) { } } +/// true if the \p Value is odd, false otherwise. +static bool isOdd(unsigned Value) { + return Value & 1; +} + ///\returns bool representing if Opcode \p Op can be part /// of an alternate sequence which can later be merged as /// a ShuffleVector instruction. @@ -183,7 +195,7 @@ static unsigned isAltInst(ArrayRef<Value *> VL) { unsigned AltOpcode = getAltOpcode(Opcode); for (int i = 1, e = VL.size(); i < e; i++) { Instruction *I = dyn_cast<Instruction>(VL[i]); - if (!I || I->getOpcode() != ((i & 1) ? AltOpcode : Opcode)) + if (!I || I->getOpcode() != (isOdd(i) ? AltOpcode : Opcode)) return 0; } return Instruction::ShuffleVector; @@ -207,23 +219,6 @@ static unsigned getSameOpcode(ArrayRef<Value *> VL) { return Opcode; } -/// Get the intersection (logical and) of all of the potential IR flags -/// of each scalar operation (VL) that will be converted into a vector (I). -/// Flag set: NSW, NUW, exact, and all of fast-math. -static void propagateIRFlags(Value *I, ArrayRef<Value *> VL) { - if (auto *VecOp = dyn_cast<Instruction>(I)) { - if (auto *Intersection = dyn_cast<Instruction>(VL[0])) { - // Intersection is initialized to the 0th scalar, - // so start counting from index '1'. - for (int i = 1, e = VL.size(); i < e; ++i) { - if (auto *Scalar = dyn_cast<Instruction>(VL[i])) - Intersection->andIRFlags(Scalar); - } - VecOp->copyIRFlags(Intersection); - } - } -} - /// \returns true if all of the values in \p VL have the same type or false /// otherwise. static bool allSameType(ArrayRef<Value *> VL) { @@ -269,6 +264,7 @@ static bool InTreeUserNeedToExtract(Value *Scalar, Instruction *UserInst, if (hasVectorInstrinsicScalarOpd(ID, 1)) { return (CI->getArgOperand(1) == Scalar); } + LLVM_FALLTHROUGH; } default: return false; @@ -304,14 +300,16 @@ public: typedef SmallVector<Instruction *, 16> InstrList; typedef SmallPtrSet<Value *, 16> ValueSet; typedef SmallVector<StoreInst *, 8> StoreList; + typedef MapVector<Value *, SmallVector<Instruction *, 2>> + ExtraValueToDebugLocsMap; BoUpSLP(Function *Func, ScalarEvolution *Se, TargetTransformInfo *Tti, TargetLibraryInfo *TLi, AliasAnalysis *Aa, LoopInfo *Li, DominatorTree *Dt, AssumptionCache *AC, DemandedBits *DB, - const DataLayout *DL) + const DataLayout *DL, OptimizationRemarkEmitter *ORE) : NumLoadsWantToKeepOrder(0), NumLoadsWantToChangeOrder(0), F(Func), SE(Se), TTI(Tti), TLI(TLi), AA(Aa), LI(Li), DT(Dt), AC(AC), DB(DB), - DL(DL), Builder(Se->getContext()) { + DL(DL), ORE(ORE), Builder(Se->getContext()) { CodeMetrics::collectEphemeralValues(F, AC, EphValues); // Use the vector register size specified by the target unless overridden // by a command-line option. @@ -324,12 +322,19 @@ public: else MaxVecRegSize = TTI->getRegisterBitWidth(true); - MinVecRegSize = MinVectorRegSizeOption; + if (MinVectorRegSizeOption.getNumOccurrences()) + MinVecRegSize = MinVectorRegSizeOption; + else + MinVecRegSize = TTI->getMinVectorRegisterBitWidth(); } /// \brief Vectorize the tree that starts with the elements in \p VL. /// Returns the vectorized root. Value *vectorizeTree(); + /// Vectorize the tree but with the list of externally used values \p + /// ExternallyUsedValues. Values in this MapVector can be replaced but the + /// generated extractvalue instructions. + Value *vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues); /// \returns the cost incurred by unwanted spills and fills, caused by /// holding live values over call sites. @@ -343,6 +348,13 @@ public: /// the purpose of scheduling and extraction in the \p UserIgnoreLst. void buildTree(ArrayRef<Value *> Roots, ArrayRef<Value *> UserIgnoreLst = None); + /// Construct a vectorizable tree that starts at \p Roots, ignoring users for + /// the purpose of scheduling and extraction in the \p UserIgnoreLst taking + /// into account (anf updating it, if required) list of externally used + /// values stored in \p ExternallyUsedValues. + void buildTree(ArrayRef<Value *> Roots, + ExtraValueToDebugLocsMap &ExternallyUsedValues, + ArrayRef<Value *> UserIgnoreLst = None); /// Clear the internal data structures that are created by 'buildTree'. void deleteTree() { @@ -359,6 +371,8 @@ public: MinBWs.clear(); } + unsigned getTreeSize() const { return VectorizableTree.size(); } + /// \brief Perform LICM and CSE on the newly generated gather sequences. void optimizeGatherSequence(); @@ -397,6 +411,8 @@ public: /// vectorizable. We do not vectorize such trees. bool isTreeTinyAndNotFullyVectorizable(); + OptimizationRemarkEmitter *getORE() { return ORE; } + private: struct TreeEntry; @@ -404,7 +420,7 @@ private: int getEntryCost(TreeEntry *E); /// This is the recursive part of buildTree. - void buildTree_rec(ArrayRef<Value *> Roots, unsigned Depth); + void buildTree_rec(ArrayRef<Value *> Roots, unsigned Depth, int); /// \returns True if the ExtractElement/ExtractValue instructions in VL can /// be vectorized to use the original vector (or aggregate "bitcast" to a vector). @@ -418,7 +434,7 @@ private: /// \returns the pointer to the vectorized value if \p VL is already /// vectorized, or NULL. They may happen in cycles. - Value *alreadyVectorized(ArrayRef<Value *> VL) const; + Value *alreadyVectorized(ArrayRef<Value *> VL, Value *OpValue) const; /// \returns the scalarization cost for this type. Scalarization in this /// context means the creation of vectors from a group of scalars. @@ -451,8 +467,9 @@ private: SmallVectorImpl<Value *> &Left, SmallVectorImpl<Value *> &Right); struct TreeEntry { - TreeEntry() : Scalars(), VectorizedValue(nullptr), - NeedToGather(0) {} + TreeEntry(std::vector<TreeEntry> &Container) + : Scalars(), VectorizedValue(nullptr), NeedToGather(0), + Container(Container) {} /// \returns true if the scalars in VL are equal to this entry. bool isSame(ArrayRef<Value *> VL) const { @@ -468,23 +485,40 @@ private: /// Do we need to gather this sequence ? bool NeedToGather; + + /// Points back to the VectorizableTree. + /// + /// Only used for Graphviz right now. Unfortunately GraphTrait::NodeRef has + /// to be a pointer and needs to be able to initialize the child iterator. + /// Thus we need a reference back to the container to translate the indices + /// to entries. + std::vector<TreeEntry> &Container; + + /// The TreeEntry index containing the user of this entry. We can actually + /// have multiple users so the data structure is not truly a tree. + SmallVector<int, 1> UserTreeIndices; }; /// Create a new VectorizableTree entry. - TreeEntry *newTreeEntry(ArrayRef<Value *> VL, bool Vectorized) { - VectorizableTree.emplace_back(); + TreeEntry *newTreeEntry(ArrayRef<Value *> VL, bool Vectorized, + int &UserTreeIdx) { + VectorizableTree.emplace_back(VectorizableTree); int idx = VectorizableTree.size() - 1; TreeEntry *Last = &VectorizableTree[idx]; Last->Scalars.insert(Last->Scalars.begin(), VL.begin(), VL.end()); Last->NeedToGather = !Vectorized; if (Vectorized) { for (int i = 0, e = VL.size(); i != e; ++i) { - assert(!ScalarToTreeEntry.count(VL[i]) && "Scalar already in tree!"); + assert(!getTreeEntry(VL[i]) && "Scalar already in tree!"); ScalarToTreeEntry[VL[i]] = idx; } } else { MustGather.insert(VL.begin(), VL.end()); } + + if (UserTreeIdx >= 0) + Last->UserTreeIndices.push_back(UserTreeIdx); + UserTreeIdx = idx; return Last; } @@ -492,6 +526,20 @@ private: /// Holds all of the tree entries. std::vector<TreeEntry> VectorizableTree; + TreeEntry *getTreeEntry(Value *V) { + auto I = ScalarToTreeEntry.find(V); + if (I != ScalarToTreeEntry.end()) + return &VectorizableTree[I->second]; + return nullptr; + } + + const TreeEntry *getTreeEntry(Value *V) const { + auto I = ScalarToTreeEntry.find(V); + if (I != ScalarToTreeEntry.end()) + return &VectorizableTree[I->second]; + return nullptr; + } + /// Maps a specific scalar to its tree entry. SmallDenseMap<Value*, int> ScalarToTreeEntry; @@ -550,15 +598,17 @@ private: void eraseInstruction(Instruction *I) { I->removeFromParent(); I->dropAllReferences(); - DeletedInstructions.push_back(std::unique_ptr<Instruction>(I)); + DeletedInstructions.emplace_back(I); } /// Temporary store for deleted instructions. Instructions will be deleted /// eventually when the BoUpSLP is destructed. - SmallVector<std::unique_ptr<Instruction>, 8> DeletedInstructions; + SmallVector<unique_value, 8> DeletedInstructions; /// A list of values that need to extracted out of the tree. - /// This list holds pairs of (Internal Scalar : External User). + /// This list holds pairs of (Internal Scalar : External User). External User + /// can be nullptr, it means that this Internal Scalar will be used later, + /// after vectorization. UserList ExternalUses; /// Values used only by @llvm.assume calls. @@ -706,6 +756,8 @@ private: return os; } #endif + friend struct GraphTraits<BoUpSLP *>; + friend struct DOTGraphTraits<BoUpSLP *>; /// Contains all scheduling data for a basic block. /// @@ -805,10 +857,10 @@ private: /// Checks if a bundle of instructions can be scheduled, i.e. has no /// cyclic dependencies. This is only a dry-run, no instructions are /// actually moved at this stage. - bool tryScheduleBundle(ArrayRef<Value *> VL, BoUpSLP *SLP); + bool tryScheduleBundle(ArrayRef<Value *> VL, BoUpSLP *SLP, Value *OpValue); /// Un-bundles a group of instructions. - void cancelScheduling(ArrayRef<Value *> VL); + void cancelScheduling(ArrayRef<Value *> VL, Value *OpValue); /// Extends the scheduling region so that V is inside the region. /// \returns true if the region size is within the limit. @@ -904,6 +956,8 @@ private: AssumptionCache *AC; DemandedBits *DB; const DataLayout *DL; + OptimizationRemarkEmitter *ORE; + unsigned MaxVecRegSize; // This is set by TTI or overridden by cl::opt. unsigned MinVecRegSize; // Set by cl::opt (default: 128). /// Instruction builder to construct the vectorized tree. @@ -916,30 +970,119 @@ private: /// original width. MapVector<Value *, std::pair<uint64_t, bool>> MinBWs; }; +} // end namespace slpvectorizer + +template <> struct GraphTraits<BoUpSLP *> { + typedef BoUpSLP::TreeEntry TreeEntry; + + /// NodeRef has to be a pointer per the GraphWriter. + typedef TreeEntry *NodeRef; + + /// \brief Add the VectorizableTree to the index iterator to be able to return + /// TreeEntry pointers. + struct ChildIteratorType + : public iterator_adaptor_base<ChildIteratorType, + SmallVector<int, 1>::iterator> { + + std::vector<TreeEntry> &VectorizableTree; + + ChildIteratorType(SmallVector<int, 1>::iterator W, + std::vector<TreeEntry> &VT) + : ChildIteratorType::iterator_adaptor_base(W), VectorizableTree(VT) {} + + NodeRef operator*() { return &VectorizableTree[*I]; } + }; + + static NodeRef getEntryNode(BoUpSLP &R) { return &R.VectorizableTree[0]; } + + static ChildIteratorType child_begin(NodeRef N) { + return {N->UserTreeIndices.begin(), N->Container}; + } + static ChildIteratorType child_end(NodeRef N) { + return {N->UserTreeIndices.end(), N->Container}; + } + + /// For the node iterator we just need to turn the TreeEntry iterator into a + /// TreeEntry* iterator so that it dereferences to NodeRef. + typedef pointer_iterator<std::vector<TreeEntry>::iterator> nodes_iterator; + + static nodes_iterator nodes_begin(BoUpSLP *R) { + return nodes_iterator(R->VectorizableTree.begin()); + } + static nodes_iterator nodes_end(BoUpSLP *R) { + return nodes_iterator(R->VectorizableTree.end()); + } + + static unsigned size(BoUpSLP *R) { return R->VectorizableTree.size(); } +}; + +template <> struct DOTGraphTraits<BoUpSLP *> : public DefaultDOTGraphTraits { + typedef BoUpSLP::TreeEntry TreeEntry; + + DOTGraphTraits(bool isSimple = false) : DefaultDOTGraphTraits(isSimple) {} + + std::string getNodeLabel(const TreeEntry *Entry, const BoUpSLP *R) { + std::string Str; + raw_string_ostream OS(Str); + if (isSplat(Entry->Scalars)) { + OS << "<splat> " << *Entry->Scalars[0]; + return Str; + } + for (auto V : Entry->Scalars) { + OS << *V; + if (std::any_of( + R->ExternalUses.begin(), R->ExternalUses.end(), + [&](const BoUpSLP::ExternalUser &EU) { return EU.Scalar == V; })) + OS << " <extract>"; + OS << "\n"; + } + return Str; + } + + static std::string getNodeAttributes(const TreeEntry *Entry, + const BoUpSLP *) { + if (Entry->NeedToGather) + return "color=red"; + return ""; + } +}; } // end namespace llvm -} // end namespace slpvectorizer void BoUpSLP::buildTree(ArrayRef<Value *> Roots, ArrayRef<Value *> UserIgnoreLst) { + ExtraValueToDebugLocsMap ExternallyUsedValues; + buildTree(Roots, ExternallyUsedValues, UserIgnoreLst); +} +void BoUpSLP::buildTree(ArrayRef<Value *> Roots, + ExtraValueToDebugLocsMap &ExternallyUsedValues, + ArrayRef<Value *> UserIgnoreLst) { deleteTree(); UserIgnoreList = UserIgnoreLst; if (!allSameType(Roots)) return; - buildTree_rec(Roots, 0); + buildTree_rec(Roots, 0, -1); // Collect the values that we need to extract from the tree. for (TreeEntry &EIdx : VectorizableTree) { TreeEntry *Entry = &EIdx; + // No need to handle users of gathered values. + if (Entry->NeedToGather) + continue; + // For each lane: for (int Lane = 0, LE = Entry->Scalars.size(); Lane != LE; ++Lane) { Value *Scalar = Entry->Scalars[Lane]; - // No need to handle users of gathered values. - if (Entry->NeedToGather) + // Check if the scalar is externally used as an extra arg. + auto ExtI = ExternallyUsedValues.find(Scalar); + if (ExtI != ExternallyUsedValues.end()) { + DEBUG(dbgs() << "SLP: Need to extract: Extra arg from lane " << + Lane << " from " << *Scalar << ".\n"); + ExternalUses.emplace_back(Scalar, nullptr, Lane); continue; - + } for (User *U : Scalar->users()) { DEBUG(dbgs() << "SLP: Checking user:" << *U << ".\n"); @@ -948,9 +1091,7 @@ void BoUpSLP::buildTree(ArrayRef<Value *> Roots, continue; // Skip in-tree scalars that become vectors - if (ScalarToTreeEntry.count(U)) { - int Idx = ScalarToTreeEntry[U]; - TreeEntry *UseEntry = &VectorizableTree[Idx]; + if (TreeEntry *UseEntry = getTreeEntry(U)) { Value *UseScalar = UseEntry->Scalars[0]; // Some in-tree scalars will remain as scalar in vectorized // instructions. If that is the case, the one in Lane 0 will @@ -959,7 +1100,7 @@ void BoUpSLP::buildTree(ArrayRef<Value *> Roots, !InTreeUserNeedToExtract(Scalar, UserInst, TLI)) { DEBUG(dbgs() << "SLP: \tInternal user will be removed:" << *U << ".\n"); - assert(!VectorizableTree[Idx].NeedToGather && "Bad state"); + assert(!UseEntry->NeedToGather && "Bad state"); continue; } } @@ -976,28 +1117,28 @@ void BoUpSLP::buildTree(ArrayRef<Value *> Roots, } } - -void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { +void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, + int UserTreeIdx) { bool isAltShuffle = false; assert((allConstant(VL) || allSameType(VL)) && "Invalid types!"); if (Depth == RecursionMaxDepth) { DEBUG(dbgs() << "SLP: Gathering due to max recursion depth.\n"); - newTreeEntry(VL, false); + newTreeEntry(VL, false, UserTreeIdx); return; } // Don't handle vectors. if (VL[0]->getType()->isVectorTy()) { DEBUG(dbgs() << "SLP: Gathering due to vector type.\n"); - newTreeEntry(VL, false); + newTreeEntry(VL, false, UserTreeIdx); return; } if (StoreInst *SI = dyn_cast<StoreInst>(VL[0])) if (SI->getValueOperand()->getType()->isVectorTy()) { DEBUG(dbgs() << "SLP: Gathering due to store vector type.\n"); - newTreeEntry(VL, false); + newTreeEntry(VL, false, UserTreeIdx); return; } unsigned Opcode = getSameOpcode(VL); @@ -1014,7 +1155,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { // If all of the operands are identical or constant we have a simple solution. if (allConstant(VL) || isSplat(VL) || !allSameBlock(VL) || !Opcode) { DEBUG(dbgs() << "SLP: Gathering due to C,S,B,O. \n"); - newTreeEntry(VL, false); + newTreeEntry(VL, false, UserTreeIdx); return; } @@ -1026,23 +1167,24 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { if (EphValues.count(VL[i])) { DEBUG(dbgs() << "SLP: The instruction (" << *VL[i] << ") is ephemeral.\n"); - newTreeEntry(VL, false); + newTreeEntry(VL, false, UserTreeIdx); return; } } // Check if this is a duplicate of another entry. - if (ScalarToTreeEntry.count(VL[0])) { - int Idx = ScalarToTreeEntry[VL[0]]; - TreeEntry *E = &VectorizableTree[Idx]; + if (TreeEntry *E = getTreeEntry(VL[0])) { for (unsigned i = 0, e = VL.size(); i != e; ++i) { DEBUG(dbgs() << "SLP: \tChecking bundle: " << *VL[i] << ".\n"); if (E->Scalars[i] != VL[i]) { DEBUG(dbgs() << "SLP: Gathering due to partial overlap.\n"); - newTreeEntry(VL, false); + newTreeEntry(VL, false, UserTreeIdx); return; } } + // Record the reuse of the tree node. FIXME, currently this is only used to + // properly draw the graph rather than for the actual vectorization. + E->UserTreeIndices.push_back(UserTreeIdx); DEBUG(dbgs() << "SLP: Perfect diamond merge at " << *VL[0] << ".\n"); return; } @@ -1052,7 +1194,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { if (ScalarToTreeEntry.count(VL[i])) { DEBUG(dbgs() << "SLP: The instruction (" << *VL[i] << ") is already in tree.\n"); - newTreeEntry(VL, false); + newTreeEntry(VL, false, UserTreeIdx); return; } } @@ -1062,7 +1204,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { for (unsigned i = 0, e = VL.size(); i != e; ++i) { if (MustGather.count(VL[i])) { DEBUG(dbgs() << "SLP: Gathering due to gathered scalar.\n"); - newTreeEntry(VL, false); + newTreeEntry(VL, false, UserTreeIdx); return; } } @@ -1070,13 +1212,13 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { // Check that all of the users of the scalars that we want to vectorize are // schedulable. Instruction *VL0 = cast<Instruction>(VL[0]); - BasicBlock *BB = cast<Instruction>(VL0)->getParent(); + BasicBlock *BB = VL0->getParent(); if (!DT->isReachableFromEntry(BB)) { // Don't go into unreachable blocks. They may contain instructions with // dependency cycles which confuse the final scheduling. DEBUG(dbgs() << "SLP: bundle in unreachable block.\n"); - newTreeEntry(VL, false); + newTreeEntry(VL, false, UserTreeIdx); return; } @@ -1085,7 +1227,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { for (unsigned j = i+1; j < e; ++j) if (VL[i] == VL[j]) { DEBUG(dbgs() << "SLP: Scalar used twice in bundle.\n"); - newTreeEntry(VL, false); + newTreeEntry(VL, false, UserTreeIdx); return; } @@ -1095,12 +1237,12 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { } BlockScheduling &BS = *BSRef.get(); - if (!BS.tryScheduleBundle(VL, this)) { + if (!BS.tryScheduleBundle(VL, this, VL0)) { DEBUG(dbgs() << "SLP: We are not able to schedule this bundle!\n"); assert((!BS.getScheduleData(VL[0]) || !BS.getScheduleData(VL[0])->isPartOfBundle()) && "tryScheduleBundle should cancelScheduling on failure"); - newTreeEntry(VL, false); + newTreeEntry(VL, false, UserTreeIdx); return; } DEBUG(dbgs() << "SLP: We are able to schedule this bundle.\n"); @@ -1116,13 +1258,13 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { cast<PHINode>(VL[j])->getIncomingValueForBlock(PH->getIncomingBlock(i))); if (Term) { DEBUG(dbgs() << "SLP: Need to swizzle PHINodes (TerminatorInst use).\n"); - BS.cancelScheduling(VL); - newTreeEntry(VL, false); + BS.cancelScheduling(VL, VL0); + newTreeEntry(VL, false, UserTreeIdx); return; } } - newTreeEntry(VL, true); + newTreeEntry(VL, true, UserTreeIdx); DEBUG(dbgs() << "SLP: added a vector of PHINodes.\n"); for (unsigned i = 0, e = PH->getNumIncomingValues(); i < e; ++i) { @@ -1132,7 +1274,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { Operands.push_back(cast<PHINode>(j)->getIncomingValueForBlock( PH->getIncomingBlock(i))); - buildTree_rec(Operands, Depth + 1); + buildTree_rec(Operands, Depth + 1, UserTreeIdx); } return; } @@ -1142,9 +1284,9 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { if (Reuse) { DEBUG(dbgs() << "SLP: Reusing extract sequence.\n"); } else { - BS.cancelScheduling(VL); + BS.cancelScheduling(VL, VL0); } - newTreeEntry(VL, Reuse); + newTreeEntry(VL, Reuse, UserTreeIdx); return; } case Instruction::Load: { @@ -1159,8 +1301,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { if (DL->getTypeSizeInBits(ScalarTy) != DL->getTypeAllocSizeInBits(ScalarTy)) { - BS.cancelScheduling(VL); - newTreeEntry(VL, false); + BS.cancelScheduling(VL, VL0); + newTreeEntry(VL, false, UserTreeIdx); DEBUG(dbgs() << "SLP: Gathering loads of non-packed type.\n"); return; } @@ -1170,8 +1312,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { for (unsigned i = 0, e = VL.size() - 1; i < e; ++i) { LoadInst *L = cast<LoadInst>(VL[i]); if (!L->isSimple()) { - BS.cancelScheduling(VL); - newTreeEntry(VL, false); + BS.cancelScheduling(VL, VL0); + newTreeEntry(VL, false, UserTreeIdx); DEBUG(dbgs() << "SLP: Gathering non-simple loads.\n"); return; } @@ -1193,7 +1335,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { if (Consecutive) { ++NumLoadsWantToKeepOrder; - newTreeEntry(VL, true); + newTreeEntry(VL, true, UserTreeIdx); DEBUG(dbgs() << "SLP: added a vector of loads.\n"); return; } @@ -1207,8 +1349,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { break; } - BS.cancelScheduling(VL); - newTreeEntry(VL, false); + BS.cancelScheduling(VL, VL0); + newTreeEntry(VL, false, UserTreeIdx); if (ReverseConsecutive) { ++NumLoadsWantToChangeOrder; @@ -1234,13 +1376,13 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { for (unsigned i = 0; i < VL.size(); ++i) { Type *Ty = cast<Instruction>(VL[i])->getOperand(0)->getType(); if (Ty != SrcTy || !isValidElementType(Ty)) { - BS.cancelScheduling(VL); - newTreeEntry(VL, false); + BS.cancelScheduling(VL, VL0); + newTreeEntry(VL, false, UserTreeIdx); DEBUG(dbgs() << "SLP: Gathering casts with different src types.\n"); return; } } - newTreeEntry(VL, true); + newTreeEntry(VL, true, UserTreeIdx); DEBUG(dbgs() << "SLP: added a vector of casts.\n"); for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) { @@ -1249,7 +1391,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { for (Value *j : VL) Operands.push_back(cast<Instruction>(j)->getOperand(i)); - buildTree_rec(Operands, Depth+1); + buildTree_rec(Operands, Depth + 1, UserTreeIdx); } return; } @@ -1262,14 +1404,14 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { CmpInst *Cmp = cast<CmpInst>(VL[i]); if (Cmp->getPredicate() != P0 || Cmp->getOperand(0)->getType() != ComparedTy) { - BS.cancelScheduling(VL); - newTreeEntry(VL, false); + BS.cancelScheduling(VL, VL0); + newTreeEntry(VL, false, UserTreeIdx); DEBUG(dbgs() << "SLP: Gathering cmp with different predicate.\n"); return; } } - newTreeEntry(VL, true); + newTreeEntry(VL, true, UserTreeIdx); DEBUG(dbgs() << "SLP: added a vector of compares.\n"); for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) { @@ -1278,7 +1420,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { for (Value *j : VL) Operands.push_back(cast<Instruction>(j)->getOperand(i)); - buildTree_rec(Operands, Depth+1); + buildTree_rec(Operands, Depth + 1, UserTreeIdx); } return; } @@ -1301,7 +1443,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { case Instruction::And: case Instruction::Or: case Instruction::Xor: { - newTreeEntry(VL, true); + newTreeEntry(VL, true, UserTreeIdx); DEBUG(dbgs() << "SLP: added a vector of bin op.\n"); // Sort operands of the instructions so that each side is more likely to @@ -1309,8 +1451,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { if (isa<BinaryOperator>(VL0) && VL0->isCommutative()) { ValueList Left, Right; reorderInputsAccordingToOpcode(VL, Left, Right); - buildTree_rec(Left, Depth + 1); - buildTree_rec(Right, Depth + 1); + buildTree_rec(Left, Depth + 1, UserTreeIdx); + buildTree_rec(Right, Depth + 1, UserTreeIdx); return; } @@ -1320,7 +1462,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { for (Value *j : VL) Operands.push_back(cast<Instruction>(j)->getOperand(i)); - buildTree_rec(Operands, Depth+1); + buildTree_rec(Operands, Depth + 1, UserTreeIdx); } return; } @@ -1329,8 +1471,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { for (unsigned j = 0; j < VL.size(); ++j) { if (cast<Instruction>(VL[j])->getNumOperands() != 2) { DEBUG(dbgs() << "SLP: not-vectorizable GEP (nested indexes).\n"); - BS.cancelScheduling(VL); - newTreeEntry(VL, false); + BS.cancelScheduling(VL, VL0); + newTreeEntry(VL, false, UserTreeIdx); return; } } @@ -1342,8 +1484,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { Type *CurTy = cast<Instruction>(VL[j])->getOperand(0)->getType(); if (Ty0 != CurTy) { DEBUG(dbgs() << "SLP: not-vectorizable GEP (different types).\n"); - BS.cancelScheduling(VL); - newTreeEntry(VL, false); + BS.cancelScheduling(VL, VL0); + newTreeEntry(VL, false, UserTreeIdx); return; } } @@ -1354,13 +1496,13 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { if (!isa<ConstantInt>(Op)) { DEBUG( dbgs() << "SLP: not-vectorizable GEP (non-constant indexes).\n"); - BS.cancelScheduling(VL); - newTreeEntry(VL, false); + BS.cancelScheduling(VL, VL0); + newTreeEntry(VL, false, UserTreeIdx); return; } } - newTreeEntry(VL, true); + newTreeEntry(VL, true, UserTreeIdx); DEBUG(dbgs() << "SLP: added a vector of GEPs.\n"); for (unsigned i = 0, e = 2; i < e; ++i) { ValueList Operands; @@ -1368,7 +1510,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { for (Value *j : VL) Operands.push_back(cast<Instruction>(j)->getOperand(i)); - buildTree_rec(Operands, Depth + 1); + buildTree_rec(Operands, Depth + 1, UserTreeIdx); } return; } @@ -1376,20 +1518,20 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { // Check if the stores are consecutive or of we need to swizzle them. for (unsigned i = 0, e = VL.size() - 1; i < e; ++i) if (!isConsecutiveAccess(VL[i], VL[i + 1], *DL, *SE)) { - BS.cancelScheduling(VL); - newTreeEntry(VL, false); + BS.cancelScheduling(VL, VL0); + newTreeEntry(VL, false, UserTreeIdx); DEBUG(dbgs() << "SLP: Non-consecutive store.\n"); return; } - newTreeEntry(VL, true); + newTreeEntry(VL, true, UserTreeIdx); DEBUG(dbgs() << "SLP: added a vector of stores.\n"); ValueList Operands; for (Value *j : VL) Operands.push_back(cast<Instruction>(j)->getOperand(0)); - buildTree_rec(Operands, Depth + 1); + buildTree_rec(Operands, Depth + 1, UserTreeIdx); return; } case Instruction::Call: { @@ -1399,8 +1541,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { // represented by an intrinsic call Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); if (!isTriviallyVectorizable(ID)) { - BS.cancelScheduling(VL); - newTreeEntry(VL, false); + BS.cancelScheduling(VL, VL0); + newTreeEntry(VL, false, UserTreeIdx); DEBUG(dbgs() << "SLP: Non-vectorizable call.\n"); return; } @@ -1413,8 +1555,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { if (!CI2 || CI2->getCalledFunction() != Int || getVectorIntrinsicIDForCall(CI2, TLI) != ID || !CI->hasIdenticalOperandBundleSchema(*CI2)) { - BS.cancelScheduling(VL); - newTreeEntry(VL, false); + BS.cancelScheduling(VL, VL0); + newTreeEntry(VL, false, UserTreeIdx); DEBUG(dbgs() << "SLP: mismatched calls:" << *CI << "!=" << *VL[i] << "\n"); return; @@ -1424,8 +1566,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { if (hasVectorInstrinsicScalarOpd(ID, 1)) { Value *A1J = CI2->getArgOperand(1); if (A1I != A1J) { - BS.cancelScheduling(VL); - newTreeEntry(VL, false); + BS.cancelScheduling(VL, VL0); + newTreeEntry(VL, false, UserTreeIdx); DEBUG(dbgs() << "SLP: mismatched arguments in call:" << *CI << " argument "<< A1I<<"!=" << A1J << "\n"); @@ -1437,15 +1579,15 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { !std::equal(CI->op_begin() + CI->getBundleOperandsStartIndex(), CI->op_begin() + CI->getBundleOperandsEndIndex(), CI2->op_begin() + CI2->getBundleOperandsStartIndex())) { - BS.cancelScheduling(VL); - newTreeEntry(VL, false); + BS.cancelScheduling(VL, VL0); + newTreeEntry(VL, false, UserTreeIdx); DEBUG(dbgs() << "SLP: mismatched bundle operands in calls:" << *CI << "!=" << *VL[i] << '\n'); return; } } - newTreeEntry(VL, true); + newTreeEntry(VL, true, UserTreeIdx); for (unsigned i = 0, e = CI->getNumArgOperands(); i != e; ++i) { ValueList Operands; // Prepare the operand vector. @@ -1453,7 +1595,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { CallInst *CI2 = dyn_cast<CallInst>(j); Operands.push_back(CI2->getArgOperand(i)); } - buildTree_rec(Operands, Depth + 1); + buildTree_rec(Operands, Depth + 1, UserTreeIdx); } return; } @@ -1461,20 +1603,20 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { // If this is not an alternate sequence of opcode like add-sub // then do not vectorize this instruction. if (!isAltShuffle) { - BS.cancelScheduling(VL); - newTreeEntry(VL, false); + BS.cancelScheduling(VL, VL0); + newTreeEntry(VL, false, UserTreeIdx); DEBUG(dbgs() << "SLP: ShuffleVector are not vectorized.\n"); return; } - newTreeEntry(VL, true); + newTreeEntry(VL, true, UserTreeIdx); DEBUG(dbgs() << "SLP: added a ShuffleVector op.\n"); // Reorder operands if reordering would enable vectorization. if (isa<BinaryOperator>(VL0)) { ValueList Left, Right; reorderAltShuffleOperands(VL, Left, Right); - buildTree_rec(Left, Depth + 1); - buildTree_rec(Right, Depth + 1); + buildTree_rec(Left, Depth + 1, UserTreeIdx); + buildTree_rec(Right, Depth + 1, UserTreeIdx); return; } @@ -1484,13 +1626,13 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { for (Value *j : VL) Operands.push_back(cast<Instruction>(j)->getOperand(i)); - buildTree_rec(Operands, Depth + 1); + buildTree_rec(Operands, Depth + 1, UserTreeIdx); } return; } default: - BS.cancelScheduling(VL); - newTreeEntry(VL, false); + BS.cancelScheduling(VL, VL0); + newTreeEntry(VL, false, UserTreeIdx); DEBUG(dbgs() << "SLP: Gathering unknown instruction.\n"); return; } @@ -1570,6 +1712,8 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { Type *ScalarTy = VL[0]->getType(); if (StoreInst *SI = dyn_cast<StoreInst>(VL[0])) ScalarTy = SI->getValueOperand()->getType(); + else if (CmpInst *CI = dyn_cast<CmpInst>(VL[0])) + ScalarTy = CI->getOperand(0)->getType(); VectorType *VecTy = VectorType::get(ScalarTy, VL.size()); // If we have computed a smaller type for the expression, update VecTy so @@ -1599,7 +1743,13 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { int DeadCost = 0; for (unsigned i = 0, e = VL.size(); i < e; ++i) { Instruction *E = cast<Instruction>(VL[i]); - if (E->hasOneUse()) + // If all users are going to be vectorized, instruction can be + // considered as dead. + // The same, if have only one user, it will be vectorized for sure. + if (E->hasOneUse() || + std::all_of(E->user_begin(), E->user_end(), [this](User *U) { + return ScalarToTreeEntry.count(U) > 0; + })) // Take credit for instruction that will become dead. DeadCost += TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy, i); @@ -1624,10 +1774,10 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { // Calculate the cost of this instruction. int ScalarCost = VL.size() * TTI->getCastInstrCost(VL0->getOpcode(), - VL0->getType(), SrcTy); + VL0->getType(), SrcTy, VL0); VectorType *SrcVecTy = VectorType::get(SrcTy, VL.size()); - int VecCost = TTI->getCastInstrCost(VL0->getOpcode(), VecTy, SrcVecTy); + int VecCost = TTI->getCastInstrCost(VL0->getOpcode(), VecTy, SrcVecTy, VL0); return VecCost - ScalarCost; } case Instruction::FCmp: @@ -1636,8 +1786,8 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { // Calculate the cost of this instruction. VectorType *MaskTy = VectorType::get(Builder.getInt1Ty(), VL.size()); int ScalarCost = VecTy->getNumElements() * - TTI->getCmpSelInstrCost(Opcode, ScalarTy, Builder.getInt1Ty()); - int VecCost = TTI->getCmpSelInstrCost(Opcode, VecTy, MaskTy); + TTI->getCmpSelInstrCost(Opcode, ScalarTy, Builder.getInt1Ty(), VL0); + int VecCost = TTI->getCmpSelInstrCost(Opcode, VecTy, MaskTy, VL0); return VecCost - ScalarCost; } case Instruction::Add: @@ -1695,11 +1845,13 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { CInt->getValue().isPowerOf2()) Op2VP = TargetTransformInfo::OP_PowerOf2; - int ScalarCost = VecTy->getNumElements() * - TTI->getArithmeticInstrCost(Opcode, ScalarTy, Op1VK, - Op2VK, Op1VP, Op2VP); + SmallVector<const Value *, 4> Operands(VL0->operand_values()); + int ScalarCost = + VecTy->getNumElements() * + TTI->getArithmeticInstrCost(Opcode, ScalarTy, Op1VK, Op2VK, Op1VP, + Op2VP, Operands); int VecCost = TTI->getArithmeticInstrCost(Opcode, VecTy, Op1VK, Op2VK, - Op1VP, Op2VP); + Op1VP, Op2VP, Operands); return VecCost - ScalarCost; } case Instruction::GetElementPtr: { @@ -1720,18 +1872,18 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { // Cost of wide load - cost of scalar loads. unsigned alignment = dyn_cast<LoadInst>(VL0)->getAlignment(); int ScalarLdCost = VecTy->getNumElements() * - TTI->getMemoryOpCost(Instruction::Load, ScalarTy, alignment, 0); + TTI->getMemoryOpCost(Instruction::Load, ScalarTy, alignment, 0, VL0); int VecLdCost = TTI->getMemoryOpCost(Instruction::Load, - VecTy, alignment, 0); + VecTy, alignment, 0, VL0); return VecLdCost - ScalarLdCost; } case Instruction::Store: { // We know that we can merge the stores. Calculate the cost. unsigned alignment = dyn_cast<StoreInst>(VL0)->getAlignment(); int ScalarStCost = VecTy->getNumElements() * - TTI->getMemoryOpCost(Instruction::Store, ScalarTy, alignment, 0); + TTI->getMemoryOpCost(Instruction::Store, ScalarTy, alignment, 0, VL0); int VecStCost = TTI->getMemoryOpCost(Instruction::Store, - VecTy, alignment, 0); + VecTy, alignment, 0, VL0); return VecStCost - ScalarStCost; } case Instruction::Call: { @@ -1739,12 +1891,9 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); // Calculate the cost of the scalar and vector calls. - SmallVector<Type*, 4> ScalarTys, VecTys; - for (unsigned op = 0, opc = CI->getNumArgOperands(); op!= opc; ++op) { + SmallVector<Type*, 4> ScalarTys; + for (unsigned op = 0, opc = CI->getNumArgOperands(); op!= opc; ++op) ScalarTys.push_back(CI->getArgOperand(op)->getType()); - VecTys.push_back(VectorType::get(CI->getArgOperand(op)->getType(), - VecTy->getNumElements())); - } FastMathFlags FMF; if (auto *FPMO = dyn_cast<FPMathOperator>(CI)) @@ -1753,7 +1902,9 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { int ScalarCallCost = VecTy->getNumElements() * TTI->getIntrinsicInstrCost(ID, ScalarTy, ScalarTys, FMF); - int VecCallCost = TTI->getIntrinsicInstrCost(ID, VecTy, VecTys, FMF); + SmallVector<Value *, 4> Args(CI->arg_operands()); + int VecCallCost = TTI->getIntrinsicInstrCost(ID, CI->getType(), Args, FMF, + VecTy->getNumElements()); DEBUG(dbgs() << "SLP: Call cost "<< VecCallCost - ScalarCallCost << " (" << VecCallCost << "-" << ScalarCallCost << ")" @@ -1861,7 +2012,7 @@ int BoUpSLP::getSpillCost() { // Update LiveValues. LiveValues.erase(PrevInst); for (auto &J : PrevInst->operands()) { - if (isa<Instruction>(&*J) && ScalarToTreeEntry.count(&*J)) + if (isa<Instruction>(&*J) && getTreeEntry(&*J)) LiveValues.insert(cast<Instruction>(&*J)); } @@ -1947,9 +2098,18 @@ int BoUpSLP::getTreeCost() { int SpillCost = getSpillCost(); Cost += SpillCost + ExtractCost; - DEBUG(dbgs() << "SLP: Spill Cost = " << SpillCost << ".\n" - << "SLP: Extract Cost = " << ExtractCost << ".\n" - << "SLP: Total Cost = " << Cost << ".\n"); + std::string Str; + { + raw_string_ostream OS(Str); + OS << "SLP: Spill Cost = " << SpillCost << ".\n" + << "SLP: Extract Cost = " << ExtractCost << ".\n" + << "SLP: Total Cost = " << Cost << ".\n"; + } + DEBUG(dbgs() << Str); + + if (ViewSLPTree) + ViewGraph(this, "SLP" + F->getName(), false, Str); + return Cost; } @@ -2248,9 +2408,7 @@ Value *BoUpSLP::Gather(ArrayRef<Value *> VL, VectorType *Ty) { CSEBlocks.insert(Insrt->getParent()); // Add to our 'need-to-extract' list. - if (ScalarToTreeEntry.count(VL[i])) { - int Idx = ScalarToTreeEntry[VL[i]]; - TreeEntry *E = &VectorizableTree[Idx]; + if (TreeEntry *E = getTreeEntry(VL[i])) { // Find which lane we need to extract. int FoundLane = -1; for (unsigned Lane = 0, LE = VL.size(); Lane != LE; ++Lane) { @@ -2269,12 +2427,8 @@ Value *BoUpSLP::Gather(ArrayRef<Value *> VL, VectorType *Ty) { return Vec; } -Value *BoUpSLP::alreadyVectorized(ArrayRef<Value *> VL) const { - SmallDenseMap<Value*, int>::const_iterator Entry - = ScalarToTreeEntry.find(VL[0]); - if (Entry != ScalarToTreeEntry.end()) { - int Idx = Entry->second; - const TreeEntry *En = &VectorizableTree[Idx]; +Value *BoUpSLP::alreadyVectorized(ArrayRef<Value *> VL, Value *OpValue) const { + if (const TreeEntry *En = getTreeEntry(OpValue)) { if (En->isSame(VL) && En->VectorizedValue) return En->VectorizedValue; } @@ -2282,12 +2436,9 @@ Value *BoUpSLP::alreadyVectorized(ArrayRef<Value *> VL) const { } Value *BoUpSLP::vectorizeTree(ArrayRef<Value *> VL) { - if (ScalarToTreeEntry.count(VL[0])) { - int Idx = ScalarToTreeEntry[VL[0]]; - TreeEntry *E = &VectorizableTree[Idx]; + if (TreeEntry *E = getTreeEntry(VL[0])) if (E->isSame(VL)) return vectorizeTree(E); - } Type *ScalarTy = VL[0]->getType(); if (StoreInst *SI = dyn_cast<StoreInst>(VL[0])) @@ -2402,7 +2553,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { Value *InVec = vectorizeTree(INVL); - if (Value *V = alreadyVectorized(E->Scalars)) + if (Value *V = alreadyVectorized(E->Scalars, VL0)) return V; CastInst *CI = dyn_cast<CastInst>(VL0); @@ -2424,7 +2575,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { Value *L = vectorizeTree(LHSV); Value *R = vectorizeTree(RHSV); - if (Value *V = alreadyVectorized(E->Scalars)) + if (Value *V = alreadyVectorized(E->Scalars, VL0)) return V; CmpInst::Predicate P0 = cast<CmpInst>(VL0)->getPredicate(); @@ -2453,7 +2604,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { Value *True = vectorizeTree(TrueVec); Value *False = vectorizeTree(FalseVec); - if (Value *V = alreadyVectorized(E->Scalars)) + if (Value *V = alreadyVectorized(E->Scalars, VL0)) return V; Value *V = Builder.CreateSelect(Cond, True, False); @@ -2493,7 +2644,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { Value *LHS = vectorizeTree(LHSVL); Value *RHS = vectorizeTree(RHSVL); - if (Value *V = alreadyVectorized(E->Scalars)) + if (Value *V = alreadyVectorized(E->Scalars, VL0)) return V; BinaryOperator *BinOp = cast<BinaryOperator>(VL0); @@ -2522,9 +2673,9 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { // The pointer operand uses an in-tree scalar so we add the new BitCast to // ExternalUses list to make sure that an extract will be generated in the // future. - if (ScalarToTreeEntry.count(LI->getPointerOperand())) - ExternalUses.push_back( - ExternalUser(LI->getPointerOperand(), cast<User>(VecPtr), 0)); + Value *PO = LI->getPointerOperand(); + if (getTreeEntry(PO)) + ExternalUses.push_back(ExternalUser(PO, cast<User>(VecPtr), 0)); unsigned Alignment = LI->getAlignment(); LI = Builder.CreateLoad(VecPtr); @@ -2555,9 +2706,9 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { // The pointer operand uses an in-tree scalar so we add the new BitCast to // ExternalUses list to make sure that an extract will be generated in the // future. - if (ScalarToTreeEntry.count(SI->getPointerOperand())) - ExternalUses.push_back( - ExternalUser(SI->getPointerOperand(), cast<User>(VecPtr), 0)); + Value *PO = SI->getPointerOperand(); + if (getTreeEntry(PO)) + ExternalUses.push_back(ExternalUser(PO, cast<User>(VecPtr), 0)); if (!Alignment) { Alignment = DL->getABITypeAlignment(SI->getValueOperand()->getType()); @@ -2638,7 +2789,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { // The scalar argument uses an in-tree scalar so we add the new vectorized // call to ExternalUses list to make sure that an extract will be // generated in the future. - if (ScalarArg && ScalarToTreeEntry.count(ScalarArg)) + if (ScalarArg && getTreeEntry(ScalarArg)) ExternalUses.push_back(ExternalUser(ScalarArg, cast<User>(V), 0)); E->VectorizedValue = V; @@ -2655,7 +2806,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { Value *LHS = vectorizeTree(LHSVL); Value *RHS = vectorizeTree(RHSVL); - if (Value *V = alreadyVectorized(E->Scalars)) + if (Value *V = alreadyVectorized(E->Scalars, VL0)) return V; // Create a vector of LHS op1 RHS @@ -2674,7 +2825,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { unsigned e = E->Scalars.size(); SmallVector<Constant *, 8> Mask(e); for (unsigned i = 0; i < e; ++i) { - if (i & 1) { + if (isOdd(i)) { Mask[i] = Builder.getInt32(e + i); OddScalars.push_back(E->Scalars[i]); } else { @@ -2702,6 +2853,12 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { } Value *BoUpSLP::vectorizeTree() { + ExtraValueToDebugLocsMap ExternallyUsedValues; + return vectorizeTree(ExternallyUsedValues); +} + +Value * +BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) { // All blocks must be scheduled before any instructions are inserted. for (auto &BSIter : BlocksSchedules) { @@ -2744,18 +2901,38 @@ Value *BoUpSLP::vectorizeTree() { // Skip users that we already RAUW. This happens when one instruction // has multiple uses of the same value. - if (!is_contained(Scalar->users(), User)) + if (User && !is_contained(Scalar->users(), User)) continue; - assert(ScalarToTreeEntry.count(Scalar) && "Invalid scalar"); - - int Idx = ScalarToTreeEntry[Scalar]; - TreeEntry *E = &VectorizableTree[Idx]; + TreeEntry *E = getTreeEntry(Scalar); + assert(E && "Invalid scalar"); assert(!E->NeedToGather && "Extracting from a gather list"); Value *Vec = E->VectorizedValue; assert(Vec && "Can't find vectorizable value"); Value *Lane = Builder.getInt32(ExternalUse.Lane); + // If User == nullptr, the Scalar is used as extra arg. Generate + // ExtractElement instruction and update the record for this scalar in + // ExternallyUsedValues. + if (!User) { + assert(ExternallyUsedValues.count(Scalar) && + "Scalar with nullptr as an external user must be registered in " + "ExternallyUsedValues map"); + if (auto *VecI = dyn_cast<Instruction>(Vec)) { + Builder.SetInsertPoint(VecI->getParent(), + std::next(VecI->getIterator())); + } else { + Builder.SetInsertPoint(&F->getEntryBlock().front()); + } + Value *Ex = Builder.CreateExtractElement(Vec, Lane); + Ex = extend(ScalarRoot, Ex, Scalar->getType()); + CSEBlocks.insert(cast<Instruction>(Scalar)->getParent()); + auto &Locs = ExternallyUsedValues[Scalar]; + ExternallyUsedValues.insert({Ex, Locs}); + ExternallyUsedValues.erase(Scalar); + continue; + } + // Generate extracts for out-of-tree users. // Find the insertion point for the extractelement lane. if (auto *VecI = dyn_cast<Instruction>(Vec)) { @@ -2813,7 +2990,7 @@ Value *BoUpSLP::vectorizeTree() { for (User *U : Scalar->users()) { DEBUG(dbgs() << "SLP: \tvalidating user:" << *U << ".\n"); - assert((ScalarToTreeEntry.count(U) || + assert((getTreeEntry(U) || // It is legal to replace users in the ignorelist by undef. is_contained(UserIgnoreList, U)) && "Replacing out-of-tree value with undef"); @@ -2920,8 +3097,8 @@ void BoUpSLP::optimizeGatherSequence() { // Groups the instructions to a bundle (which is then a single scheduling entity) // and schedules instructions until the bundle gets ready. bool BoUpSLP::BlockScheduling::tryScheduleBundle(ArrayRef<Value *> VL, - BoUpSLP *SLP) { - if (isa<PHINode>(VL[0])) + BoUpSLP *SLP, Value *OpValue) { + if (isa<PHINode>(OpValue)) return true; // Initialize the instruction bundle. @@ -2929,7 +3106,7 @@ bool BoUpSLP::BlockScheduling::tryScheduleBundle(ArrayRef<Value *> VL, ScheduleData *PrevInBundle = nullptr; ScheduleData *Bundle = nullptr; bool ReSchedule = false; - DEBUG(dbgs() << "SLP: bundle: " << *VL[0] << "\n"); + DEBUG(dbgs() << "SLP: bundle: " << *OpValue << "\n"); // Make sure that the scheduling region contains all // instructions of the bundle. @@ -3000,17 +3177,18 @@ bool BoUpSLP::BlockScheduling::tryScheduleBundle(ArrayRef<Value *> VL, } } if (!Bundle->isReady()) { - cancelScheduling(VL); + cancelScheduling(VL, OpValue); return false; } return true; } -void BoUpSLP::BlockScheduling::cancelScheduling(ArrayRef<Value *> VL) { - if (isa<PHINode>(VL[0])) +void BoUpSLP::BlockScheduling::cancelScheduling(ArrayRef<Value *> VL, + Value *OpValue) { + if (isa<PHINode>(OpValue)) return; - ScheduleData *Bundle = getScheduleData(VL[0]); + ScheduleData *Bundle = getScheduleData(OpValue); DEBUG(dbgs() << "SLP: cancel scheduling of " << *Bundle << "\n"); assert(!Bundle->IsScheduled && "Can't cancel bundle which is already scheduled"); @@ -3154,12 +3332,10 @@ void BoUpSLP::BlockScheduling::calculateDependencies(ScheduleData *SD, if (UseSD && isInSchedulingRegion(UseSD->FirstInBundle)) { BundleMember->Dependencies++; ScheduleData *DestBundle = UseSD->FirstInBundle; - if (!DestBundle->IsScheduled) { + if (!DestBundle->IsScheduled) BundleMember->incrementUnscheduledDeps(1); - } - if (!DestBundle->hasValidDependencies()) { + if (!DestBundle->hasValidDependencies()) WorkList.push_back(DestBundle); - } } } else { // I'm not sure if this can ever happen. But we need to be safe. @@ -3264,7 +3440,7 @@ void BoUpSLP::scheduleBlock(BlockScheduling *BS) { // sorted by the original instruction location. This lets the final schedule // be as close as possible to the original instruction order. struct ScheduleDataCompare { - bool operator()(ScheduleData *SD1, ScheduleData *SD2) { + bool operator()(ScheduleData *SD1, ScheduleData *SD2) const { return SD2->SchedulingPriority < SD1->SchedulingPriority; } }; @@ -3278,7 +3454,7 @@ void BoUpSLP::scheduleBlock(BlockScheduling *BS) { I = I->getNextNode()) { ScheduleData *SD = BS->getScheduleData(I); assert( - SD->isPartOfBundle() == (ScalarToTreeEntry.count(SD->Inst) != 0) && + SD->isPartOfBundle() == (getTreeEntry(SD->Inst) != nullptr) && "scheduler and vectorizer have different opinion on what is a bundle"); SD->FirstInBundle->SchedulingPriority = Idx++; if (SD->isSchedulingEntity()) { @@ -3527,10 +3703,8 @@ void BoUpSLP::computeMinimumValueSizes() { // Determine if the sign bit of all the roots is known to be zero. If not, // IsKnownPositive is set to False. IsKnownPositive = all_of(TreeRoot, [&](Value *R) { - bool KnownZero = false; - bool KnownOne = false; - ComputeSignBit(R, KnownZero, KnownOne, *DL); - return KnownZero; + KnownBits Known = computeKnownBits(R, *DL); + return Known.isNonNegative(); }); // Determine the maximum number of bits required to store the scalar @@ -3610,8 +3784,9 @@ struct SLPVectorizer : public FunctionPass { auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); auto *DB = &getAnalysis<DemandedBitsWrapperPass>().getDemandedBits(); + auto *ORE = &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); - return Impl.runImpl(F, SE, TTI, TLI, AA, LI, DT, AC, DB); + return Impl.runImpl(F, SE, TTI, TLI, AA, LI, DT, AC, DB, ORE); } void getAnalysisUsage(AnalysisUsage &AU) const override { @@ -3623,6 +3798,7 @@ struct SLPVectorizer : public FunctionPass { AU.addRequired<LoopInfoWrapperPass>(); AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<DemandedBitsWrapperPass>(); + AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); AU.addPreserved<LoopInfoWrapperPass>(); AU.addPreserved<DominatorTreeWrapperPass>(); AU.addPreserved<AAResultsWrapperPass>(); @@ -3641,13 +3817,14 @@ PreservedAnalyses SLPVectorizerPass::run(Function &F, FunctionAnalysisManager &A auto *DT = &AM.getResult<DominatorTreeAnalysis>(F); auto *AC = &AM.getResult<AssumptionAnalysis>(F); auto *DB = &AM.getResult<DemandedBitsAnalysis>(F); + auto *ORE = &AM.getResult<OptimizationRemarkEmitterAnalysis>(F); - bool Changed = runImpl(F, SE, TTI, TLI, AA, LI, DT, AC, DB); + bool Changed = runImpl(F, SE, TTI, TLI, AA, LI, DT, AC, DB, ORE); if (!Changed) return PreservedAnalyses::all(); + PreservedAnalyses PA; - PA.preserve<LoopAnalysis>(); - PA.preserve<DominatorTreeAnalysis>(); + PA.preserveSet<CFGAnalyses>(); PA.preserve<AAManager>(); PA.preserve<GlobalsAA>(); return PA; @@ -3657,7 +3834,8 @@ bool SLPVectorizerPass::runImpl(Function &F, ScalarEvolution *SE_, TargetTransformInfo *TTI_, TargetLibraryInfo *TLI_, AliasAnalysis *AA_, LoopInfo *LI_, DominatorTree *DT_, - AssumptionCache *AC_, DemandedBits *DB_) { + AssumptionCache *AC_, DemandedBits *DB_, + OptimizationRemarkEmitter *ORE_) { SE = SE_; TTI = TTI_; TLI = TLI_; @@ -3685,7 +3863,7 @@ bool SLPVectorizerPass::runImpl(Function &F, ScalarEvolution *SE_, // Use the bottom up slp vectorizer to construct chains that start with // store instructions. - BoUpSLP R(&F, SE, TTI, TLI, AA, LI, DT, AC, DB, DL); + BoUpSLP R(&F, SE, TTI, TLI, AA, LI, DT, AC, DB, DL, ORE_); // A general note: the vectorizer must use BoUpSLP::eraseInstruction() to // delete instructions. @@ -3723,11 +3901,13 @@ bool SLPVectorizerPass::runImpl(Function &F, ScalarEvolution *SE_, } /// \brief Check that the Values in the slice in VL array are still existent in -/// the WeakVH array. +/// the WeakTrackingVH array. /// Vectorization of part of the VL array may cause later values in the VL array -/// to become invalid. We track when this has happened in the WeakVH array. -static bool hasValueBeenRAUWed(ArrayRef<Value *> VL, ArrayRef<WeakVH> VH, - unsigned SliceBegin, unsigned SliceSize) { +/// to become invalid. We track when this has happened in the WeakTrackingVH +/// array. +static bool hasValueBeenRAUWed(ArrayRef<Value *> VL, + ArrayRef<WeakTrackingVH> VH, unsigned SliceBegin, + unsigned SliceSize) { VL = VL.slice(SliceBegin, SliceSize); VH = VH.slice(SliceBegin, SliceSize); return !std::equal(VL.begin(), VL.end(), VH.begin()); @@ -3745,7 +3925,7 @@ bool SLPVectorizerPass::vectorizeStoreChain(ArrayRef<Value *> Chain, BoUpSLP &R, return false; // Keep track of values that were deleted by vectorizing in the loop below. - SmallVector<WeakVH, 8> TrackValues(Chain.begin(), Chain.end()); + SmallVector<WeakTrackingVH, 8> TrackValues(Chain.begin(), Chain.end()); bool Changed = false; // Look for profitable vectorizable trees at all offsets, starting at zero. @@ -3772,6 +3952,13 @@ bool SLPVectorizerPass::vectorizeStoreChain(ArrayRef<Value *> Chain, BoUpSLP &R, DEBUG(dbgs() << "SLP: Found cost=" << Cost << " for VF=" << VF << "\n"); if (Cost < -SLPCostThreshold) { DEBUG(dbgs() << "SLP: Decided to vectorize cost=" << Cost << "\n"); + using namespace ore; + R.getORE()->emit(OptimizationRemark(SV_NAME, "StoresVectorized", + cast<StoreInst>(Chain[i])) + << "Stores SLP vectorized with cost " << NV("Cost", Cost) + << " and with tree size " + << NV("TreeSize", R.getTreeSize())); + R.vectorizeTree(); // Move to the next bundle. @@ -3931,7 +4118,7 @@ bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, bool Changed = false; // Keep track of values that were deleted by vectorizing in the loop below. - SmallVector<WeakVH, 8> TrackValues(VL.begin(), VL.end()); + SmallVector<WeakTrackingVH, 8> TrackValues(VL.begin(), VL.end()); unsigned NextInst = 0, MaxInst = VL.size(); for (unsigned VF = MaxVF; NextInst + 1 < MaxInst && VF >= MinVF; @@ -3970,8 +4157,8 @@ bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, if (AllowReorder && R.shouldReorder()) { // Conceptually, there is nothing actually preventing us from trying to // reorder a larger list. In fact, we do exactly this when vectorizing - // reductions. However, at this point, we only expect to get here from - // tryToVectorizePair(). + // reductions. However, at this point, we only expect to get here when + // there are exactly two operations. assert(Ops.size() == 2); assert(BuildVectorSlice.empty()); Value *ReorderedOps[] = {Ops[1], Ops[0]}; @@ -3985,6 +4172,12 @@ bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, if (Cost < -SLPCostThreshold) { DEBUG(dbgs() << "SLP: Vectorizing list at cost:" << Cost << ".\n"); + R.getORE()->emit(OptimizationRemark(SV_NAME, "VectorizedList", + cast<Instruction>(Ops[0])) + << "SLP vectorized with cost " << ore::NV("Cost", Cost) + << " and with tree size " + << ore::NV("TreeSize", R.getTreeSize())); + Value *VectorizedRoot = R.vectorizeTree(); // Reconstruct the build vector by extracting the vectorized root. This @@ -4026,36 +4219,40 @@ bool SLPVectorizerPass::tryToVectorize(BinaryOperator *V, BoUpSLP &R) { if (!V) return false; + Value *P = V->getParent(); + + // Vectorize in current basic block only. + auto *Op0 = dyn_cast<Instruction>(V->getOperand(0)); + auto *Op1 = dyn_cast<Instruction>(V->getOperand(1)); + if (!Op0 || !Op1 || Op0->getParent() != P || Op1->getParent() != P) + return false; + // Try to vectorize V. - if (tryToVectorizePair(V->getOperand(0), V->getOperand(1), R)) + if (tryToVectorizePair(Op0, Op1, R)) return true; - BinaryOperator *A = dyn_cast<BinaryOperator>(V->getOperand(0)); - BinaryOperator *B = dyn_cast<BinaryOperator>(V->getOperand(1)); + auto *A = dyn_cast<BinaryOperator>(Op0); + auto *B = dyn_cast<BinaryOperator>(Op1); // Try to skip B. if (B && B->hasOneUse()) { - BinaryOperator *B0 = dyn_cast<BinaryOperator>(B->getOperand(0)); - BinaryOperator *B1 = dyn_cast<BinaryOperator>(B->getOperand(1)); - if (tryToVectorizePair(A, B0, R)) { + auto *B0 = dyn_cast<BinaryOperator>(B->getOperand(0)); + auto *B1 = dyn_cast<BinaryOperator>(B->getOperand(1)); + if (B0 && B0->getParent() == P && tryToVectorizePair(A, B0, R)) return true; - } - if (tryToVectorizePair(A, B1, R)) { + if (B1 && B1->getParent() == P && tryToVectorizePair(A, B1, R)) return true; - } } // Try to skip A. if (A && A->hasOneUse()) { - BinaryOperator *A0 = dyn_cast<BinaryOperator>(A->getOperand(0)); - BinaryOperator *A1 = dyn_cast<BinaryOperator>(A->getOperand(1)); - if (tryToVectorizePair(A0, B, R)) { + auto *A0 = dyn_cast<BinaryOperator>(A->getOperand(0)); + auto *A1 = dyn_cast<BinaryOperator>(A->getOperand(1)); + if (A0 && A0->getParent() == P && tryToVectorizePair(A0, B, R)) return true; - } - if (tryToVectorizePair(A1, B, R)) { + if (A1 && A1->getParent() == P && tryToVectorizePair(A1, B, R)) return true; - } } - return 0; + return false; } /// \brief Generate a shuffle mask to be used in a reduction tree. @@ -4119,37 +4316,41 @@ namespace { class HorizontalReduction { SmallVector<Value *, 16> ReductionOps; SmallVector<Value *, 32> ReducedVals; + // Use map vector to make stable output. + MapVector<Instruction *, Value *> ExtraArgs; - BinaryOperator *ReductionRoot; - // After successfull horizontal reduction vectorization attempt for PHI node - // vectorizer tries to update root binary op by combining vectorized tree and - // the ReductionPHI node. But during vectorization this ReductionPHI can be - // vectorized itself and replaced by the undef value, while the instruction - // itself is marked for deletion. This 'marked for deletion' PHI node then can - // be used in new binary operation, causing "Use still stuck around after Def - // is destroyed" crash upon PHI node deletion. - WeakVH ReductionPHI; + BinaryOperator *ReductionRoot = nullptr; /// The opcode of the reduction. - unsigned ReductionOpcode; + Instruction::BinaryOps ReductionOpcode = Instruction::BinaryOpsEnd; /// The opcode of the values we perform a reduction on. - unsigned ReducedValueOpcode; + unsigned ReducedValueOpcode = 0; /// Should we model this reduction as a pairwise reduction tree or a tree that /// splits the vector in halves and adds those halves. - bool IsPairwiseReduction; + bool IsPairwiseReduction = false; + + /// Checks if the ParentStackElem.first should be marked as a reduction + /// operation with an extra argument or as extra argument itself. + void markExtraArg(std::pair<Instruction *, unsigned> &ParentStackElem, + Value *ExtraArg) { + if (ExtraArgs.count(ParentStackElem.first)) { + ExtraArgs[ParentStackElem.first] = nullptr; + // We ran into something like: + // ParentStackElem.first = ExtraArgs[ParentStackElem.first] + ExtraArg. + // The whole ParentStackElem.first should be considered as an extra value + // in this case. + // Do not perform analysis of remaining operands of ParentStackElem.first + // instruction, this whole instruction is an extra argument. + ParentStackElem.second = ParentStackElem.first->getNumOperands(); + } else { + // We ran into something like: + // ParentStackElem.first += ... + ExtraArg + ... + ExtraArgs[ParentStackElem.first] = ExtraArg; + } + } public: - /// The width of one full horizontal reduction operation. - unsigned ReduxWidth; - - /// Minimal width of available vector registers. It's used to determine - /// ReduxWidth. - unsigned MinVecRegSize; - - HorizontalReduction(unsigned MinVecRegSize) - : ReductionRoot(nullptr), ReductionOpcode(0), ReducedValueOpcode(0), - IsPairwiseReduction(false), ReduxWidth(0), - MinVecRegSize(MinVecRegSize) {} + HorizontalReduction() = default; /// \brief Try to find a reduction tree. bool matchAssociativeReduction(PHINode *Phi, BinaryOperator *B) { @@ -4176,21 +4377,14 @@ public: if (!isValidElementType(Ty)) return false; - const DataLayout &DL = B->getModule()->getDataLayout(); ReductionOpcode = B->getOpcode(); ReducedValueOpcode = 0; - // FIXME: Register size should be a parameter to this function, so we can - // try different vectorization factors. - ReduxWidth = MinVecRegSize / DL.getTypeSizeInBits(Ty); ReductionRoot = B; - ReductionPHI = Phi; - - if (ReduxWidth < 4) - return false; // We currently only support adds. - if (ReductionOpcode != Instruction::Add && - ReductionOpcode != Instruction::FAdd) + if ((ReductionOpcode != Instruction::Add && + ReductionOpcode != Instruction::FAdd) || + !B->isAssociative()) return false; // Post order traverse the reduction tree starting at B. We only handle true @@ -4202,30 +4396,26 @@ public: unsigned EdgeToVist = Stack.back().second++; bool IsReducedValue = TreeN->getOpcode() != ReductionOpcode; - // Only handle trees in the current basic block. - if (TreeN->getParent() != B->getParent()) - return false; - - // Each tree node needs to have one user except for the ultimate - // reduction. - if (!TreeN->hasOneUse() && TreeN != B) - return false; - // Postorder vist. if (EdgeToVist == 2 || IsReducedValue) { - if (IsReducedValue) { - // Make sure that the opcodes of the operations that we are going to - // reduce match. - if (!ReducedValueOpcode) - ReducedValueOpcode = TreeN->getOpcode(); - else if (ReducedValueOpcode != TreeN->getOpcode()) - return false; + if (IsReducedValue) ReducedVals.push_back(TreeN); - } else { - // We need to be able to reassociate the adds. - if (!TreeN->isAssociative()) - return false; - ReductionOps.push_back(TreeN); + else { + auto I = ExtraArgs.find(TreeN); + if (I != ExtraArgs.end() && !I->second) { + // Check if TreeN is an extra argument of its parent operation. + if (Stack.size() <= 1) { + // TreeN can't be an extra argument as it is a root reduction + // operation. + return false; + } + // Yes, TreeN is an extra argument, do not add it to a list of + // reduction operations. + // Stack[Stack.size() - 2] always points to the parent operation. + markExtraArg(Stack[Stack.size() - 2], TreeN); + ExtraArgs.erase(TreeN); + } else + ReductionOps.push_back(TreeN); } // Retract. Stack.pop_back(); @@ -4242,13 +4432,44 @@ public: // reduced value class. if (I && (!ReducedValueOpcode || I->getOpcode() == ReducedValueOpcode || I->getOpcode() == ReductionOpcode)) { - if (!ReducedValueOpcode && I->getOpcode() != ReductionOpcode) + // Only handle trees in the current basic block. + if (I->getParent() != B->getParent()) { + // I is an extra argument for TreeN (its parent operation). + markExtraArg(Stack.back(), I); + continue; + } + + // Each tree node needs to have one user except for the ultimate + // reduction. + if (!I->hasOneUse() && I != B) { + // I is an extra argument for TreeN (its parent operation). + markExtraArg(Stack.back(), I); + continue; + } + + if (I->getOpcode() == ReductionOpcode) { + // We need to be able to reassociate the reduction operations. + if (!I->isAssociative()) { + // I is an extra argument for TreeN (its parent operation). + markExtraArg(Stack.back(), I); + continue; + } + } else if (ReducedValueOpcode && + ReducedValueOpcode != I->getOpcode()) { + // Make sure that the opcodes of the operations that we are going to + // reduce match. + // I is an extra argument for TreeN (its parent operation). + markExtraArg(Stack.back(), I); + continue; + } else if (!ReducedValueOpcode) ReducedValueOpcode = I->getOpcode(); + Stack.push_back(std::make_pair(I, 0)); continue; } - return false; } + // NextV is an extra argument for TreeN (its parent operation). + markExtraArg(Stack.back(), NextV); } return true; } @@ -4259,10 +4480,15 @@ public: if (ReducedVals.empty()) return false; + // If there is a sufficient number of reduction values, reduce + // to a nearby power-of-2. Can safely generate oversized + // vectors and rely on the backend to split them to legal sizes. unsigned NumReducedVals = ReducedVals.size(); - if (NumReducedVals < ReduxWidth) + if (NumReducedVals < 4) return false; + unsigned ReduxWidth = PowerOf2Floor(NumReducedVals); + Value *VectorizedTree = nullptr; IRBuilder<> Builder(ReductionRoot); FastMathFlags Unsafe; @@ -4270,55 +4496,78 @@ public: Builder.setFastMathFlags(Unsafe); unsigned i = 0; - for (; i < NumReducedVals - ReduxWidth + 1; i += ReduxWidth) { + BoUpSLP::ExtraValueToDebugLocsMap ExternallyUsedValues; + // The same extra argument may be used several time, so log each attempt + // to use it. + for (auto &Pair : ExtraArgs) + ExternallyUsedValues[Pair.second].push_back(Pair.first); + while (i < NumReducedVals - ReduxWidth + 1 && ReduxWidth > 2) { auto VL = makeArrayRef(&ReducedVals[i], ReduxWidth); - V.buildTree(VL, ReductionOps); + V.buildTree(VL, ExternallyUsedValues, ReductionOps); if (V.shouldReorder()) { SmallVector<Value *, 8> Reversed(VL.rbegin(), VL.rend()); - V.buildTree(Reversed, ReductionOps); + V.buildTree(Reversed, ExternallyUsedValues, ReductionOps); } if (V.isTreeTinyAndNotFullyVectorizable()) - continue; + break; V.computeMinimumValueSizes(); // Estimate cost. - int Cost = V.getTreeCost() + getReductionCost(TTI, ReducedVals[i]); + int Cost = + V.getTreeCost() + getReductionCost(TTI, ReducedVals[i], ReduxWidth); if (Cost >= -SLPCostThreshold) break; DEBUG(dbgs() << "SLP: Vectorizing horizontal reduction at cost:" << Cost << ". (HorRdx)\n"); + auto *I0 = cast<Instruction>(VL[0]); + V.getORE()->emit( + OptimizationRemark(SV_NAME, "VectorizedHorizontalReduction", I0) + << "Vectorized horizontal reduction with cost " + << ore::NV("Cost", Cost) << " and with tree size " + << ore::NV("TreeSize", V.getTreeSize())); // Vectorize a tree. DebugLoc Loc = cast<Instruction>(ReducedVals[i])->getDebugLoc(); - Value *VectorizedRoot = V.vectorizeTree(); + Value *VectorizedRoot = V.vectorizeTree(ExternallyUsedValues); // Emit a reduction. - Value *ReducedSubTree = emitReduction(VectorizedRoot, Builder); + Value *ReducedSubTree = + emitReduction(VectorizedRoot, Builder, ReduxWidth, ReductionOps, TTI); if (VectorizedTree) { Builder.SetCurrentDebugLocation(Loc); - VectorizedTree = createBinOp(Builder, ReductionOpcode, VectorizedTree, - ReducedSubTree, "bin.rdx"); + VectorizedTree = Builder.CreateBinOp(ReductionOpcode, VectorizedTree, + ReducedSubTree, "bin.rdx"); + propagateIRFlags(VectorizedTree, ReductionOps); } else VectorizedTree = ReducedSubTree; + i += ReduxWidth; + ReduxWidth = PowerOf2Floor(NumReducedVals - i); } if (VectorizedTree) { // Finish the reduction. for (; i < NumReducedVals; ++i) { - Builder.SetCurrentDebugLocation( - cast<Instruction>(ReducedVals[i])->getDebugLoc()); - VectorizedTree = createBinOp(Builder, ReductionOpcode, VectorizedTree, - ReducedVals[i]); + auto *I = cast<Instruction>(ReducedVals[i]); + Builder.SetCurrentDebugLocation(I->getDebugLoc()); + VectorizedTree = + Builder.CreateBinOp(ReductionOpcode, VectorizedTree, I); + propagateIRFlags(VectorizedTree, ReductionOps); + } + for (auto &Pair : ExternallyUsedValues) { + assert(!Pair.second.empty() && + "At least one DebugLoc must be inserted"); + // Add each externally used value to the final reduction. + for (auto *I : Pair.second) { + Builder.SetCurrentDebugLocation(I->getDebugLoc()); + VectorizedTree = Builder.CreateBinOp(ReductionOpcode, VectorizedTree, + Pair.first, "bin.extra"); + propagateIRFlags(VectorizedTree, I); + } } // Update users. - if (ReductionPHI && !isa<UndefValue>(ReductionPHI)) { - assert(ReductionRoot && "Need a reduction operation"); - ReductionRoot->setOperand(0, VectorizedTree); - ReductionRoot->setOperand(1, ReductionPHI); - } else - ReductionRoot->replaceAllUsesWith(VectorizedTree); + ReductionRoot->replaceAllUsesWith(VectorizedTree); } return VectorizedTree != nullptr; } @@ -4329,7 +4578,8 @@ public: private: /// \brief Calculate the cost of a reduction. - int getReductionCost(TargetTransformInfo *TTI, Value *FirstReducedVal) { + int getReductionCost(TargetTransformInfo *TTI, Value *FirstReducedVal, + unsigned ReduxWidth) { Type *ScalarTy = FirstReducedVal->getType(); Type *VecTy = VectorType::get(ScalarTy, ReduxWidth); @@ -4352,41 +4602,34 @@ private: return VecReduxCost - ScalarReduxCost; } - static Value *createBinOp(IRBuilder<> &Builder, unsigned Opcode, Value *L, - Value *R, const Twine &Name = "") { - if (Opcode == Instruction::FAdd) - return Builder.CreateFAdd(L, R, Name); - return Builder.CreateBinOp((Instruction::BinaryOps)Opcode, L, R, Name); - } - /// \brief Emit a horizontal reduction of the vectorized value. - Value *emitReduction(Value *VectorizedValue, IRBuilder<> &Builder) { + Value *emitReduction(Value *VectorizedValue, IRBuilder<> &Builder, + unsigned ReduxWidth, ArrayRef<Value *> RedOps, + const TargetTransformInfo *TTI) { assert(VectorizedValue && "Need to have a vectorized tree node"); assert(isPowerOf2_32(ReduxWidth) && "We only handle power-of-two reductions for now"); + if (!IsPairwiseReduction) + return createSimpleTargetReduction( + Builder, TTI, ReductionOpcode, VectorizedValue, + TargetTransformInfo::ReductionFlags(), RedOps); + Value *TmpVec = VectorizedValue; for (unsigned i = ReduxWidth / 2; i != 0; i >>= 1) { - if (IsPairwiseReduction) { - Value *LeftMask = + Value *LeftMask = createRdxShuffleMask(ReduxWidth, i, true, true, Builder); - Value *RightMask = + Value *RightMask = createRdxShuffleMask(ReduxWidth, i, true, false, Builder); - Value *LeftShuf = Builder.CreateShuffleVector( + Value *LeftShuf = Builder.CreateShuffleVector( TmpVec, UndefValue::get(TmpVec->getType()), LeftMask, "rdx.shuf.l"); - Value *RightShuf = Builder.CreateShuffleVector( + Value *RightShuf = Builder.CreateShuffleVector( TmpVec, UndefValue::get(TmpVec->getType()), (RightMask), "rdx.shuf.r"); - TmpVec = createBinOp(Builder, ReductionOpcode, LeftShuf, RightShuf, - "bin.rdx"); - } else { - Value *UpperHalf = - createRdxShuffleMask(ReduxWidth, i, false, false, Builder); - Value *Shuf = Builder.CreateShuffleVector( - TmpVec, UndefValue::get(TmpVec->getType()), UpperHalf, "rdx.shuf"); - TmpVec = createBinOp(Builder, ReductionOpcode, TmpVec, Shuf, "bin.rdx"); - } + TmpVec = + Builder.CreateBinOp(ReductionOpcode, LeftShuf, RightShuf, "bin.rdx"); + propagateIRFlags(TmpVec, RedOps); } // The result is in the first element of the vector. @@ -4438,16 +4681,19 @@ static bool findBuildVector(InsertElementInst *FirstInsertElem, static bool findBuildAggregate(InsertValueInst *IV, SmallVectorImpl<Value *> &BuildVector, SmallVectorImpl<Value *> &BuildVectorOpds) { - if (!IV->hasOneUse()) - return false; - Value *V = IV->getAggregateOperand(); - if (!isa<UndefValue>(V)) { - InsertValueInst *I = dyn_cast<InsertValueInst>(V); - if (!I || !findBuildAggregate(I, BuildVector, BuildVectorOpds)) + Value *V; + do { + BuildVector.push_back(IV); + BuildVectorOpds.push_back(IV->getInsertedValueOperand()); + V = IV->getAggregateOperand(); + if (isa<UndefValue>(V)) + break; + IV = dyn_cast<InsertValueInst>(V); + if (!IV || !IV->hasOneUse()) return false; - } - BuildVector.push_back(IV); - BuildVectorOpds.push_back(IV->getInsertedValueOperand()); + } while (true); + std::reverse(BuildVector.begin(), BuildVector.end()); + std::reverse(BuildVectorOpds.begin(), BuildVectorOpds.end()); return true; } @@ -4507,29 +4753,105 @@ static Value *getReductionValue(const DominatorTree *DT, PHINode *P, return nullptr; } -/// \brief Attempt to reduce a horizontal reduction. -/// If it is legal to match a horizontal reduction feeding -/// the phi node P with reduction operators BI, then check if it -/// can be done. -/// \returns true if a horizontal reduction was matched and reduced. -/// \returns false if a horizontal reduction was not matched. -static bool canMatchHorizontalReduction(PHINode *P, BinaryOperator *BI, - BoUpSLP &R, TargetTransformInfo *TTI, - unsigned MinRegSize) { +/// Attempt to reduce a horizontal reduction. +/// If it is legal to match a horizontal reduction feeding the phi node \a P +/// with reduction operators \a Root (or one of its operands) in a basic block +/// \a BB, then check if it can be done. If horizontal reduction is not found +/// and root instruction is a binary operation, vectorization of the operands is +/// attempted. +/// \returns true if a horizontal reduction was matched and reduced or operands +/// of one of the binary instruction were vectorized. +/// \returns false if a horizontal reduction was not matched (or not possible) +/// or no vectorization of any binary operation feeding \a Root instruction was +/// performed. +static bool tryToVectorizeHorReductionOrInstOperands( + PHINode *P, Instruction *Root, BasicBlock *BB, BoUpSLP &R, + TargetTransformInfo *TTI, + const function_ref<bool(BinaryOperator *, BoUpSLP &)> Vectorize) { if (!ShouldVectorizeHor) return false; - HorizontalReduction HorRdx(MinRegSize); - if (!HorRdx.matchAssociativeReduction(P, BI)) + if (!Root) return false; - // If there is a sufficient number of reduction values, reduce - // to a nearby power-of-2. Can safely generate oversized - // vectors and rely on the backend to split them to legal sizes. - HorRdx.ReduxWidth = - std::max((uint64_t)4, PowerOf2Floor(HorRdx.numReductionValues())); + if (Root->getParent() != BB) + return false; + // Start analysis starting from Root instruction. If horizontal reduction is + // found, try to vectorize it. If it is not a horizontal reduction or + // vectorization is not possible or not effective, and currently analyzed + // instruction is a binary operation, try to vectorize the operands, using + // pre-order DFS traversal order. If the operands were not vectorized, repeat + // the same procedure considering each operand as a possible root of the + // horizontal reduction. + // Interrupt the process if the Root instruction itself was vectorized or all + // sub-trees not higher that RecursionMaxDepth were analyzed/vectorized. + SmallVector<std::pair<WeakTrackingVH, unsigned>, 8> Stack(1, {Root, 0}); + SmallSet<Value *, 8> VisitedInstrs; + bool Res = false; + while (!Stack.empty()) { + Value *V; + unsigned Level; + std::tie(V, Level) = Stack.pop_back_val(); + if (!V) + continue; + auto *Inst = dyn_cast<Instruction>(V); + if (!Inst || isa<PHINode>(Inst)) + continue; + if (auto *BI = dyn_cast<BinaryOperator>(Inst)) { + HorizontalReduction HorRdx; + if (HorRdx.matchAssociativeReduction(P, BI)) { + if (HorRdx.tryToReduce(R, TTI)) { + Res = true; + // Set P to nullptr to avoid re-analysis of phi node in + // matchAssociativeReduction function unless this is the root node. + P = nullptr; + continue; + } + } + if (P) { + Inst = dyn_cast<Instruction>(BI->getOperand(0)); + if (Inst == P) + Inst = dyn_cast<Instruction>(BI->getOperand(1)); + if (!Inst) { + // Set P to nullptr to avoid re-analysis of phi node in + // matchAssociativeReduction function unless this is the root node. + P = nullptr; + continue; + } + } + } + // Set P to nullptr to avoid re-analysis of phi node in + // matchAssociativeReduction function unless this is the root node. + P = nullptr; + if (Vectorize(dyn_cast<BinaryOperator>(Inst), R)) { + Res = true; + continue; + } + + // Try to vectorize operands. + if (++Level < RecursionMaxDepth) + for (auto *Op : Inst->operand_values()) + Stack.emplace_back(Op, Level); + } + return Res; +} + +bool SLPVectorizerPass::vectorizeRootInstruction(PHINode *P, Value *V, + BasicBlock *BB, BoUpSLP &R, + TargetTransformInfo *TTI) { + if (!V) + return false; + auto *I = dyn_cast<Instruction>(V); + if (!I) + return false; - return HorRdx.tryToReduce(R, TTI); + if (!isa<BinaryOperator>(I)) + P = nullptr; + // Try to match and vectorize a horizontal reduction. + return tryToVectorizeHorReductionOrInstOperands( + P, I, BB, R, TTI, [this](BinaryOperator *BI, BoUpSLP &R) -> bool { + return tryToVectorize(BI, R); + }); } bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { @@ -4571,7 +4893,13 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { // Try to vectorize them. unsigned NumElts = (SameTypeIt - IncIt); DEBUG(errs() << "SLP: Trying to vectorize starting at PHIs (" << NumElts << ")\n"); - if (NumElts > 1 && tryToVectorizeList(makeArrayRef(IncIt, NumElts), R)) { + // The order in which the phi nodes appear in the program does not matter. + // So allow tryToVectorizeList to reorder them if it is beneficial. This + // is done when there are exactly two elements since tryToVectorizeList + // asserts that there are only two values when AllowReorder is true. + bool AllowReorder = NumElts == 2; + if (NumElts > 1 && tryToVectorizeList(makeArrayRef(IncIt, NumElts), R, + None, AllowReorder)) { // Success start over because instructions might have been changed. HaveVectorizedPhiNodes = true; Changed = true; @@ -4599,67 +4927,42 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { if (P->getNumIncomingValues() != 2) return Changed; - Value *Rdx = getReductionValue(DT, P, BB, LI); - - // Check if this is a Binary Operator. - BinaryOperator *BI = dyn_cast_or_null<BinaryOperator>(Rdx); - if (!BI) - continue; - // Try to match and vectorize a horizontal reduction. - if (canMatchHorizontalReduction(P, BI, R, TTI, R.getMinVecRegSize())) { - Changed = true; - it = BB->begin(); - e = BB->end(); - continue; - } - - Value *Inst = BI->getOperand(0); - if (Inst == P) - Inst = BI->getOperand(1); - - if (tryToVectorize(dyn_cast<BinaryOperator>(Inst), R)) { - // We would like to start over since some instructions are deleted - // and the iterator may become invalid value. + if (vectorizeRootInstruction(P, getReductionValue(DT, P, BB, LI), BB, R, + TTI)) { Changed = true; it = BB->begin(); e = BB->end(); continue; } - continue; } - if (ShouldStartVectorizeHorAtStore) - if (StoreInst *SI = dyn_cast<StoreInst>(it)) - if (BinaryOperator *BinOp = - dyn_cast<BinaryOperator>(SI->getValueOperand())) { - if (canMatchHorizontalReduction(nullptr, BinOp, R, TTI, - R.getMinVecRegSize()) || - tryToVectorize(BinOp, R)) { - Changed = true; - it = BB->begin(); - e = BB->end(); - continue; - } + if (ShouldStartVectorizeHorAtStore) { + if (StoreInst *SI = dyn_cast<StoreInst>(it)) { + // Try to match and vectorize a horizontal reduction. + if (vectorizeRootInstruction(nullptr, SI->getValueOperand(), BB, R, + TTI)) { + Changed = true; + it = BB->begin(); + e = BB->end(); + continue; } + } + } // Try to vectorize horizontal reductions feeding into a return. - if (ReturnInst *RI = dyn_cast<ReturnInst>(it)) - if (RI->getNumOperands() != 0) - if (BinaryOperator *BinOp = - dyn_cast<BinaryOperator>(RI->getOperand(0))) { - DEBUG(dbgs() << "SLP: Found a return to vectorize.\n"); - if (canMatchHorizontalReduction(nullptr, BinOp, R, TTI, - R.getMinVecRegSize()) || - tryToVectorizePair(BinOp->getOperand(0), BinOp->getOperand(1), - R)) { - Changed = true; - it = BB->begin(); - e = BB->end(); - continue; - } + if (ReturnInst *RI = dyn_cast<ReturnInst>(it)) { + if (RI->getNumOperands() != 0) { + // Try to match and vectorize a horizontal reduction. + if (vectorizeRootInstruction(nullptr, RI->getOperand(0), BB, R, TTI)) { + Changed = true; + it = BB->begin(); + e = BB->end(); + continue; } + } + } // Try to vectorize trees that start at compare instructions. if (CmpInst *CI = dyn_cast<CmpInst>(it)) { @@ -4672,16 +4975,14 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { continue; } - for (int i = 0; i < 2; ++i) { - if (BinaryOperator *BI = dyn_cast<BinaryOperator>(CI->getOperand(i))) { - if (tryToVectorizePair(BI->getOperand(0), BI->getOperand(1), R)) { - Changed = true; - // We would like to start over since some instructions are deleted - // and the iterator may become invalid value. - it = BB->begin(); - e = BB->end(); - break; - } + for (int I = 0; I < 2; ++I) { + if (vectorizeRootInstruction(nullptr, CI->getOperand(I), BB, R, TTI)) { + Changed = true; + // We would like to start over since some instructions are deleted + // and the iterator may become invalid value. + it = BB->begin(); + e = BB->end(); + break; } } continue; @@ -4757,7 +5058,8 @@ bool SLPVectorizerPass::vectorizeGEPIndices(BasicBlock *BB, BoUpSLP &R) { SetVector<Value *> Candidates(GEPList.begin(), GEPList.end()); // Some of the candidates may have already been vectorized after we - // initially collected them. If so, the WeakVHs will have nullified the + // initially collected them. If so, the WeakTrackingVHs will have + // nullified the // values, so remove them from the set of candidates. Candidates.remove(nullptr); @@ -4847,6 +5149,7 @@ INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopSimplify) INITIALIZE_PASS_DEPENDENCY(DemandedBitsWrapperPass) +INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) INITIALIZE_PASS_END(SLPVectorizer, SV_NAME, lv_name, false, false) namespace llvm { diff --git a/contrib/llvm/lib/Transforms/Vectorize/Vectorize.cpp b/contrib/llvm/lib/Transforms/Vectorize/Vectorize.cpp index 28e0b2e..fb2f509 100644 --- a/contrib/llvm/lib/Transforms/Vectorize/Vectorize.cpp +++ b/contrib/llvm/lib/Transforms/Vectorize/Vectorize.cpp @@ -17,16 +17,15 @@ #include "llvm-c/Initialization.h" #include "llvm-c/Transforms/Vectorize.h" #include "llvm/Analysis/Passes.h" +#include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Verifier.h" #include "llvm/InitializePasses.h" -#include "llvm/IR/LegacyPassManager.h" using namespace llvm; /// initializeVectorizationPasses - Initialize all passes linked into the /// Vectorization library. void llvm::initializeVectorization(PassRegistry &Registry) { - initializeBBVectorizePass(Registry); initializeLoopVectorizePass(Registry); initializeSLPVectorizerPass(Registry); initializeLoadStoreVectorizerPass(Registry); @@ -36,8 +35,8 @@ void LLVMInitializeVectorization(LLVMPassRegistryRef R) { initializeVectorization(*unwrap(R)); } +// DEPRECATED: Remove after the LLVM 5 release. void LLVMAddBBVectorizePass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createBBVectorizePass()); } void LLVMAddLoopVectorizePass(LLVMPassManagerRef PM) { |