diff options
author | ed <ed@FreeBSD.org> | 2009-06-02 17:52:33 +0000 |
---|---|---|
committer | ed <ed@FreeBSD.org> | 2009-06-02 17:52:33 +0000 |
commit | 3277b69d734b9c90b44ebde4ede005717e2c3b2e (patch) | |
tree | 64ba909838c23261cace781ece27d106134ea451 /lib/Transforms/Scalar | |
download | FreeBSD-src-3277b69d734b9c90b44ebde4ede005717e2c3b2e.zip FreeBSD-src-3277b69d734b9c90b44ebde4ede005717e2c3b2e.tar.gz |
Import LLVM, at r72732.
Diffstat (limited to 'lib/Transforms/Scalar')
33 files changed, 39319 insertions, 0 deletions
diff --git a/lib/Transforms/Scalar/ADCE.cpp b/lib/Transforms/Scalar/ADCE.cpp new file mode 100644 index 0000000..9c55f66 --- /dev/null +++ b/lib/Transforms/Scalar/ADCE.cpp @@ -0,0 +1,98 @@ +//===- DCE.cpp - Code to perform dead code elimination --------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the Aggressive Dead Code Elimination pass. This pass +// optimistically assumes that all instructions are dead until proven otherwise, +// allowing it to eliminate dead computations that other DCE passes do not +// catch, particularly involving loop computations. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "adce" +#include "llvm/Transforms/Scalar.h" +#include "llvm/BasicBlock.h" +#include "llvm/Instructions.h" +#include "llvm/IntrinsicInst.h" +#include "llvm/Pass.h" +#include "llvm/Support/CFG.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/InstIterator.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" + +using namespace llvm; + +STATISTIC(NumRemoved, "Number of instructions removed"); + +namespace { + struct VISIBILITY_HIDDEN ADCE : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + ADCE() : FunctionPass(&ID) {} + + virtual bool runOnFunction(Function& F); + + virtual void getAnalysisUsage(AnalysisUsage& AU) const { + AU.setPreservesCFG(); + } + + }; +} + +char ADCE::ID = 0; +static RegisterPass<ADCE> X("adce", "Aggressive Dead Code Elimination"); + +bool ADCE::runOnFunction(Function& F) { + SmallPtrSet<Instruction*, 128> alive; + SmallVector<Instruction*, 128> worklist; + + // Collect the set of "root" instructions that are known live. + for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) + if (isa<TerminatorInst>(I.getInstructionIterator()) || + isa<DbgInfoIntrinsic>(I.getInstructionIterator()) || + I->mayHaveSideEffects()) { + alive.insert(I.getInstructionIterator()); + worklist.push_back(I.getInstructionIterator()); + } + + // Propagate liveness backwards to operands. + while (!worklist.empty()) { + Instruction* curr = worklist.back(); + worklist.pop_back(); + + for (Instruction::op_iterator OI = curr->op_begin(), OE = curr->op_end(); + OI != OE; ++OI) + if (Instruction* Inst = dyn_cast<Instruction>(OI)) + if (alive.insert(Inst)) + worklist.push_back(Inst); + } + + // The inverse of the live set is the dead set. These are those instructions + // which have no side effects and do not influence the control flow or return + // value of the function, and may therefore be deleted safely. + // NOTE: We reuse the worklist vector here for memory efficiency. + for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) + if (!alive.count(I.getInstructionIterator())) { + worklist.push_back(I.getInstructionIterator()); + I->dropAllReferences(); + } + + for (SmallVector<Instruction*, 1024>::iterator I = worklist.begin(), + E = worklist.end(); I != E; ++I) { + NumRemoved++; + (*I)->eraseFromParent(); + } + + return !worklist.empty(); +} + +FunctionPass *llvm::createAggressiveDCEPass() { + return new ADCE(); +} diff --git a/lib/Transforms/Scalar/BasicBlockPlacement.cpp b/lib/Transforms/Scalar/BasicBlockPlacement.cpp new file mode 100644 index 0000000..fb9b880 --- /dev/null +++ b/lib/Transforms/Scalar/BasicBlockPlacement.cpp @@ -0,0 +1,148 @@ +//===-- BasicBlockPlacement.cpp - Basic Block Code Layout optimization ----===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements a very simple profile guided basic block placement +// algorithm. The idea is to put frequently executed blocks together at the +// start of the function, and hopefully increase the number of fall-through +// conditional branches. If there is no profile information for a particular +// function, this pass basically orders blocks in depth-first order +// +// The algorithm implemented here is basically "Algo1" from "Profile Guided Code +// Positioning" by Pettis and Hansen, except that it uses basic block counts +// instead of edge counts. This should be improved in many ways, but is very +// simple for now. +// +// Basically we "place" the entry block, then loop over all successors in a DFO, +// placing the most frequently executed successor until we run out of blocks. I +// told you this was _extremely_ simplistic. :) This is also much slower than it +// could be. When it becomes important, this pass will be rewritten to use a +// better algorithm, and then we can worry about efficiency. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "block-placement" +#include "llvm/Analysis/ProfileInfo.h" +#include "llvm/Function.h" +#include "llvm/Pass.h" +#include "llvm/Support/CFG.h" +#include "llvm/Support/Compiler.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Transforms/Scalar.h" +#include <set> +using namespace llvm; + +STATISTIC(NumMoved, "Number of basic blocks moved"); + +namespace { + struct VISIBILITY_HIDDEN BlockPlacement : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + BlockPlacement() : FunctionPass(&ID) {} + + virtual bool runOnFunction(Function &F); + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesCFG(); + AU.addRequired<ProfileInfo>(); + //AU.addPreserved<ProfileInfo>(); // Does this work? + } + private: + /// PI - The profile information that is guiding us. + /// + ProfileInfo *PI; + + /// NumMovedBlocks - Every time we move a block, increment this counter. + /// + unsigned NumMovedBlocks; + + /// PlacedBlocks - Every time we place a block, remember it so we don't get + /// into infinite loops. + std::set<BasicBlock*> PlacedBlocks; + + /// InsertPos - This an iterator to the next place we want to insert a + /// block. + Function::iterator InsertPos; + + /// PlaceBlocks - Recursively place the specified blocks and any unplaced + /// successors. + void PlaceBlocks(BasicBlock *BB); + }; +} + +char BlockPlacement::ID = 0; +static RegisterPass<BlockPlacement> +X("block-placement", "Profile Guided Basic Block Placement"); + +FunctionPass *llvm::createBlockPlacementPass() { return new BlockPlacement(); } + +bool BlockPlacement::runOnFunction(Function &F) { + PI = &getAnalysis<ProfileInfo>(); + + NumMovedBlocks = 0; + InsertPos = F.begin(); + + // Recursively place all blocks. + PlaceBlocks(F.begin()); + + PlacedBlocks.clear(); + NumMoved += NumMovedBlocks; + return NumMovedBlocks != 0; +} + + +/// PlaceBlocks - Recursively place the specified blocks and any unplaced +/// successors. +void BlockPlacement::PlaceBlocks(BasicBlock *BB) { + assert(!PlacedBlocks.count(BB) && "Already placed this block!"); + PlacedBlocks.insert(BB); + + // Place the specified block. + if (&*InsertPos != BB) { + // Use splice to move the block into the right place. This avoids having to + // remove the block from the function then readd it, which causes a bunch of + // symbol table traffic that is entirely pointless. + Function::BasicBlockListType &Blocks = BB->getParent()->getBasicBlockList(); + Blocks.splice(InsertPos, Blocks, BB); + + ++NumMovedBlocks; + } else { + // This block is already in the right place, we don't have to do anything. + ++InsertPos; + } + + // Keep placing successors until we run out of ones to place. Note that this + // loop is very inefficient (N^2) for blocks with many successors, like switch + // statements. FIXME! + while (1) { + // Okay, now place any unplaced successors. + succ_iterator SI = succ_begin(BB), E = succ_end(BB); + + // Scan for the first unplaced successor. + for (; SI != E && PlacedBlocks.count(*SI); ++SI) + /*empty*/; + if (SI == E) return; // No more successors to place. + + unsigned MaxExecutionCount = PI->getExecutionCount(*SI); + BasicBlock *MaxSuccessor = *SI; + + // Scan for more frequently executed successors + for (; SI != E; ++SI) + if (!PlacedBlocks.count(*SI)) { + unsigned Count = PI->getExecutionCount(*SI); + if (Count > MaxExecutionCount || + // Prefer to not disturb the code. + (Count == MaxExecutionCount && *SI == &*InsertPos)) { + MaxExecutionCount = Count; + MaxSuccessor = *SI; + } + } + + // Now that we picked the maximally executed successor, place it. + PlaceBlocks(MaxSuccessor); + } +} diff --git a/lib/Transforms/Scalar/CMakeLists.txt b/lib/Transforms/Scalar/CMakeLists.txt new file mode 100644 index 0000000..7a7c48b --- /dev/null +++ b/lib/Transforms/Scalar/CMakeLists.txt @@ -0,0 +1,33 @@ +add_llvm_library(LLVMScalarOpts + ADCE.cpp + BasicBlockPlacement.cpp + CodeGenPrepare.cpp + CondPropagate.cpp + ConstantProp.cpp + DCE.cpp + DeadStoreElimination.cpp + GVN.cpp + GVNPRE.cpp + IndVarSimplify.cpp + InstructionCombining.cpp + JumpThreading.cpp + LICM.cpp + LoopDeletion.cpp + LoopIndexSplit.cpp + LoopRotation.cpp + LoopStrengthReduce.cpp + LoopUnroll.cpp + LoopUnswitch.cpp + MemCpyOptimizer.cpp + PredicateSimplifier.cpp + Reassociate.cpp + Reg2Mem.cpp + SCCP.cpp + Scalar.cpp + ScalarReplAggregates.cpp + SimplifyCFGPass.cpp + SimplifyHalfPowrLibCalls.cpp + SimplifyLibCalls.cpp + TailDuplication.cpp + TailRecursionElimination.cpp + ) diff --git a/lib/Transforms/Scalar/CodeGenPrepare.cpp b/lib/Transforms/Scalar/CodeGenPrepare.cpp new file mode 100644 index 0000000..342b1e5 --- /dev/null +++ b/lib/Transforms/Scalar/CodeGenPrepare.cpp @@ -0,0 +1,873 @@ +//===- CodeGenPrepare.cpp - Prepare a function for code generation --------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass munges the code in the input function to better prepare it for +// SelectionDAG-based code generation. This works around limitations in it's +// basic-block-at-a-time approach. It should eventually be removed. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "codegenprepare" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Function.h" +#include "llvm/InlineAsm.h" +#include "llvm/Instructions.h" +#include "llvm/IntrinsicInst.h" +#include "llvm/Pass.h" +#include "llvm/Target/TargetAsmInfo.h" +#include "llvm/Target/TargetData.h" +#include "llvm/Target/TargetLowering.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Transforms/Utils/AddrModeMatcher.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/Assembly/Writer.h" +#include "llvm/Support/CallSite.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/GetElementPtrTypeIterator.h" +#include "llvm/Support/PatternMatch.h" +using namespace llvm; +using namespace llvm::PatternMatch; + +static cl::opt<bool> FactorCommonPreds("split-critical-paths-tweak", + cl::init(false), cl::Hidden); + +namespace { + class VISIBILITY_HIDDEN CodeGenPrepare : public FunctionPass { + /// TLI - Keep a pointer of a TargetLowering to consult for determining + /// transformation profitability. + const TargetLowering *TLI; + + /// BackEdges - Keep a set of all the loop back edges. + /// + SmallSet<std::pair<const BasicBlock*, const BasicBlock*>, 8> BackEdges; + public: + static char ID; // Pass identification, replacement for typeid + explicit CodeGenPrepare(const TargetLowering *tli = 0) + : FunctionPass(&ID), TLI(tli) {} + bool runOnFunction(Function &F); + + private: + bool EliminateMostlyEmptyBlocks(Function &F); + bool CanMergeBlocks(const BasicBlock *BB, const BasicBlock *DestBB) const; + void EliminateMostlyEmptyBlock(BasicBlock *BB); + bool OptimizeBlock(BasicBlock &BB); + bool OptimizeMemoryInst(Instruction *I, Value *Addr, const Type *AccessTy, + DenseMap<Value*,Value*> &SunkAddrs); + bool OptimizeInlineAsmInst(Instruction *I, CallSite CS, + DenseMap<Value*,Value*> &SunkAddrs); + bool OptimizeExtUses(Instruction *I); + void findLoopBackEdges(const Function &F); + }; +} + +char CodeGenPrepare::ID = 0; +static RegisterPass<CodeGenPrepare> X("codegenprepare", + "Optimize for code generation"); + +FunctionPass *llvm::createCodeGenPreparePass(const TargetLowering *TLI) { + return new CodeGenPrepare(TLI); +} + +/// findLoopBackEdges - Do a DFS walk to find loop back edges. +/// +void CodeGenPrepare::findLoopBackEdges(const Function &F) { + SmallVector<std::pair<const BasicBlock*,const BasicBlock*>, 32> Edges; + FindFunctionBackedges(F, Edges); + + BackEdges.insert(Edges.begin(), Edges.end()); +} + + +bool CodeGenPrepare::runOnFunction(Function &F) { + bool EverMadeChange = false; + + // First pass, eliminate blocks that contain only PHI nodes and an + // unconditional branch. + EverMadeChange |= EliminateMostlyEmptyBlocks(F); + + // Now find loop back edges. + findLoopBackEdges(F); + + bool MadeChange = true; + while (MadeChange) { + MadeChange = false; + for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) + MadeChange |= OptimizeBlock(*BB); + EverMadeChange |= MadeChange; + } + return EverMadeChange; +} + +/// EliminateMostlyEmptyBlocks - eliminate blocks that contain only PHI nodes, +/// debug info directives, and an unconditional branch. Passes before isel +/// (e.g. LSR/loopsimplify) often split edges in ways that are non-optimal for +/// isel. Start by eliminating these blocks so we can split them the way we +/// want them. +bool CodeGenPrepare::EliminateMostlyEmptyBlocks(Function &F) { + bool MadeChange = false; + // Note that this intentionally skips the entry block. + for (Function::iterator I = ++F.begin(), E = F.end(); I != E; ) { + BasicBlock *BB = I++; + + // If this block doesn't end with an uncond branch, ignore it. + BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator()); + if (!BI || !BI->isUnconditional()) + continue; + + // If the instruction before the branch (skipping debug info) isn't a phi + // node, then other stuff is happening here. + BasicBlock::iterator BBI = BI; + if (BBI != BB->begin()) { + --BBI; + while (isa<DbgInfoIntrinsic>(BBI)) { + if (BBI == BB->begin()) + break; + --BBI; + } + if (!isa<DbgInfoIntrinsic>(BBI) && !isa<PHINode>(BBI)) + continue; + } + + // Do not break infinite loops. + BasicBlock *DestBB = BI->getSuccessor(0); + if (DestBB == BB) + continue; + + if (!CanMergeBlocks(BB, DestBB)) + continue; + + EliminateMostlyEmptyBlock(BB); + MadeChange = true; + } + return MadeChange; +} + +/// CanMergeBlocks - Return true if we can merge BB into DestBB if there is a +/// single uncond branch between them, and BB contains no other non-phi +/// instructions. +bool CodeGenPrepare::CanMergeBlocks(const BasicBlock *BB, + const BasicBlock *DestBB) const { + // We only want to eliminate blocks whose phi nodes are used by phi nodes in + // the successor. If there are more complex condition (e.g. preheaders), + // don't mess around with them. + BasicBlock::const_iterator BBI = BB->begin(); + while (const PHINode *PN = dyn_cast<PHINode>(BBI++)) { + for (Value::use_const_iterator UI = PN->use_begin(), E = PN->use_end(); + UI != E; ++UI) { + const Instruction *User = cast<Instruction>(*UI); + if (User->getParent() != DestBB || !isa<PHINode>(User)) + return false; + // If User is inside DestBB block and it is a PHINode then check + // incoming value. If incoming value is not from BB then this is + // a complex condition (e.g. preheaders) we want to avoid here. + if (User->getParent() == DestBB) { + if (const PHINode *UPN = dyn_cast<PHINode>(User)) + for (unsigned I = 0, E = UPN->getNumIncomingValues(); I != E; ++I) { + Instruction *Insn = dyn_cast<Instruction>(UPN->getIncomingValue(I)); + if (Insn && Insn->getParent() == BB && + Insn->getParent() != UPN->getIncomingBlock(I)) + return false; + } + } + } + } + + // If BB and DestBB contain any common predecessors, then the phi nodes in BB + // and DestBB may have conflicting incoming values for the block. If so, we + // can't merge the block. + const PHINode *DestBBPN = dyn_cast<PHINode>(DestBB->begin()); + if (!DestBBPN) return true; // no conflict. + + // Collect the preds of BB. + SmallPtrSet<const BasicBlock*, 16> BBPreds; + if (const PHINode *BBPN = dyn_cast<PHINode>(BB->begin())) { + // It is faster to get preds from a PHI than with pred_iterator. + for (unsigned i = 0, e = BBPN->getNumIncomingValues(); i != e; ++i) + BBPreds.insert(BBPN->getIncomingBlock(i)); + } else { + BBPreds.insert(pred_begin(BB), pred_end(BB)); + } + + // Walk the preds of DestBB. + for (unsigned i = 0, e = DestBBPN->getNumIncomingValues(); i != e; ++i) { + BasicBlock *Pred = DestBBPN->getIncomingBlock(i); + if (BBPreds.count(Pred)) { // Common predecessor? + BBI = DestBB->begin(); + while (const PHINode *PN = dyn_cast<PHINode>(BBI++)) { + const Value *V1 = PN->getIncomingValueForBlock(Pred); + const Value *V2 = PN->getIncomingValueForBlock(BB); + + // If V2 is a phi node in BB, look up what the mapped value will be. + if (const PHINode *V2PN = dyn_cast<PHINode>(V2)) + if (V2PN->getParent() == BB) + V2 = V2PN->getIncomingValueForBlock(Pred); + + // If there is a conflict, bail out. + if (V1 != V2) return false; + } + } + } + + return true; +} + + +/// EliminateMostlyEmptyBlock - Eliminate a basic block that have only phi's and +/// an unconditional branch in it. +void CodeGenPrepare::EliminateMostlyEmptyBlock(BasicBlock *BB) { + BranchInst *BI = cast<BranchInst>(BB->getTerminator()); + BasicBlock *DestBB = BI->getSuccessor(0); + + DOUT << "MERGING MOSTLY EMPTY BLOCKS - BEFORE:\n" << *BB << *DestBB; + + // If the destination block has a single pred, then this is a trivial edge, + // just collapse it. + if (BasicBlock *SinglePred = DestBB->getSinglePredecessor()) { + if (SinglePred != DestBB) { + // Remember if SinglePred was the entry block of the function. If so, we + // will need to move BB back to the entry position. + bool isEntry = SinglePred == &SinglePred->getParent()->getEntryBlock(); + MergeBasicBlockIntoOnlyPred(DestBB); + + if (isEntry && BB != &BB->getParent()->getEntryBlock()) + BB->moveBefore(&BB->getParent()->getEntryBlock()); + + DOUT << "AFTER:\n" << *DestBB << "\n\n\n"; + return; + } + } + + // Otherwise, we have multiple predecessors of BB. Update the PHIs in DestBB + // to handle the new incoming edges it is about to have. + PHINode *PN; + for (BasicBlock::iterator BBI = DestBB->begin(); + (PN = dyn_cast<PHINode>(BBI)); ++BBI) { + // Remove the incoming value for BB, and remember it. + Value *InVal = PN->removeIncomingValue(BB, false); + + // Two options: either the InVal is a phi node defined in BB or it is some + // value that dominates BB. + PHINode *InValPhi = dyn_cast<PHINode>(InVal); + if (InValPhi && InValPhi->getParent() == BB) { + // Add all of the input values of the input PHI as inputs of this phi. + for (unsigned i = 0, e = InValPhi->getNumIncomingValues(); i != e; ++i) + PN->addIncoming(InValPhi->getIncomingValue(i), + InValPhi->getIncomingBlock(i)); + } else { + // Otherwise, add one instance of the dominating value for each edge that + // we will be adding. + if (PHINode *BBPN = dyn_cast<PHINode>(BB->begin())) { + for (unsigned i = 0, e = BBPN->getNumIncomingValues(); i != e; ++i) + PN->addIncoming(InVal, BBPN->getIncomingBlock(i)); + } else { + for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) + PN->addIncoming(InVal, *PI); + } + } + } + + // The PHIs are now updated, change everything that refers to BB to use + // DestBB and remove BB. + BB->replaceAllUsesWith(DestBB); + BB->eraseFromParent(); + + DOUT << "AFTER:\n" << *DestBB << "\n\n\n"; +} + + +/// SplitEdgeNicely - Split the critical edge from TI to its specified +/// successor if it will improve codegen. We only do this if the successor has +/// phi nodes (otherwise critical edges are ok). If there is already another +/// predecessor of the succ that is empty (and thus has no phi nodes), use it +/// instead of introducing a new block. +static void SplitEdgeNicely(TerminatorInst *TI, unsigned SuccNum, + SmallSet<std::pair<const BasicBlock*, + const BasicBlock*>, 8> &BackEdges, + Pass *P) { + BasicBlock *TIBB = TI->getParent(); + BasicBlock *Dest = TI->getSuccessor(SuccNum); + assert(isa<PHINode>(Dest->begin()) && + "This should only be called if Dest has a PHI!"); + + // Do not split edges to EH landing pads. + if (InvokeInst *Invoke = dyn_cast<InvokeInst>(TI)) { + if (Invoke->getSuccessor(1) == Dest) + return; + } + + // As a hack, never split backedges of loops. Even though the copy for any + // PHIs inserted on the backedge would be dead for exits from the loop, we + // assume that the cost of *splitting* the backedge would be too high. + if (BackEdges.count(std::make_pair(TIBB, Dest))) + return; + + if (!FactorCommonPreds) { + /// TIPHIValues - This array is lazily computed to determine the values of + /// PHIs in Dest that TI would provide. + SmallVector<Value*, 32> TIPHIValues; + + // Check to see if Dest has any blocks that can be used as a split edge for + // this terminator. + for (pred_iterator PI = pred_begin(Dest), E = pred_end(Dest); PI != E; ++PI) { + BasicBlock *Pred = *PI; + // To be usable, the pred has to end with an uncond branch to the dest. + BranchInst *PredBr = dyn_cast<BranchInst>(Pred->getTerminator()); + if (!PredBr || !PredBr->isUnconditional()) + continue; + // Must be empty other than the branch and debug info. + BasicBlock::iterator I = Pred->begin(); + while (isa<DbgInfoIntrinsic>(I)) + I++; + if (dyn_cast<Instruction>(I) != PredBr) + continue; + // Cannot be the entry block; its label does not get emitted. + if (Pred == &(Dest->getParent()->getEntryBlock())) + continue; + + // Finally, since we know that Dest has phi nodes in it, we have to make + // sure that jumping to Pred will have the same effect as going to Dest in + // terms of PHI values. + PHINode *PN; + unsigned PHINo = 0; + bool FoundMatch = true; + for (BasicBlock::iterator I = Dest->begin(); + (PN = dyn_cast<PHINode>(I)); ++I, ++PHINo) { + if (PHINo == TIPHIValues.size()) + TIPHIValues.push_back(PN->getIncomingValueForBlock(TIBB)); + + // If the PHI entry doesn't work, we can't use this pred. + if (TIPHIValues[PHINo] != PN->getIncomingValueForBlock(Pred)) { + FoundMatch = false; + break; + } + } + + // If we found a workable predecessor, change TI to branch to Succ. + if (FoundMatch) { + Dest->removePredecessor(TIBB); + TI->setSuccessor(SuccNum, Pred); + return; + } + } + + SplitCriticalEdge(TI, SuccNum, P, true); + return; + } + + PHINode *PN; + SmallVector<Value*, 8> TIPHIValues; + for (BasicBlock::iterator I = Dest->begin(); + (PN = dyn_cast<PHINode>(I)); ++I) + TIPHIValues.push_back(PN->getIncomingValueForBlock(TIBB)); + + SmallVector<BasicBlock*, 8> IdenticalPreds; + for (pred_iterator PI = pred_begin(Dest), E = pred_end(Dest); PI != E; ++PI) { + BasicBlock *Pred = *PI; + if (BackEdges.count(std::make_pair(Pred, Dest))) + continue; + if (PI == TIBB) + IdenticalPreds.push_back(Pred); + else { + bool Identical = true; + unsigned PHINo = 0; + for (BasicBlock::iterator I = Dest->begin(); + (PN = dyn_cast<PHINode>(I)); ++I, ++PHINo) + if (TIPHIValues[PHINo] != PN->getIncomingValueForBlock(Pred)) { + Identical = false; + break; + } + if (Identical) + IdenticalPreds.push_back(Pred); + } + } + + assert(!IdenticalPreds.empty()); + SplitBlockPredecessors(Dest, &IdenticalPreds[0], IdenticalPreds.size(), + ".critedge", P); +} + + +/// OptimizeNoopCopyExpression - If the specified cast instruction is a noop +/// copy (e.g. it's casting from one pointer type to another, int->uint, or +/// int->sbyte on PPC), sink it into user blocks to reduce the number of virtual +/// registers that must be created and coalesced. +/// +/// Return true if any changes are made. +/// +static bool OptimizeNoopCopyExpression(CastInst *CI, const TargetLowering &TLI){ + // If this is a noop copy, + MVT SrcVT = TLI.getValueType(CI->getOperand(0)->getType()); + MVT DstVT = TLI.getValueType(CI->getType()); + + // This is an fp<->int conversion? + if (SrcVT.isInteger() != DstVT.isInteger()) + return false; + + // If this is an extension, it will be a zero or sign extension, which + // isn't a noop. + if (SrcVT.bitsLT(DstVT)) return false; + + // If these values will be promoted, find out what they will be promoted + // to. This helps us consider truncates on PPC as noop copies when they + // are. + if (TLI.getTypeAction(SrcVT) == TargetLowering::Promote) + SrcVT = TLI.getTypeToTransformTo(SrcVT); + if (TLI.getTypeAction(DstVT) == TargetLowering::Promote) + DstVT = TLI.getTypeToTransformTo(DstVT); + + // If, after promotion, these are the same types, this is a noop copy. + if (SrcVT != DstVT) + return false; + + BasicBlock *DefBB = CI->getParent(); + + /// InsertedCasts - Only insert a cast in each block once. + DenseMap<BasicBlock*, CastInst*> InsertedCasts; + + bool MadeChange = false; + for (Value::use_iterator UI = CI->use_begin(), E = CI->use_end(); + UI != E; ) { + Use &TheUse = UI.getUse(); + Instruction *User = cast<Instruction>(*UI); + + // Figure out which BB this cast is used in. For PHI's this is the + // appropriate predecessor block. + BasicBlock *UserBB = User->getParent(); + if (PHINode *PN = dyn_cast<PHINode>(User)) { + UserBB = PN->getIncomingBlock(UI); + } + + // Preincrement use iterator so we don't invalidate it. + ++UI; + + // If this user is in the same block as the cast, don't change the cast. + if (UserBB == DefBB) continue; + + // If we have already inserted a cast into this block, use it. + CastInst *&InsertedCast = InsertedCasts[UserBB]; + + if (!InsertedCast) { + BasicBlock::iterator InsertPt = UserBB->getFirstNonPHI(); + + InsertedCast = + CastInst::Create(CI->getOpcode(), CI->getOperand(0), CI->getType(), "", + InsertPt); + MadeChange = true; + } + + // Replace a use of the cast with a use of the new cast. + TheUse = InsertedCast; + } + + // If we removed all uses, nuke the cast. + if (CI->use_empty()) { + CI->eraseFromParent(); + MadeChange = true; + } + + return MadeChange; +} + +/// OptimizeCmpExpression - sink the given CmpInst into user blocks to reduce +/// the number of virtual registers that must be created and coalesced. This is +/// a clear win except on targets with multiple condition code registers +/// (PowerPC), where it might lose; some adjustment may be wanted there. +/// +/// Return true if any changes are made. +static bool OptimizeCmpExpression(CmpInst *CI) { + BasicBlock *DefBB = CI->getParent(); + + /// InsertedCmp - Only insert a cmp in each block once. + DenseMap<BasicBlock*, CmpInst*> InsertedCmps; + + bool MadeChange = false; + for (Value::use_iterator UI = CI->use_begin(), E = CI->use_end(); + UI != E; ) { + Use &TheUse = UI.getUse(); + Instruction *User = cast<Instruction>(*UI); + + // Preincrement use iterator so we don't invalidate it. + ++UI; + + // Don't bother for PHI nodes. + if (isa<PHINode>(User)) + continue; + + // Figure out which BB this cmp is used in. + BasicBlock *UserBB = User->getParent(); + + // If this user is in the same block as the cmp, don't change the cmp. + if (UserBB == DefBB) continue; + + // If we have already inserted a cmp into this block, use it. + CmpInst *&InsertedCmp = InsertedCmps[UserBB]; + + if (!InsertedCmp) { + BasicBlock::iterator InsertPt = UserBB->getFirstNonPHI(); + + InsertedCmp = + CmpInst::Create(CI->getOpcode(), CI->getPredicate(), CI->getOperand(0), + CI->getOperand(1), "", InsertPt); + MadeChange = true; + } + + // Replace a use of the cmp with a use of the new cmp. + TheUse = InsertedCmp; + } + + // If we removed all uses, nuke the cmp. + if (CI->use_empty()) + CI->eraseFromParent(); + + return MadeChange; +} + +//===----------------------------------------------------------------------===// +// Memory Optimization +//===----------------------------------------------------------------------===// + +/// IsNonLocalValue - Return true if the specified values are defined in a +/// different basic block than BB. +static bool IsNonLocalValue(Value *V, BasicBlock *BB) { + if (Instruction *I = dyn_cast<Instruction>(V)) + return I->getParent() != BB; + return false; +} + +/// OptimizeMemoryInst - Load and Store Instructions have often have +/// addressing modes that can do significant amounts of computation. As such, +/// instruction selection will try to get the load or store to do as much +/// computation as possible for the program. The problem is that isel can only +/// see within a single block. As such, we sink as much legal addressing mode +/// stuff into the block as possible. +/// +/// This method is used to optimize both load/store and inline asms with memory +/// operands. +bool CodeGenPrepare::OptimizeMemoryInst(Instruction *MemoryInst, Value *Addr, + const Type *AccessTy, + DenseMap<Value*,Value*> &SunkAddrs) { + // Figure out what addressing mode will be built up for this operation. + SmallVector<Instruction*, 16> AddrModeInsts; + ExtAddrMode AddrMode = AddressingModeMatcher::Match(Addr, AccessTy,MemoryInst, + AddrModeInsts, *TLI); + + // Check to see if any of the instructions supersumed by this addr mode are + // non-local to I's BB. + bool AnyNonLocal = false; + for (unsigned i = 0, e = AddrModeInsts.size(); i != e; ++i) { + if (IsNonLocalValue(AddrModeInsts[i], MemoryInst->getParent())) { + AnyNonLocal = true; + break; + } + } + + // If all the instructions matched are already in this BB, don't do anything. + if (!AnyNonLocal) { + DEBUG(cerr << "CGP: Found local addrmode: " << AddrMode << "\n"); + return false; + } + + // Insert this computation right after this user. Since our caller is + // scanning from the top of the BB to the bottom, reuse of the expr are + // guaranteed to happen later. + BasicBlock::iterator InsertPt = MemoryInst; + + // Now that we determined the addressing expression we want to use and know + // that we have to sink it into this block. Check to see if we have already + // done this for some other load/store instr in this block. If so, reuse the + // computation. + Value *&SunkAddr = SunkAddrs[Addr]; + if (SunkAddr) { + DEBUG(cerr << "CGP: Reusing nonlocal addrmode: " << AddrMode << " for " + << *MemoryInst); + if (SunkAddr->getType() != Addr->getType()) + SunkAddr = new BitCastInst(SunkAddr, Addr->getType(), "tmp", InsertPt); + } else { + DEBUG(cerr << "CGP: SINKING nonlocal addrmode: " << AddrMode << " for " + << *MemoryInst); + const Type *IntPtrTy = TLI->getTargetData()->getIntPtrType(); + + Value *Result = 0; + // Start with the scale value. + if (AddrMode.Scale) { + Value *V = AddrMode.ScaledReg; + if (V->getType() == IntPtrTy) { + // done. + } else if (isa<PointerType>(V->getType())) { + V = new PtrToIntInst(V, IntPtrTy, "sunkaddr", InsertPt); + } else if (cast<IntegerType>(IntPtrTy)->getBitWidth() < + cast<IntegerType>(V->getType())->getBitWidth()) { + V = new TruncInst(V, IntPtrTy, "sunkaddr", InsertPt); + } else { + V = new SExtInst(V, IntPtrTy, "sunkaddr", InsertPt); + } + if (AddrMode.Scale != 1) + V = BinaryOperator::CreateMul(V, ConstantInt::get(IntPtrTy, + AddrMode.Scale), + "sunkaddr", InsertPt); + Result = V; + } + + // Add in the base register. + if (AddrMode.BaseReg) { + Value *V = AddrMode.BaseReg; + if (V->getType() != IntPtrTy) + V = new PtrToIntInst(V, IntPtrTy, "sunkaddr", InsertPt); + if (Result) + Result = BinaryOperator::CreateAdd(Result, V, "sunkaddr", InsertPt); + else + Result = V; + } + + // Add in the BaseGV if present. + if (AddrMode.BaseGV) { + Value *V = new PtrToIntInst(AddrMode.BaseGV, IntPtrTy, "sunkaddr", + InsertPt); + if (Result) + Result = BinaryOperator::CreateAdd(Result, V, "sunkaddr", InsertPt); + else + Result = V; + } + + // Add in the Base Offset if present. + if (AddrMode.BaseOffs) { + Value *V = ConstantInt::get(IntPtrTy, AddrMode.BaseOffs); + if (Result) + Result = BinaryOperator::CreateAdd(Result, V, "sunkaddr", InsertPt); + else + Result = V; + } + + if (Result == 0) + SunkAddr = Constant::getNullValue(Addr->getType()); + else + SunkAddr = new IntToPtrInst(Result, Addr->getType(), "sunkaddr",InsertPt); + } + + MemoryInst->replaceUsesOfWith(Addr, SunkAddr); + + if (Addr->use_empty()) + RecursivelyDeleteTriviallyDeadInstructions(Addr); + return true; +} + +/// OptimizeInlineAsmInst - If there are any memory operands, use +/// OptimizeMemoryInst to sink their address computing into the block when +/// possible / profitable. +bool CodeGenPrepare::OptimizeInlineAsmInst(Instruction *I, CallSite CS, + DenseMap<Value*,Value*> &SunkAddrs) { + bool MadeChange = false; + InlineAsm *IA = cast<InlineAsm>(CS.getCalledValue()); + + // Do a prepass over the constraints, canonicalizing them, and building up the + // ConstraintOperands list. + std::vector<InlineAsm::ConstraintInfo> + ConstraintInfos = IA->ParseConstraints(); + + /// ConstraintOperands - Information about all of the constraints. + std::vector<TargetLowering::AsmOperandInfo> ConstraintOperands; + unsigned ArgNo = 0; // ArgNo - The argument of the CallInst. + for (unsigned i = 0, e = ConstraintInfos.size(); i != e; ++i) { + ConstraintOperands. + push_back(TargetLowering::AsmOperandInfo(ConstraintInfos[i])); + TargetLowering::AsmOperandInfo &OpInfo = ConstraintOperands.back(); + + // Compute the value type for each operand. + switch (OpInfo.Type) { + case InlineAsm::isOutput: + if (OpInfo.isIndirect) + OpInfo.CallOperandVal = CS.getArgument(ArgNo++); + break; + case InlineAsm::isInput: + OpInfo.CallOperandVal = CS.getArgument(ArgNo++); + break; + case InlineAsm::isClobber: + // Nothing to do. + break; + } + + // Compute the constraint code and ConstraintType to use. + TLI->ComputeConstraintToUse(OpInfo, SDValue(), + OpInfo.ConstraintType == TargetLowering::C_Memory); + + if (OpInfo.ConstraintType == TargetLowering::C_Memory && + OpInfo.isIndirect) { + Value *OpVal = OpInfo.CallOperandVal; + MadeChange |= OptimizeMemoryInst(I, OpVal, OpVal->getType(), SunkAddrs); + } + } + + return MadeChange; +} + +bool CodeGenPrepare::OptimizeExtUses(Instruction *I) { + BasicBlock *DefBB = I->getParent(); + + // If both result of the {s|z}xt and its source are live out, rewrite all + // other uses of the source with result of extension. + Value *Src = I->getOperand(0); + if (Src->hasOneUse()) + return false; + + // Only do this xform if truncating is free. + if (TLI && !TLI->isTruncateFree(I->getType(), Src->getType())) + return false; + + // Only safe to perform the optimization if the source is also defined in + // this block. + if (!isa<Instruction>(Src) || DefBB != cast<Instruction>(Src)->getParent()) + return false; + + bool DefIsLiveOut = false; + for (Value::use_iterator UI = I->use_begin(), E = I->use_end(); + UI != E; ++UI) { + Instruction *User = cast<Instruction>(*UI); + + // Figure out which BB this ext is used in. + BasicBlock *UserBB = User->getParent(); + if (UserBB == DefBB) continue; + DefIsLiveOut = true; + break; + } + if (!DefIsLiveOut) + return false; + + // Make sure non of the uses are PHI nodes. + for (Value::use_iterator UI = Src->use_begin(), E = Src->use_end(); + UI != E; ++UI) { + Instruction *User = cast<Instruction>(*UI); + BasicBlock *UserBB = User->getParent(); + if (UserBB == DefBB) continue; + // Be conservative. We don't want this xform to end up introducing + // reloads just before load / store instructions. + if (isa<PHINode>(User) || isa<LoadInst>(User) || isa<StoreInst>(User)) + return false; + } + + // InsertedTruncs - Only insert one trunc in each block once. + DenseMap<BasicBlock*, Instruction*> InsertedTruncs; + + bool MadeChange = false; + for (Value::use_iterator UI = Src->use_begin(), E = Src->use_end(); + UI != E; ++UI) { + Use &TheUse = UI.getUse(); + Instruction *User = cast<Instruction>(*UI); + + // Figure out which BB this ext is used in. + BasicBlock *UserBB = User->getParent(); + if (UserBB == DefBB) continue; + + // Both src and def are live in this block. Rewrite the use. + Instruction *&InsertedTrunc = InsertedTruncs[UserBB]; + + if (!InsertedTrunc) { + BasicBlock::iterator InsertPt = UserBB->getFirstNonPHI(); + + InsertedTrunc = new TruncInst(I, Src->getType(), "", InsertPt); + } + + // Replace a use of the {s|z}ext source with a use of the result. + TheUse = InsertedTrunc; + + MadeChange = true; + } + + return MadeChange; +} + +// In this pass we look for GEP and cast instructions that are used +// across basic blocks and rewrite them to improve basic-block-at-a-time +// selection. +bool CodeGenPrepare::OptimizeBlock(BasicBlock &BB) { + bool MadeChange = false; + + // Split all critical edges where the dest block has a PHI. + TerminatorInst *BBTI = BB.getTerminator(); + if (BBTI->getNumSuccessors() > 1) { + for (unsigned i = 0, e = BBTI->getNumSuccessors(); i != e; ++i) { + BasicBlock *SuccBB = BBTI->getSuccessor(i); + if (isa<PHINode>(SuccBB->begin()) && isCriticalEdge(BBTI, i, true)) + SplitEdgeNicely(BBTI, i, BackEdges, this); + } + } + + // Keep track of non-local addresses that have been sunk into this block. + // This allows us to avoid inserting duplicate code for blocks with multiple + // load/stores of the same address. + DenseMap<Value*, Value*> SunkAddrs; + + for (BasicBlock::iterator BBI = BB.begin(), E = BB.end(); BBI != E; ) { + Instruction *I = BBI++; + + if (CastInst *CI = dyn_cast<CastInst>(I)) { + // If the source of the cast is a constant, then this should have + // already been constant folded. The only reason NOT to constant fold + // it is if something (e.g. LSR) was careful to place the constant + // evaluation in a block other than then one that uses it (e.g. to hoist + // the address of globals out of a loop). If this is the case, we don't + // want to forward-subst the cast. + if (isa<Constant>(CI->getOperand(0))) + continue; + + bool Change = false; + if (TLI) { + Change = OptimizeNoopCopyExpression(CI, *TLI); + MadeChange |= Change; + } + + if (!Change && (isa<ZExtInst>(I) || isa<SExtInst>(I))) + MadeChange |= OptimizeExtUses(I); + } else if (CmpInst *CI = dyn_cast<CmpInst>(I)) { + MadeChange |= OptimizeCmpExpression(CI); + } else if (LoadInst *LI = dyn_cast<LoadInst>(I)) { + if (TLI) + MadeChange |= OptimizeMemoryInst(I, I->getOperand(0), LI->getType(), + SunkAddrs); + } else if (StoreInst *SI = dyn_cast<StoreInst>(I)) { + if (TLI) + MadeChange |= OptimizeMemoryInst(I, SI->getOperand(1), + SI->getOperand(0)->getType(), + SunkAddrs); + } else if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(I)) { + if (GEPI->hasAllZeroIndices()) { + /// The GEP operand must be a pointer, so must its result -> BitCast + Instruction *NC = new BitCastInst(GEPI->getOperand(0), GEPI->getType(), + GEPI->getName(), GEPI); + GEPI->replaceAllUsesWith(NC); + GEPI->eraseFromParent(); + MadeChange = true; + BBI = NC; + } + } else if (CallInst *CI = dyn_cast<CallInst>(I)) { + // If we found an inline asm expession, and if the target knows how to + // lower it to normal LLVM code, do so now. + if (TLI && isa<InlineAsm>(CI->getCalledValue())) + if (const TargetAsmInfo *TAI = + TLI->getTargetMachine().getTargetAsmInfo()) { + if (TAI->ExpandInlineAsm(CI)) { + BBI = BB.begin(); + // Avoid processing instructions out of order, which could cause + // reuse before a value is defined. + SunkAddrs.clear(); + } else + // Sink address computing for memory operands into the block. + MadeChange |= OptimizeInlineAsmInst(I, &(*CI), SunkAddrs); + } + } + } + + return MadeChange; +} diff --git a/lib/Transforms/Scalar/CondPropagate.cpp b/lib/Transforms/Scalar/CondPropagate.cpp new file mode 100644 index 0000000..c85d031 --- /dev/null +++ b/lib/Transforms/Scalar/CondPropagate.cpp @@ -0,0 +1,295 @@ +//===-- CondPropagate.cpp - Propagate Conditional Expressions -------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass propagates information about conditional expressions through the +// program, allowing it to eliminate conditional branches in some cases. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "condprop" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Constants.h" +#include "llvm/Function.h" +#include "llvm/Instructions.h" +#include "llvm/IntrinsicInst.h" +#include "llvm/Pass.h" +#include "llvm/Type.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Streams.h" +using namespace llvm; + +STATISTIC(NumBrThread, "Number of CFG edges threaded through branches"); +STATISTIC(NumSwThread, "Number of CFG edges threaded through switches"); + +namespace { + struct VISIBILITY_HIDDEN CondProp : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + CondProp() : FunctionPass(&ID) {} + + virtual bool runOnFunction(Function &F); + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequiredID(BreakCriticalEdgesID); + //AU.addRequired<DominanceFrontier>(); + } + + private: + bool MadeChange; + SmallVector<BasicBlock *, 4> DeadBlocks; + void SimplifyBlock(BasicBlock *BB); + void SimplifyPredecessors(BranchInst *BI); + void SimplifyPredecessors(SwitchInst *SI); + void RevectorBlockTo(BasicBlock *FromBB, BasicBlock *ToBB); + bool RevectorBlockTo(BasicBlock *FromBB, Value *Cond, BranchInst *BI); + }; +} + +char CondProp::ID = 0; +static RegisterPass<CondProp> X("condprop", "Conditional Propagation"); + +FunctionPass *llvm::createCondPropagationPass() { + return new CondProp(); +} + +bool CondProp::runOnFunction(Function &F) { + bool EverMadeChange = false; + DeadBlocks.clear(); + + // While we are simplifying blocks, keep iterating. + do { + MadeChange = false; + for (Function::iterator BB = F.begin(), E = F.end(); BB != E;) + SimplifyBlock(BB++); + EverMadeChange = EverMadeChange || MadeChange; + } while (MadeChange); + + if (EverMadeChange) { + while (!DeadBlocks.empty()) { + BasicBlock *BB = DeadBlocks.back(); DeadBlocks.pop_back(); + DeleteDeadBlock(BB); + } + } + return EverMadeChange; +} + +void CondProp::SimplifyBlock(BasicBlock *BB) { + if (BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator())) { + // If this is a conditional branch based on a phi node that is defined in + // this block, see if we can simplify predecessors of this block. + if (BI->isConditional() && isa<PHINode>(BI->getCondition()) && + cast<PHINode>(BI->getCondition())->getParent() == BB) + SimplifyPredecessors(BI); + + } else if (SwitchInst *SI = dyn_cast<SwitchInst>(BB->getTerminator())) { + if (isa<PHINode>(SI->getCondition()) && + cast<PHINode>(SI->getCondition())->getParent() == BB) + SimplifyPredecessors(SI); + } + + // If possible, simplify the terminator of this block. + if (ConstantFoldTerminator(BB)) + MadeChange = true; + + // If this block ends with an unconditional branch and the only successor has + // only this block as a predecessor, merge the two blocks together. + if (BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator())) + if (BI->isUnconditional() && BI->getSuccessor(0)->getSinglePredecessor() && + BB != BI->getSuccessor(0)) { + BasicBlock *Succ = BI->getSuccessor(0); + + // If Succ has any PHI nodes, they are all single-entry PHI's. Eliminate + // them. + FoldSingleEntryPHINodes(Succ); + + // Remove BI. + BI->eraseFromParent(); + + // Move over all of the instructions. + BB->getInstList().splice(BB->end(), Succ->getInstList()); + + // Any phi nodes that had entries for Succ now have entries from BB. + Succ->replaceAllUsesWith(BB); + + // Succ is now dead, but we cannot delete it without potentially + // invalidating iterators elsewhere. Just insert an unreachable + // instruction in it and delete this block later on. + new UnreachableInst(Succ); + DeadBlocks.push_back(Succ); + MadeChange = true; + } +} + +// SimplifyPredecessors(branches) - We know that BI is a conditional branch +// based on a PHI node defined in this block. If the phi node contains constant +// operands, then the blocks corresponding to those operands can be modified to +// jump directly to the destination instead of going through this block. +void CondProp::SimplifyPredecessors(BranchInst *BI) { + // TODO: We currently only handle the most trival case, where the PHI node has + // one use (the branch), and is the only instruction besides the branch and dbg + // intrinsics in the block. + PHINode *PN = cast<PHINode>(BI->getCondition()); + + if (PN->getNumIncomingValues() == 1) { + // Eliminate single-entry PHI nodes. + FoldSingleEntryPHINodes(PN->getParent()); + return; + } + + + if (!PN->hasOneUse()) return; + + BasicBlock *BB = BI->getParent(); + if (&*BB->begin() != PN) + return; + BasicBlock::iterator BBI = BB->begin(); + BasicBlock::iterator BBE = BB->end(); + while (BBI != BBE && isa<DbgInfoIntrinsic>(++BBI)) /* empty */; + if (&*BBI != BI) + return; + + // Ok, we have this really simple case, walk the PHI operands, looking for + // constants. Walk from the end to remove operands from the end when + // possible, and to avoid invalidating "i". + for (unsigned i = PN->getNumIncomingValues(); i != 0; --i) { + Value *InVal = PN->getIncomingValue(i-1); + if (!RevectorBlockTo(PN->getIncomingBlock(i-1), InVal, BI)) + continue; + + ++NumBrThread; + + // If there were two predecessors before this simplification, or if the + // PHI node contained all the same value except for the one we just + // substituted, the PHI node may be deleted. Don't iterate through it the + // last time. + if (BI->getCondition() != PN) return; + } +} + +// SimplifyPredecessors(switch) - We know that SI is switch based on a PHI node +// defined in this block. If the phi node contains constant operands, then the +// blocks corresponding to those operands can be modified to jump directly to +// the destination instead of going through this block. +void CondProp::SimplifyPredecessors(SwitchInst *SI) { + // TODO: We currently only handle the most trival case, where the PHI node has + // one use (the branch), and is the only instruction besides the branch and + // dbg intrinsics in the block. + PHINode *PN = cast<PHINode>(SI->getCondition()); + if (!PN->hasOneUse()) return; + + BasicBlock *BB = SI->getParent(); + if (&*BB->begin() != PN) + return; + BasicBlock::iterator BBI = BB->begin(); + BasicBlock::iterator BBE = BB->end(); + while (BBI != BBE && isa<DbgInfoIntrinsic>(++BBI)) /* empty */; + if (&*BBI != SI) + return; + + bool RemovedPreds = false; + + // Ok, we have this really simple case, walk the PHI operands, looking for + // constants. Walk from the end to remove operands from the end when + // possible, and to avoid invalidating "i". + for (unsigned i = PN->getNumIncomingValues(); i != 0; --i) + if (ConstantInt *CI = dyn_cast<ConstantInt>(PN->getIncomingValue(i-1))) { + // If we have a constant, forward the edge from its current to its + // ultimate destination. + unsigned DestCase = SI->findCaseValue(CI); + RevectorBlockTo(PN->getIncomingBlock(i-1), + SI->getSuccessor(DestCase)); + ++NumSwThread; + RemovedPreds = true; + + // If there were two predecessors before this simplification, or if the + // PHI node contained all the same value except for the one we just + // substituted, the PHI node may be deleted. Don't iterate through it the + // last time. + if (SI->getCondition() != PN) return; + } +} + + +// RevectorBlockTo - Revector the unconditional branch at the end of FromBB to +// the ToBB block, which is one of the successors of its current successor. +void CondProp::RevectorBlockTo(BasicBlock *FromBB, BasicBlock *ToBB) { + BranchInst *FromBr = cast<BranchInst>(FromBB->getTerminator()); + assert(FromBr->isUnconditional() && "FromBB should end with uncond br!"); + + // Get the old block we are threading through. + BasicBlock *OldSucc = FromBr->getSuccessor(0); + + // OldSucc had multiple successors. If ToBB has multiple predecessors, then + // the edge between them would be critical, which we already took care of. + // If ToBB has single operand PHI node then take care of it here. + FoldSingleEntryPHINodes(ToBB); + + // Update PHI nodes in OldSucc to know that FromBB no longer branches to it. + OldSucc->removePredecessor(FromBB); + + // Change FromBr to branch to the new destination. + FromBr->setSuccessor(0, ToBB); + + MadeChange = true; +} + +bool CondProp::RevectorBlockTo(BasicBlock *FromBB, Value *Cond, BranchInst *BI){ + BranchInst *FromBr = cast<BranchInst>(FromBB->getTerminator()); + if (!FromBr->isUnconditional()) + return false; + + // Get the old block we are threading through. + BasicBlock *OldSucc = FromBr->getSuccessor(0); + + // If the condition is a constant, simply revector the unconditional branch at + // the end of FromBB to one of the successors of its current successor. + if (ConstantInt *CB = dyn_cast<ConstantInt>(Cond)) { + BasicBlock *ToBB = BI->getSuccessor(CB->isZero()); + + // OldSucc had multiple successors. If ToBB has multiple predecessors, then + // the edge between them would be critical, which we already took care of. + // If ToBB has single operand PHI node then take care of it here. + FoldSingleEntryPHINodes(ToBB); + + // Update PHI nodes in OldSucc to know that FromBB no longer branches to it. + OldSucc->removePredecessor(FromBB); + + // Change FromBr to branch to the new destination. + FromBr->setSuccessor(0, ToBB); + } else { + BasicBlock *Succ0 = BI->getSuccessor(0); + // Do not perform transform if the new destination has PHI nodes. The + // transform will add new preds to the PHI's. + if (isa<PHINode>(Succ0->begin())) + return false; + + BasicBlock *Succ1 = BI->getSuccessor(1); + if (isa<PHINode>(Succ1->begin())) + return false; + + // Insert the new conditional branch. + BranchInst::Create(Succ0, Succ1, Cond, FromBr); + + FoldSingleEntryPHINodes(Succ0); + FoldSingleEntryPHINodes(Succ1); + + // Update PHI nodes in OldSucc to know that FromBB no longer branches to it. + OldSucc->removePredecessor(FromBB); + + // Delete the old branch. + FromBr->eraseFromParent(); + } + + MadeChange = true; + return true; +} diff --git a/lib/Transforms/Scalar/ConstantProp.cpp b/lib/Transforms/Scalar/ConstantProp.cpp new file mode 100644 index 0000000..b933488 --- /dev/null +++ b/lib/Transforms/Scalar/ConstantProp.cpp @@ -0,0 +1,90 @@ +//===- ConstantProp.cpp - Code to perform Simple Constant Propagation -----===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements constant propagation and merging: +// +// Specifically, this: +// * Converts instructions like "add int 1, 2" into 3 +// +// Notice that: +// * This pass has a habit of making definitions be dead. It is a good idea +// to run a DIE pass sometime after running this pass. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "constprop" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Constant.h" +#include "llvm/Instruction.h" +#include "llvm/Pass.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/InstIterator.h" +#include "llvm/ADT/Statistic.h" +#include <set> +using namespace llvm; + +STATISTIC(NumInstKilled, "Number of instructions killed"); + +namespace { + struct VISIBILITY_HIDDEN ConstantPropagation : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + ConstantPropagation() : FunctionPass(&ID) {} + + bool runOnFunction(Function &F); + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesCFG(); + } + }; +} + +char ConstantPropagation::ID = 0; +static RegisterPass<ConstantPropagation> +X("constprop", "Simple constant propagation"); + +FunctionPass *llvm::createConstantPropagationPass() { + return new ConstantPropagation(); +} + + +bool ConstantPropagation::runOnFunction(Function &F) { + // Initialize the worklist to all of the instructions ready to process... + std::set<Instruction*> WorkList; + for(inst_iterator i = inst_begin(F), e = inst_end(F); i != e; ++i) { + WorkList.insert(&*i); + } + bool Changed = false; + + while (!WorkList.empty()) { + Instruction *I = *WorkList.begin(); + WorkList.erase(WorkList.begin()); // Get an element from the worklist... + + if (!I->use_empty()) // Don't muck with dead instructions... + if (Constant *C = ConstantFoldInstruction(I)) { + // Add all of the users of this instruction to the worklist, they might + // be constant propagatable now... + for (Value::use_iterator UI = I->use_begin(), UE = I->use_end(); + UI != UE; ++UI) + WorkList.insert(cast<Instruction>(*UI)); + + // Replace all of the uses of a variable with uses of the constant. + I->replaceAllUsesWith(C); + + // Remove the dead instruction. + WorkList.erase(I); + I->eraseFromParent(); + + // We made a change to the function... + Changed = true; + ++NumInstKilled; + } + } + return Changed; +} diff --git a/lib/Transforms/Scalar/DCE.cpp b/lib/Transforms/Scalar/DCE.cpp new file mode 100644 index 0000000..8bb504c --- /dev/null +++ b/lib/Transforms/Scalar/DCE.cpp @@ -0,0 +1,133 @@ +//===- DCE.cpp - Code to perform dead code elimination --------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements dead inst elimination and dead code elimination. +// +// Dead Inst Elimination performs a single pass over the function removing +// instructions that are obviously dead. Dead Code Elimination is similar, but +// it rechecks instructions that were used by removed instructions to see if +// they are newly dead. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "dce" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Instruction.h" +#include "llvm/Pass.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/InstIterator.h" +#include "llvm/ADT/Statistic.h" +#include <set> +using namespace llvm; + +STATISTIC(DIEEliminated, "Number of insts removed by DIE pass"); +STATISTIC(DCEEliminated, "Number of insts removed"); + +namespace { + //===--------------------------------------------------------------------===// + // DeadInstElimination pass implementation + // + struct VISIBILITY_HIDDEN DeadInstElimination : public BasicBlockPass { + static char ID; // Pass identification, replacement for typeid + DeadInstElimination() : BasicBlockPass(&ID) {} + virtual bool runOnBasicBlock(BasicBlock &BB) { + bool Changed = false; + for (BasicBlock::iterator DI = BB.begin(); DI != BB.end(); ) { + Instruction *Inst = DI++; + if (isInstructionTriviallyDead(Inst)) { + Inst->eraseFromParent(); + Changed = true; + ++DIEEliminated; + } + } + return Changed; + } + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesCFG(); + } + }; +} + +char DeadInstElimination::ID = 0; +static RegisterPass<DeadInstElimination> +X("die", "Dead Instruction Elimination"); + +Pass *llvm::createDeadInstEliminationPass() { + return new DeadInstElimination(); +} + + +namespace { + //===--------------------------------------------------------------------===// + // DeadCodeElimination pass implementation + // + struct DCE : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + DCE() : FunctionPass(&ID) {} + + virtual bool runOnFunction(Function &F); + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesCFG(); + } + }; +} + +char DCE::ID = 0; +static RegisterPass<DCE> Y("dce", "Dead Code Elimination"); + +bool DCE::runOnFunction(Function &F) { + // Start out with all of the instructions in the worklist... + std::vector<Instruction*> WorkList; + for (inst_iterator i = inst_begin(F), e = inst_end(F); i != e; ++i) + WorkList.push_back(&*i); + + // Loop over the worklist finding instructions that are dead. If they are + // dead make them drop all of their uses, making other instructions + // potentially dead, and work until the worklist is empty. + // + bool MadeChange = false; + while (!WorkList.empty()) { + Instruction *I = WorkList.back(); + WorkList.pop_back(); + + if (isInstructionTriviallyDead(I)) { // If the instruction is dead. + // Loop over all of the values that the instruction uses, if there are + // instructions being used, add them to the worklist, because they might + // go dead after this one is removed. + // + for (User::op_iterator OI = I->op_begin(), E = I->op_end(); OI != E; ++OI) + if (Instruction *Used = dyn_cast<Instruction>(*OI)) + WorkList.push_back(Used); + + // Remove the instruction. + I->eraseFromParent(); + + // Remove the instruction from the worklist if it still exists in it. + for (std::vector<Instruction*>::iterator WI = WorkList.begin(); + WI != WorkList.end(); ) { + if (*WI == I) + WI = WorkList.erase(WI); + else + ++WI; + } + + MadeChange = true; + ++DCEEliminated; + } + } + return MadeChange; +} + +FunctionPass *llvm::createDeadCodeEliminationPass() { + return new DCE(); +} + diff --git a/lib/Transforms/Scalar/DeadStoreElimination.cpp b/lib/Transforms/Scalar/DeadStoreElimination.cpp new file mode 100644 index 0000000..b923c92 --- /dev/null +++ b/lib/Transforms/Scalar/DeadStoreElimination.cpp @@ -0,0 +1,461 @@ +//===- DeadStoreElimination.cpp - Fast Dead Store Elimination -------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements a trivial dead store elimination that only considers +// basic-block local redundant stores. +// +// FIXME: This should eventually be extended to be a post-dominator tree +// traversal. Doing so would be pretty trivial. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "dse" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Constants.h" +#include "llvm/Function.h" +#include "llvm/Instructions.h" +#include "llvm/IntrinsicInst.h" +#include "llvm/Pass.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/Dominators.h" +#include "llvm/Analysis/MemoryDependenceAnalysis.h" +#include "llvm/Target/TargetData.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Support/Compiler.h" +using namespace llvm; + +STATISTIC(NumFastStores, "Number of stores deleted"); +STATISTIC(NumFastOther , "Number of other instrs removed"); + +namespace { + struct VISIBILITY_HIDDEN DSE : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + DSE() : FunctionPass(&ID) {} + + virtual bool runOnFunction(Function &F) { + bool Changed = false; + for (Function::iterator I = F.begin(), E = F.end(); I != E; ++I) + Changed |= runOnBasicBlock(*I); + return Changed; + } + + bool runOnBasicBlock(BasicBlock &BB); + bool handleFreeWithNonTrivialDependency(FreeInst *F, MemDepResult Dep); + bool handleEndBlock(BasicBlock &BB); + bool RemoveUndeadPointers(Value* Ptr, uint64_t killPointerSize, + BasicBlock::iterator& BBI, + SmallPtrSet<Value*, 64>& deadPointers); + void DeleteDeadInstruction(Instruction *I, + SmallPtrSet<Value*, 64> *deadPointers = 0); + + + // getAnalysisUsage - We require post dominance frontiers (aka Control + // Dependence Graph) + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesCFG(); + AU.addRequired<DominatorTree>(); + AU.addRequired<TargetData>(); + AU.addRequired<AliasAnalysis>(); + AU.addRequired<MemoryDependenceAnalysis>(); + AU.addPreserved<DominatorTree>(); + AU.addPreserved<AliasAnalysis>(); + AU.addPreserved<MemoryDependenceAnalysis>(); + } + }; +} + +char DSE::ID = 0; +static RegisterPass<DSE> X("dse", "Dead Store Elimination"); + +FunctionPass *llvm::createDeadStoreEliminationPass() { return new DSE(); } + +bool DSE::runOnBasicBlock(BasicBlock &BB) { + MemoryDependenceAnalysis& MD = getAnalysis<MemoryDependenceAnalysis>(); + TargetData &TD = getAnalysis<TargetData>(); + + bool MadeChange = false; + + // Do a top-down walk on the BB + for (BasicBlock::iterator BBI = BB.begin(), BBE = BB.end(); BBI != BBE; ) { + Instruction *Inst = BBI++; + + // If we find a store or a free, get it's memory dependence. + if (!isa<StoreInst>(Inst) && !isa<FreeInst>(Inst)) + continue; + + // Don't molest volatile stores or do queries that will return "clobber". + if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) + if (SI->isVolatile()) + continue; + + MemDepResult InstDep = MD.getDependency(Inst); + + // Ignore non-local stores. + // FIXME: cross-block DSE would be fun. :) + if (InstDep.isNonLocal()) continue; + + // Handle frees whose dependencies are non-trivial. + if (FreeInst *FI = dyn_cast<FreeInst>(Inst)) { + MadeChange |= handleFreeWithNonTrivialDependency(FI, InstDep); + continue; + } + + StoreInst *SI = cast<StoreInst>(Inst); + + // If not a definite must-alias dependency, ignore it. + if (!InstDep.isDef()) + continue; + + // If this is a store-store dependence, then the previous store is dead so + // long as this store is at least as big as it. + if (StoreInst *DepStore = dyn_cast<StoreInst>(InstDep.getInst())) + if (TD.getTypeStoreSize(DepStore->getOperand(0)->getType()) <= + TD.getTypeStoreSize(SI->getOperand(0)->getType())) { + // Delete the store and now-dead instructions that feed it. + DeleteDeadInstruction(DepStore); + NumFastStores++; + MadeChange = true; + + if (BBI != BB.begin()) + --BBI; + continue; + } + + // If we're storing the same value back to a pointer that we just + // loaded from, then the store can be removed. + if (LoadInst *DepLoad = dyn_cast<LoadInst>(InstDep.getInst())) { + if (SI->getPointerOperand() == DepLoad->getPointerOperand() && + SI->getOperand(0) == DepLoad) { + DeleteDeadInstruction(SI); + if (BBI != BB.begin()) + --BBI; + NumFastStores++; + MadeChange = true; + continue; + } + } + } + + // If this block ends in a return, unwind, or unreachable, all allocas are + // dead at its end, which means stores to them are also dead. + if (BB.getTerminator()->getNumSuccessors() == 0) + MadeChange |= handleEndBlock(BB); + + return MadeChange; +} + +/// handleFreeWithNonTrivialDependency - Handle frees of entire structures whose +/// dependency is a store to a field of that structure. +bool DSE::handleFreeWithNonTrivialDependency(FreeInst *F, MemDepResult Dep) { + AliasAnalysis &AA = getAnalysis<AliasAnalysis>(); + + StoreInst *Dependency = dyn_cast_or_null<StoreInst>(Dep.getInst()); + if (!Dependency || Dependency->isVolatile()) + return false; + + Value *DepPointer = Dependency->getPointerOperand()->getUnderlyingObject(); + + // Check for aliasing. + if (AA.alias(F->getPointerOperand(), 1, DepPointer, 1) != + AliasAnalysis::MustAlias) + return false; + + // DCE instructions only used to calculate that store + DeleteDeadInstruction(Dependency); + NumFastStores++; + return true; +} + +/// handleEndBlock - Remove dead stores to stack-allocated locations in the +/// function end block. Ex: +/// %A = alloca i32 +/// ... +/// store i32 1, i32* %A +/// ret void +bool DSE::handleEndBlock(BasicBlock &BB) { + TargetData &TD = getAnalysis<TargetData>(); + AliasAnalysis &AA = getAnalysis<AliasAnalysis>(); + + bool MadeChange = false; + + // Pointers alloca'd in this function are dead in the end block + SmallPtrSet<Value*, 64> deadPointers; + + // Find all of the alloca'd pointers in the entry block. + BasicBlock *Entry = BB.getParent()->begin(); + for (BasicBlock::iterator I = Entry->begin(), E = Entry->end(); I != E; ++I) + if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) + deadPointers.insert(AI); + + // Treat byval arguments the same, stores to them are dead at the end of the + // function. + for (Function::arg_iterator AI = BB.getParent()->arg_begin(), + AE = BB.getParent()->arg_end(); AI != AE; ++AI) + if (AI->hasByValAttr()) + deadPointers.insert(AI); + + // Scan the basic block backwards + for (BasicBlock::iterator BBI = BB.end(); BBI != BB.begin(); ){ + --BBI; + + // If we find a store whose pointer is dead. + if (StoreInst* S = dyn_cast<StoreInst>(BBI)) { + if (!S->isVolatile()) { + // See through pointer-to-pointer bitcasts + Value* pointerOperand = S->getPointerOperand()->getUnderlyingObject(); + + // Alloca'd pointers or byval arguments (which are functionally like + // alloca's) are valid candidates for removal. + if (deadPointers.count(pointerOperand)) { + // DCE instructions only used to calculate that store. + BBI++; + DeleteDeadInstruction(S, &deadPointers); + NumFastStores++; + MadeChange = true; + } + } + + continue; + } + + // We can also remove memcpy's to local variables at the end of a function. + if (MemCpyInst *M = dyn_cast<MemCpyInst>(BBI)) { + Value *dest = M->getDest()->getUnderlyingObject(); + + if (deadPointers.count(dest)) { + BBI++; + DeleteDeadInstruction(M, &deadPointers); + NumFastOther++; + MadeChange = true; + continue; + } + + // Because a memcpy is also a load, we can't skip it if we didn't remove + // it. + } + + Value* killPointer = 0; + uint64_t killPointerSize = ~0UL; + + // If we encounter a use of the pointer, it is no longer considered dead + if (LoadInst *L = dyn_cast<LoadInst>(BBI)) { + // However, if this load is unused and not volatile, we can go ahead and + // remove it, and not have to worry about it making our pointer undead! + if (L->use_empty() && !L->isVolatile()) { + BBI++; + DeleteDeadInstruction(L, &deadPointers); + NumFastOther++; + MadeChange = true; + continue; + } + + killPointer = L->getPointerOperand(); + } else if (VAArgInst* V = dyn_cast<VAArgInst>(BBI)) { + killPointer = V->getOperand(0); + } else if (isa<MemCpyInst>(BBI) && + isa<ConstantInt>(cast<MemCpyInst>(BBI)->getLength())) { + killPointer = cast<MemCpyInst>(BBI)->getSource(); + killPointerSize = cast<ConstantInt>( + cast<MemCpyInst>(BBI)->getLength())->getZExtValue(); + } else if (AllocaInst* A = dyn_cast<AllocaInst>(BBI)) { + deadPointers.erase(A); + + // Dead alloca's can be DCE'd when we reach them + if (A->use_empty()) { + BBI++; + DeleteDeadInstruction(A, &deadPointers); + NumFastOther++; + MadeChange = true; + } + + continue; + } else if (CallSite::get(BBI).getInstruction() != 0) { + // If this call does not access memory, it can't + // be undeadifying any of our pointers. + CallSite CS = CallSite::get(BBI); + if (AA.doesNotAccessMemory(CS)) + continue; + + unsigned modRef = 0; + unsigned other = 0; + + // Remove any pointers made undead by the call from the dead set + std::vector<Value*> dead; + for (SmallPtrSet<Value*, 64>::iterator I = deadPointers.begin(), + E = deadPointers.end(); I != E; ++I) { + // HACK: if we detect that our AA is imprecise, it's not + // worth it to scan the rest of the deadPointers set. Just + // assume that the AA will return ModRef for everything, and + // go ahead and bail. + if (modRef >= 16 && other == 0) { + deadPointers.clear(); + return MadeChange; + } + + // Get size information for the alloca + unsigned pointerSize = ~0U; + if (AllocaInst* A = dyn_cast<AllocaInst>(*I)) { + if (ConstantInt* C = dyn_cast<ConstantInt>(A->getArraySize())) + pointerSize = C->getZExtValue() * + TD.getTypeAllocSize(A->getAllocatedType()); + } else { + const PointerType* PT = cast<PointerType>( + cast<Argument>(*I)->getType()); + pointerSize = TD.getTypeAllocSize(PT->getElementType()); + } + + // See if the call site touches it + AliasAnalysis::ModRefResult A = AA.getModRefInfo(CS, *I, pointerSize); + + if (A == AliasAnalysis::ModRef) + modRef++; + else + other++; + + if (A == AliasAnalysis::ModRef || A == AliasAnalysis::Ref) + dead.push_back(*I); + } + + for (std::vector<Value*>::iterator I = dead.begin(), E = dead.end(); + I != E; ++I) + deadPointers.erase(*I); + + continue; + } else if (isInstructionTriviallyDead(BBI)) { + // For any non-memory-affecting non-terminators, DCE them as we reach them + Instruction *Inst = BBI; + BBI++; + DeleteDeadInstruction(Inst, &deadPointers); + NumFastOther++; + MadeChange = true; + continue; + } + + if (!killPointer) + continue; + + killPointer = killPointer->getUnderlyingObject(); + + // Deal with undead pointers + MadeChange |= RemoveUndeadPointers(killPointer, killPointerSize, BBI, + deadPointers); + } + + return MadeChange; +} + +/// RemoveUndeadPointers - check for uses of a pointer that make it +/// undead when scanning for dead stores to alloca's. +bool DSE::RemoveUndeadPointers(Value* killPointer, uint64_t killPointerSize, + BasicBlock::iterator &BBI, + SmallPtrSet<Value*, 64>& deadPointers) { + TargetData &TD = getAnalysis<TargetData>(); + AliasAnalysis &AA = getAnalysis<AliasAnalysis>(); + + // If the kill pointer can be easily reduced to an alloca, + // don't bother doing extraneous AA queries. + if (deadPointers.count(killPointer)) { + deadPointers.erase(killPointer); + return false; + } + + // A global can't be in the dead pointer set. + if (isa<GlobalValue>(killPointer)) + return false; + + bool MadeChange = false; + + SmallVector<Value*, 16> undead; + + for (SmallPtrSet<Value*, 64>::iterator I = deadPointers.begin(), + E = deadPointers.end(); I != E; ++I) { + // Get size information for the alloca. + unsigned pointerSize = ~0U; + if (AllocaInst* A = dyn_cast<AllocaInst>(*I)) { + if (ConstantInt* C = dyn_cast<ConstantInt>(A->getArraySize())) + pointerSize = C->getZExtValue() * + TD.getTypeAllocSize(A->getAllocatedType()); + } else { + const PointerType* PT = cast<PointerType>(cast<Argument>(*I)->getType()); + pointerSize = TD.getTypeAllocSize(PT->getElementType()); + } + + // See if this pointer could alias it + AliasAnalysis::AliasResult A = AA.alias(*I, pointerSize, + killPointer, killPointerSize); + + // If it must-alias and a store, we can delete it + if (isa<StoreInst>(BBI) && A == AliasAnalysis::MustAlias) { + StoreInst* S = cast<StoreInst>(BBI); + + // Remove it! + BBI++; + DeleteDeadInstruction(S, &deadPointers); + NumFastStores++; + MadeChange = true; + + continue; + + // Otherwise, it is undead + } else if (A != AliasAnalysis::NoAlias) + undead.push_back(*I); + } + + for (SmallVector<Value*, 16>::iterator I = undead.begin(), E = undead.end(); + I != E; ++I) + deadPointers.erase(*I); + + return MadeChange; +} + +/// DeleteDeadInstruction - Delete this instruction. Before we do, go through +/// and zero out all the operands of this instruction. If any of them become +/// dead, delete them and the computation tree that feeds them. +/// +/// If ValueSet is non-null, remove any deleted instructions from it as well. +/// +void DSE::DeleteDeadInstruction(Instruction *I, + SmallPtrSet<Value*, 64> *ValueSet) { + SmallVector<Instruction*, 32> NowDeadInsts; + + NowDeadInsts.push_back(I); + --NumFastOther; + + // Before we touch this instruction, remove it from memdep! + MemoryDependenceAnalysis &MDA = getAnalysis<MemoryDependenceAnalysis>(); + while (!NowDeadInsts.empty()) { + Instruction *DeadInst = NowDeadInsts.back(); + NowDeadInsts.pop_back(); + + ++NumFastOther; + + // This instruction is dead, zap it, in stages. Start by removing it from + // MemDep, which needs to know the operands and needs it to be in the + // function. + MDA.removeInstruction(DeadInst); + + for (unsigned op = 0, e = DeadInst->getNumOperands(); op != e; ++op) { + Value *Op = DeadInst->getOperand(op); + DeadInst->setOperand(op, 0); + + // If this operand just became dead, add it to the NowDeadInsts list. + if (!Op->use_empty()) continue; + + if (Instruction *OpI = dyn_cast<Instruction>(Op)) + if (isInstructionTriviallyDead(OpI)) + NowDeadInsts.push_back(OpI); + } + + DeadInst->eraseFromParent(); + + if (ValueSet) ValueSet->erase(DeadInst); + } +} diff --git a/lib/Transforms/Scalar/GVN.cpp b/lib/Transforms/Scalar/GVN.cpp new file mode 100644 index 0000000..733dfa9 --- /dev/null +++ b/lib/Transforms/Scalar/GVN.cpp @@ -0,0 +1,1738 @@ +//===- GVN.cpp - Eliminate redundant values and loads ---------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass performs global value numbering to eliminate fully redundant +// instructions. It also performs simple dead load elimination. +// +// Note that this pass does the value numbering itself; it does not use the +// ValueNumbering analysis passes. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "gvn" +#include "llvm/Transforms/Scalar.h" +#include "llvm/BasicBlock.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Function.h" +#include "llvm/IntrinsicInst.h" +#include "llvm/Value.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/Dominators.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/MemoryDependenceAnalysis.h" +#include "llvm/Support/CFG.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include <cstdio> +using namespace llvm; + +STATISTIC(NumGVNInstr, "Number of instructions deleted"); +STATISTIC(NumGVNLoad, "Number of loads deleted"); +STATISTIC(NumGVNPRE, "Number of instructions PRE'd"); +STATISTIC(NumGVNBlocks, "Number of blocks merged"); +STATISTIC(NumPRELoad, "Number of loads PRE'd"); + +static cl::opt<bool> EnablePRE("enable-pre", + cl::init(true), cl::Hidden); +cl::opt<bool> EnableLoadPRE("enable-load-pre", cl::init(true)); + +//===----------------------------------------------------------------------===// +// ValueTable Class +//===----------------------------------------------------------------------===// + +/// This class holds the mapping between values and value numbers. It is used +/// as an efficient mechanism to determine the expression-wise equivalence of +/// two values. +namespace { + struct VISIBILITY_HIDDEN Expression { + enum ExpressionOpcode { ADD, SUB, MUL, UDIV, SDIV, FDIV, UREM, SREM, + FREM, SHL, LSHR, ASHR, AND, OR, XOR, ICMPEQ, + ICMPNE, ICMPUGT, ICMPUGE, ICMPULT, ICMPULE, + ICMPSGT, ICMPSGE, ICMPSLT, ICMPSLE, FCMPOEQ, + FCMPOGT, FCMPOGE, FCMPOLT, FCMPOLE, FCMPONE, + FCMPORD, FCMPUNO, FCMPUEQ, FCMPUGT, FCMPUGE, + FCMPULT, FCMPULE, FCMPUNE, EXTRACT, INSERT, + SHUFFLE, SELECT, TRUNC, ZEXT, SEXT, FPTOUI, + FPTOSI, UITOFP, SITOFP, FPTRUNC, FPEXT, + PTRTOINT, INTTOPTR, BITCAST, GEP, CALL, CONSTANT, + EMPTY, TOMBSTONE }; + + ExpressionOpcode opcode; + const Type* type; + uint32_t firstVN; + uint32_t secondVN; + uint32_t thirdVN; + SmallVector<uint32_t, 4> varargs; + Value* function; + + Expression() { } + Expression(ExpressionOpcode o) : opcode(o) { } + + bool operator==(const Expression &other) const { + if (opcode != other.opcode) + return false; + else if (opcode == EMPTY || opcode == TOMBSTONE) + return true; + else if (type != other.type) + return false; + else if (function != other.function) + return false; + else if (firstVN != other.firstVN) + return false; + else if (secondVN != other.secondVN) + return false; + else if (thirdVN != other.thirdVN) + return false; + else { + if (varargs.size() != other.varargs.size()) + return false; + + for (size_t i = 0; i < varargs.size(); ++i) + if (varargs[i] != other.varargs[i]) + return false; + + return true; + } + } + + bool operator!=(const Expression &other) const { + return !(*this == other); + } + }; + + class VISIBILITY_HIDDEN ValueTable { + private: + DenseMap<Value*, uint32_t> valueNumbering; + DenseMap<Expression, uint32_t> expressionNumbering; + AliasAnalysis* AA; + MemoryDependenceAnalysis* MD; + DominatorTree* DT; + + uint32_t nextValueNumber; + + Expression::ExpressionOpcode getOpcode(BinaryOperator* BO); + Expression::ExpressionOpcode getOpcode(CmpInst* C); + Expression::ExpressionOpcode getOpcode(CastInst* C); + Expression create_expression(BinaryOperator* BO); + Expression create_expression(CmpInst* C); + Expression create_expression(ShuffleVectorInst* V); + Expression create_expression(ExtractElementInst* C); + Expression create_expression(InsertElementInst* V); + Expression create_expression(SelectInst* V); + Expression create_expression(CastInst* C); + Expression create_expression(GetElementPtrInst* G); + Expression create_expression(CallInst* C); + Expression create_expression(Constant* C); + public: + ValueTable() : nextValueNumber(1) { } + uint32_t lookup_or_add(Value* V); + uint32_t lookup(Value* V) const; + void add(Value* V, uint32_t num); + void clear(); + void erase(Value* v); + unsigned size(); + void setAliasAnalysis(AliasAnalysis* A) { AA = A; } + AliasAnalysis *getAliasAnalysis() const { return AA; } + void setMemDep(MemoryDependenceAnalysis* M) { MD = M; } + void setDomTree(DominatorTree* D) { DT = D; } + uint32_t getNextUnusedValueNumber() { return nextValueNumber; } + void verifyRemoved(const Value *) const; + }; +} + +namespace llvm { +template <> struct DenseMapInfo<Expression> { + static inline Expression getEmptyKey() { + return Expression(Expression::EMPTY); + } + + static inline Expression getTombstoneKey() { + return Expression(Expression::TOMBSTONE); + } + + static unsigned getHashValue(const Expression e) { + unsigned hash = e.opcode; + + hash = e.firstVN + hash * 37; + hash = e.secondVN + hash * 37; + hash = e.thirdVN + hash * 37; + + hash = ((unsigned)((uintptr_t)e.type >> 4) ^ + (unsigned)((uintptr_t)e.type >> 9)) + + hash * 37; + + for (SmallVector<uint32_t, 4>::const_iterator I = e.varargs.begin(), + E = e.varargs.end(); I != E; ++I) + hash = *I + hash * 37; + + hash = ((unsigned)((uintptr_t)e.function >> 4) ^ + (unsigned)((uintptr_t)e.function >> 9)) + + hash * 37; + + return hash; + } + static bool isEqual(const Expression &LHS, const Expression &RHS) { + return LHS == RHS; + } + static bool isPod() { return true; } +}; +} + +//===----------------------------------------------------------------------===// +// ValueTable Internal Functions +//===----------------------------------------------------------------------===// +Expression::ExpressionOpcode ValueTable::getOpcode(BinaryOperator* BO) { + switch(BO->getOpcode()) { + default: // THIS SHOULD NEVER HAPPEN + assert(0 && "Binary operator with unknown opcode?"); + case Instruction::Add: return Expression::ADD; + case Instruction::Sub: return Expression::SUB; + case Instruction::Mul: return Expression::MUL; + case Instruction::UDiv: return Expression::UDIV; + case Instruction::SDiv: return Expression::SDIV; + case Instruction::FDiv: return Expression::FDIV; + case Instruction::URem: return Expression::UREM; + case Instruction::SRem: return Expression::SREM; + case Instruction::FRem: return Expression::FREM; + case Instruction::Shl: return Expression::SHL; + case Instruction::LShr: return Expression::LSHR; + case Instruction::AShr: return Expression::ASHR; + case Instruction::And: return Expression::AND; + case Instruction::Or: return Expression::OR; + case Instruction::Xor: return Expression::XOR; + } +} + +Expression::ExpressionOpcode ValueTable::getOpcode(CmpInst* C) { + if (isa<ICmpInst>(C) || isa<VICmpInst>(C)) { + switch (C->getPredicate()) { + default: // THIS SHOULD NEVER HAPPEN + assert(0 && "Comparison with unknown predicate?"); + case ICmpInst::ICMP_EQ: return Expression::ICMPEQ; + case ICmpInst::ICMP_NE: return Expression::ICMPNE; + case ICmpInst::ICMP_UGT: return Expression::ICMPUGT; + case ICmpInst::ICMP_UGE: return Expression::ICMPUGE; + case ICmpInst::ICMP_ULT: return Expression::ICMPULT; + case ICmpInst::ICMP_ULE: return Expression::ICMPULE; + case ICmpInst::ICMP_SGT: return Expression::ICMPSGT; + case ICmpInst::ICMP_SGE: return Expression::ICMPSGE; + case ICmpInst::ICMP_SLT: return Expression::ICMPSLT; + case ICmpInst::ICMP_SLE: return Expression::ICMPSLE; + } + } + assert((isa<FCmpInst>(C) || isa<VFCmpInst>(C)) && "Unknown compare"); + switch (C->getPredicate()) { + default: // THIS SHOULD NEVER HAPPEN + assert(0 && "Comparison with unknown predicate?"); + case FCmpInst::FCMP_OEQ: return Expression::FCMPOEQ; + case FCmpInst::FCMP_OGT: return Expression::FCMPOGT; + case FCmpInst::FCMP_OGE: return Expression::FCMPOGE; + case FCmpInst::FCMP_OLT: return Expression::FCMPOLT; + case FCmpInst::FCMP_OLE: return Expression::FCMPOLE; + case FCmpInst::FCMP_ONE: return Expression::FCMPONE; + case FCmpInst::FCMP_ORD: return Expression::FCMPORD; + case FCmpInst::FCMP_UNO: return Expression::FCMPUNO; + case FCmpInst::FCMP_UEQ: return Expression::FCMPUEQ; + case FCmpInst::FCMP_UGT: return Expression::FCMPUGT; + case FCmpInst::FCMP_UGE: return Expression::FCMPUGE; + case FCmpInst::FCMP_ULT: return Expression::FCMPULT; + case FCmpInst::FCMP_ULE: return Expression::FCMPULE; + case FCmpInst::FCMP_UNE: return Expression::FCMPUNE; + } +} + +Expression::ExpressionOpcode ValueTable::getOpcode(CastInst* C) { + switch(C->getOpcode()) { + default: // THIS SHOULD NEVER HAPPEN + assert(0 && "Cast operator with unknown opcode?"); + case Instruction::Trunc: return Expression::TRUNC; + case Instruction::ZExt: return Expression::ZEXT; + case Instruction::SExt: return Expression::SEXT; + case Instruction::FPToUI: return Expression::FPTOUI; + case Instruction::FPToSI: return Expression::FPTOSI; + case Instruction::UIToFP: return Expression::UITOFP; + case Instruction::SIToFP: return Expression::SITOFP; + case Instruction::FPTrunc: return Expression::FPTRUNC; + case Instruction::FPExt: return Expression::FPEXT; + case Instruction::PtrToInt: return Expression::PTRTOINT; + case Instruction::IntToPtr: return Expression::INTTOPTR; + case Instruction::BitCast: return Expression::BITCAST; + } +} + +Expression ValueTable::create_expression(CallInst* C) { + Expression e; + + e.type = C->getType(); + e.firstVN = 0; + e.secondVN = 0; + e.thirdVN = 0; + e.function = C->getCalledFunction(); + e.opcode = Expression::CALL; + + for (CallInst::op_iterator I = C->op_begin()+1, E = C->op_end(); + I != E; ++I) + e.varargs.push_back(lookup_or_add(*I)); + + return e; +} + +Expression ValueTable::create_expression(BinaryOperator* BO) { + Expression e; + + e.firstVN = lookup_or_add(BO->getOperand(0)); + e.secondVN = lookup_or_add(BO->getOperand(1)); + e.thirdVN = 0; + e.function = 0; + e.type = BO->getType(); + e.opcode = getOpcode(BO); + + return e; +} + +Expression ValueTable::create_expression(CmpInst* C) { + Expression e; + + e.firstVN = lookup_or_add(C->getOperand(0)); + e.secondVN = lookup_or_add(C->getOperand(1)); + e.thirdVN = 0; + e.function = 0; + e.type = C->getType(); + e.opcode = getOpcode(C); + + return e; +} + +Expression ValueTable::create_expression(CastInst* C) { + Expression e; + + e.firstVN = lookup_or_add(C->getOperand(0)); + e.secondVN = 0; + e.thirdVN = 0; + e.function = 0; + e.type = C->getType(); + e.opcode = getOpcode(C); + + return e; +} + +Expression ValueTable::create_expression(ShuffleVectorInst* S) { + Expression e; + + e.firstVN = lookup_or_add(S->getOperand(0)); + e.secondVN = lookup_or_add(S->getOperand(1)); + e.thirdVN = lookup_or_add(S->getOperand(2)); + e.function = 0; + e.type = S->getType(); + e.opcode = Expression::SHUFFLE; + + return e; +} + +Expression ValueTable::create_expression(ExtractElementInst* E) { + Expression e; + + e.firstVN = lookup_or_add(E->getOperand(0)); + e.secondVN = lookup_or_add(E->getOperand(1)); + e.thirdVN = 0; + e.function = 0; + e.type = E->getType(); + e.opcode = Expression::EXTRACT; + + return e; +} + +Expression ValueTable::create_expression(InsertElementInst* I) { + Expression e; + + e.firstVN = lookup_or_add(I->getOperand(0)); + e.secondVN = lookup_or_add(I->getOperand(1)); + e.thirdVN = lookup_or_add(I->getOperand(2)); + e.function = 0; + e.type = I->getType(); + e.opcode = Expression::INSERT; + + return e; +} + +Expression ValueTable::create_expression(SelectInst* I) { + Expression e; + + e.firstVN = lookup_or_add(I->getCondition()); + e.secondVN = lookup_or_add(I->getTrueValue()); + e.thirdVN = lookup_or_add(I->getFalseValue()); + e.function = 0; + e.type = I->getType(); + e.opcode = Expression::SELECT; + + return e; +} + +Expression ValueTable::create_expression(GetElementPtrInst* G) { + Expression e; + + e.firstVN = lookup_or_add(G->getPointerOperand()); + e.secondVN = 0; + e.thirdVN = 0; + e.function = 0; + e.type = G->getType(); + e.opcode = Expression::GEP; + + for (GetElementPtrInst::op_iterator I = G->idx_begin(), E = G->idx_end(); + I != E; ++I) + e.varargs.push_back(lookup_or_add(*I)); + + return e; +} + +//===----------------------------------------------------------------------===// +// ValueTable External Functions +//===----------------------------------------------------------------------===// + +/// add - Insert a value into the table with a specified value number. +void ValueTable::add(Value* V, uint32_t num) { + valueNumbering.insert(std::make_pair(V, num)); +} + +/// lookup_or_add - Returns the value number for the specified value, assigning +/// it a new number if it did not have one before. +uint32_t ValueTable::lookup_or_add(Value* V) { + DenseMap<Value*, uint32_t>::iterator VI = valueNumbering.find(V); + if (VI != valueNumbering.end()) + return VI->second; + + if (CallInst* C = dyn_cast<CallInst>(V)) { + if (AA->doesNotAccessMemory(C)) { + Expression e = create_expression(C); + + DenseMap<Expression, uint32_t>::iterator EI = expressionNumbering.find(e); + if (EI != expressionNumbering.end()) { + valueNumbering.insert(std::make_pair(V, EI->second)); + return EI->second; + } else { + expressionNumbering.insert(std::make_pair(e, nextValueNumber)); + valueNumbering.insert(std::make_pair(V, nextValueNumber)); + + return nextValueNumber++; + } + } else if (AA->onlyReadsMemory(C)) { + Expression e = create_expression(C); + + if (expressionNumbering.find(e) == expressionNumbering.end()) { + expressionNumbering.insert(std::make_pair(e, nextValueNumber)); + valueNumbering.insert(std::make_pair(V, nextValueNumber)); + return nextValueNumber++; + } + + MemDepResult local_dep = MD->getDependency(C); + + if (!local_dep.isDef() && !local_dep.isNonLocal()) { + valueNumbering.insert(std::make_pair(V, nextValueNumber)); + return nextValueNumber++; + } + + if (local_dep.isDef()) { + CallInst* local_cdep = cast<CallInst>(local_dep.getInst()); + + if (local_cdep->getNumOperands() != C->getNumOperands()) { + valueNumbering.insert(std::make_pair(V, nextValueNumber)); + return nextValueNumber++; + } + + for (unsigned i = 1; i < C->getNumOperands(); ++i) { + uint32_t c_vn = lookup_or_add(C->getOperand(i)); + uint32_t cd_vn = lookup_or_add(local_cdep->getOperand(i)); + if (c_vn != cd_vn) { + valueNumbering.insert(std::make_pair(V, nextValueNumber)); + return nextValueNumber++; + } + } + + uint32_t v = lookup_or_add(local_cdep); + valueNumbering.insert(std::make_pair(V, v)); + return v; + } + + // Non-local case. + const MemoryDependenceAnalysis::NonLocalDepInfo &deps = + MD->getNonLocalCallDependency(CallSite(C)); + // FIXME: call/call dependencies for readonly calls should return def, not + // clobber! Move the checking logic to MemDep! + CallInst* cdep = 0; + + // Check to see if we have a single dominating call instruction that is + // identical to C. + for (unsigned i = 0, e = deps.size(); i != e; ++i) { + const MemoryDependenceAnalysis::NonLocalDepEntry *I = &deps[i]; + // Ignore non-local dependencies. + if (I->second.isNonLocal()) + continue; + + // We don't handle non-depedencies. If we already have a call, reject + // instruction dependencies. + if (I->second.isClobber() || cdep != 0) { + cdep = 0; + break; + } + + CallInst *NonLocalDepCall = dyn_cast<CallInst>(I->second.getInst()); + // FIXME: All duplicated with non-local case. + if (NonLocalDepCall && DT->properlyDominates(I->first, C->getParent())){ + cdep = NonLocalDepCall; + continue; + } + + cdep = 0; + break; + } + + if (!cdep) { + valueNumbering.insert(std::make_pair(V, nextValueNumber)); + return nextValueNumber++; + } + + if (cdep->getNumOperands() != C->getNumOperands()) { + valueNumbering.insert(std::make_pair(V, nextValueNumber)); + return nextValueNumber++; + } + for (unsigned i = 1; i < C->getNumOperands(); ++i) { + uint32_t c_vn = lookup_or_add(C->getOperand(i)); + uint32_t cd_vn = lookup_or_add(cdep->getOperand(i)); + if (c_vn != cd_vn) { + valueNumbering.insert(std::make_pair(V, nextValueNumber)); + return nextValueNumber++; + } + } + + uint32_t v = lookup_or_add(cdep); + valueNumbering.insert(std::make_pair(V, v)); + return v; + + } else { + valueNumbering.insert(std::make_pair(V, nextValueNumber)); + return nextValueNumber++; + } + } else if (BinaryOperator* BO = dyn_cast<BinaryOperator>(V)) { + Expression e = create_expression(BO); + + DenseMap<Expression, uint32_t>::iterator EI = expressionNumbering.find(e); + if (EI != expressionNumbering.end()) { + valueNumbering.insert(std::make_pair(V, EI->second)); + return EI->second; + } else { + expressionNumbering.insert(std::make_pair(e, nextValueNumber)); + valueNumbering.insert(std::make_pair(V, nextValueNumber)); + + return nextValueNumber++; + } + } else if (CmpInst* C = dyn_cast<CmpInst>(V)) { + Expression e = create_expression(C); + + DenseMap<Expression, uint32_t>::iterator EI = expressionNumbering.find(e); + if (EI != expressionNumbering.end()) { + valueNumbering.insert(std::make_pair(V, EI->second)); + return EI->second; + } else { + expressionNumbering.insert(std::make_pair(e, nextValueNumber)); + valueNumbering.insert(std::make_pair(V, nextValueNumber)); + + return nextValueNumber++; + } + } else if (ShuffleVectorInst* U = dyn_cast<ShuffleVectorInst>(V)) { + Expression e = create_expression(U); + + DenseMap<Expression, uint32_t>::iterator EI = expressionNumbering.find(e); + if (EI != expressionNumbering.end()) { + valueNumbering.insert(std::make_pair(V, EI->second)); + return EI->second; + } else { + expressionNumbering.insert(std::make_pair(e, nextValueNumber)); + valueNumbering.insert(std::make_pair(V, nextValueNumber)); + + return nextValueNumber++; + } + } else if (ExtractElementInst* U = dyn_cast<ExtractElementInst>(V)) { + Expression e = create_expression(U); + + DenseMap<Expression, uint32_t>::iterator EI = expressionNumbering.find(e); + if (EI != expressionNumbering.end()) { + valueNumbering.insert(std::make_pair(V, EI->second)); + return EI->second; + } else { + expressionNumbering.insert(std::make_pair(e, nextValueNumber)); + valueNumbering.insert(std::make_pair(V, nextValueNumber)); + + return nextValueNumber++; + } + } else if (InsertElementInst* U = dyn_cast<InsertElementInst>(V)) { + Expression e = create_expression(U); + + DenseMap<Expression, uint32_t>::iterator EI = expressionNumbering.find(e); + if (EI != expressionNumbering.end()) { + valueNumbering.insert(std::make_pair(V, EI->second)); + return EI->second; + } else { + expressionNumbering.insert(std::make_pair(e, nextValueNumber)); + valueNumbering.insert(std::make_pair(V, nextValueNumber)); + + return nextValueNumber++; + } + } else if (SelectInst* U = dyn_cast<SelectInst>(V)) { + Expression e = create_expression(U); + + DenseMap<Expression, uint32_t>::iterator EI = expressionNumbering.find(e); + if (EI != expressionNumbering.end()) { + valueNumbering.insert(std::make_pair(V, EI->second)); + return EI->second; + } else { + expressionNumbering.insert(std::make_pair(e, nextValueNumber)); + valueNumbering.insert(std::make_pair(V, nextValueNumber)); + + return nextValueNumber++; + } + } else if (CastInst* U = dyn_cast<CastInst>(V)) { + Expression e = create_expression(U); + + DenseMap<Expression, uint32_t>::iterator EI = expressionNumbering.find(e); + if (EI != expressionNumbering.end()) { + valueNumbering.insert(std::make_pair(V, EI->second)); + return EI->second; + } else { + expressionNumbering.insert(std::make_pair(e, nextValueNumber)); + valueNumbering.insert(std::make_pair(V, nextValueNumber)); + + return nextValueNumber++; + } + } else if (GetElementPtrInst* U = dyn_cast<GetElementPtrInst>(V)) { + Expression e = create_expression(U); + + DenseMap<Expression, uint32_t>::iterator EI = expressionNumbering.find(e); + if (EI != expressionNumbering.end()) { + valueNumbering.insert(std::make_pair(V, EI->second)); + return EI->second; + } else { + expressionNumbering.insert(std::make_pair(e, nextValueNumber)); + valueNumbering.insert(std::make_pair(V, nextValueNumber)); + + return nextValueNumber++; + } + } else { + valueNumbering.insert(std::make_pair(V, nextValueNumber)); + return nextValueNumber++; + } +} + +/// lookup - Returns the value number of the specified value. Fails if +/// the value has not yet been numbered. +uint32_t ValueTable::lookup(Value* V) const { + DenseMap<Value*, uint32_t>::iterator VI = valueNumbering.find(V); + assert(VI != valueNumbering.end() && "Value not numbered?"); + return VI->second; +} + +/// clear - Remove all entries from the ValueTable +void ValueTable::clear() { + valueNumbering.clear(); + expressionNumbering.clear(); + nextValueNumber = 1; +} + +/// erase - Remove a value from the value numbering +void ValueTable::erase(Value* V) { + valueNumbering.erase(V); +} + +/// verifyRemoved - Verify that the value is removed from all internal data +/// structures. +void ValueTable::verifyRemoved(const Value *V) const { + for (DenseMap<Value*, uint32_t>::iterator + I = valueNumbering.begin(), E = valueNumbering.end(); I != E; ++I) { + assert(I->first != V && "Inst still occurs in value numbering map!"); + } +} + +//===----------------------------------------------------------------------===// +// GVN Pass +//===----------------------------------------------------------------------===// + +namespace { + struct VISIBILITY_HIDDEN ValueNumberScope { + ValueNumberScope* parent; + DenseMap<uint32_t, Value*> table; + + ValueNumberScope(ValueNumberScope* p) : parent(p) { } + }; +} + +namespace { + + class VISIBILITY_HIDDEN GVN : public FunctionPass { + bool runOnFunction(Function &F); + public: + static char ID; // Pass identification, replacement for typeid + GVN() : FunctionPass(&ID) { } + + private: + MemoryDependenceAnalysis *MD; + DominatorTree *DT; + + ValueTable VN; + DenseMap<BasicBlock*, ValueNumberScope*> localAvail; + + typedef DenseMap<Value*, SmallPtrSet<Instruction*, 4> > PhiMapType; + PhiMapType phiMap; + + + // This transformation requires dominator postdominator info + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired<DominatorTree>(); + AU.addRequired<MemoryDependenceAnalysis>(); + AU.addRequired<AliasAnalysis>(); + + AU.addPreserved<DominatorTree>(); + AU.addPreserved<AliasAnalysis>(); + } + + // Helper fuctions + // FIXME: eliminate or document these better + bool processLoad(LoadInst* L, + SmallVectorImpl<Instruction*> &toErase); + bool processInstruction(Instruction* I, + SmallVectorImpl<Instruction*> &toErase); + bool processNonLocalLoad(LoadInst* L, + SmallVectorImpl<Instruction*> &toErase); + bool processBlock(BasicBlock* BB); + Value *GetValueForBlock(BasicBlock *BB, Instruction* orig, + DenseMap<BasicBlock*, Value*> &Phis, + bool top_level = false); + void dump(DenseMap<uint32_t, Value*>& d); + bool iterateOnFunction(Function &F); + Value* CollapsePhi(PHINode* p); + bool isSafeReplacement(PHINode* p, Instruction* inst); + bool performPRE(Function& F); + Value* lookupNumber(BasicBlock* BB, uint32_t num); + bool mergeBlockIntoPredecessor(BasicBlock* BB); + Value* AttemptRedundancyElimination(Instruction* orig, unsigned valno); + void cleanupGlobalSets(); + void verifyRemoved(const Instruction *I) const; + }; + + char GVN::ID = 0; +} + +// createGVNPass - The public interface to this file... +FunctionPass *llvm::createGVNPass() { return new GVN(); } + +static RegisterPass<GVN> X("gvn", + "Global Value Numbering"); + +void GVN::dump(DenseMap<uint32_t, Value*>& d) { + printf("{\n"); + for (DenseMap<uint32_t, Value*>::iterator I = d.begin(), + E = d.end(); I != E; ++I) { + printf("%d\n", I->first); + I->second->dump(); + } + printf("}\n"); +} + +Value* GVN::CollapsePhi(PHINode* p) { + Value* constVal = p->hasConstantValue(); + if (!constVal) return 0; + + Instruction* inst = dyn_cast<Instruction>(constVal); + if (!inst) + return constVal; + + if (DT->dominates(inst, p)) + if (isSafeReplacement(p, inst)) + return inst; + return 0; +} + +bool GVN::isSafeReplacement(PHINode* p, Instruction* inst) { + if (!isa<PHINode>(inst)) + return true; + + for (Instruction::use_iterator UI = p->use_begin(), E = p->use_end(); + UI != E; ++UI) + if (PHINode* use_phi = dyn_cast<PHINode>(UI)) + if (use_phi->getParent() == inst->getParent()) + return false; + + return true; +} + +/// GetValueForBlock - Get the value to use within the specified basic block. +/// available values are in Phis. +Value *GVN::GetValueForBlock(BasicBlock *BB, Instruction* orig, + DenseMap<BasicBlock*, Value*> &Phis, + bool top_level) { + + // If we have already computed this value, return the previously computed val. + DenseMap<BasicBlock*, Value*>::iterator V = Phis.find(BB); + if (V != Phis.end() && !top_level) return V->second; + + // If the block is unreachable, just return undef, since this path + // can't actually occur at runtime. + if (!DT->isReachableFromEntry(BB)) + return Phis[BB] = UndefValue::get(orig->getType()); + + if (BasicBlock *Pred = BB->getSinglePredecessor()) { + Value *ret = GetValueForBlock(Pred, orig, Phis); + Phis[BB] = ret; + return ret; + } + + // Get the number of predecessors of this block so we can reserve space later. + // If there is already a PHI in it, use the #preds from it, otherwise count. + // Getting it from the PHI is constant time. + unsigned NumPreds; + if (PHINode *ExistingPN = dyn_cast<PHINode>(BB->begin())) + NumPreds = ExistingPN->getNumIncomingValues(); + else + NumPreds = std::distance(pred_begin(BB), pred_end(BB)); + + // Otherwise, the idom is the loop, so we need to insert a PHI node. Do so + // now, then get values to fill in the incoming values for the PHI. + PHINode *PN = PHINode::Create(orig->getType(), orig->getName()+".rle", + BB->begin()); + PN->reserveOperandSpace(NumPreds); + + Phis.insert(std::make_pair(BB, PN)); + + // Fill in the incoming values for the block. + for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) { + Value* val = GetValueForBlock(*PI, orig, Phis); + PN->addIncoming(val, *PI); + } + + VN.getAliasAnalysis()->copyValue(orig, PN); + + // Attempt to collapse PHI nodes that are trivially redundant + Value* v = CollapsePhi(PN); + if (!v) { + // Cache our phi construction results + if (LoadInst* L = dyn_cast<LoadInst>(orig)) + phiMap[L->getPointerOperand()].insert(PN); + else + phiMap[orig].insert(PN); + + return PN; + } + + PN->replaceAllUsesWith(v); + if (isa<PointerType>(v->getType())) + MD->invalidateCachedPointerInfo(v); + + for (DenseMap<BasicBlock*, Value*>::iterator I = Phis.begin(), + E = Phis.end(); I != E; ++I) + if (I->second == PN) + I->second = v; + + DEBUG(cerr << "GVN removed: " << *PN); + MD->removeInstruction(PN); + PN->eraseFromParent(); + DEBUG(verifyRemoved(PN)); + + Phis[BB] = v; + return v; +} + +/// IsValueFullyAvailableInBlock - Return true if we can prove that the value +/// we're analyzing is fully available in the specified block. As we go, keep +/// track of which blocks we know are fully alive in FullyAvailableBlocks. This +/// map is actually a tri-state map with the following values: +/// 0) we know the block *is not* fully available. +/// 1) we know the block *is* fully available. +/// 2) we do not know whether the block is fully available or not, but we are +/// currently speculating that it will be. +/// 3) we are speculating for this block and have used that to speculate for +/// other blocks. +static bool IsValueFullyAvailableInBlock(BasicBlock *BB, + DenseMap<BasicBlock*, char> &FullyAvailableBlocks) { + // Optimistically assume that the block is fully available and check to see + // if we already know about this block in one lookup. + std::pair<DenseMap<BasicBlock*, char>::iterator, char> IV = + FullyAvailableBlocks.insert(std::make_pair(BB, 2)); + + // If the entry already existed for this block, return the precomputed value. + if (!IV.second) { + // If this is a speculative "available" value, mark it as being used for + // speculation of other blocks. + if (IV.first->second == 2) + IV.first->second = 3; + return IV.first->second != 0; + } + + // Otherwise, see if it is fully available in all predecessors. + pred_iterator PI = pred_begin(BB), PE = pred_end(BB); + + // If this block has no predecessors, it isn't live-in here. + if (PI == PE) + goto SpeculationFailure; + + for (; PI != PE; ++PI) + // If the value isn't fully available in one of our predecessors, then it + // isn't fully available in this block either. Undo our previous + // optimistic assumption and bail out. + if (!IsValueFullyAvailableInBlock(*PI, FullyAvailableBlocks)) + goto SpeculationFailure; + + return true; + +// SpeculationFailure - If we get here, we found out that this is not, after +// all, a fully-available block. We have a problem if we speculated on this and +// used the speculation to mark other blocks as available. +SpeculationFailure: + char &BBVal = FullyAvailableBlocks[BB]; + + // If we didn't speculate on this, just return with it set to false. + if (BBVal == 2) { + BBVal = 0; + return false; + } + + // If we did speculate on this value, we could have blocks set to 1 that are + // incorrect. Walk the (transitive) successors of this block and mark them as + // 0 if set to one. + SmallVector<BasicBlock*, 32> BBWorklist; + BBWorklist.push_back(BB); + + while (!BBWorklist.empty()) { + BasicBlock *Entry = BBWorklist.pop_back_val(); + // Note that this sets blocks to 0 (unavailable) if they happen to not + // already be in FullyAvailableBlocks. This is safe. + char &EntryVal = FullyAvailableBlocks[Entry]; + if (EntryVal == 0) continue; // Already unavailable. + + // Mark as unavailable. + EntryVal = 0; + + for (succ_iterator I = succ_begin(Entry), E = succ_end(Entry); I != E; ++I) + BBWorklist.push_back(*I); + } + + return false; +} + +/// processNonLocalLoad - Attempt to eliminate a load whose dependencies are +/// non-local by performing PHI construction. +bool GVN::processNonLocalLoad(LoadInst *LI, + SmallVectorImpl<Instruction*> &toErase) { + // Find the non-local dependencies of the load. + SmallVector<MemoryDependenceAnalysis::NonLocalDepEntry, 64> Deps; + MD->getNonLocalPointerDependency(LI->getOperand(0), true, LI->getParent(), + Deps); + //DEBUG(cerr << "INVESTIGATING NONLOCAL LOAD: " << Deps.size() << *LI); + + // If we had to process more than one hundred blocks to find the + // dependencies, this load isn't worth worrying about. Optimizing + // it will be too expensive. + if (Deps.size() > 100) + return false; + + // If we had a phi translation failure, we'll have a single entry which is a + // clobber in the current block. Reject this early. + if (Deps.size() == 1 && Deps[0].second.isClobber()) + return false; + + // Filter out useless results (non-locals, etc). Keep track of the blocks + // where we have a value available in repl, also keep track of whether we see + // dependencies that produce an unknown value for the load (such as a call + // that could potentially clobber the load). + SmallVector<std::pair<BasicBlock*, Value*>, 16> ValuesPerBlock; + SmallVector<BasicBlock*, 16> UnavailableBlocks; + + for (unsigned i = 0, e = Deps.size(); i != e; ++i) { + BasicBlock *DepBB = Deps[i].first; + MemDepResult DepInfo = Deps[i].second; + + if (DepInfo.isClobber()) { + UnavailableBlocks.push_back(DepBB); + continue; + } + + Instruction *DepInst = DepInfo.getInst(); + + // Loading the allocation -> undef. + if (isa<AllocationInst>(DepInst)) { + ValuesPerBlock.push_back(std::make_pair(DepBB, + UndefValue::get(LI->getType()))); + continue; + } + + if (StoreInst* S = dyn_cast<StoreInst>(DepInst)) { + // Reject loads and stores that are to the same address but are of + // different types. + // NOTE: 403.gcc does have this case (e.g. in readonly_fields_p) because + // of bitfield access, it would be interesting to optimize for it at some + // point. + if (S->getOperand(0)->getType() != LI->getType()) { + UnavailableBlocks.push_back(DepBB); + continue; + } + + ValuesPerBlock.push_back(std::make_pair(DepBB, S->getOperand(0))); + + } else if (LoadInst* LD = dyn_cast<LoadInst>(DepInst)) { + if (LD->getType() != LI->getType()) { + UnavailableBlocks.push_back(DepBB); + continue; + } + ValuesPerBlock.push_back(std::make_pair(DepBB, LD)); + } else { + UnavailableBlocks.push_back(DepBB); + continue; + } + } + + // If we have no predecessors that produce a known value for this load, exit + // early. + if (ValuesPerBlock.empty()) return false; + + // If all of the instructions we depend on produce a known value for this + // load, then it is fully redundant and we can use PHI insertion to compute + // its value. Insert PHIs and remove the fully redundant value now. + if (UnavailableBlocks.empty()) { + // Use cached PHI construction information from previous runs + SmallPtrSet<Instruction*, 4> &p = phiMap[LI->getPointerOperand()]; + // FIXME: What does phiMap do? Are we positive it isn't getting invalidated? + for (SmallPtrSet<Instruction*, 4>::iterator I = p.begin(), E = p.end(); + I != E; ++I) { + if ((*I)->getParent() == LI->getParent()) { + DEBUG(cerr << "GVN REMOVING NONLOCAL LOAD #1: " << *LI); + LI->replaceAllUsesWith(*I); + if (isa<PointerType>((*I)->getType())) + MD->invalidateCachedPointerInfo(*I); + toErase.push_back(LI); + NumGVNLoad++; + return true; + } + + ValuesPerBlock.push_back(std::make_pair((*I)->getParent(), *I)); + } + + DEBUG(cerr << "GVN REMOVING NONLOCAL LOAD: " << *LI); + + DenseMap<BasicBlock*, Value*> BlockReplValues; + BlockReplValues.insert(ValuesPerBlock.begin(), ValuesPerBlock.end()); + // Perform PHI construction. + Value* v = GetValueForBlock(LI->getParent(), LI, BlockReplValues, true); + LI->replaceAllUsesWith(v); + + if (isa<PHINode>(v)) + v->takeName(LI); + if (isa<PointerType>(v->getType())) + MD->invalidateCachedPointerInfo(v); + toErase.push_back(LI); + NumGVNLoad++; + return true; + } + + if (!EnablePRE || !EnableLoadPRE) + return false; + + // Okay, we have *some* definitions of the value. This means that the value + // is available in some of our (transitive) predecessors. Lets think about + // doing PRE of this load. This will involve inserting a new load into the + // predecessor when it's not available. We could do this in general, but + // prefer to not increase code size. As such, we only do this when we know + // that we only have to insert *one* load (which means we're basically moving + // the load, not inserting a new one). + + SmallPtrSet<BasicBlock *, 4> Blockers; + for (unsigned i = 0, e = UnavailableBlocks.size(); i != e; ++i) + Blockers.insert(UnavailableBlocks[i]); + + // Lets find first basic block with more than one predecessor. Walk backwards + // through predecessors if needed. + BasicBlock *LoadBB = LI->getParent(); + BasicBlock *TmpBB = LoadBB; + + bool isSinglePred = false; + while (TmpBB->getSinglePredecessor()) { + isSinglePred = true; + TmpBB = TmpBB->getSinglePredecessor(); + if (!TmpBB) // If haven't found any, bail now. + return false; + if (TmpBB == LoadBB) // Infinite (unreachable) loop. + return false; + if (Blockers.count(TmpBB)) + return false; + } + + assert(TmpBB); + LoadBB = TmpBB; + + // If we have a repl set with LI itself in it, this means we have a loop where + // at least one of the values is LI. Since this means that we won't be able + // to eliminate LI even if we insert uses in the other predecessors, we will + // end up increasing code size. Reject this by scanning for LI. + for (unsigned i = 0, e = ValuesPerBlock.size(); i != e; ++i) + if (ValuesPerBlock[i].second == LI) + return false; + + if (isSinglePred) { + bool isHot = false; + for (unsigned i = 0, e = ValuesPerBlock.size(); i != e; ++i) + if (Instruction *I = dyn_cast<Instruction>(ValuesPerBlock[i].second)) + // "Hot" Instruction is in some loop (because it dominates its dep. + // instruction). + if (DT->dominates(LI, I)) { + isHot = true; + break; + } + + // We are interested only in "hot" instructions. We don't want to do any + // mis-optimizations here. + if (!isHot) + return false; + } + + // Okay, we have some hope :). Check to see if the loaded value is fully + // available in all but one predecessor. + // FIXME: If we could restructure the CFG, we could make a common pred with + // all the preds that don't have an available LI and insert a new load into + // that one block. + BasicBlock *UnavailablePred = 0; + + DenseMap<BasicBlock*, char> FullyAvailableBlocks; + for (unsigned i = 0, e = ValuesPerBlock.size(); i != e; ++i) + FullyAvailableBlocks[ValuesPerBlock[i].first] = true; + for (unsigned i = 0, e = UnavailableBlocks.size(); i != e; ++i) + FullyAvailableBlocks[UnavailableBlocks[i]] = false; + + for (pred_iterator PI = pred_begin(LoadBB), E = pred_end(LoadBB); + PI != E; ++PI) { + if (IsValueFullyAvailableInBlock(*PI, FullyAvailableBlocks)) + continue; + + // If this load is not available in multiple predecessors, reject it. + if (UnavailablePred && UnavailablePred != *PI) + return false; + UnavailablePred = *PI; + } + + assert(UnavailablePred != 0 && + "Fully available value should be eliminated above!"); + + // If the loaded pointer is PHI node defined in this block, do PHI translation + // to get its value in the predecessor. + Value *LoadPtr = LI->getOperand(0)->DoPHITranslation(LoadBB, UnavailablePred); + + // Make sure the value is live in the predecessor. If it was defined by a + // non-PHI instruction in this block, we don't know how to recompute it above. + if (Instruction *LPInst = dyn_cast<Instruction>(LoadPtr)) + if (!DT->dominates(LPInst->getParent(), UnavailablePred)) { + DEBUG(cerr << "COULDN'T PRE LOAD BECAUSE PTR IS UNAVAILABLE IN PRED: " + << *LPInst << *LI << "\n"); + return false; + } + + // We don't currently handle critical edges :( + if (UnavailablePred->getTerminator()->getNumSuccessors() != 1) { + DEBUG(cerr << "COULD NOT PRE LOAD BECAUSE OF CRITICAL EDGE '" + << UnavailablePred->getName() << "': " << *LI); + return false; + } + + // Okay, we can eliminate this load by inserting a reload in the predecessor + // and using PHI construction to get the value in the other predecessors, do + // it. + DEBUG(cerr << "GVN REMOVING PRE LOAD: " << *LI); + + Value *NewLoad = new LoadInst(LoadPtr, LI->getName()+".pre", false, + LI->getAlignment(), + UnavailablePred->getTerminator()); + + SmallPtrSet<Instruction*, 4> &p = phiMap[LI->getPointerOperand()]; + for (SmallPtrSet<Instruction*, 4>::iterator I = p.begin(), E = p.end(); + I != E; ++I) + ValuesPerBlock.push_back(std::make_pair((*I)->getParent(), *I)); + + DenseMap<BasicBlock*, Value*> BlockReplValues; + BlockReplValues.insert(ValuesPerBlock.begin(), ValuesPerBlock.end()); + BlockReplValues[UnavailablePred] = NewLoad; + + // Perform PHI construction. + Value* v = GetValueForBlock(LI->getParent(), LI, BlockReplValues, true); + LI->replaceAllUsesWith(v); + if (isa<PHINode>(v)) + v->takeName(LI); + if (isa<PointerType>(v->getType())) + MD->invalidateCachedPointerInfo(v); + toErase.push_back(LI); + NumPRELoad++; + return true; +} + +/// processLoad - Attempt to eliminate a load, first by eliminating it +/// locally, and then attempting non-local elimination if that fails. +bool GVN::processLoad(LoadInst *L, SmallVectorImpl<Instruction*> &toErase) { + if (L->isVolatile()) + return false; + + Value* pointer = L->getPointerOperand(); + + // ... to a pointer that has been loaded from before... + MemDepResult dep = MD->getDependency(L); + + // If the value isn't available, don't do anything! + if (dep.isClobber()) { + DEBUG( + // fast print dep, using operator<< on instruction would be too slow + DOUT << "GVN: load "; + WriteAsOperand(*DOUT.stream(), L); + Instruction *I = dep.getInst(); + DOUT << " is clobbered by " << *I; + ); + return false; + } + + // If it is defined in another block, try harder. + if (dep.isNonLocal()) + return processNonLocalLoad(L, toErase); + + Instruction *DepInst = dep.getInst(); + if (StoreInst *DepSI = dyn_cast<StoreInst>(DepInst)) { + // Only forward substitute stores to loads of the same type. + // FIXME: Could do better! + if (DepSI->getPointerOperand()->getType() != pointer->getType()) + return false; + + // Remove it! + L->replaceAllUsesWith(DepSI->getOperand(0)); + if (isa<PointerType>(DepSI->getOperand(0)->getType())) + MD->invalidateCachedPointerInfo(DepSI->getOperand(0)); + toErase.push_back(L); + NumGVNLoad++; + return true; + } + + if (LoadInst *DepLI = dyn_cast<LoadInst>(DepInst)) { + // Only forward substitute stores to loads of the same type. + // FIXME: Could do better! load i32 -> load i8 -> truncate on little endian. + if (DepLI->getType() != L->getType()) + return false; + + // Remove it! + L->replaceAllUsesWith(DepLI); + if (isa<PointerType>(DepLI->getType())) + MD->invalidateCachedPointerInfo(DepLI); + toErase.push_back(L); + NumGVNLoad++; + return true; + } + + // If this load really doesn't depend on anything, then we must be loading an + // undef value. This can happen when loading for a fresh allocation with no + // intervening stores, for example. + if (isa<AllocationInst>(DepInst)) { + L->replaceAllUsesWith(UndefValue::get(L->getType())); + toErase.push_back(L); + NumGVNLoad++; + return true; + } + + return false; +} + +Value* GVN::lookupNumber(BasicBlock* BB, uint32_t num) { + DenseMap<BasicBlock*, ValueNumberScope*>::iterator I = localAvail.find(BB); + if (I == localAvail.end()) + return 0; + + ValueNumberScope* locals = I->second; + + while (locals) { + DenseMap<uint32_t, Value*>::iterator I = locals->table.find(num); + if (I != locals->table.end()) + return I->second; + else + locals = locals->parent; + } + + return 0; +} + +/// AttemptRedundancyElimination - If the "fast path" of redundancy elimination +/// by inheritance from the dominator fails, see if we can perform phi +/// construction to eliminate the redundancy. +Value* GVN::AttemptRedundancyElimination(Instruction* orig, unsigned valno) { + BasicBlock* BaseBlock = orig->getParent(); + + SmallPtrSet<BasicBlock*, 4> Visited; + SmallVector<BasicBlock*, 8> Stack; + Stack.push_back(BaseBlock); + + DenseMap<BasicBlock*, Value*> Results; + + // Walk backwards through our predecessors, looking for instances of the + // value number we're looking for. Instances are recorded in the Results + // map, which is then used to perform phi construction. + while (!Stack.empty()) { + BasicBlock* Current = Stack.back(); + Stack.pop_back(); + + // If we've walked all the way to a proper dominator, then give up. Cases + // where the instance is in the dominator will have been caught by the fast + // path, and any cases that require phi construction further than this are + // probably not worth it anyways. Note that this is a SIGNIFICANT compile + // time improvement. + if (DT->properlyDominates(Current, orig->getParent())) return 0; + + DenseMap<BasicBlock*, ValueNumberScope*>::iterator LA = + localAvail.find(Current); + if (LA == localAvail.end()) return 0; + DenseMap<uint32_t, Value*>::iterator V = LA->second->table.find(valno); + + if (V != LA->second->table.end()) { + // Found an instance, record it. + Results.insert(std::make_pair(Current, V->second)); + continue; + } + + // If we reach the beginning of the function, then give up. + if (pred_begin(Current) == pred_end(Current)) + return 0; + + for (pred_iterator PI = pred_begin(Current), PE = pred_end(Current); + PI != PE; ++PI) + if (Visited.insert(*PI)) + Stack.push_back(*PI); + } + + // If we didn't find instances, give up. Otherwise, perform phi construction. + if (Results.size() == 0) + return 0; + else + return GetValueForBlock(BaseBlock, orig, Results, true); +} + +/// processInstruction - When calculating availability, handle an instruction +/// by inserting it into the appropriate sets +bool GVN::processInstruction(Instruction *I, + SmallVectorImpl<Instruction*> &toErase) { + if (LoadInst* L = dyn_cast<LoadInst>(I)) { + bool changed = processLoad(L, toErase); + + if (!changed) { + unsigned num = VN.lookup_or_add(L); + localAvail[I->getParent()]->table.insert(std::make_pair(num, L)); + } + + return changed; + } + + uint32_t nextNum = VN.getNextUnusedValueNumber(); + unsigned num = VN.lookup_or_add(I); + + if (BranchInst* BI = dyn_cast<BranchInst>(I)) { + localAvail[I->getParent()]->table.insert(std::make_pair(num, I)); + + if (!BI->isConditional() || isa<Constant>(BI->getCondition())) + return false; + + Value* branchCond = BI->getCondition(); + uint32_t condVN = VN.lookup_or_add(branchCond); + + BasicBlock* trueSucc = BI->getSuccessor(0); + BasicBlock* falseSucc = BI->getSuccessor(1); + + if (trueSucc->getSinglePredecessor()) + localAvail[trueSucc]->table[condVN] = ConstantInt::getTrue(); + if (falseSucc->getSinglePredecessor()) + localAvail[falseSucc]->table[condVN] = ConstantInt::getFalse(); + + return false; + + // Allocations are always uniquely numbered, so we can save time and memory + // by fast failing them. + } else if (isa<AllocationInst>(I) || isa<TerminatorInst>(I)) { + localAvail[I->getParent()]->table.insert(std::make_pair(num, I)); + return false; + } + + // Collapse PHI nodes + if (PHINode* p = dyn_cast<PHINode>(I)) { + Value* constVal = CollapsePhi(p); + + if (constVal) { + for (PhiMapType::iterator PI = phiMap.begin(), PE = phiMap.end(); + PI != PE; ++PI) + PI->second.erase(p); + + p->replaceAllUsesWith(constVal); + if (isa<PointerType>(constVal->getType())) + MD->invalidateCachedPointerInfo(constVal); + VN.erase(p); + + toErase.push_back(p); + } else { + localAvail[I->getParent()]->table.insert(std::make_pair(num, I)); + } + + // If the number we were assigned was a brand new VN, then we don't + // need to do a lookup to see if the number already exists + // somewhere in the domtree: it can't! + } else if (num == nextNum) { + localAvail[I->getParent()]->table.insert(std::make_pair(num, I)); + + // Perform fast-path value-number based elimination of values inherited from + // dominators. + } else if (Value* repl = lookupNumber(I->getParent(), num)) { + // Remove it! + VN.erase(I); + I->replaceAllUsesWith(repl); + if (isa<PointerType>(repl->getType())) + MD->invalidateCachedPointerInfo(repl); + toErase.push_back(I); + return true; + +#if 0 + // Perform slow-pathvalue-number based elimination with phi construction. + } else if (Value* repl = AttemptRedundancyElimination(I, num)) { + // Remove it! + VN.erase(I); + I->replaceAllUsesWith(repl); + if (isa<PointerType>(repl->getType())) + MD->invalidateCachedPointerInfo(repl); + toErase.push_back(I); + return true; +#endif + } else { + localAvail[I->getParent()]->table.insert(std::make_pair(num, I)); + } + + return false; +} + +/// runOnFunction - This is the main transformation entry point for a function. +bool GVN::runOnFunction(Function& F) { + MD = &getAnalysis<MemoryDependenceAnalysis>(); + DT = &getAnalysis<DominatorTree>(); + VN.setAliasAnalysis(&getAnalysis<AliasAnalysis>()); + VN.setMemDep(MD); + VN.setDomTree(DT); + + bool changed = false; + bool shouldContinue = true; + + // Merge unconditional branches, allowing PRE to catch more + // optimization opportunities. + for (Function::iterator FI = F.begin(), FE = F.end(); FI != FE; ) { + BasicBlock* BB = FI; + ++FI; + bool removedBlock = MergeBlockIntoPredecessor(BB, this); + if (removedBlock) NumGVNBlocks++; + + changed |= removedBlock; + } + + unsigned Iteration = 0; + + while (shouldContinue) { + DEBUG(cerr << "GVN iteration: " << Iteration << "\n"); + shouldContinue = iterateOnFunction(F); + changed |= shouldContinue; + ++Iteration; + } + + if (EnablePRE) { + bool PREChanged = true; + while (PREChanged) { + PREChanged = performPRE(F); + changed |= PREChanged; + } + } + // FIXME: Should perform GVN again after PRE does something. PRE can move + // computations into blocks where they become fully redundant. Note that + // we can't do this until PRE's critical edge splitting updates memdep. + // Actually, when this happens, we should just fully integrate PRE into GVN. + + cleanupGlobalSets(); + + return changed; +} + + +bool GVN::processBlock(BasicBlock* BB) { + // FIXME: Kill off toErase by doing erasing eagerly in a helper function (and + // incrementing BI before processing an instruction). + SmallVector<Instruction*, 8> toErase; + bool changed_function = false; + + for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); + BI != BE;) { + changed_function |= processInstruction(BI, toErase); + if (toErase.empty()) { + ++BI; + continue; + } + + // If we need some instructions deleted, do it now. + NumGVNInstr += toErase.size(); + + // Avoid iterator invalidation. + bool AtStart = BI == BB->begin(); + if (!AtStart) + --BI; + + for (SmallVector<Instruction*, 4>::iterator I = toErase.begin(), + E = toErase.end(); I != E; ++I) { + DEBUG(cerr << "GVN removed: " << **I); + MD->removeInstruction(*I); + (*I)->eraseFromParent(); + DEBUG(verifyRemoved(*I)); + } + toErase.clear(); + + if (AtStart) + BI = BB->begin(); + else + ++BI; + } + + return changed_function; +} + +/// performPRE - Perform a purely local form of PRE that looks for diamond +/// control flow patterns and attempts to perform simple PRE at the join point. +bool GVN::performPRE(Function& F) { + bool Changed = false; + SmallVector<std::pair<TerminatorInst*, unsigned>, 4> toSplit; + DenseMap<BasicBlock*, Value*> predMap; + for (df_iterator<BasicBlock*> DI = df_begin(&F.getEntryBlock()), + DE = df_end(&F.getEntryBlock()); DI != DE; ++DI) { + BasicBlock* CurrentBlock = *DI; + + // Nothing to PRE in the entry block. + if (CurrentBlock == &F.getEntryBlock()) continue; + + for (BasicBlock::iterator BI = CurrentBlock->begin(), + BE = CurrentBlock->end(); BI != BE; ) { + Instruction *CurInst = BI++; + + if (isa<AllocationInst>(CurInst) || isa<TerminatorInst>(CurInst) || + isa<PHINode>(CurInst) || (CurInst->getType() == Type::VoidTy) || + CurInst->mayReadFromMemory() || CurInst->mayHaveSideEffects() || + isa<DbgInfoIntrinsic>(CurInst)) + continue; + + uint32_t valno = VN.lookup(CurInst); + + // Look for the predecessors for PRE opportunities. We're + // only trying to solve the basic diamond case, where + // a value is computed in the successor and one predecessor, + // but not the other. We also explicitly disallow cases + // where the successor is its own predecessor, because they're + // more complicated to get right. + unsigned numWith = 0; + unsigned numWithout = 0; + BasicBlock* PREPred = 0; + predMap.clear(); + + for (pred_iterator PI = pred_begin(CurrentBlock), + PE = pred_end(CurrentBlock); PI != PE; ++PI) { + // We're not interested in PRE where the block is its + // own predecessor, on in blocks with predecessors + // that are not reachable. + if (*PI == CurrentBlock) { + numWithout = 2; + break; + } else if (!localAvail.count(*PI)) { + numWithout = 2; + break; + } + + DenseMap<uint32_t, Value*>::iterator predV = + localAvail[*PI]->table.find(valno); + if (predV == localAvail[*PI]->table.end()) { + PREPred = *PI; + numWithout++; + } else if (predV->second == CurInst) { + numWithout = 2; + } else { + predMap[*PI] = predV->second; + numWith++; + } + } + + // Don't do PRE when it might increase code size, i.e. when + // we would need to insert instructions in more than one pred. + if (numWithout != 1 || numWith == 0) + continue; + + // We can't do PRE safely on a critical edge, so instead we schedule + // the edge to be split and perform the PRE the next time we iterate + // on the function. + unsigned succNum = 0; + for (unsigned i = 0, e = PREPred->getTerminator()->getNumSuccessors(); + i != e; ++i) + if (PREPred->getTerminator()->getSuccessor(i) == CurrentBlock) { + succNum = i; + break; + } + + if (isCriticalEdge(PREPred->getTerminator(), succNum)) { + toSplit.push_back(std::make_pair(PREPred->getTerminator(), succNum)); + continue; + } + + // Instantiate the expression the in predecessor that lacked it. + // Because we are going top-down through the block, all value numbers + // will be available in the predecessor by the time we need them. Any + // that weren't original present will have been instantiated earlier + // in this loop. + Instruction* PREInstr = CurInst->clone(); + bool success = true; + for (unsigned i = 0, e = CurInst->getNumOperands(); i != e; ++i) { + Value *Op = PREInstr->getOperand(i); + if (isa<Argument>(Op) || isa<Constant>(Op) || isa<GlobalValue>(Op)) + continue; + + if (Value *V = lookupNumber(PREPred, VN.lookup(Op))) { + PREInstr->setOperand(i, V); + } else { + success = false; + break; + } + } + + // Fail out if we encounter an operand that is not available in + // the PRE predecessor. This is typically because of loads which + // are not value numbered precisely. + if (!success) { + delete PREInstr; + DEBUG(verifyRemoved(PREInstr)); + continue; + } + + PREInstr->insertBefore(PREPred->getTerminator()); + PREInstr->setName(CurInst->getName() + ".pre"); + predMap[PREPred] = PREInstr; + VN.add(PREInstr, valno); + NumGVNPRE++; + + // Update the availability map to include the new instruction. + localAvail[PREPred]->table.insert(std::make_pair(valno, PREInstr)); + + // Create a PHI to make the value available in this block. + PHINode* Phi = PHINode::Create(CurInst->getType(), + CurInst->getName() + ".pre-phi", + CurrentBlock->begin()); + for (pred_iterator PI = pred_begin(CurrentBlock), + PE = pred_end(CurrentBlock); PI != PE; ++PI) + Phi->addIncoming(predMap[*PI], *PI); + + VN.add(Phi, valno); + localAvail[CurrentBlock]->table[valno] = Phi; + + CurInst->replaceAllUsesWith(Phi); + if (isa<PointerType>(Phi->getType())) + MD->invalidateCachedPointerInfo(Phi); + VN.erase(CurInst); + + DEBUG(cerr << "GVN PRE removed: " << *CurInst); + MD->removeInstruction(CurInst); + CurInst->eraseFromParent(); + DEBUG(verifyRemoved(CurInst)); + Changed = true; + } + } + + for (SmallVector<std::pair<TerminatorInst*, unsigned>, 4>::iterator + I = toSplit.begin(), E = toSplit.end(); I != E; ++I) + SplitCriticalEdge(I->first, I->second, this); + + return Changed || toSplit.size(); +} + +/// iterateOnFunction - Executes one iteration of GVN +bool GVN::iterateOnFunction(Function &F) { + cleanupGlobalSets(); + + for (df_iterator<DomTreeNode*> DI = df_begin(DT->getRootNode()), + DE = df_end(DT->getRootNode()); DI != DE; ++DI) { + if (DI->getIDom()) + localAvail[DI->getBlock()] = + new ValueNumberScope(localAvail[DI->getIDom()->getBlock()]); + else + localAvail[DI->getBlock()] = new ValueNumberScope(0); + } + + // Top-down walk of the dominator tree + bool changed = false; +#if 0 + // Needed for value numbering with phi construction to work. + ReversePostOrderTraversal<Function*> RPOT(&F); + for (ReversePostOrderTraversal<Function*>::rpo_iterator RI = RPOT.begin(), + RE = RPOT.end(); RI != RE; ++RI) + changed |= processBlock(*RI); +#else + for (df_iterator<DomTreeNode*> DI = df_begin(DT->getRootNode()), + DE = df_end(DT->getRootNode()); DI != DE; ++DI) + changed |= processBlock(DI->getBlock()); +#endif + + return changed; +} + +void GVN::cleanupGlobalSets() { + VN.clear(); + phiMap.clear(); + + for (DenseMap<BasicBlock*, ValueNumberScope*>::iterator + I = localAvail.begin(), E = localAvail.end(); I != E; ++I) + delete I->second; + localAvail.clear(); +} + +/// verifyRemoved - Verify that the specified instruction does not occur in our +/// internal data structures. +void GVN::verifyRemoved(const Instruction *Inst) const { + VN.verifyRemoved(Inst); + + // Walk through the PHI map to make sure the instruction isn't hiding in there + // somewhere. + for (PhiMapType::iterator + I = phiMap.begin(), E = phiMap.end(); I != E; ++I) { + assert(I->first != Inst && "Inst is still a key in PHI map!"); + + for (SmallPtrSet<Instruction*, 4>::iterator + II = I->second.begin(), IE = I->second.end(); II != IE; ++II) { + assert(*II != Inst && "Inst is still a value in PHI map!"); + } + } + + // Walk through the value number scope to make sure the instruction isn't + // ferreted away in it. + for (DenseMap<BasicBlock*, ValueNumberScope*>::iterator + I = localAvail.begin(), E = localAvail.end(); I != E; ++I) { + const ValueNumberScope *VNS = I->second; + + while (VNS) { + for (DenseMap<uint32_t, Value*>::iterator + II = VNS->table.begin(), IE = VNS->table.end(); II != IE; ++II) { + assert(II->second != Inst && "Inst still in value numbering scope!"); + } + + VNS = VNS->parent; + } + } +} diff --git a/lib/Transforms/Scalar/GVNPRE.cpp b/lib/Transforms/Scalar/GVNPRE.cpp new file mode 100644 index 0000000..e3b0937 --- /dev/null +++ b/lib/Transforms/Scalar/GVNPRE.cpp @@ -0,0 +1,1885 @@ +//===- GVNPRE.cpp - Eliminate redundant values and expressions ------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass performs a hybrid of global value numbering and partial redundancy +// elimination, known as GVN-PRE. It performs partial redundancy elimination on +// values, rather than lexical expressions, allowing a more comprehensive view +// the optimization. It replaces redundant values with uses of earlier +// occurences of the same value. While this is beneficial in that it eliminates +// unneeded computation, it also increases register pressure by creating large +// live ranges, and should be used with caution on platforms that are very +// sensitive to register pressure. +// +// Note that this pass does the value numbering itself, it does not use the +// ValueNumbering analysis passes. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "gvnpre" +#include "llvm/Value.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Instructions.h" +#include "llvm/Function.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Analysis/Dominators.h" +#include "llvm/ADT/BitVector.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h" +#include "llvm/Support/CFG.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include <algorithm> +#include <deque> +#include <map> +using namespace llvm; + +//===----------------------------------------------------------------------===// +// ValueTable Class +//===----------------------------------------------------------------------===// + +namespace { + +/// This class holds the mapping between values and value numbers. It is used +/// as an efficient mechanism to determine the expression-wise equivalence of +/// two values. + +struct Expression { + enum ExpressionOpcode { ADD, SUB, MUL, UDIV, SDIV, FDIV, UREM, SREM, + FREM, SHL, LSHR, ASHR, AND, OR, XOR, ICMPEQ, + ICMPNE, ICMPUGT, ICMPUGE, ICMPULT, ICMPULE, + ICMPSGT, ICMPSGE, ICMPSLT, ICMPSLE, FCMPOEQ, + FCMPOGT, FCMPOGE, FCMPOLT, FCMPOLE, FCMPONE, + FCMPORD, FCMPUNO, FCMPUEQ, FCMPUGT, FCMPUGE, + FCMPULT, FCMPULE, FCMPUNE, EXTRACT, INSERT, + SHUFFLE, SELECT, TRUNC, ZEXT, SEXT, FPTOUI, + FPTOSI, UITOFP, SITOFP, FPTRUNC, FPEXT, + PTRTOINT, INTTOPTR, BITCAST, GEP, EMPTY, + TOMBSTONE }; + + ExpressionOpcode opcode; + const Type* type; + uint32_t firstVN; + uint32_t secondVN; + uint32_t thirdVN; + SmallVector<uint32_t, 4> varargs; + + Expression() { } + explicit Expression(ExpressionOpcode o) : opcode(o) { } + + bool operator==(const Expression &other) const { + if (opcode != other.opcode) + return false; + else if (opcode == EMPTY || opcode == TOMBSTONE) + return true; + else if (type != other.type) + return false; + else if (firstVN != other.firstVN) + return false; + else if (secondVN != other.secondVN) + return false; + else if (thirdVN != other.thirdVN) + return false; + else { + if (varargs.size() != other.varargs.size()) + return false; + + for (size_t i = 0; i < varargs.size(); ++i) + if (varargs[i] != other.varargs[i]) + return false; + + return true; + } + } + + bool operator!=(const Expression &other) const { + if (opcode != other.opcode) + return true; + else if (opcode == EMPTY || opcode == TOMBSTONE) + return false; + else if (type != other.type) + return true; + else if (firstVN != other.firstVN) + return true; + else if (secondVN != other.secondVN) + return true; + else if (thirdVN != other.thirdVN) + return true; + else { + if (varargs.size() != other.varargs.size()) + return true; + + for (size_t i = 0; i < varargs.size(); ++i) + if (varargs[i] != other.varargs[i]) + return true; + + return false; + } + } +}; + +} + +namespace { + class VISIBILITY_HIDDEN ValueTable { + private: + DenseMap<Value*, uint32_t> valueNumbering; + DenseMap<Expression, uint32_t> expressionNumbering; + + uint32_t nextValueNumber; + + Expression::ExpressionOpcode getOpcode(BinaryOperator* BO); + Expression::ExpressionOpcode getOpcode(CmpInst* C); + Expression::ExpressionOpcode getOpcode(CastInst* C); + Expression create_expression(BinaryOperator* BO); + Expression create_expression(CmpInst* C); + Expression create_expression(ShuffleVectorInst* V); + Expression create_expression(ExtractElementInst* C); + Expression create_expression(InsertElementInst* V); + Expression create_expression(SelectInst* V); + Expression create_expression(CastInst* C); + Expression create_expression(GetElementPtrInst* G); + public: + ValueTable() { nextValueNumber = 1; } + uint32_t lookup_or_add(Value* V); + uint32_t lookup(Value* V) const; + void add(Value* V, uint32_t num); + void clear(); + void erase(Value* v); + unsigned size(); + }; +} + +namespace llvm { +template <> struct DenseMapInfo<Expression> { + static inline Expression getEmptyKey() { + return Expression(Expression::EMPTY); + } + + static inline Expression getTombstoneKey() { + return Expression(Expression::TOMBSTONE); + } + + static unsigned getHashValue(const Expression e) { + unsigned hash = e.opcode; + + hash = e.firstVN + hash * 37; + hash = e.secondVN + hash * 37; + hash = e.thirdVN + hash * 37; + + hash = ((unsigned)((uintptr_t)e.type >> 4) ^ + (unsigned)((uintptr_t)e.type >> 9)) + + hash * 37; + + for (SmallVector<uint32_t, 4>::const_iterator I = e.varargs.begin(), + E = e.varargs.end(); I != E; ++I) + hash = *I + hash * 37; + + return hash; + } + static bool isEqual(const Expression &LHS, const Expression &RHS) { + return LHS == RHS; + } + static bool isPod() { return true; } +}; +} + +//===----------------------------------------------------------------------===// +// ValueTable Internal Functions +//===----------------------------------------------------------------------===// +Expression::ExpressionOpcode + ValueTable::getOpcode(BinaryOperator* BO) { + switch(BO->getOpcode()) { + case Instruction::Add: + return Expression::ADD; + case Instruction::Sub: + return Expression::SUB; + case Instruction::Mul: + return Expression::MUL; + case Instruction::UDiv: + return Expression::UDIV; + case Instruction::SDiv: + return Expression::SDIV; + case Instruction::FDiv: + return Expression::FDIV; + case Instruction::URem: + return Expression::UREM; + case Instruction::SRem: + return Expression::SREM; + case Instruction::FRem: + return Expression::FREM; + case Instruction::Shl: + return Expression::SHL; + case Instruction::LShr: + return Expression::LSHR; + case Instruction::AShr: + return Expression::ASHR; + case Instruction::And: + return Expression::AND; + case Instruction::Or: + return Expression::OR; + case Instruction::Xor: + return Expression::XOR; + + // THIS SHOULD NEVER HAPPEN + default: + assert(0 && "Binary operator with unknown opcode?"); + return Expression::ADD; + } +} + +Expression::ExpressionOpcode ValueTable::getOpcode(CmpInst* C) { + if (C->getOpcode() == Instruction::ICmp) { + switch (C->getPredicate()) { + case ICmpInst::ICMP_EQ: + return Expression::ICMPEQ; + case ICmpInst::ICMP_NE: + return Expression::ICMPNE; + case ICmpInst::ICMP_UGT: + return Expression::ICMPUGT; + case ICmpInst::ICMP_UGE: + return Expression::ICMPUGE; + case ICmpInst::ICMP_ULT: + return Expression::ICMPULT; + case ICmpInst::ICMP_ULE: + return Expression::ICMPULE; + case ICmpInst::ICMP_SGT: + return Expression::ICMPSGT; + case ICmpInst::ICMP_SGE: + return Expression::ICMPSGE; + case ICmpInst::ICMP_SLT: + return Expression::ICMPSLT; + case ICmpInst::ICMP_SLE: + return Expression::ICMPSLE; + + // THIS SHOULD NEVER HAPPEN + default: + assert(0 && "Comparison with unknown predicate?"); + return Expression::ICMPEQ; + } + } else { + switch (C->getPredicate()) { + case FCmpInst::FCMP_OEQ: + return Expression::FCMPOEQ; + case FCmpInst::FCMP_OGT: + return Expression::FCMPOGT; + case FCmpInst::FCMP_OGE: + return Expression::FCMPOGE; + case FCmpInst::FCMP_OLT: + return Expression::FCMPOLT; + case FCmpInst::FCMP_OLE: + return Expression::FCMPOLE; + case FCmpInst::FCMP_ONE: + return Expression::FCMPONE; + case FCmpInst::FCMP_ORD: + return Expression::FCMPORD; + case FCmpInst::FCMP_UNO: + return Expression::FCMPUNO; + case FCmpInst::FCMP_UEQ: + return Expression::FCMPUEQ; + case FCmpInst::FCMP_UGT: + return Expression::FCMPUGT; + case FCmpInst::FCMP_UGE: + return Expression::FCMPUGE; + case FCmpInst::FCMP_ULT: + return Expression::FCMPULT; + case FCmpInst::FCMP_ULE: + return Expression::FCMPULE; + case FCmpInst::FCMP_UNE: + return Expression::FCMPUNE; + + // THIS SHOULD NEVER HAPPEN + default: + assert(0 && "Comparison with unknown predicate?"); + return Expression::FCMPOEQ; + } + } +} + +Expression::ExpressionOpcode + ValueTable::getOpcode(CastInst* C) { + switch(C->getOpcode()) { + case Instruction::Trunc: + return Expression::TRUNC; + case Instruction::ZExt: + return Expression::ZEXT; + case Instruction::SExt: + return Expression::SEXT; + case Instruction::FPToUI: + return Expression::FPTOUI; + case Instruction::FPToSI: + return Expression::FPTOSI; + case Instruction::UIToFP: + return Expression::UITOFP; + case Instruction::SIToFP: + return Expression::SITOFP; + case Instruction::FPTrunc: + return Expression::FPTRUNC; + case Instruction::FPExt: + return Expression::FPEXT; + case Instruction::PtrToInt: + return Expression::PTRTOINT; + case Instruction::IntToPtr: + return Expression::INTTOPTR; + case Instruction::BitCast: + return Expression::BITCAST; + + // THIS SHOULD NEVER HAPPEN + default: + assert(0 && "Cast operator with unknown opcode?"); + return Expression::BITCAST; + } +} + +Expression ValueTable::create_expression(BinaryOperator* BO) { + Expression e; + + e.firstVN = lookup_or_add(BO->getOperand(0)); + e.secondVN = lookup_or_add(BO->getOperand(1)); + e.thirdVN = 0; + e.type = BO->getType(); + e.opcode = getOpcode(BO); + + return e; +} + +Expression ValueTable::create_expression(CmpInst* C) { + Expression e; + + e.firstVN = lookup_or_add(C->getOperand(0)); + e.secondVN = lookup_or_add(C->getOperand(1)); + e.thirdVN = 0; + e.type = C->getType(); + e.opcode = getOpcode(C); + + return e; +} + +Expression ValueTable::create_expression(CastInst* C) { + Expression e; + + e.firstVN = lookup_or_add(C->getOperand(0)); + e.secondVN = 0; + e.thirdVN = 0; + e.type = C->getType(); + e.opcode = getOpcode(C); + + return e; +} + +Expression ValueTable::create_expression(ShuffleVectorInst* S) { + Expression e; + + e.firstVN = lookup_or_add(S->getOperand(0)); + e.secondVN = lookup_or_add(S->getOperand(1)); + e.thirdVN = lookup_or_add(S->getOperand(2)); + e.type = S->getType(); + e.opcode = Expression::SHUFFLE; + + return e; +} + +Expression ValueTable::create_expression(ExtractElementInst* E) { + Expression e; + + e.firstVN = lookup_or_add(E->getOperand(0)); + e.secondVN = lookup_or_add(E->getOperand(1)); + e.thirdVN = 0; + e.type = E->getType(); + e.opcode = Expression::EXTRACT; + + return e; +} + +Expression ValueTable::create_expression(InsertElementInst* I) { + Expression e; + + e.firstVN = lookup_or_add(I->getOperand(0)); + e.secondVN = lookup_or_add(I->getOperand(1)); + e.thirdVN = lookup_or_add(I->getOperand(2)); + e.type = I->getType(); + e.opcode = Expression::INSERT; + + return e; +} + +Expression ValueTable::create_expression(SelectInst* I) { + Expression e; + + e.firstVN = lookup_or_add(I->getCondition()); + e.secondVN = lookup_or_add(I->getTrueValue()); + e.thirdVN = lookup_or_add(I->getFalseValue()); + e.type = I->getType(); + e.opcode = Expression::SELECT; + + return e; +} + +Expression ValueTable::create_expression(GetElementPtrInst* G) { + Expression e; + + e.firstVN = lookup_or_add(G->getPointerOperand()); + e.secondVN = 0; + e.thirdVN = 0; + e.type = G->getType(); + e.opcode = Expression::GEP; + + for (GetElementPtrInst::op_iterator I = G->idx_begin(), E = G->idx_end(); + I != E; ++I) + e.varargs.push_back(lookup_or_add(*I)); + + return e; +} + +//===----------------------------------------------------------------------===// +// ValueTable External Functions +//===----------------------------------------------------------------------===// + +/// lookup_or_add - Returns the value number for the specified value, assigning +/// it a new number if it did not have one before. +uint32_t ValueTable::lookup_or_add(Value* V) { + DenseMap<Value*, uint32_t>::iterator VI = valueNumbering.find(V); + if (VI != valueNumbering.end()) + return VI->second; + + + if (BinaryOperator* BO = dyn_cast<BinaryOperator>(V)) { + Expression e = create_expression(BO); + + DenseMap<Expression, uint32_t>::iterator EI = expressionNumbering.find(e); + if (EI != expressionNumbering.end()) { + valueNumbering.insert(std::make_pair(V, EI->second)); + return EI->second; + } else { + expressionNumbering.insert(std::make_pair(e, nextValueNumber)); + valueNumbering.insert(std::make_pair(V, nextValueNumber)); + + return nextValueNumber++; + } + } else if (CmpInst* C = dyn_cast<CmpInst>(V)) { + Expression e = create_expression(C); + + DenseMap<Expression, uint32_t>::iterator EI = expressionNumbering.find(e); + if (EI != expressionNumbering.end()) { + valueNumbering.insert(std::make_pair(V, EI->second)); + return EI->second; + } else { + expressionNumbering.insert(std::make_pair(e, nextValueNumber)); + valueNumbering.insert(std::make_pair(V, nextValueNumber)); + + return nextValueNumber++; + } + } else if (ShuffleVectorInst* U = dyn_cast<ShuffleVectorInst>(V)) { + Expression e = create_expression(U); + + DenseMap<Expression, uint32_t>::iterator EI = expressionNumbering.find(e); + if (EI != expressionNumbering.end()) { + valueNumbering.insert(std::make_pair(V, EI->second)); + return EI->second; + } else { + expressionNumbering.insert(std::make_pair(e, nextValueNumber)); + valueNumbering.insert(std::make_pair(V, nextValueNumber)); + + return nextValueNumber++; + } + } else if (ExtractElementInst* U = dyn_cast<ExtractElementInst>(V)) { + Expression e = create_expression(U); + + DenseMap<Expression, uint32_t>::iterator EI = expressionNumbering.find(e); + if (EI != expressionNumbering.end()) { + valueNumbering.insert(std::make_pair(V, EI->second)); + return EI->second; + } else { + expressionNumbering.insert(std::make_pair(e, nextValueNumber)); + valueNumbering.insert(std::make_pair(V, nextValueNumber)); + + return nextValueNumber++; + } + } else if (InsertElementInst* U = dyn_cast<InsertElementInst>(V)) { + Expression e = create_expression(U); + + DenseMap<Expression, uint32_t>::iterator EI = expressionNumbering.find(e); + if (EI != expressionNumbering.end()) { + valueNumbering.insert(std::make_pair(V, EI->second)); + return EI->second; + } else { + expressionNumbering.insert(std::make_pair(e, nextValueNumber)); + valueNumbering.insert(std::make_pair(V, nextValueNumber)); + + return nextValueNumber++; + } + } else if (SelectInst* U = dyn_cast<SelectInst>(V)) { + Expression e = create_expression(U); + + DenseMap<Expression, uint32_t>::iterator EI = expressionNumbering.find(e); + if (EI != expressionNumbering.end()) { + valueNumbering.insert(std::make_pair(V, EI->second)); + return EI->second; + } else { + expressionNumbering.insert(std::make_pair(e, nextValueNumber)); + valueNumbering.insert(std::make_pair(V, nextValueNumber)); + + return nextValueNumber++; + } + } else if (CastInst* U = dyn_cast<CastInst>(V)) { + Expression e = create_expression(U); + + DenseMap<Expression, uint32_t>::iterator EI = expressionNumbering.find(e); + if (EI != expressionNumbering.end()) { + valueNumbering.insert(std::make_pair(V, EI->second)); + return EI->second; + } else { + expressionNumbering.insert(std::make_pair(e, nextValueNumber)); + valueNumbering.insert(std::make_pair(V, nextValueNumber)); + + return nextValueNumber++; + } + } else if (GetElementPtrInst* U = dyn_cast<GetElementPtrInst>(V)) { + Expression e = create_expression(U); + + DenseMap<Expression, uint32_t>::iterator EI = expressionNumbering.find(e); + if (EI != expressionNumbering.end()) { + valueNumbering.insert(std::make_pair(V, EI->second)); + return EI->second; + } else { + expressionNumbering.insert(std::make_pair(e, nextValueNumber)); + valueNumbering.insert(std::make_pair(V, nextValueNumber)); + + return nextValueNumber++; + } + } else { + valueNumbering.insert(std::make_pair(V, nextValueNumber)); + return nextValueNumber++; + } +} + +/// lookup - Returns the value number of the specified value. Fails if +/// the value has not yet been numbered. +uint32_t ValueTable::lookup(Value* V) const { + DenseMap<Value*, uint32_t>::iterator VI = valueNumbering.find(V); + if (VI != valueNumbering.end()) + return VI->second; + else + assert(0 && "Value not numbered?"); + + return 0; +} + +/// add - Add the specified value with the given value number, removing +/// its old number, if any +void ValueTable::add(Value* V, uint32_t num) { + DenseMap<Value*, uint32_t>::iterator VI = valueNumbering.find(V); + if (VI != valueNumbering.end()) + valueNumbering.erase(VI); + valueNumbering.insert(std::make_pair(V, num)); +} + +/// clear - Remove all entries from the ValueTable +void ValueTable::clear() { + valueNumbering.clear(); + expressionNumbering.clear(); + nextValueNumber = 1; +} + +/// erase - Remove a value from the value numbering +void ValueTable::erase(Value* V) { + valueNumbering.erase(V); +} + +/// size - Return the number of assigned value numbers +unsigned ValueTable::size() { + // NOTE: zero is never assigned + return nextValueNumber; +} + +namespace { + +//===----------------------------------------------------------------------===// +// ValueNumberedSet Class +//===----------------------------------------------------------------------===// + +class ValueNumberedSet { + private: + SmallPtrSet<Value*, 8> contents; + BitVector numbers; + public: + ValueNumberedSet() { numbers.resize(1); } + ValueNumberedSet(const ValueNumberedSet& other) { + numbers = other.numbers; + contents = other.contents; + } + + typedef SmallPtrSet<Value*, 8>::iterator iterator; + + iterator begin() { return contents.begin(); } + iterator end() { return contents.end(); } + + bool insert(Value* v) { return contents.insert(v); } + void insert(iterator I, iterator E) { contents.insert(I, E); } + void erase(Value* v) { contents.erase(v); } + unsigned count(Value* v) { return contents.count(v); } + size_t size() { return contents.size(); } + + void set(unsigned i) { + if (i >= numbers.size()) + numbers.resize(i+1); + + numbers.set(i); + } + + void operator=(const ValueNumberedSet& other) { + contents = other.contents; + numbers = other.numbers; + } + + void reset(unsigned i) { + if (i < numbers.size()) + numbers.reset(i); + } + + bool test(unsigned i) { + if (i >= numbers.size()) + return false; + + return numbers.test(i); + } + + void clear() { + contents.clear(); + numbers.clear(); + } +}; + +} + +//===----------------------------------------------------------------------===// +// GVNPRE Pass +//===----------------------------------------------------------------------===// + +namespace { + + class VISIBILITY_HIDDEN GVNPRE : public FunctionPass { + bool runOnFunction(Function &F); + public: + static char ID; // Pass identification, replacement for typeid + GVNPRE() : FunctionPass(&ID) {} + + private: + ValueTable VN; + SmallVector<Instruction*, 8> createdExpressions; + + DenseMap<BasicBlock*, ValueNumberedSet> availableOut; + DenseMap<BasicBlock*, ValueNumberedSet> anticipatedIn; + DenseMap<BasicBlock*, ValueNumberedSet> generatedPhis; + + // This transformation requires dominator postdominator info + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesCFG(); + AU.addRequiredID(BreakCriticalEdgesID); + AU.addRequired<UnifyFunctionExitNodes>(); + AU.addRequired<DominatorTree>(); + } + + // Helper fuctions + // FIXME: eliminate or document these better + void dump(ValueNumberedSet& s) const ; + void clean(ValueNumberedSet& set) ; + Value* find_leader(ValueNumberedSet& vals, uint32_t v) ; + Value* phi_translate(Value* V, BasicBlock* pred, BasicBlock* succ) ; + void phi_translate_set(ValueNumberedSet& anticIn, BasicBlock* pred, + BasicBlock* succ, ValueNumberedSet& out) ; + + void topo_sort(ValueNumberedSet& set, + SmallVector<Value*, 8>& vec) ; + + void cleanup() ; + bool elimination() ; + + void val_insert(ValueNumberedSet& s, Value* v) ; + void val_replace(ValueNumberedSet& s, Value* v) ; + bool dependsOnInvoke(Value* V) ; + void buildsets_availout(BasicBlock::iterator I, + ValueNumberedSet& currAvail, + ValueNumberedSet& currPhis, + ValueNumberedSet& currExps, + SmallPtrSet<Value*, 16>& currTemps); + bool buildsets_anticout(BasicBlock* BB, + ValueNumberedSet& anticOut, + SmallPtrSet<BasicBlock*, 8>& visited); + unsigned buildsets_anticin(BasicBlock* BB, + ValueNumberedSet& anticOut, + ValueNumberedSet& currExps, + SmallPtrSet<Value*, 16>& currTemps, + SmallPtrSet<BasicBlock*, 8>& visited); + void buildsets(Function& F) ; + + void insertion_pre(Value* e, BasicBlock* BB, + DenseMap<BasicBlock*, Value*>& avail, + std::map<BasicBlock*,ValueNumberedSet>& new_set); + unsigned insertion_mergepoint(SmallVector<Value*, 8>& workList, + df_iterator<DomTreeNode*>& D, + std::map<BasicBlock*, ValueNumberedSet>& new_set); + bool insertion(Function& F) ; + + }; + + char GVNPRE::ID = 0; + +} + +// createGVNPREPass - The public interface to this file... +FunctionPass *llvm::createGVNPREPass() { return new GVNPRE(); } + +static RegisterPass<GVNPRE> X("gvnpre", + "Global Value Numbering/Partial Redundancy Elimination"); + + +STATISTIC(NumInsertedVals, "Number of values inserted"); +STATISTIC(NumInsertedPhis, "Number of PHI nodes inserted"); +STATISTIC(NumEliminated, "Number of redundant instructions eliminated"); + +/// find_leader - Given a set and a value number, return the first +/// element of the set with that value number, or 0 if no such element +/// is present +Value* GVNPRE::find_leader(ValueNumberedSet& vals, uint32_t v) { + if (!vals.test(v)) + return 0; + + for (ValueNumberedSet::iterator I = vals.begin(), E = vals.end(); + I != E; ++I) + if (v == VN.lookup(*I)) + return *I; + + assert(0 && "No leader found, but present bit is set?"); + return 0; +} + +/// val_insert - Insert a value into a set only if there is not a value +/// with the same value number already in the set +void GVNPRE::val_insert(ValueNumberedSet& s, Value* v) { + uint32_t num = VN.lookup(v); + if (!s.test(num)) + s.insert(v); +} + +/// val_replace - Insert a value into a set, replacing any values already in +/// the set that have the same value number +void GVNPRE::val_replace(ValueNumberedSet& s, Value* v) { + if (s.count(v)) return; + + uint32_t num = VN.lookup(v); + Value* leader = find_leader(s, num); + if (leader != 0) + s.erase(leader); + s.insert(v); + s.set(num); +} + +/// phi_translate - Given a value, its parent block, and a predecessor of its +/// parent, translate the value into legal for the predecessor block. This +/// means translating its operands (and recursively, their operands) through +/// any phi nodes in the parent into values available in the predecessor +Value* GVNPRE::phi_translate(Value* V, BasicBlock* pred, BasicBlock* succ) { + if (V == 0) + return 0; + + // Unary Operations + if (CastInst* U = dyn_cast<CastInst>(V)) { + Value* newOp1 = 0; + if (isa<Instruction>(U->getOperand(0))) + newOp1 = phi_translate(U->getOperand(0), pred, succ); + else + newOp1 = U->getOperand(0); + + if (newOp1 == 0) + return 0; + + if (newOp1 != U->getOperand(0)) { + Instruction* newVal = 0; + if (CastInst* C = dyn_cast<CastInst>(U)) + newVal = CastInst::Create(C->getOpcode(), + newOp1, C->getType(), + C->getName()+".expr"); + + uint32_t v = VN.lookup_or_add(newVal); + + Value* leader = find_leader(availableOut[pred], v); + if (leader == 0) { + createdExpressions.push_back(newVal); + return newVal; + } else { + VN.erase(newVal); + delete newVal; + return leader; + } + } + + // Binary Operations + } if (isa<BinaryOperator>(V) || isa<CmpInst>(V) || + isa<ExtractElementInst>(V)) { + User* U = cast<User>(V); + + Value* newOp1 = 0; + if (isa<Instruction>(U->getOperand(0))) + newOp1 = phi_translate(U->getOperand(0), pred, succ); + else + newOp1 = U->getOperand(0); + + if (newOp1 == 0) + return 0; + + Value* newOp2 = 0; + if (isa<Instruction>(U->getOperand(1))) + newOp2 = phi_translate(U->getOperand(1), pred, succ); + else + newOp2 = U->getOperand(1); + + if (newOp2 == 0) + return 0; + + if (newOp1 != U->getOperand(0) || newOp2 != U->getOperand(1)) { + Instruction* newVal = 0; + if (BinaryOperator* BO = dyn_cast<BinaryOperator>(U)) + newVal = BinaryOperator::Create(BO->getOpcode(), + newOp1, newOp2, + BO->getName()+".expr"); + else if (CmpInst* C = dyn_cast<CmpInst>(U)) + newVal = CmpInst::Create(C->getOpcode(), + C->getPredicate(), + newOp1, newOp2, + C->getName()+".expr"); + else if (ExtractElementInst* E = dyn_cast<ExtractElementInst>(U)) + newVal = new ExtractElementInst(newOp1, newOp2, E->getName()+".expr"); + + uint32_t v = VN.lookup_or_add(newVal); + + Value* leader = find_leader(availableOut[pred], v); + if (leader == 0) { + createdExpressions.push_back(newVal); + return newVal; + } else { + VN.erase(newVal); + delete newVal; + return leader; + } + } + + // Ternary Operations + } else if (isa<ShuffleVectorInst>(V) || isa<InsertElementInst>(V) || + isa<SelectInst>(V)) { + User* U = cast<User>(V); + + Value* newOp1 = 0; + if (isa<Instruction>(U->getOperand(0))) + newOp1 = phi_translate(U->getOperand(0), pred, succ); + else + newOp1 = U->getOperand(0); + + if (newOp1 == 0) + return 0; + + Value* newOp2 = 0; + if (isa<Instruction>(U->getOperand(1))) + newOp2 = phi_translate(U->getOperand(1), pred, succ); + else + newOp2 = U->getOperand(1); + + if (newOp2 == 0) + return 0; + + Value* newOp3 = 0; + if (isa<Instruction>(U->getOperand(2))) + newOp3 = phi_translate(U->getOperand(2), pred, succ); + else + newOp3 = U->getOperand(2); + + if (newOp3 == 0) + return 0; + + if (newOp1 != U->getOperand(0) || + newOp2 != U->getOperand(1) || + newOp3 != U->getOperand(2)) { + Instruction* newVal = 0; + if (ShuffleVectorInst* S = dyn_cast<ShuffleVectorInst>(U)) + newVal = new ShuffleVectorInst(newOp1, newOp2, newOp3, + S->getName() + ".expr"); + else if (InsertElementInst* I = dyn_cast<InsertElementInst>(U)) + newVal = InsertElementInst::Create(newOp1, newOp2, newOp3, + I->getName() + ".expr"); + else if (SelectInst* I = dyn_cast<SelectInst>(U)) + newVal = SelectInst::Create(newOp1, newOp2, newOp3, + I->getName() + ".expr"); + + uint32_t v = VN.lookup_or_add(newVal); + + Value* leader = find_leader(availableOut[pred], v); + if (leader == 0) { + createdExpressions.push_back(newVal); + return newVal; + } else { + VN.erase(newVal); + delete newVal; + return leader; + } + } + + // Varargs operators + } else if (GetElementPtrInst* U = dyn_cast<GetElementPtrInst>(V)) { + Value* newOp1 = 0; + if (isa<Instruction>(U->getPointerOperand())) + newOp1 = phi_translate(U->getPointerOperand(), pred, succ); + else + newOp1 = U->getPointerOperand(); + + if (newOp1 == 0) + return 0; + + bool changed_idx = false; + SmallVector<Value*, 4> newIdx; + for (GetElementPtrInst::op_iterator I = U->idx_begin(), E = U->idx_end(); + I != E; ++I) + if (isa<Instruction>(*I)) { + Value* newVal = phi_translate(*I, pred, succ); + newIdx.push_back(newVal); + if (newVal != *I) + changed_idx = true; + } else { + newIdx.push_back(*I); + } + + if (newOp1 != U->getPointerOperand() || changed_idx) { + Instruction* newVal = + GetElementPtrInst::Create(newOp1, + newIdx.begin(), newIdx.end(), + U->getName()+".expr"); + + uint32_t v = VN.lookup_or_add(newVal); + + Value* leader = find_leader(availableOut[pred], v); + if (leader == 0) { + createdExpressions.push_back(newVal); + return newVal; + } else { + VN.erase(newVal); + delete newVal; + return leader; + } + } + + // PHI Nodes + } else if (PHINode* P = dyn_cast<PHINode>(V)) { + if (P->getParent() == succ) + return P->getIncomingValueForBlock(pred); + } + + return V; +} + +/// phi_translate_set - Perform phi translation on every element of a set +void GVNPRE::phi_translate_set(ValueNumberedSet& anticIn, + BasicBlock* pred, BasicBlock* succ, + ValueNumberedSet& out) { + for (ValueNumberedSet::iterator I = anticIn.begin(), + E = anticIn.end(); I != E; ++I) { + Value* V = phi_translate(*I, pred, succ); + if (V != 0 && !out.test(VN.lookup_or_add(V))) { + out.insert(V); + out.set(VN.lookup(V)); + } + } +} + +/// dependsOnInvoke - Test if a value has an phi node as an operand, any of +/// whose inputs is an invoke instruction. If this is true, we cannot safely +/// PRE the instruction or anything that depends on it. +bool GVNPRE::dependsOnInvoke(Value* V) { + if (PHINode* p = dyn_cast<PHINode>(V)) { + for (PHINode::op_iterator I = p->op_begin(), E = p->op_end(); I != E; ++I) + if (isa<InvokeInst>(*I)) + return true; + return false; + } else { + return false; + } +} + +/// clean - Remove all non-opaque values from the set whose operands are not +/// themselves in the set, as well as all values that depend on invokes (see +/// above) +void GVNPRE::clean(ValueNumberedSet& set) { + SmallVector<Value*, 8> worklist; + worklist.reserve(set.size()); + topo_sort(set, worklist); + + for (unsigned i = 0; i < worklist.size(); ++i) { + Value* v = worklist[i]; + + // Handle unary ops + if (CastInst* U = dyn_cast<CastInst>(v)) { + bool lhsValid = !isa<Instruction>(U->getOperand(0)); + lhsValid |= set.test(VN.lookup(U->getOperand(0))); + if (lhsValid) + lhsValid = !dependsOnInvoke(U->getOperand(0)); + + if (!lhsValid) { + set.erase(U); + set.reset(VN.lookup(U)); + } + + // Handle binary ops + } else if (isa<BinaryOperator>(v) || isa<CmpInst>(v) || + isa<ExtractElementInst>(v)) { + User* U = cast<User>(v); + + bool lhsValid = !isa<Instruction>(U->getOperand(0)); + lhsValid |= set.test(VN.lookup(U->getOperand(0))); + if (lhsValid) + lhsValid = !dependsOnInvoke(U->getOperand(0)); + + bool rhsValid = !isa<Instruction>(U->getOperand(1)); + rhsValid |= set.test(VN.lookup(U->getOperand(1))); + if (rhsValid) + rhsValid = !dependsOnInvoke(U->getOperand(1)); + + if (!lhsValid || !rhsValid) { + set.erase(U); + set.reset(VN.lookup(U)); + } + + // Handle ternary ops + } else if (isa<ShuffleVectorInst>(v) || isa<InsertElementInst>(v) || + isa<SelectInst>(v)) { + User* U = cast<User>(v); + + bool lhsValid = !isa<Instruction>(U->getOperand(0)); + lhsValid |= set.test(VN.lookup(U->getOperand(0))); + if (lhsValid) + lhsValid = !dependsOnInvoke(U->getOperand(0)); + + bool rhsValid = !isa<Instruction>(U->getOperand(1)); + rhsValid |= set.test(VN.lookup(U->getOperand(1))); + if (rhsValid) + rhsValid = !dependsOnInvoke(U->getOperand(1)); + + bool thirdValid = !isa<Instruction>(U->getOperand(2)); + thirdValid |= set.test(VN.lookup(U->getOperand(2))); + if (thirdValid) + thirdValid = !dependsOnInvoke(U->getOperand(2)); + + if (!lhsValid || !rhsValid || !thirdValid) { + set.erase(U); + set.reset(VN.lookup(U)); + } + + // Handle varargs ops + } else if (GetElementPtrInst* U = dyn_cast<GetElementPtrInst>(v)) { + bool ptrValid = !isa<Instruction>(U->getPointerOperand()); + ptrValid |= set.test(VN.lookup(U->getPointerOperand())); + if (ptrValid) + ptrValid = !dependsOnInvoke(U->getPointerOperand()); + + bool varValid = true; + for (GetElementPtrInst::op_iterator I = U->idx_begin(), E = U->idx_end(); + I != E; ++I) + if (varValid) { + varValid &= !isa<Instruction>(*I) || set.test(VN.lookup(*I)); + varValid &= !dependsOnInvoke(*I); + } + + if (!ptrValid || !varValid) { + set.erase(U); + set.reset(VN.lookup(U)); + } + } + } +} + +/// topo_sort - Given a set of values, sort them by topological +/// order into the provided vector. +void GVNPRE::topo_sort(ValueNumberedSet& set, SmallVector<Value*, 8>& vec) { + SmallPtrSet<Value*, 16> visited; + SmallVector<Value*, 8> stack; + for (ValueNumberedSet::iterator I = set.begin(), E = set.end(); + I != E; ++I) { + if (visited.count(*I) == 0) + stack.push_back(*I); + + while (!stack.empty()) { + Value* e = stack.back(); + + // Handle unary ops + if (CastInst* U = dyn_cast<CastInst>(e)) { + Value* l = find_leader(set, VN.lookup(U->getOperand(0))); + + if (l != 0 && isa<Instruction>(l) && + visited.count(l) == 0) + stack.push_back(l); + else { + vec.push_back(e); + visited.insert(e); + stack.pop_back(); + } + + // Handle binary ops + } else if (isa<BinaryOperator>(e) || isa<CmpInst>(e) || + isa<ExtractElementInst>(e)) { + User* U = cast<User>(e); + Value* l = find_leader(set, VN.lookup(U->getOperand(0))); + Value* r = find_leader(set, VN.lookup(U->getOperand(1))); + + if (l != 0 && isa<Instruction>(l) && + visited.count(l) == 0) + stack.push_back(l); + else if (r != 0 && isa<Instruction>(r) && + visited.count(r) == 0) + stack.push_back(r); + else { + vec.push_back(e); + visited.insert(e); + stack.pop_back(); + } + + // Handle ternary ops + } else if (isa<InsertElementInst>(e) || isa<ShuffleVectorInst>(e) || + isa<SelectInst>(e)) { + User* U = cast<User>(e); + Value* l = find_leader(set, VN.lookup(U->getOperand(0))); + Value* r = find_leader(set, VN.lookup(U->getOperand(1))); + Value* m = find_leader(set, VN.lookup(U->getOperand(2))); + + if (l != 0 && isa<Instruction>(l) && + visited.count(l) == 0) + stack.push_back(l); + else if (r != 0 && isa<Instruction>(r) && + visited.count(r) == 0) + stack.push_back(r); + else if (m != 0 && isa<Instruction>(m) && + visited.count(m) == 0) + stack.push_back(m); + else { + vec.push_back(e); + visited.insert(e); + stack.pop_back(); + } + + // Handle vararg ops + } else if (GetElementPtrInst* U = dyn_cast<GetElementPtrInst>(e)) { + Value* p = find_leader(set, VN.lookup(U->getPointerOperand())); + + if (p != 0 && isa<Instruction>(p) && + visited.count(p) == 0) + stack.push_back(p); + else { + bool push_va = false; + for (GetElementPtrInst::op_iterator I = U->idx_begin(), + E = U->idx_end(); I != E; ++I) { + Value * v = find_leader(set, VN.lookup(*I)); + if (v != 0 && isa<Instruction>(v) && visited.count(v) == 0) { + stack.push_back(v); + push_va = true; + } + } + + if (!push_va) { + vec.push_back(e); + visited.insert(e); + stack.pop_back(); + } + } + + // Handle opaque ops + } else { + visited.insert(e); + vec.push_back(e); + stack.pop_back(); + } + } + + stack.clear(); + } +} + +/// dump - Dump a set of values to standard error +void GVNPRE::dump(ValueNumberedSet& s) const { + DOUT << "{ "; + for (ValueNumberedSet::iterator I = s.begin(), E = s.end(); + I != E; ++I) { + DOUT << "" << VN.lookup(*I) << ": "; + DEBUG((*I)->dump()); + } + DOUT << "}\n\n"; +} + +/// elimination - Phase 3 of the main algorithm. Perform full redundancy +/// elimination by walking the dominator tree and removing any instruction that +/// is dominated by another instruction with the same value number. +bool GVNPRE::elimination() { + bool changed_function = false; + + SmallVector<std::pair<Instruction*, Value*>, 8> replace; + SmallVector<Instruction*, 8> erase; + + DominatorTree& DT = getAnalysis<DominatorTree>(); + + for (df_iterator<DomTreeNode*> DI = df_begin(DT.getRootNode()), + E = df_end(DT.getRootNode()); DI != E; ++DI) { + BasicBlock* BB = DI->getBlock(); + + for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); + BI != BE; ++BI) { + + if (isa<BinaryOperator>(BI) || isa<CmpInst>(BI) || + isa<ShuffleVectorInst>(BI) || isa<InsertElementInst>(BI) || + isa<ExtractElementInst>(BI) || isa<SelectInst>(BI) || + isa<CastInst>(BI) || isa<GetElementPtrInst>(BI)) { + + if (availableOut[BB].test(VN.lookup(BI)) && + !availableOut[BB].count(BI)) { + Value *leader = find_leader(availableOut[BB], VN.lookup(BI)); + if (Instruction* Instr = dyn_cast<Instruction>(leader)) + if (Instr->getParent() != 0 && Instr != BI) { + replace.push_back(std::make_pair(BI, leader)); + erase.push_back(BI); + ++NumEliminated; + } + } + } + } + } + + while (!replace.empty()) { + std::pair<Instruction*, Value*> rep = replace.back(); + replace.pop_back(); + rep.first->replaceAllUsesWith(rep.second); + changed_function = true; + } + + for (SmallVector<Instruction*, 8>::iterator I = erase.begin(), + E = erase.end(); I != E; ++I) + (*I)->eraseFromParent(); + + return changed_function; +} + +/// cleanup - Delete any extraneous values that were created to represent +/// expressions without leaders. +void GVNPRE::cleanup() { + while (!createdExpressions.empty()) { + Instruction* I = createdExpressions.back(); + createdExpressions.pop_back(); + + delete I; + } +} + +/// buildsets_availout - When calculating availability, handle an instruction +/// by inserting it into the appropriate sets +void GVNPRE::buildsets_availout(BasicBlock::iterator I, + ValueNumberedSet& currAvail, + ValueNumberedSet& currPhis, + ValueNumberedSet& currExps, + SmallPtrSet<Value*, 16>& currTemps) { + // Handle PHI nodes + if (PHINode* p = dyn_cast<PHINode>(I)) { + unsigned num = VN.lookup_or_add(p); + + currPhis.insert(p); + currPhis.set(num); + + // Handle unary ops + } else if (CastInst* U = dyn_cast<CastInst>(I)) { + Value* leftValue = U->getOperand(0); + + unsigned num = VN.lookup_or_add(U); + + if (isa<Instruction>(leftValue)) + if (!currExps.test(VN.lookup(leftValue))) { + currExps.insert(leftValue); + currExps.set(VN.lookup(leftValue)); + } + + if (!currExps.test(num)) { + currExps.insert(U); + currExps.set(num); + } + + // Handle binary ops + } else if (isa<BinaryOperator>(I) || isa<CmpInst>(I) || + isa<ExtractElementInst>(I)) { + User* U = cast<User>(I); + Value* leftValue = U->getOperand(0); + Value* rightValue = U->getOperand(1); + + unsigned num = VN.lookup_or_add(U); + + if (isa<Instruction>(leftValue)) + if (!currExps.test(VN.lookup(leftValue))) { + currExps.insert(leftValue); + currExps.set(VN.lookup(leftValue)); + } + + if (isa<Instruction>(rightValue)) + if (!currExps.test(VN.lookup(rightValue))) { + currExps.insert(rightValue); + currExps.set(VN.lookup(rightValue)); + } + + if (!currExps.test(num)) { + currExps.insert(U); + currExps.set(num); + } + + // Handle ternary ops + } else if (isa<InsertElementInst>(I) || isa<ShuffleVectorInst>(I) || + isa<SelectInst>(I)) { + User* U = cast<User>(I); + Value* leftValue = U->getOperand(0); + Value* rightValue = U->getOperand(1); + Value* thirdValue = U->getOperand(2); + + VN.lookup_or_add(U); + + unsigned num = VN.lookup_or_add(U); + + if (isa<Instruction>(leftValue)) + if (!currExps.test(VN.lookup(leftValue))) { + currExps.insert(leftValue); + currExps.set(VN.lookup(leftValue)); + } + if (isa<Instruction>(rightValue)) + if (!currExps.test(VN.lookup(rightValue))) { + currExps.insert(rightValue); + currExps.set(VN.lookup(rightValue)); + } + if (isa<Instruction>(thirdValue)) + if (!currExps.test(VN.lookup(thirdValue))) { + currExps.insert(thirdValue); + currExps.set(VN.lookup(thirdValue)); + } + + if (!currExps.test(num)) { + currExps.insert(U); + currExps.set(num); + } + + // Handle vararg ops + } else if (GetElementPtrInst* U = dyn_cast<GetElementPtrInst>(I)) { + Value* ptrValue = U->getPointerOperand(); + + VN.lookup_or_add(U); + + unsigned num = VN.lookup_or_add(U); + + if (isa<Instruction>(ptrValue)) + if (!currExps.test(VN.lookup(ptrValue))) { + currExps.insert(ptrValue); + currExps.set(VN.lookup(ptrValue)); + } + + for (GetElementPtrInst::op_iterator OI = U->idx_begin(), OE = U->idx_end(); + OI != OE; ++OI) + if (isa<Instruction>(*OI) && !currExps.test(VN.lookup(*OI))) { + currExps.insert(*OI); + currExps.set(VN.lookup(*OI)); + } + + if (!currExps.test(VN.lookup(U))) { + currExps.insert(U); + currExps.set(num); + } + + // Handle opaque ops + } else if (!I->isTerminator()){ + VN.lookup_or_add(I); + + currTemps.insert(I); + } + + if (!I->isTerminator()) + if (!currAvail.test(VN.lookup(I))) { + currAvail.insert(I); + currAvail.set(VN.lookup(I)); + } +} + +/// buildsets_anticout - When walking the postdom tree, calculate the ANTIC_OUT +/// set as a function of the ANTIC_IN set of the block's predecessors +bool GVNPRE::buildsets_anticout(BasicBlock* BB, + ValueNumberedSet& anticOut, + SmallPtrSet<BasicBlock*, 8>& visited) { + if (BB->getTerminator()->getNumSuccessors() == 1) { + if (BB->getTerminator()->getSuccessor(0) != BB && + visited.count(BB->getTerminator()->getSuccessor(0)) == 0) { + return true; + } + else { + phi_translate_set(anticipatedIn[BB->getTerminator()->getSuccessor(0)], + BB, BB->getTerminator()->getSuccessor(0), anticOut); + } + } else if (BB->getTerminator()->getNumSuccessors() > 1) { + BasicBlock* first = BB->getTerminator()->getSuccessor(0); + for (ValueNumberedSet::iterator I = anticipatedIn[first].begin(), + E = anticipatedIn[first].end(); I != E; ++I) { + anticOut.insert(*I); + anticOut.set(VN.lookup(*I)); + } + + for (unsigned i = 1; i < BB->getTerminator()->getNumSuccessors(); ++i) { + BasicBlock* currSucc = BB->getTerminator()->getSuccessor(i); + ValueNumberedSet& succAnticIn = anticipatedIn[currSucc]; + + SmallVector<Value*, 16> temp; + + for (ValueNumberedSet::iterator I = anticOut.begin(), + E = anticOut.end(); I != E; ++I) + if (!succAnticIn.test(VN.lookup(*I))) + temp.push_back(*I); + + for (SmallVector<Value*, 16>::iterator I = temp.begin(), E = temp.end(); + I != E; ++I) { + anticOut.erase(*I); + anticOut.reset(VN.lookup(*I)); + } + } + } + + return false; +} + +/// buildsets_anticin - Walk the postdom tree, calculating ANTIC_OUT for +/// each block. ANTIC_IN is then a function of ANTIC_OUT and the GEN +/// sets populated in buildsets_availout +unsigned GVNPRE::buildsets_anticin(BasicBlock* BB, + ValueNumberedSet& anticOut, + ValueNumberedSet& currExps, + SmallPtrSet<Value*, 16>& currTemps, + SmallPtrSet<BasicBlock*, 8>& visited) { + ValueNumberedSet& anticIn = anticipatedIn[BB]; + unsigned old = anticIn.size(); + + bool defer = buildsets_anticout(BB, anticOut, visited); + if (defer) + return 0; + + anticIn.clear(); + + for (ValueNumberedSet::iterator I = anticOut.begin(), + E = anticOut.end(); I != E; ++I) { + anticIn.insert(*I); + anticIn.set(VN.lookup(*I)); + } + for (ValueNumberedSet::iterator I = currExps.begin(), + E = currExps.end(); I != E; ++I) { + if (!anticIn.test(VN.lookup(*I))) { + anticIn.insert(*I); + anticIn.set(VN.lookup(*I)); + } + } + + for (SmallPtrSet<Value*, 16>::iterator I = currTemps.begin(), + E = currTemps.end(); I != E; ++I) { + anticIn.erase(*I); + anticIn.reset(VN.lookup(*I)); + } + + clean(anticIn); + anticOut.clear(); + + if (old != anticIn.size()) + return 2; + else + return 1; +} + +/// buildsets - Phase 1 of the main algorithm. Construct the AVAIL_OUT +/// and the ANTIC_IN sets. +void GVNPRE::buildsets(Function& F) { + DenseMap<BasicBlock*, ValueNumberedSet> generatedExpressions; + DenseMap<BasicBlock*, SmallPtrSet<Value*, 16> > generatedTemporaries; + + DominatorTree &DT = getAnalysis<DominatorTree>(); + + // Phase 1, Part 1: calculate AVAIL_OUT + + // Top-down walk of the dominator tree + for (df_iterator<DomTreeNode*> DI = df_begin(DT.getRootNode()), + E = df_end(DT.getRootNode()); DI != E; ++DI) { + + // Get the sets to update for this block + ValueNumberedSet& currExps = generatedExpressions[DI->getBlock()]; + ValueNumberedSet& currPhis = generatedPhis[DI->getBlock()]; + SmallPtrSet<Value*, 16>& currTemps = generatedTemporaries[DI->getBlock()]; + ValueNumberedSet& currAvail = availableOut[DI->getBlock()]; + + BasicBlock* BB = DI->getBlock(); + + // A block inherits AVAIL_OUT from its dominator + if (DI->getIDom() != 0) + currAvail = availableOut[DI->getIDom()->getBlock()]; + + for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); + BI != BE; ++BI) + buildsets_availout(BI, currAvail, currPhis, currExps, + currTemps); + + } + + // Phase 1, Part 2: calculate ANTIC_IN + + SmallPtrSet<BasicBlock*, 8> visited; + SmallPtrSet<BasicBlock*, 4> block_changed; + for (Function::iterator FI = F.begin(), FE = F.end(); FI != FE; ++FI) + block_changed.insert(FI); + + bool changed = true; + unsigned iterations = 0; + + while (changed) { + changed = false; + ValueNumberedSet anticOut; + + // Postorder walk of the CFG + for (po_iterator<BasicBlock*> BBI = po_begin(&F.getEntryBlock()), + BBE = po_end(&F.getEntryBlock()); BBI != BBE; ++BBI) { + BasicBlock* BB = *BBI; + + if (block_changed.count(BB) != 0) { + unsigned ret = buildsets_anticin(BB, anticOut,generatedExpressions[BB], + generatedTemporaries[BB], visited); + + if (ret == 0) { + changed = true; + continue; + } else { + visited.insert(BB); + + if (ret == 2) + for (pred_iterator PI = pred_begin(BB), PE = pred_end(BB); + PI != PE; ++PI) { + block_changed.insert(*PI); + } + else + block_changed.erase(BB); + + changed |= (ret == 2); + } + } + } + + iterations++; + } +} + +/// insertion_pre - When a partial redundancy has been identified, eliminate it +/// by inserting appropriate values into the predecessors and a phi node in +/// the main block +void GVNPRE::insertion_pre(Value* e, BasicBlock* BB, + DenseMap<BasicBlock*, Value*>& avail, + std::map<BasicBlock*, ValueNumberedSet>& new_sets) { + for (pred_iterator PI = pred_begin(BB), PE = pred_end(BB); PI != PE; ++PI) { + Value* e2 = avail[*PI]; + if (!availableOut[*PI].test(VN.lookup(e2))) { + User* U = cast<User>(e2); + + Value* s1 = 0; + if (isa<BinaryOperator>(U->getOperand(0)) || + isa<CmpInst>(U->getOperand(0)) || + isa<ShuffleVectorInst>(U->getOperand(0)) || + isa<ExtractElementInst>(U->getOperand(0)) || + isa<InsertElementInst>(U->getOperand(0)) || + isa<SelectInst>(U->getOperand(0)) || + isa<CastInst>(U->getOperand(0)) || + isa<GetElementPtrInst>(U->getOperand(0))) + s1 = find_leader(availableOut[*PI], VN.lookup(U->getOperand(0))); + else + s1 = U->getOperand(0); + + Value* s2 = 0; + + if (isa<BinaryOperator>(U) || + isa<CmpInst>(U) || + isa<ShuffleVectorInst>(U) || + isa<ExtractElementInst>(U) || + isa<InsertElementInst>(U) || + isa<SelectInst>(U)) { + if (isa<BinaryOperator>(U->getOperand(1)) || + isa<CmpInst>(U->getOperand(1)) || + isa<ShuffleVectorInst>(U->getOperand(1)) || + isa<ExtractElementInst>(U->getOperand(1)) || + isa<InsertElementInst>(U->getOperand(1)) || + isa<SelectInst>(U->getOperand(1)) || + isa<CastInst>(U->getOperand(1)) || + isa<GetElementPtrInst>(U->getOperand(1))) { + s2 = find_leader(availableOut[*PI], VN.lookup(U->getOperand(1))); + } else { + s2 = U->getOperand(1); + } + } + + // Ternary Operators + Value* s3 = 0; + if (isa<ShuffleVectorInst>(U) || + isa<InsertElementInst>(U) || + isa<SelectInst>(U)) { + if (isa<BinaryOperator>(U->getOperand(2)) || + isa<CmpInst>(U->getOperand(2)) || + isa<ShuffleVectorInst>(U->getOperand(2)) || + isa<ExtractElementInst>(U->getOperand(2)) || + isa<InsertElementInst>(U->getOperand(2)) || + isa<SelectInst>(U->getOperand(2)) || + isa<CastInst>(U->getOperand(2)) || + isa<GetElementPtrInst>(U->getOperand(2))) { + s3 = find_leader(availableOut[*PI], VN.lookup(U->getOperand(2))); + } else { + s3 = U->getOperand(2); + } + } + + // Vararg operators + SmallVector<Value*, 4> sVarargs; + if (GetElementPtrInst* G = dyn_cast<GetElementPtrInst>(U)) { + for (GetElementPtrInst::op_iterator OI = G->idx_begin(), + OE = G->idx_end(); OI != OE; ++OI) { + if (isa<BinaryOperator>(*OI) || + isa<CmpInst>(*OI) || + isa<ShuffleVectorInst>(*OI) || + isa<ExtractElementInst>(*OI) || + isa<InsertElementInst>(*OI) || + isa<SelectInst>(*OI) || + isa<CastInst>(*OI) || + isa<GetElementPtrInst>(*OI)) { + sVarargs.push_back(find_leader(availableOut[*PI], + VN.lookup(*OI))); + } else { + sVarargs.push_back(*OI); + } + } + } + + Value* newVal = 0; + if (BinaryOperator* BO = dyn_cast<BinaryOperator>(U)) + newVal = BinaryOperator::Create(BO->getOpcode(), s1, s2, + BO->getName()+".gvnpre", + (*PI)->getTerminator()); + else if (CmpInst* C = dyn_cast<CmpInst>(U)) + newVal = CmpInst::Create(C->getOpcode(), C->getPredicate(), s1, s2, + C->getName()+".gvnpre", + (*PI)->getTerminator()); + else if (ShuffleVectorInst* S = dyn_cast<ShuffleVectorInst>(U)) + newVal = new ShuffleVectorInst(s1, s2, s3, S->getName()+".gvnpre", + (*PI)->getTerminator()); + else if (InsertElementInst* S = dyn_cast<InsertElementInst>(U)) + newVal = InsertElementInst::Create(s1, s2, s3, S->getName()+".gvnpre", + (*PI)->getTerminator()); + else if (ExtractElementInst* S = dyn_cast<ExtractElementInst>(U)) + newVal = new ExtractElementInst(s1, s2, S->getName()+".gvnpre", + (*PI)->getTerminator()); + else if (SelectInst* S = dyn_cast<SelectInst>(U)) + newVal = SelectInst::Create(s1, s2, s3, S->getName()+".gvnpre", + (*PI)->getTerminator()); + else if (CastInst* C = dyn_cast<CastInst>(U)) + newVal = CastInst::Create(C->getOpcode(), s1, C->getType(), + C->getName()+".gvnpre", + (*PI)->getTerminator()); + else if (GetElementPtrInst* G = dyn_cast<GetElementPtrInst>(U)) + newVal = GetElementPtrInst::Create(s1, sVarargs.begin(), sVarargs.end(), + G->getName()+".gvnpre", + (*PI)->getTerminator()); + + VN.add(newVal, VN.lookup(U)); + + ValueNumberedSet& predAvail = availableOut[*PI]; + val_replace(predAvail, newVal); + val_replace(new_sets[*PI], newVal); + predAvail.set(VN.lookup(newVal)); + + DenseMap<BasicBlock*, Value*>::iterator av = avail.find(*PI); + if (av != avail.end()) + avail.erase(av); + avail.insert(std::make_pair(*PI, newVal)); + + ++NumInsertedVals; + } + } + + PHINode* p = 0; + + for (pred_iterator PI = pred_begin(BB), PE = pred_end(BB); PI != PE; ++PI) { + if (p == 0) + p = PHINode::Create(avail[*PI]->getType(), "gvnpre-join", BB->begin()); + + p->addIncoming(avail[*PI], *PI); + } + + VN.add(p, VN.lookup(e)); + val_replace(availableOut[BB], p); + availableOut[BB].set(VN.lookup(e)); + generatedPhis[BB].insert(p); + generatedPhis[BB].set(VN.lookup(e)); + new_sets[BB].insert(p); + new_sets[BB].set(VN.lookup(e)); + + ++NumInsertedPhis; +} + +/// insertion_mergepoint - When walking the dom tree, check at each merge +/// block for the possibility of a partial redundancy. If present, eliminate it +unsigned GVNPRE::insertion_mergepoint(SmallVector<Value*, 8>& workList, + df_iterator<DomTreeNode*>& D, + std::map<BasicBlock*, ValueNumberedSet >& new_sets) { + bool changed_function = false; + bool new_stuff = false; + + BasicBlock* BB = D->getBlock(); + for (unsigned i = 0; i < workList.size(); ++i) { + Value* e = workList[i]; + + if (isa<BinaryOperator>(e) || isa<CmpInst>(e) || + isa<ExtractElementInst>(e) || isa<InsertElementInst>(e) || + isa<ShuffleVectorInst>(e) || isa<SelectInst>(e) || isa<CastInst>(e) || + isa<GetElementPtrInst>(e)) { + if (availableOut[D->getIDom()->getBlock()].test(VN.lookup(e))) + continue; + + DenseMap<BasicBlock*, Value*> avail; + bool by_some = false; + bool all_same = true; + Value * first_s = 0; + + for (pred_iterator PI = pred_begin(BB), PE = pred_end(BB); PI != PE; + ++PI) { + Value *e2 = phi_translate(e, *PI, BB); + Value *e3 = find_leader(availableOut[*PI], VN.lookup(e2)); + + if (e3 == 0) { + DenseMap<BasicBlock*, Value*>::iterator av = avail.find(*PI); + if (av != avail.end()) + avail.erase(av); + avail.insert(std::make_pair(*PI, e2)); + all_same = false; + } else { + DenseMap<BasicBlock*, Value*>::iterator av = avail.find(*PI); + if (av != avail.end()) + avail.erase(av); + avail.insert(std::make_pair(*PI, e3)); + + by_some = true; + if (first_s == 0) + first_s = e3; + else if (first_s != e3) + all_same = false; + } + } + + if (by_some && !all_same && + !generatedPhis[BB].test(VN.lookup(e))) { + insertion_pre(e, BB, avail, new_sets); + + changed_function = true; + new_stuff = true; + } + } + } + + unsigned retval = 0; + if (changed_function) + retval += 1; + if (new_stuff) + retval += 2; + + return retval; +} + +/// insert - Phase 2 of the main algorithm. Walk the dominator tree looking for +/// merge points. When one is found, check for a partial redundancy. If one is +/// present, eliminate it. Repeat this walk until no changes are made. +bool GVNPRE::insertion(Function& F) { + bool changed_function = false; + + DominatorTree &DT = getAnalysis<DominatorTree>(); + + std::map<BasicBlock*, ValueNumberedSet> new_sets; + bool new_stuff = true; + while (new_stuff) { + new_stuff = false; + for (df_iterator<DomTreeNode*> DI = df_begin(DT.getRootNode()), + E = df_end(DT.getRootNode()); DI != E; ++DI) { + BasicBlock* BB = DI->getBlock(); + + if (BB == 0) + continue; + + ValueNumberedSet& availOut = availableOut[BB]; + ValueNumberedSet& anticIn = anticipatedIn[BB]; + + // Replace leaders with leaders inherited from dominator + if (DI->getIDom() != 0) { + ValueNumberedSet& dom_set = new_sets[DI->getIDom()->getBlock()]; + for (ValueNumberedSet::iterator I = dom_set.begin(), + E = dom_set.end(); I != E; ++I) { + val_replace(new_sets[BB], *I); + val_replace(availOut, *I); + } + } + + // If there is more than one predecessor... + if (pred_begin(BB) != pred_end(BB) && ++pred_begin(BB) != pred_end(BB)) { + SmallVector<Value*, 8> workList; + workList.reserve(anticIn.size()); + topo_sort(anticIn, workList); + + unsigned result = insertion_mergepoint(workList, DI, new_sets); + if (result & 1) + changed_function = true; + if (result & 2) + new_stuff = true; + } + } + } + + return changed_function; +} + +// GVNPRE::runOnFunction - This is the main transformation entry point for a +// function. +// +bool GVNPRE::runOnFunction(Function &F) { + // Clean out global sets from any previous functions + VN.clear(); + createdExpressions.clear(); + availableOut.clear(); + anticipatedIn.clear(); + generatedPhis.clear(); + + bool changed_function = false; + + // Phase 1: BuildSets + // This phase calculates the AVAIL_OUT and ANTIC_IN sets + buildsets(F); + + // Phase 2: Insert + // This phase inserts values to make partially redundant values + // fully redundant + changed_function |= insertion(F); + + // Phase 3: Eliminate + // This phase performs trivial full redundancy elimination + changed_function |= elimination(); + + // Phase 4: Cleanup + // This phase cleans up values that were created solely + // as leaders for expressions + cleanup(); + + return changed_function; +} diff --git a/lib/Transforms/Scalar/IndVarSimplify.cpp b/lib/Transforms/Scalar/IndVarSimplify.cpp new file mode 100644 index 0000000..ca7aa7b --- /dev/null +++ b/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -0,0 +1,880 @@ +//===- IndVarSimplify.cpp - Induction Variable Elimination ----------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This transformation analyzes and transforms the induction variables (and +// computations derived from them) into simpler forms suitable for subsequent +// analysis and transformation. +// +// This transformation makes the following changes to each loop with an +// identifiable induction variable: +// 1. All loops are transformed to have a SINGLE canonical induction variable +// which starts at zero and steps by one. +// 2. The canonical induction variable is guaranteed to be the first PHI node +// in the loop header block. +// 3. Any pointer arithmetic recurrences are raised to use array subscripts. +// +// If the trip count of a loop is computable, this pass also makes the following +// changes: +// 1. The exit condition for the loop is canonicalized to compare the +// induction value against the exit value. This turns loops like: +// 'for (i = 7; i*i < 1000; ++i)' into 'for (i = 0; i != 25; ++i)' +// 2. Any use outside of the loop of an expression derived from the indvar +// is changed to compute the derived value outside of the loop, eliminating +// the dependence on the exit value of the induction variable. If the only +// purpose of the loop is to compute the exit value of some derived +// expression, this transformation will make the loop dead. +// +// This transformation should be followed by strength reduction after all of the +// desired loop transformations have been performed. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "indvars" +#include "llvm/Transforms/Scalar.h" +#include "llvm/BasicBlock.h" +#include "llvm/Constants.h" +#include "llvm/Instructions.h" +#include "llvm/Type.h" +#include "llvm/Analysis/Dominators.h" +#include "llvm/Analysis/IVUsers.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Support/CFG.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/STLExtras.h" +using namespace llvm; + +STATISTIC(NumRemoved , "Number of aux indvars removed"); +STATISTIC(NumInserted, "Number of canonical indvars added"); +STATISTIC(NumReplaced, "Number of exit values replaced"); +STATISTIC(NumLFTR , "Number of loop exit tests replaced"); + +namespace { + class VISIBILITY_HIDDEN IndVarSimplify : public LoopPass { + IVUsers *IU; + LoopInfo *LI; + ScalarEvolution *SE; + bool Changed; + public: + + static char ID; // Pass identification, replacement for typeid + IndVarSimplify() : LoopPass(&ID) {} + + virtual bool runOnLoop(Loop *L, LPPassManager &LPM); + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired<DominatorTree>(); + AU.addRequired<ScalarEvolution>(); + AU.addRequiredID(LCSSAID); + AU.addRequiredID(LoopSimplifyID); + AU.addRequired<LoopInfo>(); + AU.addRequired<IVUsers>(); + AU.addPreserved<ScalarEvolution>(); + AU.addPreservedID(LoopSimplifyID); + AU.addPreserved<IVUsers>(); + AU.addPreservedID(LCSSAID); + AU.setPreservesCFG(); + } + + private: + + void RewriteNonIntegerIVs(Loop *L); + + ICmpInst *LinearFunctionTestReplace(Loop *L, SCEVHandle BackedgeTakenCount, + Value *IndVar, + BasicBlock *ExitingBlock, + BranchInst *BI, + SCEVExpander &Rewriter); + void RewriteLoopExitValues(Loop *L, const SCEV *BackedgeTakenCount); + + void RewriteIVExpressions(Loop *L, const Type *LargestType, + SCEVExpander &Rewriter); + + void SinkUnusedInvariants(Loop *L, SCEVExpander &Rewriter); + + void FixUsesBeforeDefs(Loop *L, SCEVExpander &Rewriter); + + void HandleFloatingPointIV(Loop *L, PHINode *PH); + }; +} + +char IndVarSimplify::ID = 0; +static RegisterPass<IndVarSimplify> +X("indvars", "Canonicalize Induction Variables"); + +Pass *llvm::createIndVarSimplifyPass() { + return new IndVarSimplify(); +} + +/// LinearFunctionTestReplace - This method rewrites the exit condition of the +/// loop to be a canonical != comparison against the incremented loop induction +/// variable. This pass is able to rewrite the exit tests of any loop where the +/// SCEV analysis can determine a loop-invariant trip count of the loop, which +/// is actually a much broader range than just linear tests. +ICmpInst *IndVarSimplify::LinearFunctionTestReplace(Loop *L, + SCEVHandle BackedgeTakenCount, + Value *IndVar, + BasicBlock *ExitingBlock, + BranchInst *BI, + SCEVExpander &Rewriter) { + // If the exiting block is not the same as the backedge block, we must compare + // against the preincremented value, otherwise we prefer to compare against + // the post-incremented value. + Value *CmpIndVar; + SCEVHandle RHS = BackedgeTakenCount; + if (ExitingBlock == L->getLoopLatch()) { + // Add one to the "backedge-taken" count to get the trip count. + // If this addition may overflow, we have to be more pessimistic and + // cast the induction variable before doing the add. + SCEVHandle Zero = SE->getIntegerSCEV(0, BackedgeTakenCount->getType()); + SCEVHandle N = + SE->getAddExpr(BackedgeTakenCount, + SE->getIntegerSCEV(1, BackedgeTakenCount->getType())); + if ((isa<SCEVConstant>(N) && !N->isZero()) || + SE->isLoopGuardedByCond(L, ICmpInst::ICMP_NE, N, Zero)) { + // No overflow. Cast the sum. + RHS = SE->getTruncateOrZeroExtend(N, IndVar->getType()); + } else { + // Potential overflow. Cast before doing the add. + RHS = SE->getTruncateOrZeroExtend(BackedgeTakenCount, + IndVar->getType()); + RHS = SE->getAddExpr(RHS, + SE->getIntegerSCEV(1, IndVar->getType())); + } + + // The BackedgeTaken expression contains the number of times that the + // backedge branches to the loop header. This is one less than the + // number of times the loop executes, so use the incremented indvar. + CmpIndVar = L->getCanonicalInductionVariableIncrement(); + } else { + // We have to use the preincremented value... + RHS = SE->getTruncateOrZeroExtend(BackedgeTakenCount, + IndVar->getType()); + CmpIndVar = IndVar; + } + + // Expand the code for the iteration count into the preheader of the loop. + BasicBlock *Preheader = L->getLoopPreheader(); + Value *ExitCnt = Rewriter.expandCodeFor(RHS, CmpIndVar->getType(), + Preheader->getTerminator()); + + // Insert a new icmp_ne or icmp_eq instruction before the branch. + ICmpInst::Predicate Opcode; + if (L->contains(BI->getSuccessor(0))) + Opcode = ICmpInst::ICMP_NE; + else + Opcode = ICmpInst::ICMP_EQ; + + DOUT << "INDVARS: Rewriting loop exit condition to:\n" + << " LHS:" << *CmpIndVar // includes a newline + << " op:\t" + << (Opcode == ICmpInst::ICMP_NE ? "!=" : "==") << "\n" + << " RHS:\t" << *RHS << "\n"; + + ICmpInst *Cond = new ICmpInst(Opcode, CmpIndVar, ExitCnt, "exitcond", BI); + + Instruction *OrigCond = cast<Instruction>(BI->getCondition()); + // It's tempting to use replaceAllUsesWith here to fully replace the old + // comparison, but that's not immediately safe, since users of the old + // comparison may not be dominated by the new comparison. Instead, just + // update the branch to use the new comparison; in the common case this + // will make old comparison dead. + BI->setCondition(Cond); + RecursivelyDeleteTriviallyDeadInstructions(OrigCond); + + ++NumLFTR; + Changed = true; + return Cond; +} + +/// RewriteLoopExitValues - Check to see if this loop has a computable +/// loop-invariant execution count. If so, this means that we can compute the +/// final value of any expressions that are recurrent in the loop, and +/// substitute the exit values from the loop into any instructions outside of +/// the loop that use the final values of the current expressions. +/// +/// This is mostly redundant with the regular IndVarSimplify activities that +/// happen later, except that it's more powerful in some cases, because it's +/// able to brute-force evaluate arbitrary instructions as long as they have +/// constant operands at the beginning of the loop. +void IndVarSimplify::RewriteLoopExitValues(Loop *L, + const SCEV *BackedgeTakenCount) { + // Verify the input to the pass in already in LCSSA form. + assert(L->isLCSSAForm()); + + BasicBlock *Preheader = L->getLoopPreheader(); + + // Scan all of the instructions in the loop, looking at those that have + // extra-loop users and which are recurrences. + SCEVExpander Rewriter(*SE); + + // We insert the code into the preheader of the loop if the loop contains + // multiple exit blocks, or in the exit block if there is exactly one. + BasicBlock *BlockToInsertInto; + SmallVector<BasicBlock*, 8> ExitBlocks; + L->getUniqueExitBlocks(ExitBlocks); + if (ExitBlocks.size() == 1) + BlockToInsertInto = ExitBlocks[0]; + else + BlockToInsertInto = Preheader; + BasicBlock::iterator InsertPt = BlockToInsertInto->getFirstNonPHI(); + + std::map<Instruction*, Value*> ExitValues; + + // Find all values that are computed inside the loop, but used outside of it. + // Because of LCSSA, these values will only occur in LCSSA PHI Nodes. Scan + // the exit blocks of the loop to find them. + for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i) { + BasicBlock *ExitBB = ExitBlocks[i]; + + // If there are no PHI nodes in this exit block, then no values defined + // inside the loop are used on this path, skip it. + PHINode *PN = dyn_cast<PHINode>(ExitBB->begin()); + if (!PN) continue; + + unsigned NumPreds = PN->getNumIncomingValues(); + + // Iterate over all of the PHI nodes. + BasicBlock::iterator BBI = ExitBB->begin(); + while ((PN = dyn_cast<PHINode>(BBI++))) { + if (PN->use_empty()) + continue; // dead use, don't replace it + // Iterate over all of the values in all the PHI nodes. + for (unsigned i = 0; i != NumPreds; ++i) { + // If the value being merged in is not integer or is not defined + // in the loop, skip it. + Value *InVal = PN->getIncomingValue(i); + if (!isa<Instruction>(InVal) || + // SCEV only supports integer expressions for now. + (!isa<IntegerType>(InVal->getType()) && + !isa<PointerType>(InVal->getType()))) + continue; + + // If this pred is for a subloop, not L itself, skip it. + if (LI->getLoopFor(PN->getIncomingBlock(i)) != L) + continue; // The Block is in a subloop, skip it. + + // Check that InVal is defined in the loop. + Instruction *Inst = cast<Instruction>(InVal); + if (!L->contains(Inst->getParent())) + continue; + + // Okay, this instruction has a user outside of the current loop + // and varies predictably *inside* the loop. Evaluate the value it + // contains when the loop exits, if possible. + SCEVHandle ExitValue = SE->getSCEVAtScope(Inst, L->getParentLoop()); + if (!ExitValue->isLoopInvariant(L)) + continue; + + Changed = true; + ++NumReplaced; + + // See if we already computed the exit value for the instruction, if so, + // just reuse it. + Value *&ExitVal = ExitValues[Inst]; + if (!ExitVal) + ExitVal = Rewriter.expandCodeFor(ExitValue, PN->getType(), InsertPt); + + DOUT << "INDVARS: RLEV: AfterLoopVal = " << *ExitVal + << " LoopVal = " << *Inst << "\n"; + + PN->setIncomingValue(i, ExitVal); + + // If this instruction is dead now, delete it. + RecursivelyDeleteTriviallyDeadInstructions(Inst); + + // See if this is a single-entry LCSSA PHI node. If so, we can (and + // have to) remove + // the PHI entirely. This is safe, because the NewVal won't be variant + // in the loop, so we don't need an LCSSA phi node anymore. + if (NumPreds == 1) { + PN->replaceAllUsesWith(ExitVal); + RecursivelyDeleteTriviallyDeadInstructions(PN); + break; + } + } + } + } +} + +void IndVarSimplify::RewriteNonIntegerIVs(Loop *L) { + // First step. Check to see if there are any floating-point recurrences. + // If there are, change them into integer recurrences, permitting analysis by + // the SCEV routines. + // + BasicBlock *Header = L->getHeader(); + + SmallVector<WeakVH, 8> PHIs; + for (BasicBlock::iterator I = Header->begin(); + PHINode *PN = dyn_cast<PHINode>(I); ++I) + PHIs.push_back(PN); + + for (unsigned i = 0, e = PHIs.size(); i != e; ++i) + if (PHINode *PN = dyn_cast_or_null<PHINode>(PHIs[i])) + HandleFloatingPointIV(L, PN); + + // If the loop previously had floating-point IV, ScalarEvolution + // may not have been able to compute a trip count. Now that we've done some + // re-writing, the trip count may be computable. + if (Changed) + SE->forgetLoopBackedgeTakenCount(L); +} + +bool IndVarSimplify::runOnLoop(Loop *L, LPPassManager &LPM) { + IU = &getAnalysis<IVUsers>(); + LI = &getAnalysis<LoopInfo>(); + SE = &getAnalysis<ScalarEvolution>(); + Changed = false; + + // If there are any floating-point recurrences, attempt to + // transform them to use integer recurrences. + RewriteNonIntegerIVs(L); + + BasicBlock *Header = L->getHeader(); + BasicBlock *ExitingBlock = L->getExitingBlock(); // may be null + SCEVHandle BackedgeTakenCount = SE->getBackedgeTakenCount(L); + + // Check to see if this loop has a computable loop-invariant execution count. + // If so, this means that we can compute the final value of any expressions + // that are recurrent in the loop, and substitute the exit values from the + // loop into any instructions outside of the loop that use the final values of + // the current expressions. + // + if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount)) + RewriteLoopExitValues(L, BackedgeTakenCount); + + // Compute the type of the largest recurrence expression, and decide whether + // a canonical induction variable should be inserted. + const Type *LargestType = 0; + bool NeedCannIV = false; + if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount)) { + LargestType = BackedgeTakenCount->getType(); + LargestType = SE->getEffectiveSCEVType(LargestType); + // If we have a known trip count and a single exit block, we'll be + // rewriting the loop exit test condition below, which requires a + // canonical induction variable. + if (ExitingBlock) + NeedCannIV = true; + } + for (unsigned i = 0, e = IU->StrideOrder.size(); i != e; ++i) { + SCEVHandle Stride = IU->StrideOrder[i]; + const Type *Ty = SE->getEffectiveSCEVType(Stride->getType()); + if (!LargestType || + SE->getTypeSizeInBits(Ty) > + SE->getTypeSizeInBits(LargestType)) + LargestType = Ty; + + std::map<SCEVHandle, IVUsersOfOneStride *>::iterator SI = + IU->IVUsesByStride.find(IU->StrideOrder[i]); + assert(SI != IU->IVUsesByStride.end() && "Stride doesn't exist!"); + + if (!SI->second->Users.empty()) + NeedCannIV = true; + } + + // Create a rewriter object which we'll use to transform the code with. + SCEVExpander Rewriter(*SE); + + // Now that we know the largest of of the induction variable expressions + // in this loop, insert a canonical induction variable of the largest size. + Value *IndVar = 0; + if (NeedCannIV) { + IndVar = Rewriter.getOrInsertCanonicalInductionVariable(L,LargestType); + ++NumInserted; + Changed = true; + DOUT << "INDVARS: New CanIV: " << *IndVar; + } + + // If we have a trip count expression, rewrite the loop's exit condition + // using it. We can currently only handle loops with a single exit. + ICmpInst *NewICmp = 0; + if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) && ExitingBlock) { + assert(NeedCannIV && + "LinearFunctionTestReplace requires a canonical induction variable"); + // Can't rewrite non-branch yet. + if (BranchInst *BI = dyn_cast<BranchInst>(ExitingBlock->getTerminator())) + NewICmp = LinearFunctionTestReplace(L, BackedgeTakenCount, IndVar, + ExitingBlock, BI, Rewriter); + } + + Rewriter.setInsertionPoint(Header->getFirstNonPHI()); + + // Rewrite IV-derived expressions. Clears the rewriter cache. + RewriteIVExpressions(L, LargestType, Rewriter); + + // The Rewriter may only be used for isInsertedInstruction queries from this + // point on. + + // Loop-invariant instructions in the preheader that aren't used in the + // loop may be sunk below the loop to reduce register pressure. + SinkUnusedInvariants(L, Rewriter); + + // Reorder instructions to avoid use-before-def conditions. + FixUsesBeforeDefs(L, Rewriter); + + // For completeness, inform IVUsers of the IV use in the newly-created + // loop exit test instruction. + if (NewICmp) + IU->AddUsersIfInteresting(cast<Instruction>(NewICmp->getOperand(0))); + + // Clean up dead instructions. + DeleteDeadPHIs(L->getHeader()); + // Check a post-condition. + assert(L->isLCSSAForm() && "Indvars did not leave the loop in lcssa form!"); + return Changed; +} + +void IndVarSimplify::RewriteIVExpressions(Loop *L, const Type *LargestType, + SCEVExpander &Rewriter) { + SmallVector<WeakVH, 16> DeadInsts; + + // Rewrite all induction variable expressions in terms of the canonical + // induction variable. + // + // If there were induction variables of other sizes or offsets, manually + // add the offsets to the primary induction variable and cast, avoiding + // the need for the code evaluation methods to insert induction variables + // of different sizes. + for (unsigned i = 0, e = IU->StrideOrder.size(); i != e; ++i) { + SCEVHandle Stride = IU->StrideOrder[i]; + + std::map<SCEVHandle, IVUsersOfOneStride *>::iterator SI = + IU->IVUsesByStride.find(IU->StrideOrder[i]); + assert(SI != IU->IVUsesByStride.end() && "Stride doesn't exist!"); + ilist<IVStrideUse> &List = SI->second->Users; + for (ilist<IVStrideUse>::iterator UI = List.begin(), + E = List.end(); UI != E; ++UI) { + SCEVHandle Offset = UI->getOffset(); + Value *Op = UI->getOperandValToReplace(); + Instruction *User = UI->getUser(); + bool isSigned = UI->isSigned(); + + // Compute the final addrec to expand into code. + SCEVHandle AR = IU->getReplacementExpr(*UI); + + // FIXME: It is an extremely bad idea to indvar substitute anything more + // complex than affine induction variables. Doing so will put expensive + // polynomial evaluations inside of the loop, and the str reduction pass + // currently can only reduce affine polynomials. For now just disable + // indvar subst on anything more complex than an affine addrec, unless + // it can be expanded to a trivial value. + if (!Stride->isLoopInvariant(L) && + !isa<SCEVConstant>(AR) && + L->contains(User->getParent())) + continue; + + Value *NewVal = 0; + if (AR->isLoopInvariant(L)) { + BasicBlock::iterator I = Rewriter.getInsertionPoint(); + // Expand loop-invariant values in the loop preheader. They will + // be sunk to the exit block later, if possible. + NewVal = + Rewriter.expandCodeFor(AR, LargestType, + L->getLoopPreheader()->getTerminator()); + Rewriter.setInsertionPoint(I); + ++NumReplaced; + } else { + const Type *IVTy = Offset->getType(); + const Type *UseTy = Op->getType(); + + // Promote the Offset and Stride up to the canonical induction + // variable's bit width. + SCEVHandle PromotedOffset = Offset; + SCEVHandle PromotedStride = Stride; + if (SE->getTypeSizeInBits(IVTy) != SE->getTypeSizeInBits(LargestType)) { + // It doesn't matter for correctness whether zero or sign extension + // is used here, since the value is truncated away below, but if the + // value is signed, sign extension is more likely to be folded. + if (isSigned) { + PromotedOffset = SE->getSignExtendExpr(PromotedOffset, LargestType); + PromotedStride = SE->getSignExtendExpr(PromotedStride, LargestType); + } else { + PromotedOffset = SE->getZeroExtendExpr(PromotedOffset, LargestType); + // If the stride is obviously negative, use sign extension to + // produce things like x-1 instead of x+255. + if (isa<SCEVConstant>(PromotedStride) && + cast<SCEVConstant>(PromotedStride) + ->getValue()->getValue().isNegative()) + PromotedStride = SE->getSignExtendExpr(PromotedStride, + LargestType); + else + PromotedStride = SE->getZeroExtendExpr(PromotedStride, + LargestType); + } + } + + // Create the SCEV representing the offset from the canonical + // induction variable, still in the canonical induction variable's + // type, so that all expanded arithmetic is done in the same type. + SCEVHandle NewAR = SE->getAddRecExpr(SE->getIntegerSCEV(0, LargestType), + PromotedStride, L); + // Add the PromotedOffset as a separate step, because it may not be + // loop-invariant. + NewAR = SE->getAddExpr(NewAR, PromotedOffset); + + // Expand the addrec into instructions. + Value *V = Rewriter.expandCodeFor(NewAR); + + // Insert an explicit cast if necessary to truncate the value + // down to the original stride type. This is done outside of + // SCEVExpander because in SCEV expressions, a truncate of an + // addrec is always folded. + if (LargestType != IVTy) { + if (SE->getTypeSizeInBits(IVTy) != SE->getTypeSizeInBits(LargestType)) + NewAR = SE->getTruncateExpr(NewAR, IVTy); + if (Rewriter.isInsertedExpression(NewAR)) + V = Rewriter.expandCodeFor(NewAR); + else { + V = Rewriter.InsertCastOfTo(CastInst::getCastOpcode(V, false, + IVTy, false), + V, IVTy); + assert(!isa<SExtInst>(V) && !isa<ZExtInst>(V) && + "LargestType wasn't actually the largest type!"); + // Force the rewriter to use this trunc whenever this addrec + // appears so that it doesn't insert new phi nodes or + // arithmetic in a different type. + Rewriter.addInsertedValue(V, NewAR); + } + } + + DOUT << "INDVARS: Made offset-and-trunc IV for offset " + << *IVTy << " " << *Offset << ": "; + DEBUG(WriteAsOperand(*DOUT, V, false)); + DOUT << "\n"; + + // Now expand it into actual Instructions and patch it into place. + NewVal = Rewriter.expandCodeFor(AR, UseTy); + } + + // Patch the new value into place. + if (Op->hasName()) + NewVal->takeName(Op); + User->replaceUsesOfWith(Op, NewVal); + UI->setOperandValToReplace(NewVal); + DOUT << "INDVARS: Rewrote IV '" << *AR << "' " << *Op + << " into = " << *NewVal << "\n"; + ++NumRemoved; + Changed = true; + + // The old value may be dead now. + DeadInsts.push_back(Op); + } + } + + // Clear the rewriter cache, because values that are in the rewriter's cache + // can be deleted in the loop below, causing the AssertingVH in the cache to + // trigger. + Rewriter.clear(); + // Now that we're done iterating through lists, clean up any instructions + // which are now dead. + while (!DeadInsts.empty()) { + Instruction *Inst = dyn_cast_or_null<Instruction>(DeadInsts.pop_back_val()); + if (Inst) + RecursivelyDeleteTriviallyDeadInstructions(Inst); + } +} + +/// If there's a single exit block, sink any loop-invariant values that +/// were defined in the preheader but not used inside the loop into the +/// exit block to reduce register pressure in the loop. +void IndVarSimplify::SinkUnusedInvariants(Loop *L, SCEVExpander &Rewriter) { + BasicBlock *ExitBlock = L->getExitBlock(); + if (!ExitBlock) return; + + Instruction *NonPHI = ExitBlock->getFirstNonPHI(); + BasicBlock *Preheader = L->getLoopPreheader(); + BasicBlock::iterator I = Preheader->getTerminator(); + while (I != Preheader->begin()) { + --I; + // New instructions were inserted at the end of the preheader. Only + // consider those new instructions. + if (!Rewriter.isInsertedInstruction(I)) + break; + // Determine if there is a use in or before the loop (direct or + // otherwise). + bool UsedInLoop = false; + for (Value::use_iterator UI = I->use_begin(), UE = I->use_end(); + UI != UE; ++UI) { + BasicBlock *UseBB = cast<Instruction>(UI)->getParent(); + if (PHINode *P = dyn_cast<PHINode>(UI)) { + unsigned i = + PHINode::getIncomingValueNumForOperand(UI.getOperandNo()); + UseBB = P->getIncomingBlock(i); + } + if (UseBB == Preheader || L->contains(UseBB)) { + UsedInLoop = true; + break; + } + } + // If there is, the def must remain in the preheader. + if (UsedInLoop) + continue; + // Otherwise, sink it to the exit block. + Instruction *ToMove = I; + bool Done = false; + if (I != Preheader->begin()) + --I; + else + Done = true; + ToMove->moveBefore(NonPHI); + if (Done) + break; + } +} + +/// Re-schedule the inserted instructions to put defs before uses. This +/// fixes problems that arrise when SCEV expressions contain loop-variant +/// values unrelated to the induction variable which are defined inside the +/// loop. FIXME: It would be better to insert instructions in the right +/// place so that this step isn't needed. +void IndVarSimplify::FixUsesBeforeDefs(Loop *L, SCEVExpander &Rewriter) { + // Visit all the blocks in the loop in pre-order dom-tree dfs order. + DominatorTree *DT = &getAnalysis<DominatorTree>(); + std::map<Instruction *, unsigned> NumPredsLeft; + SmallVector<DomTreeNode *, 16> Worklist; + Worklist.push_back(DT->getNode(L->getHeader())); + do { + DomTreeNode *Node = Worklist.pop_back_val(); + for (DomTreeNode::iterator I = Node->begin(), E = Node->end(); I != E; ++I) + if (L->contains((*I)->getBlock())) + Worklist.push_back(*I); + BasicBlock *BB = Node->getBlock(); + // Visit all the instructions in the block top down. + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) { + // Count the number of operands that aren't properly dominating. + unsigned NumPreds = 0; + if (Rewriter.isInsertedInstruction(I) && !isa<PHINode>(I)) + for (User::op_iterator OI = I->op_begin(), OE = I->op_end(); + OI != OE; ++OI) + if (Instruction *Inst = dyn_cast<Instruction>(OI)) + if (L->contains(Inst->getParent()) && !NumPredsLeft.count(Inst)) + ++NumPreds; + NumPredsLeft[I] = NumPreds; + // Notify uses of the position of this instruction, and move the + // users (and their dependents, recursively) into place after this + // instruction if it is their last outstanding operand. + for (Value::use_iterator UI = I->use_begin(), UE = I->use_end(); + UI != UE; ++UI) { + Instruction *Inst = cast<Instruction>(UI); + std::map<Instruction *, unsigned>::iterator Z = NumPredsLeft.find(Inst); + if (Z != NumPredsLeft.end() && Z->second != 0 && --Z->second == 0) { + SmallVector<Instruction *, 4> UseWorkList; + UseWorkList.push_back(Inst); + BasicBlock::iterator InsertPt = I; + if (InvokeInst *II = dyn_cast<InvokeInst>(InsertPt)) + InsertPt = II->getNormalDest()->begin(); + else + ++InsertPt; + while (isa<PHINode>(InsertPt)) ++InsertPt; + do { + Instruction *Use = UseWorkList.pop_back_val(); + Use->moveBefore(InsertPt); + NumPredsLeft.erase(Use); + for (Value::use_iterator IUI = Use->use_begin(), + IUE = Use->use_end(); IUI != IUE; ++IUI) { + Instruction *IUIInst = cast<Instruction>(IUI); + if (L->contains(IUIInst->getParent()) && + Rewriter.isInsertedInstruction(IUIInst) && + !isa<PHINode>(IUIInst)) + UseWorkList.push_back(IUIInst); + } + } while (!UseWorkList.empty()); + } + } + } + } while (!Worklist.empty()); +} + +/// Return true if it is OK to use SIToFPInst for an inducation variable +/// with given inital and exit values. +static bool useSIToFPInst(ConstantFP &InitV, ConstantFP &ExitV, + uint64_t intIV, uint64_t intEV) { + + if (InitV.getValueAPF().isNegative() || ExitV.getValueAPF().isNegative()) + return true; + + // If the iteration range can be handled by SIToFPInst then use it. + APInt Max = APInt::getSignedMaxValue(32); + if (Max.getZExtValue() > static_cast<uint64_t>(abs64(intEV - intIV))) + return true; + + return false; +} + +/// convertToInt - Convert APF to an integer, if possible. +static bool convertToInt(const APFloat &APF, uint64_t *intVal) { + + bool isExact = false; + if (&APF.getSemantics() == &APFloat::PPCDoubleDouble) + return false; + if (APF.convertToInteger(intVal, 32, APF.isNegative(), + APFloat::rmTowardZero, &isExact) + != APFloat::opOK) + return false; + if (!isExact) + return false; + return true; + +} + +/// HandleFloatingPointIV - If the loop has floating induction variable +/// then insert corresponding integer induction variable if possible. +/// For example, +/// for(double i = 0; i < 10000; ++i) +/// bar(i) +/// is converted into +/// for(int i = 0; i < 10000; ++i) +/// bar((double)i); +/// +void IndVarSimplify::HandleFloatingPointIV(Loop *L, PHINode *PH) { + + unsigned IncomingEdge = L->contains(PH->getIncomingBlock(0)); + unsigned BackEdge = IncomingEdge^1; + + // Check incoming value. + ConstantFP *InitValue = dyn_cast<ConstantFP>(PH->getIncomingValue(IncomingEdge)); + if (!InitValue) return; + uint64_t newInitValue = Type::Int32Ty->getPrimitiveSizeInBits(); + if (!convertToInt(InitValue->getValueAPF(), &newInitValue)) + return; + + // Check IV increment. Reject this PH if increement operation is not + // an add or increment value can not be represented by an integer. + BinaryOperator *Incr = + dyn_cast<BinaryOperator>(PH->getIncomingValue(BackEdge)); + if (!Incr) return; + if (Incr->getOpcode() != Instruction::Add) return; + ConstantFP *IncrValue = NULL; + unsigned IncrVIndex = 1; + if (Incr->getOperand(1) == PH) + IncrVIndex = 0; + IncrValue = dyn_cast<ConstantFP>(Incr->getOperand(IncrVIndex)); + if (!IncrValue) return; + uint64_t newIncrValue = Type::Int32Ty->getPrimitiveSizeInBits(); + if (!convertToInt(IncrValue->getValueAPF(), &newIncrValue)) + return; + + // Check Incr uses. One user is PH and the other users is exit condition used + // by the conditional terminator. + Value::use_iterator IncrUse = Incr->use_begin(); + Instruction *U1 = cast<Instruction>(IncrUse++); + if (IncrUse == Incr->use_end()) return; + Instruction *U2 = cast<Instruction>(IncrUse++); + if (IncrUse != Incr->use_end()) return; + + // Find exit condition. + FCmpInst *EC = dyn_cast<FCmpInst>(U1); + if (!EC) + EC = dyn_cast<FCmpInst>(U2); + if (!EC) return; + + if (BranchInst *BI = dyn_cast<BranchInst>(EC->getParent()->getTerminator())) { + if (!BI->isConditional()) return; + if (BI->getCondition() != EC) return; + } + + // Find exit value. If exit value can not be represented as an interger then + // do not handle this floating point PH. + ConstantFP *EV = NULL; + unsigned EVIndex = 1; + if (EC->getOperand(1) == Incr) + EVIndex = 0; + EV = dyn_cast<ConstantFP>(EC->getOperand(EVIndex)); + if (!EV) return; + uint64_t intEV = Type::Int32Ty->getPrimitiveSizeInBits(); + if (!convertToInt(EV->getValueAPF(), &intEV)) + return; + + // Find new predicate for integer comparison. + CmpInst::Predicate NewPred = CmpInst::BAD_ICMP_PREDICATE; + switch (EC->getPredicate()) { + case CmpInst::FCMP_OEQ: + case CmpInst::FCMP_UEQ: + NewPred = CmpInst::ICMP_EQ; + break; + case CmpInst::FCMP_OGT: + case CmpInst::FCMP_UGT: + NewPred = CmpInst::ICMP_UGT; + break; + case CmpInst::FCMP_OGE: + case CmpInst::FCMP_UGE: + NewPred = CmpInst::ICMP_UGE; + break; + case CmpInst::FCMP_OLT: + case CmpInst::FCMP_ULT: + NewPred = CmpInst::ICMP_ULT; + break; + case CmpInst::FCMP_OLE: + case CmpInst::FCMP_ULE: + NewPred = CmpInst::ICMP_ULE; + break; + default: + break; + } + if (NewPred == CmpInst::BAD_ICMP_PREDICATE) return; + + // Insert new integer induction variable. + PHINode *NewPHI = PHINode::Create(Type::Int32Ty, + PH->getName()+".int", PH); + NewPHI->addIncoming(ConstantInt::get(Type::Int32Ty, newInitValue), + PH->getIncomingBlock(IncomingEdge)); + + Value *NewAdd = BinaryOperator::CreateAdd(NewPHI, + ConstantInt::get(Type::Int32Ty, + newIncrValue), + Incr->getName()+".int", Incr); + NewPHI->addIncoming(NewAdd, PH->getIncomingBlock(BackEdge)); + + // The back edge is edge 1 of newPHI, whatever it may have been in the + // original PHI. + ConstantInt *NewEV = ConstantInt::get(Type::Int32Ty, intEV); + Value *LHS = (EVIndex == 1 ? NewPHI->getIncomingValue(1) : NewEV); + Value *RHS = (EVIndex == 1 ? NewEV : NewPHI->getIncomingValue(1)); + ICmpInst *NewEC = new ICmpInst(NewPred, LHS, RHS, EC->getNameStart(), + EC->getParent()->getTerminator()); + + // In the following deltions, PH may become dead and may be deleted. + // Use a WeakVH to observe whether this happens. + WeakVH WeakPH = PH; + + // Delete old, floating point, exit comparision instruction. + NewEC->takeName(EC); + EC->replaceAllUsesWith(NewEC); + RecursivelyDeleteTriviallyDeadInstructions(EC); + + // Delete old, floating point, increment instruction. + Incr->replaceAllUsesWith(UndefValue::get(Incr->getType())); + RecursivelyDeleteTriviallyDeadInstructions(Incr); + + // Replace floating induction variable, if it isn't already deleted. + // Give SIToFPInst preference over UIToFPInst because it is faster on + // platforms that are widely used. + if (WeakPH && !PH->use_empty()) { + if (useSIToFPInst(*InitValue, *EV, newInitValue, intEV)) { + SIToFPInst *Conv = new SIToFPInst(NewPHI, PH->getType(), "indvar.conv", + PH->getParent()->getFirstNonPHI()); + PH->replaceAllUsesWith(Conv); + } else { + UIToFPInst *Conv = new UIToFPInst(NewPHI, PH->getType(), "indvar.conv", + PH->getParent()->getFirstNonPHI()); + PH->replaceAllUsesWith(Conv); + } + RecursivelyDeleteTriviallyDeadInstructions(PH); + } + + // Add a new IVUsers entry for the newly-created integer PHI. + IU->AddUsersIfInteresting(NewPHI); +} diff --git a/lib/Transforms/Scalar/InstructionCombining.cpp b/lib/Transforms/Scalar/InstructionCombining.cpp new file mode 100644 index 0000000..e6f854f --- /dev/null +++ b/lib/Transforms/Scalar/InstructionCombining.cpp @@ -0,0 +1,12919 @@ +//===- InstructionCombining.cpp - Combine multiple instructions -----------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// InstructionCombining - Combine instructions to form fewer, simple +// instructions. This pass does not modify the CFG. This pass is where +// algebraic simplification happens. +// +// This pass combines things like: +// %Y = add i32 %X, 1 +// %Z = add i32 %Y, 1 +// into: +// %Z = add i32 %X, 2 +// +// This is a simple worklist driven algorithm. +// +// This pass guarantees that the following canonicalizations are performed on +// the program: +// 1. If a binary operator has a constant operand, it is moved to the RHS +// 2. Bitwise operators with constant operands are always grouped so that +// shifts are performed first, then or's, then and's, then xor's. +// 3. Compare instructions are converted from <,>,<=,>= to ==,!= if possible +// 4. All cmp instructions on boolean values are replaced with logical ops +// 5. add X, X is represented as (X*2) => (X << 1) +// 6. Multiplies with a power-of-two constant argument are transformed into +// shifts. +// ... etc. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "instcombine" +#include "llvm/Transforms/Scalar.h" +#include "llvm/IntrinsicInst.h" +#include "llvm/Pass.h" +#include "llvm/DerivedTypes.h" +#include "llvm/GlobalVariable.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Target/TargetData.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Support/CallSite.h" +#include "llvm/Support/ConstantRange.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/GetElementPtrTypeIterator.h" +#include "llvm/Support/InstVisitor.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/PatternMatch.h" +#include "llvm/Support/Compiler.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/STLExtras.h" +#include <algorithm> +#include <climits> +#include <sstream> +using namespace llvm; +using namespace llvm::PatternMatch; + +STATISTIC(NumCombined , "Number of insts combined"); +STATISTIC(NumConstProp, "Number of constant folds"); +STATISTIC(NumDeadInst , "Number of dead inst eliminated"); +STATISTIC(NumDeadStore, "Number of dead stores eliminated"); +STATISTIC(NumSunkInst , "Number of instructions sunk"); + +namespace { + class VISIBILITY_HIDDEN InstCombiner + : public FunctionPass, + public InstVisitor<InstCombiner, Instruction*> { + // Worklist of all of the instructions that need to be simplified. + SmallVector<Instruction*, 256> Worklist; + DenseMap<Instruction*, unsigned> WorklistMap; + TargetData *TD; + bool MustPreserveLCSSA; + public: + static char ID; // Pass identification, replacement for typeid + InstCombiner() : FunctionPass(&ID) {} + + /// AddToWorkList - Add the specified instruction to the worklist if it + /// isn't already in it. + void AddToWorkList(Instruction *I) { + if (WorklistMap.insert(std::make_pair(I, Worklist.size())).second) + Worklist.push_back(I); + } + + // RemoveFromWorkList - remove I from the worklist if it exists. + void RemoveFromWorkList(Instruction *I) { + DenseMap<Instruction*, unsigned>::iterator It = WorklistMap.find(I); + if (It == WorklistMap.end()) return; // Not in worklist. + + // Don't bother moving everything down, just null out the slot. + Worklist[It->second] = 0; + + WorklistMap.erase(It); + } + + Instruction *RemoveOneFromWorkList() { + Instruction *I = Worklist.back(); + Worklist.pop_back(); + WorklistMap.erase(I); + return I; + } + + + /// AddUsersToWorkList - When an instruction is simplified, add all users of + /// the instruction to the work lists because they might get more simplified + /// now. + /// + void AddUsersToWorkList(Value &I) { + for (Value::use_iterator UI = I.use_begin(), UE = I.use_end(); + UI != UE; ++UI) + AddToWorkList(cast<Instruction>(*UI)); + } + + /// AddUsesToWorkList - When an instruction is simplified, add operands to + /// the work lists because they might get more simplified now. + /// + void AddUsesToWorkList(Instruction &I) { + for (User::op_iterator i = I.op_begin(), e = I.op_end(); i != e; ++i) + if (Instruction *Op = dyn_cast<Instruction>(*i)) + AddToWorkList(Op); + } + + /// AddSoonDeadInstToWorklist - The specified instruction is about to become + /// dead. Add all of its operands to the worklist, turning them into + /// undef's to reduce the number of uses of those instructions. + /// + /// Return the specified operand before it is turned into an undef. + /// + Value *AddSoonDeadInstToWorklist(Instruction &I, unsigned op) { + Value *R = I.getOperand(op); + + for (User::op_iterator i = I.op_begin(), e = I.op_end(); i != e; ++i) + if (Instruction *Op = dyn_cast<Instruction>(*i)) { + AddToWorkList(Op); + // Set the operand to undef to drop the use. + *i = UndefValue::get(Op->getType()); + } + + return R; + } + + public: + virtual bool runOnFunction(Function &F); + + bool DoOneIteration(Function &F, unsigned ItNum); + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired<TargetData>(); + AU.addPreservedID(LCSSAID); + AU.setPreservesCFG(); + } + + TargetData &getTargetData() const { return *TD; } + + // Visitation implementation - Implement instruction combining for different + // instruction types. The semantics are as follows: + // Return Value: + // null - No change was made + // I - Change was made, I is still valid, I may be dead though + // otherwise - Change was made, replace I with returned instruction + // + Instruction *visitAdd(BinaryOperator &I); + Instruction *visitSub(BinaryOperator &I); + Instruction *visitMul(BinaryOperator &I); + Instruction *visitURem(BinaryOperator &I); + Instruction *visitSRem(BinaryOperator &I); + Instruction *visitFRem(BinaryOperator &I); + bool SimplifyDivRemOfSelect(BinaryOperator &I); + Instruction *commonRemTransforms(BinaryOperator &I); + Instruction *commonIRemTransforms(BinaryOperator &I); + Instruction *commonDivTransforms(BinaryOperator &I); + Instruction *commonIDivTransforms(BinaryOperator &I); + Instruction *visitUDiv(BinaryOperator &I); + Instruction *visitSDiv(BinaryOperator &I); + Instruction *visitFDiv(BinaryOperator &I); + Instruction *FoldAndOfICmps(Instruction &I, ICmpInst *LHS, ICmpInst *RHS); + Instruction *visitAnd(BinaryOperator &I); + Instruction *FoldOrOfICmps(Instruction &I, ICmpInst *LHS, ICmpInst *RHS); + Instruction *FoldOrWithConstants(BinaryOperator &I, Value *Op, + Value *A, Value *B, Value *C); + Instruction *visitOr (BinaryOperator &I); + Instruction *visitXor(BinaryOperator &I); + Instruction *visitShl(BinaryOperator &I); + Instruction *visitAShr(BinaryOperator &I); + Instruction *visitLShr(BinaryOperator &I); + Instruction *commonShiftTransforms(BinaryOperator &I); + Instruction *FoldFCmp_IntToFP_Cst(FCmpInst &I, Instruction *LHSI, + Constant *RHSC); + Instruction *visitFCmpInst(FCmpInst &I); + Instruction *visitICmpInst(ICmpInst &I); + Instruction *visitICmpInstWithCastAndCast(ICmpInst &ICI); + Instruction *visitICmpInstWithInstAndIntCst(ICmpInst &ICI, + Instruction *LHS, + ConstantInt *RHS); + Instruction *FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, + ConstantInt *DivRHS); + + Instruction *FoldGEPICmp(User *GEPLHS, Value *RHS, + ICmpInst::Predicate Cond, Instruction &I); + Instruction *FoldShiftByConstant(Value *Op0, ConstantInt *Op1, + BinaryOperator &I); + Instruction *commonCastTransforms(CastInst &CI); + Instruction *commonIntCastTransforms(CastInst &CI); + Instruction *commonPointerCastTransforms(CastInst &CI); + Instruction *visitTrunc(TruncInst &CI); + Instruction *visitZExt(ZExtInst &CI); + Instruction *visitSExt(SExtInst &CI); + Instruction *visitFPTrunc(FPTruncInst &CI); + Instruction *visitFPExt(CastInst &CI); + Instruction *visitFPToUI(FPToUIInst &FI); + Instruction *visitFPToSI(FPToSIInst &FI); + Instruction *visitUIToFP(CastInst &CI); + Instruction *visitSIToFP(CastInst &CI); + Instruction *visitPtrToInt(PtrToIntInst &CI); + Instruction *visitIntToPtr(IntToPtrInst &CI); + Instruction *visitBitCast(BitCastInst &CI); + Instruction *FoldSelectOpOp(SelectInst &SI, Instruction *TI, + Instruction *FI); + Instruction *FoldSelectIntoOp(SelectInst &SI, Value*, Value*); + Instruction *visitSelectInst(SelectInst &SI); + Instruction *visitSelectInstWithICmp(SelectInst &SI, ICmpInst *ICI); + Instruction *visitCallInst(CallInst &CI); + Instruction *visitInvokeInst(InvokeInst &II); + Instruction *visitPHINode(PHINode &PN); + Instruction *visitGetElementPtrInst(GetElementPtrInst &GEP); + Instruction *visitAllocationInst(AllocationInst &AI); + Instruction *visitFreeInst(FreeInst &FI); + Instruction *visitLoadInst(LoadInst &LI); + Instruction *visitStoreInst(StoreInst &SI); + Instruction *visitBranchInst(BranchInst &BI); + Instruction *visitSwitchInst(SwitchInst &SI); + Instruction *visitInsertElementInst(InsertElementInst &IE); + Instruction *visitExtractElementInst(ExtractElementInst &EI); + Instruction *visitShuffleVectorInst(ShuffleVectorInst &SVI); + Instruction *visitExtractValueInst(ExtractValueInst &EV); + + // visitInstruction - Specify what to return for unhandled instructions... + Instruction *visitInstruction(Instruction &I) { return 0; } + + private: + Instruction *visitCallSite(CallSite CS); + bool transformConstExprCastCall(CallSite CS); + Instruction *transformCallThroughTrampoline(CallSite CS); + Instruction *transformZExtICmp(ICmpInst *ICI, Instruction &CI, + bool DoXform = true); + bool WillNotOverflowSignedAdd(Value *LHS, Value *RHS); + DbgDeclareInst *hasOneUsePlusDeclare(Value *V); + + + public: + // InsertNewInstBefore - insert an instruction New before instruction Old + // in the program. Add the new instruction to the worklist. + // + Instruction *InsertNewInstBefore(Instruction *New, Instruction &Old) { + assert(New && New->getParent() == 0 && + "New instruction already inserted into a basic block!"); + BasicBlock *BB = Old.getParent(); + BB->getInstList().insert(&Old, New); // Insert inst + AddToWorkList(New); + return New; + } + + /// InsertCastBefore - Insert a cast of V to TY before the instruction POS. + /// This also adds the cast to the worklist. Finally, this returns the + /// cast. + Value *InsertCastBefore(Instruction::CastOps opc, Value *V, const Type *Ty, + Instruction &Pos) { + if (V->getType() == Ty) return V; + + if (Constant *CV = dyn_cast<Constant>(V)) + return ConstantExpr::getCast(opc, CV, Ty); + + Instruction *C = CastInst::Create(opc, V, Ty, V->getName(), &Pos); + AddToWorkList(C); + return C; + } + + Value *InsertBitCastBefore(Value *V, const Type *Ty, Instruction &Pos) { + return InsertCastBefore(Instruction::BitCast, V, Ty, Pos); + } + + + // ReplaceInstUsesWith - This method is to be used when an instruction is + // found to be dead, replacable with another preexisting expression. Here + // we add all uses of I to the worklist, replace all uses of I with the new + // value, then return I, so that the inst combiner will know that I was + // modified. + // + Instruction *ReplaceInstUsesWith(Instruction &I, Value *V) { + AddUsersToWorkList(I); // Add all modified instrs to worklist + if (&I != V) { + I.replaceAllUsesWith(V); + return &I; + } else { + // If we are replacing the instruction with itself, this must be in a + // segment of unreachable code, so just clobber the instruction. + I.replaceAllUsesWith(UndefValue::get(I.getType())); + return &I; + } + } + + // EraseInstFromFunction - When dealing with an instruction that has side + // effects or produces a void value, we can't rely on DCE to delete the + // instruction. Instead, visit methods should return the value returned by + // this function. + Instruction *EraseInstFromFunction(Instruction &I) { + assert(I.use_empty() && "Cannot erase instruction that is used!"); + AddUsesToWorkList(I); + RemoveFromWorkList(&I); + I.eraseFromParent(); + return 0; // Don't do anything with FI + } + + void ComputeMaskedBits(Value *V, const APInt &Mask, APInt &KnownZero, + APInt &KnownOne, unsigned Depth = 0) const { + return llvm::ComputeMaskedBits(V, Mask, KnownZero, KnownOne, TD, Depth); + } + + bool MaskedValueIsZero(Value *V, const APInt &Mask, + unsigned Depth = 0) const { + return llvm::MaskedValueIsZero(V, Mask, TD, Depth); + } + unsigned ComputeNumSignBits(Value *Op, unsigned Depth = 0) const { + return llvm::ComputeNumSignBits(Op, TD, Depth); + } + + private: + + /// SimplifyCommutative - This performs a few simplifications for + /// commutative operators. + bool SimplifyCommutative(BinaryOperator &I); + + /// SimplifyCompare - This reorders the operands of a CmpInst to get them in + /// most-complex to least-complex order. + bool SimplifyCompare(CmpInst &I); + + /// SimplifyDemandedUseBits - Attempts to replace V with a simpler value + /// based on the demanded bits. + Value *SimplifyDemandedUseBits(Value *V, APInt DemandedMask, + APInt& KnownZero, APInt& KnownOne, + unsigned Depth); + bool SimplifyDemandedBits(Use &U, APInt DemandedMask, + APInt& KnownZero, APInt& KnownOne, + unsigned Depth=0); + + /// SimplifyDemandedInstructionBits - Inst is an integer instruction that + /// SimplifyDemandedBits knows about. See if the instruction has any + /// properties that allow us to simplify its operands. + bool SimplifyDemandedInstructionBits(Instruction &Inst); + + Value *SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, + APInt& UndefElts, unsigned Depth = 0); + + // FoldOpIntoPhi - Given a binary operator or cast instruction which has a + // PHI node as operand #0, see if we can fold the instruction into the PHI + // (which is only possible if all operands to the PHI are constants). + Instruction *FoldOpIntoPhi(Instruction &I); + + // FoldPHIArgOpIntoPHI - If all operands to a PHI node are the same "unary" + // operator and they all are only used by the PHI, PHI together their + // inputs, and do the operation once, to the result of the PHI. + Instruction *FoldPHIArgOpIntoPHI(PHINode &PN); + Instruction *FoldPHIArgBinOpIntoPHI(PHINode &PN); + Instruction *FoldPHIArgGEPIntoPHI(PHINode &PN); + + + Instruction *OptAndOp(Instruction *Op, ConstantInt *OpRHS, + ConstantInt *AndRHS, BinaryOperator &TheAnd); + + Value *FoldLogicalPlusAnd(Value *LHS, Value *RHS, ConstantInt *Mask, + bool isSub, Instruction &I); + Instruction *InsertRangeTest(Value *V, Constant *Lo, Constant *Hi, + bool isSigned, bool Inside, Instruction &IB); + Instruction *PromoteCastOfAllocation(BitCastInst &CI, AllocationInst &AI); + Instruction *MatchBSwap(BinaryOperator &I); + bool SimplifyStoreAtEndOfBlock(StoreInst &SI); + Instruction *SimplifyMemTransfer(MemIntrinsic *MI); + Instruction *SimplifyMemSet(MemSetInst *MI); + + + Value *EvaluateInDifferentType(Value *V, const Type *Ty, bool isSigned); + + bool CanEvaluateInDifferentType(Value *V, const IntegerType *Ty, + unsigned CastOpc, int &NumCastsRemoved); + unsigned GetOrEnforceKnownAlignment(Value *V, + unsigned PrefAlign = 0); + + }; +} + +char InstCombiner::ID = 0; +static RegisterPass<InstCombiner> +X("instcombine", "Combine redundant instructions"); + +// getComplexity: Assign a complexity or rank value to LLVM Values... +// 0 -> undef, 1 -> Const, 2 -> Other, 3 -> Arg, 3 -> Unary, 4 -> OtherInst +static unsigned getComplexity(Value *V) { + if (isa<Instruction>(V)) { + if (BinaryOperator::isNeg(V) || BinaryOperator::isNot(V)) + return 3; + return 4; + } + if (isa<Argument>(V)) return 3; + return isa<Constant>(V) ? (isa<UndefValue>(V) ? 0 : 1) : 2; +} + +// isOnlyUse - Return true if this instruction will be deleted if we stop using +// it. +static bool isOnlyUse(Value *V) { + return V->hasOneUse() || isa<Constant>(V); +} + +// getPromotedType - Return the specified type promoted as it would be to pass +// though a va_arg area... +static const Type *getPromotedType(const Type *Ty) { + if (const IntegerType* ITy = dyn_cast<IntegerType>(Ty)) { + if (ITy->getBitWidth() < 32) + return Type::Int32Ty; + } + return Ty; +} + +/// getBitCastOperand - If the specified operand is a CastInst, a constant +/// expression bitcast, or a GetElementPtrInst with all zero indices, return the +/// operand value, otherwise return null. +static Value *getBitCastOperand(Value *V) { + if (BitCastInst *I = dyn_cast<BitCastInst>(V)) + // BitCastInst? + return I->getOperand(0); + else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(V)) { + // GetElementPtrInst? + if (GEP->hasAllZeroIndices()) + return GEP->getOperand(0); + } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) { + if (CE->getOpcode() == Instruction::BitCast) + // BitCast ConstantExp? + return CE->getOperand(0); + else if (CE->getOpcode() == Instruction::GetElementPtr) { + // GetElementPtr ConstantExp? + for (User::op_iterator I = CE->op_begin() + 1, E = CE->op_end(); + I != E; ++I) { + ConstantInt *CI = dyn_cast<ConstantInt>(I); + if (!CI || !CI->isZero()) + // Any non-zero indices? Not cast-like. + return 0; + } + // All-zero indices? This is just like casting. + return CE->getOperand(0); + } + } + return 0; +} + +/// This function is a wrapper around CastInst::isEliminableCastPair. It +/// simply extracts arguments and returns what that function returns. +static Instruction::CastOps +isEliminableCastPair( + const CastInst *CI, ///< The first cast instruction + unsigned opcode, ///< The opcode of the second cast instruction + const Type *DstTy, ///< The target type for the second cast instruction + TargetData *TD ///< The target data for pointer size +) { + + const Type *SrcTy = CI->getOperand(0)->getType(); // A from above + const Type *MidTy = CI->getType(); // B from above + + // Get the opcodes of the two Cast instructions + Instruction::CastOps firstOp = Instruction::CastOps(CI->getOpcode()); + Instruction::CastOps secondOp = Instruction::CastOps(opcode); + + unsigned Res = CastInst::isEliminableCastPair(firstOp, secondOp, SrcTy, MidTy, + DstTy, TD->getIntPtrType()); + + // We don't want to form an inttoptr or ptrtoint that converts to an integer + // type that differs from the pointer size. + if ((Res == Instruction::IntToPtr && SrcTy != TD->getIntPtrType()) || + (Res == Instruction::PtrToInt && DstTy != TD->getIntPtrType())) + Res = 0; + + return Instruction::CastOps(Res); +} + +/// ValueRequiresCast - Return true if the cast from "V to Ty" actually results +/// in any code being generated. It does not require codegen if V is simple +/// enough or if the cast can be folded into other casts. +static bool ValueRequiresCast(Instruction::CastOps opcode, const Value *V, + const Type *Ty, TargetData *TD) { + if (V->getType() == Ty || isa<Constant>(V)) return false; + + // If this is another cast that can be eliminated, it isn't codegen either. + if (const CastInst *CI = dyn_cast<CastInst>(V)) + if (isEliminableCastPair(CI, opcode, Ty, TD)) + return false; + return true; +} + +// SimplifyCommutative - This performs a few simplifications for commutative +// operators: +// +// 1. Order operands such that they are listed from right (least complex) to +// left (most complex). This puts constants before unary operators before +// binary operators. +// +// 2. Transform: (op (op V, C1), C2) ==> (op V, (op C1, C2)) +// 3. Transform: (op (op V1, C1), (op V2, C2)) ==> (op (op V1, V2), (op C1,C2)) +// +bool InstCombiner::SimplifyCommutative(BinaryOperator &I) { + bool Changed = false; + if (getComplexity(I.getOperand(0)) < getComplexity(I.getOperand(1))) + Changed = !I.swapOperands(); + + if (!I.isAssociative()) return Changed; + Instruction::BinaryOps Opcode = I.getOpcode(); + if (BinaryOperator *Op = dyn_cast<BinaryOperator>(I.getOperand(0))) + if (Op->getOpcode() == Opcode && isa<Constant>(Op->getOperand(1))) { + if (isa<Constant>(I.getOperand(1))) { + Constant *Folded = ConstantExpr::get(I.getOpcode(), + cast<Constant>(I.getOperand(1)), + cast<Constant>(Op->getOperand(1))); + I.setOperand(0, Op->getOperand(0)); + I.setOperand(1, Folded); + return true; + } else if (BinaryOperator *Op1=dyn_cast<BinaryOperator>(I.getOperand(1))) + if (Op1->getOpcode() == Opcode && isa<Constant>(Op1->getOperand(1)) && + isOnlyUse(Op) && isOnlyUse(Op1)) { + Constant *C1 = cast<Constant>(Op->getOperand(1)); + Constant *C2 = cast<Constant>(Op1->getOperand(1)); + + // Fold (op (op V1, C1), (op V2, C2)) ==> (op (op V1, V2), (op C1,C2)) + Constant *Folded = ConstantExpr::get(I.getOpcode(), C1, C2); + Instruction *New = BinaryOperator::Create(Opcode, Op->getOperand(0), + Op1->getOperand(0), + Op1->getName(), &I); + AddToWorkList(New); + I.setOperand(0, New); + I.setOperand(1, Folded); + return true; + } + } + return Changed; +} + +/// SimplifyCompare - For a CmpInst this function just orders the operands +/// so that theyare listed from right (least complex) to left (most complex). +/// This puts constants before unary operators before binary operators. +bool InstCombiner::SimplifyCompare(CmpInst &I) { + if (getComplexity(I.getOperand(0)) >= getComplexity(I.getOperand(1))) + return false; + I.swapOperands(); + // Compare instructions are not associative so there's nothing else we can do. + return true; +} + +// dyn_castNegVal - Given a 'sub' instruction, return the RHS of the instruction +// if the LHS is a constant zero (which is the 'negate' form). +// +static inline Value *dyn_castNegVal(Value *V) { + if (BinaryOperator::isNeg(V)) + return BinaryOperator::getNegArgument(V); + + // Constants can be considered to be negated values if they can be folded. + if (ConstantInt *C = dyn_cast<ConstantInt>(V)) + return ConstantExpr::getNeg(C); + + if (ConstantVector *C = dyn_cast<ConstantVector>(V)) + if (C->getType()->getElementType()->isInteger()) + return ConstantExpr::getNeg(C); + + return 0; +} + +static inline Value *dyn_castNotVal(Value *V) { + if (BinaryOperator::isNot(V)) + return BinaryOperator::getNotArgument(V); + + // Constants can be considered to be not'ed values... + if (ConstantInt *C = dyn_cast<ConstantInt>(V)) + return ConstantInt::get(~C->getValue()); + return 0; +} + +// dyn_castFoldableMul - If this value is a multiply that can be folded into +// other computations (because it has a constant operand), return the +// non-constant operand of the multiply, and set CST to point to the multiplier. +// Otherwise, return null. +// +static inline Value *dyn_castFoldableMul(Value *V, ConstantInt *&CST) { + if (V->hasOneUse() && V->getType()->isInteger()) + if (Instruction *I = dyn_cast<Instruction>(V)) { + if (I->getOpcode() == Instruction::Mul) + if ((CST = dyn_cast<ConstantInt>(I->getOperand(1)))) + return I->getOperand(0); + if (I->getOpcode() == Instruction::Shl) + if ((CST = dyn_cast<ConstantInt>(I->getOperand(1)))) { + // The multiplier is really 1 << CST. + uint32_t BitWidth = cast<IntegerType>(V->getType())->getBitWidth(); + uint32_t CSTVal = CST->getLimitedValue(BitWidth); + CST = ConstantInt::get(APInt(BitWidth, 1).shl(CSTVal)); + return I->getOperand(0); + } + } + return 0; +} + +/// dyn_castGetElementPtr - If this is a getelementptr instruction or constant +/// expression, return it. +static User *dyn_castGetElementPtr(Value *V) { + if (isa<GetElementPtrInst>(V)) return cast<User>(V); + if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) + if (CE->getOpcode() == Instruction::GetElementPtr) + return cast<User>(V); + return false; +} + +/// getOpcode - If this is an Instruction or a ConstantExpr, return the +/// opcode value. Otherwise return UserOp1. +static unsigned getOpcode(const Value *V) { + if (const Instruction *I = dyn_cast<Instruction>(V)) + return I->getOpcode(); + if (const ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) + return CE->getOpcode(); + // Use UserOp1 to mean there's no opcode. + return Instruction::UserOp1; +} + +/// AddOne - Add one to a ConstantInt +static ConstantInt *AddOne(ConstantInt *C) { + APInt Val(C->getValue()); + return ConstantInt::get(++Val); +} +/// SubOne - Subtract one from a ConstantInt +static ConstantInt *SubOne(ConstantInt *C) { + APInt Val(C->getValue()); + return ConstantInt::get(--Val); +} +/// Add - Add two ConstantInts together +static ConstantInt *Add(ConstantInt *C1, ConstantInt *C2) { + return ConstantInt::get(C1->getValue() + C2->getValue()); +} +/// And - Bitwise AND two ConstantInts together +static ConstantInt *And(ConstantInt *C1, ConstantInt *C2) { + return ConstantInt::get(C1->getValue() & C2->getValue()); +} +/// Subtract - Subtract one ConstantInt from another +static ConstantInt *Subtract(ConstantInt *C1, ConstantInt *C2) { + return ConstantInt::get(C1->getValue() - C2->getValue()); +} +/// Multiply - Multiply two ConstantInts together +static ConstantInt *Multiply(ConstantInt *C1, ConstantInt *C2) { + return ConstantInt::get(C1->getValue() * C2->getValue()); +} +/// MultiplyOverflows - True if the multiply can not be expressed in an int +/// this size. +static bool MultiplyOverflows(ConstantInt *C1, ConstantInt *C2, bool sign) { + uint32_t W = C1->getBitWidth(); + APInt LHSExt = C1->getValue(), RHSExt = C2->getValue(); + if (sign) { + LHSExt.sext(W * 2); + RHSExt.sext(W * 2); + } else { + LHSExt.zext(W * 2); + RHSExt.zext(W * 2); + } + + APInt MulExt = LHSExt * RHSExt; + + if (sign) { + APInt Min = APInt::getSignedMinValue(W).sext(W * 2); + APInt Max = APInt::getSignedMaxValue(W).sext(W * 2); + return MulExt.slt(Min) || MulExt.sgt(Max); + } else + return MulExt.ugt(APInt::getLowBitsSet(W * 2, W)); +} + + +/// ShrinkDemandedConstant - Check to see if the specified operand of the +/// specified instruction is a constant integer. If so, check to see if there +/// are any bits set in the constant that are not demanded. If so, shrink the +/// constant and return true. +static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo, + APInt Demanded) { + assert(I && "No instruction?"); + assert(OpNo < I->getNumOperands() && "Operand index too large"); + + // If the operand is not a constant integer, nothing to do. + ConstantInt *OpC = dyn_cast<ConstantInt>(I->getOperand(OpNo)); + if (!OpC) return false; + + // If there are no bits set that aren't demanded, nothing to do. + Demanded.zextOrTrunc(OpC->getValue().getBitWidth()); + if ((~Demanded & OpC->getValue()) == 0) + return false; + + // This instruction is producing bits that are not demanded. Shrink the RHS. + Demanded &= OpC->getValue(); + I->setOperand(OpNo, ConstantInt::get(Demanded)); + return true; +} + +// ComputeSignedMinMaxValuesFromKnownBits - Given a signed integer type and a +// set of known zero and one bits, compute the maximum and minimum values that +// could have the specified known zero and known one bits, returning them in +// min/max. +static void ComputeSignedMinMaxValuesFromKnownBits(const APInt& KnownZero, + const APInt& KnownOne, + APInt& Min, APInt& Max) { + assert(KnownZero.getBitWidth() == KnownOne.getBitWidth() && + KnownZero.getBitWidth() == Min.getBitWidth() && + KnownZero.getBitWidth() == Max.getBitWidth() && + "KnownZero, KnownOne and Min, Max must have equal bitwidth."); + APInt UnknownBits = ~(KnownZero|KnownOne); + + // The minimum value is when all unknown bits are zeros, EXCEPT for the sign + // bit if it is unknown. + Min = KnownOne; + Max = KnownOne|UnknownBits; + + if (UnknownBits.isNegative()) { // Sign bit is unknown + Min.set(Min.getBitWidth()-1); + Max.clear(Max.getBitWidth()-1); + } +} + +// ComputeUnsignedMinMaxValuesFromKnownBits - Given an unsigned integer type and +// a set of known zero and one bits, compute the maximum and minimum values that +// could have the specified known zero and known one bits, returning them in +// min/max. +static void ComputeUnsignedMinMaxValuesFromKnownBits(const APInt &KnownZero, + const APInt &KnownOne, + APInt &Min, APInt &Max) { + assert(KnownZero.getBitWidth() == KnownOne.getBitWidth() && + KnownZero.getBitWidth() == Min.getBitWidth() && + KnownZero.getBitWidth() == Max.getBitWidth() && + "Ty, KnownZero, KnownOne and Min, Max must have equal bitwidth."); + APInt UnknownBits = ~(KnownZero|KnownOne); + + // The minimum value is when the unknown bits are all zeros. + Min = KnownOne; + // The maximum value is when the unknown bits are all ones. + Max = KnownOne|UnknownBits; +} + +/// SimplifyDemandedInstructionBits - Inst is an integer instruction that +/// SimplifyDemandedBits knows about. See if the instruction has any +/// properties that allow us to simplify its operands. +bool InstCombiner::SimplifyDemandedInstructionBits(Instruction &Inst) { + unsigned BitWidth = cast<IntegerType>(Inst.getType())->getBitWidth(); + APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); + APInt DemandedMask(APInt::getAllOnesValue(BitWidth)); + + Value *V = SimplifyDemandedUseBits(&Inst, DemandedMask, + KnownZero, KnownOne, 0); + if (V == 0) return false; + if (V == &Inst) return true; + ReplaceInstUsesWith(Inst, V); + return true; +} + +/// SimplifyDemandedBits - This form of SimplifyDemandedBits simplifies the +/// specified instruction operand if possible, updating it in place. It returns +/// true if it made any change and false otherwise. +bool InstCombiner::SimplifyDemandedBits(Use &U, APInt DemandedMask, + APInt &KnownZero, APInt &KnownOne, + unsigned Depth) { + Value *NewVal = SimplifyDemandedUseBits(U.get(), DemandedMask, + KnownZero, KnownOne, Depth); + if (NewVal == 0) return false; + U.set(NewVal); + return true; +} + + +/// SimplifyDemandedUseBits - This function attempts to replace V with a simpler +/// value based on the demanded bits. When this function is called, it is known +/// that only the bits set in DemandedMask of the result of V are ever used +/// downstream. Consequently, depending on the mask and V, it may be possible +/// to replace V with a constant or one of its operands. In such cases, this +/// function does the replacement and returns true. In all other cases, it +/// returns false after analyzing the expression and setting KnownOne and known +/// to be one in the expression. KnownZero contains all the bits that are known +/// to be zero in the expression. These are provided to potentially allow the +/// caller (which might recursively be SimplifyDemandedBits itself) to simplify +/// the expression. KnownOne and KnownZero always follow the invariant that +/// KnownOne & KnownZero == 0. That is, a bit can't be both 1 and 0. Note that +/// the bits in KnownOne and KnownZero may only be accurate for those bits set +/// in DemandedMask. Note also that the bitwidth of V, DemandedMask, KnownZero +/// and KnownOne must all be the same. +/// +/// This returns null if it did not change anything and it permits no +/// simplification. This returns V itself if it did some simplification of V's +/// operands based on the information about what bits are demanded. This returns +/// some other non-null value if it found out that V is equal to another value +/// in the context where the specified bits are demanded, but not for all users. +Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, + APInt &KnownZero, APInt &KnownOne, + unsigned Depth) { + assert(V != 0 && "Null pointer of Value???"); + assert(Depth <= 6 && "Limit Search Depth"); + uint32_t BitWidth = DemandedMask.getBitWidth(); + const Type *VTy = V->getType(); + assert((TD || !isa<PointerType>(VTy)) && + "SimplifyDemandedBits needs to know bit widths!"); + assert((!TD || TD->getTypeSizeInBits(VTy) == BitWidth) && + (!isa<IntegerType>(VTy) || + VTy->getPrimitiveSizeInBits() == BitWidth) && + KnownZero.getBitWidth() == BitWidth && + KnownOne.getBitWidth() == BitWidth && + "Value *V, DemandedMask, KnownZero and KnownOne \ + must have same BitWidth"); + if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) { + // We know all of the bits for a constant! + KnownOne = CI->getValue() & DemandedMask; + KnownZero = ~KnownOne & DemandedMask; + return 0; + } + if (isa<ConstantPointerNull>(V)) { + // We know all of the bits for a constant! + KnownOne.clear(); + KnownZero = DemandedMask; + return 0; + } + + KnownZero.clear(); + KnownOne.clear(); + if (DemandedMask == 0) { // Not demanding any bits from V. + if (isa<UndefValue>(V)) + return 0; + return UndefValue::get(VTy); + } + + if (Depth == 6) // Limit search depth. + return 0; + + APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0); + APInt &RHSKnownZero = KnownZero, &RHSKnownOne = KnownOne; + + Instruction *I = dyn_cast<Instruction>(V); + if (!I) { + ComputeMaskedBits(V, DemandedMask, RHSKnownZero, RHSKnownOne, Depth); + return 0; // Only analyze instructions. + } + + // If there are multiple uses of this value and we aren't at the root, then + // we can't do any simplifications of the operands, because DemandedMask + // only reflects the bits demanded by *one* of the users. + if (Depth != 0 && !I->hasOneUse()) { + // Despite the fact that we can't simplify this instruction in all User's + // context, we can at least compute the knownzero/knownone bits, and we can + // do simplifications that apply to *just* the one user if we know that + // this instruction has a simpler value in that context. + if (I->getOpcode() == Instruction::And) { + // If either the LHS or the RHS are Zero, the result is zero. + ComputeMaskedBits(I->getOperand(1), DemandedMask, + RHSKnownZero, RHSKnownOne, Depth+1); + ComputeMaskedBits(I->getOperand(0), DemandedMask & ~RHSKnownZero, + LHSKnownZero, LHSKnownOne, Depth+1); + + // If all of the demanded bits are known 1 on one side, return the other. + // These bits cannot contribute to the result of the 'and' in this + // context. + if ((DemandedMask & ~LHSKnownZero & RHSKnownOne) == + (DemandedMask & ~LHSKnownZero)) + return I->getOperand(0); + if ((DemandedMask & ~RHSKnownZero & LHSKnownOne) == + (DemandedMask & ~RHSKnownZero)) + return I->getOperand(1); + + // If all of the demanded bits in the inputs are known zeros, return zero. + if ((DemandedMask & (RHSKnownZero|LHSKnownZero)) == DemandedMask) + return Constant::getNullValue(VTy); + + } else if (I->getOpcode() == Instruction::Or) { + // We can simplify (X|Y) -> X or Y in the user's context if we know that + // only bits from X or Y are demanded. + + // If either the LHS or the RHS are One, the result is One. + ComputeMaskedBits(I->getOperand(1), DemandedMask, + RHSKnownZero, RHSKnownOne, Depth+1); + ComputeMaskedBits(I->getOperand(0), DemandedMask & ~RHSKnownOne, + LHSKnownZero, LHSKnownOne, Depth+1); + + // If all of the demanded bits are known zero on one side, return the + // other. These bits cannot contribute to the result of the 'or' in this + // context. + if ((DemandedMask & ~LHSKnownOne & RHSKnownZero) == + (DemandedMask & ~LHSKnownOne)) + return I->getOperand(0); + if ((DemandedMask & ~RHSKnownOne & LHSKnownZero) == + (DemandedMask & ~RHSKnownOne)) + return I->getOperand(1); + + // If all of the potentially set bits on one side are known to be set on + // the other side, just use the 'other' side. + if ((DemandedMask & (~RHSKnownZero) & LHSKnownOne) == + (DemandedMask & (~RHSKnownZero))) + return I->getOperand(0); + if ((DemandedMask & (~LHSKnownZero) & RHSKnownOne) == + (DemandedMask & (~LHSKnownZero))) + return I->getOperand(1); + } + + // Compute the KnownZero/KnownOne bits to simplify things downstream. + ComputeMaskedBits(I, DemandedMask, KnownZero, KnownOne, Depth); + return 0; + } + + // If this is the root being simplified, allow it to have multiple uses, + // just set the DemandedMask to all bits so that we can try to simplify the + // operands. This allows visitTruncInst (for example) to simplify the + // operand of a trunc without duplicating all the logic below. + if (Depth == 0 && !V->hasOneUse()) + DemandedMask = APInt::getAllOnesValue(BitWidth); + + switch (I->getOpcode()) { + default: + ComputeMaskedBits(I, DemandedMask, RHSKnownZero, RHSKnownOne, Depth); + break; + case Instruction::And: + // If either the LHS or the RHS are Zero, the result is zero. + if (SimplifyDemandedBits(I->getOperandUse(1), DemandedMask, + RHSKnownZero, RHSKnownOne, Depth+1) || + SimplifyDemandedBits(I->getOperandUse(0), DemandedMask & ~RHSKnownZero, + LHSKnownZero, LHSKnownOne, Depth+1)) + return I; + assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?"); + assert(!(LHSKnownZero & LHSKnownOne) && "Bits known to be one AND zero?"); + + // If all of the demanded bits are known 1 on one side, return the other. + // These bits cannot contribute to the result of the 'and'. + if ((DemandedMask & ~LHSKnownZero & RHSKnownOne) == + (DemandedMask & ~LHSKnownZero)) + return I->getOperand(0); + if ((DemandedMask & ~RHSKnownZero & LHSKnownOne) == + (DemandedMask & ~RHSKnownZero)) + return I->getOperand(1); + + // If all of the demanded bits in the inputs are known zeros, return zero. + if ((DemandedMask & (RHSKnownZero|LHSKnownZero)) == DemandedMask) + return Constant::getNullValue(VTy); + + // If the RHS is a constant, see if we can simplify it. + if (ShrinkDemandedConstant(I, 1, DemandedMask & ~LHSKnownZero)) + return I; + + // Output known-1 bits are only known if set in both the LHS & RHS. + RHSKnownOne &= LHSKnownOne; + // Output known-0 are known to be clear if zero in either the LHS | RHS. + RHSKnownZero |= LHSKnownZero; + break; + case Instruction::Or: + // If either the LHS or the RHS are One, the result is One. + if (SimplifyDemandedBits(I->getOperandUse(1), DemandedMask, + RHSKnownZero, RHSKnownOne, Depth+1) || + SimplifyDemandedBits(I->getOperandUse(0), DemandedMask & ~RHSKnownOne, + LHSKnownZero, LHSKnownOne, Depth+1)) + return I; + assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?"); + assert(!(LHSKnownZero & LHSKnownOne) && "Bits known to be one AND zero?"); + + // If all of the demanded bits are known zero on one side, return the other. + // These bits cannot contribute to the result of the 'or'. + if ((DemandedMask & ~LHSKnownOne & RHSKnownZero) == + (DemandedMask & ~LHSKnownOne)) + return I->getOperand(0); + if ((DemandedMask & ~RHSKnownOne & LHSKnownZero) == + (DemandedMask & ~RHSKnownOne)) + return I->getOperand(1); + + // If all of the potentially set bits on one side are known to be set on + // the other side, just use the 'other' side. + if ((DemandedMask & (~RHSKnownZero) & LHSKnownOne) == + (DemandedMask & (~RHSKnownZero))) + return I->getOperand(0); + if ((DemandedMask & (~LHSKnownZero) & RHSKnownOne) == + (DemandedMask & (~LHSKnownZero))) + return I->getOperand(1); + + // If the RHS is a constant, see if we can simplify it. + if (ShrinkDemandedConstant(I, 1, DemandedMask)) + return I; + + // Output known-0 bits are only known if clear in both the LHS & RHS. + RHSKnownZero &= LHSKnownZero; + // Output known-1 are known to be set if set in either the LHS | RHS. + RHSKnownOne |= LHSKnownOne; + break; + case Instruction::Xor: { + if (SimplifyDemandedBits(I->getOperandUse(1), DemandedMask, + RHSKnownZero, RHSKnownOne, Depth+1) || + SimplifyDemandedBits(I->getOperandUse(0), DemandedMask, + LHSKnownZero, LHSKnownOne, Depth+1)) + return I; + assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?"); + assert(!(LHSKnownZero & LHSKnownOne) && "Bits known to be one AND zero?"); + + // If all of the demanded bits are known zero on one side, return the other. + // These bits cannot contribute to the result of the 'xor'. + if ((DemandedMask & RHSKnownZero) == DemandedMask) + return I->getOperand(0); + if ((DemandedMask & LHSKnownZero) == DemandedMask) + return I->getOperand(1); + + // Output known-0 bits are known if clear or set in both the LHS & RHS. + APInt KnownZeroOut = (RHSKnownZero & LHSKnownZero) | + (RHSKnownOne & LHSKnownOne); + // Output known-1 are known to be set if set in only one of the LHS, RHS. + APInt KnownOneOut = (RHSKnownZero & LHSKnownOne) | + (RHSKnownOne & LHSKnownZero); + + // If all of the demanded bits are known to be zero on one side or the + // other, turn this into an *inclusive* or. + // e.g. (A & C1)^(B & C2) -> (A & C1)|(B & C2) iff C1&C2 == 0 + if ((DemandedMask & ~RHSKnownZero & ~LHSKnownZero) == 0) { + Instruction *Or = + BinaryOperator::CreateOr(I->getOperand(0), I->getOperand(1), + I->getName()); + return InsertNewInstBefore(Or, *I); + } + + // If all of the demanded bits on one side are known, and all of the set + // bits on that side are also known to be set on the other side, turn this + // into an AND, as we know the bits will be cleared. + // e.g. (X | C1) ^ C2 --> (X | C1) & ~C2 iff (C1&C2) == C2 + if ((DemandedMask & (RHSKnownZero|RHSKnownOne)) == DemandedMask) { + // all known + if ((RHSKnownOne & LHSKnownOne) == RHSKnownOne) { + Constant *AndC = ConstantInt::get(~RHSKnownOne & DemandedMask); + Instruction *And = + BinaryOperator::CreateAnd(I->getOperand(0), AndC, "tmp"); + return InsertNewInstBefore(And, *I); + } + } + + // If the RHS is a constant, see if we can simplify it. + // FIXME: for XOR, we prefer to force bits to 1 if they will make a -1. + if (ShrinkDemandedConstant(I, 1, DemandedMask)) + return I; + + RHSKnownZero = KnownZeroOut; + RHSKnownOne = KnownOneOut; + break; + } + case Instruction::Select: + if (SimplifyDemandedBits(I->getOperandUse(2), DemandedMask, + RHSKnownZero, RHSKnownOne, Depth+1) || + SimplifyDemandedBits(I->getOperandUse(1), DemandedMask, + LHSKnownZero, LHSKnownOne, Depth+1)) + return I; + assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?"); + assert(!(LHSKnownZero & LHSKnownOne) && "Bits known to be one AND zero?"); + + // If the operands are constants, see if we can simplify them. + if (ShrinkDemandedConstant(I, 1, DemandedMask) || + ShrinkDemandedConstant(I, 2, DemandedMask)) + return I; + + // Only known if known in both the LHS and RHS. + RHSKnownOne &= LHSKnownOne; + RHSKnownZero &= LHSKnownZero; + break; + case Instruction::Trunc: { + unsigned truncBf = I->getOperand(0)->getType()->getPrimitiveSizeInBits(); + DemandedMask.zext(truncBf); + RHSKnownZero.zext(truncBf); + RHSKnownOne.zext(truncBf); + if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMask, + RHSKnownZero, RHSKnownOne, Depth+1)) + return I; + DemandedMask.trunc(BitWidth); + RHSKnownZero.trunc(BitWidth); + RHSKnownOne.trunc(BitWidth); + assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?"); + break; + } + case Instruction::BitCast: + if (!I->getOperand(0)->getType()->isInteger()) + return false; // vector->int or fp->int? + if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMask, + RHSKnownZero, RHSKnownOne, Depth+1)) + return I; + assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?"); + break; + case Instruction::ZExt: { + // Compute the bits in the result that are not present in the input. + unsigned SrcBitWidth =I->getOperand(0)->getType()->getPrimitiveSizeInBits(); + + DemandedMask.trunc(SrcBitWidth); + RHSKnownZero.trunc(SrcBitWidth); + RHSKnownOne.trunc(SrcBitWidth); + if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMask, + RHSKnownZero, RHSKnownOne, Depth+1)) + return I; + DemandedMask.zext(BitWidth); + RHSKnownZero.zext(BitWidth); + RHSKnownOne.zext(BitWidth); + assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?"); + // The top bits are known to be zero. + RHSKnownZero |= APInt::getHighBitsSet(BitWidth, BitWidth - SrcBitWidth); + break; + } + case Instruction::SExt: { + // Compute the bits in the result that are not present in the input. + unsigned SrcBitWidth =I->getOperand(0)->getType()->getPrimitiveSizeInBits(); + + APInt InputDemandedBits = DemandedMask & + APInt::getLowBitsSet(BitWidth, SrcBitWidth); + + APInt NewBits(APInt::getHighBitsSet(BitWidth, BitWidth - SrcBitWidth)); + // If any of the sign extended bits are demanded, we know that the sign + // bit is demanded. + if ((NewBits & DemandedMask) != 0) + InputDemandedBits.set(SrcBitWidth-1); + + InputDemandedBits.trunc(SrcBitWidth); + RHSKnownZero.trunc(SrcBitWidth); + RHSKnownOne.trunc(SrcBitWidth); + if (SimplifyDemandedBits(I->getOperandUse(0), InputDemandedBits, + RHSKnownZero, RHSKnownOne, Depth+1)) + return I; + InputDemandedBits.zext(BitWidth); + RHSKnownZero.zext(BitWidth); + RHSKnownOne.zext(BitWidth); + assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?"); + + // If the sign bit of the input is known set or clear, then we know the + // top bits of the result. + + // If the input sign bit is known zero, or if the NewBits are not demanded + // convert this into a zero extension. + if (RHSKnownZero[SrcBitWidth-1] || (NewBits & ~DemandedMask) == NewBits) { + // Convert to ZExt cast + CastInst *NewCast = new ZExtInst(I->getOperand(0), VTy, I->getName()); + return InsertNewInstBefore(NewCast, *I); + } else if (RHSKnownOne[SrcBitWidth-1]) { // Input sign bit known set + RHSKnownOne |= NewBits; + } + break; + } + case Instruction::Add: { + // Figure out what the input bits are. If the top bits of the and result + // are not demanded, then the add doesn't demand them from its input + // either. + unsigned NLZ = DemandedMask.countLeadingZeros(); + + // If there is a constant on the RHS, there are a variety of xformations + // we can do. + if (ConstantInt *RHS = dyn_cast<ConstantInt>(I->getOperand(1))) { + // If null, this should be simplified elsewhere. Some of the xforms here + // won't work if the RHS is zero. + if (RHS->isZero()) + break; + + // If the top bit of the output is demanded, demand everything from the + // input. Otherwise, we demand all the input bits except NLZ top bits. + APInt InDemandedBits(APInt::getLowBitsSet(BitWidth, BitWidth - NLZ)); + + // Find information about known zero/one bits in the input. + if (SimplifyDemandedBits(I->getOperandUse(0), InDemandedBits, + LHSKnownZero, LHSKnownOne, Depth+1)) + return I; + + // If the RHS of the add has bits set that can't affect the input, reduce + // the constant. + if (ShrinkDemandedConstant(I, 1, InDemandedBits)) + return I; + + // Avoid excess work. + if (LHSKnownZero == 0 && LHSKnownOne == 0) + break; + + // Turn it into OR if input bits are zero. + if ((LHSKnownZero & RHS->getValue()) == RHS->getValue()) { + Instruction *Or = + BinaryOperator::CreateOr(I->getOperand(0), I->getOperand(1), + I->getName()); + return InsertNewInstBefore(Or, *I); + } + + // We can say something about the output known-zero and known-one bits, + // depending on potential carries from the input constant and the + // unknowns. For example if the LHS is known to have at most the 0x0F0F0 + // bits set and the RHS constant is 0x01001, then we know we have a known + // one mask of 0x00001 and a known zero mask of 0xE0F0E. + + // To compute this, we first compute the potential carry bits. These are + // the bits which may be modified. I'm not aware of a better way to do + // this scan. + const APInt &RHSVal = RHS->getValue(); + APInt CarryBits((~LHSKnownZero + RHSVal) ^ (~LHSKnownZero ^ RHSVal)); + + // Now that we know which bits have carries, compute the known-1/0 sets. + + // Bits are known one if they are known zero in one operand and one in the + // other, and there is no input carry. + RHSKnownOne = ((LHSKnownZero & RHSVal) | + (LHSKnownOne & ~RHSVal)) & ~CarryBits; + + // Bits are known zero if they are known zero in both operands and there + // is no input carry. + RHSKnownZero = LHSKnownZero & ~RHSVal & ~CarryBits; + } else { + // If the high-bits of this ADD are not demanded, then it does not demand + // the high bits of its LHS or RHS. + if (DemandedMask[BitWidth-1] == 0) { + // Right fill the mask of bits for this ADD to demand the most + // significant bit and all those below it. + APInt DemandedFromOps(APInt::getLowBitsSet(BitWidth, BitWidth-NLZ)); + if (SimplifyDemandedBits(I->getOperandUse(0), DemandedFromOps, + LHSKnownZero, LHSKnownOne, Depth+1) || + SimplifyDemandedBits(I->getOperandUse(1), DemandedFromOps, + LHSKnownZero, LHSKnownOne, Depth+1)) + return I; + } + } + break; + } + case Instruction::Sub: + // If the high-bits of this SUB are not demanded, then it does not demand + // the high bits of its LHS or RHS. + if (DemandedMask[BitWidth-1] == 0) { + // Right fill the mask of bits for this SUB to demand the most + // significant bit and all those below it. + uint32_t NLZ = DemandedMask.countLeadingZeros(); + APInt DemandedFromOps(APInt::getLowBitsSet(BitWidth, BitWidth-NLZ)); + if (SimplifyDemandedBits(I->getOperandUse(0), DemandedFromOps, + LHSKnownZero, LHSKnownOne, Depth+1) || + SimplifyDemandedBits(I->getOperandUse(1), DemandedFromOps, + LHSKnownZero, LHSKnownOne, Depth+1)) + return I; + } + // Otherwise just hand the sub off to ComputeMaskedBits to fill in + // the known zeros and ones. + ComputeMaskedBits(V, DemandedMask, RHSKnownZero, RHSKnownOne, Depth); + break; + case Instruction::Shl: + if (ConstantInt *SA = dyn_cast<ConstantInt>(I->getOperand(1))) { + uint64_t ShiftAmt = SA->getLimitedValue(BitWidth); + APInt DemandedMaskIn(DemandedMask.lshr(ShiftAmt)); + if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMaskIn, + RHSKnownZero, RHSKnownOne, Depth+1)) + return I; + assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?"); + RHSKnownZero <<= ShiftAmt; + RHSKnownOne <<= ShiftAmt; + // low bits known zero. + if (ShiftAmt) + RHSKnownZero |= APInt::getLowBitsSet(BitWidth, ShiftAmt); + } + break; + case Instruction::LShr: + // For a logical shift right + if (ConstantInt *SA = dyn_cast<ConstantInt>(I->getOperand(1))) { + uint64_t ShiftAmt = SA->getLimitedValue(BitWidth); + + // Unsigned shift right. + APInt DemandedMaskIn(DemandedMask.shl(ShiftAmt)); + if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMaskIn, + RHSKnownZero, RHSKnownOne, Depth+1)) + return I; + assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?"); + RHSKnownZero = APIntOps::lshr(RHSKnownZero, ShiftAmt); + RHSKnownOne = APIntOps::lshr(RHSKnownOne, ShiftAmt); + if (ShiftAmt) { + // Compute the new bits that are at the top now. + APInt HighBits(APInt::getHighBitsSet(BitWidth, ShiftAmt)); + RHSKnownZero |= HighBits; // high bits known zero. + } + } + break; + case Instruction::AShr: + // If this is an arithmetic shift right and only the low-bit is set, we can + // always convert this into a logical shr, even if the shift amount is + // variable. The low bit of the shift cannot be an input sign bit unless + // the shift amount is >= the size of the datatype, which is undefined. + if (DemandedMask == 1) { + // Perform the logical shift right. + Instruction *NewVal = BinaryOperator::CreateLShr( + I->getOperand(0), I->getOperand(1), I->getName()); + return InsertNewInstBefore(NewVal, *I); + } + + // If the sign bit is the only bit demanded by this ashr, then there is no + // need to do it, the shift doesn't change the high bit. + if (DemandedMask.isSignBit()) + return I->getOperand(0); + + if (ConstantInt *SA = dyn_cast<ConstantInt>(I->getOperand(1))) { + uint32_t ShiftAmt = SA->getLimitedValue(BitWidth); + + // Signed shift right. + APInt DemandedMaskIn(DemandedMask.shl(ShiftAmt)); + // If any of the "high bits" are demanded, we should set the sign bit as + // demanded. + if (DemandedMask.countLeadingZeros() <= ShiftAmt) + DemandedMaskIn.set(BitWidth-1); + if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMaskIn, + RHSKnownZero, RHSKnownOne, Depth+1)) + return I; + assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?"); + // Compute the new bits that are at the top now. + APInt HighBits(APInt::getHighBitsSet(BitWidth, ShiftAmt)); + RHSKnownZero = APIntOps::lshr(RHSKnownZero, ShiftAmt); + RHSKnownOne = APIntOps::lshr(RHSKnownOne, ShiftAmt); + + // Handle the sign bits. + APInt SignBit(APInt::getSignBit(BitWidth)); + // Adjust to where it is now in the mask. + SignBit = APIntOps::lshr(SignBit, ShiftAmt); + + // If the input sign bit is known to be zero, or if none of the top bits + // are demanded, turn this into an unsigned shift right. + if (BitWidth <= ShiftAmt || RHSKnownZero[BitWidth-ShiftAmt-1] || + (HighBits & ~DemandedMask) == HighBits) { + // Perform the logical shift right. + Instruction *NewVal = BinaryOperator::CreateLShr( + I->getOperand(0), SA, I->getName()); + return InsertNewInstBefore(NewVal, *I); + } else if ((RHSKnownOne & SignBit) != 0) { // New bits are known one. + RHSKnownOne |= HighBits; + } + } + break; + case Instruction::SRem: + if (ConstantInt *Rem = dyn_cast<ConstantInt>(I->getOperand(1))) { + APInt RA = Rem->getValue().abs(); + if (RA.isPowerOf2()) { + if (DemandedMask.ule(RA)) // srem won't affect demanded bits + return I->getOperand(0); + + APInt LowBits = RA - 1; + APInt Mask2 = LowBits | APInt::getSignBit(BitWidth); + if (SimplifyDemandedBits(I->getOperandUse(0), Mask2, + LHSKnownZero, LHSKnownOne, Depth+1)) + return I; + + if (LHSKnownZero[BitWidth-1] || ((LHSKnownZero & LowBits) == LowBits)) + LHSKnownZero |= ~LowBits; + + KnownZero |= LHSKnownZero & DemandedMask; + + assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?"); + } + } + break; + case Instruction::URem: { + APInt KnownZero2(BitWidth, 0), KnownOne2(BitWidth, 0); + APInt AllOnes = APInt::getAllOnesValue(BitWidth); + if (SimplifyDemandedBits(I->getOperandUse(0), AllOnes, + KnownZero2, KnownOne2, Depth+1) || + SimplifyDemandedBits(I->getOperandUse(1), AllOnes, + KnownZero2, KnownOne2, Depth+1)) + return I; + + unsigned Leaders = KnownZero2.countLeadingOnes(); + Leaders = std::max(Leaders, + KnownZero2.countLeadingOnes()); + KnownZero = APInt::getHighBitsSet(BitWidth, Leaders) & DemandedMask; + break; + } + case Instruction::Call: + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { + switch (II->getIntrinsicID()) { + default: break; + case Intrinsic::bswap: { + // If the only bits demanded come from one byte of the bswap result, + // just shift the input byte into position to eliminate the bswap. + unsigned NLZ = DemandedMask.countLeadingZeros(); + unsigned NTZ = DemandedMask.countTrailingZeros(); + + // Round NTZ down to the next byte. If we have 11 trailing zeros, then + // we need all the bits down to bit 8. Likewise, round NLZ. If we + // have 14 leading zeros, round to 8. + NLZ &= ~7; + NTZ &= ~7; + // If we need exactly one byte, we can do this transformation. + if (BitWidth-NLZ-NTZ == 8) { + unsigned ResultBit = NTZ; + unsigned InputBit = BitWidth-NTZ-8; + + // Replace this with either a left or right shift to get the byte into + // the right place. + Instruction *NewVal; + if (InputBit > ResultBit) + NewVal = BinaryOperator::CreateLShr(I->getOperand(1), + ConstantInt::get(I->getType(), InputBit-ResultBit)); + else + NewVal = BinaryOperator::CreateShl(I->getOperand(1), + ConstantInt::get(I->getType(), ResultBit-InputBit)); + NewVal->takeName(I); + return InsertNewInstBefore(NewVal, *I); + } + + // TODO: Could compute known zero/one bits based on the input. + break; + } + } + } + ComputeMaskedBits(V, DemandedMask, RHSKnownZero, RHSKnownOne, Depth); + break; + } + + // If the client is only demanding bits that we know, return the known + // constant. + if ((DemandedMask & (RHSKnownZero|RHSKnownOne)) == DemandedMask) { + Constant *C = ConstantInt::get(RHSKnownOne); + if (isa<PointerType>(V->getType())) + C = ConstantExpr::getIntToPtr(C, V->getType()); + return C; + } + return false; +} + + +/// SimplifyDemandedVectorElts - The specified value produces a vector with +/// any number of elements. DemandedElts contains the set of elements that are +/// actually used by the caller. This method analyzes which elements of the +/// operand are undef and returns that information in UndefElts. +/// +/// If the information about demanded elements can be used to simplify the +/// operation, the operation is simplified, then the resultant value is +/// returned. This returns null if no change was made. +Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, + APInt& UndefElts, + unsigned Depth) { + unsigned VWidth = cast<VectorType>(V->getType())->getNumElements(); + APInt EltMask(APInt::getAllOnesValue(VWidth)); + assert((DemandedElts & ~EltMask) == 0 && "Invalid DemandedElts!"); + + if (isa<UndefValue>(V)) { + // If the entire vector is undefined, just return this info. + UndefElts = EltMask; + return 0; + } else if (DemandedElts == 0) { // If nothing is demanded, provide undef. + UndefElts = EltMask; + return UndefValue::get(V->getType()); + } + + UndefElts = 0; + if (ConstantVector *CP = dyn_cast<ConstantVector>(V)) { + const Type *EltTy = cast<VectorType>(V->getType())->getElementType(); + Constant *Undef = UndefValue::get(EltTy); + + std::vector<Constant*> Elts; + for (unsigned i = 0; i != VWidth; ++i) + if (!DemandedElts[i]) { // If not demanded, set to undef. + Elts.push_back(Undef); + UndefElts.set(i); + } else if (isa<UndefValue>(CP->getOperand(i))) { // Already undef. + Elts.push_back(Undef); + UndefElts.set(i); + } else { // Otherwise, defined. + Elts.push_back(CP->getOperand(i)); + } + + // If we changed the constant, return it. + Constant *NewCP = ConstantVector::get(Elts); + return NewCP != CP ? NewCP : 0; + } else if (isa<ConstantAggregateZero>(V)) { + // Simplify the CAZ to a ConstantVector where the non-demanded elements are + // set to undef. + + // Check if this is identity. If so, return 0 since we are not simplifying + // anything. + if (DemandedElts == ((1ULL << VWidth) -1)) + return 0; + + const Type *EltTy = cast<VectorType>(V->getType())->getElementType(); + Constant *Zero = Constant::getNullValue(EltTy); + Constant *Undef = UndefValue::get(EltTy); + std::vector<Constant*> Elts; + for (unsigned i = 0; i != VWidth; ++i) { + Constant *Elt = DemandedElts[i] ? Zero : Undef; + Elts.push_back(Elt); + } + UndefElts = DemandedElts ^ EltMask; + return ConstantVector::get(Elts); + } + + // Limit search depth. + if (Depth == 10) + return 0; + + // If multiple users are using the root value, procede with + // simplification conservatively assuming that all elements + // are needed. + if (!V->hasOneUse()) { + // Quit if we find multiple users of a non-root value though. + // They'll be handled when it's their turn to be visited by + // the main instcombine process. + if (Depth != 0) + // TODO: Just compute the UndefElts information recursively. + return 0; + + // Conservatively assume that all elements are needed. + DemandedElts = EltMask; + } + + Instruction *I = dyn_cast<Instruction>(V); + if (!I) return 0; // Only analyze instructions. + + bool MadeChange = false; + APInt UndefElts2(VWidth, 0); + Value *TmpV; + switch (I->getOpcode()) { + default: break; + + case Instruction::InsertElement: { + // If this is a variable index, we don't know which element it overwrites. + // demand exactly the same input as we produce. + ConstantInt *Idx = dyn_cast<ConstantInt>(I->getOperand(2)); + if (Idx == 0) { + // Note that we can't propagate undef elt info, because we don't know + // which elt is getting updated. + TmpV = SimplifyDemandedVectorElts(I->getOperand(0), DemandedElts, + UndefElts2, Depth+1); + if (TmpV) { I->setOperand(0, TmpV); MadeChange = true; } + break; + } + + // If this is inserting an element that isn't demanded, remove this + // insertelement. + unsigned IdxNo = Idx->getZExtValue(); + if (IdxNo >= VWidth || !DemandedElts[IdxNo]) + return AddSoonDeadInstToWorklist(*I, 0); + + // Otherwise, the element inserted overwrites whatever was there, so the + // input demanded set is simpler than the output set. + APInt DemandedElts2 = DemandedElts; + DemandedElts2.clear(IdxNo); + TmpV = SimplifyDemandedVectorElts(I->getOperand(0), DemandedElts2, + UndefElts, Depth+1); + if (TmpV) { I->setOperand(0, TmpV); MadeChange = true; } + + // The inserted element is defined. + UndefElts.clear(IdxNo); + break; + } + case Instruction::ShuffleVector: { + ShuffleVectorInst *Shuffle = cast<ShuffleVectorInst>(I); + uint64_t LHSVWidth = + cast<VectorType>(Shuffle->getOperand(0)->getType())->getNumElements(); + APInt LeftDemanded(LHSVWidth, 0), RightDemanded(LHSVWidth, 0); + for (unsigned i = 0; i < VWidth; i++) { + if (DemandedElts[i]) { + unsigned MaskVal = Shuffle->getMaskValue(i); + if (MaskVal != -1u) { + assert(MaskVal < LHSVWidth * 2 && + "shufflevector mask index out of range!"); + if (MaskVal < LHSVWidth) + LeftDemanded.set(MaskVal); + else + RightDemanded.set(MaskVal - LHSVWidth); + } + } + } + + APInt UndefElts4(LHSVWidth, 0); + TmpV = SimplifyDemandedVectorElts(I->getOperand(0), LeftDemanded, + UndefElts4, Depth+1); + if (TmpV) { I->setOperand(0, TmpV); MadeChange = true; } + + APInt UndefElts3(LHSVWidth, 0); + TmpV = SimplifyDemandedVectorElts(I->getOperand(1), RightDemanded, + UndefElts3, Depth+1); + if (TmpV) { I->setOperand(1, TmpV); MadeChange = true; } + + bool NewUndefElts = false; + for (unsigned i = 0; i < VWidth; i++) { + unsigned MaskVal = Shuffle->getMaskValue(i); + if (MaskVal == -1u) { + UndefElts.set(i); + } else if (MaskVal < LHSVWidth) { + if (UndefElts4[MaskVal]) { + NewUndefElts = true; + UndefElts.set(i); + } + } else { + if (UndefElts3[MaskVal - LHSVWidth]) { + NewUndefElts = true; + UndefElts.set(i); + } + } + } + + if (NewUndefElts) { + // Add additional discovered undefs. + std::vector<Constant*> Elts; + for (unsigned i = 0; i < VWidth; ++i) { + if (UndefElts[i]) + Elts.push_back(UndefValue::get(Type::Int32Ty)); + else + Elts.push_back(ConstantInt::get(Type::Int32Ty, + Shuffle->getMaskValue(i))); + } + I->setOperand(2, ConstantVector::get(Elts)); + MadeChange = true; + } + break; + } + case Instruction::BitCast: { + // Vector->vector casts only. + const VectorType *VTy = dyn_cast<VectorType>(I->getOperand(0)->getType()); + if (!VTy) break; + unsigned InVWidth = VTy->getNumElements(); + APInt InputDemandedElts(InVWidth, 0); + unsigned Ratio; + + if (VWidth == InVWidth) { + // If we are converting from <4 x i32> -> <4 x f32>, we demand the same + // elements as are demanded of us. + Ratio = 1; + InputDemandedElts = DemandedElts; + } else if (VWidth > InVWidth) { + // Untested so far. + break; + + // If there are more elements in the result than there are in the source, + // then an input element is live if any of the corresponding output + // elements are live. + Ratio = VWidth/InVWidth; + for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx) { + if (DemandedElts[OutIdx]) + InputDemandedElts.set(OutIdx/Ratio); + } + } else { + // Untested so far. + break; + + // If there are more elements in the source than there are in the result, + // then an input element is live if the corresponding output element is + // live. + Ratio = InVWidth/VWidth; + for (unsigned InIdx = 0; InIdx != InVWidth; ++InIdx) + if (DemandedElts[InIdx/Ratio]) + InputDemandedElts.set(InIdx); + } + + // div/rem demand all inputs, because they don't want divide by zero. + TmpV = SimplifyDemandedVectorElts(I->getOperand(0), InputDemandedElts, + UndefElts2, Depth+1); + if (TmpV) { + I->setOperand(0, TmpV); + MadeChange = true; + } + + UndefElts = UndefElts2; + if (VWidth > InVWidth) { + assert(0 && "Unimp"); + // If there are more elements in the result than there are in the source, + // then an output element is undef if the corresponding input element is + // undef. + for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx) + if (UndefElts2[OutIdx/Ratio]) + UndefElts.set(OutIdx); + } else if (VWidth < InVWidth) { + assert(0 && "Unimp"); + // If there are more elements in the source than there are in the result, + // then a result element is undef if all of the corresponding input + // elements are undef. + UndefElts = ~0ULL >> (64-VWidth); // Start out all undef. + for (unsigned InIdx = 0; InIdx != InVWidth; ++InIdx) + if (!UndefElts2[InIdx]) // Not undef? + UndefElts.clear(InIdx/Ratio); // Clear undef bit. + } + break; + } + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + case Instruction::Add: + case Instruction::Sub: + case Instruction::Mul: + // div/rem demand all inputs, because they don't want divide by zero. + TmpV = SimplifyDemandedVectorElts(I->getOperand(0), DemandedElts, + UndefElts, Depth+1); + if (TmpV) { I->setOperand(0, TmpV); MadeChange = true; } + TmpV = SimplifyDemandedVectorElts(I->getOperand(1), DemandedElts, + UndefElts2, Depth+1); + if (TmpV) { I->setOperand(1, TmpV); MadeChange = true; } + + // Output elements are undefined if both are undefined. Consider things + // like undef&0. The result is known zero, not undef. + UndefElts &= UndefElts2; + break; + + case Instruction::Call: { + IntrinsicInst *II = dyn_cast<IntrinsicInst>(I); + if (!II) break; + switch (II->getIntrinsicID()) { + default: break; + + // Binary vector operations that work column-wise. A dest element is a + // function of the corresponding input elements from the two inputs. + case Intrinsic::x86_sse_sub_ss: + case Intrinsic::x86_sse_mul_ss: + case Intrinsic::x86_sse_min_ss: + case Intrinsic::x86_sse_max_ss: + case Intrinsic::x86_sse2_sub_sd: + case Intrinsic::x86_sse2_mul_sd: + case Intrinsic::x86_sse2_min_sd: + case Intrinsic::x86_sse2_max_sd: + TmpV = SimplifyDemandedVectorElts(II->getOperand(1), DemandedElts, + UndefElts, Depth+1); + if (TmpV) { II->setOperand(1, TmpV); MadeChange = true; } + TmpV = SimplifyDemandedVectorElts(II->getOperand(2), DemandedElts, + UndefElts2, Depth+1); + if (TmpV) { II->setOperand(2, TmpV); MadeChange = true; } + + // If only the low elt is demanded and this is a scalarizable intrinsic, + // scalarize it now. + if (DemandedElts == 1) { + switch (II->getIntrinsicID()) { + default: break; + case Intrinsic::x86_sse_sub_ss: + case Intrinsic::x86_sse_mul_ss: + case Intrinsic::x86_sse2_sub_sd: + case Intrinsic::x86_sse2_mul_sd: + // TODO: Lower MIN/MAX/ABS/etc + Value *LHS = II->getOperand(1); + Value *RHS = II->getOperand(2); + // Extract the element as scalars. + LHS = InsertNewInstBefore(new ExtractElementInst(LHS, 0U,"tmp"), *II); + RHS = InsertNewInstBefore(new ExtractElementInst(RHS, 0U,"tmp"), *II); + + switch (II->getIntrinsicID()) { + default: assert(0 && "Case stmts out of sync!"); + case Intrinsic::x86_sse_sub_ss: + case Intrinsic::x86_sse2_sub_sd: + TmpV = InsertNewInstBefore(BinaryOperator::CreateSub(LHS, RHS, + II->getName()), *II); + break; + case Intrinsic::x86_sse_mul_ss: + case Intrinsic::x86_sse2_mul_sd: + TmpV = InsertNewInstBefore(BinaryOperator::CreateMul(LHS, RHS, + II->getName()), *II); + break; + } + + Instruction *New = + InsertElementInst::Create(UndefValue::get(II->getType()), TmpV, 0U, + II->getName()); + InsertNewInstBefore(New, *II); + AddSoonDeadInstToWorklist(*II, 0); + return New; + } + } + + // Output elements are undefined if both are undefined. Consider things + // like undef&0. The result is known zero, not undef. + UndefElts &= UndefElts2; + break; + } + break; + } + } + return MadeChange ? I : 0; +} + + +/// AssociativeOpt - Perform an optimization on an associative operator. This +/// function is designed to check a chain of associative operators for a +/// potential to apply a certain optimization. Since the optimization may be +/// applicable if the expression was reassociated, this checks the chain, then +/// reassociates the expression as necessary to expose the optimization +/// opportunity. This makes use of a special Functor, which must define +/// 'shouldApply' and 'apply' methods. +/// +template<typename Functor> +static Instruction *AssociativeOpt(BinaryOperator &Root, const Functor &F) { + unsigned Opcode = Root.getOpcode(); + Value *LHS = Root.getOperand(0); + + // Quick check, see if the immediate LHS matches... + if (F.shouldApply(LHS)) + return F.apply(Root); + + // Otherwise, if the LHS is not of the same opcode as the root, return. + Instruction *LHSI = dyn_cast<Instruction>(LHS); + while (LHSI && LHSI->getOpcode() == Opcode && LHSI->hasOneUse()) { + // Should we apply this transform to the RHS? + bool ShouldApply = F.shouldApply(LHSI->getOperand(1)); + + // If not to the RHS, check to see if we should apply to the LHS... + if (!ShouldApply && F.shouldApply(LHSI->getOperand(0))) { + cast<BinaryOperator>(LHSI)->swapOperands(); // Make the LHS the RHS + ShouldApply = true; + } + + // If the functor wants to apply the optimization to the RHS of LHSI, + // reassociate the expression from ((? op A) op B) to (? op (A op B)) + if (ShouldApply) { + // Now all of the instructions are in the current basic block, go ahead + // and perform the reassociation. + Instruction *TmpLHSI = cast<Instruction>(Root.getOperand(0)); + + // First move the selected RHS to the LHS of the root... + Root.setOperand(0, LHSI->getOperand(1)); + + // Make what used to be the LHS of the root be the user of the root... + Value *ExtraOperand = TmpLHSI->getOperand(1); + if (&Root == TmpLHSI) { + Root.replaceAllUsesWith(Constant::getNullValue(TmpLHSI->getType())); + return 0; + } + Root.replaceAllUsesWith(TmpLHSI); // Users now use TmpLHSI + TmpLHSI->setOperand(1, &Root); // TmpLHSI now uses the root + BasicBlock::iterator ARI = &Root; ++ARI; + TmpLHSI->moveBefore(ARI); // Move TmpLHSI to after Root + ARI = Root; + + // Now propagate the ExtraOperand down the chain of instructions until we + // get to LHSI. + while (TmpLHSI != LHSI) { + Instruction *NextLHSI = cast<Instruction>(TmpLHSI->getOperand(0)); + // Move the instruction to immediately before the chain we are + // constructing to avoid breaking dominance properties. + NextLHSI->moveBefore(ARI); + ARI = NextLHSI; + + Value *NextOp = NextLHSI->getOperand(1); + NextLHSI->setOperand(1, ExtraOperand); + TmpLHSI = NextLHSI; + ExtraOperand = NextOp; + } + + // Now that the instructions are reassociated, have the functor perform + // the transformation... + return F.apply(Root); + } + + LHSI = dyn_cast<Instruction>(LHSI->getOperand(0)); + } + return 0; +} + +namespace { + +// AddRHS - Implements: X + X --> X << 1 +struct AddRHS { + Value *RHS; + AddRHS(Value *rhs) : RHS(rhs) {} + bool shouldApply(Value *LHS) const { return LHS == RHS; } + Instruction *apply(BinaryOperator &Add) const { + return BinaryOperator::CreateShl(Add.getOperand(0), + ConstantInt::get(Add.getType(), 1)); + } +}; + +// AddMaskingAnd - Implements (A & C1)+(B & C2) --> (A & C1)|(B & C2) +// iff C1&C2 == 0 +struct AddMaskingAnd { + Constant *C2; + AddMaskingAnd(Constant *c) : C2(c) {} + bool shouldApply(Value *LHS) const { + ConstantInt *C1; + return match(LHS, m_And(m_Value(), m_ConstantInt(C1))) && + ConstantExpr::getAnd(C1, C2)->isNullValue(); + } + Instruction *apply(BinaryOperator &Add) const { + return BinaryOperator::CreateOr(Add.getOperand(0), Add.getOperand(1)); + } +}; + +} + +static Value *FoldOperationIntoSelectOperand(Instruction &I, Value *SO, + InstCombiner *IC) { + if (CastInst *CI = dyn_cast<CastInst>(&I)) { + return IC->InsertCastBefore(CI->getOpcode(), SO, I.getType(), I); + } + + // Figure out if the constant is the left or the right argument. + bool ConstIsRHS = isa<Constant>(I.getOperand(1)); + Constant *ConstOperand = cast<Constant>(I.getOperand(ConstIsRHS)); + + if (Constant *SOC = dyn_cast<Constant>(SO)) { + if (ConstIsRHS) + return ConstantExpr::get(I.getOpcode(), SOC, ConstOperand); + return ConstantExpr::get(I.getOpcode(), ConstOperand, SOC); + } + + Value *Op0 = SO, *Op1 = ConstOperand; + if (!ConstIsRHS) + std::swap(Op0, Op1); + Instruction *New; + if (BinaryOperator *BO = dyn_cast<BinaryOperator>(&I)) + New = BinaryOperator::Create(BO->getOpcode(), Op0, Op1,SO->getName()+".op"); + else if (CmpInst *CI = dyn_cast<CmpInst>(&I)) + New = CmpInst::Create(CI->getOpcode(), CI->getPredicate(), Op0, Op1, + SO->getName()+".cmp"); + else { + assert(0 && "Unknown binary instruction type!"); + abort(); + } + return IC->InsertNewInstBefore(New, I); +} + +// FoldOpIntoSelect - Given an instruction with a select as one operand and a +// constant as the other operand, try to fold the binary operator into the +// select arguments. This also works for Cast instructions, which obviously do +// not have a second operand. +static Instruction *FoldOpIntoSelect(Instruction &Op, SelectInst *SI, + InstCombiner *IC) { + // Don't modify shared select instructions + if (!SI->hasOneUse()) return 0; + Value *TV = SI->getOperand(1); + Value *FV = SI->getOperand(2); + + if (isa<Constant>(TV) || isa<Constant>(FV)) { + // Bool selects with constant operands can be folded to logical ops. + if (SI->getType() == Type::Int1Ty) return 0; + + Value *SelectTrueVal = FoldOperationIntoSelectOperand(Op, TV, IC); + Value *SelectFalseVal = FoldOperationIntoSelectOperand(Op, FV, IC); + + return SelectInst::Create(SI->getCondition(), SelectTrueVal, + SelectFalseVal); + } + return 0; +} + + +/// FoldOpIntoPhi - Given a binary operator or cast instruction which has a PHI +/// node as operand #0, see if we can fold the instruction into the PHI (which +/// is only possible if all operands to the PHI are constants). +Instruction *InstCombiner::FoldOpIntoPhi(Instruction &I) { + PHINode *PN = cast<PHINode>(I.getOperand(0)); + unsigned NumPHIValues = PN->getNumIncomingValues(); + if (!PN->hasOneUse() || NumPHIValues == 0) return 0; + + // Check to see if all of the operands of the PHI are constants. If there is + // one non-constant value, remember the BB it is. If there is more than one + // or if *it* is a PHI, bail out. + BasicBlock *NonConstBB = 0; + for (unsigned i = 0; i != NumPHIValues; ++i) + if (!isa<Constant>(PN->getIncomingValue(i))) { + if (NonConstBB) return 0; // More than one non-const value. + if (isa<PHINode>(PN->getIncomingValue(i))) return 0; // Itself a phi. + NonConstBB = PN->getIncomingBlock(i); + + // If the incoming non-constant value is in I's block, we have an infinite + // loop. + if (NonConstBB == I.getParent()) + return 0; + } + + // If there is exactly one non-constant value, we can insert a copy of the + // operation in that block. However, if this is a critical edge, we would be + // inserting the computation one some other paths (e.g. inside a loop). Only + // do this if the pred block is unconditionally branching into the phi block. + if (NonConstBB) { + BranchInst *BI = dyn_cast<BranchInst>(NonConstBB->getTerminator()); + if (!BI || !BI->isUnconditional()) return 0; + } + + // Okay, we can do the transformation: create the new PHI node. + PHINode *NewPN = PHINode::Create(I.getType(), ""); + NewPN->reserveOperandSpace(PN->getNumOperands()/2); + InsertNewInstBefore(NewPN, *PN); + NewPN->takeName(PN); + + // Next, add all of the operands to the PHI. + if (I.getNumOperands() == 2) { + Constant *C = cast<Constant>(I.getOperand(1)); + for (unsigned i = 0; i != NumPHIValues; ++i) { + Value *InV = 0; + if (Constant *InC = dyn_cast<Constant>(PN->getIncomingValue(i))) { + if (CmpInst *CI = dyn_cast<CmpInst>(&I)) + InV = ConstantExpr::getCompare(CI->getPredicate(), InC, C); + else + InV = ConstantExpr::get(I.getOpcode(), InC, C); + } else { + assert(PN->getIncomingBlock(i) == NonConstBB); + if (BinaryOperator *BO = dyn_cast<BinaryOperator>(&I)) + InV = BinaryOperator::Create(BO->getOpcode(), + PN->getIncomingValue(i), C, "phitmp", + NonConstBB->getTerminator()); + else if (CmpInst *CI = dyn_cast<CmpInst>(&I)) + InV = CmpInst::Create(CI->getOpcode(), + CI->getPredicate(), + PN->getIncomingValue(i), C, "phitmp", + NonConstBB->getTerminator()); + else + assert(0 && "Unknown binop!"); + + AddToWorkList(cast<Instruction>(InV)); + } + NewPN->addIncoming(InV, PN->getIncomingBlock(i)); + } + } else { + CastInst *CI = cast<CastInst>(&I); + const Type *RetTy = CI->getType(); + for (unsigned i = 0; i != NumPHIValues; ++i) { + Value *InV; + if (Constant *InC = dyn_cast<Constant>(PN->getIncomingValue(i))) { + InV = ConstantExpr::getCast(CI->getOpcode(), InC, RetTy); + } else { + assert(PN->getIncomingBlock(i) == NonConstBB); + InV = CastInst::Create(CI->getOpcode(), PN->getIncomingValue(i), + I.getType(), "phitmp", + NonConstBB->getTerminator()); + AddToWorkList(cast<Instruction>(InV)); + } + NewPN->addIncoming(InV, PN->getIncomingBlock(i)); + } + } + return ReplaceInstUsesWith(I, NewPN); +} + + +/// WillNotOverflowSignedAdd - Return true if we can prove that: +/// (sext (add LHS, RHS)) === (add (sext LHS), (sext RHS)) +/// This basically requires proving that the add in the original type would not +/// overflow to change the sign bit or have a carry out. +bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS) { + // There are different heuristics we can use for this. Here are some simple + // ones. + + // Add has the property that adding any two 2's complement numbers can only + // have one carry bit which can change a sign. As such, if LHS and RHS each + // have at least two sign bits, we know that the addition of the two values will + // sign extend fine. + if (ComputeNumSignBits(LHS) > 1 && ComputeNumSignBits(RHS) > 1) + return true; + + + // If one of the operands only has one non-zero bit, and if the other operand + // has a known-zero bit in a more significant place than it (not including the + // sign bit) the ripple may go up to and fill the zero, but won't change the + // sign. For example, (X & ~4) + 1. + + // TODO: Implement. + + return false; +} + + +Instruction *InstCombiner::visitAdd(BinaryOperator &I) { + bool Changed = SimplifyCommutative(I); + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); + + if (Constant *RHSC = dyn_cast<Constant>(RHS)) { + // X + undef -> undef + if (isa<UndefValue>(RHS)) + return ReplaceInstUsesWith(I, RHS); + + // X + 0 --> X + if (!I.getType()->isFPOrFPVector()) { // NOTE: -0 + +0 = +0. + if (RHSC->isNullValue()) + return ReplaceInstUsesWith(I, LHS); + } else if (ConstantFP *CFP = dyn_cast<ConstantFP>(RHSC)) { + if (CFP->isExactlyValue(ConstantFP::getNegativeZero + (I.getType())->getValueAPF())) + return ReplaceInstUsesWith(I, LHS); + } + + if (ConstantInt *CI = dyn_cast<ConstantInt>(RHSC)) { + // X + (signbit) --> X ^ signbit + const APInt& Val = CI->getValue(); + uint32_t BitWidth = Val.getBitWidth(); + if (Val == APInt::getSignBit(BitWidth)) + return BinaryOperator::CreateXor(LHS, RHS); + + // See if SimplifyDemandedBits can simplify this. This handles stuff like + // (X & 254)+1 -> (X&254)|1 + if (!isa<VectorType>(I.getType()) && SimplifyDemandedInstructionBits(I)) + return &I; + + // zext(i1) - 1 -> select i1, 0, -1 + if (ZExtInst *ZI = dyn_cast<ZExtInst>(LHS)) + if (CI->isAllOnesValue() && + ZI->getOperand(0)->getType() == Type::Int1Ty) + return SelectInst::Create(ZI->getOperand(0), + Constant::getNullValue(I.getType()), + ConstantInt::getAllOnesValue(I.getType())); + } + + if (isa<PHINode>(LHS)) + if (Instruction *NV = FoldOpIntoPhi(I)) + return NV; + + ConstantInt *XorRHS = 0; + Value *XorLHS = 0; + if (isa<ConstantInt>(RHSC) && + match(LHS, m_Xor(m_Value(XorLHS), m_ConstantInt(XorRHS)))) { + uint32_t TySizeBits = I.getType()->getPrimitiveSizeInBits(); + const APInt& RHSVal = cast<ConstantInt>(RHSC)->getValue(); + + uint32_t Size = TySizeBits / 2; + APInt C0080Val(APInt(TySizeBits, 1ULL).shl(Size - 1)); + APInt CFF80Val(-C0080Val); + do { + if (TySizeBits > Size) { + // If we have ADD(XOR(AND(X, 0xFF), 0x80), 0xF..F80), it's a sext. + // If we have ADD(XOR(AND(X, 0xFF), 0xF..F80), 0x80), it's a sext. + if ((RHSVal == CFF80Val && XorRHS->getValue() == C0080Val) || + (RHSVal == C0080Val && XorRHS->getValue() == CFF80Val)) { + // This is a sign extend if the top bits are known zero. + if (!MaskedValueIsZero(XorLHS, + APInt::getHighBitsSet(TySizeBits, TySizeBits - Size))) + Size = 0; // Not a sign ext, but can't be any others either. + break; + } + } + Size >>= 1; + C0080Val = APIntOps::lshr(C0080Val, Size); + CFF80Val = APIntOps::ashr(CFF80Val, Size); + } while (Size >= 1); + + // FIXME: This shouldn't be necessary. When the backends can handle types + // with funny bit widths then this switch statement should be removed. It + // is just here to get the size of the "middle" type back up to something + // that the back ends can handle. + const Type *MiddleType = 0; + switch (Size) { + default: break; + case 32: MiddleType = Type::Int32Ty; break; + case 16: MiddleType = Type::Int16Ty; break; + case 8: MiddleType = Type::Int8Ty; break; + } + if (MiddleType) { + Instruction *NewTrunc = new TruncInst(XorLHS, MiddleType, "sext"); + InsertNewInstBefore(NewTrunc, I); + return new SExtInst(NewTrunc, I.getType(), I.getName()); + } + } + } + + if (I.getType() == Type::Int1Ty) + return BinaryOperator::CreateXor(LHS, RHS); + + // X + X --> X << 1 + if (I.getType()->isInteger()) { + if (Instruction *Result = AssociativeOpt(I, AddRHS(RHS))) return Result; + + if (Instruction *RHSI = dyn_cast<Instruction>(RHS)) { + if (RHSI->getOpcode() == Instruction::Sub) + if (LHS == RHSI->getOperand(1)) // A + (B - A) --> B + return ReplaceInstUsesWith(I, RHSI->getOperand(0)); + } + if (Instruction *LHSI = dyn_cast<Instruction>(LHS)) { + if (LHSI->getOpcode() == Instruction::Sub) + if (RHS == LHSI->getOperand(1)) // (B - A) + A --> B + return ReplaceInstUsesWith(I, LHSI->getOperand(0)); + } + } + + // -A + B --> B - A + // -A + -B --> -(A + B) + if (Value *LHSV = dyn_castNegVal(LHS)) { + if (LHS->getType()->isIntOrIntVector()) { + if (Value *RHSV = dyn_castNegVal(RHS)) { + Instruction *NewAdd = BinaryOperator::CreateAdd(LHSV, RHSV, "sum"); + InsertNewInstBefore(NewAdd, I); + return BinaryOperator::CreateNeg(NewAdd); + } + } + + return BinaryOperator::CreateSub(RHS, LHSV); + } + + // A + -B --> A - B + if (!isa<Constant>(RHS)) + if (Value *V = dyn_castNegVal(RHS)) + return BinaryOperator::CreateSub(LHS, V); + + + ConstantInt *C2; + if (Value *X = dyn_castFoldableMul(LHS, C2)) { + if (X == RHS) // X*C + X --> X * (C+1) + return BinaryOperator::CreateMul(RHS, AddOne(C2)); + + // X*C1 + X*C2 --> X * (C1+C2) + ConstantInt *C1; + if (X == dyn_castFoldableMul(RHS, C1)) + return BinaryOperator::CreateMul(X, Add(C1, C2)); + } + + // X + X*C --> X * (C+1) + if (dyn_castFoldableMul(RHS, C2) == LHS) + return BinaryOperator::CreateMul(LHS, AddOne(C2)); + + // X + ~X --> -1 since ~X = -X-1 + if (dyn_castNotVal(LHS) == RHS || dyn_castNotVal(RHS) == LHS) + return ReplaceInstUsesWith(I, Constant::getAllOnesValue(I.getType())); + + + // (A & C1)+(B & C2) --> (A & C1)|(B & C2) iff C1&C2 == 0 + if (match(RHS, m_And(m_Value(), m_ConstantInt(C2)))) + if (Instruction *R = AssociativeOpt(I, AddMaskingAnd(C2))) + return R; + + // A+B --> A|B iff A and B have no bits set in common. + if (const IntegerType *IT = dyn_cast<IntegerType>(I.getType())) { + APInt Mask = APInt::getAllOnesValue(IT->getBitWidth()); + APInt LHSKnownOne(IT->getBitWidth(), 0); + APInt LHSKnownZero(IT->getBitWidth(), 0); + ComputeMaskedBits(LHS, Mask, LHSKnownZero, LHSKnownOne); + if (LHSKnownZero != 0) { + APInt RHSKnownOne(IT->getBitWidth(), 0); + APInt RHSKnownZero(IT->getBitWidth(), 0); + ComputeMaskedBits(RHS, Mask, RHSKnownZero, RHSKnownOne); + + // No bits in common -> bitwise or. + if ((LHSKnownZero|RHSKnownZero).isAllOnesValue()) + return BinaryOperator::CreateOr(LHS, RHS); + } + } + + // W*X + Y*Z --> W * (X+Z) iff W == Y + if (I.getType()->isIntOrIntVector()) { + Value *W, *X, *Y, *Z; + if (match(LHS, m_Mul(m_Value(W), m_Value(X))) && + match(RHS, m_Mul(m_Value(Y), m_Value(Z)))) { + if (W != Y) { + if (W == Z) { + std::swap(Y, Z); + } else if (Y == X) { + std::swap(W, X); + } else if (X == Z) { + std::swap(Y, Z); + std::swap(W, X); + } + } + + if (W == Y) { + Value *NewAdd = InsertNewInstBefore(BinaryOperator::CreateAdd(X, Z, + LHS->getName()), I); + return BinaryOperator::CreateMul(W, NewAdd); + } + } + } + + if (ConstantInt *CRHS = dyn_cast<ConstantInt>(RHS)) { + Value *X = 0; + if (match(LHS, m_Not(m_Value(X)))) // ~X + C --> (C-1) - X + return BinaryOperator::CreateSub(SubOne(CRHS), X); + + // (X & FF00) + xx00 -> (X+xx00) & FF00 + if (LHS->hasOneUse() && match(LHS, m_And(m_Value(X), m_ConstantInt(C2)))) { + Constant *Anded = And(CRHS, C2); + if (Anded == CRHS) { + // See if all bits from the first bit set in the Add RHS up are included + // in the mask. First, get the rightmost bit. + const APInt& AddRHSV = CRHS->getValue(); + + // Form a mask of all bits from the lowest bit added through the top. + APInt AddRHSHighBits(~((AddRHSV & -AddRHSV)-1)); + + // See if the and mask includes all of these bits. + APInt AddRHSHighBitsAnd(AddRHSHighBits & C2->getValue()); + + if (AddRHSHighBits == AddRHSHighBitsAnd) { + // Okay, the xform is safe. Insert the new add pronto. + Value *NewAdd = InsertNewInstBefore(BinaryOperator::CreateAdd(X, CRHS, + LHS->getName()), I); + return BinaryOperator::CreateAnd(NewAdd, C2); + } + } + } + + // Try to fold constant add into select arguments. + if (SelectInst *SI = dyn_cast<SelectInst>(LHS)) + if (Instruction *R = FoldOpIntoSelect(I, SI, this)) + return R; + } + + // add (cast *A to intptrtype) B -> + // cast (GEP (cast *A to sbyte*) B) --> intptrtype + { + CastInst *CI = dyn_cast<CastInst>(LHS); + Value *Other = RHS; + if (!CI) { + CI = dyn_cast<CastInst>(RHS); + Other = LHS; + } + if (CI && CI->getType()->isSized() && + (CI->getType()->getPrimitiveSizeInBits() == + TD->getIntPtrType()->getPrimitiveSizeInBits()) + && isa<PointerType>(CI->getOperand(0)->getType())) { + unsigned AS = + cast<PointerType>(CI->getOperand(0)->getType())->getAddressSpace(); + Value *I2 = InsertBitCastBefore(CI->getOperand(0), + PointerType::get(Type::Int8Ty, AS), I); + I2 = InsertNewInstBefore(GetElementPtrInst::Create(I2, Other, "ctg2"), I); + return new PtrToIntInst(I2, CI->getType()); + } + } + + // add (select X 0 (sub n A)) A --> select X A n + { + SelectInst *SI = dyn_cast<SelectInst>(LHS); + Value *A = RHS; + if (!SI) { + SI = dyn_cast<SelectInst>(RHS); + A = LHS; + } + if (SI && SI->hasOneUse()) { + Value *TV = SI->getTrueValue(); + Value *FV = SI->getFalseValue(); + Value *N; + + // Can we fold the add into the argument of the select? + // We check both true and false select arguments for a matching subtract. + if (match(FV, m_Zero()) && match(TV, m_Sub(m_Value(N), m_Specific(A)))) + // Fold the add into the true select value. + return SelectInst::Create(SI->getCondition(), N, A); + if (match(TV, m_Zero()) && match(FV, m_Sub(m_Value(N), m_Specific(A)))) + // Fold the add into the false select value. + return SelectInst::Create(SI->getCondition(), A, N); + } + } + + // Check for X+0.0. Simplify it to X if we know X is not -0.0. + if (ConstantFP *CFP = dyn_cast<ConstantFP>(RHS)) + if (CFP->getValueAPF().isPosZero() && CannotBeNegativeZero(LHS)) + return ReplaceInstUsesWith(I, LHS); + + // Check for (add (sext x), y), see if we can merge this into an + // integer add followed by a sext. + if (SExtInst *LHSConv = dyn_cast<SExtInst>(LHS)) { + // (add (sext x), cst) --> (sext (add x, cst')) + if (ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS)) { + Constant *CI = + ConstantExpr::getTrunc(RHSC, LHSConv->getOperand(0)->getType()); + if (LHSConv->hasOneUse() && + ConstantExpr::getSExt(CI, I.getType()) == RHSC && + WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI)) { + // Insert the new, smaller add. + Instruction *NewAdd = BinaryOperator::CreateAdd(LHSConv->getOperand(0), + CI, "addconv"); + InsertNewInstBefore(NewAdd, I); + return new SExtInst(NewAdd, I.getType()); + } + } + + // (add (sext x), (sext y)) --> (sext (add int x, y)) + if (SExtInst *RHSConv = dyn_cast<SExtInst>(RHS)) { + // Only do this if x/y have the same type, if at last one of them has a + // single use (so we don't increase the number of sexts), and if the + // integer add will not overflow. + if (LHSConv->getOperand(0)->getType()==RHSConv->getOperand(0)->getType()&& + (LHSConv->hasOneUse() || RHSConv->hasOneUse()) && + WillNotOverflowSignedAdd(LHSConv->getOperand(0), + RHSConv->getOperand(0))) { + // Insert the new integer add. + Instruction *NewAdd = BinaryOperator::CreateAdd(LHSConv->getOperand(0), + RHSConv->getOperand(0), + "addconv"); + InsertNewInstBefore(NewAdd, I); + return new SExtInst(NewAdd, I.getType()); + } + } + } + + // Check for (add double (sitofp x), y), see if we can merge this into an + // integer add followed by a promotion. + if (SIToFPInst *LHSConv = dyn_cast<SIToFPInst>(LHS)) { + // (add double (sitofp x), fpcst) --> (sitofp (add int x, intcst)) + // ... if the constant fits in the integer value. This is useful for things + // like (double)(x & 1234) + 4.0 -> (double)((X & 1234)+4) which no longer + // requires a constant pool load, and generally allows the add to be better + // instcombined. + if (ConstantFP *CFP = dyn_cast<ConstantFP>(RHS)) { + Constant *CI = + ConstantExpr::getFPToSI(CFP, LHSConv->getOperand(0)->getType()); + if (LHSConv->hasOneUse() && + ConstantExpr::getSIToFP(CI, I.getType()) == CFP && + WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI)) { + // Insert the new integer add. + Instruction *NewAdd = BinaryOperator::CreateAdd(LHSConv->getOperand(0), + CI, "addconv"); + InsertNewInstBefore(NewAdd, I); + return new SIToFPInst(NewAdd, I.getType()); + } + } + + // (add double (sitofp x), (sitofp y)) --> (sitofp (add int x, y)) + if (SIToFPInst *RHSConv = dyn_cast<SIToFPInst>(RHS)) { + // Only do this if x/y have the same type, if at last one of them has a + // single use (so we don't increase the number of int->fp conversions), + // and if the integer add will not overflow. + if (LHSConv->getOperand(0)->getType()==RHSConv->getOperand(0)->getType()&& + (LHSConv->hasOneUse() || RHSConv->hasOneUse()) && + WillNotOverflowSignedAdd(LHSConv->getOperand(0), + RHSConv->getOperand(0))) { + // Insert the new integer add. + Instruction *NewAdd = BinaryOperator::CreateAdd(LHSConv->getOperand(0), + RHSConv->getOperand(0), + "addconv"); + InsertNewInstBefore(NewAdd, I); + return new SIToFPInst(NewAdd, I.getType()); + } + } + } + + return Changed ? &I : 0; +} + +Instruction *InstCombiner::visitSub(BinaryOperator &I) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + if (Op0 == Op1 && // sub X, X -> 0 + !I.getType()->isFPOrFPVector()) + return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); + + // If this is a 'B = x-(-A)', change to B = x+A... + if (Value *V = dyn_castNegVal(Op1)) + return BinaryOperator::CreateAdd(Op0, V); + + if (isa<UndefValue>(Op0)) + return ReplaceInstUsesWith(I, Op0); // undef - X -> undef + if (isa<UndefValue>(Op1)) + return ReplaceInstUsesWith(I, Op1); // X - undef -> undef + + if (ConstantInt *C = dyn_cast<ConstantInt>(Op0)) { + // Replace (-1 - A) with (~A)... + if (C->isAllOnesValue()) + return BinaryOperator::CreateNot(Op1); + + // C - ~X == X + (1+C) + Value *X = 0; + if (match(Op1, m_Not(m_Value(X)))) + return BinaryOperator::CreateAdd(X, AddOne(C)); + + // -(X >>u 31) -> (X >>s 31) + // -(X >>s 31) -> (X >>u 31) + if (C->isZero()) { + if (BinaryOperator *SI = dyn_cast<BinaryOperator>(Op1)) { + if (SI->getOpcode() == Instruction::LShr) { + if (ConstantInt *CU = dyn_cast<ConstantInt>(SI->getOperand(1))) { + // Check to see if we are shifting out everything but the sign bit. + if (CU->getLimitedValue(SI->getType()->getPrimitiveSizeInBits()) == + SI->getType()->getPrimitiveSizeInBits()-1) { + // Ok, the transformation is safe. Insert AShr. + return BinaryOperator::Create(Instruction::AShr, + SI->getOperand(0), CU, SI->getName()); + } + } + } + else if (SI->getOpcode() == Instruction::AShr) { + if (ConstantInt *CU = dyn_cast<ConstantInt>(SI->getOperand(1))) { + // Check to see if we are shifting out everything but the sign bit. + if (CU->getLimitedValue(SI->getType()->getPrimitiveSizeInBits()) == + SI->getType()->getPrimitiveSizeInBits()-1) { + // Ok, the transformation is safe. Insert LShr. + return BinaryOperator::CreateLShr( + SI->getOperand(0), CU, SI->getName()); + } + } + } + } + } + + // Try to fold constant sub into select arguments. + if (SelectInst *SI = dyn_cast<SelectInst>(Op1)) + if (Instruction *R = FoldOpIntoSelect(I, SI, this)) + return R; + } + + if (I.getType() == Type::Int1Ty) + return BinaryOperator::CreateXor(Op0, Op1); + + if (BinaryOperator *Op1I = dyn_cast<BinaryOperator>(Op1)) { + if (Op1I->getOpcode() == Instruction::Add && + !Op0->getType()->isFPOrFPVector()) { + if (Op1I->getOperand(0) == Op0) // X-(X+Y) == -Y + return BinaryOperator::CreateNeg(Op1I->getOperand(1), I.getName()); + else if (Op1I->getOperand(1) == Op0) // X-(Y+X) == -Y + return BinaryOperator::CreateNeg(Op1I->getOperand(0), I.getName()); + else if (ConstantInt *CI1 = dyn_cast<ConstantInt>(I.getOperand(0))) { + if (ConstantInt *CI2 = dyn_cast<ConstantInt>(Op1I->getOperand(1))) + // C1-(X+C2) --> (C1-C2)-X + return BinaryOperator::CreateSub(Subtract(CI1, CI2), + Op1I->getOperand(0)); + } + } + + if (Op1I->hasOneUse()) { + // Replace (x - (y - z)) with (x + (z - y)) if the (y - z) subexpression + // is not used by anyone else... + // + if (Op1I->getOpcode() == Instruction::Sub && + !Op1I->getType()->isFPOrFPVector()) { + // Swap the two operands of the subexpr... + Value *IIOp0 = Op1I->getOperand(0), *IIOp1 = Op1I->getOperand(1); + Op1I->setOperand(0, IIOp1); + Op1I->setOperand(1, IIOp0); + + // Create the new top level add instruction... + return BinaryOperator::CreateAdd(Op0, Op1); + } + + // Replace (A - (A & B)) with (A & ~B) if this is the only use of (A&B)... + // + if (Op1I->getOpcode() == Instruction::And && + (Op1I->getOperand(0) == Op0 || Op1I->getOperand(1) == Op0)) { + Value *OtherOp = Op1I->getOperand(Op1I->getOperand(0) == Op0); + + Value *NewNot = + InsertNewInstBefore(BinaryOperator::CreateNot(OtherOp, "B.not"), I); + return BinaryOperator::CreateAnd(Op0, NewNot); + } + + // 0 - (X sdiv C) -> (X sdiv -C) + if (Op1I->getOpcode() == Instruction::SDiv) + if (ConstantInt *CSI = dyn_cast<ConstantInt>(Op0)) + if (CSI->isZero()) + if (Constant *DivRHS = dyn_cast<Constant>(Op1I->getOperand(1))) + return BinaryOperator::CreateSDiv(Op1I->getOperand(0), + ConstantExpr::getNeg(DivRHS)); + + // X - X*C --> X * (1-C) + ConstantInt *C2 = 0; + if (dyn_castFoldableMul(Op1I, C2) == Op0) { + Constant *CP1 = Subtract(ConstantInt::get(I.getType(), 1), C2); + return BinaryOperator::CreateMul(Op0, CP1); + } + } + } + + if (!Op0->getType()->isFPOrFPVector()) + if (BinaryOperator *Op0I = dyn_cast<BinaryOperator>(Op0)) { + if (Op0I->getOpcode() == Instruction::Add) { + if (Op0I->getOperand(0) == Op1) // (Y+X)-Y == X + return ReplaceInstUsesWith(I, Op0I->getOperand(1)); + else if (Op0I->getOperand(1) == Op1) // (X+Y)-Y == X + return ReplaceInstUsesWith(I, Op0I->getOperand(0)); + } else if (Op0I->getOpcode() == Instruction::Sub) { + if (Op0I->getOperand(0) == Op1) // (X-Y)-X == -Y + return BinaryOperator::CreateNeg(Op0I->getOperand(1), I.getName()); + } + } + + ConstantInt *C1; + if (Value *X = dyn_castFoldableMul(Op0, C1)) { + if (X == Op1) // X*C - X --> X * (C-1) + return BinaryOperator::CreateMul(Op1, SubOne(C1)); + + ConstantInt *C2; // X*C1 - X*C2 -> X * (C1-C2) + if (X == dyn_castFoldableMul(Op1, C2)) + return BinaryOperator::CreateMul(X, Subtract(C1, C2)); + } + return 0; +} + +/// isSignBitCheck - Given an exploded icmp instruction, return true if the +/// comparison only checks the sign bit. If it only checks the sign bit, set +/// TrueIfSigned if the result of the comparison is true when the input value is +/// signed. +static bool isSignBitCheck(ICmpInst::Predicate pred, ConstantInt *RHS, + bool &TrueIfSigned) { + switch (pred) { + case ICmpInst::ICMP_SLT: // True if LHS s< 0 + TrueIfSigned = true; + return RHS->isZero(); + case ICmpInst::ICMP_SLE: // True if LHS s<= RHS and RHS == -1 + TrueIfSigned = true; + return RHS->isAllOnesValue(); + case ICmpInst::ICMP_SGT: // True if LHS s> -1 + TrueIfSigned = false; + return RHS->isAllOnesValue(); + case ICmpInst::ICMP_UGT: + // True if LHS u> RHS and RHS == high-bit-mask - 1 + TrueIfSigned = true; + return RHS->getValue() == + APInt::getSignedMaxValue(RHS->getType()->getPrimitiveSizeInBits()); + case ICmpInst::ICMP_UGE: + // True if LHS u>= RHS and RHS == high-bit-mask (2^7, 2^15, 2^31, etc) + TrueIfSigned = true; + return RHS->getValue().isSignBit(); + default: + return false; + } +} + +Instruction *InstCombiner::visitMul(BinaryOperator &I) { + bool Changed = SimplifyCommutative(I); + Value *Op0 = I.getOperand(0); + + if (isa<UndefValue>(I.getOperand(1))) // undef * X -> 0 + return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); + + // Simplify mul instructions with a constant RHS... + if (Constant *Op1 = dyn_cast<Constant>(I.getOperand(1))) { + if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { + + // ((X << C1)*C2) == (X * (C2 << C1)) + if (BinaryOperator *SI = dyn_cast<BinaryOperator>(Op0)) + if (SI->getOpcode() == Instruction::Shl) + if (Constant *ShOp = dyn_cast<Constant>(SI->getOperand(1))) + return BinaryOperator::CreateMul(SI->getOperand(0), + ConstantExpr::getShl(CI, ShOp)); + + if (CI->isZero()) + return ReplaceInstUsesWith(I, Op1); // X * 0 == 0 + if (CI->equalsInt(1)) // X * 1 == X + return ReplaceInstUsesWith(I, Op0); + if (CI->isAllOnesValue()) // X * -1 == 0 - X + return BinaryOperator::CreateNeg(Op0, I.getName()); + + const APInt& Val = cast<ConstantInt>(CI)->getValue(); + if (Val.isPowerOf2()) { // Replace X*(2^C) with X << C + return BinaryOperator::CreateShl(Op0, + ConstantInt::get(Op0->getType(), Val.logBase2())); + } + } else if (ConstantFP *Op1F = dyn_cast<ConstantFP>(Op1)) { + if (Op1F->isNullValue()) + return ReplaceInstUsesWith(I, Op1); + + // "In IEEE floating point, x*1 is not equivalent to x for nans. However, + // ANSI says we can drop signals, so we can do this anyway." (from GCC) + if (Op1F->isExactlyValue(1.0)) + return ReplaceInstUsesWith(I, Op0); // Eliminate 'mul double %X, 1.0' + } else if (isa<VectorType>(Op1->getType())) { + if (isa<ConstantAggregateZero>(Op1)) + return ReplaceInstUsesWith(I, Op1); + + if (ConstantVector *Op1V = dyn_cast<ConstantVector>(Op1)) { + if (Op1V->isAllOnesValue()) // X * -1 == 0 - X + return BinaryOperator::CreateNeg(Op0, I.getName()); + + // As above, vector X*splat(1.0) -> X in all defined cases. + if (Constant *Splat = Op1V->getSplatValue()) { + if (ConstantFP *F = dyn_cast<ConstantFP>(Splat)) + if (F->isExactlyValue(1.0)) + return ReplaceInstUsesWith(I, Op0); + if (ConstantInt *CI = dyn_cast<ConstantInt>(Splat)) + if (CI->equalsInt(1)) + return ReplaceInstUsesWith(I, Op0); + } + } + } + + if (BinaryOperator *Op0I = dyn_cast<BinaryOperator>(Op0)) + if (Op0I->getOpcode() == Instruction::Add && Op0I->hasOneUse() && + isa<ConstantInt>(Op0I->getOperand(1)) && isa<ConstantInt>(Op1)) { + // Canonicalize (X+C1)*C2 -> X*C2+C1*C2. + Instruction *Add = BinaryOperator::CreateMul(Op0I->getOperand(0), + Op1, "tmp"); + InsertNewInstBefore(Add, I); + Value *C1C2 = ConstantExpr::getMul(Op1, + cast<Constant>(Op0I->getOperand(1))); + return BinaryOperator::CreateAdd(Add, C1C2); + + } + + // Try to fold constant mul into select arguments. + if (SelectInst *SI = dyn_cast<SelectInst>(Op0)) + if (Instruction *R = FoldOpIntoSelect(I, SI, this)) + return R; + + if (isa<PHINode>(Op0)) + if (Instruction *NV = FoldOpIntoPhi(I)) + return NV; + } + + if (Value *Op0v = dyn_castNegVal(Op0)) // -X * -Y = X*Y + if (Value *Op1v = dyn_castNegVal(I.getOperand(1))) + return BinaryOperator::CreateMul(Op0v, Op1v); + + // (X / Y) * Y = X - (X % Y) + // (X / Y) * -Y = (X % Y) - X + { + Value *Op1 = I.getOperand(1); + BinaryOperator *BO = dyn_cast<BinaryOperator>(Op0); + if (!BO || + (BO->getOpcode() != Instruction::UDiv && + BO->getOpcode() != Instruction::SDiv)) { + Op1 = Op0; + BO = dyn_cast<BinaryOperator>(I.getOperand(1)); + } + Value *Neg = dyn_castNegVal(Op1); + if (BO && BO->hasOneUse() && + (BO->getOperand(1) == Op1 || BO->getOperand(1) == Neg) && + (BO->getOpcode() == Instruction::UDiv || + BO->getOpcode() == Instruction::SDiv)) { + Value *Op0BO = BO->getOperand(0), *Op1BO = BO->getOperand(1); + + Instruction *Rem; + if (BO->getOpcode() == Instruction::UDiv) + Rem = BinaryOperator::CreateURem(Op0BO, Op1BO); + else + Rem = BinaryOperator::CreateSRem(Op0BO, Op1BO); + + InsertNewInstBefore(Rem, I); + Rem->takeName(BO); + + if (Op1BO == Op1) + return BinaryOperator::CreateSub(Op0BO, Rem); + else + return BinaryOperator::CreateSub(Rem, Op0BO); + } + } + + if (I.getType() == Type::Int1Ty) + return BinaryOperator::CreateAnd(Op0, I.getOperand(1)); + + // If one of the operands of the multiply is a cast from a boolean value, then + // we know the bool is either zero or one, so this is a 'masking' multiply. + // See if we can simplify things based on how the boolean was originally + // formed. + CastInst *BoolCast = 0; + if (ZExtInst *CI = dyn_cast<ZExtInst>(Op0)) + if (CI->getOperand(0)->getType() == Type::Int1Ty) + BoolCast = CI; + if (!BoolCast) + if (ZExtInst *CI = dyn_cast<ZExtInst>(I.getOperand(1))) + if (CI->getOperand(0)->getType() == Type::Int1Ty) + BoolCast = CI; + if (BoolCast) { + if (ICmpInst *SCI = dyn_cast<ICmpInst>(BoolCast->getOperand(0))) { + Value *SCIOp0 = SCI->getOperand(0), *SCIOp1 = SCI->getOperand(1); + const Type *SCOpTy = SCIOp0->getType(); + bool TIS = false; + + // If the icmp is true iff the sign bit of X is set, then convert this + // multiply into a shift/and combination. + if (isa<ConstantInt>(SCIOp1) && + isSignBitCheck(SCI->getPredicate(), cast<ConstantInt>(SCIOp1), TIS) && + TIS) { + // Shift the X value right to turn it into "all signbits". + Constant *Amt = ConstantInt::get(SCIOp0->getType(), + SCOpTy->getPrimitiveSizeInBits()-1); + Value *V = + InsertNewInstBefore( + BinaryOperator::Create(Instruction::AShr, SCIOp0, Amt, + BoolCast->getOperand(0)->getName()+ + ".mask"), I); + + // If the multiply type is not the same as the source type, sign extend + // or truncate to the multiply type. + if (I.getType() != V->getType()) { + uint32_t SrcBits = V->getType()->getPrimitiveSizeInBits(); + uint32_t DstBits = I.getType()->getPrimitiveSizeInBits(); + Instruction::CastOps opcode = + (SrcBits == DstBits ? Instruction::BitCast : + (SrcBits < DstBits ? Instruction::SExt : Instruction::Trunc)); + V = InsertCastBefore(opcode, V, I.getType(), I); + } + + Value *OtherOp = Op0 == BoolCast ? I.getOperand(1) : Op0; + return BinaryOperator::CreateAnd(V, OtherOp); + } + } + } + + return Changed ? &I : 0; +} + +/// SimplifyDivRemOfSelect - Try to fold a divide or remainder of a select +/// instruction. +bool InstCombiner::SimplifyDivRemOfSelect(BinaryOperator &I) { + SelectInst *SI = cast<SelectInst>(I.getOperand(1)); + + // div/rem X, (Cond ? 0 : Y) -> div/rem X, Y + int NonNullOperand = -1; + if (Constant *ST = dyn_cast<Constant>(SI->getOperand(1))) + if (ST->isNullValue()) + NonNullOperand = 2; + // div/rem X, (Cond ? Y : 0) -> div/rem X, Y + if (Constant *ST = dyn_cast<Constant>(SI->getOperand(2))) + if (ST->isNullValue()) + NonNullOperand = 1; + + if (NonNullOperand == -1) + return false; + + Value *SelectCond = SI->getOperand(0); + + // Change the div/rem to use 'Y' instead of the select. + I.setOperand(1, SI->getOperand(NonNullOperand)); + + // Okay, we know we replace the operand of the div/rem with 'Y' with no + // problem. However, the select, or the condition of the select may have + // multiple uses. Based on our knowledge that the operand must be non-zero, + // propagate the known value for the select into other uses of it, and + // propagate a known value of the condition into its other users. + + // If the select and condition only have a single use, don't bother with this, + // early exit. + if (SI->use_empty() && SelectCond->hasOneUse()) + return true; + + // Scan the current block backward, looking for other uses of SI. + BasicBlock::iterator BBI = &I, BBFront = I.getParent()->begin(); + + while (BBI != BBFront) { + --BBI; + // If we found a call to a function, we can't assume it will return, so + // information from below it cannot be propagated above it. + if (isa<CallInst>(BBI) && !isa<IntrinsicInst>(BBI)) + break; + + // Replace uses of the select or its condition with the known values. + for (Instruction::op_iterator I = BBI->op_begin(), E = BBI->op_end(); + I != E; ++I) { + if (*I == SI) { + *I = SI->getOperand(NonNullOperand); + AddToWorkList(BBI); + } else if (*I == SelectCond) { + *I = NonNullOperand == 1 ? ConstantInt::getTrue() : + ConstantInt::getFalse(); + AddToWorkList(BBI); + } + } + + // If we past the instruction, quit looking for it. + if (&*BBI == SI) + SI = 0; + if (&*BBI == SelectCond) + SelectCond = 0; + + // If we ran out of things to eliminate, break out of the loop. + if (SelectCond == 0 && SI == 0) + break; + + } + return true; +} + + +/// This function implements the transforms on div instructions that work +/// regardless of the kind of div instruction it is (udiv, sdiv, or fdiv). It is +/// used by the visitors to those instructions. +/// @brief Transforms common to all three div instructions +Instruction *InstCombiner::commonDivTransforms(BinaryOperator &I) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + // undef / X -> 0 for integer. + // undef / X -> undef for FP (the undef could be a snan). + if (isa<UndefValue>(Op0)) { + if (Op0->getType()->isFPOrFPVector()) + return ReplaceInstUsesWith(I, Op0); + return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); + } + + // X / undef -> undef + if (isa<UndefValue>(Op1)) + return ReplaceInstUsesWith(I, Op1); + + return 0; +} + +/// This function implements the transforms common to both integer division +/// instructions (udiv and sdiv). It is called by the visitors to those integer +/// division instructions. +/// @brief Common integer divide transforms +Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + // (sdiv X, X) --> 1 (udiv X, X) --> 1 + if (Op0 == Op1) { + if (const VectorType *Ty = dyn_cast<VectorType>(I.getType())) { + ConstantInt *CI = ConstantInt::get(Ty->getElementType(), 1); + std::vector<Constant*> Elts(Ty->getNumElements(), CI); + return ReplaceInstUsesWith(I, ConstantVector::get(Elts)); + } + + ConstantInt *CI = ConstantInt::get(I.getType(), 1); + return ReplaceInstUsesWith(I, CI); + } + + if (Instruction *Common = commonDivTransforms(I)) + return Common; + + // Handle cases involving: [su]div X, (select Cond, Y, Z) + // This does not apply for fdiv. + if (isa<SelectInst>(Op1) && SimplifyDivRemOfSelect(I)) + return &I; + + if (ConstantInt *RHS = dyn_cast<ConstantInt>(Op1)) { + // div X, 1 == X + if (RHS->equalsInt(1)) + return ReplaceInstUsesWith(I, Op0); + + // (X / C1) / C2 -> X / (C1*C2) + if (Instruction *LHS = dyn_cast<Instruction>(Op0)) + if (Instruction::BinaryOps(LHS->getOpcode()) == I.getOpcode()) + if (ConstantInt *LHSRHS = dyn_cast<ConstantInt>(LHS->getOperand(1))) { + if (MultiplyOverflows(RHS, LHSRHS, I.getOpcode()==Instruction::SDiv)) + return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); + else + return BinaryOperator::Create(I.getOpcode(), LHS->getOperand(0), + Multiply(RHS, LHSRHS)); + } + + if (!RHS->isZero()) { // avoid X udiv 0 + if (SelectInst *SI = dyn_cast<SelectInst>(Op0)) + if (Instruction *R = FoldOpIntoSelect(I, SI, this)) + return R; + if (isa<PHINode>(Op0)) + if (Instruction *NV = FoldOpIntoPhi(I)) + return NV; + } + } + + // 0 / X == 0, we don't need to preserve faults! + if (ConstantInt *LHS = dyn_cast<ConstantInt>(Op0)) + if (LHS->equalsInt(0)) + return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); + + // It can't be division by zero, hence it must be division by one. + if (I.getType() == Type::Int1Ty) + return ReplaceInstUsesWith(I, Op0); + + if (ConstantVector *Op1V = dyn_cast<ConstantVector>(Op1)) { + if (ConstantInt *X = cast_or_null<ConstantInt>(Op1V->getSplatValue())) + // div X, 1 == X + if (X->isOne()) + return ReplaceInstUsesWith(I, Op0); + } + + return 0; +} + +Instruction *InstCombiner::visitUDiv(BinaryOperator &I) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + // Handle the integer div common cases + if (Instruction *Common = commonIDivTransforms(I)) + return Common; + + if (ConstantInt *C = dyn_cast<ConstantInt>(Op1)) { + // X udiv C^2 -> X >> C + // Check to see if this is an unsigned division with an exact power of 2, + // if so, convert to a right shift. + if (C->getValue().isPowerOf2()) // 0 not included in isPowerOf2 + return BinaryOperator::CreateLShr(Op0, + ConstantInt::get(Op0->getType(), C->getValue().logBase2())); + + // X udiv C, where C >= signbit + if (C->getValue().isNegative()) { + Value *IC = InsertNewInstBefore(new ICmpInst(ICmpInst::ICMP_ULT, Op0, C), + I); + return SelectInst::Create(IC, Constant::getNullValue(I.getType()), + ConstantInt::get(I.getType(), 1)); + } + } + + // X udiv (C1 << N), where C1 is "1<<C2" --> X >> (N+C2) + if (BinaryOperator *RHSI = dyn_cast<BinaryOperator>(I.getOperand(1))) { + if (RHSI->getOpcode() == Instruction::Shl && + isa<ConstantInt>(RHSI->getOperand(0))) { + const APInt& C1 = cast<ConstantInt>(RHSI->getOperand(0))->getValue(); + if (C1.isPowerOf2()) { + Value *N = RHSI->getOperand(1); + const Type *NTy = N->getType(); + if (uint32_t C2 = C1.logBase2()) { + Constant *C2V = ConstantInt::get(NTy, C2); + N = InsertNewInstBefore(BinaryOperator::CreateAdd(N, C2V, "tmp"), I); + } + return BinaryOperator::CreateLShr(Op0, N); + } + } + } + + // udiv X, (Select Cond, C1, C2) --> Select Cond, (shr X, C1), (shr X, C2) + // where C1&C2 are powers of two. + if (SelectInst *SI = dyn_cast<SelectInst>(Op1)) + if (ConstantInt *STO = dyn_cast<ConstantInt>(SI->getOperand(1))) + if (ConstantInt *SFO = dyn_cast<ConstantInt>(SI->getOperand(2))) { + const APInt &TVA = STO->getValue(), &FVA = SFO->getValue(); + if (TVA.isPowerOf2() && FVA.isPowerOf2()) { + // Compute the shift amounts + uint32_t TSA = TVA.logBase2(), FSA = FVA.logBase2(); + // Construct the "on true" case of the select + Constant *TC = ConstantInt::get(Op0->getType(), TSA); + Instruction *TSI = BinaryOperator::CreateLShr( + Op0, TC, SI->getName()+".t"); + TSI = InsertNewInstBefore(TSI, I); + + // Construct the "on false" case of the select + Constant *FC = ConstantInt::get(Op0->getType(), FSA); + Instruction *FSI = BinaryOperator::CreateLShr( + Op0, FC, SI->getName()+".f"); + FSI = InsertNewInstBefore(FSI, I); + + // construct the select instruction and return it. + return SelectInst::Create(SI->getOperand(0), TSI, FSI, SI->getName()); + } + } + return 0; +} + +Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + // Handle the integer div common cases + if (Instruction *Common = commonIDivTransforms(I)) + return Common; + + if (ConstantInt *RHS = dyn_cast<ConstantInt>(Op1)) { + // sdiv X, -1 == -X + if (RHS->isAllOnesValue()) + return BinaryOperator::CreateNeg(Op0); + } + + // If the sign bits of both operands are zero (i.e. we can prove they are + // unsigned inputs), turn this into a udiv. + if (I.getType()->isInteger()) { + APInt Mask(APInt::getSignBit(I.getType()->getPrimitiveSizeInBits())); + if (MaskedValueIsZero(Op1, Mask) && MaskedValueIsZero(Op0, Mask)) { + // X sdiv Y -> X udiv Y, iff X and Y don't have sign bit set + return BinaryOperator::CreateUDiv(Op0, Op1, I.getName()); + } + } + + return 0; +} + +Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { + return commonDivTransforms(I); +} + +/// This function implements the transforms on rem instructions that work +/// regardless of the kind of rem instruction it is (urem, srem, or frem). It +/// is used by the visitors to those instructions. +/// @brief Transforms common to all three rem instructions +Instruction *InstCombiner::commonRemTransforms(BinaryOperator &I) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + if (isa<UndefValue>(Op0)) { // undef % X -> 0 + if (I.getType()->isFPOrFPVector()) + return ReplaceInstUsesWith(I, Op0); // X % undef -> undef (could be SNaN) + return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); + } + if (isa<UndefValue>(Op1)) + return ReplaceInstUsesWith(I, Op1); // X % undef -> undef + + // Handle cases involving: rem X, (select Cond, Y, Z) + if (isa<SelectInst>(Op1) && SimplifyDivRemOfSelect(I)) + return &I; + + return 0; +} + +/// This function implements the transforms common to both integer remainder +/// instructions (urem and srem). It is called by the visitors to those integer +/// remainder instructions. +/// @brief Common integer remainder transforms +Instruction *InstCombiner::commonIRemTransforms(BinaryOperator &I) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + if (Instruction *common = commonRemTransforms(I)) + return common; + + // 0 % X == 0 for integer, we don't need to preserve faults! + if (Constant *LHS = dyn_cast<Constant>(Op0)) + if (LHS->isNullValue()) + return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); + + if (ConstantInt *RHS = dyn_cast<ConstantInt>(Op1)) { + // X % 0 == undef, we don't need to preserve faults! + if (RHS->equalsInt(0)) + return ReplaceInstUsesWith(I, UndefValue::get(I.getType())); + + if (RHS->equalsInt(1)) // X % 1 == 0 + return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); + + if (Instruction *Op0I = dyn_cast<Instruction>(Op0)) { + if (SelectInst *SI = dyn_cast<SelectInst>(Op0I)) { + if (Instruction *R = FoldOpIntoSelect(I, SI, this)) + return R; + } else if (isa<PHINode>(Op0I)) { + if (Instruction *NV = FoldOpIntoPhi(I)) + return NV; + } + + // See if we can fold away this rem instruction. + if (SimplifyDemandedInstructionBits(I)) + return &I; + } + } + + return 0; +} + +Instruction *InstCombiner::visitURem(BinaryOperator &I) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + if (Instruction *common = commonIRemTransforms(I)) + return common; + + if (ConstantInt *RHS = dyn_cast<ConstantInt>(Op1)) { + // X urem C^2 -> X and C + // Check to see if this is an unsigned remainder with an exact power of 2, + // if so, convert to a bitwise and. + if (ConstantInt *C = dyn_cast<ConstantInt>(RHS)) + if (C->getValue().isPowerOf2()) + return BinaryOperator::CreateAnd(Op0, SubOne(C)); + } + + if (Instruction *RHSI = dyn_cast<Instruction>(I.getOperand(1))) { + // Turn A % (C << N), where C is 2^k, into A & ((C << N)-1) + if (RHSI->getOpcode() == Instruction::Shl && + isa<ConstantInt>(RHSI->getOperand(0))) { + if (cast<ConstantInt>(RHSI->getOperand(0))->getValue().isPowerOf2()) { + Constant *N1 = ConstantInt::getAllOnesValue(I.getType()); + Value *Add = InsertNewInstBefore(BinaryOperator::CreateAdd(RHSI, N1, + "tmp"), I); + return BinaryOperator::CreateAnd(Op0, Add); + } + } + } + + // urem X, (select Cond, 2^C1, 2^C2) --> select Cond, (and X, C1), (and X, C2) + // where C1&C2 are powers of two. + if (SelectInst *SI = dyn_cast<SelectInst>(Op1)) { + if (ConstantInt *STO = dyn_cast<ConstantInt>(SI->getOperand(1))) + if (ConstantInt *SFO = dyn_cast<ConstantInt>(SI->getOperand(2))) { + // STO == 0 and SFO == 0 handled above. + if ((STO->getValue().isPowerOf2()) && + (SFO->getValue().isPowerOf2())) { + Value *TrueAnd = InsertNewInstBefore( + BinaryOperator::CreateAnd(Op0, SubOne(STO), SI->getName()+".t"), I); + Value *FalseAnd = InsertNewInstBefore( + BinaryOperator::CreateAnd(Op0, SubOne(SFO), SI->getName()+".f"), I); + return SelectInst::Create(SI->getOperand(0), TrueAnd, FalseAnd); + } + } + } + + return 0; +} + +Instruction *InstCombiner::visitSRem(BinaryOperator &I) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + // Handle the integer rem common cases + if (Instruction *common = commonIRemTransforms(I)) + return common; + + if (Value *RHSNeg = dyn_castNegVal(Op1)) + if (!isa<Constant>(RHSNeg) || + (isa<ConstantInt>(RHSNeg) && + cast<ConstantInt>(RHSNeg)->getValue().isStrictlyPositive())) { + // X % -Y -> X % Y + AddUsesToWorkList(I); + I.setOperand(1, RHSNeg); + return &I; + } + + // If the sign bits of both operands are zero (i.e. we can prove they are + // unsigned inputs), turn this into a urem. + if (I.getType()->isInteger()) { + APInt Mask(APInt::getSignBit(I.getType()->getPrimitiveSizeInBits())); + if (MaskedValueIsZero(Op1, Mask) && MaskedValueIsZero(Op0, Mask)) { + // X srem Y -> X urem Y, iff X and Y don't have sign bit set + return BinaryOperator::CreateURem(Op0, Op1, I.getName()); + } + } + + // If it's a constant vector, flip any negative values positive. + if (ConstantVector *RHSV = dyn_cast<ConstantVector>(Op1)) { + unsigned VWidth = RHSV->getNumOperands(); + + bool hasNegative = false; + for (unsigned i = 0; !hasNegative && i != VWidth; ++i) + if (ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV->getOperand(i))) + if (RHS->getValue().isNegative()) + hasNegative = true; + + if (hasNegative) { + std::vector<Constant *> Elts(VWidth); + for (unsigned i = 0; i != VWidth; ++i) { + if (ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV->getOperand(i))) { + if (RHS->getValue().isNegative()) + Elts[i] = cast<ConstantInt>(ConstantExpr::getNeg(RHS)); + else + Elts[i] = RHS; + } + } + + Constant *NewRHSV = ConstantVector::get(Elts); + if (NewRHSV != RHSV) { + AddUsesToWorkList(I); + I.setOperand(1, NewRHSV); + return &I; + } + } + } + + return 0; +} + +Instruction *InstCombiner::visitFRem(BinaryOperator &I) { + return commonRemTransforms(I); +} + +// isOneBitSet - Return true if there is exactly one bit set in the specified +// constant. +static bool isOneBitSet(const ConstantInt *CI) { + return CI->getValue().isPowerOf2(); +} + +// isHighOnes - Return true if the constant is of the form 1+0+. +// This is the same as lowones(~X). +static bool isHighOnes(const ConstantInt *CI) { + return (~CI->getValue() + 1).isPowerOf2(); +} + +/// getICmpCode - Encode a icmp predicate into a three bit mask. These bits +/// are carefully arranged to allow folding of expressions such as: +/// +/// (A < B) | (A > B) --> (A != B) +/// +/// Note that this is only valid if the first and second predicates have the +/// same sign. Is illegal to do: (A u< B) | (A s> B) +/// +/// Three bits are used to represent the condition, as follows: +/// 0 A > B +/// 1 A == B +/// 2 A < B +/// +/// <=> Value Definition +/// 000 0 Always false +/// 001 1 A > B +/// 010 2 A == B +/// 011 3 A >= B +/// 100 4 A < B +/// 101 5 A != B +/// 110 6 A <= B +/// 111 7 Always true +/// +static unsigned getICmpCode(const ICmpInst *ICI) { + switch (ICI->getPredicate()) { + // False -> 0 + case ICmpInst::ICMP_UGT: return 1; // 001 + case ICmpInst::ICMP_SGT: return 1; // 001 + case ICmpInst::ICMP_EQ: return 2; // 010 + case ICmpInst::ICMP_UGE: return 3; // 011 + case ICmpInst::ICMP_SGE: return 3; // 011 + case ICmpInst::ICMP_ULT: return 4; // 100 + case ICmpInst::ICMP_SLT: return 4; // 100 + case ICmpInst::ICMP_NE: return 5; // 101 + case ICmpInst::ICMP_ULE: return 6; // 110 + case ICmpInst::ICMP_SLE: return 6; // 110 + // True -> 7 + default: + assert(0 && "Invalid ICmp predicate!"); + return 0; + } +} + +/// getFCmpCode - Similar to getICmpCode but for FCmpInst. This encodes a fcmp +/// predicate into a three bit mask. It also returns whether it is an ordered +/// predicate by reference. +static unsigned getFCmpCode(FCmpInst::Predicate CC, bool &isOrdered) { + isOrdered = false; + switch (CC) { + case FCmpInst::FCMP_ORD: isOrdered = true; return 0; // 000 + case FCmpInst::FCMP_UNO: return 0; // 000 + case FCmpInst::FCMP_OGT: isOrdered = true; return 1; // 001 + case FCmpInst::FCMP_UGT: return 1; // 001 + case FCmpInst::FCMP_OEQ: isOrdered = true; return 2; // 010 + case FCmpInst::FCMP_UEQ: return 2; // 010 + case FCmpInst::FCMP_OGE: isOrdered = true; return 3; // 011 + case FCmpInst::FCMP_UGE: return 3; // 011 + case FCmpInst::FCMP_OLT: isOrdered = true; return 4; // 100 + case FCmpInst::FCMP_ULT: return 4; // 100 + case FCmpInst::FCMP_ONE: isOrdered = true; return 5; // 101 + case FCmpInst::FCMP_UNE: return 5; // 101 + case FCmpInst::FCMP_OLE: isOrdered = true; return 6; // 110 + case FCmpInst::FCMP_ULE: return 6; // 110 + // True -> 7 + default: + // Not expecting FCMP_FALSE and FCMP_TRUE; + assert(0 && "Unexpected FCmp predicate!"); + return 0; + } +} + +/// getICmpValue - This is the complement of getICmpCode, which turns an +/// opcode and two operands into either a constant true or false, or a brand +/// new ICmp instruction. The sign is passed in to determine which kind +/// of predicate to use in the new icmp instruction. +static Value *getICmpValue(bool sign, unsigned code, Value *LHS, Value *RHS) { + switch (code) { + default: assert(0 && "Illegal ICmp code!"); + case 0: return ConstantInt::getFalse(); + case 1: + if (sign) + return new ICmpInst(ICmpInst::ICMP_SGT, LHS, RHS); + else + return new ICmpInst(ICmpInst::ICMP_UGT, LHS, RHS); + case 2: return new ICmpInst(ICmpInst::ICMP_EQ, LHS, RHS); + case 3: + if (sign) + return new ICmpInst(ICmpInst::ICMP_SGE, LHS, RHS); + else + return new ICmpInst(ICmpInst::ICMP_UGE, LHS, RHS); + case 4: + if (sign) + return new ICmpInst(ICmpInst::ICMP_SLT, LHS, RHS); + else + return new ICmpInst(ICmpInst::ICMP_ULT, LHS, RHS); + case 5: return new ICmpInst(ICmpInst::ICMP_NE, LHS, RHS); + case 6: + if (sign) + return new ICmpInst(ICmpInst::ICMP_SLE, LHS, RHS); + else + return new ICmpInst(ICmpInst::ICMP_ULE, LHS, RHS); + case 7: return ConstantInt::getTrue(); + } +} + +/// getFCmpValue - This is the complement of getFCmpCode, which turns an +/// opcode and two operands into either a FCmp instruction. isordered is passed +/// in to determine which kind of predicate to use in the new fcmp instruction. +static Value *getFCmpValue(bool isordered, unsigned code, + Value *LHS, Value *RHS) { + switch (code) { + default: assert(0 && "Illegal FCmp code!"); + case 0: + if (isordered) + return new FCmpInst(FCmpInst::FCMP_ORD, LHS, RHS); + else + return new FCmpInst(FCmpInst::FCMP_UNO, LHS, RHS); + case 1: + if (isordered) + return new FCmpInst(FCmpInst::FCMP_OGT, LHS, RHS); + else + return new FCmpInst(FCmpInst::FCMP_UGT, LHS, RHS); + case 2: + if (isordered) + return new FCmpInst(FCmpInst::FCMP_OEQ, LHS, RHS); + else + return new FCmpInst(FCmpInst::FCMP_UEQ, LHS, RHS); + case 3: + if (isordered) + return new FCmpInst(FCmpInst::FCMP_OGE, LHS, RHS); + else + return new FCmpInst(FCmpInst::FCMP_UGE, LHS, RHS); + case 4: + if (isordered) + return new FCmpInst(FCmpInst::FCMP_OLT, LHS, RHS); + else + return new FCmpInst(FCmpInst::FCMP_ULT, LHS, RHS); + case 5: + if (isordered) + return new FCmpInst(FCmpInst::FCMP_ONE, LHS, RHS); + else + return new FCmpInst(FCmpInst::FCMP_UNE, LHS, RHS); + case 6: + if (isordered) + return new FCmpInst(FCmpInst::FCMP_OLE, LHS, RHS); + else + return new FCmpInst(FCmpInst::FCMP_ULE, LHS, RHS); + case 7: return ConstantInt::getTrue(); + } +} + +/// PredicatesFoldable - Return true if both predicates match sign or if at +/// least one of them is an equality comparison (which is signless). +static bool PredicatesFoldable(ICmpInst::Predicate p1, ICmpInst::Predicate p2) { + return (ICmpInst::isSignedPredicate(p1) == ICmpInst::isSignedPredicate(p2)) || + (ICmpInst::isSignedPredicate(p1) && ICmpInst::isEquality(p2)) || + (ICmpInst::isSignedPredicate(p2) && ICmpInst::isEquality(p1)); +} + +namespace { +// FoldICmpLogical - Implements (icmp1 A, B) & (icmp2 A, B) --> (icmp3 A, B) +struct FoldICmpLogical { + InstCombiner &IC; + Value *LHS, *RHS; + ICmpInst::Predicate pred; + FoldICmpLogical(InstCombiner &ic, ICmpInst *ICI) + : IC(ic), LHS(ICI->getOperand(0)), RHS(ICI->getOperand(1)), + pred(ICI->getPredicate()) {} + bool shouldApply(Value *V) const { + if (ICmpInst *ICI = dyn_cast<ICmpInst>(V)) + if (PredicatesFoldable(pred, ICI->getPredicate())) + return ((ICI->getOperand(0) == LHS && ICI->getOperand(1) == RHS) || + (ICI->getOperand(0) == RHS && ICI->getOperand(1) == LHS)); + return false; + } + Instruction *apply(Instruction &Log) const { + ICmpInst *ICI = cast<ICmpInst>(Log.getOperand(0)); + if (ICI->getOperand(0) != LHS) { + assert(ICI->getOperand(1) == LHS); + ICI->swapOperands(); // Swap the LHS and RHS of the ICmp + } + + ICmpInst *RHSICI = cast<ICmpInst>(Log.getOperand(1)); + unsigned LHSCode = getICmpCode(ICI); + unsigned RHSCode = getICmpCode(RHSICI); + unsigned Code; + switch (Log.getOpcode()) { + case Instruction::And: Code = LHSCode & RHSCode; break; + case Instruction::Or: Code = LHSCode | RHSCode; break; + case Instruction::Xor: Code = LHSCode ^ RHSCode; break; + default: assert(0 && "Illegal logical opcode!"); return 0; + } + + bool isSigned = ICmpInst::isSignedPredicate(RHSICI->getPredicate()) || + ICmpInst::isSignedPredicate(ICI->getPredicate()); + + Value *RV = getICmpValue(isSigned, Code, LHS, RHS); + if (Instruction *I = dyn_cast<Instruction>(RV)) + return I; + // Otherwise, it's a constant boolean value... + return IC.ReplaceInstUsesWith(Log, RV); + } +}; +} // end anonymous namespace + +// OptAndOp - This handles expressions of the form ((val OP C1) & C2). Where +// the Op parameter is 'OP', OpRHS is 'C1', and AndRHS is 'C2'. Op is +// guaranteed to be a binary operator. +Instruction *InstCombiner::OptAndOp(Instruction *Op, + ConstantInt *OpRHS, + ConstantInt *AndRHS, + BinaryOperator &TheAnd) { + Value *X = Op->getOperand(0); + Constant *Together = 0; + if (!Op->isShift()) + Together = And(AndRHS, OpRHS); + + switch (Op->getOpcode()) { + case Instruction::Xor: + if (Op->hasOneUse()) { + // (X ^ C1) & C2 --> (X & C2) ^ (C1&C2) + Instruction *And = BinaryOperator::CreateAnd(X, AndRHS); + InsertNewInstBefore(And, TheAnd); + And->takeName(Op); + return BinaryOperator::CreateXor(And, Together); + } + break; + case Instruction::Or: + if (Together == AndRHS) // (X | C) & C --> C + return ReplaceInstUsesWith(TheAnd, AndRHS); + + if (Op->hasOneUse() && Together != OpRHS) { + // (X | C1) & C2 --> (X | (C1&C2)) & C2 + Instruction *Or = BinaryOperator::CreateOr(X, Together); + InsertNewInstBefore(Or, TheAnd); + Or->takeName(Op); + return BinaryOperator::CreateAnd(Or, AndRHS); + } + break; + case Instruction::Add: + if (Op->hasOneUse()) { + // Adding a one to a single bit bit-field should be turned into an XOR + // of the bit. First thing to check is to see if this AND is with a + // single bit constant. + const APInt& AndRHSV = cast<ConstantInt>(AndRHS)->getValue(); + + // If there is only one bit set... + if (isOneBitSet(cast<ConstantInt>(AndRHS))) { + // Ok, at this point, we know that we are masking the result of the + // ADD down to exactly one bit. If the constant we are adding has + // no bits set below this bit, then we can eliminate the ADD. + const APInt& AddRHS = cast<ConstantInt>(OpRHS)->getValue(); + + // Check to see if any bits below the one bit set in AndRHSV are set. + if ((AddRHS & (AndRHSV-1)) == 0) { + // If not, the only thing that can effect the output of the AND is + // the bit specified by AndRHSV. If that bit is set, the effect of + // the XOR is to toggle the bit. If it is clear, then the ADD has + // no effect. + if ((AddRHS & AndRHSV) == 0) { // Bit is not set, noop + TheAnd.setOperand(0, X); + return &TheAnd; + } else { + // Pull the XOR out of the AND. + Instruction *NewAnd = BinaryOperator::CreateAnd(X, AndRHS); + InsertNewInstBefore(NewAnd, TheAnd); + NewAnd->takeName(Op); + return BinaryOperator::CreateXor(NewAnd, AndRHS); + } + } + } + } + break; + + case Instruction::Shl: { + // We know that the AND will not produce any of the bits shifted in, so if + // the anded constant includes them, clear them now! + // + uint32_t BitWidth = AndRHS->getType()->getBitWidth(); + uint32_t OpRHSVal = OpRHS->getLimitedValue(BitWidth); + APInt ShlMask(APInt::getHighBitsSet(BitWidth, BitWidth-OpRHSVal)); + ConstantInt *CI = ConstantInt::get(AndRHS->getValue() & ShlMask); + + if (CI->getValue() == ShlMask) { + // Masking out bits that the shift already masks + return ReplaceInstUsesWith(TheAnd, Op); // No need for the and. + } else if (CI != AndRHS) { // Reducing bits set in and. + TheAnd.setOperand(1, CI); + return &TheAnd; + } + break; + } + case Instruction::LShr: + { + // We know that the AND will not produce any of the bits shifted in, so if + // the anded constant includes them, clear them now! This only applies to + // unsigned shifts, because a signed shr may bring in set bits! + // + uint32_t BitWidth = AndRHS->getType()->getBitWidth(); + uint32_t OpRHSVal = OpRHS->getLimitedValue(BitWidth); + APInt ShrMask(APInt::getLowBitsSet(BitWidth, BitWidth - OpRHSVal)); + ConstantInt *CI = ConstantInt::get(AndRHS->getValue() & ShrMask); + + if (CI->getValue() == ShrMask) { + // Masking out bits that the shift already masks. + return ReplaceInstUsesWith(TheAnd, Op); + } else if (CI != AndRHS) { + TheAnd.setOperand(1, CI); // Reduce bits set in and cst. + return &TheAnd; + } + break; + } + case Instruction::AShr: + // Signed shr. + // See if this is shifting in some sign extension, then masking it out + // with an and. + if (Op->hasOneUse()) { + uint32_t BitWidth = AndRHS->getType()->getBitWidth(); + uint32_t OpRHSVal = OpRHS->getLimitedValue(BitWidth); + APInt ShrMask(APInt::getLowBitsSet(BitWidth, BitWidth - OpRHSVal)); + Constant *C = ConstantInt::get(AndRHS->getValue() & ShrMask); + if (C == AndRHS) { // Masking out bits shifted in. + // (Val ashr C1) & C2 -> (Val lshr C1) & C2 + // Make the argument unsigned. + Value *ShVal = Op->getOperand(0); + ShVal = InsertNewInstBefore( + BinaryOperator::CreateLShr(ShVal, OpRHS, + Op->getName()), TheAnd); + return BinaryOperator::CreateAnd(ShVal, AndRHS, TheAnd.getName()); + } + } + break; + } + return 0; +} + + +/// InsertRangeTest - Emit a computation of: (V >= Lo && V < Hi) if Inside is +/// true, otherwise (V < Lo || V >= Hi). In pratice, we emit the more efficient +/// (V-Lo) <u Hi-Lo. This method expects that Lo <= Hi. isSigned indicates +/// whether to treat the V, Lo and HI as signed or not. IB is the location to +/// insert new instructions. +Instruction *InstCombiner::InsertRangeTest(Value *V, Constant *Lo, Constant *Hi, + bool isSigned, bool Inside, + Instruction &IB) { + assert(cast<ConstantInt>(ConstantExpr::getICmp((isSigned ? + ICmpInst::ICMP_SLE:ICmpInst::ICMP_ULE), Lo, Hi))->getZExtValue() && + "Lo is not <= Hi in range emission code!"); + + if (Inside) { + if (Lo == Hi) // Trivially false. + return new ICmpInst(ICmpInst::ICMP_NE, V, V); + + // V >= Min && V < Hi --> V < Hi + if (cast<ConstantInt>(Lo)->isMinValue(isSigned)) { + ICmpInst::Predicate pred = (isSigned ? + ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT); + return new ICmpInst(pred, V, Hi); + } + + // Emit V-Lo <u Hi-Lo + Constant *NegLo = ConstantExpr::getNeg(Lo); + Instruction *Add = BinaryOperator::CreateAdd(V, NegLo, V->getName()+".off"); + InsertNewInstBefore(Add, IB); + Constant *UpperBound = ConstantExpr::getAdd(NegLo, Hi); + return new ICmpInst(ICmpInst::ICMP_ULT, Add, UpperBound); + } + + if (Lo == Hi) // Trivially true. + return new ICmpInst(ICmpInst::ICMP_EQ, V, V); + + // V < Min || V >= Hi -> V > Hi-1 + Hi = SubOne(cast<ConstantInt>(Hi)); + if (cast<ConstantInt>(Lo)->isMinValue(isSigned)) { + ICmpInst::Predicate pred = (isSigned ? + ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT); + return new ICmpInst(pred, V, Hi); + } + + // Emit V-Lo >u Hi-1-Lo + // Note that Hi has already had one subtracted from it, above. + ConstantInt *NegLo = cast<ConstantInt>(ConstantExpr::getNeg(Lo)); + Instruction *Add = BinaryOperator::CreateAdd(V, NegLo, V->getName()+".off"); + InsertNewInstBefore(Add, IB); + Constant *LowerBound = ConstantExpr::getAdd(NegLo, Hi); + return new ICmpInst(ICmpInst::ICMP_UGT, Add, LowerBound); +} + +// isRunOfOnes - Returns true iff Val consists of one contiguous run of 1s with +// any number of 0s on either side. The 1s are allowed to wrap from LSB to +// MSB, so 0x000FFF0, 0x0000FFFF, and 0xFF0000FF are all runs. 0x0F0F0000 is +// not, since all 1s are not contiguous. +static bool isRunOfOnes(ConstantInt *Val, uint32_t &MB, uint32_t &ME) { + const APInt& V = Val->getValue(); + uint32_t BitWidth = Val->getType()->getBitWidth(); + if (!APIntOps::isShiftedMask(BitWidth, V)) return false; + + // look for the first zero bit after the run of ones + MB = BitWidth - ((V - 1) ^ V).countLeadingZeros(); + // look for the first non-zero bit + ME = V.getActiveBits(); + return true; +} + +/// FoldLogicalPlusAnd - This is part of an expression (LHS +/- RHS) & Mask, +/// where isSub determines whether the operator is a sub. If we can fold one of +/// the following xforms: +/// +/// ((A & N) +/- B) & Mask -> (A +/- B) & Mask iff N&Mask == Mask +/// ((A | N) +/- B) & Mask -> (A +/- B) & Mask iff N&Mask == 0 +/// ((A ^ N) +/- B) & Mask -> (A +/- B) & Mask iff N&Mask == 0 +/// +/// return (A +/- B). +/// +Value *InstCombiner::FoldLogicalPlusAnd(Value *LHS, Value *RHS, + ConstantInt *Mask, bool isSub, + Instruction &I) { + Instruction *LHSI = dyn_cast<Instruction>(LHS); + if (!LHSI || LHSI->getNumOperands() != 2 || + !isa<ConstantInt>(LHSI->getOperand(1))) return 0; + + ConstantInt *N = cast<ConstantInt>(LHSI->getOperand(1)); + + switch (LHSI->getOpcode()) { + default: return 0; + case Instruction::And: + if (And(N, Mask) == Mask) { + // If the AndRHS is a power of two minus one (0+1+), this is simple. + if ((Mask->getValue().countLeadingZeros() + + Mask->getValue().countPopulation()) == + Mask->getValue().getBitWidth()) + break; + + // Otherwise, if Mask is 0+1+0+, and if B is known to have the low 0+ + // part, we don't need any explicit masks to take them out of A. If that + // is all N is, ignore it. + uint32_t MB = 0, ME = 0; + if (isRunOfOnes(Mask, MB, ME)) { // begin/end bit of run, inclusive + uint32_t BitWidth = cast<IntegerType>(RHS->getType())->getBitWidth(); + APInt Mask(APInt::getLowBitsSet(BitWidth, MB-1)); + if (MaskedValueIsZero(RHS, Mask)) + break; + } + } + return 0; + case Instruction::Or: + case Instruction::Xor: + // If the AndRHS is a power of two minus one (0+1+), and N&Mask == 0 + if ((Mask->getValue().countLeadingZeros() + + Mask->getValue().countPopulation()) == Mask->getValue().getBitWidth() + && And(N, Mask)->isZero()) + break; + return 0; + } + + Instruction *New; + if (isSub) + New = BinaryOperator::CreateSub(LHSI->getOperand(0), RHS, "fold"); + else + New = BinaryOperator::CreateAdd(LHSI->getOperand(0), RHS, "fold"); + return InsertNewInstBefore(New, I); +} + +/// FoldAndOfICmps - Fold (icmp)&(icmp) if possible. +Instruction *InstCombiner::FoldAndOfICmps(Instruction &I, + ICmpInst *LHS, ICmpInst *RHS) { + Value *Val, *Val2; + ConstantInt *LHSCst, *RHSCst; + ICmpInst::Predicate LHSCC, RHSCC; + + // This only handles icmp of constants: (icmp1 A, C1) & (icmp2 B, C2). + if (!match(LHS, m_ICmp(LHSCC, m_Value(Val), m_ConstantInt(LHSCst))) || + !match(RHS, m_ICmp(RHSCC, m_Value(Val2), m_ConstantInt(RHSCst)))) + return 0; + + // (icmp ult A, C) & (icmp ult B, C) --> (icmp ult (A|B), C) + // where C is a power of 2 + if (LHSCst == RHSCst && LHSCC == RHSCC && LHSCC == ICmpInst::ICMP_ULT && + LHSCst->getValue().isPowerOf2()) { + Instruction *NewOr = BinaryOperator::CreateOr(Val, Val2); + InsertNewInstBefore(NewOr, I); + return new ICmpInst(LHSCC, NewOr, LHSCst); + } + + // From here on, we only handle: + // (icmp1 A, C1) & (icmp2 A, C2) --> something simpler. + if (Val != Val2) return 0; + + // ICMP_[US][GL]E X, CST is folded to ICMP_[US][GL]T elsewhere. + if (LHSCC == ICmpInst::ICMP_UGE || LHSCC == ICmpInst::ICMP_ULE || + RHSCC == ICmpInst::ICMP_UGE || RHSCC == ICmpInst::ICMP_ULE || + LHSCC == ICmpInst::ICMP_SGE || LHSCC == ICmpInst::ICMP_SLE || + RHSCC == ICmpInst::ICMP_SGE || RHSCC == ICmpInst::ICMP_SLE) + return 0; + + // We can't fold (ugt x, C) & (sgt x, C2). + if (!PredicatesFoldable(LHSCC, RHSCC)) + return 0; + + // Ensure that the larger constant is on the RHS. + bool ShouldSwap; + if (ICmpInst::isSignedPredicate(LHSCC) || + (ICmpInst::isEquality(LHSCC) && + ICmpInst::isSignedPredicate(RHSCC))) + ShouldSwap = LHSCst->getValue().sgt(RHSCst->getValue()); + else + ShouldSwap = LHSCst->getValue().ugt(RHSCst->getValue()); + + if (ShouldSwap) { + std::swap(LHS, RHS); + std::swap(LHSCst, RHSCst); + std::swap(LHSCC, RHSCC); + } + + // At this point, we know we have have two icmp instructions + // comparing a value against two constants and and'ing the result + // together. Because of the above check, we know that we only have + // icmp eq, icmp ne, icmp [su]lt, and icmp [SU]gt here. We also know + // (from the FoldICmpLogical check above), that the two constants + // are not equal and that the larger constant is on the RHS + assert(LHSCst != RHSCst && "Compares not folded above?"); + + switch (LHSCC) { + default: assert(0 && "Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: + switch (RHSCC) { + default: assert(0 && "Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: // (X == 13 & X == 15) -> false + case ICmpInst::ICMP_UGT: // (X == 13 & X > 15) -> false + case ICmpInst::ICMP_SGT: // (X == 13 & X > 15) -> false + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + case ICmpInst::ICMP_NE: // (X == 13 & X != 15) -> X == 13 + case ICmpInst::ICMP_ULT: // (X == 13 & X < 15) -> X == 13 + case ICmpInst::ICMP_SLT: // (X == 13 & X < 15) -> X == 13 + return ReplaceInstUsesWith(I, LHS); + } + case ICmpInst::ICMP_NE: + switch (RHSCC) { + default: assert(0 && "Unknown integer condition code!"); + case ICmpInst::ICMP_ULT: + if (LHSCst == SubOne(RHSCst)) // (X != 13 & X u< 14) -> X < 13 + return new ICmpInst(ICmpInst::ICMP_ULT, Val, LHSCst); + break; // (X != 13 & X u< 15) -> no change + case ICmpInst::ICMP_SLT: + if (LHSCst == SubOne(RHSCst)) // (X != 13 & X s< 14) -> X < 13 + return new ICmpInst(ICmpInst::ICMP_SLT, Val, LHSCst); + break; // (X != 13 & X s< 15) -> no change + case ICmpInst::ICMP_EQ: // (X != 13 & X == 15) -> X == 15 + case ICmpInst::ICMP_UGT: // (X != 13 & X u> 15) -> X u> 15 + case ICmpInst::ICMP_SGT: // (X != 13 & X s> 15) -> X s> 15 + return ReplaceInstUsesWith(I, RHS); + case ICmpInst::ICMP_NE: + if (LHSCst == SubOne(RHSCst)){// (X != 13 & X != 14) -> X-13 >u 1 + Constant *AddCST = ConstantExpr::getNeg(LHSCst); + Instruction *Add = BinaryOperator::CreateAdd(Val, AddCST, + Val->getName()+".off"); + InsertNewInstBefore(Add, I); + return new ICmpInst(ICmpInst::ICMP_UGT, Add, + ConstantInt::get(Add->getType(), 1)); + } + break; // (X != 13 & X != 15) -> no change + } + break; + case ICmpInst::ICMP_ULT: + switch (RHSCC) { + default: assert(0 && "Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: // (X u< 13 & X == 15) -> false + case ICmpInst::ICMP_UGT: // (X u< 13 & X u> 15) -> false + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + case ICmpInst::ICMP_SGT: // (X u< 13 & X s> 15) -> no change + break; + case ICmpInst::ICMP_NE: // (X u< 13 & X != 15) -> X u< 13 + case ICmpInst::ICMP_ULT: // (X u< 13 & X u< 15) -> X u< 13 + return ReplaceInstUsesWith(I, LHS); + case ICmpInst::ICMP_SLT: // (X u< 13 & X s< 15) -> no change + break; + } + break; + case ICmpInst::ICMP_SLT: + switch (RHSCC) { + default: assert(0 && "Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: // (X s< 13 & X == 15) -> false + case ICmpInst::ICMP_SGT: // (X s< 13 & X s> 15) -> false + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + case ICmpInst::ICMP_UGT: // (X s< 13 & X u> 15) -> no change + break; + case ICmpInst::ICMP_NE: // (X s< 13 & X != 15) -> X < 13 + case ICmpInst::ICMP_SLT: // (X s< 13 & X s< 15) -> X < 13 + return ReplaceInstUsesWith(I, LHS); + case ICmpInst::ICMP_ULT: // (X s< 13 & X u< 15) -> no change + break; + } + break; + case ICmpInst::ICMP_UGT: + switch (RHSCC) { + default: assert(0 && "Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: // (X u> 13 & X == 15) -> X == 15 + case ICmpInst::ICMP_UGT: // (X u> 13 & X u> 15) -> X u> 15 + return ReplaceInstUsesWith(I, RHS); + case ICmpInst::ICMP_SGT: // (X u> 13 & X s> 15) -> no change + break; + case ICmpInst::ICMP_NE: + if (RHSCst == AddOne(LHSCst)) // (X u> 13 & X != 14) -> X u> 14 + return new ICmpInst(LHSCC, Val, RHSCst); + break; // (X u> 13 & X != 15) -> no change + case ICmpInst::ICMP_ULT: // (X u> 13 & X u< 15) -> (X-14) <u 1 + return InsertRangeTest(Val, AddOne(LHSCst), RHSCst, false, true, I); + case ICmpInst::ICMP_SLT: // (X u> 13 & X s< 15) -> no change + break; + } + break; + case ICmpInst::ICMP_SGT: + switch (RHSCC) { + default: assert(0 && "Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: // (X s> 13 & X == 15) -> X == 15 + case ICmpInst::ICMP_SGT: // (X s> 13 & X s> 15) -> X s> 15 + return ReplaceInstUsesWith(I, RHS); + case ICmpInst::ICMP_UGT: // (X s> 13 & X u> 15) -> no change + break; + case ICmpInst::ICMP_NE: + if (RHSCst == AddOne(LHSCst)) // (X s> 13 & X != 14) -> X s> 14 + return new ICmpInst(LHSCC, Val, RHSCst); + break; // (X s> 13 & X != 15) -> no change + case ICmpInst::ICMP_SLT: // (X s> 13 & X s< 15) -> (X-14) s< 1 + return InsertRangeTest(Val, AddOne(LHSCst), RHSCst, true, true, I); + case ICmpInst::ICMP_ULT: // (X s> 13 & X u< 15) -> no change + break; + } + break; + } + + return 0; +} + + +Instruction *InstCombiner::visitAnd(BinaryOperator &I) { + bool Changed = SimplifyCommutative(I); + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + if (isa<UndefValue>(Op1)) // X & undef -> 0 + return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); + + // and X, X = X + if (Op0 == Op1) + return ReplaceInstUsesWith(I, Op1); + + // See if we can simplify any instructions used by the instruction whose sole + // purpose is to compute bits we don't care about. + if (!isa<VectorType>(I.getType())) { + if (SimplifyDemandedInstructionBits(I)) + return &I; + } else { + if (ConstantVector *CP = dyn_cast<ConstantVector>(Op1)) { + if (CP->isAllOnesValue()) // X & <-1,-1> -> X + return ReplaceInstUsesWith(I, I.getOperand(0)); + } else if (isa<ConstantAggregateZero>(Op1)) { + return ReplaceInstUsesWith(I, Op1); // X & <0,0> -> <0,0> + } + } + + if (ConstantInt *AndRHS = dyn_cast<ConstantInt>(Op1)) { + const APInt& AndRHSMask = AndRHS->getValue(); + APInt NotAndRHS(~AndRHSMask); + + // Optimize a variety of ((val OP C1) & C2) combinations... + if (isa<BinaryOperator>(Op0)) { + Instruction *Op0I = cast<Instruction>(Op0); + Value *Op0LHS = Op0I->getOperand(0); + Value *Op0RHS = Op0I->getOperand(1); + switch (Op0I->getOpcode()) { + case Instruction::Xor: + case Instruction::Or: + // If the mask is only needed on one incoming arm, push it up. + if (Op0I->hasOneUse()) { + if (MaskedValueIsZero(Op0LHS, NotAndRHS)) { + // Not masking anything out for the LHS, move to RHS. + Instruction *NewRHS = BinaryOperator::CreateAnd(Op0RHS, AndRHS, + Op0RHS->getName()+".masked"); + InsertNewInstBefore(NewRHS, I); + return BinaryOperator::Create( + cast<BinaryOperator>(Op0I)->getOpcode(), Op0LHS, NewRHS); + } + if (!isa<Constant>(Op0RHS) && + MaskedValueIsZero(Op0RHS, NotAndRHS)) { + // Not masking anything out for the RHS, move to LHS. + Instruction *NewLHS = BinaryOperator::CreateAnd(Op0LHS, AndRHS, + Op0LHS->getName()+".masked"); + InsertNewInstBefore(NewLHS, I); + return BinaryOperator::Create( + cast<BinaryOperator>(Op0I)->getOpcode(), NewLHS, Op0RHS); + } + } + + break; + case Instruction::Add: + // ((A & N) + B) & AndRHS -> (A + B) & AndRHS iff N&AndRHS == AndRHS. + // ((A | N) + B) & AndRHS -> (A + B) & AndRHS iff N&AndRHS == 0 + // ((A ^ N) + B) & AndRHS -> (A + B) & AndRHS iff N&AndRHS == 0 + if (Value *V = FoldLogicalPlusAnd(Op0LHS, Op0RHS, AndRHS, false, I)) + return BinaryOperator::CreateAnd(V, AndRHS); + if (Value *V = FoldLogicalPlusAnd(Op0RHS, Op0LHS, AndRHS, false, I)) + return BinaryOperator::CreateAnd(V, AndRHS); // Add commutes + break; + + case Instruction::Sub: + // ((A & N) - B) & AndRHS -> (A - B) & AndRHS iff N&AndRHS == AndRHS. + // ((A | N) - B) & AndRHS -> (A - B) & AndRHS iff N&AndRHS == 0 + // ((A ^ N) - B) & AndRHS -> (A - B) & AndRHS iff N&AndRHS == 0 + if (Value *V = FoldLogicalPlusAnd(Op0LHS, Op0RHS, AndRHS, true, I)) + return BinaryOperator::CreateAnd(V, AndRHS); + + // (A - N) & AndRHS -> -N & AndRHS iff A&AndRHS==0 and AndRHS + // has 1's for all bits that the subtraction with A might affect. + if (Op0I->hasOneUse()) { + uint32_t BitWidth = AndRHSMask.getBitWidth(); + uint32_t Zeros = AndRHSMask.countLeadingZeros(); + APInt Mask = APInt::getLowBitsSet(BitWidth, BitWidth - Zeros); + + ConstantInt *A = dyn_cast<ConstantInt>(Op0LHS); + if (!(A && A->isZero()) && // avoid infinite recursion. + MaskedValueIsZero(Op0LHS, Mask)) { + Instruction *NewNeg = BinaryOperator::CreateNeg(Op0RHS); + InsertNewInstBefore(NewNeg, I); + return BinaryOperator::CreateAnd(NewNeg, AndRHS); + } + } + break; + + case Instruction::Shl: + case Instruction::LShr: + // (1 << x) & 1 --> zext(x == 0) + // (1 >> x) & 1 --> zext(x == 0) + if (AndRHSMask == 1 && Op0LHS == AndRHS) { + Instruction *NewICmp = new ICmpInst(ICmpInst::ICMP_EQ, Op0RHS, + Constant::getNullValue(I.getType())); + InsertNewInstBefore(NewICmp, I); + return new ZExtInst(NewICmp, I.getType()); + } + break; + } + + if (ConstantInt *Op0CI = dyn_cast<ConstantInt>(Op0I->getOperand(1))) + if (Instruction *Res = OptAndOp(Op0I, Op0CI, AndRHS, I)) + return Res; + } else if (CastInst *CI = dyn_cast<CastInst>(Op0)) { + // If this is an integer truncation or change from signed-to-unsigned, and + // if the source is an and/or with immediate, transform it. This + // frequently occurs for bitfield accesses. + if (Instruction *CastOp = dyn_cast<Instruction>(CI->getOperand(0))) { + if ((isa<TruncInst>(CI) || isa<BitCastInst>(CI)) && + CastOp->getNumOperands() == 2) + if (ConstantInt *AndCI = dyn_cast<ConstantInt>(CastOp->getOperand(1))) { + if (CastOp->getOpcode() == Instruction::And) { + // Change: and (cast (and X, C1) to T), C2 + // into : and (cast X to T), trunc_or_bitcast(C1)&C2 + // This will fold the two constants together, which may allow + // other simplifications. + Instruction *NewCast = CastInst::CreateTruncOrBitCast( + CastOp->getOperand(0), I.getType(), + CastOp->getName()+".shrunk"); + NewCast = InsertNewInstBefore(NewCast, I); + // trunc_or_bitcast(C1)&C2 + Constant *C3 = ConstantExpr::getTruncOrBitCast(AndCI,I.getType()); + C3 = ConstantExpr::getAnd(C3, AndRHS); + return BinaryOperator::CreateAnd(NewCast, C3); + } else if (CastOp->getOpcode() == Instruction::Or) { + // Change: and (cast (or X, C1) to T), C2 + // into : trunc(C1)&C2 iff trunc(C1)&C2 == C2 + Constant *C3 = ConstantExpr::getTruncOrBitCast(AndCI,I.getType()); + if (ConstantExpr::getAnd(C3, AndRHS) == AndRHS) // trunc(C1)&C2 + return ReplaceInstUsesWith(I, AndRHS); + } + } + } + } + + // Try to fold constant and into select arguments. + if (SelectInst *SI = dyn_cast<SelectInst>(Op0)) + if (Instruction *R = FoldOpIntoSelect(I, SI, this)) + return R; + if (isa<PHINode>(Op0)) + if (Instruction *NV = FoldOpIntoPhi(I)) + return NV; + } + + Value *Op0NotVal = dyn_castNotVal(Op0); + Value *Op1NotVal = dyn_castNotVal(Op1); + + if (Op0NotVal == Op1 || Op1NotVal == Op0) // A & ~A == ~A & A == 0 + return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); + + // (~A & ~B) == (~(A | B)) - De Morgan's Law + if (Op0NotVal && Op1NotVal && isOnlyUse(Op0) && isOnlyUse(Op1)) { + Instruction *Or = BinaryOperator::CreateOr(Op0NotVal, Op1NotVal, + I.getName()+".demorgan"); + InsertNewInstBefore(Or, I); + return BinaryOperator::CreateNot(Or); + } + + { + Value *A = 0, *B = 0, *C = 0, *D = 0; + if (match(Op0, m_Or(m_Value(A), m_Value(B)))) { + if (A == Op1 || B == Op1) // (A | ?) & A --> A + return ReplaceInstUsesWith(I, Op1); + + // (A|B) & ~(A&B) -> A^B + if (match(Op1, m_Not(m_And(m_Value(C), m_Value(D))))) { + if ((A == C && B == D) || (A == D && B == C)) + return BinaryOperator::CreateXor(A, B); + } + } + + if (match(Op1, m_Or(m_Value(A), m_Value(B)))) { + if (A == Op0 || B == Op0) // A & (A | ?) --> A + return ReplaceInstUsesWith(I, Op0); + + // ~(A&B) & (A|B) -> A^B + if (match(Op0, m_Not(m_And(m_Value(C), m_Value(D))))) { + if ((A == C && B == D) || (A == D && B == C)) + return BinaryOperator::CreateXor(A, B); + } + } + + if (Op0->hasOneUse() && + match(Op0, m_Xor(m_Value(A), m_Value(B)))) { + if (A == Op1) { // (A^B)&A -> A&(A^B) + I.swapOperands(); // Simplify below + std::swap(Op0, Op1); + } else if (B == Op1) { // (A^B)&B -> B&(B^A) + cast<BinaryOperator>(Op0)->swapOperands(); + I.swapOperands(); // Simplify below + std::swap(Op0, Op1); + } + } + + if (Op1->hasOneUse() && + match(Op1, m_Xor(m_Value(A), m_Value(B)))) { + if (B == Op0) { // B&(A^B) -> B&(B^A) + cast<BinaryOperator>(Op1)->swapOperands(); + std::swap(A, B); + } + if (A == Op0) { // A&(A^B) -> A & ~B + Instruction *NotB = BinaryOperator::CreateNot(B, "tmp"); + InsertNewInstBefore(NotB, I); + return BinaryOperator::CreateAnd(A, NotB); + } + } + + // (A&((~A)|B)) -> A&B + if (match(Op0, m_Or(m_Not(m_Specific(Op1)), m_Value(A))) || + match(Op0, m_Or(m_Value(A), m_Not(m_Specific(Op1))))) + return BinaryOperator::CreateAnd(A, Op1); + if (match(Op1, m_Or(m_Not(m_Specific(Op0)), m_Value(A))) || + match(Op1, m_Or(m_Value(A), m_Not(m_Specific(Op0))))) + return BinaryOperator::CreateAnd(A, Op0); + } + + if (ICmpInst *RHS = dyn_cast<ICmpInst>(Op1)) { + // (icmp1 A, B) & (icmp2 A, B) --> (icmp3 A, B) + if (Instruction *R = AssociativeOpt(I, FoldICmpLogical(*this, RHS))) + return R; + + if (ICmpInst *LHS = dyn_cast<ICmpInst>(Op0)) + if (Instruction *Res = FoldAndOfICmps(I, LHS, RHS)) + return Res; + } + + // fold (and (cast A), (cast B)) -> (cast (and A, B)) + if (CastInst *Op0C = dyn_cast<CastInst>(Op0)) + if (CastInst *Op1C = dyn_cast<CastInst>(Op1)) + if (Op0C->getOpcode() == Op1C->getOpcode()) { // same cast kind ? + const Type *SrcTy = Op0C->getOperand(0)->getType(); + if (SrcTy == Op1C->getOperand(0)->getType() && SrcTy->isInteger() && + // Only do this if the casts both really cause code to be generated. + ValueRequiresCast(Op0C->getOpcode(), Op0C->getOperand(0), + I.getType(), TD) && + ValueRequiresCast(Op1C->getOpcode(), Op1C->getOperand(0), + I.getType(), TD)) { + Instruction *NewOp = BinaryOperator::CreateAnd(Op0C->getOperand(0), + Op1C->getOperand(0), + I.getName()); + InsertNewInstBefore(NewOp, I); + return CastInst::Create(Op0C->getOpcode(), NewOp, I.getType()); + } + } + + // (X >> Z) & (Y >> Z) -> (X&Y) >> Z for all shifts. + if (BinaryOperator *SI1 = dyn_cast<BinaryOperator>(Op1)) { + if (BinaryOperator *SI0 = dyn_cast<BinaryOperator>(Op0)) + if (SI0->isShift() && SI0->getOpcode() == SI1->getOpcode() && + SI0->getOperand(1) == SI1->getOperand(1) && + (SI0->hasOneUse() || SI1->hasOneUse())) { + Instruction *NewOp = + InsertNewInstBefore(BinaryOperator::CreateAnd(SI0->getOperand(0), + SI1->getOperand(0), + SI0->getName()), I); + return BinaryOperator::Create(SI1->getOpcode(), NewOp, + SI1->getOperand(1)); + } + } + + // If and'ing two fcmp, try combine them into one. + if (FCmpInst *LHS = dyn_cast<FCmpInst>(I.getOperand(0))) { + if (FCmpInst *RHS = dyn_cast<FCmpInst>(I.getOperand(1))) { + if (LHS->getPredicate() == FCmpInst::FCMP_ORD && + RHS->getPredicate() == FCmpInst::FCMP_ORD) { + // (fcmp ord x, c) & (fcmp ord y, c) -> (fcmp ord x, y) + if (ConstantFP *LHSC = dyn_cast<ConstantFP>(LHS->getOperand(1))) + if (ConstantFP *RHSC = dyn_cast<ConstantFP>(RHS->getOperand(1))) { + // If either of the constants are nans, then the whole thing returns + // false. + if (LHSC->getValueAPF().isNaN() || RHSC->getValueAPF().isNaN()) + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + return new FCmpInst(FCmpInst::FCMP_ORD, LHS->getOperand(0), + RHS->getOperand(0)); + } + } else { + Value *Op0LHS, *Op0RHS, *Op1LHS, *Op1RHS; + FCmpInst::Predicate Op0CC, Op1CC; + if (match(Op0, m_FCmp(Op0CC, m_Value(Op0LHS), m_Value(Op0RHS))) && + match(Op1, m_FCmp(Op1CC, m_Value(Op1LHS), m_Value(Op1RHS)))) { + if (Op0LHS == Op1RHS && Op0RHS == Op1LHS) { + // Swap RHS operands to match LHS. + Op1CC = FCmpInst::getSwappedPredicate(Op1CC); + std::swap(Op1LHS, Op1RHS); + } + if (Op0LHS == Op1LHS && Op0RHS == Op1RHS) { + // Simplify (fcmp cc0 x, y) & (fcmp cc1 x, y). + if (Op0CC == Op1CC) + return new FCmpInst((FCmpInst::Predicate)Op0CC, Op0LHS, Op0RHS); + else if (Op0CC == FCmpInst::FCMP_FALSE || + Op1CC == FCmpInst::FCMP_FALSE) + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + else if (Op0CC == FCmpInst::FCMP_TRUE) + return ReplaceInstUsesWith(I, Op1); + else if (Op1CC == FCmpInst::FCMP_TRUE) + return ReplaceInstUsesWith(I, Op0); + bool Op0Ordered; + bool Op1Ordered; + unsigned Op0Pred = getFCmpCode(Op0CC, Op0Ordered); + unsigned Op1Pred = getFCmpCode(Op1CC, Op1Ordered); + if (Op1Pred == 0) { + std::swap(Op0, Op1); + std::swap(Op0Pred, Op1Pred); + std::swap(Op0Ordered, Op1Ordered); + } + if (Op0Pred == 0) { + // uno && ueq -> uno && (uno || eq) -> ueq + // ord && olt -> ord && (ord && lt) -> olt + if (Op0Ordered == Op1Ordered) + return ReplaceInstUsesWith(I, Op1); + // uno && oeq -> uno && (ord && eq) -> false + // uno && ord -> false + if (!Op0Ordered) + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + // ord && ueq -> ord && (uno || eq) -> oeq + return cast<Instruction>(getFCmpValue(true, Op1Pred, + Op0LHS, Op0RHS)); + } + } + } + } + } + } + + return Changed ? &I : 0; +} + +/// CollectBSwapParts - Analyze the specified subexpression and see if it is +/// capable of providing pieces of a bswap. The subexpression provides pieces +/// of a bswap if it is proven that each of the non-zero bytes in the output of +/// the expression came from the corresponding "byte swapped" byte in some other +/// value. For example, if the current subexpression is "(shl i32 %X, 24)" then +/// we know that the expression deposits the low byte of %X into the high byte +/// of the bswap result and that all other bytes are zero. This expression is +/// accepted, the high byte of ByteValues is set to X to indicate a correct +/// match. +/// +/// This function returns true if the match was unsuccessful and false if so. +/// On entry to the function the "OverallLeftShift" is a signed integer value +/// indicating the number of bytes that the subexpression is later shifted. For +/// example, if the expression is later right shifted by 16 bits, the +/// OverallLeftShift value would be -2 on entry. This is used to specify which +/// byte of ByteValues is actually being set. +/// +/// Similarly, ByteMask is a bitmask where a bit is clear if its corresponding +/// byte is masked to zero by a user. For example, in (X & 255), X will be +/// processed with a bytemask of 1. Because bytemask is 32-bits, this limits +/// this function to working on up to 32-byte (256 bit) values. ByteMask is +/// always in the local (OverallLeftShift) coordinate space. +/// +static bool CollectBSwapParts(Value *V, int OverallLeftShift, uint32_t ByteMask, + SmallVector<Value*, 8> &ByteValues) { + if (Instruction *I = dyn_cast<Instruction>(V)) { + // If this is an or instruction, it may be an inner node of the bswap. + if (I->getOpcode() == Instruction::Or) { + return CollectBSwapParts(I->getOperand(0), OverallLeftShift, ByteMask, + ByteValues) || + CollectBSwapParts(I->getOperand(1), OverallLeftShift, ByteMask, + ByteValues); + } + + // If this is a logical shift by a constant multiple of 8, recurse with + // OverallLeftShift and ByteMask adjusted. + if (I->isLogicalShift() && isa<ConstantInt>(I->getOperand(1))) { + unsigned ShAmt = + cast<ConstantInt>(I->getOperand(1))->getLimitedValue(~0U); + // Ensure the shift amount is defined and of a byte value. + if ((ShAmt & 7) || (ShAmt > 8*ByteValues.size())) + return true; + + unsigned ByteShift = ShAmt >> 3; + if (I->getOpcode() == Instruction::Shl) { + // X << 2 -> collect(X, +2) + OverallLeftShift += ByteShift; + ByteMask >>= ByteShift; + } else { + // X >>u 2 -> collect(X, -2) + OverallLeftShift -= ByteShift; + ByteMask <<= ByteShift; + ByteMask &= (~0U >> (32-ByteValues.size())); + } + + if (OverallLeftShift >= (int)ByteValues.size()) return true; + if (OverallLeftShift <= -(int)ByteValues.size()) return true; + + return CollectBSwapParts(I->getOperand(0), OverallLeftShift, ByteMask, + ByteValues); + } + + // If this is a logical 'and' with a mask that clears bytes, clear the + // corresponding bytes in ByteMask. + if (I->getOpcode() == Instruction::And && + isa<ConstantInt>(I->getOperand(1))) { + // Scan every byte of the and mask, seeing if the byte is either 0 or 255. + unsigned NumBytes = ByteValues.size(); + APInt Byte(I->getType()->getPrimitiveSizeInBits(), 255); + const APInt &AndMask = cast<ConstantInt>(I->getOperand(1))->getValue(); + + for (unsigned i = 0; i != NumBytes; ++i, Byte <<= 8) { + // If this byte is masked out by a later operation, we don't care what + // the and mask is. + if ((ByteMask & (1 << i)) == 0) + continue; + + // If the AndMask is all zeros for this byte, clear the bit. + APInt MaskB = AndMask & Byte; + if (MaskB == 0) { + ByteMask &= ~(1U << i); + continue; + } + + // If the AndMask is not all ones for this byte, it's not a bytezap. + if (MaskB != Byte) + return true; + + // Otherwise, this byte is kept. + } + + return CollectBSwapParts(I->getOperand(0), OverallLeftShift, ByteMask, + ByteValues); + } + } + + // Okay, we got to something that isn't a shift, 'or' or 'and'. This must be + // the input value to the bswap. Some observations: 1) if more than one byte + // is demanded from this input, then it could not be successfully assembled + // into a byteswap. At least one of the two bytes would not be aligned with + // their ultimate destination. + if (!isPowerOf2_32(ByteMask)) return true; + unsigned InputByteNo = CountTrailingZeros_32(ByteMask); + + // 2) The input and ultimate destinations must line up: if byte 3 of an i32 + // is demanded, it needs to go into byte 0 of the result. This means that the + // byte needs to be shifted until it lands in the right byte bucket. The + // shift amount depends on the position: if the byte is coming from the high + // part of the value (e.g. byte 3) then it must be shifted right. If from the + // low part, it must be shifted left. + unsigned DestByteNo = InputByteNo + OverallLeftShift; + if (InputByteNo < ByteValues.size()/2) { + if (ByteValues.size()-1-DestByteNo != InputByteNo) + return true; + } else { + if (ByteValues.size()-1-DestByteNo != InputByteNo) + return true; + } + + // If the destination byte value is already defined, the values are or'd + // together, which isn't a bswap (unless it's an or of the same bits). + if (ByteValues[DestByteNo] && ByteValues[DestByteNo] != V) + return true; + ByteValues[DestByteNo] = V; + return false; +} + +/// MatchBSwap - Given an OR instruction, check to see if this is a bswap idiom. +/// If so, insert the new bswap intrinsic and return it. +Instruction *InstCombiner::MatchBSwap(BinaryOperator &I) { + const IntegerType *ITy = dyn_cast<IntegerType>(I.getType()); + if (!ITy || ITy->getBitWidth() % 16 || + // ByteMask only allows up to 32-byte values. + ITy->getBitWidth() > 32*8) + return 0; // Can only bswap pairs of bytes. Can't do vectors. + + /// ByteValues - For each byte of the result, we keep track of which value + /// defines each byte. + SmallVector<Value*, 8> ByteValues; + ByteValues.resize(ITy->getBitWidth()/8); + + // Try to find all the pieces corresponding to the bswap. + uint32_t ByteMask = ~0U >> (32-ByteValues.size()); + if (CollectBSwapParts(&I, 0, ByteMask, ByteValues)) + return 0; + + // Check to see if all of the bytes come from the same value. + Value *V = ByteValues[0]; + if (V == 0) return 0; // Didn't find a byte? Must be zero. + + // Check to make sure that all of the bytes come from the same value. + for (unsigned i = 1, e = ByteValues.size(); i != e; ++i) + if (ByteValues[i] != V) + return 0; + const Type *Tys[] = { ITy }; + Module *M = I.getParent()->getParent()->getParent(); + Function *F = Intrinsic::getDeclaration(M, Intrinsic::bswap, Tys, 1); + return CallInst::Create(F, V); +} + +/// MatchSelectFromAndOr - We have an expression of the form (A&C)|(B&D). Check +/// If A is (cond?-1:0) and either B or D is ~(cond?-1,0) or (cond?0,-1), then +/// we can simplify this expression to "cond ? C : D or B". +static Instruction *MatchSelectFromAndOr(Value *A, Value *B, + Value *C, Value *D) { + // If A is not a select of -1/0, this cannot match. + Value *Cond = 0; + if (!match(A, m_SelectCst<-1, 0>(m_Value(Cond)))) + return 0; + + // ((cond?-1:0)&C) | (B&(cond?0:-1)) -> cond ? C : B. + if (match(D, m_SelectCst<0, -1>(m_Specific(Cond)))) + return SelectInst::Create(Cond, C, B); + if (match(D, m_Not(m_SelectCst<-1, 0>(m_Specific(Cond))))) + return SelectInst::Create(Cond, C, B); + // ((cond?-1:0)&C) | ((cond?0:-1)&D) -> cond ? C : D. + if (match(B, m_SelectCst<0, -1>(m_Specific(Cond)))) + return SelectInst::Create(Cond, C, D); + if (match(B, m_Not(m_SelectCst<-1, 0>(m_Specific(Cond))))) + return SelectInst::Create(Cond, C, D); + return 0; +} + +/// FoldOrOfICmps - Fold (icmp)|(icmp) if possible. +Instruction *InstCombiner::FoldOrOfICmps(Instruction &I, + ICmpInst *LHS, ICmpInst *RHS) { + Value *Val, *Val2; + ConstantInt *LHSCst, *RHSCst; + ICmpInst::Predicate LHSCC, RHSCC; + + // This only handles icmp of constants: (icmp1 A, C1) | (icmp2 B, C2). + if (!match(LHS, m_ICmp(LHSCC, m_Value(Val), m_ConstantInt(LHSCst))) || + !match(RHS, m_ICmp(RHSCC, m_Value(Val2), m_ConstantInt(RHSCst)))) + return 0; + + // From here on, we only handle: + // (icmp1 A, C1) | (icmp2 A, C2) --> something simpler. + if (Val != Val2) return 0; + + // ICMP_[US][GL]E X, CST is folded to ICMP_[US][GL]T elsewhere. + if (LHSCC == ICmpInst::ICMP_UGE || LHSCC == ICmpInst::ICMP_ULE || + RHSCC == ICmpInst::ICMP_UGE || RHSCC == ICmpInst::ICMP_ULE || + LHSCC == ICmpInst::ICMP_SGE || LHSCC == ICmpInst::ICMP_SLE || + RHSCC == ICmpInst::ICMP_SGE || RHSCC == ICmpInst::ICMP_SLE) + return 0; + + // We can't fold (ugt x, C) | (sgt x, C2). + if (!PredicatesFoldable(LHSCC, RHSCC)) + return 0; + + // Ensure that the larger constant is on the RHS. + bool ShouldSwap; + if (ICmpInst::isSignedPredicate(LHSCC) || + (ICmpInst::isEquality(LHSCC) && + ICmpInst::isSignedPredicate(RHSCC))) + ShouldSwap = LHSCst->getValue().sgt(RHSCst->getValue()); + else + ShouldSwap = LHSCst->getValue().ugt(RHSCst->getValue()); + + if (ShouldSwap) { + std::swap(LHS, RHS); + std::swap(LHSCst, RHSCst); + std::swap(LHSCC, RHSCC); + } + + // At this point, we know we have have two icmp instructions + // comparing a value against two constants and or'ing the result + // together. Because of the above check, we know that we only have + // ICMP_EQ, ICMP_NE, ICMP_LT, and ICMP_GT here. We also know (from the + // FoldICmpLogical check above), that the two constants are not + // equal. + assert(LHSCst != RHSCst && "Compares not folded above?"); + + switch (LHSCC) { + default: assert(0 && "Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: + switch (RHSCC) { + default: assert(0 && "Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: + if (LHSCst == SubOne(RHSCst)) { // (X == 13 | X == 14) -> X-13 <u 2 + Constant *AddCST = ConstantExpr::getNeg(LHSCst); + Instruction *Add = BinaryOperator::CreateAdd(Val, AddCST, + Val->getName()+".off"); + InsertNewInstBefore(Add, I); + AddCST = Subtract(AddOne(RHSCst), LHSCst); + return new ICmpInst(ICmpInst::ICMP_ULT, Add, AddCST); + } + break; // (X == 13 | X == 15) -> no change + case ICmpInst::ICMP_UGT: // (X == 13 | X u> 14) -> no change + case ICmpInst::ICMP_SGT: // (X == 13 | X s> 14) -> no change + break; + case ICmpInst::ICMP_NE: // (X == 13 | X != 15) -> X != 15 + case ICmpInst::ICMP_ULT: // (X == 13 | X u< 15) -> X u< 15 + case ICmpInst::ICMP_SLT: // (X == 13 | X s< 15) -> X s< 15 + return ReplaceInstUsesWith(I, RHS); + } + break; + case ICmpInst::ICMP_NE: + switch (RHSCC) { + default: assert(0 && "Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: // (X != 13 | X == 15) -> X != 13 + case ICmpInst::ICMP_UGT: // (X != 13 | X u> 15) -> X != 13 + case ICmpInst::ICMP_SGT: // (X != 13 | X s> 15) -> X != 13 + return ReplaceInstUsesWith(I, LHS); + case ICmpInst::ICMP_NE: // (X != 13 | X != 15) -> true + case ICmpInst::ICMP_ULT: // (X != 13 | X u< 15) -> true + case ICmpInst::ICMP_SLT: // (X != 13 | X s< 15) -> true + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + } + break; + case ICmpInst::ICMP_ULT: + switch (RHSCC) { + default: assert(0 && "Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: // (X u< 13 | X == 14) -> no change + break; + case ICmpInst::ICMP_UGT: // (X u< 13 | X u> 15) -> (X-13) u> 2 + // If RHSCst is [us]MAXINT, it is always false. Not handling + // this can cause overflow. + if (RHSCst->isMaxValue(false)) + return ReplaceInstUsesWith(I, LHS); + return InsertRangeTest(Val, LHSCst, AddOne(RHSCst), false, false, I); + case ICmpInst::ICMP_SGT: // (X u< 13 | X s> 15) -> no change + break; + case ICmpInst::ICMP_NE: // (X u< 13 | X != 15) -> X != 15 + case ICmpInst::ICMP_ULT: // (X u< 13 | X u< 15) -> X u< 15 + return ReplaceInstUsesWith(I, RHS); + case ICmpInst::ICMP_SLT: // (X u< 13 | X s< 15) -> no change + break; + } + break; + case ICmpInst::ICMP_SLT: + switch (RHSCC) { + default: assert(0 && "Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: // (X s< 13 | X == 14) -> no change + break; + case ICmpInst::ICMP_SGT: // (X s< 13 | X s> 15) -> (X-13) s> 2 + // If RHSCst is [us]MAXINT, it is always false. Not handling + // this can cause overflow. + if (RHSCst->isMaxValue(true)) + return ReplaceInstUsesWith(I, LHS); + return InsertRangeTest(Val, LHSCst, AddOne(RHSCst), true, false, I); + case ICmpInst::ICMP_UGT: // (X s< 13 | X u> 15) -> no change + break; + case ICmpInst::ICMP_NE: // (X s< 13 | X != 15) -> X != 15 + case ICmpInst::ICMP_SLT: // (X s< 13 | X s< 15) -> X s< 15 + return ReplaceInstUsesWith(I, RHS); + case ICmpInst::ICMP_ULT: // (X s< 13 | X u< 15) -> no change + break; + } + break; + case ICmpInst::ICMP_UGT: + switch (RHSCC) { + default: assert(0 && "Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: // (X u> 13 | X == 15) -> X u> 13 + case ICmpInst::ICMP_UGT: // (X u> 13 | X u> 15) -> X u> 13 + return ReplaceInstUsesWith(I, LHS); + case ICmpInst::ICMP_SGT: // (X u> 13 | X s> 15) -> no change + break; + case ICmpInst::ICMP_NE: // (X u> 13 | X != 15) -> true + case ICmpInst::ICMP_ULT: // (X u> 13 | X u< 15) -> true + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + case ICmpInst::ICMP_SLT: // (X u> 13 | X s< 15) -> no change + break; + } + break; + case ICmpInst::ICMP_SGT: + switch (RHSCC) { + default: assert(0 && "Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: // (X s> 13 | X == 15) -> X > 13 + case ICmpInst::ICMP_SGT: // (X s> 13 | X s> 15) -> X > 13 + return ReplaceInstUsesWith(I, LHS); + case ICmpInst::ICMP_UGT: // (X s> 13 | X u> 15) -> no change + break; + case ICmpInst::ICMP_NE: // (X s> 13 | X != 15) -> true + case ICmpInst::ICMP_SLT: // (X s> 13 | X s< 15) -> true + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + case ICmpInst::ICMP_ULT: // (X s> 13 | X u< 15) -> no change + break; + } + break; + } + return 0; +} + +/// FoldOrWithConstants - This helper function folds: +/// +/// ((A | B) & C1) | (B & C2) +/// +/// into: +/// +/// (A & C1) | B +/// +/// when the XOR of the two constants is "all ones" (-1). +Instruction *InstCombiner::FoldOrWithConstants(BinaryOperator &I, Value *Op, + Value *A, Value *B, Value *C) { + ConstantInt *CI1 = dyn_cast<ConstantInt>(C); + if (!CI1) return 0; + + Value *V1 = 0; + ConstantInt *CI2 = 0; + if (!match(Op, m_And(m_Value(V1), m_ConstantInt(CI2)))) return 0; + + APInt Xor = CI1->getValue() ^ CI2->getValue(); + if (!Xor.isAllOnesValue()) return 0; + + if (V1 == A || V1 == B) { + Instruction *NewOp = + InsertNewInstBefore(BinaryOperator::CreateAnd((V1 == A) ? B : A, CI1), I); + return BinaryOperator::CreateOr(NewOp, V1); + } + + return 0; +} + +Instruction *InstCombiner::visitOr(BinaryOperator &I) { + bool Changed = SimplifyCommutative(I); + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + if (isa<UndefValue>(Op1)) // X | undef -> -1 + return ReplaceInstUsesWith(I, Constant::getAllOnesValue(I.getType())); + + // or X, X = X + if (Op0 == Op1) + return ReplaceInstUsesWith(I, Op0); + + // See if we can simplify any instructions used by the instruction whose sole + // purpose is to compute bits we don't care about. + if (!isa<VectorType>(I.getType())) { + if (SimplifyDemandedInstructionBits(I)) + return &I; + } else if (isa<ConstantAggregateZero>(Op1)) { + return ReplaceInstUsesWith(I, Op0); // X | <0,0> -> X + } else if (ConstantVector *CP = dyn_cast<ConstantVector>(Op1)) { + if (CP->isAllOnesValue()) // X | <-1,-1> -> <-1,-1> + return ReplaceInstUsesWith(I, I.getOperand(1)); + } + + + + // or X, -1 == -1 + if (ConstantInt *RHS = dyn_cast<ConstantInt>(Op1)) { + ConstantInt *C1 = 0; Value *X = 0; + // (X & C1) | C2 --> (X | C2) & (C1|C2) + if (match(Op0, m_And(m_Value(X), m_ConstantInt(C1))) && isOnlyUse(Op0)) { + Instruction *Or = BinaryOperator::CreateOr(X, RHS); + InsertNewInstBefore(Or, I); + Or->takeName(Op0); + return BinaryOperator::CreateAnd(Or, + ConstantInt::get(RHS->getValue() | C1->getValue())); + } + + // (X ^ C1) | C2 --> (X | C2) ^ (C1&~C2) + if (match(Op0, m_Xor(m_Value(X), m_ConstantInt(C1))) && isOnlyUse(Op0)) { + Instruction *Or = BinaryOperator::CreateOr(X, RHS); + InsertNewInstBefore(Or, I); + Or->takeName(Op0); + return BinaryOperator::CreateXor(Or, + ConstantInt::get(C1->getValue() & ~RHS->getValue())); + } + + // Try to fold constant and into select arguments. + if (SelectInst *SI = dyn_cast<SelectInst>(Op0)) + if (Instruction *R = FoldOpIntoSelect(I, SI, this)) + return R; + if (isa<PHINode>(Op0)) + if (Instruction *NV = FoldOpIntoPhi(I)) + return NV; + } + + Value *A = 0, *B = 0; + ConstantInt *C1 = 0, *C2 = 0; + + if (match(Op0, m_And(m_Value(A), m_Value(B)))) + if (A == Op1 || B == Op1) // (A & ?) | A --> A + return ReplaceInstUsesWith(I, Op1); + if (match(Op1, m_And(m_Value(A), m_Value(B)))) + if (A == Op0 || B == Op0) // A | (A & ?) --> A + return ReplaceInstUsesWith(I, Op0); + + // (A | B) | C and A | (B | C) -> bswap if possible. + // (A >> B) | (C << D) and (A << B) | (B >> C) -> bswap if possible. + if (match(Op0, m_Or(m_Value(), m_Value())) || + match(Op1, m_Or(m_Value(), m_Value())) || + (match(Op0, m_Shift(m_Value(), m_Value())) && + match(Op1, m_Shift(m_Value(), m_Value())))) { + if (Instruction *BSwap = MatchBSwap(I)) + return BSwap; + } + + // (X^C)|Y -> (X|Y)^C iff Y&C == 0 + if (Op0->hasOneUse() && match(Op0, m_Xor(m_Value(A), m_ConstantInt(C1))) && + MaskedValueIsZero(Op1, C1->getValue())) { + Instruction *NOr = BinaryOperator::CreateOr(A, Op1); + InsertNewInstBefore(NOr, I); + NOr->takeName(Op0); + return BinaryOperator::CreateXor(NOr, C1); + } + + // Y|(X^C) -> (X|Y)^C iff Y&C == 0 + if (Op1->hasOneUse() && match(Op1, m_Xor(m_Value(A), m_ConstantInt(C1))) && + MaskedValueIsZero(Op0, C1->getValue())) { + Instruction *NOr = BinaryOperator::CreateOr(A, Op0); + InsertNewInstBefore(NOr, I); + NOr->takeName(Op0); + return BinaryOperator::CreateXor(NOr, C1); + } + + // (A & C)|(B & D) + Value *C = 0, *D = 0; + if (match(Op0, m_And(m_Value(A), m_Value(C))) && + match(Op1, m_And(m_Value(B), m_Value(D)))) { + Value *V1 = 0, *V2 = 0, *V3 = 0; + C1 = dyn_cast<ConstantInt>(C); + C2 = dyn_cast<ConstantInt>(D); + if (C1 && C2) { // (A & C1)|(B & C2) + // If we have: ((V + N) & C1) | (V & C2) + // .. and C2 = ~C1 and C2 is 0+1+ and (N & C2) == 0 + // replace with V+N. + if (C1->getValue() == ~C2->getValue()) { + if ((C2->getValue() & (C2->getValue()+1)) == 0 && // C2 == 0+1+ + match(A, m_Add(m_Value(V1), m_Value(V2)))) { + // Add commutes, try both ways. + if (V1 == B && MaskedValueIsZero(V2, C2->getValue())) + return ReplaceInstUsesWith(I, A); + if (V2 == B && MaskedValueIsZero(V1, C2->getValue())) + return ReplaceInstUsesWith(I, A); + } + // Or commutes, try both ways. + if ((C1->getValue() & (C1->getValue()+1)) == 0 && + match(B, m_Add(m_Value(V1), m_Value(V2)))) { + // Add commutes, try both ways. + if (V1 == A && MaskedValueIsZero(V2, C1->getValue())) + return ReplaceInstUsesWith(I, B); + if (V2 == A && MaskedValueIsZero(V1, C1->getValue())) + return ReplaceInstUsesWith(I, B); + } + } + V1 = 0; V2 = 0; V3 = 0; + } + + // Check to see if we have any common things being and'ed. If so, find the + // terms for V1 & (V2|V3). + if (isOnlyUse(Op0) || isOnlyUse(Op1)) { + if (A == B) // (A & C)|(A & D) == A & (C|D) + V1 = A, V2 = C, V3 = D; + else if (A == D) // (A & C)|(B & A) == A & (B|C) + V1 = A, V2 = B, V3 = C; + else if (C == B) // (A & C)|(C & D) == C & (A|D) + V1 = C, V2 = A, V3 = D; + else if (C == D) // (A & C)|(B & C) == C & (A|B) + V1 = C, V2 = A, V3 = B; + + if (V1) { + Value *Or = + InsertNewInstBefore(BinaryOperator::CreateOr(V2, V3, "tmp"), I); + return BinaryOperator::CreateAnd(V1, Or); + } + } + + // (A & (C0?-1:0)) | (B & ~(C0?-1:0)) -> C0 ? A : B, and commuted variants + if (Instruction *Match = MatchSelectFromAndOr(A, B, C, D)) + return Match; + if (Instruction *Match = MatchSelectFromAndOr(B, A, D, C)) + return Match; + if (Instruction *Match = MatchSelectFromAndOr(C, B, A, D)) + return Match; + if (Instruction *Match = MatchSelectFromAndOr(D, A, B, C)) + return Match; + + // ((A&~B)|(~A&B)) -> A^B + if ((match(C, m_Not(m_Specific(D))) && + match(B, m_Not(m_Specific(A))))) + return BinaryOperator::CreateXor(A, D); + // ((~B&A)|(~A&B)) -> A^B + if ((match(A, m_Not(m_Specific(D))) && + match(B, m_Not(m_Specific(C))))) + return BinaryOperator::CreateXor(C, D); + // ((A&~B)|(B&~A)) -> A^B + if ((match(C, m_Not(m_Specific(B))) && + match(D, m_Not(m_Specific(A))))) + return BinaryOperator::CreateXor(A, B); + // ((~B&A)|(B&~A)) -> A^B + if ((match(A, m_Not(m_Specific(B))) && + match(D, m_Not(m_Specific(C))))) + return BinaryOperator::CreateXor(C, B); + } + + // (X >> Z) | (Y >> Z) -> (X|Y) >> Z for all shifts. + if (BinaryOperator *SI1 = dyn_cast<BinaryOperator>(Op1)) { + if (BinaryOperator *SI0 = dyn_cast<BinaryOperator>(Op0)) + if (SI0->isShift() && SI0->getOpcode() == SI1->getOpcode() && + SI0->getOperand(1) == SI1->getOperand(1) && + (SI0->hasOneUse() || SI1->hasOneUse())) { + Instruction *NewOp = + InsertNewInstBefore(BinaryOperator::CreateOr(SI0->getOperand(0), + SI1->getOperand(0), + SI0->getName()), I); + return BinaryOperator::Create(SI1->getOpcode(), NewOp, + SI1->getOperand(1)); + } + } + + // ((A|B)&1)|(B&-2) -> (A&1) | B + if (match(Op0, m_And(m_Or(m_Value(A), m_Value(B)), m_Value(C))) || + match(Op0, m_And(m_Value(C), m_Or(m_Value(A), m_Value(B))))) { + Instruction *Ret = FoldOrWithConstants(I, Op1, A, B, C); + if (Ret) return Ret; + } + // (B&-2)|((A|B)&1) -> (A&1) | B + if (match(Op1, m_And(m_Or(m_Value(A), m_Value(B)), m_Value(C))) || + match(Op1, m_And(m_Value(C), m_Or(m_Value(A), m_Value(B))))) { + Instruction *Ret = FoldOrWithConstants(I, Op0, A, B, C); + if (Ret) return Ret; + } + + if (match(Op0, m_Not(m_Value(A)))) { // ~A | Op1 + if (A == Op1) // ~A | A == -1 + return ReplaceInstUsesWith(I, Constant::getAllOnesValue(I.getType())); + } else { + A = 0; + } + // Note, A is still live here! + if (match(Op1, m_Not(m_Value(B)))) { // Op0 | ~B + if (Op0 == B) + return ReplaceInstUsesWith(I, Constant::getAllOnesValue(I.getType())); + + // (~A | ~B) == (~(A & B)) - De Morgan's Law + if (A && isOnlyUse(Op0) && isOnlyUse(Op1)) { + Value *And = InsertNewInstBefore(BinaryOperator::CreateAnd(A, B, + I.getName()+".demorgan"), I); + return BinaryOperator::CreateNot(And); + } + } + + // (icmp1 A, B) | (icmp2 A, B) --> (icmp3 A, B) + if (ICmpInst *RHS = dyn_cast<ICmpInst>(I.getOperand(1))) { + if (Instruction *R = AssociativeOpt(I, FoldICmpLogical(*this, RHS))) + return R; + + if (ICmpInst *LHS = dyn_cast<ICmpInst>(I.getOperand(0))) + if (Instruction *Res = FoldOrOfICmps(I, LHS, RHS)) + return Res; + } + + // fold (or (cast A), (cast B)) -> (cast (or A, B)) + if (CastInst *Op0C = dyn_cast<CastInst>(Op0)) { + if (CastInst *Op1C = dyn_cast<CastInst>(Op1)) + if (Op0C->getOpcode() == Op1C->getOpcode()) {// same cast kind ? + if (!isa<ICmpInst>(Op0C->getOperand(0)) || + !isa<ICmpInst>(Op1C->getOperand(0))) { + const Type *SrcTy = Op0C->getOperand(0)->getType(); + if (SrcTy == Op1C->getOperand(0)->getType() && SrcTy->isInteger() && + // Only do this if the casts both really cause code to be + // generated. + ValueRequiresCast(Op0C->getOpcode(), Op0C->getOperand(0), + I.getType(), TD) && + ValueRequiresCast(Op1C->getOpcode(), Op1C->getOperand(0), + I.getType(), TD)) { + Instruction *NewOp = BinaryOperator::CreateOr(Op0C->getOperand(0), + Op1C->getOperand(0), + I.getName()); + InsertNewInstBefore(NewOp, I); + return CastInst::Create(Op0C->getOpcode(), NewOp, I.getType()); + } + } + } + } + + + // (fcmp uno x, c) | (fcmp uno y, c) -> (fcmp uno x, y) + if (FCmpInst *LHS = dyn_cast<FCmpInst>(I.getOperand(0))) { + if (FCmpInst *RHS = dyn_cast<FCmpInst>(I.getOperand(1))) { + if (LHS->getPredicate() == FCmpInst::FCMP_UNO && + RHS->getPredicate() == FCmpInst::FCMP_UNO && + LHS->getOperand(0)->getType() == RHS->getOperand(0)->getType()) { + if (ConstantFP *LHSC = dyn_cast<ConstantFP>(LHS->getOperand(1))) + if (ConstantFP *RHSC = dyn_cast<ConstantFP>(RHS->getOperand(1))) { + // If either of the constants are nans, then the whole thing returns + // true. + if (LHSC->getValueAPF().isNaN() || RHSC->getValueAPF().isNaN()) + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + + // Otherwise, no need to compare the two constants, compare the + // rest. + return new FCmpInst(FCmpInst::FCMP_UNO, LHS->getOperand(0), + RHS->getOperand(0)); + } + } else { + Value *Op0LHS, *Op0RHS, *Op1LHS, *Op1RHS; + FCmpInst::Predicate Op0CC, Op1CC; + if (match(Op0, m_FCmp(Op0CC, m_Value(Op0LHS), m_Value(Op0RHS))) && + match(Op1, m_FCmp(Op1CC, m_Value(Op1LHS), m_Value(Op1RHS)))) { + if (Op0LHS == Op1RHS && Op0RHS == Op1LHS) { + // Swap RHS operands to match LHS. + Op1CC = FCmpInst::getSwappedPredicate(Op1CC); + std::swap(Op1LHS, Op1RHS); + } + if (Op0LHS == Op1LHS && Op0RHS == Op1RHS) { + // Simplify (fcmp cc0 x, y) | (fcmp cc1 x, y). + if (Op0CC == Op1CC) + return new FCmpInst((FCmpInst::Predicate)Op0CC, Op0LHS, Op0RHS); + else if (Op0CC == FCmpInst::FCMP_TRUE || + Op1CC == FCmpInst::FCMP_TRUE) + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + else if (Op0CC == FCmpInst::FCMP_FALSE) + return ReplaceInstUsesWith(I, Op1); + else if (Op1CC == FCmpInst::FCMP_FALSE) + return ReplaceInstUsesWith(I, Op0); + bool Op0Ordered; + bool Op1Ordered; + unsigned Op0Pred = getFCmpCode(Op0CC, Op0Ordered); + unsigned Op1Pred = getFCmpCode(Op1CC, Op1Ordered); + if (Op0Ordered == Op1Ordered) { + // If both are ordered or unordered, return a new fcmp with + // or'ed predicates. + Value *RV = getFCmpValue(Op0Ordered, Op0Pred|Op1Pred, + Op0LHS, Op0RHS); + if (Instruction *I = dyn_cast<Instruction>(RV)) + return I; + // Otherwise, it's a constant boolean value... + return ReplaceInstUsesWith(I, RV); + } + } + } + } + } + } + + return Changed ? &I : 0; +} + +namespace { + +// XorSelf - Implements: X ^ X --> 0 +struct XorSelf { + Value *RHS; + XorSelf(Value *rhs) : RHS(rhs) {} + bool shouldApply(Value *LHS) const { return LHS == RHS; } + Instruction *apply(BinaryOperator &Xor) const { + return &Xor; + } +}; + +} + +Instruction *InstCombiner::visitXor(BinaryOperator &I) { + bool Changed = SimplifyCommutative(I); + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + if (isa<UndefValue>(Op1)) { + if (isa<UndefValue>(Op0)) + // Handle undef ^ undef -> 0 special case. This is a common + // idiom (misuse). + return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); + return ReplaceInstUsesWith(I, Op1); // X ^ undef -> undef + } + + // xor X, X = 0, even if X is nested in a sequence of Xor's. + if (Instruction *Result = AssociativeOpt(I, XorSelf(Op1))) { + assert(Result == &I && "AssociativeOpt didn't work?"); Result=Result; + return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); + } + + // See if we can simplify any instructions used by the instruction whose sole + // purpose is to compute bits we don't care about. + if (!isa<VectorType>(I.getType())) { + if (SimplifyDemandedInstructionBits(I)) + return &I; + } else if (isa<ConstantAggregateZero>(Op1)) { + return ReplaceInstUsesWith(I, Op0); // X ^ <0,0> -> X + } + + // Is this a ~ operation? + if (Value *NotOp = dyn_castNotVal(&I)) { + // ~(~X & Y) --> (X | ~Y) - De Morgan's Law + // ~(~X | Y) === (X & ~Y) - De Morgan's Law + if (BinaryOperator *Op0I = dyn_cast<BinaryOperator>(NotOp)) { + if (Op0I->getOpcode() == Instruction::And || + Op0I->getOpcode() == Instruction::Or) { + if (dyn_castNotVal(Op0I->getOperand(1))) Op0I->swapOperands(); + if (Value *Op0NotVal = dyn_castNotVal(Op0I->getOperand(0))) { + Instruction *NotY = + BinaryOperator::CreateNot(Op0I->getOperand(1), + Op0I->getOperand(1)->getName()+".not"); + InsertNewInstBefore(NotY, I); + if (Op0I->getOpcode() == Instruction::And) + return BinaryOperator::CreateOr(Op0NotVal, NotY); + else + return BinaryOperator::CreateAnd(Op0NotVal, NotY); + } + } + } + } + + + if (ConstantInt *RHS = dyn_cast<ConstantInt>(Op1)) { + if (RHS == ConstantInt::getTrue() && Op0->hasOneUse()) { + // xor (cmp A, B), true = not (cmp A, B) = !cmp A, B + if (ICmpInst *ICI = dyn_cast<ICmpInst>(Op0)) + return new ICmpInst(ICI->getInversePredicate(), + ICI->getOperand(0), ICI->getOperand(1)); + + if (FCmpInst *FCI = dyn_cast<FCmpInst>(Op0)) + return new FCmpInst(FCI->getInversePredicate(), + FCI->getOperand(0), FCI->getOperand(1)); + } + + // fold (xor(zext(cmp)), 1) and (xor(sext(cmp)), -1) to ext(!cmp). + if (CastInst *Op0C = dyn_cast<CastInst>(Op0)) { + if (CmpInst *CI = dyn_cast<CmpInst>(Op0C->getOperand(0))) { + if (CI->hasOneUse() && Op0C->hasOneUse()) { + Instruction::CastOps Opcode = Op0C->getOpcode(); + if (Opcode == Instruction::ZExt || Opcode == Instruction::SExt) { + if (RHS == ConstantExpr::getCast(Opcode, ConstantInt::getTrue(), + Op0C->getDestTy())) { + Instruction *NewCI = InsertNewInstBefore(CmpInst::Create( + CI->getOpcode(), CI->getInversePredicate(), + CI->getOperand(0), CI->getOperand(1)), I); + NewCI->takeName(CI); + return CastInst::Create(Opcode, NewCI, Op0C->getType()); + } + } + } + } + } + + if (BinaryOperator *Op0I = dyn_cast<BinaryOperator>(Op0)) { + // ~(c-X) == X-c-1 == X+(-c-1) + if (Op0I->getOpcode() == Instruction::Sub && RHS->isAllOnesValue()) + if (Constant *Op0I0C = dyn_cast<Constant>(Op0I->getOperand(0))) { + Constant *NegOp0I0C = ConstantExpr::getNeg(Op0I0C); + Constant *ConstantRHS = ConstantExpr::getSub(NegOp0I0C, + ConstantInt::get(I.getType(), 1)); + return BinaryOperator::CreateAdd(Op0I->getOperand(1), ConstantRHS); + } + + if (ConstantInt *Op0CI = dyn_cast<ConstantInt>(Op0I->getOperand(1))) { + if (Op0I->getOpcode() == Instruction::Add) { + // ~(X-c) --> (-c-1)-X + if (RHS->isAllOnesValue()) { + Constant *NegOp0CI = ConstantExpr::getNeg(Op0CI); + return BinaryOperator::CreateSub( + ConstantExpr::getSub(NegOp0CI, + ConstantInt::get(I.getType(), 1)), + Op0I->getOperand(0)); + } else if (RHS->getValue().isSignBit()) { + // (X + C) ^ signbit -> (X + C + signbit) + Constant *C = ConstantInt::get(RHS->getValue() + Op0CI->getValue()); + return BinaryOperator::CreateAdd(Op0I->getOperand(0), C); + + } + } else if (Op0I->getOpcode() == Instruction::Or) { + // (X|C1)^C2 -> X^(C1|C2) iff X&~C1 == 0 + if (MaskedValueIsZero(Op0I->getOperand(0), Op0CI->getValue())) { + Constant *NewRHS = ConstantExpr::getOr(Op0CI, RHS); + // Anything in both C1 and C2 is known to be zero, remove it from + // NewRHS. + Constant *CommonBits = And(Op0CI, RHS); + NewRHS = ConstantExpr::getAnd(NewRHS, + ConstantExpr::getNot(CommonBits)); + AddToWorkList(Op0I); + I.setOperand(0, Op0I->getOperand(0)); + I.setOperand(1, NewRHS); + return &I; + } + } + } + } + + // Try to fold constant and into select arguments. + if (SelectInst *SI = dyn_cast<SelectInst>(Op0)) + if (Instruction *R = FoldOpIntoSelect(I, SI, this)) + return R; + if (isa<PHINode>(Op0)) + if (Instruction *NV = FoldOpIntoPhi(I)) + return NV; + } + + if (Value *X = dyn_castNotVal(Op0)) // ~A ^ A == -1 + if (X == Op1) + return ReplaceInstUsesWith(I, Constant::getAllOnesValue(I.getType())); + + if (Value *X = dyn_castNotVal(Op1)) // A ^ ~A == -1 + if (X == Op0) + return ReplaceInstUsesWith(I, Constant::getAllOnesValue(I.getType())); + + + BinaryOperator *Op1I = dyn_cast<BinaryOperator>(Op1); + if (Op1I) { + Value *A, *B; + if (match(Op1I, m_Or(m_Value(A), m_Value(B)))) { + if (A == Op0) { // B^(B|A) == (A|B)^B + Op1I->swapOperands(); + I.swapOperands(); + std::swap(Op0, Op1); + } else if (B == Op0) { // B^(A|B) == (A|B)^B + I.swapOperands(); // Simplified below. + std::swap(Op0, Op1); + } + } else if (match(Op1I, m_Xor(m_Specific(Op0), m_Value(B)))) { + return ReplaceInstUsesWith(I, B); // A^(A^B) == B + } else if (match(Op1I, m_Xor(m_Value(A), m_Specific(Op0)))) { + return ReplaceInstUsesWith(I, A); // A^(B^A) == B + } else if (match(Op1I, m_And(m_Value(A), m_Value(B))) && Op1I->hasOneUse()){ + if (A == Op0) { // A^(A&B) -> A^(B&A) + Op1I->swapOperands(); + std::swap(A, B); + } + if (B == Op0) { // A^(B&A) -> (B&A)^A + I.swapOperands(); // Simplified below. + std::swap(Op0, Op1); + } + } + } + + BinaryOperator *Op0I = dyn_cast<BinaryOperator>(Op0); + if (Op0I) { + Value *A, *B; + if (match(Op0I, m_Or(m_Value(A), m_Value(B))) && Op0I->hasOneUse()) { + if (A == Op1) // (B|A)^B == (A|B)^B + std::swap(A, B); + if (B == Op1) { // (A|B)^B == A & ~B + Instruction *NotB = + InsertNewInstBefore(BinaryOperator::CreateNot(Op1, "tmp"), I); + return BinaryOperator::CreateAnd(A, NotB); + } + } else if (match(Op0I, m_Xor(m_Specific(Op1), m_Value(B)))) { + return ReplaceInstUsesWith(I, B); // (A^B)^A == B + } else if (match(Op0I, m_Xor(m_Value(A), m_Specific(Op1)))) { + return ReplaceInstUsesWith(I, A); // (B^A)^A == B + } else if (match(Op0I, m_And(m_Value(A), m_Value(B))) && Op0I->hasOneUse()){ + if (A == Op1) // (A&B)^A -> (B&A)^A + std::swap(A, B); + if (B == Op1 && // (B&A)^A == ~B & A + !isa<ConstantInt>(Op1)) { // Canonical form is (B&C)^C + Instruction *N = + InsertNewInstBefore(BinaryOperator::CreateNot(A, "tmp"), I); + return BinaryOperator::CreateAnd(N, Op1); + } + } + } + + // (X >> Z) ^ (Y >> Z) -> (X^Y) >> Z for all shifts. + if (Op0I && Op1I && Op0I->isShift() && + Op0I->getOpcode() == Op1I->getOpcode() && + Op0I->getOperand(1) == Op1I->getOperand(1) && + (Op1I->hasOneUse() || Op1I->hasOneUse())) { + Instruction *NewOp = + InsertNewInstBefore(BinaryOperator::CreateXor(Op0I->getOperand(0), + Op1I->getOperand(0), + Op0I->getName()), I); + return BinaryOperator::Create(Op1I->getOpcode(), NewOp, + Op1I->getOperand(1)); + } + + if (Op0I && Op1I) { + Value *A, *B, *C, *D; + // (A & B)^(A | B) -> A ^ B + if (match(Op0I, m_And(m_Value(A), m_Value(B))) && + match(Op1I, m_Or(m_Value(C), m_Value(D)))) { + if ((A == C && B == D) || (A == D && B == C)) + return BinaryOperator::CreateXor(A, B); + } + // (A | B)^(A & B) -> A ^ B + if (match(Op0I, m_Or(m_Value(A), m_Value(B))) && + match(Op1I, m_And(m_Value(C), m_Value(D)))) { + if ((A == C && B == D) || (A == D && B == C)) + return BinaryOperator::CreateXor(A, B); + } + + // (A & B)^(C & D) + if ((Op0I->hasOneUse() || Op1I->hasOneUse()) && + match(Op0I, m_And(m_Value(A), m_Value(B))) && + match(Op1I, m_And(m_Value(C), m_Value(D)))) { + // (X & Y)^(X & Y) -> (Y^Z) & X + Value *X = 0, *Y = 0, *Z = 0; + if (A == C) + X = A, Y = B, Z = D; + else if (A == D) + X = A, Y = B, Z = C; + else if (B == C) + X = B, Y = A, Z = D; + else if (B == D) + X = B, Y = A, Z = C; + + if (X) { + Instruction *NewOp = + InsertNewInstBefore(BinaryOperator::CreateXor(Y, Z, Op0->getName()), I); + return BinaryOperator::CreateAnd(NewOp, X); + } + } + } + + // (icmp1 A, B) ^ (icmp2 A, B) --> (icmp3 A, B) + if (ICmpInst *RHS = dyn_cast<ICmpInst>(I.getOperand(1))) + if (Instruction *R = AssociativeOpt(I, FoldICmpLogical(*this, RHS))) + return R; + + // fold (xor (cast A), (cast B)) -> (cast (xor A, B)) + if (CastInst *Op0C = dyn_cast<CastInst>(Op0)) { + if (CastInst *Op1C = dyn_cast<CastInst>(Op1)) + if (Op0C->getOpcode() == Op1C->getOpcode()) { // same cast kind? + const Type *SrcTy = Op0C->getOperand(0)->getType(); + if (SrcTy == Op1C->getOperand(0)->getType() && SrcTy->isInteger() && + // Only do this if the casts both really cause code to be generated. + ValueRequiresCast(Op0C->getOpcode(), Op0C->getOperand(0), + I.getType(), TD) && + ValueRequiresCast(Op1C->getOpcode(), Op1C->getOperand(0), + I.getType(), TD)) { + Instruction *NewOp = BinaryOperator::CreateXor(Op0C->getOperand(0), + Op1C->getOperand(0), + I.getName()); + InsertNewInstBefore(NewOp, I); + return CastInst::Create(Op0C->getOpcode(), NewOp, I.getType()); + } + } + } + + return Changed ? &I : 0; +} + +/// AddWithOverflow - Compute Result = In1+In2, returning true if the result +/// overflowed for this type. +static bool AddWithOverflow(ConstantInt *&Result, ConstantInt *In1, + ConstantInt *In2, bool IsSigned = false) { + Result = cast<ConstantInt>(Add(In1, In2)); + + if (IsSigned) + if (In2->getValue().isNegative()) + return Result->getValue().sgt(In1->getValue()); + else + return Result->getValue().slt(In1->getValue()); + else + return Result->getValue().ult(In1->getValue()); +} + +/// SubWithOverflow - Compute Result = In1-In2, returning true if the result +/// overflowed for this type. +static bool SubWithOverflow(ConstantInt *&Result, ConstantInt *In1, + ConstantInt *In2, bool IsSigned = false) { + Result = cast<ConstantInt>(Subtract(In1, In2)); + + if (IsSigned) + if (In2->getValue().isNegative()) + return Result->getValue().slt(In1->getValue()); + else + return Result->getValue().sgt(In1->getValue()); + else + return Result->getValue().ugt(In1->getValue()); +} + +/// EmitGEPOffset - Given a getelementptr instruction/constantexpr, emit the +/// code necessary to compute the offset from the base pointer (without adding +/// in the base pointer). Return the result as a signed integer of intptr size. +static Value *EmitGEPOffset(User *GEP, Instruction &I, InstCombiner &IC) { + TargetData &TD = IC.getTargetData(); + gep_type_iterator GTI = gep_type_begin(GEP); + const Type *IntPtrTy = TD.getIntPtrType(); + Value *Result = Constant::getNullValue(IntPtrTy); + + // Build a mask for high order bits. + unsigned IntPtrWidth = TD.getPointerSizeInBits(); + uint64_t PtrSizeMask = ~0ULL >> (64-IntPtrWidth); + + for (User::op_iterator i = GEP->op_begin() + 1, e = GEP->op_end(); i != e; + ++i, ++GTI) { + Value *Op = *i; + uint64_t Size = TD.getTypeAllocSize(GTI.getIndexedType()) & PtrSizeMask; + if (ConstantInt *OpC = dyn_cast<ConstantInt>(Op)) { + if (OpC->isZero()) continue; + + // Handle a struct index, which adds its field offset to the pointer. + if (const StructType *STy = dyn_cast<StructType>(*GTI)) { + Size = TD.getStructLayout(STy)->getElementOffset(OpC->getZExtValue()); + + if (ConstantInt *RC = dyn_cast<ConstantInt>(Result)) + Result = ConstantInt::get(RC->getValue() + APInt(IntPtrWidth, Size)); + else + Result = IC.InsertNewInstBefore( + BinaryOperator::CreateAdd(Result, + ConstantInt::get(IntPtrTy, Size), + GEP->getName()+".offs"), I); + continue; + } + + Constant *Scale = ConstantInt::get(IntPtrTy, Size); + Constant *OC = ConstantExpr::getIntegerCast(OpC, IntPtrTy, true /*SExt*/); + Scale = ConstantExpr::getMul(OC, Scale); + if (Constant *RC = dyn_cast<Constant>(Result)) + Result = ConstantExpr::getAdd(RC, Scale); + else { + // Emit an add instruction. + Result = IC.InsertNewInstBefore( + BinaryOperator::CreateAdd(Result, Scale, + GEP->getName()+".offs"), I); + } + continue; + } + // Convert to correct type. + if (Op->getType() != IntPtrTy) { + if (Constant *OpC = dyn_cast<Constant>(Op)) + Op = ConstantExpr::getIntegerCast(OpC, IntPtrTy, true); + else + Op = IC.InsertNewInstBefore(CastInst::CreateIntegerCast(Op, IntPtrTy, + true, + Op->getName()+".c"), I); + } + if (Size != 1) { + Constant *Scale = ConstantInt::get(IntPtrTy, Size); + if (Constant *OpC = dyn_cast<Constant>(Op)) + Op = ConstantExpr::getMul(OpC, Scale); + else // We'll let instcombine(mul) convert this to a shl if possible. + Op = IC.InsertNewInstBefore(BinaryOperator::CreateMul(Op, Scale, + GEP->getName()+".idx"), I); + } + + // Emit an add instruction. + if (isa<Constant>(Op) && isa<Constant>(Result)) + Result = ConstantExpr::getAdd(cast<Constant>(Op), + cast<Constant>(Result)); + else + Result = IC.InsertNewInstBefore(BinaryOperator::CreateAdd(Op, Result, + GEP->getName()+".offs"), I); + } + return Result; +} + + +/// EvaluateGEPOffsetExpression - Return an value that can be used to compare of +/// the *offset* implied by GEP to zero. For example, if we have &A[i], we want +/// to return 'i' for "icmp ne i, 0". Note that, in general, indices can be +/// complex, and scales are involved. The above expression would also be legal +/// to codegen as "icmp ne (i*4), 0" (assuming A is a pointer to i32). This +/// later form is less amenable to optimization though, and we are allowed to +/// generate the first by knowing that pointer arithmetic doesn't overflow. +/// +/// If we can't emit an optimized form for this expression, this returns null. +/// +static Value *EvaluateGEPOffsetExpression(User *GEP, Instruction &I, + InstCombiner &IC) { + TargetData &TD = IC.getTargetData(); + gep_type_iterator GTI = gep_type_begin(GEP); + + // Check to see if this gep only has a single variable index. If so, and if + // any constant indices are a multiple of its scale, then we can compute this + // in terms of the scale of the variable index. For example, if the GEP + // implies an offset of "12 + i*4", then we can codegen this as "3 + i", + // because the expression will cross zero at the same point. + unsigned i, e = GEP->getNumOperands(); + int64_t Offset = 0; + for (i = 1; i != e; ++i, ++GTI) { + if (ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(i))) { + // Compute the aggregate offset of constant indices. + if (CI->isZero()) continue; + + // Handle a struct index, which adds its field offset to the pointer. + if (const StructType *STy = dyn_cast<StructType>(*GTI)) { + Offset += TD.getStructLayout(STy)->getElementOffset(CI->getZExtValue()); + } else { + uint64_t Size = TD.getTypeAllocSize(GTI.getIndexedType()); + Offset += Size*CI->getSExtValue(); + } + } else { + // Found our variable index. + break; + } + } + + // If there are no variable indices, we must have a constant offset, just + // evaluate it the general way. + if (i == e) return 0; + + Value *VariableIdx = GEP->getOperand(i); + // Determine the scale factor of the variable element. For example, this is + // 4 if the variable index is into an array of i32. + uint64_t VariableScale = TD.getTypeAllocSize(GTI.getIndexedType()); + + // Verify that there are no other variable indices. If so, emit the hard way. + for (++i, ++GTI; i != e; ++i, ++GTI) { + ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(i)); + if (!CI) return 0; + + // Compute the aggregate offset of constant indices. + if (CI->isZero()) continue; + + // Handle a struct index, which adds its field offset to the pointer. + if (const StructType *STy = dyn_cast<StructType>(*GTI)) { + Offset += TD.getStructLayout(STy)->getElementOffset(CI->getZExtValue()); + } else { + uint64_t Size = TD.getTypeAllocSize(GTI.getIndexedType()); + Offset += Size*CI->getSExtValue(); + } + } + + // Okay, we know we have a single variable index, which must be a + // pointer/array/vector index. If there is no offset, life is simple, return + // the index. + unsigned IntPtrWidth = TD.getPointerSizeInBits(); + if (Offset == 0) { + // Cast to intptrty in case a truncation occurs. If an extension is needed, + // we don't need to bother extending: the extension won't affect where the + // computation crosses zero. + if (VariableIdx->getType()->getPrimitiveSizeInBits() > IntPtrWidth) + VariableIdx = new TruncInst(VariableIdx, TD.getIntPtrType(), + VariableIdx->getNameStart(), &I); + return VariableIdx; + } + + // Otherwise, there is an index. The computation we will do will be modulo + // the pointer size, so get it. + uint64_t PtrSizeMask = ~0ULL >> (64-IntPtrWidth); + + Offset &= PtrSizeMask; + VariableScale &= PtrSizeMask; + + // To do this transformation, any constant index must be a multiple of the + // variable scale factor. For example, we can evaluate "12 + 4*i" as "3 + i", + // but we can't evaluate "10 + 3*i" in terms of i. Check that the offset is a + // multiple of the variable scale. + int64_t NewOffs = Offset / (int64_t)VariableScale; + if (Offset != NewOffs*(int64_t)VariableScale) + return 0; + + // Okay, we can do this evaluation. Start by converting the index to intptr. + const Type *IntPtrTy = TD.getIntPtrType(); + if (VariableIdx->getType() != IntPtrTy) + VariableIdx = CastInst::CreateIntegerCast(VariableIdx, IntPtrTy, + true /*SExt*/, + VariableIdx->getNameStart(), &I); + Constant *OffsetVal = ConstantInt::get(IntPtrTy, NewOffs); + return BinaryOperator::CreateAdd(VariableIdx, OffsetVal, "offset", &I); +} + + +/// FoldGEPICmp - Fold comparisons between a GEP instruction and something +/// else. At this point we know that the GEP is on the LHS of the comparison. +Instruction *InstCombiner::FoldGEPICmp(User *GEPLHS, Value *RHS, + ICmpInst::Predicate Cond, + Instruction &I) { + assert(dyn_castGetElementPtr(GEPLHS) && "LHS is not a getelementptr!"); + + // Look through bitcasts. + if (BitCastInst *BCI = dyn_cast<BitCastInst>(RHS)) + RHS = BCI->getOperand(0); + + Value *PtrBase = GEPLHS->getOperand(0); + if (PtrBase == RHS) { + // ((gep Ptr, OFFSET) cmp Ptr) ---> (OFFSET cmp 0). + // This transformation (ignoring the base and scales) is valid because we + // know pointers can't overflow. See if we can output an optimized form. + Value *Offset = EvaluateGEPOffsetExpression(GEPLHS, I, *this); + + // If not, synthesize the offset the hard way. + if (Offset == 0) + Offset = EmitGEPOffset(GEPLHS, I, *this); + return new ICmpInst(ICmpInst::getSignedPredicate(Cond), Offset, + Constant::getNullValue(Offset->getType())); + } else if (User *GEPRHS = dyn_castGetElementPtr(RHS)) { + // If the base pointers are different, but the indices are the same, just + // compare the base pointer. + if (PtrBase != GEPRHS->getOperand(0)) { + bool IndicesTheSame = GEPLHS->getNumOperands()==GEPRHS->getNumOperands(); + IndicesTheSame &= GEPLHS->getOperand(0)->getType() == + GEPRHS->getOperand(0)->getType(); + if (IndicesTheSame) + for (unsigned i = 1, e = GEPLHS->getNumOperands(); i != e; ++i) + if (GEPLHS->getOperand(i) != GEPRHS->getOperand(i)) { + IndicesTheSame = false; + break; + } + + // If all indices are the same, just compare the base pointers. + if (IndicesTheSame) + return new ICmpInst(ICmpInst::getSignedPredicate(Cond), + GEPLHS->getOperand(0), GEPRHS->getOperand(0)); + + // Otherwise, the base pointers are different and the indices are + // different, bail out. + return 0; + } + + // If one of the GEPs has all zero indices, recurse. + bool AllZeros = true; + for (unsigned i = 1, e = GEPLHS->getNumOperands(); i != e; ++i) + if (!isa<Constant>(GEPLHS->getOperand(i)) || + !cast<Constant>(GEPLHS->getOperand(i))->isNullValue()) { + AllZeros = false; + break; + } + if (AllZeros) + return FoldGEPICmp(GEPRHS, GEPLHS->getOperand(0), + ICmpInst::getSwappedPredicate(Cond), I); + + // If the other GEP has all zero indices, recurse. + AllZeros = true; + for (unsigned i = 1, e = GEPRHS->getNumOperands(); i != e; ++i) + if (!isa<Constant>(GEPRHS->getOperand(i)) || + !cast<Constant>(GEPRHS->getOperand(i))->isNullValue()) { + AllZeros = false; + break; + } + if (AllZeros) + return FoldGEPICmp(GEPLHS, GEPRHS->getOperand(0), Cond, I); + + if (GEPLHS->getNumOperands() == GEPRHS->getNumOperands()) { + // If the GEPs only differ by one index, compare it. + unsigned NumDifferences = 0; // Keep track of # differences. + unsigned DiffOperand = 0; // The operand that differs. + for (unsigned i = 1, e = GEPRHS->getNumOperands(); i != e; ++i) + if (GEPLHS->getOperand(i) != GEPRHS->getOperand(i)) { + if (GEPLHS->getOperand(i)->getType()->getPrimitiveSizeInBits() != + GEPRHS->getOperand(i)->getType()->getPrimitiveSizeInBits()) { + // Irreconcilable differences. + NumDifferences = 2; + break; + } else { + if (NumDifferences++) break; + DiffOperand = i; + } + } + + if (NumDifferences == 0) // SAME GEP? + return ReplaceInstUsesWith(I, // No comparison is needed here. + ConstantInt::get(Type::Int1Ty, + ICmpInst::isTrueWhenEqual(Cond))); + + else if (NumDifferences == 1) { + Value *LHSV = GEPLHS->getOperand(DiffOperand); + Value *RHSV = GEPRHS->getOperand(DiffOperand); + // Make sure we do a signed comparison here. + return new ICmpInst(ICmpInst::getSignedPredicate(Cond), LHSV, RHSV); + } + } + + // Only lower this if the icmp is the only user of the GEP or if we expect + // the result to fold to a constant! + if ((isa<ConstantExpr>(GEPLHS) || GEPLHS->hasOneUse()) && + (isa<ConstantExpr>(GEPRHS) || GEPRHS->hasOneUse())) { + // ((gep Ptr, OFFSET1) cmp (gep Ptr, OFFSET2) ---> (OFFSET1 cmp OFFSET2) + Value *L = EmitGEPOffset(GEPLHS, I, *this); + Value *R = EmitGEPOffset(GEPRHS, I, *this); + return new ICmpInst(ICmpInst::getSignedPredicate(Cond), L, R); + } + } + return 0; +} + +/// FoldFCmp_IntToFP_Cst - Fold fcmp ([us]itofp x, cst) if possible. +/// +Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I, + Instruction *LHSI, + Constant *RHSC) { + if (!isa<ConstantFP>(RHSC)) return 0; + const APFloat &RHS = cast<ConstantFP>(RHSC)->getValueAPF(); + + // Get the width of the mantissa. We don't want to hack on conversions that + // might lose information from the integer, e.g. "i64 -> float" + int MantissaWidth = LHSI->getType()->getFPMantissaWidth(); + if (MantissaWidth == -1) return 0; // Unknown. + + // Check to see that the input is converted from an integer type that is small + // enough that preserves all bits. TODO: check here for "known" sign bits. + // This would allow us to handle (fptosi (x >>s 62) to float) if x is i64 f.e. + unsigned InputSize = LHSI->getOperand(0)->getType()->getPrimitiveSizeInBits(); + + // If this is a uitofp instruction, we need an extra bit to hold the sign. + bool LHSUnsigned = isa<UIToFPInst>(LHSI); + if (LHSUnsigned) + ++InputSize; + + // If the conversion would lose info, don't hack on this. + if ((int)InputSize > MantissaWidth) + return 0; + + // Otherwise, we can potentially simplify the comparison. We know that it + // will always come through as an integer value and we know the constant is + // not a NAN (it would have been previously simplified). + assert(!RHS.isNaN() && "NaN comparison not already folded!"); + + ICmpInst::Predicate Pred; + switch (I.getPredicate()) { + default: assert(0 && "Unexpected predicate!"); + case FCmpInst::FCMP_UEQ: + case FCmpInst::FCMP_OEQ: + Pred = ICmpInst::ICMP_EQ; + break; + case FCmpInst::FCMP_UGT: + case FCmpInst::FCMP_OGT: + Pred = LHSUnsigned ? ICmpInst::ICMP_UGT : ICmpInst::ICMP_SGT; + break; + case FCmpInst::FCMP_UGE: + case FCmpInst::FCMP_OGE: + Pred = LHSUnsigned ? ICmpInst::ICMP_UGE : ICmpInst::ICMP_SGE; + break; + case FCmpInst::FCMP_ULT: + case FCmpInst::FCMP_OLT: + Pred = LHSUnsigned ? ICmpInst::ICMP_ULT : ICmpInst::ICMP_SLT; + break; + case FCmpInst::FCMP_ULE: + case FCmpInst::FCMP_OLE: + Pred = LHSUnsigned ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_SLE; + break; + case FCmpInst::FCMP_UNE: + case FCmpInst::FCMP_ONE: + Pred = ICmpInst::ICMP_NE; + break; + case FCmpInst::FCMP_ORD: + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + case FCmpInst::FCMP_UNO: + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + } + + const IntegerType *IntTy = cast<IntegerType>(LHSI->getOperand(0)->getType()); + + // Now we know that the APFloat is a normal number, zero or inf. + + // See if the FP constant is too large for the integer. For example, + // comparing an i8 to 300.0. + unsigned IntWidth = IntTy->getPrimitiveSizeInBits(); + + if (!LHSUnsigned) { + // If the RHS value is > SignedMax, fold the comparison. This handles +INF + // and large values. + APFloat SMax(RHS.getSemantics(), APFloat::fcZero, false); + SMax.convertFromAPInt(APInt::getSignedMaxValue(IntWidth), true, + APFloat::rmNearestTiesToEven); + if (SMax.compare(RHS) == APFloat::cmpLessThan) { // smax < 13123.0 + if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_SLT || + Pred == ICmpInst::ICMP_SLE) + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + } + } else { + // If the RHS value is > UnsignedMax, fold the comparison. This handles + // +INF and large values. + APFloat UMax(RHS.getSemantics(), APFloat::fcZero, false); + UMax.convertFromAPInt(APInt::getMaxValue(IntWidth), false, + APFloat::rmNearestTiesToEven); + if (UMax.compare(RHS) == APFloat::cmpLessThan) { // umax < 13123.0 + if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_ULT || + Pred == ICmpInst::ICMP_ULE) + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + } + } + + if (!LHSUnsigned) { + // See if the RHS value is < SignedMin. + APFloat SMin(RHS.getSemantics(), APFloat::fcZero, false); + SMin.convertFromAPInt(APInt::getSignedMinValue(IntWidth), true, + APFloat::rmNearestTiesToEven); + if (SMin.compare(RHS) == APFloat::cmpGreaterThan) { // smin > 12312.0 + if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_SGT || + Pred == ICmpInst::ICMP_SGE) + return ReplaceInstUsesWith(I,ConstantInt::getTrue()); + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + } + } + + // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or + // [0, UMAX], but it may still be fractional. See if it is fractional by + // casting the FP value to the integer value and back, checking for equality. + // Don't do this for zero, because -0.0 is not fractional. + Constant *RHSInt = LHSUnsigned + ? ConstantExpr::getFPToUI(RHSC, IntTy) + : ConstantExpr::getFPToSI(RHSC, IntTy); + if (!RHS.isZero()) { + bool Equal = LHSUnsigned + ? ConstantExpr::getUIToFP(RHSInt, RHSC->getType()) == RHSC + : ConstantExpr::getSIToFP(RHSInt, RHSC->getType()) == RHSC; + if (!Equal) { + // If we had a comparison against a fractional value, we have to adjust + // the compare predicate and sometimes the value. RHSC is rounded towards + // zero at this point. + switch (Pred) { + default: assert(0 && "Unexpected integer comparison!"); + case ICmpInst::ICMP_NE: // (float)int != 4.4 --> true + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + case ICmpInst::ICMP_EQ: // (float)int == 4.4 --> false + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + case ICmpInst::ICMP_ULE: + // (float)int <= 4.4 --> int <= 4 + // (float)int <= -4.4 --> false + if (RHS.isNegative()) + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + break; + case ICmpInst::ICMP_SLE: + // (float)int <= 4.4 --> int <= 4 + // (float)int <= -4.4 --> int < -4 + if (RHS.isNegative()) + Pred = ICmpInst::ICMP_SLT; + break; + case ICmpInst::ICMP_ULT: + // (float)int < -4.4 --> false + // (float)int < 4.4 --> int <= 4 + if (RHS.isNegative()) + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + Pred = ICmpInst::ICMP_ULE; + break; + case ICmpInst::ICMP_SLT: + // (float)int < -4.4 --> int < -4 + // (float)int < 4.4 --> int <= 4 + if (!RHS.isNegative()) + Pred = ICmpInst::ICMP_SLE; + break; + case ICmpInst::ICMP_UGT: + // (float)int > 4.4 --> int > 4 + // (float)int > -4.4 --> true + if (RHS.isNegative()) + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + break; + case ICmpInst::ICMP_SGT: + // (float)int > 4.4 --> int > 4 + // (float)int > -4.4 --> int >= -4 + if (RHS.isNegative()) + Pred = ICmpInst::ICMP_SGE; + break; + case ICmpInst::ICMP_UGE: + // (float)int >= -4.4 --> true + // (float)int >= 4.4 --> int > 4 + if (!RHS.isNegative()) + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + Pred = ICmpInst::ICMP_UGT; + break; + case ICmpInst::ICMP_SGE: + // (float)int >= -4.4 --> int >= -4 + // (float)int >= 4.4 --> int > 4 + if (!RHS.isNegative()) + Pred = ICmpInst::ICMP_SGT; + break; + } + } + } + + // Lower this FP comparison into an appropriate integer version of the + // comparison. + return new ICmpInst(Pred, LHSI->getOperand(0), RHSInt); +} + +Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { + bool Changed = SimplifyCompare(I); + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + // Fold trivial predicates. + if (I.getPredicate() == FCmpInst::FCMP_FALSE) + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + if (I.getPredicate() == FCmpInst::FCMP_TRUE) + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + + // Simplify 'fcmp pred X, X' + if (Op0 == Op1) { + switch (I.getPredicate()) { + default: assert(0 && "Unknown predicate!"); + case FCmpInst::FCMP_UEQ: // True if unordered or equal + case FCmpInst::FCMP_UGE: // True if unordered, greater than, or equal + case FCmpInst::FCMP_ULE: // True if unordered, less than, or equal + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + case FCmpInst::FCMP_OGT: // True if ordered and greater than + case FCmpInst::FCMP_OLT: // True if ordered and less than + case FCmpInst::FCMP_ONE: // True if ordered and operands are unequal + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + + case FCmpInst::FCMP_UNO: // True if unordered: isnan(X) | isnan(Y) + case FCmpInst::FCMP_ULT: // True if unordered or less than + case FCmpInst::FCMP_UGT: // True if unordered or greater than + case FCmpInst::FCMP_UNE: // True if unordered or not equal + // Canonicalize these to be 'fcmp uno %X, 0.0'. + I.setPredicate(FCmpInst::FCMP_UNO); + I.setOperand(1, Constant::getNullValue(Op0->getType())); + return &I; + + case FCmpInst::FCMP_ORD: // True if ordered (no nans) + case FCmpInst::FCMP_OEQ: // True if ordered and equal + case FCmpInst::FCMP_OGE: // True if ordered and greater than or equal + case FCmpInst::FCMP_OLE: // True if ordered and less than or equal + // Canonicalize these to be 'fcmp ord %X, 0.0'. + I.setPredicate(FCmpInst::FCMP_ORD); + I.setOperand(1, Constant::getNullValue(Op0->getType())); + return &I; + } + } + + if (isa<UndefValue>(Op1)) // fcmp pred X, undef -> undef + return ReplaceInstUsesWith(I, UndefValue::get(Type::Int1Ty)); + + // Handle fcmp with constant RHS + if (Constant *RHSC = dyn_cast<Constant>(Op1)) { + // If the constant is a nan, see if we can fold the comparison based on it. + if (ConstantFP *CFP = dyn_cast<ConstantFP>(RHSC)) { + if (CFP->getValueAPF().isNaN()) { + if (FCmpInst::isOrdered(I.getPredicate())) // True if ordered and... + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + assert(FCmpInst::isUnordered(I.getPredicate()) && + "Comparison must be either ordered or unordered!"); + // True if unordered. + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + } + } + + if (Instruction *LHSI = dyn_cast<Instruction>(Op0)) + switch (LHSI->getOpcode()) { + case Instruction::PHI: + // Only fold fcmp into the PHI if the phi and fcmp are in the same + // block. If in the same block, we're encouraging jump threading. If + // not, we are just pessimizing the code by making an i1 phi. + if (LHSI->getParent() == I.getParent()) + if (Instruction *NV = FoldOpIntoPhi(I)) + return NV; + break; + case Instruction::SIToFP: + case Instruction::UIToFP: + if (Instruction *NV = FoldFCmp_IntToFP_Cst(I, LHSI, RHSC)) + return NV; + break; + case Instruction::Select: + // If either operand of the select is a constant, we can fold the + // comparison into the select arms, which will cause one to be + // constant folded and the select turned into a bitwise or. + Value *Op1 = 0, *Op2 = 0; + if (LHSI->hasOneUse()) { + if (Constant *C = dyn_cast<Constant>(LHSI->getOperand(1))) { + // Fold the known value into the constant operand. + Op1 = ConstantExpr::getCompare(I.getPredicate(), C, RHSC); + // Insert a new FCmp of the other select operand. + Op2 = InsertNewInstBefore(new FCmpInst(I.getPredicate(), + LHSI->getOperand(2), RHSC, + I.getName()), I); + } else if (Constant *C = dyn_cast<Constant>(LHSI->getOperand(2))) { + // Fold the known value into the constant operand. + Op2 = ConstantExpr::getCompare(I.getPredicate(), C, RHSC); + // Insert a new FCmp of the other select operand. + Op1 = InsertNewInstBefore(new FCmpInst(I.getPredicate(), + LHSI->getOperand(1), RHSC, + I.getName()), I); + } + } + + if (Op1) + return SelectInst::Create(LHSI->getOperand(0), Op1, Op2); + break; + } + } + + return Changed ? &I : 0; +} + +Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { + bool Changed = SimplifyCompare(I); + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + const Type *Ty = Op0->getType(); + + // icmp X, X + if (Op0 == Op1) + return ReplaceInstUsesWith(I, ConstantInt::get(Type::Int1Ty, + I.isTrueWhenEqual())); + + if (isa<UndefValue>(Op1)) // X icmp undef -> undef + return ReplaceInstUsesWith(I, UndefValue::get(Type::Int1Ty)); + + // icmp <global/alloca*/null>, <global/alloca*/null> - Global/Stack value + // addresses never equal each other! We already know that Op0 != Op1. + if ((isa<GlobalValue>(Op0) || isa<AllocaInst>(Op0) || + isa<ConstantPointerNull>(Op0)) && + (isa<GlobalValue>(Op1) || isa<AllocaInst>(Op1) || + isa<ConstantPointerNull>(Op1))) + return ReplaceInstUsesWith(I, ConstantInt::get(Type::Int1Ty, + !I.isTrueWhenEqual())); + + // icmp's with boolean values can always be turned into bitwise operations + if (Ty == Type::Int1Ty) { + switch (I.getPredicate()) { + default: assert(0 && "Invalid icmp instruction!"); + case ICmpInst::ICMP_EQ: { // icmp eq i1 A, B -> ~(A^B) + Instruction *Xor = BinaryOperator::CreateXor(Op0, Op1, I.getName()+"tmp"); + InsertNewInstBefore(Xor, I); + return BinaryOperator::CreateNot(Xor); + } + case ICmpInst::ICMP_NE: // icmp eq i1 A, B -> A^B + return BinaryOperator::CreateXor(Op0, Op1); + + case ICmpInst::ICMP_UGT: + std::swap(Op0, Op1); // Change icmp ugt -> icmp ult + // FALL THROUGH + case ICmpInst::ICMP_ULT:{ // icmp ult i1 A, B -> ~A & B + Instruction *Not = BinaryOperator::CreateNot(Op0, I.getName()+"tmp"); + InsertNewInstBefore(Not, I); + return BinaryOperator::CreateAnd(Not, Op1); + } + case ICmpInst::ICMP_SGT: + std::swap(Op0, Op1); // Change icmp sgt -> icmp slt + // FALL THROUGH + case ICmpInst::ICMP_SLT: { // icmp slt i1 A, B -> A & ~B + Instruction *Not = BinaryOperator::CreateNot(Op1, I.getName()+"tmp"); + InsertNewInstBefore(Not, I); + return BinaryOperator::CreateAnd(Not, Op0); + } + case ICmpInst::ICMP_UGE: + std::swap(Op0, Op1); // Change icmp uge -> icmp ule + // FALL THROUGH + case ICmpInst::ICMP_ULE: { // icmp ule i1 A, B -> ~A | B + Instruction *Not = BinaryOperator::CreateNot(Op0, I.getName()+"tmp"); + InsertNewInstBefore(Not, I); + return BinaryOperator::CreateOr(Not, Op1); + } + case ICmpInst::ICMP_SGE: + std::swap(Op0, Op1); // Change icmp sge -> icmp sle + // FALL THROUGH + case ICmpInst::ICMP_SLE: { // icmp sle i1 A, B -> A | ~B + Instruction *Not = BinaryOperator::CreateNot(Op1, I.getName()+"tmp"); + InsertNewInstBefore(Not, I); + return BinaryOperator::CreateOr(Not, Op0); + } + } + } + + unsigned BitWidth = 0; + if (TD) + BitWidth = TD->getTypeSizeInBits(Ty); + else if (isa<IntegerType>(Ty)) + BitWidth = Ty->getPrimitiveSizeInBits(); + + bool isSignBit = false; + + // See if we are doing a comparison with a constant. + if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { + Value *A = 0, *B = 0; + + // (icmp ne/eq (sub A B) 0) -> (icmp ne/eq A, B) + if (I.isEquality() && CI->isNullValue() && + match(Op0, m_Sub(m_Value(A), m_Value(B)))) { + // (icmp cond A B) if cond is equality + return new ICmpInst(I.getPredicate(), A, B); + } + + // If we have an icmp le or icmp ge instruction, turn it into the + // appropriate icmp lt or icmp gt instruction. This allows us to rely on + // them being folded in the code below. + switch (I.getPredicate()) { + default: break; + case ICmpInst::ICMP_ULE: + if (CI->isMaxValue(false)) // A <=u MAX -> TRUE + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + return new ICmpInst(ICmpInst::ICMP_ULT, Op0, AddOne(CI)); + case ICmpInst::ICMP_SLE: + if (CI->isMaxValue(true)) // A <=s MAX -> TRUE + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + return new ICmpInst(ICmpInst::ICMP_SLT, Op0, AddOne(CI)); + case ICmpInst::ICMP_UGE: + if (CI->isMinValue(false)) // A >=u MIN -> TRUE + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + return new ICmpInst( ICmpInst::ICMP_UGT, Op0, SubOne(CI)); + case ICmpInst::ICMP_SGE: + if (CI->isMinValue(true)) // A >=s MIN -> TRUE + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + return new ICmpInst(ICmpInst::ICMP_SGT, Op0, SubOne(CI)); + } + + // If this comparison is a normal comparison, it demands all + // bits, if it is a sign bit comparison, it only demands the sign bit. + bool UnusedBit; + isSignBit = isSignBitCheck(I.getPredicate(), CI, UnusedBit); + } + + // See if we can fold the comparison based on range information we can get + // by checking whether bits are known to be zero or one in the input. + if (BitWidth != 0) { + APInt Op0KnownZero(BitWidth, 0), Op0KnownOne(BitWidth, 0); + APInt Op1KnownZero(BitWidth, 0), Op1KnownOne(BitWidth, 0); + + if (SimplifyDemandedBits(I.getOperandUse(0), + isSignBit ? APInt::getSignBit(BitWidth) + : APInt::getAllOnesValue(BitWidth), + Op0KnownZero, Op0KnownOne, 0)) + return &I; + if (SimplifyDemandedBits(I.getOperandUse(1), + APInt::getAllOnesValue(BitWidth), + Op1KnownZero, Op1KnownOne, 0)) + return &I; + + // Given the known and unknown bits, compute a range that the LHS could be + // in. Compute the Min, Max and RHS values based on the known bits. For the + // EQ and NE we use unsigned values. + APInt Op0Min(BitWidth, 0), Op0Max(BitWidth, 0); + APInt Op1Min(BitWidth, 0), Op1Max(BitWidth, 0); + if (ICmpInst::isSignedPredicate(I.getPredicate())) { + ComputeSignedMinMaxValuesFromKnownBits(Op0KnownZero, Op0KnownOne, + Op0Min, Op0Max); + ComputeSignedMinMaxValuesFromKnownBits(Op1KnownZero, Op1KnownOne, + Op1Min, Op1Max); + } else { + ComputeUnsignedMinMaxValuesFromKnownBits(Op0KnownZero, Op0KnownOne, + Op0Min, Op0Max); + ComputeUnsignedMinMaxValuesFromKnownBits(Op1KnownZero, Op1KnownOne, + Op1Min, Op1Max); + } + + // If Min and Max are known to be the same, then SimplifyDemandedBits + // figured out that the LHS is a constant. Just constant fold this now so + // that code below can assume that Min != Max. + if (!isa<Constant>(Op0) && Op0Min == Op0Max) + return new ICmpInst(I.getPredicate(), ConstantInt::get(Op0Min), Op1); + if (!isa<Constant>(Op1) && Op1Min == Op1Max) + return new ICmpInst(I.getPredicate(), Op0, ConstantInt::get(Op1Min)); + + // Based on the range information we know about the LHS, see if we can + // simplify this comparison. For example, (x&4) < 8 is always true. + switch (I.getPredicate()) { + default: assert(0 && "Unknown icmp opcode!"); + case ICmpInst::ICMP_EQ: + if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max)) + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + break; + case ICmpInst::ICMP_NE: + if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max)) + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + break; + case ICmpInst::ICMP_ULT: + if (Op0Max.ult(Op1Min)) // A <u B -> true if max(A) < min(B) + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + if (Op0Min.uge(Op1Max)) // A <u B -> false if min(A) >= max(B) + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + if (Op1Min == Op0Max) // A <u B -> A != B if max(A) == min(B) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); + if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { + if (Op1Max == Op0Min+1) // A <u C -> A == C-1 if min(A)+1 == C + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, SubOne(CI)); + + // (x <u 2147483648) -> (x >s -1) -> true if sign bit clear + if (CI->isMinValue(true)) + return new ICmpInst(ICmpInst::ICMP_SGT, Op0, + ConstantInt::getAllOnesValue(Op0->getType())); + } + break; + case ICmpInst::ICMP_UGT: + if (Op0Min.ugt(Op1Max)) // A >u B -> true if min(A) > max(B) + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + if (Op0Max.ule(Op1Min)) // A >u B -> false if max(A) <= max(B) + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + + if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); + if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { + if (Op1Min == Op0Max-1) // A >u C -> A == C+1 if max(a)-1 == C + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, AddOne(CI)); + + // (x >u 2147483647) -> (x <s 0) -> true if sign bit set + if (CI->isMaxValue(true)) + return new ICmpInst(ICmpInst::ICMP_SLT, Op0, + ConstantInt::getNullValue(Op0->getType())); + } + break; + case ICmpInst::ICMP_SLT: + if (Op0Max.slt(Op1Min)) // A <s B -> true if max(A) < min(C) + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + if (Op0Min.sge(Op1Max)) // A <s B -> false if min(A) >= max(C) + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); + if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { + if (Op1Max == Op0Min+1) // A <s C -> A == C-1 if min(A)+1 == C + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, SubOne(CI)); + } + break; + case ICmpInst::ICMP_SGT: + if (Op0Min.sgt(Op1Max)) // A >s B -> true if min(A) > max(B) + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + if (Op0Max.sle(Op1Min)) // A >s B -> false if max(A) <= min(B) + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + + if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); + if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { + if (Op1Min == Op0Max-1) // A >s C -> A == C+1 if max(A)-1 == C + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, AddOne(CI)); + } + break; + case ICmpInst::ICMP_SGE: + assert(!isa<ConstantInt>(Op1) && "ICMP_SGE with ConstantInt not folded!"); + if (Op0Min.sge(Op1Max)) // A >=s B -> true if min(A) >= max(B) + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + if (Op0Max.slt(Op1Min)) // A >=s B -> false if max(A) < min(B) + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + break; + case ICmpInst::ICMP_SLE: + assert(!isa<ConstantInt>(Op1) && "ICMP_SLE with ConstantInt not folded!"); + if (Op0Max.sle(Op1Min)) // A <=s B -> true if max(A) <= min(B) + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + if (Op0Min.sgt(Op1Max)) // A <=s B -> false if min(A) > max(B) + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + break; + case ICmpInst::ICMP_UGE: + assert(!isa<ConstantInt>(Op1) && "ICMP_UGE with ConstantInt not folded!"); + if (Op0Min.uge(Op1Max)) // A >=u B -> true if min(A) >= max(B) + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + if (Op0Max.ult(Op1Min)) // A >=u B -> false if max(A) < min(B) + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + break; + case ICmpInst::ICMP_ULE: + assert(!isa<ConstantInt>(Op1) && "ICMP_ULE with ConstantInt not folded!"); + if (Op0Max.ule(Op1Min)) // A <=u B -> true if max(A) <= min(B) + return ReplaceInstUsesWith(I, ConstantInt::getTrue()); + if (Op0Min.ugt(Op1Max)) // A <=u B -> false if min(A) > max(B) + return ReplaceInstUsesWith(I, ConstantInt::getFalse()); + break; + } + + // Turn a signed comparison into an unsigned one if both operands + // are known to have the same sign. + if (I.isSignedPredicate() && + ((Op0KnownZero.isNegative() && Op1KnownZero.isNegative()) || + (Op0KnownOne.isNegative() && Op1KnownOne.isNegative()))) + return new ICmpInst(I.getUnsignedPredicate(), Op0, Op1); + } + + // Test if the ICmpInst instruction is used exclusively by a select as + // part of a minimum or maximum operation. If so, refrain from doing + // any other folding. This helps out other analyses which understand + // non-obfuscated minimum and maximum idioms, such as ScalarEvolution + // and CodeGen. And in this case, at least one of the comparison + // operands has at least one user besides the compare (the select), + // which would often largely negate the benefit of folding anyway. + if (I.hasOneUse()) + if (SelectInst *SI = dyn_cast<SelectInst>(*I.use_begin())) + if ((SI->getOperand(1) == Op0 && SI->getOperand(2) == Op1) || + (SI->getOperand(2) == Op0 && SI->getOperand(1) == Op1)) + return 0; + + // See if we are doing a comparison between a constant and an instruction that + // can be folded into the comparison. + if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { + // Since the RHS is a ConstantInt (CI), if the left hand side is an + // instruction, see if that instruction also has constants so that the + // instruction can be folded into the icmp + if (Instruction *LHSI = dyn_cast<Instruction>(Op0)) + if (Instruction *Res = visitICmpInstWithInstAndIntCst(I, LHSI, CI)) + return Res; + } + + // Handle icmp with constant (but not simple integer constant) RHS + if (Constant *RHSC = dyn_cast<Constant>(Op1)) { + if (Instruction *LHSI = dyn_cast<Instruction>(Op0)) + switch (LHSI->getOpcode()) { + case Instruction::GetElementPtr: + if (RHSC->isNullValue()) { + // icmp pred GEP (P, int 0, int 0, int 0), null -> icmp pred P, null + bool isAllZeros = true; + for (unsigned i = 1, e = LHSI->getNumOperands(); i != e; ++i) + if (!isa<Constant>(LHSI->getOperand(i)) || + !cast<Constant>(LHSI->getOperand(i))->isNullValue()) { + isAllZeros = false; + break; + } + if (isAllZeros) + return new ICmpInst(I.getPredicate(), LHSI->getOperand(0), + Constant::getNullValue(LHSI->getOperand(0)->getType())); + } + break; + + case Instruction::PHI: + // Only fold icmp into the PHI if the phi and fcmp are in the same + // block. If in the same block, we're encouraging jump threading. If + // not, we are just pessimizing the code by making an i1 phi. + if (LHSI->getParent() == I.getParent()) + if (Instruction *NV = FoldOpIntoPhi(I)) + return NV; + break; + case Instruction::Select: { + // If either operand of the select is a constant, we can fold the + // comparison into the select arms, which will cause one to be + // constant folded and the select turned into a bitwise or. + Value *Op1 = 0, *Op2 = 0; + if (LHSI->hasOneUse()) { + if (Constant *C = dyn_cast<Constant>(LHSI->getOperand(1))) { + // Fold the known value into the constant operand. + Op1 = ConstantExpr::getICmp(I.getPredicate(), C, RHSC); + // Insert a new ICmp of the other select operand. + Op2 = InsertNewInstBefore(new ICmpInst(I.getPredicate(), + LHSI->getOperand(2), RHSC, + I.getName()), I); + } else if (Constant *C = dyn_cast<Constant>(LHSI->getOperand(2))) { + // Fold the known value into the constant operand. + Op2 = ConstantExpr::getICmp(I.getPredicate(), C, RHSC); + // Insert a new ICmp of the other select operand. + Op1 = InsertNewInstBefore(new ICmpInst(I.getPredicate(), + LHSI->getOperand(1), RHSC, + I.getName()), I); + } + } + + if (Op1) + return SelectInst::Create(LHSI->getOperand(0), Op1, Op2); + break; + } + case Instruction::Malloc: + // If we have (malloc != null), and if the malloc has a single use, we + // can assume it is successful and remove the malloc. + if (LHSI->hasOneUse() && isa<ConstantPointerNull>(RHSC)) { + AddToWorkList(LHSI); + return ReplaceInstUsesWith(I, ConstantInt::get(Type::Int1Ty, + !I.isTrueWhenEqual())); + } + break; + } + } + + // If we can optimize a 'icmp GEP, P' or 'icmp P, GEP', do so now. + if (User *GEP = dyn_castGetElementPtr(Op0)) + if (Instruction *NI = FoldGEPICmp(GEP, Op1, I.getPredicate(), I)) + return NI; + if (User *GEP = dyn_castGetElementPtr(Op1)) + if (Instruction *NI = FoldGEPICmp(GEP, Op0, + ICmpInst::getSwappedPredicate(I.getPredicate()), I)) + return NI; + + // Test to see if the operands of the icmp are casted versions of other + // values. If the ptr->ptr cast can be stripped off both arguments, we do so + // now. + if (BitCastInst *CI = dyn_cast<BitCastInst>(Op0)) { + if (isa<PointerType>(Op0->getType()) && + (isa<Constant>(Op1) || isa<BitCastInst>(Op1))) { + // We keep moving the cast from the left operand over to the right + // operand, where it can often be eliminated completely. + Op0 = CI->getOperand(0); + + // If operand #1 is a bitcast instruction, it must also be a ptr->ptr cast + // so eliminate it as well. + if (BitCastInst *CI2 = dyn_cast<BitCastInst>(Op1)) + Op1 = CI2->getOperand(0); + + // If Op1 is a constant, we can fold the cast into the constant. + if (Op0->getType() != Op1->getType()) { + if (Constant *Op1C = dyn_cast<Constant>(Op1)) { + Op1 = ConstantExpr::getBitCast(Op1C, Op0->getType()); + } else { + // Otherwise, cast the RHS right before the icmp + Op1 = InsertBitCastBefore(Op1, Op0->getType(), I); + } + } + return new ICmpInst(I.getPredicate(), Op0, Op1); + } + } + + if (isa<CastInst>(Op0)) { + // Handle the special case of: icmp (cast bool to X), <cst> + // This comes up when you have code like + // int X = A < B; + // if (X) ... + // For generality, we handle any zero-extension of any operand comparison + // with a constant or another cast from the same type. + if (isa<ConstantInt>(Op1) || isa<CastInst>(Op1)) + if (Instruction *R = visitICmpInstWithCastAndCast(I)) + return R; + } + + // See if it's the same type of instruction on the left and right. + if (BinaryOperator *Op0I = dyn_cast<BinaryOperator>(Op0)) { + if (BinaryOperator *Op1I = dyn_cast<BinaryOperator>(Op1)) { + if (Op0I->getOpcode() == Op1I->getOpcode() && Op0I->hasOneUse() && + Op1I->hasOneUse() && Op0I->getOperand(1) == Op1I->getOperand(1)) { + switch (Op0I->getOpcode()) { + default: break; + case Instruction::Add: + case Instruction::Sub: + case Instruction::Xor: + if (I.isEquality()) // a+x icmp eq/ne b+x --> a icmp b + return new ICmpInst(I.getPredicate(), Op0I->getOperand(0), + Op1I->getOperand(0)); + // icmp u/s (a ^ signbit), (b ^ signbit) --> icmp s/u a, b + if (ConstantInt *CI = dyn_cast<ConstantInt>(Op0I->getOperand(1))) { + if (CI->getValue().isSignBit()) { + ICmpInst::Predicate Pred = I.isSignedPredicate() + ? I.getUnsignedPredicate() + : I.getSignedPredicate(); + return new ICmpInst(Pred, Op0I->getOperand(0), + Op1I->getOperand(0)); + } + + if (CI->getValue().isMaxSignedValue()) { + ICmpInst::Predicate Pred = I.isSignedPredicate() + ? I.getUnsignedPredicate() + : I.getSignedPredicate(); + Pred = I.getSwappedPredicate(Pred); + return new ICmpInst(Pred, Op0I->getOperand(0), + Op1I->getOperand(0)); + } + } + break; + case Instruction::Mul: + if (!I.isEquality()) + break; + + if (ConstantInt *CI = dyn_cast<ConstantInt>(Op0I->getOperand(1))) { + // a * Cst icmp eq/ne b * Cst --> a & Mask icmp b & Mask + // Mask = -1 >> count-trailing-zeros(Cst). + if (!CI->isZero() && !CI->isOne()) { + const APInt &AP = CI->getValue(); + ConstantInt *Mask = ConstantInt::get( + APInt::getLowBitsSet(AP.getBitWidth(), + AP.getBitWidth() - + AP.countTrailingZeros())); + Instruction *And1 = BinaryOperator::CreateAnd(Op0I->getOperand(0), + Mask); + Instruction *And2 = BinaryOperator::CreateAnd(Op1I->getOperand(0), + Mask); + InsertNewInstBefore(And1, I); + InsertNewInstBefore(And2, I); + return new ICmpInst(I.getPredicate(), And1, And2); + } + } + break; + } + } + } + } + + // ~x < ~y --> y < x + { Value *A, *B; + if (match(Op0, m_Not(m_Value(A))) && + match(Op1, m_Not(m_Value(B)))) + return new ICmpInst(I.getPredicate(), B, A); + } + + if (I.isEquality()) { + Value *A, *B, *C, *D; + + // -x == -y --> x == y + if (match(Op0, m_Neg(m_Value(A))) && + match(Op1, m_Neg(m_Value(B)))) + return new ICmpInst(I.getPredicate(), A, B); + + if (match(Op0, m_Xor(m_Value(A), m_Value(B)))) { + if (A == Op1 || B == Op1) { // (A^B) == A -> B == 0 + Value *OtherVal = A == Op1 ? B : A; + return new ICmpInst(I.getPredicate(), OtherVal, + Constant::getNullValue(A->getType())); + } + + if (match(Op1, m_Xor(m_Value(C), m_Value(D)))) { + // A^c1 == C^c2 --> A == C^(c1^c2) + ConstantInt *C1, *C2; + if (match(B, m_ConstantInt(C1)) && + match(D, m_ConstantInt(C2)) && Op1->hasOneUse()) { + Constant *NC = ConstantInt::get(C1->getValue() ^ C2->getValue()); + Instruction *Xor = BinaryOperator::CreateXor(C, NC, "tmp"); + return new ICmpInst(I.getPredicate(), A, + InsertNewInstBefore(Xor, I)); + } + + // A^B == A^D -> B == D + if (A == C) return new ICmpInst(I.getPredicate(), B, D); + if (A == D) return new ICmpInst(I.getPredicate(), B, C); + if (B == C) return new ICmpInst(I.getPredicate(), A, D); + if (B == D) return new ICmpInst(I.getPredicate(), A, C); + } + } + + if (match(Op1, m_Xor(m_Value(A), m_Value(B))) && + (A == Op0 || B == Op0)) { + // A == (A^B) -> B == 0 + Value *OtherVal = A == Op0 ? B : A; + return new ICmpInst(I.getPredicate(), OtherVal, + Constant::getNullValue(A->getType())); + } + + // (A-B) == A -> B == 0 + if (match(Op0, m_Sub(m_Specific(Op1), m_Value(B)))) + return new ICmpInst(I.getPredicate(), B, + Constant::getNullValue(B->getType())); + + // A == (A-B) -> B == 0 + if (match(Op1, m_Sub(m_Specific(Op0), m_Value(B)))) + return new ICmpInst(I.getPredicate(), B, + Constant::getNullValue(B->getType())); + + // (X&Z) == (Y&Z) -> (X^Y) & Z == 0 + if (Op0->hasOneUse() && Op1->hasOneUse() && + match(Op0, m_And(m_Value(A), m_Value(B))) && + match(Op1, m_And(m_Value(C), m_Value(D)))) { + Value *X = 0, *Y = 0, *Z = 0; + + if (A == C) { + X = B; Y = D; Z = A; + } else if (A == D) { + X = B; Y = C; Z = A; + } else if (B == C) { + X = A; Y = D; Z = B; + } else if (B == D) { + X = A; Y = C; Z = B; + } + + if (X) { // Build (X^Y) & Z + Op1 = InsertNewInstBefore(BinaryOperator::CreateXor(X, Y, "tmp"), I); + Op1 = InsertNewInstBefore(BinaryOperator::CreateAnd(Op1, Z, "tmp"), I); + I.setOperand(0, Op1); + I.setOperand(1, Constant::getNullValue(Op1->getType())); + return &I; + } + } + } + return Changed ? &I : 0; +} + + +/// FoldICmpDivCst - Fold "icmp pred, ([su]div X, DivRHS), CmpRHS" where DivRHS +/// and CmpRHS are both known to be integer constants. +Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, + ConstantInt *DivRHS) { + ConstantInt *CmpRHS = cast<ConstantInt>(ICI.getOperand(1)); + const APInt &CmpRHSV = CmpRHS->getValue(); + + // FIXME: If the operand types don't match the type of the divide + // then don't attempt this transform. The code below doesn't have the + // logic to deal with a signed divide and an unsigned compare (and + // vice versa). This is because (x /s C1) <s C2 produces different + // results than (x /s C1) <u C2 or (x /u C1) <s C2 or even + // (x /u C1) <u C2. Simply casting the operands and result won't + // work. :( The if statement below tests that condition and bails + // if it finds it. + bool DivIsSigned = DivI->getOpcode() == Instruction::SDiv; + if (!ICI.isEquality() && DivIsSigned != ICI.isSignedPredicate()) + return 0; + if (DivRHS->isZero()) + return 0; // The ProdOV computation fails on divide by zero. + if (DivIsSigned && DivRHS->isAllOnesValue()) + return 0; // The overflow computation also screws up here + if (DivRHS->isOne()) + return 0; // Not worth bothering, and eliminates some funny cases + // with INT_MIN. + + // Compute Prod = CI * DivRHS. We are essentially solving an equation + // of form X/C1=C2. We solve for X by multiplying C1 (DivRHS) and + // C2 (CI). By solving for X we can turn this into a range check + // instead of computing a divide. + ConstantInt *Prod = Multiply(CmpRHS, DivRHS); + + // Determine if the product overflows by seeing if the product is + // not equal to the divide. Make sure we do the same kind of divide + // as in the LHS instruction that we're folding. + bool ProdOV = (DivIsSigned ? ConstantExpr::getSDiv(Prod, DivRHS) : + ConstantExpr::getUDiv(Prod, DivRHS)) != CmpRHS; + + // Get the ICmp opcode + ICmpInst::Predicate Pred = ICI.getPredicate(); + + // Figure out the interval that is being checked. For example, a comparison + // like "X /u 5 == 0" is really checking that X is in the interval [0, 5). + // Compute this interval based on the constants involved and the signedness of + // the compare/divide. This computes a half-open interval, keeping track of + // whether either value in the interval overflows. After analysis each + // overflow variable is set to 0 if it's corresponding bound variable is valid + // -1 if overflowed off the bottom end, or +1 if overflowed off the top end. + int LoOverflow = 0, HiOverflow = 0; + ConstantInt *LoBound = 0, *HiBound = 0; + + if (!DivIsSigned) { // udiv + // e.g. X/5 op 3 --> [15, 20) + LoBound = Prod; + HiOverflow = LoOverflow = ProdOV; + if (!HiOverflow) + HiOverflow = AddWithOverflow(HiBound, LoBound, DivRHS, false); + } else if (DivRHS->getValue().isStrictlyPositive()) { // Divisor is > 0. + if (CmpRHSV == 0) { // (X / pos) op 0 + // Can't overflow. e.g. X/2 op 0 --> [-1, 2) + LoBound = cast<ConstantInt>(ConstantExpr::getNeg(SubOne(DivRHS))); + HiBound = DivRHS; + } else if (CmpRHSV.isStrictlyPositive()) { // (X / pos) op pos + LoBound = Prod; // e.g. X/5 op 3 --> [15, 20) + HiOverflow = LoOverflow = ProdOV; + if (!HiOverflow) + HiOverflow = AddWithOverflow(HiBound, Prod, DivRHS, true); + } else { // (X / pos) op neg + // e.g. X/5 op -3 --> [-15-4, -15+1) --> [-19, -14) + HiBound = AddOne(Prod); + LoOverflow = HiOverflow = ProdOV ? -1 : 0; + if (!LoOverflow) { + ConstantInt* DivNeg = cast<ConstantInt>(ConstantExpr::getNeg(DivRHS)); + LoOverflow = AddWithOverflow(LoBound, HiBound, DivNeg, + true) ? -1 : 0; + } + } + } else if (DivRHS->getValue().isNegative()) { // Divisor is < 0. + if (CmpRHSV == 0) { // (X / neg) op 0 + // e.g. X/-5 op 0 --> [-4, 5) + LoBound = AddOne(DivRHS); + HiBound = cast<ConstantInt>(ConstantExpr::getNeg(DivRHS)); + if (HiBound == DivRHS) { // -INTMIN = INTMIN + HiOverflow = 1; // [INTMIN+1, overflow) + HiBound = 0; // e.g. X/INTMIN = 0 --> X > INTMIN + } + } else if (CmpRHSV.isStrictlyPositive()) { // (X / neg) op pos + // e.g. X/-5 op 3 --> [-19, -14) + HiBound = AddOne(Prod); + HiOverflow = LoOverflow = ProdOV ? -1 : 0; + if (!LoOverflow) + LoOverflow = AddWithOverflow(LoBound, HiBound, DivRHS, true) ? -1 : 0; + } else { // (X / neg) op neg + LoBound = Prod; // e.g. X/-5 op -3 --> [15, 20) + LoOverflow = HiOverflow = ProdOV; + if (!HiOverflow) + HiOverflow = SubWithOverflow(HiBound, Prod, DivRHS, true); + } + + // Dividing by a negative swaps the condition. LT <-> GT + Pred = ICmpInst::getSwappedPredicate(Pred); + } + + Value *X = DivI->getOperand(0); + switch (Pred) { + default: assert(0 && "Unhandled icmp opcode!"); + case ICmpInst::ICMP_EQ: + if (LoOverflow && HiOverflow) + return ReplaceInstUsesWith(ICI, ConstantInt::getFalse()); + else if (HiOverflow) + return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : + ICmpInst::ICMP_UGE, X, LoBound); + else if (LoOverflow) + return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : + ICmpInst::ICMP_ULT, X, HiBound); + else + return InsertRangeTest(X, LoBound, HiBound, DivIsSigned, true, ICI); + case ICmpInst::ICMP_NE: + if (LoOverflow && HiOverflow) + return ReplaceInstUsesWith(ICI, ConstantInt::getTrue()); + else if (HiOverflow) + return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : + ICmpInst::ICMP_ULT, X, LoBound); + else if (LoOverflow) + return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : + ICmpInst::ICMP_UGE, X, HiBound); + else + return InsertRangeTest(X, LoBound, HiBound, DivIsSigned, false, ICI); + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_SLT: + if (LoOverflow == +1) // Low bound is greater than input range. + return ReplaceInstUsesWith(ICI, ConstantInt::getTrue()); + if (LoOverflow == -1) // Low bound is less than input range. + return ReplaceInstUsesWith(ICI, ConstantInt::getFalse()); + return new ICmpInst(Pred, X, LoBound); + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_SGT: + if (HiOverflow == +1) // High bound greater than input range. + return ReplaceInstUsesWith(ICI, ConstantInt::getFalse()); + else if (HiOverflow == -1) // High bound less than input range. + return ReplaceInstUsesWith(ICI, ConstantInt::getTrue()); + if (Pred == ICmpInst::ICMP_UGT) + return new ICmpInst(ICmpInst::ICMP_UGE, X, HiBound); + else + return new ICmpInst(ICmpInst::ICMP_SGE, X, HiBound); + } +} + + +/// visitICmpInstWithInstAndIntCst - Handle "icmp (instr, intcst)". +/// +Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, + Instruction *LHSI, + ConstantInt *RHS) { + const APInt &RHSV = RHS->getValue(); + + switch (LHSI->getOpcode()) { + case Instruction::Trunc: + if (ICI.isEquality() && LHSI->hasOneUse()) { + // Simplify icmp eq (trunc x to i8), 42 -> icmp eq x, 42|highbits if all + // of the high bits truncated out of x are known. + unsigned DstBits = LHSI->getType()->getPrimitiveSizeInBits(), + SrcBits = LHSI->getOperand(0)->getType()->getPrimitiveSizeInBits(); + APInt Mask(APInt::getHighBitsSet(SrcBits, SrcBits-DstBits)); + APInt KnownZero(SrcBits, 0), KnownOne(SrcBits, 0); + ComputeMaskedBits(LHSI->getOperand(0), Mask, KnownZero, KnownOne); + + // If all the high bits are known, we can do this xform. + if ((KnownZero|KnownOne).countLeadingOnes() >= SrcBits-DstBits) { + // Pull in the high bits from known-ones set. + APInt NewRHS(RHS->getValue()); + NewRHS.zext(SrcBits); + NewRHS |= KnownOne; + return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0), + ConstantInt::get(NewRHS)); + } + } + break; + + case Instruction::Xor: // (icmp pred (xor X, XorCST), CI) + if (ConstantInt *XorCST = dyn_cast<ConstantInt>(LHSI->getOperand(1))) { + // If this is a comparison that tests the signbit (X < 0) or (x > -1), + // fold the xor. + if ((ICI.getPredicate() == ICmpInst::ICMP_SLT && RHSV == 0) || + (ICI.getPredicate() == ICmpInst::ICMP_SGT && RHSV.isAllOnesValue())) { + Value *CompareVal = LHSI->getOperand(0); + + // If the sign bit of the XorCST is not set, there is no change to + // the operation, just stop using the Xor. + if (!XorCST->getValue().isNegative()) { + ICI.setOperand(0, CompareVal); + AddToWorkList(LHSI); + return &ICI; + } + + // Was the old condition true if the operand is positive? + bool isTrueIfPositive = ICI.getPredicate() == ICmpInst::ICMP_SGT; + + // If so, the new one isn't. + isTrueIfPositive ^= true; + + if (isTrueIfPositive) + return new ICmpInst(ICmpInst::ICMP_SGT, CompareVal, SubOne(RHS)); + else + return new ICmpInst(ICmpInst::ICMP_SLT, CompareVal, AddOne(RHS)); + } + + if (LHSI->hasOneUse()) { + // (icmp u/s (xor A SignBit), C) -> (icmp s/u A, (xor C SignBit)) + if (!ICI.isEquality() && XorCST->getValue().isSignBit()) { + const APInt &SignBit = XorCST->getValue(); + ICmpInst::Predicate Pred = ICI.isSignedPredicate() + ? ICI.getUnsignedPredicate() + : ICI.getSignedPredicate(); + return new ICmpInst(Pred, LHSI->getOperand(0), + ConstantInt::get(RHSV ^ SignBit)); + } + + // (icmp u/s (xor A ~SignBit), C) -> (icmp s/u (xor C ~SignBit), A) + if (!ICI.isEquality() && XorCST->getValue().isMaxSignedValue()) { + const APInt &NotSignBit = XorCST->getValue(); + ICmpInst::Predicate Pred = ICI.isSignedPredicate() + ? ICI.getUnsignedPredicate() + : ICI.getSignedPredicate(); + Pred = ICI.getSwappedPredicate(Pred); + return new ICmpInst(Pred, LHSI->getOperand(0), + ConstantInt::get(RHSV ^ NotSignBit)); + } + } + } + break; + case Instruction::And: // (icmp pred (and X, AndCST), RHS) + if (LHSI->hasOneUse() && isa<ConstantInt>(LHSI->getOperand(1)) && + LHSI->getOperand(0)->hasOneUse()) { + ConstantInt *AndCST = cast<ConstantInt>(LHSI->getOperand(1)); + + // If the LHS is an AND of a truncating cast, we can widen the + // and/compare to be the input width without changing the value + // produced, eliminating a cast. + if (TruncInst *Cast = dyn_cast<TruncInst>(LHSI->getOperand(0))) { + // We can do this transformation if either the AND constant does not + // have its sign bit set or if it is an equality comparison. + // Extending a relational comparison when we're checking the sign + // bit would not work. + if (Cast->hasOneUse() && + (ICI.isEquality() || + (AndCST->getValue().isNonNegative() && RHSV.isNonNegative()))) { + uint32_t BitWidth = + cast<IntegerType>(Cast->getOperand(0)->getType())->getBitWidth(); + APInt NewCST = AndCST->getValue(); + NewCST.zext(BitWidth); + APInt NewCI = RHSV; + NewCI.zext(BitWidth); + Instruction *NewAnd = + BinaryOperator::CreateAnd(Cast->getOperand(0), + ConstantInt::get(NewCST),LHSI->getName()); + InsertNewInstBefore(NewAnd, ICI); + return new ICmpInst(ICI.getPredicate(), NewAnd, + ConstantInt::get(NewCI)); + } + } + + // If this is: (X >> C1) & C2 != C3 (where any shift and any compare + // could exist), turn it into (X & (C2 << C1)) != (C3 << C1). This + // happens a LOT in code produced by the C front-end, for bitfield + // access. + BinaryOperator *Shift = dyn_cast<BinaryOperator>(LHSI->getOperand(0)); + if (Shift && !Shift->isShift()) + Shift = 0; + + ConstantInt *ShAmt; + ShAmt = Shift ? dyn_cast<ConstantInt>(Shift->getOperand(1)) : 0; + const Type *Ty = Shift ? Shift->getType() : 0; // Type of the shift. + const Type *AndTy = AndCST->getType(); // Type of the and. + + // We can fold this as long as we can't shift unknown bits + // into the mask. This can only happen with signed shift + // rights, as they sign-extend. + if (ShAmt) { + bool CanFold = Shift->isLogicalShift(); + if (!CanFold) { + // To test for the bad case of the signed shr, see if any + // of the bits shifted in could be tested after the mask. + uint32_t TyBits = Ty->getPrimitiveSizeInBits(); + int ShAmtVal = TyBits - ShAmt->getLimitedValue(TyBits); + + uint32_t BitWidth = AndTy->getPrimitiveSizeInBits(); + if ((APInt::getHighBitsSet(BitWidth, BitWidth-ShAmtVal) & + AndCST->getValue()) == 0) + CanFold = true; + } + + if (CanFold) { + Constant *NewCst; + if (Shift->getOpcode() == Instruction::Shl) + NewCst = ConstantExpr::getLShr(RHS, ShAmt); + else + NewCst = ConstantExpr::getShl(RHS, ShAmt); + + // Check to see if we are shifting out any of the bits being + // compared. + if (ConstantExpr::get(Shift->getOpcode(), NewCst, ShAmt) != RHS) { + // If we shifted bits out, the fold is not going to work out. + // As a special case, check to see if this means that the + // result is always true or false now. + if (ICI.getPredicate() == ICmpInst::ICMP_EQ) + return ReplaceInstUsesWith(ICI, ConstantInt::getFalse()); + if (ICI.getPredicate() == ICmpInst::ICMP_NE) + return ReplaceInstUsesWith(ICI, ConstantInt::getTrue()); + } else { + ICI.setOperand(1, NewCst); + Constant *NewAndCST; + if (Shift->getOpcode() == Instruction::Shl) + NewAndCST = ConstantExpr::getLShr(AndCST, ShAmt); + else + NewAndCST = ConstantExpr::getShl(AndCST, ShAmt); + LHSI->setOperand(1, NewAndCST); + LHSI->setOperand(0, Shift->getOperand(0)); + AddToWorkList(Shift); // Shift is dead. + AddUsesToWorkList(ICI); + return &ICI; + } + } + } + + // Turn ((X >> Y) & C) == 0 into (X & (C << Y)) == 0. The later is + // preferable because it allows the C<<Y expression to be hoisted out + // of a loop if Y is invariant and X is not. + if (Shift && Shift->hasOneUse() && RHSV == 0 && + ICI.isEquality() && !Shift->isArithmeticShift() && + !isa<Constant>(Shift->getOperand(0))) { + // Compute C << Y. + Value *NS; + if (Shift->getOpcode() == Instruction::LShr) { + NS = BinaryOperator::CreateShl(AndCST, + Shift->getOperand(1), "tmp"); + } else { + // Insert a logical shift. + NS = BinaryOperator::CreateLShr(AndCST, + Shift->getOperand(1), "tmp"); + } + InsertNewInstBefore(cast<Instruction>(NS), ICI); + + // Compute X & (C << Y). + Instruction *NewAnd = + BinaryOperator::CreateAnd(Shift->getOperand(0), NS, LHSI->getName()); + InsertNewInstBefore(NewAnd, ICI); + + ICI.setOperand(0, NewAnd); + return &ICI; + } + } + break; + + case Instruction::Shl: { // (icmp pred (shl X, ShAmt), CI) + ConstantInt *ShAmt = dyn_cast<ConstantInt>(LHSI->getOperand(1)); + if (!ShAmt) break; + + uint32_t TypeBits = RHSV.getBitWidth(); + + // Check that the shift amount is in range. If not, don't perform + // undefined shifts. When the shift is visited it will be + // simplified. + if (ShAmt->uge(TypeBits)) + break; + + if (ICI.isEquality()) { + // If we are comparing against bits always shifted out, the + // comparison cannot succeed. + Constant *Comp = + ConstantExpr::getShl(ConstantExpr::getLShr(RHS, ShAmt), ShAmt); + if (Comp != RHS) {// Comparing against a bit that we know is zero. + bool IsICMP_NE = ICI.getPredicate() == ICmpInst::ICMP_NE; + Constant *Cst = ConstantInt::get(Type::Int1Ty, IsICMP_NE); + return ReplaceInstUsesWith(ICI, Cst); + } + + if (LHSI->hasOneUse()) { + // Otherwise strength reduce the shift into an and. + uint32_t ShAmtVal = (uint32_t)ShAmt->getLimitedValue(TypeBits); + Constant *Mask = + ConstantInt::get(APInt::getLowBitsSet(TypeBits, TypeBits-ShAmtVal)); + + Instruction *AndI = + BinaryOperator::CreateAnd(LHSI->getOperand(0), + Mask, LHSI->getName()+".mask"); + Value *And = InsertNewInstBefore(AndI, ICI); + return new ICmpInst(ICI.getPredicate(), And, + ConstantInt::get(RHSV.lshr(ShAmtVal))); + } + } + + // Otherwise, if this is a comparison of the sign bit, simplify to and/test. + bool TrueIfSigned = false; + if (LHSI->hasOneUse() && + isSignBitCheck(ICI.getPredicate(), RHS, TrueIfSigned)) { + // (X << 31) <s 0 --> (X&1) != 0 + Constant *Mask = ConstantInt::get(APInt(TypeBits, 1) << + (TypeBits-ShAmt->getZExtValue()-1)); + Instruction *AndI = + BinaryOperator::CreateAnd(LHSI->getOperand(0), + Mask, LHSI->getName()+".mask"); + Value *And = InsertNewInstBefore(AndI, ICI); + + return new ICmpInst(TrueIfSigned ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ, + And, Constant::getNullValue(And->getType())); + } + break; + } + + case Instruction::LShr: // (icmp pred (shr X, ShAmt), CI) + case Instruction::AShr: { + // Only handle equality comparisons of shift-by-constant. + ConstantInt *ShAmt = dyn_cast<ConstantInt>(LHSI->getOperand(1)); + if (!ShAmt || !ICI.isEquality()) break; + + // Check that the shift amount is in range. If not, don't perform + // undefined shifts. When the shift is visited it will be + // simplified. + uint32_t TypeBits = RHSV.getBitWidth(); + if (ShAmt->uge(TypeBits)) + break; + + uint32_t ShAmtVal = (uint32_t)ShAmt->getLimitedValue(TypeBits); + + // If we are comparing against bits always shifted out, the + // comparison cannot succeed. + APInt Comp = RHSV << ShAmtVal; + if (LHSI->getOpcode() == Instruction::LShr) + Comp = Comp.lshr(ShAmtVal); + else + Comp = Comp.ashr(ShAmtVal); + + if (Comp != RHSV) { // Comparing against a bit that we know is zero. + bool IsICMP_NE = ICI.getPredicate() == ICmpInst::ICMP_NE; + Constant *Cst = ConstantInt::get(Type::Int1Ty, IsICMP_NE); + return ReplaceInstUsesWith(ICI, Cst); + } + + // Otherwise, check to see if the bits shifted out are known to be zero. + // If so, we can compare against the unshifted value: + // (X & 4) >> 1 == 2 --> (X & 4) == 4. + if (LHSI->hasOneUse() && + MaskedValueIsZero(LHSI->getOperand(0), + APInt::getLowBitsSet(Comp.getBitWidth(), ShAmtVal))) { + return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0), + ConstantExpr::getShl(RHS, ShAmt)); + } + + if (LHSI->hasOneUse()) { + // Otherwise strength reduce the shift into an and. + APInt Val(APInt::getHighBitsSet(TypeBits, TypeBits - ShAmtVal)); + Constant *Mask = ConstantInt::get(Val); + + Instruction *AndI = + BinaryOperator::CreateAnd(LHSI->getOperand(0), + Mask, LHSI->getName()+".mask"); + Value *And = InsertNewInstBefore(AndI, ICI); + return new ICmpInst(ICI.getPredicate(), And, + ConstantExpr::getShl(RHS, ShAmt)); + } + break; + } + + case Instruction::SDiv: + case Instruction::UDiv: + // Fold: icmp pred ([us]div X, C1), C2 -> range test + // Fold this div into the comparison, producing a range check. + // Determine, based on the divide type, what the range is being + // checked. If there is an overflow on the low or high side, remember + // it, otherwise compute the range [low, hi) bounding the new value. + // See: InsertRangeTest above for the kinds of replacements possible. + if (ConstantInt *DivRHS = dyn_cast<ConstantInt>(LHSI->getOperand(1))) + if (Instruction *R = FoldICmpDivCst(ICI, cast<BinaryOperator>(LHSI), + DivRHS)) + return R; + break; + + case Instruction::Add: + // Fold: icmp pred (add, X, C1), C2 + + if (!ICI.isEquality()) { + ConstantInt *LHSC = dyn_cast<ConstantInt>(LHSI->getOperand(1)); + if (!LHSC) break; + const APInt &LHSV = LHSC->getValue(); + + ConstantRange CR = ICI.makeConstantRange(ICI.getPredicate(), RHSV) + .subtract(LHSV); + + if (ICI.isSignedPredicate()) { + if (CR.getLower().isSignBit()) { + return new ICmpInst(ICmpInst::ICMP_SLT, LHSI->getOperand(0), + ConstantInt::get(CR.getUpper())); + } else if (CR.getUpper().isSignBit()) { + return new ICmpInst(ICmpInst::ICMP_SGE, LHSI->getOperand(0), + ConstantInt::get(CR.getLower())); + } + } else { + if (CR.getLower().isMinValue()) { + return new ICmpInst(ICmpInst::ICMP_ULT, LHSI->getOperand(0), + ConstantInt::get(CR.getUpper())); + } else if (CR.getUpper().isMinValue()) { + return new ICmpInst(ICmpInst::ICMP_UGE, LHSI->getOperand(0), + ConstantInt::get(CR.getLower())); + } + } + } + break; + } + + // Simplify icmp_eq and icmp_ne instructions with integer constant RHS. + if (ICI.isEquality()) { + bool isICMP_NE = ICI.getPredicate() == ICmpInst::ICMP_NE; + + // If the first operand is (add|sub|and|or|xor|rem) with a constant, and + // the second operand is a constant, simplify a bit. + if (BinaryOperator *BO = dyn_cast<BinaryOperator>(LHSI)) { + switch (BO->getOpcode()) { + case Instruction::SRem: + // If we have a signed (X % (2^c)) == 0, turn it into an unsigned one. + if (RHSV == 0 && isa<ConstantInt>(BO->getOperand(1)) &&BO->hasOneUse()){ + const APInt &V = cast<ConstantInt>(BO->getOperand(1))->getValue(); + if (V.sgt(APInt(V.getBitWidth(), 1)) && V.isPowerOf2()) { + Instruction *NewRem = + BinaryOperator::CreateURem(BO->getOperand(0), BO->getOperand(1), + BO->getName()); + InsertNewInstBefore(NewRem, ICI); + return new ICmpInst(ICI.getPredicate(), NewRem, + Constant::getNullValue(BO->getType())); + } + } + break; + case Instruction::Add: + // Replace ((add A, B) != C) with (A != C-B) if B & C are constants. + if (ConstantInt *BOp1C = dyn_cast<ConstantInt>(BO->getOperand(1))) { + if (BO->hasOneUse()) + return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), + Subtract(RHS, BOp1C)); + } else if (RHSV == 0) { + // Replace ((add A, B) != 0) with (A != -B) if A or B is + // efficiently invertible, or if the add has just this one use. + Value *BOp0 = BO->getOperand(0), *BOp1 = BO->getOperand(1); + + if (Value *NegVal = dyn_castNegVal(BOp1)) + return new ICmpInst(ICI.getPredicate(), BOp0, NegVal); + else if (Value *NegVal = dyn_castNegVal(BOp0)) + return new ICmpInst(ICI.getPredicate(), NegVal, BOp1); + else if (BO->hasOneUse()) { + Instruction *Neg = BinaryOperator::CreateNeg(BOp1); + InsertNewInstBefore(Neg, ICI); + Neg->takeName(BO); + return new ICmpInst(ICI.getPredicate(), BOp0, Neg); + } + } + break; + case Instruction::Xor: + // For the xor case, we can xor two constants together, eliminating + // the explicit xor. + if (Constant *BOC = dyn_cast<Constant>(BO->getOperand(1))) + return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), + ConstantExpr::getXor(RHS, BOC)); + + // FALLTHROUGH + case Instruction::Sub: + // Replace (([sub|xor] A, B) != 0) with (A != B) + if (RHSV == 0) + return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), + BO->getOperand(1)); + break; + + case Instruction::Or: + // If bits are being or'd in that are not present in the constant we + // are comparing against, then the comparison could never succeed! + if (Constant *BOC = dyn_cast<Constant>(BO->getOperand(1))) { + Constant *NotCI = ConstantExpr::getNot(RHS); + if (!ConstantExpr::getAnd(BOC, NotCI)->isNullValue()) + return ReplaceInstUsesWith(ICI, ConstantInt::get(Type::Int1Ty, + isICMP_NE)); + } + break; + + case Instruction::And: + if (ConstantInt *BOC = dyn_cast<ConstantInt>(BO->getOperand(1))) { + // If bits are being compared against that are and'd out, then the + // comparison can never succeed! + if ((RHSV & ~BOC->getValue()) != 0) + return ReplaceInstUsesWith(ICI, ConstantInt::get(Type::Int1Ty, + isICMP_NE)); + + // If we have ((X & C) == C), turn it into ((X & C) != 0). + if (RHS == BOC && RHSV.isPowerOf2()) + return new ICmpInst(isICMP_NE ? ICmpInst::ICMP_EQ : + ICmpInst::ICMP_NE, LHSI, + Constant::getNullValue(RHS->getType())); + + // Replace (and X, (1 << size(X)-1) != 0) with x s< 0 + if (BOC->getValue().isSignBit()) { + Value *X = BO->getOperand(0); + Constant *Zero = Constant::getNullValue(X->getType()); + ICmpInst::Predicate pred = isICMP_NE ? + ICmpInst::ICMP_SLT : ICmpInst::ICMP_SGE; + return new ICmpInst(pred, X, Zero); + } + + // ((X & ~7) == 0) --> X < 8 + if (RHSV == 0 && isHighOnes(BOC)) { + Value *X = BO->getOperand(0); + Constant *NegX = ConstantExpr::getNeg(BOC); + ICmpInst::Predicate pred = isICMP_NE ? + ICmpInst::ICMP_UGE : ICmpInst::ICMP_ULT; + return new ICmpInst(pred, X, NegX); + } + } + default: break; + } + } else if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(LHSI)) { + // Handle icmp {eq|ne} <intrinsic>, intcst. + if (II->getIntrinsicID() == Intrinsic::bswap) { + AddToWorkList(II); + ICI.setOperand(0, II->getOperand(1)); + ICI.setOperand(1, ConstantInt::get(RHSV.byteSwap())); + return &ICI; + } + } + } + return 0; +} + +/// visitICmpInstWithCastAndCast - Handle icmp (cast x to y), (cast/cst). +/// We only handle extending casts so far. +/// +Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICI) { + const CastInst *LHSCI = cast<CastInst>(ICI.getOperand(0)); + Value *LHSCIOp = LHSCI->getOperand(0); + const Type *SrcTy = LHSCIOp->getType(); + const Type *DestTy = LHSCI->getType(); + Value *RHSCIOp; + + // Turn icmp (ptrtoint x), (ptrtoint/c) into a compare of the input if the + // integer type is the same size as the pointer type. + if (LHSCI->getOpcode() == Instruction::PtrToInt && + getTargetData().getPointerSizeInBits() == + cast<IntegerType>(DestTy)->getBitWidth()) { + Value *RHSOp = 0; + if (Constant *RHSC = dyn_cast<Constant>(ICI.getOperand(1))) { + RHSOp = ConstantExpr::getIntToPtr(RHSC, SrcTy); + } else if (PtrToIntInst *RHSC = dyn_cast<PtrToIntInst>(ICI.getOperand(1))) { + RHSOp = RHSC->getOperand(0); + // If the pointer types don't match, insert a bitcast. + if (LHSCIOp->getType() != RHSOp->getType()) + RHSOp = InsertBitCastBefore(RHSOp, LHSCIOp->getType(), ICI); + } + + if (RHSOp) + return new ICmpInst(ICI.getPredicate(), LHSCIOp, RHSOp); + } + + // The code below only handles extension cast instructions, so far. + // Enforce this. + if (LHSCI->getOpcode() != Instruction::ZExt && + LHSCI->getOpcode() != Instruction::SExt) + return 0; + + bool isSignedExt = LHSCI->getOpcode() == Instruction::SExt; + bool isSignedCmp = ICI.isSignedPredicate(); + + if (CastInst *CI = dyn_cast<CastInst>(ICI.getOperand(1))) { + // Not an extension from the same type? + RHSCIOp = CI->getOperand(0); + if (RHSCIOp->getType() != LHSCIOp->getType()) + return 0; + + // If the signedness of the two casts doesn't agree (i.e. one is a sext + // and the other is a zext), then we can't handle this. + if (CI->getOpcode() != LHSCI->getOpcode()) + return 0; + + // Deal with equality cases early. + if (ICI.isEquality()) + return new ICmpInst(ICI.getPredicate(), LHSCIOp, RHSCIOp); + + // A signed comparison of sign extended values simplifies into a + // signed comparison. + if (isSignedCmp && isSignedExt) + return new ICmpInst(ICI.getPredicate(), LHSCIOp, RHSCIOp); + + // The other three cases all fold into an unsigned comparison. + return new ICmpInst(ICI.getUnsignedPredicate(), LHSCIOp, RHSCIOp); + } + + // If we aren't dealing with a constant on the RHS, exit early + ConstantInt *CI = dyn_cast<ConstantInt>(ICI.getOperand(1)); + if (!CI) + return 0; + + // Compute the constant that would happen if we truncated to SrcTy then + // reextended to DestTy. + Constant *Res1 = ConstantExpr::getTrunc(CI, SrcTy); + Constant *Res2 = ConstantExpr::getCast(LHSCI->getOpcode(), Res1, DestTy); + + // If the re-extended constant didn't change... + if (Res2 == CI) { + // Make sure that sign of the Cmp and the sign of the Cast are the same. + // For example, we might have: + // %A = sext short %X to uint + // %B = icmp ugt uint %A, 1330 + // It is incorrect to transform this into + // %B = icmp ugt short %X, 1330 + // because %A may have negative value. + // + // However, we allow this when the compare is EQ/NE, because they are + // signless. + if (isSignedExt == isSignedCmp || ICI.isEquality()) + return new ICmpInst(ICI.getPredicate(), LHSCIOp, Res1); + return 0; + } + + // The re-extended constant changed so the constant cannot be represented + // in the shorter type. Consequently, we cannot emit a simple comparison. + + // First, handle some easy cases. We know the result cannot be equal at this + // point so handle the ICI.isEquality() cases + if (ICI.getPredicate() == ICmpInst::ICMP_EQ) + return ReplaceInstUsesWith(ICI, ConstantInt::getFalse()); + if (ICI.getPredicate() == ICmpInst::ICMP_NE) + return ReplaceInstUsesWith(ICI, ConstantInt::getTrue()); + + // Evaluate the comparison for LT (we invert for GT below). LE and GE cases + // should have been folded away previously and not enter in here. + Value *Result; + if (isSignedCmp) { + // We're performing a signed comparison. + if (cast<ConstantInt>(CI)->getValue().isNegative()) + Result = ConstantInt::getFalse(); // X < (small) --> false + else + Result = ConstantInt::getTrue(); // X < (large) --> true + } else { + // We're performing an unsigned comparison. + if (isSignedExt) { + // We're performing an unsigned comp with a sign extended value. + // This is true if the input is >= 0. [aka >s -1] + Constant *NegOne = ConstantInt::getAllOnesValue(SrcTy); + Result = InsertNewInstBefore(new ICmpInst(ICmpInst::ICMP_SGT, LHSCIOp, + NegOne, ICI.getName()), ICI); + } else { + // Unsigned extend & unsigned compare -> always true. + Result = ConstantInt::getTrue(); + } + } + + // Finally, return the value computed. + if (ICI.getPredicate() == ICmpInst::ICMP_ULT || + ICI.getPredicate() == ICmpInst::ICMP_SLT) + return ReplaceInstUsesWith(ICI, Result); + + assert((ICI.getPredicate()==ICmpInst::ICMP_UGT || + ICI.getPredicate()==ICmpInst::ICMP_SGT) && + "ICmp should be folded!"); + if (Constant *CI = dyn_cast<Constant>(Result)) + return ReplaceInstUsesWith(ICI, ConstantExpr::getNot(CI)); + return BinaryOperator::CreateNot(Result); +} + +Instruction *InstCombiner::visitShl(BinaryOperator &I) { + return commonShiftTransforms(I); +} + +Instruction *InstCombiner::visitLShr(BinaryOperator &I) { + return commonShiftTransforms(I); +} + +Instruction *InstCombiner::visitAShr(BinaryOperator &I) { + if (Instruction *R = commonShiftTransforms(I)) + return R; + + Value *Op0 = I.getOperand(0); + + // ashr int -1, X = -1 (for any arithmetic shift rights of ~0) + if (ConstantInt *CSI = dyn_cast<ConstantInt>(Op0)) + if (CSI->isAllOnesValue()) + return ReplaceInstUsesWith(I, CSI); + + // See if we can turn a signed shr into an unsigned shr. + if (!isa<VectorType>(I.getType())) { + if (MaskedValueIsZero(Op0, + APInt::getSignBit(I.getType()->getPrimitiveSizeInBits()))) + return BinaryOperator::CreateLShr(Op0, I.getOperand(1)); + + // Arithmetic shifting an all-sign-bit value is a no-op. + unsigned NumSignBits = ComputeNumSignBits(Op0); + if (NumSignBits == Op0->getType()->getPrimitiveSizeInBits()) + return ReplaceInstUsesWith(I, Op0); + } + + return 0; +} + +Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { + assert(I.getOperand(1)->getType() == I.getOperand(0)->getType()); + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + // shl X, 0 == X and shr X, 0 == X + // shl 0, X == 0 and shr 0, X == 0 + if (Op1 == Constant::getNullValue(Op1->getType()) || + Op0 == Constant::getNullValue(Op0->getType())) + return ReplaceInstUsesWith(I, Op0); + + if (isa<UndefValue>(Op0)) { + if (I.getOpcode() == Instruction::AShr) // undef >>s X -> undef + return ReplaceInstUsesWith(I, Op0); + else // undef << X -> 0, undef >>u X -> 0 + return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); + } + if (isa<UndefValue>(Op1)) { + if (I.getOpcode() == Instruction::AShr) // X >>s undef -> X + return ReplaceInstUsesWith(I, Op0); + else // X << undef, X >>u undef -> 0 + return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); + } + + // See if we can fold away this shift. + if (!isa<VectorType>(I.getType()) && SimplifyDemandedInstructionBits(I)) + return &I; + + // Try to fold constant and into select arguments. + if (isa<Constant>(Op0)) + if (SelectInst *SI = dyn_cast<SelectInst>(Op1)) + if (Instruction *R = FoldOpIntoSelect(I, SI, this)) + return R; + + if (ConstantInt *CUI = dyn_cast<ConstantInt>(Op1)) + if (Instruction *Res = FoldShiftByConstant(Op0, CUI, I)) + return Res; + return 0; +} + +Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, ConstantInt *Op1, + BinaryOperator &I) { + bool isLeftShift = I.getOpcode() == Instruction::Shl; + + // See if we can simplify any instructions used by the instruction whose sole + // purpose is to compute bits we don't care about. + uint32_t TypeBits = Op0->getType()->getPrimitiveSizeInBits(); + + // shl uint X, 32 = 0 and shr ubyte Y, 9 = 0, ... just don't eliminate shr + // of a signed value. + // + if (Op1->uge(TypeBits)) { + if (I.getOpcode() != Instruction::AShr) + return ReplaceInstUsesWith(I, Constant::getNullValue(Op0->getType())); + else { + I.setOperand(1, ConstantInt::get(I.getType(), TypeBits-1)); + return &I; + } + } + + // ((X*C1) << C2) == (X * (C1 << C2)) + if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Op0)) + if (BO->getOpcode() == Instruction::Mul && isLeftShift) + if (Constant *BOOp = dyn_cast<Constant>(BO->getOperand(1))) + return BinaryOperator::CreateMul(BO->getOperand(0), + ConstantExpr::getShl(BOOp, Op1)); + + // Try to fold constant and into select arguments. + if (SelectInst *SI = dyn_cast<SelectInst>(Op0)) + if (Instruction *R = FoldOpIntoSelect(I, SI, this)) + return R; + if (isa<PHINode>(Op0)) + if (Instruction *NV = FoldOpIntoPhi(I)) + return NV; + + // Fold shift2(trunc(shift1(x,c1)), c2) -> trunc(shift2(shift1(x,c1),c2)) + if (TruncInst *TI = dyn_cast<TruncInst>(Op0)) { + Instruction *TrOp = dyn_cast<Instruction>(TI->getOperand(0)); + // If 'shift2' is an ashr, we would have to get the sign bit into a funny + // place. Don't try to do this transformation in this case. Also, we + // require that the input operand is a shift-by-constant so that we have + // confidence that the shifts will get folded together. We could do this + // xform in more cases, but it is unlikely to be profitable. + if (TrOp && I.isLogicalShift() && TrOp->isShift() && + isa<ConstantInt>(TrOp->getOperand(1))) { + // Okay, we'll do this xform. Make the shift of shift. + Constant *ShAmt = ConstantExpr::getZExt(Op1, TrOp->getType()); + Instruction *NSh = BinaryOperator::Create(I.getOpcode(), TrOp, ShAmt, + I.getName()); + InsertNewInstBefore(NSh, I); // (shift2 (shift1 & 0x00FF), c2) + + // For logical shifts, the truncation has the effect of making the high + // part of the register be zeros. Emulate this by inserting an AND to + // clear the top bits as needed. This 'and' will usually be zapped by + // other xforms later if dead. + unsigned SrcSize = TrOp->getType()->getPrimitiveSizeInBits(); + unsigned DstSize = TI->getType()->getPrimitiveSizeInBits(); + APInt MaskV(APInt::getLowBitsSet(SrcSize, DstSize)); + + // The mask we constructed says what the trunc would do if occurring + // between the shifts. We want to know the effect *after* the second + // shift. We know that it is a logical shift by a constant, so adjust the + // mask as appropriate. + if (I.getOpcode() == Instruction::Shl) + MaskV <<= Op1->getZExtValue(); + else { + assert(I.getOpcode() == Instruction::LShr && "Unknown logical shift"); + MaskV = MaskV.lshr(Op1->getZExtValue()); + } + + Instruction *And = BinaryOperator::CreateAnd(NSh, ConstantInt::get(MaskV), + TI->getName()); + InsertNewInstBefore(And, I); // shift1 & 0x00FF + + // Return the value truncated to the interesting size. + return new TruncInst(And, I.getType()); + } + } + + if (Op0->hasOneUse()) { + if (BinaryOperator *Op0BO = dyn_cast<BinaryOperator>(Op0)) { + // Turn ((X >> C) + Y) << C -> (X + (Y << C)) & (~0 << C) + Value *V1, *V2; + ConstantInt *CC; + switch (Op0BO->getOpcode()) { + default: break; + case Instruction::Add: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: { + // These operators commute. + // Turn (Y + (X >> C)) << C -> (X + (Y << C)) & (~0 << C) + if (isLeftShift && Op0BO->getOperand(1)->hasOneUse() && + match(Op0BO->getOperand(1), m_Shr(m_Value(V1), m_Specific(Op1)))){ + Instruction *YS = BinaryOperator::CreateShl( + Op0BO->getOperand(0), Op1, + Op0BO->getName()); + InsertNewInstBefore(YS, I); // (Y << C) + Instruction *X = + BinaryOperator::Create(Op0BO->getOpcode(), YS, V1, + Op0BO->getOperand(1)->getName()); + InsertNewInstBefore(X, I); // (X + (Y << C)) + uint32_t Op1Val = Op1->getLimitedValue(TypeBits); + return BinaryOperator::CreateAnd(X, ConstantInt::get( + APInt::getHighBitsSet(TypeBits, TypeBits-Op1Val))); + } + + // Turn (Y + ((X >> C) & CC)) << C -> ((X & (CC << C)) + (Y << C)) + Value *Op0BOOp1 = Op0BO->getOperand(1); + if (isLeftShift && Op0BOOp1->hasOneUse() && + match(Op0BOOp1, + m_And(m_Shr(m_Value(V1), m_Specific(Op1)), + m_ConstantInt(CC))) && + cast<BinaryOperator>(Op0BOOp1)->getOperand(0)->hasOneUse()) { + Instruction *YS = BinaryOperator::CreateShl( + Op0BO->getOperand(0), Op1, + Op0BO->getName()); + InsertNewInstBefore(YS, I); // (Y << C) + Instruction *XM = + BinaryOperator::CreateAnd(V1, ConstantExpr::getShl(CC, Op1), + V1->getName()+".mask"); + InsertNewInstBefore(XM, I); // X & (CC << C) + + return BinaryOperator::Create(Op0BO->getOpcode(), YS, XM); + } + } + + // FALL THROUGH. + case Instruction::Sub: { + // Turn ((X >> C) + Y) << C -> (X + (Y << C)) & (~0 << C) + if (isLeftShift && Op0BO->getOperand(0)->hasOneUse() && + match(Op0BO->getOperand(0), m_Shr(m_Value(V1), m_Specific(Op1)))){ + Instruction *YS = BinaryOperator::CreateShl( + Op0BO->getOperand(1), Op1, + Op0BO->getName()); + InsertNewInstBefore(YS, I); // (Y << C) + Instruction *X = + BinaryOperator::Create(Op0BO->getOpcode(), V1, YS, + Op0BO->getOperand(0)->getName()); + InsertNewInstBefore(X, I); // (X + (Y << C)) + uint32_t Op1Val = Op1->getLimitedValue(TypeBits); + return BinaryOperator::CreateAnd(X, ConstantInt::get( + APInt::getHighBitsSet(TypeBits, TypeBits-Op1Val))); + } + + // Turn (((X >> C)&CC) + Y) << C -> (X + (Y << C)) & (CC << C) + if (isLeftShift && Op0BO->getOperand(0)->hasOneUse() && + match(Op0BO->getOperand(0), + m_And(m_Shr(m_Value(V1), m_Value(V2)), + m_ConstantInt(CC))) && V2 == Op1 && + cast<BinaryOperator>(Op0BO->getOperand(0)) + ->getOperand(0)->hasOneUse()) { + Instruction *YS = BinaryOperator::CreateShl( + Op0BO->getOperand(1), Op1, + Op0BO->getName()); + InsertNewInstBefore(YS, I); // (Y << C) + Instruction *XM = + BinaryOperator::CreateAnd(V1, ConstantExpr::getShl(CC, Op1), + V1->getName()+".mask"); + InsertNewInstBefore(XM, I); // X & (CC << C) + + return BinaryOperator::Create(Op0BO->getOpcode(), XM, YS); + } + + break; + } + } + + + // If the operand is an bitwise operator with a constant RHS, and the + // shift is the only use, we can pull it out of the shift. + if (ConstantInt *Op0C = dyn_cast<ConstantInt>(Op0BO->getOperand(1))) { + bool isValid = true; // Valid only for And, Or, Xor + bool highBitSet = false; // Transform if high bit of constant set? + + switch (Op0BO->getOpcode()) { + default: isValid = false; break; // Do not perform transform! + case Instruction::Add: + isValid = isLeftShift; + break; + case Instruction::Or: + case Instruction::Xor: + highBitSet = false; + break; + case Instruction::And: + highBitSet = true; + break; + } + + // If this is a signed shift right, and the high bit is modified + // by the logical operation, do not perform the transformation. + // The highBitSet boolean indicates the value of the high bit of + // the constant which would cause it to be modified for this + // operation. + // + if (isValid && I.getOpcode() == Instruction::AShr) + isValid = Op0C->getValue()[TypeBits-1] == highBitSet; + + if (isValid) { + Constant *NewRHS = ConstantExpr::get(I.getOpcode(), Op0C, Op1); + + Instruction *NewShift = + BinaryOperator::Create(I.getOpcode(), Op0BO->getOperand(0), Op1); + InsertNewInstBefore(NewShift, I); + NewShift->takeName(Op0BO); + + return BinaryOperator::Create(Op0BO->getOpcode(), NewShift, + NewRHS); + } + } + } + } + + // Find out if this is a shift of a shift by a constant. + BinaryOperator *ShiftOp = dyn_cast<BinaryOperator>(Op0); + if (ShiftOp && !ShiftOp->isShift()) + ShiftOp = 0; + + if (ShiftOp && isa<ConstantInt>(ShiftOp->getOperand(1))) { + ConstantInt *ShiftAmt1C = cast<ConstantInt>(ShiftOp->getOperand(1)); + uint32_t ShiftAmt1 = ShiftAmt1C->getLimitedValue(TypeBits); + uint32_t ShiftAmt2 = Op1->getLimitedValue(TypeBits); + assert(ShiftAmt2 != 0 && "Should have been simplified earlier"); + if (ShiftAmt1 == 0) return 0; // Will be simplified in the future. + Value *X = ShiftOp->getOperand(0); + + uint32_t AmtSum = ShiftAmt1+ShiftAmt2; // Fold into one big shift. + + const IntegerType *Ty = cast<IntegerType>(I.getType()); + + // Check for (X << c1) << c2 and (X >> c1) >> c2 + if (I.getOpcode() == ShiftOp->getOpcode()) { + // If this is oversized composite shift, then unsigned shifts get 0, ashr + // saturates. + if (AmtSum >= TypeBits) { + if (I.getOpcode() != Instruction::AShr) + return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); + AmtSum = TypeBits-1; // Saturate to 31 for i32 ashr. + } + + return BinaryOperator::Create(I.getOpcode(), X, + ConstantInt::get(Ty, AmtSum)); + } else if (ShiftOp->getOpcode() == Instruction::LShr && + I.getOpcode() == Instruction::AShr) { + if (AmtSum >= TypeBits) + return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); + + // ((X >>u C1) >>s C2) -> (X >>u (C1+C2)) since C1 != 0. + return BinaryOperator::CreateLShr(X, ConstantInt::get(Ty, AmtSum)); + } else if (ShiftOp->getOpcode() == Instruction::AShr && + I.getOpcode() == Instruction::LShr) { + // ((X >>s C1) >>u C2) -> ((X >>s (C1+C2)) & mask) since C1 != 0. + if (AmtSum >= TypeBits) + AmtSum = TypeBits-1; + + Instruction *Shift = + BinaryOperator::CreateAShr(X, ConstantInt::get(Ty, AmtSum)); + InsertNewInstBefore(Shift, I); + + APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt2)); + return BinaryOperator::CreateAnd(Shift, ConstantInt::get(Mask)); + } + + // Okay, if we get here, one shift must be left, and the other shift must be + // right. See if the amounts are equal. + if (ShiftAmt1 == ShiftAmt2) { + // If we have ((X >>? C) << C), turn this into X & (-1 << C). + if (I.getOpcode() == Instruction::Shl) { + APInt Mask(APInt::getHighBitsSet(TypeBits, TypeBits - ShiftAmt1)); + return BinaryOperator::CreateAnd(X, ConstantInt::get(Mask)); + } + // If we have ((X << C) >>u C), turn this into X & (-1 >>u C). + if (I.getOpcode() == Instruction::LShr) { + APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt1)); + return BinaryOperator::CreateAnd(X, ConstantInt::get(Mask)); + } + // We can simplify ((X << C) >>s C) into a trunc + sext. + // NOTE: we could do this for any C, but that would make 'unusual' integer + // types. For now, just stick to ones well-supported by the code + // generators. + const Type *SExtType = 0; + switch (Ty->getBitWidth() - ShiftAmt1) { + case 1 : + case 8 : + case 16 : + case 32 : + case 64 : + case 128: + SExtType = IntegerType::get(Ty->getBitWidth() - ShiftAmt1); + break; + default: break; + } + if (SExtType) { + Instruction *NewTrunc = new TruncInst(X, SExtType, "sext"); + InsertNewInstBefore(NewTrunc, I); + return new SExtInst(NewTrunc, Ty); + } + // Otherwise, we can't handle it yet. + } else if (ShiftAmt1 < ShiftAmt2) { + uint32_t ShiftDiff = ShiftAmt2-ShiftAmt1; + + // (X >>? C1) << C2 --> X << (C2-C1) & (-1 << C2) + if (I.getOpcode() == Instruction::Shl) { + assert(ShiftOp->getOpcode() == Instruction::LShr || + ShiftOp->getOpcode() == Instruction::AShr); + Instruction *Shift = + BinaryOperator::CreateShl(X, ConstantInt::get(Ty, ShiftDiff)); + InsertNewInstBefore(Shift, I); + + APInt Mask(APInt::getHighBitsSet(TypeBits, TypeBits - ShiftAmt2)); + return BinaryOperator::CreateAnd(Shift, ConstantInt::get(Mask)); + } + + // (X << C1) >>u C2 --> X >>u (C2-C1) & (-1 >> C2) + if (I.getOpcode() == Instruction::LShr) { + assert(ShiftOp->getOpcode() == Instruction::Shl); + Instruction *Shift = + BinaryOperator::CreateLShr(X, ConstantInt::get(Ty, ShiftDiff)); + InsertNewInstBefore(Shift, I); + + APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt2)); + return BinaryOperator::CreateAnd(Shift, ConstantInt::get(Mask)); + } + + // We can't handle (X << C1) >>s C2, it shifts arbitrary bits in. + } else { + assert(ShiftAmt2 < ShiftAmt1); + uint32_t ShiftDiff = ShiftAmt1-ShiftAmt2; + + // (X >>? C1) << C2 --> X >>? (C1-C2) & (-1 << C2) + if (I.getOpcode() == Instruction::Shl) { + assert(ShiftOp->getOpcode() == Instruction::LShr || + ShiftOp->getOpcode() == Instruction::AShr); + Instruction *Shift = + BinaryOperator::Create(ShiftOp->getOpcode(), X, + ConstantInt::get(Ty, ShiftDiff)); + InsertNewInstBefore(Shift, I); + + APInt Mask(APInt::getHighBitsSet(TypeBits, TypeBits - ShiftAmt2)); + return BinaryOperator::CreateAnd(Shift, ConstantInt::get(Mask)); + } + + // (X << C1) >>u C2 --> X << (C1-C2) & (-1 >> C2) + if (I.getOpcode() == Instruction::LShr) { + assert(ShiftOp->getOpcode() == Instruction::Shl); + Instruction *Shift = + BinaryOperator::CreateShl(X, ConstantInt::get(Ty, ShiftDiff)); + InsertNewInstBefore(Shift, I); + + APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt2)); + return BinaryOperator::CreateAnd(Shift, ConstantInt::get(Mask)); + } + + // We can't handle (X << C1) >>a C2, it shifts arbitrary bits in. + } + } + return 0; +} + + +/// DecomposeSimpleLinearExpr - Analyze 'Val', seeing if it is a simple linear +/// expression. If so, decompose it, returning some value X, such that Val is +/// X*Scale+Offset. +/// +static Value *DecomposeSimpleLinearExpr(Value *Val, unsigned &Scale, + int &Offset) { + assert(Val->getType() == Type::Int32Ty && "Unexpected allocation size type!"); + if (ConstantInt *CI = dyn_cast<ConstantInt>(Val)) { + Offset = CI->getZExtValue(); + Scale = 0; + return ConstantInt::get(Type::Int32Ty, 0); + } else if (BinaryOperator *I = dyn_cast<BinaryOperator>(Val)) { + if (ConstantInt *RHS = dyn_cast<ConstantInt>(I->getOperand(1))) { + if (I->getOpcode() == Instruction::Shl) { + // This is a value scaled by '1 << the shift amt'. + Scale = 1U << RHS->getZExtValue(); + Offset = 0; + return I->getOperand(0); + } else if (I->getOpcode() == Instruction::Mul) { + // This value is scaled by 'RHS'. + Scale = RHS->getZExtValue(); + Offset = 0; + return I->getOperand(0); + } else if (I->getOpcode() == Instruction::Add) { + // We have X+C. Check to see if we really have (X*C2)+C1, + // where C1 is divisible by C2. + unsigned SubScale; + Value *SubVal = + DecomposeSimpleLinearExpr(I->getOperand(0), SubScale, Offset); + Offset += RHS->getZExtValue(); + Scale = SubScale; + return SubVal; + } + } + } + + // Otherwise, we can't look past this. + Scale = 1; + Offset = 0; + return Val; +} + + +/// PromoteCastOfAllocation - If we find a cast of an allocation instruction, +/// try to eliminate the cast by moving the type information into the alloc. +Instruction *InstCombiner::PromoteCastOfAllocation(BitCastInst &CI, + AllocationInst &AI) { + const PointerType *PTy = cast<PointerType>(CI.getType()); + + // Remove any uses of AI that are dead. + assert(!CI.use_empty() && "Dead instructions should be removed earlier!"); + + for (Value::use_iterator UI = AI.use_begin(), E = AI.use_end(); UI != E; ) { + Instruction *User = cast<Instruction>(*UI++); + if (isInstructionTriviallyDead(User)) { + while (UI != E && *UI == User) + ++UI; // If this instruction uses AI more than once, don't break UI. + + ++NumDeadInst; + DOUT << "IC: DCE: " << *User; + EraseInstFromFunction(*User); + } + } + + // Get the type really allocated and the type casted to. + const Type *AllocElTy = AI.getAllocatedType(); + const Type *CastElTy = PTy->getElementType(); + if (!AllocElTy->isSized() || !CastElTy->isSized()) return 0; + + unsigned AllocElTyAlign = TD->getABITypeAlignment(AllocElTy); + unsigned CastElTyAlign = TD->getABITypeAlignment(CastElTy); + if (CastElTyAlign < AllocElTyAlign) return 0; + + // If the allocation has multiple uses, only promote it if we are strictly + // increasing the alignment of the resultant allocation. If we keep it the + // same, we open the door to infinite loops of various kinds. (A reference + // from a dbg.declare doesn't count as a use for this purpose.) + if (!AI.hasOneUse() && !hasOneUsePlusDeclare(&AI) && + CastElTyAlign == AllocElTyAlign) return 0; + + uint64_t AllocElTySize = TD->getTypeAllocSize(AllocElTy); + uint64_t CastElTySize = TD->getTypeAllocSize(CastElTy); + if (CastElTySize == 0 || AllocElTySize == 0) return 0; + + // See if we can satisfy the modulus by pulling a scale out of the array + // size argument. + unsigned ArraySizeScale; + int ArrayOffset; + Value *NumElements = // See if the array size is a decomposable linear expr. + DecomposeSimpleLinearExpr(AI.getOperand(0), ArraySizeScale, ArrayOffset); + + // If we can now satisfy the modulus, by using a non-1 scale, we really can + // do the xform. + if ((AllocElTySize*ArraySizeScale) % CastElTySize != 0 || + (AllocElTySize*ArrayOffset ) % CastElTySize != 0) return 0; + + unsigned Scale = (AllocElTySize*ArraySizeScale)/CastElTySize; + Value *Amt = 0; + if (Scale == 1) { + Amt = NumElements; + } else { + // If the allocation size is constant, form a constant mul expression + Amt = ConstantInt::get(Type::Int32Ty, Scale); + if (isa<ConstantInt>(NumElements)) + Amt = Multiply(cast<ConstantInt>(NumElements), cast<ConstantInt>(Amt)); + // otherwise multiply the amount and the number of elements + else { + Instruction *Tmp = BinaryOperator::CreateMul(Amt, NumElements, "tmp"); + Amt = InsertNewInstBefore(Tmp, AI); + } + } + + if (int Offset = (AllocElTySize*ArrayOffset)/CastElTySize) { + Value *Off = ConstantInt::get(Type::Int32Ty, Offset, true); + Instruction *Tmp = BinaryOperator::CreateAdd(Amt, Off, "tmp"); + Amt = InsertNewInstBefore(Tmp, AI); + } + + AllocationInst *New; + if (isa<MallocInst>(AI)) + New = new MallocInst(CastElTy, Amt, AI.getAlignment()); + else + New = new AllocaInst(CastElTy, Amt, AI.getAlignment()); + InsertNewInstBefore(New, AI); + New->takeName(&AI); + + // If the allocation has one real use plus a dbg.declare, just remove the + // declare. + if (DbgDeclareInst *DI = hasOneUsePlusDeclare(&AI)) { + EraseInstFromFunction(*DI); + } + // If the allocation has multiple real uses, insert a cast and change all + // things that used it to use the new cast. This will also hack on CI, but it + // will die soon. + else if (!AI.hasOneUse()) { + AddUsesToWorkList(AI); + // New is the allocation instruction, pointer typed. AI is the original + // allocation instruction, also pointer typed. Thus, cast to use is BitCast. + CastInst *NewCast = new BitCastInst(New, AI.getType(), "tmpcast"); + InsertNewInstBefore(NewCast, AI); + AI.replaceAllUsesWith(NewCast); + } + return ReplaceInstUsesWith(CI, New); +} + +/// CanEvaluateInDifferentType - Return true if we can take the specified value +/// and return it as type Ty without inserting any new casts and without +/// changing the computed value. This is used by code that tries to decide +/// whether promoting or shrinking integer operations to wider or smaller types +/// will allow us to eliminate a truncate or extend. +/// +/// This is a truncation operation if Ty is smaller than V->getType(), or an +/// extension operation if Ty is larger. +/// +/// If CastOpc is a truncation, then Ty will be a type smaller than V. We +/// should return true if trunc(V) can be computed by computing V in the smaller +/// type. If V is an instruction, then trunc(inst(x,y)) can be computed as +/// inst(trunc(x),trunc(y)), which only makes sense if x and y can be +/// efficiently truncated. +/// +/// If CastOpc is a sext or zext, we are asking if the low bits of the value can +/// bit computed in a larger type, which is then and'd or sext_in_reg'd to get +/// the final result. +bool InstCombiner::CanEvaluateInDifferentType(Value *V, const IntegerType *Ty, + unsigned CastOpc, + int &NumCastsRemoved){ + // We can always evaluate constants in another type. + if (isa<ConstantInt>(V)) + return true; + + Instruction *I = dyn_cast<Instruction>(V); + if (!I) return false; + + const IntegerType *OrigTy = cast<IntegerType>(V->getType()); + + // If this is an extension or truncate, we can often eliminate it. + if (isa<TruncInst>(I) || isa<ZExtInst>(I) || isa<SExtInst>(I)) { + // If this is a cast from the destination type, we can trivially eliminate + // it, and this will remove a cast overall. + if (I->getOperand(0)->getType() == Ty) { + // If the first operand is itself a cast, and is eliminable, do not count + // this as an eliminable cast. We would prefer to eliminate those two + // casts first. + if (!isa<CastInst>(I->getOperand(0)) && I->hasOneUse()) + ++NumCastsRemoved; + return true; + } + } + + // We can't extend or shrink something that has multiple uses: doing so would + // require duplicating the instruction in general, which isn't profitable. + if (!I->hasOneUse()) return false; + + unsigned Opc = I->getOpcode(); + switch (Opc) { + case Instruction::Add: + case Instruction::Sub: + case Instruction::Mul: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + // These operators can all arbitrarily be extended or truncated. + return CanEvaluateInDifferentType(I->getOperand(0), Ty, CastOpc, + NumCastsRemoved) && + CanEvaluateInDifferentType(I->getOperand(1), Ty, CastOpc, + NumCastsRemoved); + + case Instruction::Shl: + // If we are truncating the result of this SHL, and if it's a shift of a + // constant amount, we can always perform a SHL in a smaller type. + if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) { + uint32_t BitWidth = Ty->getBitWidth(); + if (BitWidth < OrigTy->getBitWidth() && + CI->getLimitedValue(BitWidth) < BitWidth) + return CanEvaluateInDifferentType(I->getOperand(0), Ty, CastOpc, + NumCastsRemoved); + } + break; + case Instruction::LShr: + // If this is a truncate of a logical shr, we can truncate it to a smaller + // lshr iff we know that the bits we would otherwise be shifting in are + // already zeros. + if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) { + uint32_t OrigBitWidth = OrigTy->getBitWidth(); + uint32_t BitWidth = Ty->getBitWidth(); + if (BitWidth < OrigBitWidth && + MaskedValueIsZero(I->getOperand(0), + APInt::getHighBitsSet(OrigBitWidth, OrigBitWidth-BitWidth)) && + CI->getLimitedValue(BitWidth) < BitWidth) { + return CanEvaluateInDifferentType(I->getOperand(0), Ty, CastOpc, + NumCastsRemoved); + } + } + break; + case Instruction::ZExt: + case Instruction::SExt: + case Instruction::Trunc: + // If this is the same kind of case as our original (e.g. zext+zext), we + // can safely replace it. Note that replacing it does not reduce the number + // of casts in the input. + if (Opc == CastOpc) + return true; + + // sext (zext ty1), ty2 -> zext ty2 + if (CastOpc == Instruction::SExt && Opc == Instruction::ZExt) + return true; + break; + case Instruction::Select: { + SelectInst *SI = cast<SelectInst>(I); + return CanEvaluateInDifferentType(SI->getTrueValue(), Ty, CastOpc, + NumCastsRemoved) && + CanEvaluateInDifferentType(SI->getFalseValue(), Ty, CastOpc, + NumCastsRemoved); + } + case Instruction::PHI: { + // We can change a phi if we can change all operands. + PHINode *PN = cast<PHINode>(I); + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) + if (!CanEvaluateInDifferentType(PN->getIncomingValue(i), Ty, CastOpc, + NumCastsRemoved)) + return false; + return true; + } + default: + // TODO: Can handle more cases here. + break; + } + + return false; +} + +/// EvaluateInDifferentType - Given an expression that +/// CanEvaluateInDifferentType returns true for, actually insert the code to +/// evaluate the expression. +Value *InstCombiner::EvaluateInDifferentType(Value *V, const Type *Ty, + bool isSigned) { + if (Constant *C = dyn_cast<Constant>(V)) + return ConstantExpr::getIntegerCast(C, Ty, isSigned /*Sext or ZExt*/); + + // Otherwise, it must be an instruction. + Instruction *I = cast<Instruction>(V); + Instruction *Res = 0; + unsigned Opc = I->getOpcode(); + switch (Opc) { + case Instruction::Add: + case Instruction::Sub: + case Instruction::Mul: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + case Instruction::AShr: + case Instruction::LShr: + case Instruction::Shl: { + Value *LHS = EvaluateInDifferentType(I->getOperand(0), Ty, isSigned); + Value *RHS = EvaluateInDifferentType(I->getOperand(1), Ty, isSigned); + Res = BinaryOperator::Create((Instruction::BinaryOps)Opc, LHS, RHS); + break; + } + case Instruction::Trunc: + case Instruction::ZExt: + case Instruction::SExt: + // If the source type of the cast is the type we're trying for then we can + // just return the source. There's no need to insert it because it is not + // new. + if (I->getOperand(0)->getType() == Ty) + return I->getOperand(0); + + // Otherwise, must be the same type of cast, so just reinsert a new one. + Res = CastInst::Create(cast<CastInst>(I)->getOpcode(), I->getOperand(0), + Ty); + break; + case Instruction::Select: { + Value *True = EvaluateInDifferentType(I->getOperand(1), Ty, isSigned); + Value *False = EvaluateInDifferentType(I->getOperand(2), Ty, isSigned); + Res = SelectInst::Create(I->getOperand(0), True, False); + break; + } + case Instruction::PHI: { + PHINode *OPN = cast<PHINode>(I); + PHINode *NPN = PHINode::Create(Ty); + for (unsigned i = 0, e = OPN->getNumIncomingValues(); i != e; ++i) { + Value *V =EvaluateInDifferentType(OPN->getIncomingValue(i), Ty, isSigned); + NPN->addIncoming(V, OPN->getIncomingBlock(i)); + } + Res = NPN; + break; + } + default: + // TODO: Can handle more cases here. + assert(0 && "Unreachable!"); + break; + } + + Res->takeName(I); + return InsertNewInstBefore(Res, *I); +} + +/// @brief Implement the transforms common to all CastInst visitors. +Instruction *InstCombiner::commonCastTransforms(CastInst &CI) { + Value *Src = CI.getOperand(0); + + // Many cases of "cast of a cast" are eliminable. If it's eliminable we just + // eliminate it now. + if (CastInst *CSrc = dyn_cast<CastInst>(Src)) { // A->B->C cast + if (Instruction::CastOps opc = + isEliminableCastPair(CSrc, CI.getOpcode(), CI.getType(), TD)) { + // The first cast (CSrc) is eliminable so we need to fix up or replace + // the second cast (CI). CSrc will then have a good chance of being dead. + return CastInst::Create(opc, CSrc->getOperand(0), CI.getType()); + } + } + + // If we are casting a select then fold the cast into the select + if (SelectInst *SI = dyn_cast<SelectInst>(Src)) + if (Instruction *NV = FoldOpIntoSelect(CI, SI, this)) + return NV; + + // If we are casting a PHI then fold the cast into the PHI + if (isa<PHINode>(Src)) + if (Instruction *NV = FoldOpIntoPhi(CI)) + return NV; + + return 0; +} + +/// FindElementAtOffset - Given a type and a constant offset, determine whether +/// or not there is a sequence of GEP indices into the type that will land us at +/// the specified offset. If so, fill them into NewIndices and return the +/// resultant element type, otherwise return null. +static const Type *FindElementAtOffset(const Type *Ty, int64_t Offset, + SmallVectorImpl<Value*> &NewIndices, + const TargetData *TD) { + if (!Ty->isSized()) return 0; + + // Start with the index over the outer type. Note that the type size + // might be zero (even if the offset isn't zero) if the indexed type + // is something like [0 x {int, int}] + const Type *IntPtrTy = TD->getIntPtrType(); + int64_t FirstIdx = 0; + if (int64_t TySize = TD->getTypeAllocSize(Ty)) { + FirstIdx = Offset/TySize; + Offset -= FirstIdx*TySize; + + // Handle hosts where % returns negative instead of values [0..TySize). + if (Offset < 0) { + --FirstIdx; + Offset += TySize; + assert(Offset >= 0); + } + assert((uint64_t)Offset < (uint64_t)TySize && "Out of range offset"); + } + + NewIndices.push_back(ConstantInt::get(IntPtrTy, FirstIdx)); + + // Index into the types. If we fail, set OrigBase to null. + while (Offset) { + // Indexing into tail padding between struct/array elements. + if (uint64_t(Offset*8) >= TD->getTypeSizeInBits(Ty)) + return 0; + + if (const StructType *STy = dyn_cast<StructType>(Ty)) { + const StructLayout *SL = TD->getStructLayout(STy); + assert(Offset < (int64_t)SL->getSizeInBytes() && + "Offset must stay within the indexed type"); + + unsigned Elt = SL->getElementContainingOffset(Offset); + NewIndices.push_back(ConstantInt::get(Type::Int32Ty, Elt)); + + Offset -= SL->getElementOffset(Elt); + Ty = STy->getElementType(Elt); + } else if (const ArrayType *AT = dyn_cast<ArrayType>(Ty)) { + uint64_t EltSize = TD->getTypeAllocSize(AT->getElementType()); + assert(EltSize && "Cannot index into a zero-sized array"); + NewIndices.push_back(ConstantInt::get(IntPtrTy,Offset/EltSize)); + Offset %= EltSize; + Ty = AT->getElementType(); + } else { + // Otherwise, we can't index into the middle of this atomic type, bail. + return 0; + } + } + + return Ty; +} + +/// @brief Implement the transforms for cast of pointer (bitcast/ptrtoint) +Instruction *InstCombiner::commonPointerCastTransforms(CastInst &CI) { + Value *Src = CI.getOperand(0); + + if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Src)) { + // If casting the result of a getelementptr instruction with no offset, turn + // this into a cast of the original pointer! + if (GEP->hasAllZeroIndices()) { + // Changing the cast operand is usually not a good idea but it is safe + // here because the pointer operand is being replaced with another + // pointer operand so the opcode doesn't need to change. + AddToWorkList(GEP); + CI.setOperand(0, GEP->getOperand(0)); + return &CI; + } + + // If the GEP has a single use, and the base pointer is a bitcast, and the + // GEP computes a constant offset, see if we can convert these three + // instructions into fewer. This typically happens with unions and other + // non-type-safe code. + if (GEP->hasOneUse() && isa<BitCastInst>(GEP->getOperand(0))) { + if (GEP->hasAllConstantIndices()) { + // We are guaranteed to get a constant from EmitGEPOffset. + ConstantInt *OffsetV = cast<ConstantInt>(EmitGEPOffset(GEP, CI, *this)); + int64_t Offset = OffsetV->getSExtValue(); + + // Get the base pointer input of the bitcast, and the type it points to. + Value *OrigBase = cast<BitCastInst>(GEP->getOperand(0))->getOperand(0); + const Type *GEPIdxTy = + cast<PointerType>(OrigBase->getType())->getElementType(); + SmallVector<Value*, 8> NewIndices; + if (FindElementAtOffset(GEPIdxTy, Offset, NewIndices, TD)) { + // If we were able to index down into an element, create the GEP + // and bitcast the result. This eliminates one bitcast, potentially + // two. + Instruction *NGEP = GetElementPtrInst::Create(OrigBase, + NewIndices.begin(), + NewIndices.end(), ""); + InsertNewInstBefore(NGEP, CI); + NGEP->takeName(GEP); + + if (isa<BitCastInst>(CI)) + return new BitCastInst(NGEP, CI.getType()); + assert(isa<PtrToIntInst>(CI)); + return new PtrToIntInst(NGEP, CI.getType()); + } + } + } + } + + return commonCastTransforms(CI); +} + +/// isSafeIntegerType - Return true if this is a basic integer type, not a crazy +/// type like i42. We don't want to introduce operations on random non-legal +/// integer types where they don't already exist in the code. In the future, +/// we should consider making this based off target-data, so that 32-bit targets +/// won't get i64 operations etc. +static bool isSafeIntegerType(const Type *Ty) { + switch (Ty->getPrimitiveSizeInBits()) { + case 8: + case 16: + case 32: + case 64: + return true; + default: + return false; + } +} + +/// Only the TRUNC, ZEXT, SEXT, and BITCAST can both operand and result as +/// integer types. This function implements the common transforms for all those +/// cases. +/// @brief Implement the transforms common to CastInst with integer operands +Instruction *InstCombiner::commonIntCastTransforms(CastInst &CI) { + if (Instruction *Result = commonCastTransforms(CI)) + return Result; + + Value *Src = CI.getOperand(0); + const Type *SrcTy = Src->getType(); + const Type *DestTy = CI.getType(); + uint32_t SrcBitSize = SrcTy->getPrimitiveSizeInBits(); + uint32_t DestBitSize = DestTy->getPrimitiveSizeInBits(); + + // See if we can simplify any instructions used by the LHS whose sole + // purpose is to compute bits we don't care about. + if (SimplifyDemandedInstructionBits(CI)) + return &CI; + + // If the source isn't an instruction or has more than one use then we + // can't do anything more. + Instruction *SrcI = dyn_cast<Instruction>(Src); + if (!SrcI || !Src->hasOneUse()) + return 0; + + // Attempt to propagate the cast into the instruction for int->int casts. + int NumCastsRemoved = 0; + if (!isa<BitCastInst>(CI) && + // Only do this if the dest type is a simple type, don't convert the + // expression tree to something weird like i93 unless the source is also + // strange. + (isSafeIntegerType(DestTy) || !isSafeIntegerType(SrcI->getType())) && + CanEvaluateInDifferentType(SrcI, cast<IntegerType>(DestTy), + CI.getOpcode(), NumCastsRemoved)) { + // If this cast is a truncate, evaluting in a different type always + // eliminates the cast, so it is always a win. If this is a zero-extension, + // we need to do an AND to maintain the clear top-part of the computation, + // so we require that the input have eliminated at least one cast. If this + // is a sign extension, we insert two new casts (to do the extension) so we + // require that two casts have been eliminated. + bool DoXForm = false; + bool JustReplace = false; + switch (CI.getOpcode()) { + default: + // All the others use floating point so we shouldn't actually + // get here because of the check above. + assert(0 && "Unknown cast type"); + case Instruction::Trunc: + DoXForm = true; + break; + case Instruction::ZExt: { + DoXForm = NumCastsRemoved >= 1; + if (!DoXForm && 0) { + // If it's unnecessary to issue an AND to clear the high bits, it's + // always profitable to do this xform. + Value *TryRes = EvaluateInDifferentType(SrcI, DestTy, false); + APInt Mask(APInt::getBitsSet(DestBitSize, SrcBitSize, DestBitSize)); + if (MaskedValueIsZero(TryRes, Mask)) + return ReplaceInstUsesWith(CI, TryRes); + + if (Instruction *TryI = dyn_cast<Instruction>(TryRes)) + if (TryI->use_empty()) + EraseInstFromFunction(*TryI); + } + break; + } + case Instruction::SExt: { + DoXForm = NumCastsRemoved >= 2; + if (!DoXForm && !isa<TruncInst>(SrcI) && 0) { + // If we do not have to emit the truncate + sext pair, then it's always + // profitable to do this xform. + // + // It's not safe to eliminate the trunc + sext pair if one of the + // eliminated cast is a truncate. e.g. + // t2 = trunc i32 t1 to i16 + // t3 = sext i16 t2 to i32 + // != + // i32 t1 + Value *TryRes = EvaluateInDifferentType(SrcI, DestTy, true); + unsigned NumSignBits = ComputeNumSignBits(TryRes); + if (NumSignBits > (DestBitSize - SrcBitSize)) + return ReplaceInstUsesWith(CI, TryRes); + + if (Instruction *TryI = dyn_cast<Instruction>(TryRes)) + if (TryI->use_empty()) + EraseInstFromFunction(*TryI); + } + break; + } + } + + if (DoXForm) { + DOUT << "ICE: EvaluateInDifferentType converting expression type to avoid" + << " cast: " << CI; + Value *Res = EvaluateInDifferentType(SrcI, DestTy, + CI.getOpcode() == Instruction::SExt); + if (JustReplace) + // Just replace this cast with the result. + return ReplaceInstUsesWith(CI, Res); + + assert(Res->getType() == DestTy); + switch (CI.getOpcode()) { + default: assert(0 && "Unknown cast type!"); + case Instruction::Trunc: + case Instruction::BitCast: + // Just replace this cast with the result. + return ReplaceInstUsesWith(CI, Res); + case Instruction::ZExt: { + assert(SrcBitSize < DestBitSize && "Not a zext?"); + + // If the high bits are already zero, just replace this cast with the + // result. + APInt Mask(APInt::getBitsSet(DestBitSize, SrcBitSize, DestBitSize)); + if (MaskedValueIsZero(Res, Mask)) + return ReplaceInstUsesWith(CI, Res); + + // We need to emit an AND to clear the high bits. + Constant *C = ConstantInt::get(APInt::getLowBitsSet(DestBitSize, + SrcBitSize)); + return BinaryOperator::CreateAnd(Res, C); + } + case Instruction::SExt: { + // If the high bits are already filled with sign bit, just replace this + // cast with the result. + unsigned NumSignBits = ComputeNumSignBits(Res); + if (NumSignBits > (DestBitSize - SrcBitSize)) + return ReplaceInstUsesWith(CI, Res); + + // We need to emit a cast to truncate, then a cast to sext. + return CastInst::Create(Instruction::SExt, + InsertCastBefore(Instruction::Trunc, Res, Src->getType(), + CI), DestTy); + } + } + } + } + + Value *Op0 = SrcI->getNumOperands() > 0 ? SrcI->getOperand(0) : 0; + Value *Op1 = SrcI->getNumOperands() > 1 ? SrcI->getOperand(1) : 0; + + switch (SrcI->getOpcode()) { + case Instruction::Add: + case Instruction::Mul: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + // If we are discarding information, rewrite. + if (DestBitSize <= SrcBitSize && DestBitSize != 1) { + // Don't insert two casts if they cannot be eliminated. We allow + // two casts to be inserted if the sizes are the same. This could + // only be converting signedness, which is a noop. + if (DestBitSize == SrcBitSize || + !ValueRequiresCast(CI.getOpcode(), Op1, DestTy,TD) || + !ValueRequiresCast(CI.getOpcode(), Op0, DestTy, TD)) { + Instruction::CastOps opcode = CI.getOpcode(); + Value *Op0c = InsertCastBefore(opcode, Op0, DestTy, *SrcI); + Value *Op1c = InsertCastBefore(opcode, Op1, DestTy, *SrcI); + return BinaryOperator::Create( + cast<BinaryOperator>(SrcI)->getOpcode(), Op0c, Op1c); + } + } + + // cast (xor bool X, true) to int --> xor (cast bool X to int), 1 + if (isa<ZExtInst>(CI) && SrcBitSize == 1 && + SrcI->getOpcode() == Instruction::Xor && + Op1 == ConstantInt::getTrue() && + (!Op0->hasOneUse() || !isa<CmpInst>(Op0))) { + Value *New = InsertCastBefore(Instruction::ZExt, Op0, DestTy, CI); + return BinaryOperator::CreateXor(New, ConstantInt::get(CI.getType(), 1)); + } + break; + case Instruction::SDiv: + case Instruction::UDiv: + case Instruction::SRem: + case Instruction::URem: + // If we are just changing the sign, rewrite. + if (DestBitSize == SrcBitSize) { + // Don't insert two casts if they cannot be eliminated. We allow + // two casts to be inserted if the sizes are the same. This could + // only be converting signedness, which is a noop. + if (!ValueRequiresCast(CI.getOpcode(), Op1, DestTy, TD) || + !ValueRequiresCast(CI.getOpcode(), Op0, DestTy, TD)) { + Value *Op0c = InsertCastBefore(Instruction::BitCast, + Op0, DestTy, *SrcI); + Value *Op1c = InsertCastBefore(Instruction::BitCast, + Op1, DestTy, *SrcI); + return BinaryOperator::Create( + cast<BinaryOperator>(SrcI)->getOpcode(), Op0c, Op1c); + } + } + break; + + case Instruction::Shl: + // Allow changing the sign of the source operand. Do not allow + // changing the size of the shift, UNLESS the shift amount is a + // constant. We must not change variable sized shifts to a smaller + // size, because it is undefined to shift more bits out than exist + // in the value. + if (DestBitSize == SrcBitSize || + (DestBitSize < SrcBitSize && isa<Constant>(Op1))) { + Instruction::CastOps opcode = (DestBitSize == SrcBitSize ? + Instruction::BitCast : Instruction::Trunc); + Value *Op0c = InsertCastBefore(opcode, Op0, DestTy, *SrcI); + Value *Op1c = InsertCastBefore(opcode, Op1, DestTy, *SrcI); + return BinaryOperator::CreateShl(Op0c, Op1c); + } + break; + case Instruction::AShr: + // If this is a signed shr, and if all bits shifted in are about to be + // truncated off, turn it into an unsigned shr to allow greater + // simplifications. + if (DestBitSize < SrcBitSize && + isa<ConstantInt>(Op1)) { + uint32_t ShiftAmt = cast<ConstantInt>(Op1)->getLimitedValue(SrcBitSize); + if (SrcBitSize > ShiftAmt && SrcBitSize-ShiftAmt >= DestBitSize) { + // Insert the new logical shift right. + return BinaryOperator::CreateLShr(Op0, Op1); + } + } + break; + } + return 0; +} + +Instruction *InstCombiner::visitTrunc(TruncInst &CI) { + if (Instruction *Result = commonIntCastTransforms(CI)) + return Result; + + Value *Src = CI.getOperand(0); + const Type *Ty = CI.getType(); + uint32_t DestBitWidth = Ty->getPrimitiveSizeInBits(); + uint32_t SrcBitWidth = cast<IntegerType>(Src->getType())->getBitWidth(); + + // Canonicalize trunc x to i1 -> (icmp ne (and x, 1), 0) + if (DestBitWidth == 1) { + Constant *One = ConstantInt::get(Src->getType(), 1); + Src = InsertNewInstBefore(BinaryOperator::CreateAnd(Src, One, "tmp"), CI); + Value *Zero = Constant::getNullValue(Src->getType()); + return new ICmpInst(ICmpInst::ICMP_NE, Src, Zero); + } + + // Optimize trunc(lshr(), c) to pull the shift through the truncate. + ConstantInt *ShAmtV = 0; + Value *ShiftOp = 0; + if (Src->hasOneUse() && + match(Src, m_LShr(m_Value(ShiftOp), m_ConstantInt(ShAmtV)))) { + uint32_t ShAmt = ShAmtV->getLimitedValue(SrcBitWidth); + + // Get a mask for the bits shifting in. + APInt Mask(APInt::getLowBitsSet(SrcBitWidth, ShAmt).shl(DestBitWidth)); + if (MaskedValueIsZero(ShiftOp, Mask)) { + if (ShAmt >= DestBitWidth) // All zeros. + return ReplaceInstUsesWith(CI, Constant::getNullValue(Ty)); + + // Okay, we can shrink this. Truncate the input, then return a new + // shift. + Value *V1 = InsertCastBefore(Instruction::Trunc, ShiftOp, Ty, CI); + Value *V2 = ConstantExpr::getTrunc(ShAmtV, Ty); + return BinaryOperator::CreateLShr(V1, V2); + } + } + + return 0; +} + +/// transformZExtICmp - Transform (zext icmp) to bitwise / integer operations +/// in order to eliminate the icmp. +Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, Instruction &CI, + bool DoXform) { + // If we are just checking for a icmp eq of a single bit and zext'ing it + // to an integer, then shift the bit to the appropriate place and then + // cast to integer to avoid the comparison. + if (ConstantInt *Op1C = dyn_cast<ConstantInt>(ICI->getOperand(1))) { + const APInt &Op1CV = Op1C->getValue(); + + // zext (x <s 0) to i32 --> x>>u31 true if signbit set. + // zext (x >s -1) to i32 --> (x>>u31)^1 true if signbit clear. + if ((ICI->getPredicate() == ICmpInst::ICMP_SLT && Op1CV == 0) || + (ICI->getPredicate() == ICmpInst::ICMP_SGT &&Op1CV.isAllOnesValue())) { + if (!DoXform) return ICI; + + Value *In = ICI->getOperand(0); + Value *Sh = ConstantInt::get(In->getType(), + In->getType()->getPrimitiveSizeInBits()-1); + In = InsertNewInstBefore(BinaryOperator::CreateLShr(In, Sh, + In->getName()+".lobit"), + CI); + if (In->getType() != CI.getType()) + In = CastInst::CreateIntegerCast(In, CI.getType(), + false/*ZExt*/, "tmp", &CI); + + if (ICI->getPredicate() == ICmpInst::ICMP_SGT) { + Constant *One = ConstantInt::get(In->getType(), 1); + In = InsertNewInstBefore(BinaryOperator::CreateXor(In, One, + In->getName()+".not"), + CI); + } + + return ReplaceInstUsesWith(CI, In); + } + + + + // zext (X == 0) to i32 --> X^1 iff X has only the low bit set. + // zext (X == 0) to i32 --> (X>>1)^1 iff X has only the 2nd bit set. + // zext (X == 1) to i32 --> X iff X has only the low bit set. + // zext (X == 2) to i32 --> X>>1 iff X has only the 2nd bit set. + // zext (X != 0) to i32 --> X iff X has only the low bit set. + // zext (X != 0) to i32 --> X>>1 iff X has only the 2nd bit set. + // zext (X != 1) to i32 --> X^1 iff X has only the low bit set. + // zext (X != 2) to i32 --> (X>>1)^1 iff X has only the 2nd bit set. + if ((Op1CV == 0 || Op1CV.isPowerOf2()) && + // This only works for EQ and NE + ICI->isEquality()) { + // If Op1C some other power of two, convert: + uint32_t BitWidth = Op1C->getType()->getBitWidth(); + APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); + APInt TypeMask(APInt::getAllOnesValue(BitWidth)); + ComputeMaskedBits(ICI->getOperand(0), TypeMask, KnownZero, KnownOne); + + APInt KnownZeroMask(~KnownZero); + if (KnownZeroMask.isPowerOf2()) { // Exactly 1 possible 1? + if (!DoXform) return ICI; + + bool isNE = ICI->getPredicate() == ICmpInst::ICMP_NE; + if (Op1CV != 0 && (Op1CV != KnownZeroMask)) { + // (X&4) == 2 --> false + // (X&4) != 2 --> true + Constant *Res = ConstantInt::get(Type::Int1Ty, isNE); + Res = ConstantExpr::getZExt(Res, CI.getType()); + return ReplaceInstUsesWith(CI, Res); + } + + uint32_t ShiftAmt = KnownZeroMask.logBase2(); + Value *In = ICI->getOperand(0); + if (ShiftAmt) { + // Perform a logical shr by shiftamt. + // Insert the shift to put the result in the low bit. + In = InsertNewInstBefore(BinaryOperator::CreateLShr(In, + ConstantInt::get(In->getType(), ShiftAmt), + In->getName()+".lobit"), CI); + } + + if ((Op1CV != 0) == isNE) { // Toggle the low bit. + Constant *One = ConstantInt::get(In->getType(), 1); + In = BinaryOperator::CreateXor(In, One, "tmp"); + InsertNewInstBefore(cast<Instruction>(In), CI); + } + + if (CI.getType() == In->getType()) + return ReplaceInstUsesWith(CI, In); + else + return CastInst::CreateIntegerCast(In, CI.getType(), false/*ZExt*/); + } + } + } + + return 0; +} + +Instruction *InstCombiner::visitZExt(ZExtInst &CI) { + // If one of the common conversion will work .. + if (Instruction *Result = commonIntCastTransforms(CI)) + return Result; + + Value *Src = CI.getOperand(0); + + // If this is a TRUNC followed by a ZEXT then we are dealing with integral + // types and if the sizes are just right we can convert this into a logical + // 'and' which will be much cheaper than the pair of casts. + if (TruncInst *CSrc = dyn_cast<TruncInst>(Src)) { // A->B->C cast + // Get the sizes of the types involved. We know that the intermediate type + // will be smaller than A or C, but don't know the relation between A and C. + Value *A = CSrc->getOperand(0); + unsigned SrcSize = A->getType()->getPrimitiveSizeInBits(); + unsigned MidSize = CSrc->getType()->getPrimitiveSizeInBits(); + unsigned DstSize = CI.getType()->getPrimitiveSizeInBits(); + // If we're actually extending zero bits, then if + // SrcSize < DstSize: zext(a & mask) + // SrcSize == DstSize: a & mask + // SrcSize > DstSize: trunc(a) & mask + if (SrcSize < DstSize) { + APInt AndValue(APInt::getLowBitsSet(SrcSize, MidSize)); + Constant *AndConst = ConstantInt::get(AndValue); + Instruction *And = + BinaryOperator::CreateAnd(A, AndConst, CSrc->getName()+".mask"); + InsertNewInstBefore(And, CI); + return new ZExtInst(And, CI.getType()); + } else if (SrcSize == DstSize) { + APInt AndValue(APInt::getLowBitsSet(SrcSize, MidSize)); + return BinaryOperator::CreateAnd(A, ConstantInt::get(AndValue)); + } else if (SrcSize > DstSize) { + Instruction *Trunc = new TruncInst(A, CI.getType(), "tmp"); + InsertNewInstBefore(Trunc, CI); + APInt AndValue(APInt::getLowBitsSet(DstSize, MidSize)); + return BinaryOperator::CreateAnd(Trunc, ConstantInt::get(AndValue)); + } + } + + if (ICmpInst *ICI = dyn_cast<ICmpInst>(Src)) + return transformZExtICmp(ICI, CI); + + BinaryOperator *SrcI = dyn_cast<BinaryOperator>(Src); + if (SrcI && SrcI->getOpcode() == Instruction::Or) { + // zext (or icmp, icmp) --> or (zext icmp), (zext icmp) if at least one + // of the (zext icmp) will be transformed. + ICmpInst *LHS = dyn_cast<ICmpInst>(SrcI->getOperand(0)); + ICmpInst *RHS = dyn_cast<ICmpInst>(SrcI->getOperand(1)); + if (LHS && RHS && LHS->hasOneUse() && RHS->hasOneUse() && + (transformZExtICmp(LHS, CI, false) || + transformZExtICmp(RHS, CI, false))) { + Value *LCast = InsertCastBefore(Instruction::ZExt, LHS, CI.getType(), CI); + Value *RCast = InsertCastBefore(Instruction::ZExt, RHS, CI.getType(), CI); + return BinaryOperator::Create(Instruction::Or, LCast, RCast); + } + } + + return 0; +} + +Instruction *InstCombiner::visitSExt(SExtInst &CI) { + if (Instruction *I = commonIntCastTransforms(CI)) + return I; + + Value *Src = CI.getOperand(0); + + // Canonicalize sign-extend from i1 to a select. + if (Src->getType() == Type::Int1Ty) + return SelectInst::Create(Src, + ConstantInt::getAllOnesValue(CI.getType()), + Constant::getNullValue(CI.getType())); + + // See if the value being truncated is already sign extended. If so, just + // eliminate the trunc/sext pair. + if (getOpcode(Src) == Instruction::Trunc) { + Value *Op = cast<User>(Src)->getOperand(0); + unsigned OpBits = cast<IntegerType>(Op->getType())->getBitWidth(); + unsigned MidBits = cast<IntegerType>(Src->getType())->getBitWidth(); + unsigned DestBits = cast<IntegerType>(CI.getType())->getBitWidth(); + unsigned NumSignBits = ComputeNumSignBits(Op); + + if (OpBits == DestBits) { + // Op is i32, Mid is i8, and Dest is i32. If Op has more than 24 sign + // bits, it is already ready. + if (NumSignBits > DestBits-MidBits) + return ReplaceInstUsesWith(CI, Op); + } else if (OpBits < DestBits) { + // Op is i32, Mid is i8, and Dest is i64. If Op has more than 24 sign + // bits, just sext from i32. + if (NumSignBits > OpBits-MidBits) + return new SExtInst(Op, CI.getType(), "tmp"); + } else { + // Op is i64, Mid is i8, and Dest is i32. If Op has more than 56 sign + // bits, just truncate to i32. + if (NumSignBits > OpBits-MidBits) + return new TruncInst(Op, CI.getType(), "tmp"); + } + } + + // If the input is a shl/ashr pair of a same constant, then this is a sign + // extension from a smaller value. If we could trust arbitrary bitwidth + // integers, we could turn this into a truncate to the smaller bit and then + // use a sext for the whole extension. Since we don't, look deeper and check + // for a truncate. If the source and dest are the same type, eliminate the + // trunc and extend and just do shifts. For example, turn: + // %a = trunc i32 %i to i8 + // %b = shl i8 %a, 6 + // %c = ashr i8 %b, 6 + // %d = sext i8 %c to i32 + // into: + // %a = shl i32 %i, 30 + // %d = ashr i32 %a, 30 + Value *A = 0; + ConstantInt *BA = 0, *CA = 0; + if (match(Src, m_AShr(m_Shl(m_Value(A), m_ConstantInt(BA)), + m_ConstantInt(CA))) && + BA == CA && isa<TruncInst>(A)) { + Value *I = cast<TruncInst>(A)->getOperand(0); + if (I->getType() == CI.getType()) { + unsigned MidSize = Src->getType()->getPrimitiveSizeInBits(); + unsigned SrcDstSize = CI.getType()->getPrimitiveSizeInBits(); + unsigned ShAmt = CA->getZExtValue()+SrcDstSize-MidSize; + Constant *ShAmtV = ConstantInt::get(CI.getType(), ShAmt); + I = InsertNewInstBefore(BinaryOperator::CreateShl(I, ShAmtV, + CI.getName()), CI); + return BinaryOperator::CreateAShr(I, ShAmtV); + } + } + + return 0; +} + +/// FitsInFPType - Return a Constant* for the specified FP constant if it fits +/// in the specified FP type without changing its value. +static Constant *FitsInFPType(ConstantFP *CFP, const fltSemantics &Sem) { + bool losesInfo; + APFloat F = CFP->getValueAPF(); + (void)F.convert(Sem, APFloat::rmNearestTiesToEven, &losesInfo); + if (!losesInfo) + return ConstantFP::get(F); + return 0; +} + +/// LookThroughFPExtensions - If this is an fp extension instruction, look +/// through it until we get the source value. +static Value *LookThroughFPExtensions(Value *V) { + if (Instruction *I = dyn_cast<Instruction>(V)) + if (I->getOpcode() == Instruction::FPExt) + return LookThroughFPExtensions(I->getOperand(0)); + + // If this value is a constant, return the constant in the smallest FP type + // that can accurately represent it. This allows us to turn + // (float)((double)X+2.0) into x+2.0f. + if (ConstantFP *CFP = dyn_cast<ConstantFP>(V)) { + if (CFP->getType() == Type::PPC_FP128Ty) + return V; // No constant folding of this. + // See if the value can be truncated to float and then reextended. + if (Value *V = FitsInFPType(CFP, APFloat::IEEEsingle)) + return V; + if (CFP->getType() == Type::DoubleTy) + return V; // Won't shrink. + if (Value *V = FitsInFPType(CFP, APFloat::IEEEdouble)) + return V; + // Don't try to shrink to various long double types. + } + + return V; +} + +Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) { + if (Instruction *I = commonCastTransforms(CI)) + return I; + + // If we have fptrunc(add (fpextend x), (fpextend y)), where x and y are + // smaller than the destination type, we can eliminate the truncate by doing + // the add as the smaller type. This applies to add/sub/mul/div as well as + // many builtins (sqrt, etc). + BinaryOperator *OpI = dyn_cast<BinaryOperator>(CI.getOperand(0)); + if (OpI && OpI->hasOneUse()) { + switch (OpI->getOpcode()) { + default: break; + case Instruction::Add: + case Instruction::Sub: + case Instruction::Mul: + case Instruction::FDiv: + case Instruction::FRem: + const Type *SrcTy = OpI->getType(); + Value *LHSTrunc = LookThroughFPExtensions(OpI->getOperand(0)); + Value *RHSTrunc = LookThroughFPExtensions(OpI->getOperand(1)); + if (LHSTrunc->getType() != SrcTy && + RHSTrunc->getType() != SrcTy) { + unsigned DstSize = CI.getType()->getPrimitiveSizeInBits(); + // If the source types were both smaller than the destination type of + // the cast, do this xform. + if (LHSTrunc->getType()->getPrimitiveSizeInBits() <= DstSize && + RHSTrunc->getType()->getPrimitiveSizeInBits() <= DstSize) { + LHSTrunc = InsertCastBefore(Instruction::FPExt, LHSTrunc, + CI.getType(), CI); + RHSTrunc = InsertCastBefore(Instruction::FPExt, RHSTrunc, + CI.getType(), CI); + return BinaryOperator::Create(OpI->getOpcode(), LHSTrunc, RHSTrunc); + } + } + break; + } + } + return 0; +} + +Instruction *InstCombiner::visitFPExt(CastInst &CI) { + return commonCastTransforms(CI); +} + +Instruction *InstCombiner::visitFPToUI(FPToUIInst &FI) { + Instruction *OpI = dyn_cast<Instruction>(FI.getOperand(0)); + if (OpI == 0) + return commonCastTransforms(FI); + + // fptoui(uitofp(X)) --> X + // fptoui(sitofp(X)) --> X + // This is safe if the intermediate type has enough bits in its mantissa to + // accurately represent all values of X. For example, do not do this with + // i64->float->i64. This is also safe for sitofp case, because any negative + // 'X' value would cause an undefined result for the fptoui. + if ((isa<UIToFPInst>(OpI) || isa<SIToFPInst>(OpI)) && + OpI->getOperand(0)->getType() == FI.getType() && + (int)FI.getType()->getPrimitiveSizeInBits() < /*extra bit for sign */ + OpI->getType()->getFPMantissaWidth()) + return ReplaceInstUsesWith(FI, OpI->getOperand(0)); + + return commonCastTransforms(FI); +} + +Instruction *InstCombiner::visitFPToSI(FPToSIInst &FI) { + Instruction *OpI = dyn_cast<Instruction>(FI.getOperand(0)); + if (OpI == 0) + return commonCastTransforms(FI); + + // fptosi(sitofp(X)) --> X + // fptosi(uitofp(X)) --> X + // This is safe if the intermediate type has enough bits in its mantissa to + // accurately represent all values of X. For example, do not do this with + // i64->float->i64. This is also safe for sitofp case, because any negative + // 'X' value would cause an undefined result for the fptoui. + if ((isa<UIToFPInst>(OpI) || isa<SIToFPInst>(OpI)) && + OpI->getOperand(0)->getType() == FI.getType() && + (int)FI.getType()->getPrimitiveSizeInBits() <= + OpI->getType()->getFPMantissaWidth()) + return ReplaceInstUsesWith(FI, OpI->getOperand(0)); + + return commonCastTransforms(FI); +} + +Instruction *InstCombiner::visitUIToFP(CastInst &CI) { + return commonCastTransforms(CI); +} + +Instruction *InstCombiner::visitSIToFP(CastInst &CI) { + return commonCastTransforms(CI); +} + +Instruction *InstCombiner::visitPtrToInt(PtrToIntInst &CI) { + // If the destination integer type is smaller than the intptr_t type for + // this target, do a ptrtoint to intptr_t then do a trunc. This allows the + // trunc to be exposed to other transforms. Don't do this for extending + // ptrtoint's, because we don't know if the target sign or zero extends its + // pointers. + if (CI.getType()->getPrimitiveSizeInBits() < TD->getPointerSizeInBits()) { + Value *P = InsertNewInstBefore(new PtrToIntInst(CI.getOperand(0), + TD->getIntPtrType(), + "tmp"), CI); + return new TruncInst(P, CI.getType()); + } + + return commonPointerCastTransforms(CI); +} + +Instruction *InstCombiner::visitIntToPtr(IntToPtrInst &CI) { + // If the source integer type is larger than the intptr_t type for + // this target, do a trunc to the intptr_t type, then inttoptr of it. This + // allows the trunc to be exposed to other transforms. Don't do this for + // extending inttoptr's, because we don't know if the target sign or zero + // extends to pointers. + if (CI.getOperand(0)->getType()->getPrimitiveSizeInBits() > + TD->getPointerSizeInBits()) { + Value *P = InsertNewInstBefore(new TruncInst(CI.getOperand(0), + TD->getIntPtrType(), + "tmp"), CI); + return new IntToPtrInst(P, CI.getType()); + } + + if (Instruction *I = commonCastTransforms(CI)) + return I; + + const Type *DestPointee = cast<PointerType>(CI.getType())->getElementType(); + if (!DestPointee->isSized()) return 0; + + // If this is inttoptr(add (ptrtoint x), cst), try to turn this into a GEP. + ConstantInt *Cst; + Value *X; + if (match(CI.getOperand(0), m_Add(m_Cast<PtrToIntInst>(m_Value(X)), + m_ConstantInt(Cst)))) { + // If the source and destination operands have the same type, see if this + // is a single-index GEP. + if (X->getType() == CI.getType()) { + // Get the size of the pointee type. + uint64_t Size = TD->getTypeAllocSize(DestPointee); + + // Convert the constant to intptr type. + APInt Offset = Cst->getValue(); + Offset.sextOrTrunc(TD->getPointerSizeInBits()); + + // If Offset is evenly divisible by Size, we can do this xform. + if (Size && !APIntOps::srem(Offset, APInt(Offset.getBitWidth(), Size))){ + Offset = APIntOps::sdiv(Offset, APInt(Offset.getBitWidth(), Size)); + return GetElementPtrInst::Create(X, ConstantInt::get(Offset)); + } + } + // TODO: Could handle other cases, e.g. where add is indexing into field of + // struct etc. + } else if (CI.getOperand(0)->hasOneUse() && + match(CI.getOperand(0), m_Add(m_Value(X), m_ConstantInt(Cst)))) { + // Otherwise, if this is inttoptr(add x, cst), try to turn this into an + // "inttoptr+GEP" instead of "add+intptr". + + // Get the size of the pointee type. + uint64_t Size = TD->getTypeAllocSize(DestPointee); + + // Convert the constant to intptr type. + APInt Offset = Cst->getValue(); + Offset.sextOrTrunc(TD->getPointerSizeInBits()); + + // If Offset is evenly divisible by Size, we can do this xform. + if (Size && !APIntOps::srem(Offset, APInt(Offset.getBitWidth(), Size))){ + Offset = APIntOps::sdiv(Offset, APInt(Offset.getBitWidth(), Size)); + + Instruction *P = InsertNewInstBefore(new IntToPtrInst(X, CI.getType(), + "tmp"), CI); + return GetElementPtrInst::Create(P, ConstantInt::get(Offset), "tmp"); + } + } + return 0; +} + +Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { + // If the operands are integer typed then apply the integer transforms, + // otherwise just apply the common ones. + Value *Src = CI.getOperand(0); + const Type *SrcTy = Src->getType(); + const Type *DestTy = CI.getType(); + + if (SrcTy->isInteger() && DestTy->isInteger()) { + if (Instruction *Result = commonIntCastTransforms(CI)) + return Result; + } else if (isa<PointerType>(SrcTy)) { + if (Instruction *I = commonPointerCastTransforms(CI)) + return I; + } else { + if (Instruction *Result = commonCastTransforms(CI)) + return Result; + } + + + // Get rid of casts from one type to the same type. These are useless and can + // be replaced by the operand. + if (DestTy == Src->getType()) + return ReplaceInstUsesWith(CI, Src); + + if (const PointerType *DstPTy = dyn_cast<PointerType>(DestTy)) { + const PointerType *SrcPTy = cast<PointerType>(SrcTy); + const Type *DstElTy = DstPTy->getElementType(); + const Type *SrcElTy = SrcPTy->getElementType(); + + // If the address spaces don't match, don't eliminate the bitcast, which is + // required for changing types. + if (SrcPTy->getAddressSpace() != DstPTy->getAddressSpace()) + return 0; + + // If we are casting a malloc or alloca to a pointer to a type of the same + // size, rewrite the allocation instruction to allocate the "right" type. + if (AllocationInst *AI = dyn_cast<AllocationInst>(Src)) + if (Instruction *V = PromoteCastOfAllocation(CI, *AI)) + return V; + + // If the source and destination are pointers, and this cast is equivalent + // to a getelementptr X, 0, 0, 0... turn it into the appropriate gep. + // This can enhance SROA and other transforms that want type-safe pointers. + Constant *ZeroUInt = Constant::getNullValue(Type::Int32Ty); + unsigned NumZeros = 0; + while (SrcElTy != DstElTy && + isa<CompositeType>(SrcElTy) && !isa<PointerType>(SrcElTy) && + SrcElTy->getNumContainedTypes() /* not "{}" */) { + SrcElTy = cast<CompositeType>(SrcElTy)->getTypeAtIndex(ZeroUInt); + ++NumZeros; + } + + // If we found a path from the src to dest, create the getelementptr now. + if (SrcElTy == DstElTy) { + SmallVector<Value*, 8> Idxs(NumZeros+1, ZeroUInt); + return GetElementPtrInst::Create(Src, Idxs.begin(), Idxs.end(), "", + ((Instruction*) NULL)); + } + } + + if (ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(Src)) { + if (SVI->hasOneUse()) { + // Okay, we have (bitconvert (shuffle ..)). Check to see if this is + // a bitconvert to a vector with the same # elts. + if (isa<VectorType>(DestTy) && + cast<VectorType>(DestTy)->getNumElements() == + SVI->getType()->getNumElements() && + SVI->getType()->getNumElements() == + cast<VectorType>(SVI->getOperand(0)->getType())->getNumElements()) { + CastInst *Tmp; + // If either of the operands is a cast from CI.getType(), then + // evaluating the shuffle in the casted destination's type will allow + // us to eliminate at least one cast. + if (((Tmp = dyn_cast<CastInst>(SVI->getOperand(0))) && + Tmp->getOperand(0)->getType() == DestTy) || + ((Tmp = dyn_cast<CastInst>(SVI->getOperand(1))) && + Tmp->getOperand(0)->getType() == DestTy)) { + Value *LHS = InsertCastBefore(Instruction::BitCast, + SVI->getOperand(0), DestTy, CI); + Value *RHS = InsertCastBefore(Instruction::BitCast, + SVI->getOperand(1), DestTy, CI); + // Return a new shuffle vector. Use the same element ID's, as we + // know the vector types match #elts. + return new ShuffleVectorInst(LHS, RHS, SVI->getOperand(2)); + } + } + } + } + return 0; +} + +/// GetSelectFoldableOperands - We want to turn code that looks like this: +/// %C = or %A, %B +/// %D = select %cond, %C, %A +/// into: +/// %C = select %cond, %B, 0 +/// %D = or %A, %C +/// +/// Assuming that the specified instruction is an operand to the select, return +/// a bitmask indicating which operands of this instruction are foldable if they +/// equal the other incoming value of the select. +/// +static unsigned GetSelectFoldableOperands(Instruction *I) { + switch (I->getOpcode()) { + case Instruction::Add: + case Instruction::Mul: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + return 3; // Can fold through either operand. + case Instruction::Sub: // Can only fold on the amount subtracted. + case Instruction::Shl: // Can only fold on the shift amount. + case Instruction::LShr: + case Instruction::AShr: + return 1; + default: + return 0; // Cannot fold + } +} + +/// GetSelectFoldableConstant - For the same transformation as the previous +/// function, return the identity constant that goes into the select. +static Constant *GetSelectFoldableConstant(Instruction *I) { + switch (I->getOpcode()) { + default: assert(0 && "This cannot happen!"); abort(); + case Instruction::Add: + case Instruction::Sub: + case Instruction::Or: + case Instruction::Xor: + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: + return Constant::getNullValue(I->getType()); + case Instruction::And: + return Constant::getAllOnesValue(I->getType()); + case Instruction::Mul: + return ConstantInt::get(I->getType(), 1); + } +} + +/// FoldSelectOpOp - Here we have (select c, TI, FI), and we know that TI and FI +/// have the same opcode and only one use each. Try to simplify this. +Instruction *InstCombiner::FoldSelectOpOp(SelectInst &SI, Instruction *TI, + Instruction *FI) { + if (TI->getNumOperands() == 1) { + // If this is a non-volatile load or a cast from the same type, + // merge. + if (TI->isCast()) { + if (TI->getOperand(0)->getType() != FI->getOperand(0)->getType()) + return 0; + } else { + return 0; // unknown unary op. + } + + // Fold this by inserting a select from the input values. + SelectInst *NewSI = SelectInst::Create(SI.getCondition(), TI->getOperand(0), + FI->getOperand(0), SI.getName()+".v"); + InsertNewInstBefore(NewSI, SI); + return CastInst::Create(Instruction::CastOps(TI->getOpcode()), NewSI, + TI->getType()); + } + + // Only handle binary operators here. + if (!isa<BinaryOperator>(TI)) + return 0; + + // Figure out if the operations have any operands in common. + Value *MatchOp, *OtherOpT, *OtherOpF; + bool MatchIsOpZero; + if (TI->getOperand(0) == FI->getOperand(0)) { + MatchOp = TI->getOperand(0); + OtherOpT = TI->getOperand(1); + OtherOpF = FI->getOperand(1); + MatchIsOpZero = true; + } else if (TI->getOperand(1) == FI->getOperand(1)) { + MatchOp = TI->getOperand(1); + OtherOpT = TI->getOperand(0); + OtherOpF = FI->getOperand(0); + MatchIsOpZero = false; + } else if (!TI->isCommutative()) { + return 0; + } else if (TI->getOperand(0) == FI->getOperand(1)) { + MatchOp = TI->getOperand(0); + OtherOpT = TI->getOperand(1); + OtherOpF = FI->getOperand(0); + MatchIsOpZero = true; + } else if (TI->getOperand(1) == FI->getOperand(0)) { + MatchOp = TI->getOperand(1); + OtherOpT = TI->getOperand(0); + OtherOpF = FI->getOperand(1); + MatchIsOpZero = true; + } else { + return 0; + } + + // If we reach here, they do have operations in common. + SelectInst *NewSI = SelectInst::Create(SI.getCondition(), OtherOpT, + OtherOpF, SI.getName()+".v"); + InsertNewInstBefore(NewSI, SI); + + if (BinaryOperator *BO = dyn_cast<BinaryOperator>(TI)) { + if (MatchIsOpZero) + return BinaryOperator::Create(BO->getOpcode(), MatchOp, NewSI); + else + return BinaryOperator::Create(BO->getOpcode(), NewSI, MatchOp); + } + assert(0 && "Shouldn't get here"); + return 0; +} + +static bool isSelect01(Constant *C1, Constant *C2) { + ConstantInt *C1I = dyn_cast<ConstantInt>(C1); + if (!C1I) + return false; + ConstantInt *C2I = dyn_cast<ConstantInt>(C2); + if (!C2I) + return false; + return (C1I->isZero() || C1I->isOne()) && (C2I->isZero() || C2I->isOne()); +} + +/// FoldSelectIntoOp - Try fold the select into one of the operands to +/// facilitate further optimization. +Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal, + Value *FalseVal) { + // See the comment above GetSelectFoldableOperands for a description of the + // transformation we are doing here. + if (Instruction *TVI = dyn_cast<Instruction>(TrueVal)) { + if (TVI->hasOneUse() && TVI->getNumOperands() == 2 && + !isa<Constant>(FalseVal)) { + if (unsigned SFO = GetSelectFoldableOperands(TVI)) { + unsigned OpToFold = 0; + if ((SFO & 1) && FalseVal == TVI->getOperand(0)) { + OpToFold = 1; + } else if ((SFO & 2) && FalseVal == TVI->getOperand(1)) { + OpToFold = 2; + } + + if (OpToFold) { + Constant *C = GetSelectFoldableConstant(TVI); + Value *OOp = TVI->getOperand(2-OpToFold); + // Avoid creating select between 2 constants unless it's selecting + // between 0 and 1. + if (!isa<Constant>(OOp) || isSelect01(C, cast<Constant>(OOp))) { + Instruction *NewSel = SelectInst::Create(SI.getCondition(), OOp, C); + InsertNewInstBefore(NewSel, SI); + NewSel->takeName(TVI); + if (BinaryOperator *BO = dyn_cast<BinaryOperator>(TVI)) + return BinaryOperator::Create(BO->getOpcode(), FalseVal, NewSel); + assert(0 && "Unknown instruction!!"); + } + } + } + } + } + + if (Instruction *FVI = dyn_cast<Instruction>(FalseVal)) { + if (FVI->hasOneUse() && FVI->getNumOperands() == 2 && + !isa<Constant>(TrueVal)) { + if (unsigned SFO = GetSelectFoldableOperands(FVI)) { + unsigned OpToFold = 0; + if ((SFO & 1) && TrueVal == FVI->getOperand(0)) { + OpToFold = 1; + } else if ((SFO & 2) && TrueVal == FVI->getOperand(1)) { + OpToFold = 2; + } + + if (OpToFold) { + Constant *C = GetSelectFoldableConstant(FVI); + Value *OOp = FVI->getOperand(2-OpToFold); + // Avoid creating select between 2 constants unless it's selecting + // between 0 and 1. + if (!isa<Constant>(OOp) || isSelect01(C, cast<Constant>(OOp))) { + Instruction *NewSel = SelectInst::Create(SI.getCondition(), C, OOp); + InsertNewInstBefore(NewSel, SI); + NewSel->takeName(FVI); + if (BinaryOperator *BO = dyn_cast<BinaryOperator>(FVI)) + return BinaryOperator::Create(BO->getOpcode(), TrueVal, NewSel); + assert(0 && "Unknown instruction!!"); + } + } + } + } + } + + return 0; +} + +/// visitSelectInstWithICmp - Visit a SelectInst that has an +/// ICmpInst as its first operand. +/// +Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI, + ICmpInst *ICI) { + bool Changed = false; + ICmpInst::Predicate Pred = ICI->getPredicate(); + Value *CmpLHS = ICI->getOperand(0); + Value *CmpRHS = ICI->getOperand(1); + Value *TrueVal = SI.getTrueValue(); + Value *FalseVal = SI.getFalseValue(); + + // Check cases where the comparison is with a constant that + // can be adjusted to fit the min/max idiom. We may edit ICI in + // place here, so make sure the select is the only user. + if (ICI->hasOneUse()) + if (ConstantInt *CI = dyn_cast<ConstantInt>(CmpRHS)) { + switch (Pred) { + default: break; + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_SLT: { + // X < MIN ? T : F --> F + if (CI->isMinValue(Pred == ICmpInst::ICMP_SLT)) + return ReplaceInstUsesWith(SI, FalseVal); + // X < C ? X : C-1 --> X > C-1 ? C-1 : X + Constant *AdjustedRHS = SubOne(CI); + if ((CmpLHS == TrueVal && AdjustedRHS == FalseVal) || + (CmpLHS == FalseVal && AdjustedRHS == TrueVal)) { + Pred = ICmpInst::getSwappedPredicate(Pred); + CmpRHS = AdjustedRHS; + std::swap(FalseVal, TrueVal); + ICI->setPredicate(Pred); + ICI->setOperand(1, CmpRHS); + SI.setOperand(1, TrueVal); + SI.setOperand(2, FalseVal); + Changed = true; + } + break; + } + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_SGT: { + // X > MAX ? T : F --> F + if (CI->isMaxValue(Pred == ICmpInst::ICMP_SGT)) + return ReplaceInstUsesWith(SI, FalseVal); + // X > C ? X : C+1 --> X < C+1 ? C+1 : X + Constant *AdjustedRHS = AddOne(CI); + if ((CmpLHS == TrueVal && AdjustedRHS == FalseVal) || + (CmpLHS == FalseVal && AdjustedRHS == TrueVal)) { + Pred = ICmpInst::getSwappedPredicate(Pred); + CmpRHS = AdjustedRHS; + std::swap(FalseVal, TrueVal); + ICI->setPredicate(Pred); + ICI->setOperand(1, CmpRHS); + SI.setOperand(1, TrueVal); + SI.setOperand(2, FalseVal); + Changed = true; + } + break; + } + } + + // (x <s 0) ? -1 : 0 -> ashr x, 31 -> all ones if signed + // (x >s -1) ? -1 : 0 -> ashr x, 31 -> all ones if not signed + CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE; + if (match(TrueVal, m_ConstantInt<-1>()) && + match(FalseVal, m_ConstantInt<0>())) + Pred = ICI->getPredicate(); + else if (match(TrueVal, m_ConstantInt<0>()) && + match(FalseVal, m_ConstantInt<-1>())) + Pred = CmpInst::getInversePredicate(ICI->getPredicate()); + + if (Pred != CmpInst::BAD_ICMP_PREDICATE) { + // If we are just checking for a icmp eq of a single bit and zext'ing it + // to an integer, then shift the bit to the appropriate place and then + // cast to integer to avoid the comparison. + const APInt &Op1CV = CI->getValue(); + + // sext (x <s 0) to i32 --> x>>s31 true if signbit set. + // sext (x >s -1) to i32 --> (x>>s31)^-1 true if signbit clear. + if ((Pred == ICmpInst::ICMP_SLT && Op1CV == 0) || + (Pred == ICmpInst::ICMP_SGT && Op1CV.isAllOnesValue())) { + Value *In = ICI->getOperand(0); + Value *Sh = ConstantInt::get(In->getType(), + In->getType()->getPrimitiveSizeInBits()-1); + In = InsertNewInstBefore(BinaryOperator::CreateAShr(In, Sh, + In->getName()+".lobit"), + *ICI); + if (In->getType() != SI.getType()) + In = CastInst::CreateIntegerCast(In, SI.getType(), + true/*SExt*/, "tmp", ICI); + + if (Pred == ICmpInst::ICMP_SGT) + In = InsertNewInstBefore(BinaryOperator::CreateNot(In, + In->getName()+".not"), *ICI); + + return ReplaceInstUsesWith(SI, In); + } + } + } + + if (CmpLHS == TrueVal && CmpRHS == FalseVal) { + // Transform (X == Y) ? X : Y -> Y + if (Pred == ICmpInst::ICMP_EQ) + return ReplaceInstUsesWith(SI, FalseVal); + // Transform (X != Y) ? X : Y -> X + if (Pred == ICmpInst::ICMP_NE) + return ReplaceInstUsesWith(SI, TrueVal); + /// NOTE: if we wanted to, this is where to detect integer MIN/MAX + + } else if (CmpLHS == FalseVal && CmpRHS == TrueVal) { + // Transform (X == Y) ? Y : X -> X + if (Pred == ICmpInst::ICMP_EQ) + return ReplaceInstUsesWith(SI, FalseVal); + // Transform (X != Y) ? Y : X -> Y + if (Pred == ICmpInst::ICMP_NE) + return ReplaceInstUsesWith(SI, TrueVal); + /// NOTE: if we wanted to, this is where to detect integer MIN/MAX + } + + /// NOTE: if we wanted to, this is where to detect integer ABS + + return Changed ? &SI : 0; +} + +Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { + Value *CondVal = SI.getCondition(); + Value *TrueVal = SI.getTrueValue(); + Value *FalseVal = SI.getFalseValue(); + + // select true, X, Y -> X + // select false, X, Y -> Y + if (ConstantInt *C = dyn_cast<ConstantInt>(CondVal)) + return ReplaceInstUsesWith(SI, C->getZExtValue() ? TrueVal : FalseVal); + + // select C, X, X -> X + if (TrueVal == FalseVal) + return ReplaceInstUsesWith(SI, TrueVal); + + if (isa<UndefValue>(TrueVal)) // select C, undef, X -> X + return ReplaceInstUsesWith(SI, FalseVal); + if (isa<UndefValue>(FalseVal)) // select C, X, undef -> X + return ReplaceInstUsesWith(SI, TrueVal); + if (isa<UndefValue>(CondVal)) { // select undef, X, Y -> X or Y + if (isa<Constant>(TrueVal)) + return ReplaceInstUsesWith(SI, TrueVal); + else + return ReplaceInstUsesWith(SI, FalseVal); + } + + if (SI.getType() == Type::Int1Ty) { + if (ConstantInt *C = dyn_cast<ConstantInt>(TrueVal)) { + if (C->getZExtValue()) { + // Change: A = select B, true, C --> A = or B, C + return BinaryOperator::CreateOr(CondVal, FalseVal); + } else { + // Change: A = select B, false, C --> A = and !B, C + Value *NotCond = + InsertNewInstBefore(BinaryOperator::CreateNot(CondVal, + "not."+CondVal->getName()), SI); + return BinaryOperator::CreateAnd(NotCond, FalseVal); + } + } else if (ConstantInt *C = dyn_cast<ConstantInt>(FalseVal)) { + if (C->getZExtValue() == false) { + // Change: A = select B, C, false --> A = and B, C + return BinaryOperator::CreateAnd(CondVal, TrueVal); + } else { + // Change: A = select B, C, true --> A = or !B, C + Value *NotCond = + InsertNewInstBefore(BinaryOperator::CreateNot(CondVal, + "not."+CondVal->getName()), SI); + return BinaryOperator::CreateOr(NotCond, TrueVal); + } + } + + // select a, b, a -> a&b + // select a, a, b -> a|b + if (CondVal == TrueVal) + return BinaryOperator::CreateOr(CondVal, FalseVal); + else if (CondVal == FalseVal) + return BinaryOperator::CreateAnd(CondVal, TrueVal); + } + + // Selecting between two integer constants? + if (ConstantInt *TrueValC = dyn_cast<ConstantInt>(TrueVal)) + if (ConstantInt *FalseValC = dyn_cast<ConstantInt>(FalseVal)) { + // select C, 1, 0 -> zext C to int + if (FalseValC->isZero() && TrueValC->getValue() == 1) { + return CastInst::Create(Instruction::ZExt, CondVal, SI.getType()); + } else if (TrueValC->isZero() && FalseValC->getValue() == 1) { + // select C, 0, 1 -> zext !C to int + Value *NotCond = + InsertNewInstBefore(BinaryOperator::CreateNot(CondVal, + "not."+CondVal->getName()), SI); + return CastInst::Create(Instruction::ZExt, NotCond, SI.getType()); + } + + if (ICmpInst *IC = dyn_cast<ICmpInst>(SI.getCondition())) { + + // (x <s 0) ? -1 : 0 -> ashr x, 31 + if (TrueValC->isAllOnesValue() && FalseValC->isZero()) + if (ConstantInt *CmpCst = dyn_cast<ConstantInt>(IC->getOperand(1))) { + if (IC->getPredicate() == ICmpInst::ICMP_SLT && CmpCst->isZero()) { + // The comparison constant and the result are not neccessarily the + // same width. Make an all-ones value by inserting a AShr. + Value *X = IC->getOperand(0); + uint32_t Bits = X->getType()->getPrimitiveSizeInBits(); + Constant *ShAmt = ConstantInt::get(X->getType(), Bits-1); + Instruction *SRA = BinaryOperator::Create(Instruction::AShr, X, + ShAmt, "ones"); + InsertNewInstBefore(SRA, SI); + + // Then cast to the appropriate width. + return CastInst::CreateIntegerCast(SRA, SI.getType(), true); + } + } + + + // If one of the constants is zero (we know they can't both be) and we + // have an icmp instruction with zero, and we have an 'and' with the + // non-constant value, eliminate this whole mess. This corresponds to + // cases like this: ((X & 27) ? 27 : 0) + if (TrueValC->isZero() || FalseValC->isZero()) + if (IC->isEquality() && isa<ConstantInt>(IC->getOperand(1)) && + cast<Constant>(IC->getOperand(1))->isNullValue()) + if (Instruction *ICA = dyn_cast<Instruction>(IC->getOperand(0))) + if (ICA->getOpcode() == Instruction::And && + isa<ConstantInt>(ICA->getOperand(1)) && + (ICA->getOperand(1) == TrueValC || + ICA->getOperand(1) == FalseValC) && + isOneBitSet(cast<ConstantInt>(ICA->getOperand(1)))) { + // Okay, now we know that everything is set up, we just don't + // know whether we have a icmp_ne or icmp_eq and whether the + // true or false val is the zero. + bool ShouldNotVal = !TrueValC->isZero(); + ShouldNotVal ^= IC->getPredicate() == ICmpInst::ICMP_NE; + Value *V = ICA; + if (ShouldNotVal) + V = InsertNewInstBefore(BinaryOperator::Create( + Instruction::Xor, V, ICA->getOperand(1)), SI); + return ReplaceInstUsesWith(SI, V); + } + } + } + + // See if we are selecting two values based on a comparison of the two values. + if (FCmpInst *FCI = dyn_cast<FCmpInst>(CondVal)) { + if (FCI->getOperand(0) == TrueVal && FCI->getOperand(1) == FalseVal) { + // Transform (X == Y) ? X : Y -> Y + if (FCI->getPredicate() == FCmpInst::FCMP_OEQ) { + // This is not safe in general for floating point: + // consider X== -0, Y== +0. + // It becomes safe if either operand is a nonzero constant. + ConstantFP *CFPt, *CFPf; + if (((CFPt = dyn_cast<ConstantFP>(TrueVal)) && + !CFPt->getValueAPF().isZero()) || + ((CFPf = dyn_cast<ConstantFP>(FalseVal)) && + !CFPf->getValueAPF().isZero())) + return ReplaceInstUsesWith(SI, FalseVal); + } + // Transform (X != Y) ? X : Y -> X + if (FCI->getPredicate() == FCmpInst::FCMP_ONE) + return ReplaceInstUsesWith(SI, TrueVal); + // NOTE: if we wanted to, this is where to detect MIN/MAX + + } else if (FCI->getOperand(0) == FalseVal && FCI->getOperand(1) == TrueVal){ + // Transform (X == Y) ? Y : X -> X + if (FCI->getPredicate() == FCmpInst::FCMP_OEQ) { + // This is not safe in general for floating point: + // consider X== -0, Y== +0. + // It becomes safe if either operand is a nonzero constant. + ConstantFP *CFPt, *CFPf; + if (((CFPt = dyn_cast<ConstantFP>(TrueVal)) && + !CFPt->getValueAPF().isZero()) || + ((CFPf = dyn_cast<ConstantFP>(FalseVal)) && + !CFPf->getValueAPF().isZero())) + return ReplaceInstUsesWith(SI, FalseVal); + } + // Transform (X != Y) ? Y : X -> Y + if (FCI->getPredicate() == FCmpInst::FCMP_ONE) + return ReplaceInstUsesWith(SI, TrueVal); + // NOTE: if we wanted to, this is where to detect MIN/MAX + } + // NOTE: if we wanted to, this is where to detect ABS + } + + // See if we are selecting two values based on a comparison of the two values. + if (ICmpInst *ICI = dyn_cast<ICmpInst>(CondVal)) + if (Instruction *Result = visitSelectInstWithICmp(SI, ICI)) + return Result; + + if (Instruction *TI = dyn_cast<Instruction>(TrueVal)) + if (Instruction *FI = dyn_cast<Instruction>(FalseVal)) + if (TI->hasOneUse() && FI->hasOneUse()) { + Instruction *AddOp = 0, *SubOp = 0; + + // Turn (select C, (op X, Y), (op X, Z)) -> (op X, (select C, Y, Z)) + if (TI->getOpcode() == FI->getOpcode()) + if (Instruction *IV = FoldSelectOpOp(SI, TI, FI)) + return IV; + + // Turn select C, (X+Y), (X-Y) --> (X+(select C, Y, (-Y))). This is + // even legal for FP. + if (TI->getOpcode() == Instruction::Sub && + FI->getOpcode() == Instruction::Add) { + AddOp = FI; SubOp = TI; + } else if (FI->getOpcode() == Instruction::Sub && + TI->getOpcode() == Instruction::Add) { + AddOp = TI; SubOp = FI; + } + + if (AddOp) { + Value *OtherAddOp = 0; + if (SubOp->getOperand(0) == AddOp->getOperand(0)) { + OtherAddOp = AddOp->getOperand(1); + } else if (SubOp->getOperand(0) == AddOp->getOperand(1)) { + OtherAddOp = AddOp->getOperand(0); + } + + if (OtherAddOp) { + // So at this point we know we have (Y -> OtherAddOp): + // select C, (add X, Y), (sub X, Z) + Value *NegVal; // Compute -Z + if (Constant *C = dyn_cast<Constant>(SubOp->getOperand(1))) { + NegVal = ConstantExpr::getNeg(C); + } else { + NegVal = InsertNewInstBefore( + BinaryOperator::CreateNeg(SubOp->getOperand(1), "tmp"), SI); + } + + Value *NewTrueOp = OtherAddOp; + Value *NewFalseOp = NegVal; + if (AddOp != TI) + std::swap(NewTrueOp, NewFalseOp); + Instruction *NewSel = + SelectInst::Create(CondVal, NewTrueOp, + NewFalseOp, SI.getName() + ".p"); + + NewSel = InsertNewInstBefore(NewSel, SI); + return BinaryOperator::CreateAdd(SubOp->getOperand(0), NewSel); + } + } + } + + // See if we can fold the select into one of our operands. + if (SI.getType()->isInteger()) { + Instruction *FoldI = FoldSelectIntoOp(SI, TrueVal, FalseVal); + if (FoldI) + return FoldI; + } + + if (BinaryOperator::isNot(CondVal)) { + SI.setOperand(0, BinaryOperator::getNotArgument(CondVal)); + SI.setOperand(1, FalseVal); + SI.setOperand(2, TrueVal); + return &SI; + } + + return 0; +} + +/// EnforceKnownAlignment - If the specified pointer points to an object that +/// we control, modify the object's alignment to PrefAlign. This isn't +/// often possible though. If alignment is important, a more reliable approach +/// is to simply align all global variables and allocation instructions to +/// their preferred alignment from the beginning. +/// +static unsigned EnforceKnownAlignment(Value *V, + unsigned Align, unsigned PrefAlign) { + + User *U = dyn_cast<User>(V); + if (!U) return Align; + + switch (getOpcode(U)) { + default: break; + case Instruction::BitCast: + return EnforceKnownAlignment(U->getOperand(0), Align, PrefAlign); + case Instruction::GetElementPtr: { + // If all indexes are zero, it is just the alignment of the base pointer. + bool AllZeroOperands = true; + for (User::op_iterator i = U->op_begin() + 1, e = U->op_end(); i != e; ++i) + if (!isa<Constant>(*i) || + !cast<Constant>(*i)->isNullValue()) { + AllZeroOperands = false; + break; + } + + if (AllZeroOperands) { + // Treat this like a bitcast. + return EnforceKnownAlignment(U->getOperand(0), Align, PrefAlign); + } + break; + } + } + + if (GlobalValue *GV = dyn_cast<GlobalValue>(V)) { + // If there is a large requested alignment and we can, bump up the alignment + // of the global. + if (!GV->isDeclaration()) { + if (GV->getAlignment() >= PrefAlign) + Align = GV->getAlignment(); + else { + GV->setAlignment(PrefAlign); + Align = PrefAlign; + } + } + } else if (AllocationInst *AI = dyn_cast<AllocationInst>(V)) { + // If there is a requested alignment and if this is an alloca, round up. We + // don't do this for malloc, because some systems can't respect the request. + if (isa<AllocaInst>(AI)) { + if (AI->getAlignment() >= PrefAlign) + Align = AI->getAlignment(); + else { + AI->setAlignment(PrefAlign); + Align = PrefAlign; + } + } + } + + return Align; +} + +/// GetOrEnforceKnownAlignment - If the specified pointer has an alignment that +/// we can determine, return it, otherwise return 0. If PrefAlign is specified, +/// and it is more than the alignment of the ultimate object, see if we can +/// increase the alignment of the ultimate object, making this check succeed. +unsigned InstCombiner::GetOrEnforceKnownAlignment(Value *V, + unsigned PrefAlign) { + unsigned BitWidth = TD ? TD->getTypeSizeInBits(V->getType()) : + sizeof(PrefAlign) * CHAR_BIT; + APInt Mask = APInt::getAllOnesValue(BitWidth); + APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); + ComputeMaskedBits(V, Mask, KnownZero, KnownOne); + unsigned TrailZ = KnownZero.countTrailingOnes(); + unsigned Align = 1u << std::min(BitWidth - 1, TrailZ); + + if (PrefAlign > Align) + Align = EnforceKnownAlignment(V, Align, PrefAlign); + + // We don't need to make any adjustment. + return Align; +} + +Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) { + unsigned DstAlign = GetOrEnforceKnownAlignment(MI->getOperand(1)); + unsigned SrcAlign = GetOrEnforceKnownAlignment(MI->getOperand(2)); + unsigned MinAlign = std::min(DstAlign, SrcAlign); + unsigned CopyAlign = MI->getAlignment(); + + if (CopyAlign < MinAlign) { + MI->setAlignment(MinAlign); + return MI; + } + + // If MemCpyInst length is 1/2/4/8 bytes then replace memcpy with + // load/store. + ConstantInt *MemOpLength = dyn_cast<ConstantInt>(MI->getOperand(3)); + if (MemOpLength == 0) return 0; + + // Source and destination pointer types are always "i8*" for intrinsic. See + // if the size is something we can handle with a single primitive load/store. + // A single load+store correctly handles overlapping memory in the memmove + // case. + unsigned Size = MemOpLength->getZExtValue(); + if (Size == 0) return MI; // Delete this mem transfer. + + if (Size > 8 || (Size&(Size-1))) + return 0; // If not 1/2/4/8 bytes, exit. + + // Use an integer load+store unless we can find something better. + Type *NewPtrTy = PointerType::getUnqual(IntegerType::get(Size<<3)); + + // Memcpy forces the use of i8* for the source and destination. That means + // that if you're using memcpy to move one double around, you'll get a cast + // from double* to i8*. We'd much rather use a double load+store rather than + // an i64 load+store, here because this improves the odds that the source or + // dest address will be promotable. See if we can find a better type than the + // integer datatype. + if (Value *Op = getBitCastOperand(MI->getOperand(1))) { + const Type *SrcETy = cast<PointerType>(Op->getType())->getElementType(); + if (SrcETy->isSized() && TD->getTypeStoreSize(SrcETy) == Size) { + // The SrcETy might be something like {{{double}}} or [1 x double]. Rip + // down through these levels if so. + while (!SrcETy->isSingleValueType()) { + if (const StructType *STy = dyn_cast<StructType>(SrcETy)) { + if (STy->getNumElements() == 1) + SrcETy = STy->getElementType(0); + else + break; + } else if (const ArrayType *ATy = dyn_cast<ArrayType>(SrcETy)) { + if (ATy->getNumElements() == 1) + SrcETy = ATy->getElementType(); + else + break; + } else + break; + } + + if (SrcETy->isSingleValueType()) + NewPtrTy = PointerType::getUnqual(SrcETy); + } + } + + + // If the memcpy/memmove provides better alignment info than we can + // infer, use it. + SrcAlign = std::max(SrcAlign, CopyAlign); + DstAlign = std::max(DstAlign, CopyAlign); + + Value *Src = InsertBitCastBefore(MI->getOperand(2), NewPtrTy, *MI); + Value *Dest = InsertBitCastBefore(MI->getOperand(1), NewPtrTy, *MI); + Instruction *L = new LoadInst(Src, "tmp", false, SrcAlign); + InsertNewInstBefore(L, *MI); + InsertNewInstBefore(new StoreInst(L, Dest, false, DstAlign), *MI); + + // Set the size of the copy to 0, it will be deleted on the next iteration. + MI->setOperand(3, Constant::getNullValue(MemOpLength->getType())); + return MI; +} + +Instruction *InstCombiner::SimplifyMemSet(MemSetInst *MI) { + unsigned Alignment = GetOrEnforceKnownAlignment(MI->getDest()); + if (MI->getAlignment() < Alignment) { + MI->setAlignment(Alignment); + return MI; + } + + // Extract the length and alignment and fill if they are constant. + ConstantInt *LenC = dyn_cast<ConstantInt>(MI->getLength()); + ConstantInt *FillC = dyn_cast<ConstantInt>(MI->getValue()); + if (!LenC || !FillC || FillC->getType() != Type::Int8Ty) + return 0; + uint64_t Len = LenC->getZExtValue(); + Alignment = MI->getAlignment(); + + // If the length is zero, this is a no-op + if (Len == 0) return MI; // memset(d,c,0,a) -> noop + + // memset(s,c,n) -> store s, c (for n=1,2,4,8) + if (Len <= 8 && isPowerOf2_32((uint32_t)Len)) { + const Type *ITy = IntegerType::get(Len*8); // n=1 -> i8. + + Value *Dest = MI->getDest(); + Dest = InsertBitCastBefore(Dest, PointerType::getUnqual(ITy), *MI); + + // Alignment 0 is identity for alignment 1 for memset, but not store. + if (Alignment == 0) Alignment = 1; + + // Extract the fill value and store. + uint64_t Fill = FillC->getZExtValue()*0x0101010101010101ULL; + InsertNewInstBefore(new StoreInst(ConstantInt::get(ITy, Fill), Dest, false, + Alignment), *MI); + + // Set the size of the copy to 0, it will be deleted on the next iteration. + MI->setLength(Constant::getNullValue(LenC->getType())); + return MI; + } + + return 0; +} + + +/// visitCallInst - CallInst simplification. This mostly only handles folding +/// of intrinsic instructions. For normal calls, it allows visitCallSite to do +/// the heavy lifting. +/// +Instruction *InstCombiner::visitCallInst(CallInst &CI) { + // If the caller function is nounwind, mark the call as nounwind, even if the + // callee isn't. + if (CI.getParent()->getParent()->doesNotThrow() && + !CI.doesNotThrow()) { + CI.setDoesNotThrow(); + return &CI; + } + + + + IntrinsicInst *II = dyn_cast<IntrinsicInst>(&CI); + if (!II) return visitCallSite(&CI); + + // Intrinsics cannot occur in an invoke, so handle them here instead of in + // visitCallSite. + if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(II)) { + bool Changed = false; + + // memmove/cpy/set of zero bytes is a noop. + if (Constant *NumBytes = dyn_cast<Constant>(MI->getLength())) { + if (NumBytes->isNullValue()) return EraseInstFromFunction(CI); + + if (ConstantInt *CI = dyn_cast<ConstantInt>(NumBytes)) + if (CI->getZExtValue() == 1) { + // Replace the instruction with just byte operations. We would + // transform other cases to loads/stores, but we don't know if + // alignment is sufficient. + } + } + + // If we have a memmove and the source operation is a constant global, + // then the source and dest pointers can't alias, so we can change this + // into a call to memcpy. + if (MemMoveInst *MMI = dyn_cast<MemMoveInst>(MI)) { + if (GlobalVariable *GVSrc = dyn_cast<GlobalVariable>(MMI->getSource())) + if (GVSrc->isConstant()) { + Module *M = CI.getParent()->getParent()->getParent(); + Intrinsic::ID MemCpyID = Intrinsic::memcpy; + const Type *Tys[1]; + Tys[0] = CI.getOperand(3)->getType(); + CI.setOperand(0, + Intrinsic::getDeclaration(M, MemCpyID, Tys, 1)); + Changed = true; + } + + // memmove(x,x,size) -> noop. + if (MMI->getSource() == MMI->getDest()) + return EraseInstFromFunction(CI); + } + + // If we can determine a pointer alignment that is bigger than currently + // set, update the alignment. + if (isa<MemTransferInst>(MI)) { + if (Instruction *I = SimplifyMemTransfer(MI)) + return I; + } else if (MemSetInst *MSI = dyn_cast<MemSetInst>(MI)) { + if (Instruction *I = SimplifyMemSet(MSI)) + return I; + } + + if (Changed) return II; + } + + switch (II->getIntrinsicID()) { + default: break; + case Intrinsic::bswap: + // bswap(bswap(x)) -> x + if (IntrinsicInst *Operand = dyn_cast<IntrinsicInst>(II->getOperand(1))) + if (Operand->getIntrinsicID() == Intrinsic::bswap) + return ReplaceInstUsesWith(CI, Operand->getOperand(1)); + break; + case Intrinsic::ppc_altivec_lvx: + case Intrinsic::ppc_altivec_lvxl: + case Intrinsic::x86_sse_loadu_ps: + case Intrinsic::x86_sse2_loadu_pd: + case Intrinsic::x86_sse2_loadu_dq: + // Turn PPC lvx -> load if the pointer is known aligned. + // Turn X86 loadups -> load if the pointer is known aligned. + if (GetOrEnforceKnownAlignment(II->getOperand(1), 16) >= 16) { + Value *Ptr = InsertBitCastBefore(II->getOperand(1), + PointerType::getUnqual(II->getType()), + CI); + return new LoadInst(Ptr); + } + break; + case Intrinsic::ppc_altivec_stvx: + case Intrinsic::ppc_altivec_stvxl: + // Turn stvx -> store if the pointer is known aligned. + if (GetOrEnforceKnownAlignment(II->getOperand(2), 16) >= 16) { + const Type *OpPtrTy = + PointerType::getUnqual(II->getOperand(1)->getType()); + Value *Ptr = InsertBitCastBefore(II->getOperand(2), OpPtrTy, CI); + return new StoreInst(II->getOperand(1), Ptr); + } + break; + case Intrinsic::x86_sse_storeu_ps: + case Intrinsic::x86_sse2_storeu_pd: + case Intrinsic::x86_sse2_storeu_dq: + // Turn X86 storeu -> store if the pointer is known aligned. + if (GetOrEnforceKnownAlignment(II->getOperand(1), 16) >= 16) { + const Type *OpPtrTy = + PointerType::getUnqual(II->getOperand(2)->getType()); + Value *Ptr = InsertBitCastBefore(II->getOperand(1), OpPtrTy, CI); + return new StoreInst(II->getOperand(2), Ptr); + } + break; + + case Intrinsic::x86_sse_cvttss2si: { + // These intrinsics only demands the 0th element of its input vector. If + // we can simplify the input based on that, do so now. + unsigned VWidth = + cast<VectorType>(II->getOperand(1)->getType())->getNumElements(); + APInt DemandedElts(VWidth, 1); + APInt UndefElts(VWidth, 0); + if (Value *V = SimplifyDemandedVectorElts(II->getOperand(1), DemandedElts, + UndefElts)) { + II->setOperand(1, V); + return II; + } + break; + } + + case Intrinsic::ppc_altivec_vperm: + // Turn vperm(V1,V2,mask) -> shuffle(V1,V2,mask) if mask is a constant. + if (ConstantVector *Mask = dyn_cast<ConstantVector>(II->getOperand(3))) { + assert(Mask->getNumOperands() == 16 && "Bad type for intrinsic!"); + + // Check that all of the elements are integer constants or undefs. + bool AllEltsOk = true; + for (unsigned i = 0; i != 16; ++i) { + if (!isa<ConstantInt>(Mask->getOperand(i)) && + !isa<UndefValue>(Mask->getOperand(i))) { + AllEltsOk = false; + break; + } + } + + if (AllEltsOk) { + // Cast the input vectors to byte vectors. + Value *Op0 =InsertBitCastBefore(II->getOperand(1),Mask->getType(),CI); + Value *Op1 =InsertBitCastBefore(II->getOperand(2),Mask->getType(),CI); + Value *Result = UndefValue::get(Op0->getType()); + + // Only extract each element once. + Value *ExtractedElts[32]; + memset(ExtractedElts, 0, sizeof(ExtractedElts)); + + for (unsigned i = 0; i != 16; ++i) { + if (isa<UndefValue>(Mask->getOperand(i))) + continue; + unsigned Idx=cast<ConstantInt>(Mask->getOperand(i))->getZExtValue(); + Idx &= 31; // Match the hardware behavior. + + if (ExtractedElts[Idx] == 0) { + Instruction *Elt = + new ExtractElementInst(Idx < 16 ? Op0 : Op1, Idx&15, "tmp"); + InsertNewInstBefore(Elt, CI); + ExtractedElts[Idx] = Elt; + } + + // Insert this value into the result vector. + Result = InsertElementInst::Create(Result, ExtractedElts[Idx], + i, "tmp"); + InsertNewInstBefore(cast<Instruction>(Result), CI); + } + return CastInst::Create(Instruction::BitCast, Result, CI.getType()); + } + } + break; + + case Intrinsic::stackrestore: { + // If the save is right next to the restore, remove the restore. This can + // happen when variable allocas are DCE'd. + if (IntrinsicInst *SS = dyn_cast<IntrinsicInst>(II->getOperand(1))) { + if (SS->getIntrinsicID() == Intrinsic::stacksave) { + BasicBlock::iterator BI = SS; + if (&*++BI == II) + return EraseInstFromFunction(CI); + } + } + + // Scan down this block to see if there is another stack restore in the + // same block without an intervening call/alloca. + BasicBlock::iterator BI = II; + TerminatorInst *TI = II->getParent()->getTerminator(); + bool CannotRemove = false; + for (++BI; &*BI != TI; ++BI) { + if (isa<AllocaInst>(BI)) { + CannotRemove = true; + break; + } + if (CallInst *BCI = dyn_cast<CallInst>(BI)) { + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(BCI)) { + // If there is a stackrestore below this one, remove this one. + if (II->getIntrinsicID() == Intrinsic::stackrestore) + return EraseInstFromFunction(CI); + // Otherwise, ignore the intrinsic. + } else { + // If we found a non-intrinsic call, we can't remove the stack + // restore. + CannotRemove = true; + break; + } + } + } + + // If the stack restore is in a return/unwind block and if there are no + // allocas or calls between the restore and the return, nuke the restore. + if (!CannotRemove && (isa<ReturnInst>(TI) || isa<UnwindInst>(TI))) + return EraseInstFromFunction(CI); + break; + } + } + + return visitCallSite(II); +} + +// InvokeInst simplification +// +Instruction *InstCombiner::visitInvokeInst(InvokeInst &II) { + return visitCallSite(&II); +} + +/// isSafeToEliminateVarargsCast - If this cast does not affect the value +/// passed through the varargs area, we can eliminate the use of the cast. +static bool isSafeToEliminateVarargsCast(const CallSite CS, + const CastInst * const CI, + const TargetData * const TD, + const int ix) { + if (!CI->isLosslessCast()) + return false; + + // The size of ByVal arguments is derived from the type, so we + // can't change to a type with a different size. If the size were + // passed explicitly we could avoid this check. + if (!CS.paramHasAttr(ix, Attribute::ByVal)) + return true; + + const Type* SrcTy = + cast<PointerType>(CI->getOperand(0)->getType())->getElementType(); + const Type* DstTy = cast<PointerType>(CI->getType())->getElementType(); + if (!SrcTy->isSized() || !DstTy->isSized()) + return false; + if (TD->getTypeAllocSize(SrcTy) != TD->getTypeAllocSize(DstTy)) + return false; + return true; +} + +// visitCallSite - Improvements for call and invoke instructions. +// +Instruction *InstCombiner::visitCallSite(CallSite CS) { + bool Changed = false; + + // If the callee is a constexpr cast of a function, attempt to move the cast + // to the arguments of the call/invoke. + if (transformConstExprCastCall(CS)) return 0; + + Value *Callee = CS.getCalledValue(); + + if (Function *CalleeF = dyn_cast<Function>(Callee)) + if (CalleeF->getCallingConv() != CS.getCallingConv()) { + Instruction *OldCall = CS.getInstruction(); + // If the call and callee calling conventions don't match, this call must + // be unreachable, as the call is undefined. + new StoreInst(ConstantInt::getTrue(), + UndefValue::get(PointerType::getUnqual(Type::Int1Ty)), + OldCall); + if (!OldCall->use_empty()) + OldCall->replaceAllUsesWith(UndefValue::get(OldCall->getType())); + if (isa<CallInst>(OldCall)) // Not worth removing an invoke here. + return EraseInstFromFunction(*OldCall); + return 0; + } + + if (isa<ConstantPointerNull>(Callee) || isa<UndefValue>(Callee)) { + // This instruction is not reachable, just remove it. We insert a store to + // undef so that we know that this code is not reachable, despite the fact + // that we can't modify the CFG here. + new StoreInst(ConstantInt::getTrue(), + UndefValue::get(PointerType::getUnqual(Type::Int1Ty)), + CS.getInstruction()); + + if (!CS.getInstruction()->use_empty()) + CS.getInstruction()-> + replaceAllUsesWith(UndefValue::get(CS.getInstruction()->getType())); + + if (InvokeInst *II = dyn_cast<InvokeInst>(CS.getInstruction())) { + // Don't break the CFG, insert a dummy cond branch. + BranchInst::Create(II->getNormalDest(), II->getUnwindDest(), + ConstantInt::getTrue(), II); + } + return EraseInstFromFunction(*CS.getInstruction()); + } + + if (BitCastInst *BC = dyn_cast<BitCastInst>(Callee)) + if (IntrinsicInst *In = dyn_cast<IntrinsicInst>(BC->getOperand(0))) + if (In->getIntrinsicID() == Intrinsic::init_trampoline) + return transformCallThroughTrampoline(CS); + + const PointerType *PTy = cast<PointerType>(Callee->getType()); + const FunctionType *FTy = cast<FunctionType>(PTy->getElementType()); + if (FTy->isVarArg()) { + int ix = FTy->getNumParams() + (isa<InvokeInst>(Callee) ? 3 : 1); + // See if we can optimize any arguments passed through the varargs area of + // the call. + for (CallSite::arg_iterator I = CS.arg_begin()+FTy->getNumParams(), + E = CS.arg_end(); I != E; ++I, ++ix) { + CastInst *CI = dyn_cast<CastInst>(*I); + if (CI && isSafeToEliminateVarargsCast(CS, CI, TD, ix)) { + *I = CI->getOperand(0); + Changed = true; + } + } + } + + if (isa<InlineAsm>(Callee) && !CS.doesNotThrow()) { + // Inline asm calls cannot throw - mark them 'nounwind'. + CS.setDoesNotThrow(); + Changed = true; + } + + return Changed ? CS.getInstruction() : 0; +} + +// transformConstExprCastCall - If the callee is a constexpr cast of a function, +// attempt to move the cast to the arguments of the call/invoke. +// +bool InstCombiner::transformConstExprCastCall(CallSite CS) { + if (!isa<ConstantExpr>(CS.getCalledValue())) return false; + ConstantExpr *CE = cast<ConstantExpr>(CS.getCalledValue()); + if (CE->getOpcode() != Instruction::BitCast || + !isa<Function>(CE->getOperand(0))) + return false; + Function *Callee = cast<Function>(CE->getOperand(0)); + Instruction *Caller = CS.getInstruction(); + const AttrListPtr &CallerPAL = CS.getAttributes(); + + // Okay, this is a cast from a function to a different type. Unless doing so + // would cause a type conversion of one of our arguments, change this call to + // be a direct call with arguments casted to the appropriate types. + // + const FunctionType *FT = Callee->getFunctionType(); + const Type *OldRetTy = Caller->getType(); + const Type *NewRetTy = FT->getReturnType(); + + if (isa<StructType>(NewRetTy)) + return false; // TODO: Handle multiple return values. + + // Check to see if we are changing the return type... + if (OldRetTy != NewRetTy) { + if (Callee->isDeclaration() && + // Conversion is ok if changing from one pointer type to another or from + // a pointer to an integer of the same size. + !((isa<PointerType>(OldRetTy) || OldRetTy == TD->getIntPtrType()) && + (isa<PointerType>(NewRetTy) || NewRetTy == TD->getIntPtrType()))) + return false; // Cannot transform this return value. + + if (!Caller->use_empty() && + // void -> non-void is handled specially + NewRetTy != Type::VoidTy && !CastInst::isCastable(NewRetTy, OldRetTy)) + return false; // Cannot transform this return value. + + if (!CallerPAL.isEmpty() && !Caller->use_empty()) { + Attributes RAttrs = CallerPAL.getRetAttributes(); + if (RAttrs & Attribute::typeIncompatible(NewRetTy)) + return false; // Attribute not compatible with transformed value. + } + + // If the callsite is an invoke instruction, and the return value is used by + // a PHI node in a successor, we cannot change the return type of the call + // because there is no place to put the cast instruction (without breaking + // the critical edge). Bail out in this case. + if (!Caller->use_empty()) + if (InvokeInst *II = dyn_cast<InvokeInst>(Caller)) + for (Value::use_iterator UI = II->use_begin(), E = II->use_end(); + UI != E; ++UI) + if (PHINode *PN = dyn_cast<PHINode>(*UI)) + if (PN->getParent() == II->getNormalDest() || + PN->getParent() == II->getUnwindDest()) + return false; + } + + unsigned NumActualArgs = unsigned(CS.arg_end()-CS.arg_begin()); + unsigned NumCommonArgs = std::min(FT->getNumParams(), NumActualArgs); + + CallSite::arg_iterator AI = CS.arg_begin(); + for (unsigned i = 0, e = NumCommonArgs; i != e; ++i, ++AI) { + const Type *ParamTy = FT->getParamType(i); + const Type *ActTy = (*AI)->getType(); + + if (!CastInst::isCastable(ActTy, ParamTy)) + return false; // Cannot transform this parameter value. + + if (CallerPAL.getParamAttributes(i + 1) + & Attribute::typeIncompatible(ParamTy)) + return false; // Attribute not compatible with transformed value. + + // Converting from one pointer type to another or between a pointer and an + // integer of the same size is safe even if we do not have a body. + bool isConvertible = ActTy == ParamTy || + ((isa<PointerType>(ParamTy) || ParamTy == TD->getIntPtrType()) && + (isa<PointerType>(ActTy) || ActTy == TD->getIntPtrType())); + if (Callee->isDeclaration() && !isConvertible) return false; + } + + if (FT->getNumParams() < NumActualArgs && !FT->isVarArg() && + Callee->isDeclaration()) + return false; // Do not delete arguments unless we have a function body. + + if (FT->getNumParams() < NumActualArgs && FT->isVarArg() && + !CallerPAL.isEmpty()) + // In this case we have more arguments than the new function type, but we + // won't be dropping them. Check that these extra arguments have attributes + // that are compatible with being a vararg call argument. + for (unsigned i = CallerPAL.getNumSlots(); i; --i) { + if (CallerPAL.getSlot(i - 1).Index <= FT->getNumParams()) + break; + Attributes PAttrs = CallerPAL.getSlot(i - 1).Attrs; + if (PAttrs & Attribute::VarArgsIncompatible) + return false; + } + + // Okay, we decided that this is a safe thing to do: go ahead and start + // inserting cast instructions as necessary... + std::vector<Value*> Args; + Args.reserve(NumActualArgs); + SmallVector<AttributeWithIndex, 8> attrVec; + attrVec.reserve(NumCommonArgs); + + // Get any return attributes. + Attributes RAttrs = CallerPAL.getRetAttributes(); + + // If the return value is not being used, the type may not be compatible + // with the existing attributes. Wipe out any problematic attributes. + RAttrs &= ~Attribute::typeIncompatible(NewRetTy); + + // Add the new return attributes. + if (RAttrs) + attrVec.push_back(AttributeWithIndex::get(0, RAttrs)); + + AI = CS.arg_begin(); + for (unsigned i = 0; i != NumCommonArgs; ++i, ++AI) { + const Type *ParamTy = FT->getParamType(i); + if ((*AI)->getType() == ParamTy) { + Args.push_back(*AI); + } else { + Instruction::CastOps opcode = CastInst::getCastOpcode(*AI, + false, ParamTy, false); + CastInst *NewCast = CastInst::Create(opcode, *AI, ParamTy, "tmp"); + Args.push_back(InsertNewInstBefore(NewCast, *Caller)); + } + + // Add any parameter attributes. + if (Attributes PAttrs = CallerPAL.getParamAttributes(i + 1)) + attrVec.push_back(AttributeWithIndex::get(i + 1, PAttrs)); + } + + // If the function takes more arguments than the call was taking, add them + // now... + for (unsigned i = NumCommonArgs; i != FT->getNumParams(); ++i) + Args.push_back(Constant::getNullValue(FT->getParamType(i))); + + // If we are removing arguments to the function, emit an obnoxious warning... + if (FT->getNumParams() < NumActualArgs) { + if (!FT->isVarArg()) { + cerr << "WARNING: While resolving call to function '" + << Callee->getName() << "' arguments were dropped!\n"; + } else { + // Add all of the arguments in their promoted form to the arg list... + for (unsigned i = FT->getNumParams(); i != NumActualArgs; ++i, ++AI) { + const Type *PTy = getPromotedType((*AI)->getType()); + if (PTy != (*AI)->getType()) { + // Must promote to pass through va_arg area! + Instruction::CastOps opcode = CastInst::getCastOpcode(*AI, false, + PTy, false); + Instruction *Cast = CastInst::Create(opcode, *AI, PTy, "tmp"); + InsertNewInstBefore(Cast, *Caller); + Args.push_back(Cast); + } else { + Args.push_back(*AI); + } + + // Add any parameter attributes. + if (Attributes PAttrs = CallerPAL.getParamAttributes(i + 1)) + attrVec.push_back(AttributeWithIndex::get(i + 1, PAttrs)); + } + } + } + + if (Attributes FnAttrs = CallerPAL.getFnAttributes()) + attrVec.push_back(AttributeWithIndex::get(~0, FnAttrs)); + + if (NewRetTy == Type::VoidTy) + Caller->setName(""); // Void type should not have a name. + + const AttrListPtr &NewCallerPAL = AttrListPtr::get(attrVec.begin(),attrVec.end()); + + Instruction *NC; + if (InvokeInst *II = dyn_cast<InvokeInst>(Caller)) { + NC = InvokeInst::Create(Callee, II->getNormalDest(), II->getUnwindDest(), + Args.begin(), Args.end(), + Caller->getName(), Caller); + cast<InvokeInst>(NC)->setCallingConv(II->getCallingConv()); + cast<InvokeInst>(NC)->setAttributes(NewCallerPAL); + } else { + NC = CallInst::Create(Callee, Args.begin(), Args.end(), + Caller->getName(), Caller); + CallInst *CI = cast<CallInst>(Caller); + if (CI->isTailCall()) + cast<CallInst>(NC)->setTailCall(); + cast<CallInst>(NC)->setCallingConv(CI->getCallingConv()); + cast<CallInst>(NC)->setAttributes(NewCallerPAL); + } + + // Insert a cast of the return type as necessary. + Value *NV = NC; + if (OldRetTy != NV->getType() && !Caller->use_empty()) { + if (NV->getType() != Type::VoidTy) { + Instruction::CastOps opcode = CastInst::getCastOpcode(NC, false, + OldRetTy, false); + NV = NC = CastInst::Create(opcode, NC, OldRetTy, "tmp"); + + // If this is an invoke instruction, we should insert it after the first + // non-phi, instruction in the normal successor block. + if (InvokeInst *II = dyn_cast<InvokeInst>(Caller)) { + BasicBlock::iterator I = II->getNormalDest()->getFirstNonPHI(); + InsertNewInstBefore(NC, *I); + } else { + // Otherwise, it's a call, just insert cast right after the call instr + InsertNewInstBefore(NC, *Caller); + } + AddUsersToWorkList(*Caller); + } else { + NV = UndefValue::get(Caller->getType()); + } + } + + if (Caller->getType() != Type::VoidTy && !Caller->use_empty()) + Caller->replaceAllUsesWith(NV); + Caller->eraseFromParent(); + RemoveFromWorkList(Caller); + return true; +} + +// transformCallThroughTrampoline - Turn a call to a function created by the +// init_trampoline intrinsic into a direct call to the underlying function. +// +Instruction *InstCombiner::transformCallThroughTrampoline(CallSite CS) { + Value *Callee = CS.getCalledValue(); + const PointerType *PTy = cast<PointerType>(Callee->getType()); + const FunctionType *FTy = cast<FunctionType>(PTy->getElementType()); + const AttrListPtr &Attrs = CS.getAttributes(); + + // If the call already has the 'nest' attribute somewhere then give up - + // otherwise 'nest' would occur twice after splicing in the chain. + if (Attrs.hasAttrSomewhere(Attribute::Nest)) + return 0; + + IntrinsicInst *Tramp = + cast<IntrinsicInst>(cast<BitCastInst>(Callee)->getOperand(0)); + + Function *NestF = cast<Function>(Tramp->getOperand(2)->stripPointerCasts()); + const PointerType *NestFPTy = cast<PointerType>(NestF->getType()); + const FunctionType *NestFTy = cast<FunctionType>(NestFPTy->getElementType()); + + const AttrListPtr &NestAttrs = NestF->getAttributes(); + if (!NestAttrs.isEmpty()) { + unsigned NestIdx = 1; + const Type *NestTy = 0; + Attributes NestAttr = Attribute::None; + + // Look for a parameter marked with the 'nest' attribute. + for (FunctionType::param_iterator I = NestFTy->param_begin(), + E = NestFTy->param_end(); I != E; ++NestIdx, ++I) + if (NestAttrs.paramHasAttr(NestIdx, Attribute::Nest)) { + // Record the parameter type and any other attributes. + NestTy = *I; + NestAttr = NestAttrs.getParamAttributes(NestIdx); + break; + } + + if (NestTy) { + Instruction *Caller = CS.getInstruction(); + std::vector<Value*> NewArgs; + NewArgs.reserve(unsigned(CS.arg_end()-CS.arg_begin())+1); + + SmallVector<AttributeWithIndex, 8> NewAttrs; + NewAttrs.reserve(Attrs.getNumSlots() + 1); + + // Insert the nest argument into the call argument list, which may + // mean appending it. Likewise for attributes. + + // Add any result attributes. + if (Attributes Attr = Attrs.getRetAttributes()) + NewAttrs.push_back(AttributeWithIndex::get(0, Attr)); + + { + unsigned Idx = 1; + CallSite::arg_iterator I = CS.arg_begin(), E = CS.arg_end(); + do { + if (Idx == NestIdx) { + // Add the chain argument and attributes. + Value *NestVal = Tramp->getOperand(3); + if (NestVal->getType() != NestTy) + NestVal = new BitCastInst(NestVal, NestTy, "nest", Caller); + NewArgs.push_back(NestVal); + NewAttrs.push_back(AttributeWithIndex::get(NestIdx, NestAttr)); + } + + if (I == E) + break; + + // Add the original argument and attributes. + NewArgs.push_back(*I); + if (Attributes Attr = Attrs.getParamAttributes(Idx)) + NewAttrs.push_back + (AttributeWithIndex::get(Idx + (Idx >= NestIdx), Attr)); + + ++Idx, ++I; + } while (1); + } + + // Add any function attributes. + if (Attributes Attr = Attrs.getFnAttributes()) + NewAttrs.push_back(AttributeWithIndex::get(~0, Attr)); + + // The trampoline may have been bitcast to a bogus type (FTy). + // Handle this by synthesizing a new function type, equal to FTy + // with the chain parameter inserted. + + std::vector<const Type*> NewTypes; + NewTypes.reserve(FTy->getNumParams()+1); + + // Insert the chain's type into the list of parameter types, which may + // mean appending it. + { + unsigned Idx = 1; + FunctionType::param_iterator I = FTy->param_begin(), + E = FTy->param_end(); + + do { + if (Idx == NestIdx) + // Add the chain's type. + NewTypes.push_back(NestTy); + + if (I == E) + break; + + // Add the original type. + NewTypes.push_back(*I); + + ++Idx, ++I; + } while (1); + } + + // Replace the trampoline call with a direct call. Let the generic + // code sort out any function type mismatches. + FunctionType *NewFTy = + FunctionType::get(FTy->getReturnType(), NewTypes, FTy->isVarArg()); + Constant *NewCallee = NestF->getType() == PointerType::getUnqual(NewFTy) ? + NestF : ConstantExpr::getBitCast(NestF, PointerType::getUnqual(NewFTy)); + const AttrListPtr &NewPAL = AttrListPtr::get(NewAttrs.begin(),NewAttrs.end()); + + Instruction *NewCaller; + if (InvokeInst *II = dyn_cast<InvokeInst>(Caller)) { + NewCaller = InvokeInst::Create(NewCallee, + II->getNormalDest(), II->getUnwindDest(), + NewArgs.begin(), NewArgs.end(), + Caller->getName(), Caller); + cast<InvokeInst>(NewCaller)->setCallingConv(II->getCallingConv()); + cast<InvokeInst>(NewCaller)->setAttributes(NewPAL); + } else { + NewCaller = CallInst::Create(NewCallee, NewArgs.begin(), NewArgs.end(), + Caller->getName(), Caller); + if (cast<CallInst>(Caller)->isTailCall()) + cast<CallInst>(NewCaller)->setTailCall(); + cast<CallInst>(NewCaller)-> + setCallingConv(cast<CallInst>(Caller)->getCallingConv()); + cast<CallInst>(NewCaller)->setAttributes(NewPAL); + } + if (Caller->getType() != Type::VoidTy && !Caller->use_empty()) + Caller->replaceAllUsesWith(NewCaller); + Caller->eraseFromParent(); + RemoveFromWorkList(Caller); + return 0; + } + } + + // Replace the trampoline call with a direct call. Since there is no 'nest' + // parameter, there is no need to adjust the argument list. Let the generic + // code sort out any function type mismatches. + Constant *NewCallee = + NestF->getType() == PTy ? NestF : ConstantExpr::getBitCast(NestF, PTy); + CS.setCalledFunction(NewCallee); + return CS.getInstruction(); +} + +/// FoldPHIArgBinOpIntoPHI - If we have something like phi [add (a,b), add(c,d)] +/// and if a/b/c/d and the add's all have a single use, turn this into two phi's +/// and a single binop. +Instruction *InstCombiner::FoldPHIArgBinOpIntoPHI(PHINode &PN) { + Instruction *FirstInst = cast<Instruction>(PN.getIncomingValue(0)); + assert(isa<BinaryOperator>(FirstInst) || isa<CmpInst>(FirstInst)); + unsigned Opc = FirstInst->getOpcode(); + Value *LHSVal = FirstInst->getOperand(0); + Value *RHSVal = FirstInst->getOperand(1); + + const Type *LHSType = LHSVal->getType(); + const Type *RHSType = RHSVal->getType(); + + // Scan to see if all operands are the same opcode, all have one use, and all + // kill their operands (i.e. the operands have one use). + for (unsigned i = 1; i != PN.getNumIncomingValues(); ++i) { + Instruction *I = dyn_cast<Instruction>(PN.getIncomingValue(i)); + if (!I || I->getOpcode() != Opc || !I->hasOneUse() || + // Verify type of the LHS matches so we don't fold cmp's of different + // types or GEP's with different index types. + I->getOperand(0)->getType() != LHSType || + I->getOperand(1)->getType() != RHSType) + return 0; + + // If they are CmpInst instructions, check their predicates + if (Opc == Instruction::ICmp || Opc == Instruction::FCmp) + if (cast<CmpInst>(I)->getPredicate() != + cast<CmpInst>(FirstInst)->getPredicate()) + return 0; + + // Keep track of which operand needs a phi node. + if (I->getOperand(0) != LHSVal) LHSVal = 0; + if (I->getOperand(1) != RHSVal) RHSVal = 0; + } + + // Otherwise, this is safe to transform! + + Value *InLHS = FirstInst->getOperand(0); + Value *InRHS = FirstInst->getOperand(1); + PHINode *NewLHS = 0, *NewRHS = 0; + if (LHSVal == 0) { + NewLHS = PHINode::Create(LHSType, + FirstInst->getOperand(0)->getName() + ".pn"); + NewLHS->reserveOperandSpace(PN.getNumOperands()/2); + NewLHS->addIncoming(InLHS, PN.getIncomingBlock(0)); + InsertNewInstBefore(NewLHS, PN); + LHSVal = NewLHS; + } + + if (RHSVal == 0) { + NewRHS = PHINode::Create(RHSType, + FirstInst->getOperand(1)->getName() + ".pn"); + NewRHS->reserveOperandSpace(PN.getNumOperands()/2); + NewRHS->addIncoming(InRHS, PN.getIncomingBlock(0)); + InsertNewInstBefore(NewRHS, PN); + RHSVal = NewRHS; + } + + // Add all operands to the new PHIs. + if (NewLHS || NewRHS) { + for (unsigned i = 1, e = PN.getNumIncomingValues(); i != e; ++i) { + Instruction *InInst = cast<Instruction>(PN.getIncomingValue(i)); + if (NewLHS) { + Value *NewInLHS = InInst->getOperand(0); + NewLHS->addIncoming(NewInLHS, PN.getIncomingBlock(i)); + } + if (NewRHS) { + Value *NewInRHS = InInst->getOperand(1); + NewRHS->addIncoming(NewInRHS, PN.getIncomingBlock(i)); + } + } + } + + if (BinaryOperator *BinOp = dyn_cast<BinaryOperator>(FirstInst)) + return BinaryOperator::Create(BinOp->getOpcode(), LHSVal, RHSVal); + CmpInst *CIOp = cast<CmpInst>(FirstInst); + return CmpInst::Create(CIOp->getOpcode(), CIOp->getPredicate(), LHSVal, + RHSVal); +} + +Instruction *InstCombiner::FoldPHIArgGEPIntoPHI(PHINode &PN) { + GetElementPtrInst *FirstInst =cast<GetElementPtrInst>(PN.getIncomingValue(0)); + + SmallVector<Value*, 16> FixedOperands(FirstInst->op_begin(), + FirstInst->op_end()); + // This is true if all GEP bases are allocas and if all indices into them are + // constants. + bool AllBasePointersAreAllocas = true; + + // Scan to see if all operands are the same opcode, all have one use, and all + // kill their operands (i.e. the operands have one use). + for (unsigned i = 1; i != PN.getNumIncomingValues(); ++i) { + GetElementPtrInst *GEP= dyn_cast<GetElementPtrInst>(PN.getIncomingValue(i)); + if (!GEP || !GEP->hasOneUse() || GEP->getType() != FirstInst->getType() || + GEP->getNumOperands() != FirstInst->getNumOperands()) + return 0; + + // Keep track of whether or not all GEPs are of alloca pointers. + if (AllBasePointersAreAllocas && + (!isa<AllocaInst>(GEP->getOperand(0)) || + !GEP->hasAllConstantIndices())) + AllBasePointersAreAllocas = false; + + // Compare the operand lists. + for (unsigned op = 0, e = FirstInst->getNumOperands(); op != e; ++op) { + if (FirstInst->getOperand(op) == GEP->getOperand(op)) + continue; + + // Don't merge two GEPs when two operands differ (introducing phi nodes) + // if one of the PHIs has a constant for the index. The index may be + // substantially cheaper to compute for the constants, so making it a + // variable index could pessimize the path. This also handles the case + // for struct indices, which must always be constant. + if (isa<ConstantInt>(FirstInst->getOperand(op)) || + isa<ConstantInt>(GEP->getOperand(op))) + return 0; + + if (FirstInst->getOperand(op)->getType() !=GEP->getOperand(op)->getType()) + return 0; + FixedOperands[op] = 0; // Needs a PHI. + } + } + + // If all of the base pointers of the PHI'd GEPs are from allocas, don't + // bother doing this transformation. At best, this will just save a bit of + // offset calculation, but all the predecessors will have to materialize the + // stack address into a register anyway. We'd actually rather *clone* the + // load up into the predecessors so that we have a load of a gep of an alloca, + // which can usually all be folded into the load. + if (AllBasePointersAreAllocas) + return 0; + + // Otherwise, this is safe to transform. Insert PHI nodes for each operand + // that is variable. + SmallVector<PHINode*, 16> OperandPhis(FixedOperands.size()); + + bool HasAnyPHIs = false; + for (unsigned i = 0, e = FixedOperands.size(); i != e; ++i) { + if (FixedOperands[i]) continue; // operand doesn't need a phi. + Value *FirstOp = FirstInst->getOperand(i); + PHINode *NewPN = PHINode::Create(FirstOp->getType(), + FirstOp->getName()+".pn"); + InsertNewInstBefore(NewPN, PN); + + NewPN->reserveOperandSpace(e); + NewPN->addIncoming(FirstOp, PN.getIncomingBlock(0)); + OperandPhis[i] = NewPN; + FixedOperands[i] = NewPN; + HasAnyPHIs = true; + } + + + // Add all operands to the new PHIs. + if (HasAnyPHIs) { + for (unsigned i = 1, e = PN.getNumIncomingValues(); i != e; ++i) { + GetElementPtrInst *InGEP =cast<GetElementPtrInst>(PN.getIncomingValue(i)); + BasicBlock *InBB = PN.getIncomingBlock(i); + + for (unsigned op = 0, e = OperandPhis.size(); op != e; ++op) + if (PHINode *OpPhi = OperandPhis[op]) + OpPhi->addIncoming(InGEP->getOperand(op), InBB); + } + } + + Value *Base = FixedOperands[0]; + return GetElementPtrInst::Create(Base, FixedOperands.begin()+1, + FixedOperands.end()); +} + + +/// isSafeAndProfitableToSinkLoad - Return true if we know that it is safe to +/// sink the load out of the block that defines it. This means that it must be +/// obvious the value of the load is not changed from the point of the load to +/// the end of the block it is in. +/// +/// Finally, it is safe, but not profitable, to sink a load targetting a +/// non-address-taken alloca. Doing so will cause us to not promote the alloca +/// to a register. +static bool isSafeAndProfitableToSinkLoad(LoadInst *L) { + BasicBlock::iterator BBI = L, E = L->getParent()->end(); + + for (++BBI; BBI != E; ++BBI) + if (BBI->mayWriteToMemory()) + return false; + + // Check for non-address taken alloca. If not address-taken already, it isn't + // profitable to do this xform. + if (AllocaInst *AI = dyn_cast<AllocaInst>(L->getOperand(0))) { + bool isAddressTaken = false; + for (Value::use_iterator UI = AI->use_begin(), E = AI->use_end(); + UI != E; ++UI) { + if (isa<LoadInst>(UI)) continue; + if (StoreInst *SI = dyn_cast<StoreInst>(*UI)) { + // If storing TO the alloca, then the address isn't taken. + if (SI->getOperand(1) == AI) continue; + } + isAddressTaken = true; + break; + } + + if (!isAddressTaken && AI->isStaticAlloca()) + return false; + } + + // If this load is a load from a GEP with a constant offset from an alloca, + // then we don't want to sink it. In its present form, it will be + // load [constant stack offset]. Sinking it will cause us to have to + // materialize the stack addresses in each predecessor in a register only to + // do a shared load from register in the successor. + if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(L->getOperand(0))) + if (AllocaInst *AI = dyn_cast<AllocaInst>(GEP->getOperand(0))) + if (AI->isStaticAlloca() && GEP->hasAllConstantIndices()) + return false; + + return true; +} + + +// FoldPHIArgOpIntoPHI - If all operands to a PHI node are the same "unary" +// operator and they all are only used by the PHI, PHI together their +// inputs, and do the operation once, to the result of the PHI. +Instruction *InstCombiner::FoldPHIArgOpIntoPHI(PHINode &PN) { + Instruction *FirstInst = cast<Instruction>(PN.getIncomingValue(0)); + + // Scan the instruction, looking for input operations that can be folded away. + // If all input operands to the phi are the same instruction (e.g. a cast from + // the same type or "+42") we can pull the operation through the PHI, reducing + // code size and simplifying code. + Constant *ConstantOp = 0; + const Type *CastSrcTy = 0; + bool isVolatile = false; + if (isa<CastInst>(FirstInst)) { + CastSrcTy = FirstInst->getOperand(0)->getType(); + } else if (isa<BinaryOperator>(FirstInst) || isa<CmpInst>(FirstInst)) { + // Can fold binop, compare or shift here if the RHS is a constant, + // otherwise call FoldPHIArgBinOpIntoPHI. + ConstantOp = dyn_cast<Constant>(FirstInst->getOperand(1)); + if (ConstantOp == 0) + return FoldPHIArgBinOpIntoPHI(PN); + } else if (LoadInst *LI = dyn_cast<LoadInst>(FirstInst)) { + isVolatile = LI->isVolatile(); + // We can't sink the load if the loaded value could be modified between the + // load and the PHI. + if (LI->getParent() != PN.getIncomingBlock(0) || + !isSafeAndProfitableToSinkLoad(LI)) + return 0; + + // If the PHI is of volatile loads and the load block has multiple + // successors, sinking it would remove a load of the volatile value from + // the path through the other successor. + if (isVolatile && + LI->getParent()->getTerminator()->getNumSuccessors() != 1) + return 0; + + } else if (isa<GetElementPtrInst>(FirstInst)) { + return FoldPHIArgGEPIntoPHI(PN); + } else { + return 0; // Cannot fold this operation. + } + + // Check to see if all arguments are the same operation. + for (unsigned i = 1, e = PN.getNumIncomingValues(); i != e; ++i) { + if (!isa<Instruction>(PN.getIncomingValue(i))) return 0; + Instruction *I = cast<Instruction>(PN.getIncomingValue(i)); + if (!I->hasOneUse() || !I->isSameOperationAs(FirstInst)) + return 0; + if (CastSrcTy) { + if (I->getOperand(0)->getType() != CastSrcTy) + return 0; // Cast operation must match. + } else if (LoadInst *LI = dyn_cast<LoadInst>(I)) { + // We can't sink the load if the loaded value could be modified between + // the load and the PHI. + if (LI->isVolatile() != isVolatile || + LI->getParent() != PN.getIncomingBlock(i) || + !isSafeAndProfitableToSinkLoad(LI)) + return 0; + + // If the PHI is of volatile loads and the load block has multiple + // successors, sinking it would remove a load of the volatile value from + // the path through the other successor. + if (isVolatile && + LI->getParent()->getTerminator()->getNumSuccessors() != 1) + return 0; + + } else if (I->getOperand(1) != ConstantOp) { + return 0; + } + } + + // Okay, they are all the same operation. Create a new PHI node of the + // correct type, and PHI together all of the LHS's of the instructions. + PHINode *NewPN = PHINode::Create(FirstInst->getOperand(0)->getType(), + PN.getName()+".in"); + NewPN->reserveOperandSpace(PN.getNumOperands()/2); + + Value *InVal = FirstInst->getOperand(0); + NewPN->addIncoming(InVal, PN.getIncomingBlock(0)); + + // Add all operands to the new PHI. + for (unsigned i = 1, e = PN.getNumIncomingValues(); i != e; ++i) { + Value *NewInVal = cast<Instruction>(PN.getIncomingValue(i))->getOperand(0); + if (NewInVal != InVal) + InVal = 0; + NewPN->addIncoming(NewInVal, PN.getIncomingBlock(i)); + } + + Value *PhiVal; + if (InVal) { + // The new PHI unions all of the same values together. This is really + // common, so we handle it intelligently here for compile-time speed. + PhiVal = InVal; + delete NewPN; + } else { + InsertNewInstBefore(NewPN, PN); + PhiVal = NewPN; + } + + // Insert and return the new operation. + if (CastInst* FirstCI = dyn_cast<CastInst>(FirstInst)) + return CastInst::Create(FirstCI->getOpcode(), PhiVal, PN.getType()); + if (BinaryOperator *BinOp = dyn_cast<BinaryOperator>(FirstInst)) + return BinaryOperator::Create(BinOp->getOpcode(), PhiVal, ConstantOp); + if (CmpInst *CIOp = dyn_cast<CmpInst>(FirstInst)) + return CmpInst::Create(CIOp->getOpcode(), CIOp->getPredicate(), + PhiVal, ConstantOp); + assert(isa<LoadInst>(FirstInst) && "Unknown operation"); + + // If this was a volatile load that we are merging, make sure to loop through + // and mark all the input loads as non-volatile. If we don't do this, we will + // insert a new volatile load and the old ones will not be deletable. + if (isVolatile) + for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i) + cast<LoadInst>(PN.getIncomingValue(i))->setVolatile(false); + + return new LoadInst(PhiVal, "", isVolatile); +} + +/// DeadPHICycle - Return true if this PHI node is only used by a PHI node cycle +/// that is dead. +static bool DeadPHICycle(PHINode *PN, + SmallPtrSet<PHINode*, 16> &PotentiallyDeadPHIs) { + if (PN->use_empty()) return true; + if (!PN->hasOneUse()) return false; + + // Remember this node, and if we find the cycle, return. + if (!PotentiallyDeadPHIs.insert(PN)) + return true; + + // Don't scan crazily complex things. + if (PotentiallyDeadPHIs.size() == 16) + return false; + + if (PHINode *PU = dyn_cast<PHINode>(PN->use_back())) + return DeadPHICycle(PU, PotentiallyDeadPHIs); + + return false; +} + +/// PHIsEqualValue - Return true if this phi node is always equal to +/// NonPhiInVal. This happens with mutually cyclic phi nodes like: +/// z = some value; x = phi (y, z); y = phi (x, z) +static bool PHIsEqualValue(PHINode *PN, Value *NonPhiInVal, + SmallPtrSet<PHINode*, 16> &ValueEqualPHIs) { + // See if we already saw this PHI node. + if (!ValueEqualPHIs.insert(PN)) + return true; + + // Don't scan crazily complex things. + if (ValueEqualPHIs.size() == 16) + return false; + + // Scan the operands to see if they are either phi nodes or are equal to + // the value. + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { + Value *Op = PN->getIncomingValue(i); + if (PHINode *OpPN = dyn_cast<PHINode>(Op)) { + if (!PHIsEqualValue(OpPN, NonPhiInVal, ValueEqualPHIs)) + return false; + } else if (Op != NonPhiInVal) + return false; + } + + return true; +} + + +// PHINode simplification +// +Instruction *InstCombiner::visitPHINode(PHINode &PN) { + // If LCSSA is around, don't mess with Phi nodes + if (MustPreserveLCSSA) return 0; + + if (Value *V = PN.hasConstantValue()) + return ReplaceInstUsesWith(PN, V); + + // If all PHI operands are the same operation, pull them through the PHI, + // reducing code size. + if (isa<Instruction>(PN.getIncomingValue(0)) && + isa<Instruction>(PN.getIncomingValue(1)) && + cast<Instruction>(PN.getIncomingValue(0))->getOpcode() == + cast<Instruction>(PN.getIncomingValue(1))->getOpcode() && + // FIXME: The hasOneUse check will fail for PHIs that use the value more + // than themselves more than once. + PN.getIncomingValue(0)->hasOneUse()) + if (Instruction *Result = FoldPHIArgOpIntoPHI(PN)) + return Result; + + // If this is a trivial cycle in the PHI node graph, remove it. Basically, if + // this PHI only has a single use (a PHI), and if that PHI only has one use (a + // PHI)... break the cycle. + if (PN.hasOneUse()) { + Instruction *PHIUser = cast<Instruction>(PN.use_back()); + if (PHINode *PU = dyn_cast<PHINode>(PHIUser)) { + SmallPtrSet<PHINode*, 16> PotentiallyDeadPHIs; + PotentiallyDeadPHIs.insert(&PN); + if (DeadPHICycle(PU, PotentiallyDeadPHIs)) + return ReplaceInstUsesWith(PN, UndefValue::get(PN.getType())); + } + + // If this phi has a single use, and if that use just computes a value for + // the next iteration of a loop, delete the phi. This occurs with unused + // induction variables, e.g. "for (int j = 0; ; ++j);". Detecting this + // common case here is good because the only other things that catch this + // are induction variable analysis (sometimes) and ADCE, which is only run + // late. + if (PHIUser->hasOneUse() && + (isa<BinaryOperator>(PHIUser) || isa<GetElementPtrInst>(PHIUser)) && + PHIUser->use_back() == &PN) { + return ReplaceInstUsesWith(PN, UndefValue::get(PN.getType())); + } + } + + // We sometimes end up with phi cycles that non-obviously end up being the + // same value, for example: + // z = some value; x = phi (y, z); y = phi (x, z) + // where the phi nodes don't necessarily need to be in the same block. Do a + // quick check to see if the PHI node only contains a single non-phi value, if + // so, scan to see if the phi cycle is actually equal to that value. + { + unsigned InValNo = 0, NumOperandVals = PN.getNumIncomingValues(); + // Scan for the first non-phi operand. + while (InValNo != NumOperandVals && + isa<PHINode>(PN.getIncomingValue(InValNo))) + ++InValNo; + + if (InValNo != NumOperandVals) { + Value *NonPhiInVal = PN.getOperand(InValNo); + + // Scan the rest of the operands to see if there are any conflicts, if so + // there is no need to recursively scan other phis. + for (++InValNo; InValNo != NumOperandVals; ++InValNo) { + Value *OpVal = PN.getIncomingValue(InValNo); + if (OpVal != NonPhiInVal && !isa<PHINode>(OpVal)) + break; + } + + // If we scanned over all operands, then we have one unique value plus + // phi values. Scan PHI nodes to see if they all merge in each other or + // the value. + if (InValNo == NumOperandVals) { + SmallPtrSet<PHINode*, 16> ValueEqualPHIs; + if (PHIsEqualValue(&PN, NonPhiInVal, ValueEqualPHIs)) + return ReplaceInstUsesWith(PN, NonPhiInVal); + } + } + } + return 0; +} + +static Value *InsertCastToIntPtrTy(Value *V, const Type *DTy, + Instruction *InsertPoint, + InstCombiner *IC) { + unsigned PtrSize = DTy->getPrimitiveSizeInBits(); + unsigned VTySize = V->getType()->getPrimitiveSizeInBits(); + // We must cast correctly to the pointer type. Ensure that we + // sign extend the integer value if it is smaller as this is + // used for address computation. + Instruction::CastOps opcode = + (VTySize < PtrSize ? Instruction::SExt : + (VTySize == PtrSize ? Instruction::BitCast : Instruction::Trunc)); + return IC->InsertCastBefore(opcode, V, DTy, *InsertPoint); +} + + +Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { + Value *PtrOp = GEP.getOperand(0); + // Is it 'getelementptr %P, i32 0' or 'getelementptr %P' + // If so, eliminate the noop. + if (GEP.getNumOperands() == 1) + return ReplaceInstUsesWith(GEP, PtrOp); + + if (isa<UndefValue>(GEP.getOperand(0))) + return ReplaceInstUsesWith(GEP, UndefValue::get(GEP.getType())); + + bool HasZeroPointerIndex = false; + if (Constant *C = dyn_cast<Constant>(GEP.getOperand(1))) + HasZeroPointerIndex = C->isNullValue(); + + if (GEP.getNumOperands() == 2 && HasZeroPointerIndex) + return ReplaceInstUsesWith(GEP, PtrOp); + + // Eliminate unneeded casts for indices. + bool MadeChange = false; + + gep_type_iterator GTI = gep_type_begin(GEP); + for (User::op_iterator i = GEP.op_begin() + 1, e = GEP.op_end(); + i != e; ++i, ++GTI) { + if (isa<SequentialType>(*GTI)) { + if (CastInst *CI = dyn_cast<CastInst>(*i)) { + if (CI->getOpcode() == Instruction::ZExt || + CI->getOpcode() == Instruction::SExt) { + const Type *SrcTy = CI->getOperand(0)->getType(); + // We can eliminate a cast from i32 to i64 iff the target + // is a 32-bit pointer target. + if (SrcTy->getPrimitiveSizeInBits() >= TD->getPointerSizeInBits()) { + MadeChange = true; + *i = CI->getOperand(0); + } + } + } + // If we are using a wider index than needed for this platform, shrink it + // to what we need. If narrower, sign-extend it to what we need. + // If the incoming value needs a cast instruction, + // insert it. This explicit cast can make subsequent optimizations more + // obvious. + Value *Op = *i; + if (TD->getTypeSizeInBits(Op->getType()) > TD->getPointerSizeInBits()) { + if (Constant *C = dyn_cast<Constant>(Op)) { + *i = ConstantExpr::getTrunc(C, TD->getIntPtrType()); + MadeChange = true; + } else { + Op = InsertCastBefore(Instruction::Trunc, Op, TD->getIntPtrType(), + GEP); + *i = Op; + MadeChange = true; + } + } else if (TD->getTypeSizeInBits(Op->getType()) < TD->getPointerSizeInBits()) { + if (Constant *C = dyn_cast<Constant>(Op)) { + *i = ConstantExpr::getSExt(C, TD->getIntPtrType()); + MadeChange = true; + } else { + Op = InsertCastBefore(Instruction::SExt, Op, TD->getIntPtrType(), + GEP); + *i = Op; + MadeChange = true; + } + } + } + } + if (MadeChange) return &GEP; + + // Combine Indices - If the source pointer to this getelementptr instruction + // is a getelementptr instruction, combine the indices of the two + // getelementptr instructions into a single instruction. + // + SmallVector<Value*, 8> SrcGEPOperands; + if (User *Src = dyn_castGetElementPtr(PtrOp)) + SrcGEPOperands.append(Src->op_begin(), Src->op_end()); + + if (!SrcGEPOperands.empty()) { + // Note that if our source is a gep chain itself that we wait for that + // chain to be resolved before we perform this transformation. This + // avoids us creating a TON of code in some cases. + // + if (isa<GetElementPtrInst>(SrcGEPOperands[0]) && + cast<Instruction>(SrcGEPOperands[0])->getNumOperands() == 2) + return 0; // Wait until our source is folded to completion. + + SmallVector<Value*, 8> Indices; + + // Find out whether the last index in the source GEP is a sequential idx. + bool EndsWithSequential = false; + for (gep_type_iterator I = gep_type_begin(*cast<User>(PtrOp)), + E = gep_type_end(*cast<User>(PtrOp)); I != E; ++I) + EndsWithSequential = !isa<StructType>(*I); + + // Can we combine the two pointer arithmetics offsets? + if (EndsWithSequential) { + // Replace: gep (gep %P, long B), long A, ... + // With: T = long A+B; gep %P, T, ... + // + Value *Sum, *SO1 = SrcGEPOperands.back(), *GO1 = GEP.getOperand(1); + if (SO1 == Constant::getNullValue(SO1->getType())) { + Sum = GO1; + } else if (GO1 == Constant::getNullValue(GO1->getType())) { + Sum = SO1; + } else { + // If they aren't the same type, convert both to an integer of the + // target's pointer size. + if (SO1->getType() != GO1->getType()) { + if (Constant *SO1C = dyn_cast<Constant>(SO1)) { + SO1 = ConstantExpr::getIntegerCast(SO1C, GO1->getType(), true); + } else if (Constant *GO1C = dyn_cast<Constant>(GO1)) { + GO1 = ConstantExpr::getIntegerCast(GO1C, SO1->getType(), true); + } else { + unsigned PS = TD->getPointerSizeInBits(); + if (TD->getTypeSizeInBits(SO1->getType()) == PS) { + // Convert GO1 to SO1's type. + GO1 = InsertCastToIntPtrTy(GO1, SO1->getType(), &GEP, this); + + } else if (TD->getTypeSizeInBits(GO1->getType()) == PS) { + // Convert SO1 to GO1's type. + SO1 = InsertCastToIntPtrTy(SO1, GO1->getType(), &GEP, this); + } else { + const Type *PT = TD->getIntPtrType(); + SO1 = InsertCastToIntPtrTy(SO1, PT, &GEP, this); + GO1 = InsertCastToIntPtrTy(GO1, PT, &GEP, this); + } + } + } + if (isa<Constant>(SO1) && isa<Constant>(GO1)) + Sum = ConstantExpr::getAdd(cast<Constant>(SO1), cast<Constant>(GO1)); + else { + Sum = BinaryOperator::CreateAdd(SO1, GO1, PtrOp->getName()+".sum"); + InsertNewInstBefore(cast<Instruction>(Sum), GEP); + } + } + + // Recycle the GEP we already have if possible. + if (SrcGEPOperands.size() == 2) { + GEP.setOperand(0, SrcGEPOperands[0]); + GEP.setOperand(1, Sum); + return &GEP; + } else { + Indices.insert(Indices.end(), SrcGEPOperands.begin()+1, + SrcGEPOperands.end()-1); + Indices.push_back(Sum); + Indices.insert(Indices.end(), GEP.op_begin()+2, GEP.op_end()); + } + } else if (isa<Constant>(*GEP.idx_begin()) && + cast<Constant>(*GEP.idx_begin())->isNullValue() && + SrcGEPOperands.size() != 1) { + // Otherwise we can do the fold if the first index of the GEP is a zero + Indices.insert(Indices.end(), SrcGEPOperands.begin()+1, + SrcGEPOperands.end()); + Indices.insert(Indices.end(), GEP.idx_begin()+1, GEP.idx_end()); + } + + if (!Indices.empty()) + return GetElementPtrInst::Create(SrcGEPOperands[0], Indices.begin(), + Indices.end(), GEP.getName()); + + } else if (GlobalValue *GV = dyn_cast<GlobalValue>(PtrOp)) { + // GEP of global variable. If all of the indices for this GEP are + // constants, we can promote this to a constexpr instead of an instruction. + + // Scan for nonconstants... + SmallVector<Constant*, 8> Indices; + User::op_iterator I = GEP.idx_begin(), E = GEP.idx_end(); + for (; I != E && isa<Constant>(*I); ++I) + Indices.push_back(cast<Constant>(*I)); + + if (I == E) { // If they are all constants... + Constant *CE = ConstantExpr::getGetElementPtr(GV, + &Indices[0],Indices.size()); + + // Replace all uses of the GEP with the new constexpr... + return ReplaceInstUsesWith(GEP, CE); + } + } else if (Value *X = getBitCastOperand(PtrOp)) { // Is the operand a cast? + if (!isa<PointerType>(X->getType())) { + // Not interesting. Source pointer must be a cast from pointer. + } else if (HasZeroPointerIndex) { + // transform: GEP (bitcast [10 x i8]* X to [0 x i8]*), i32 0, ... + // into : GEP [10 x i8]* X, i32 0, ... + // + // Likewise, transform: GEP (bitcast i8* X to [0 x i8]*), i32 0, ... + // into : GEP i8* X, ... + // + // This occurs when the program declares an array extern like "int X[];" + const PointerType *CPTy = cast<PointerType>(PtrOp->getType()); + const PointerType *XTy = cast<PointerType>(X->getType()); + if (const ArrayType *CATy = + dyn_cast<ArrayType>(CPTy->getElementType())) { + // GEP (bitcast i8* X to [0 x i8]*), i32 0, ... ? + if (CATy->getElementType() == XTy->getElementType()) { + // -> GEP i8* X, ... + SmallVector<Value*, 8> Indices(GEP.idx_begin()+1, GEP.idx_end()); + return GetElementPtrInst::Create(X, Indices.begin(), Indices.end(), + GEP.getName()); + } else if (const ArrayType *XATy = + dyn_cast<ArrayType>(XTy->getElementType())) { + // GEP (bitcast [10 x i8]* X to [0 x i8]*), i32 0, ... ? + if (CATy->getElementType() == XATy->getElementType()) { + // -> GEP [10 x i8]* X, i32 0, ... + // At this point, we know that the cast source type is a pointer + // to an array of the same type as the destination pointer + // array. Because the array type is never stepped over (there + // is a leading zero) we can fold the cast into this GEP. + GEP.setOperand(0, X); + return &GEP; + } + } + } + } else if (GEP.getNumOperands() == 2) { + // Transform things like: + // %t = getelementptr i32* bitcast ([2 x i32]* %str to i32*), i32 %V + // into: %t1 = getelementptr [2 x i32]* %str, i32 0, i32 %V; bitcast + const Type *SrcElTy = cast<PointerType>(X->getType())->getElementType(); + const Type *ResElTy=cast<PointerType>(PtrOp->getType())->getElementType(); + if (isa<ArrayType>(SrcElTy) && + TD->getTypeAllocSize(cast<ArrayType>(SrcElTy)->getElementType()) == + TD->getTypeAllocSize(ResElTy)) { + Value *Idx[2]; + Idx[0] = Constant::getNullValue(Type::Int32Ty); + Idx[1] = GEP.getOperand(1); + Value *V = InsertNewInstBefore( + GetElementPtrInst::Create(X, Idx, Idx + 2, GEP.getName()), GEP); + // V and GEP are both pointer types --> BitCast + return new BitCastInst(V, GEP.getType()); + } + + // Transform things like: + // getelementptr i8* bitcast ([100 x double]* X to i8*), i32 %tmp + // (where tmp = 8*tmp2) into: + // getelementptr [100 x double]* %arr, i32 0, i32 %tmp2; bitcast + + if (isa<ArrayType>(SrcElTy) && ResElTy == Type::Int8Ty) { + uint64_t ArrayEltSize = + TD->getTypeAllocSize(cast<ArrayType>(SrcElTy)->getElementType()); + + // Check to see if "tmp" is a scale by a multiple of ArrayEltSize. We + // allow either a mul, shift, or constant here. + Value *NewIdx = 0; + ConstantInt *Scale = 0; + if (ArrayEltSize == 1) { + NewIdx = GEP.getOperand(1); + Scale = ConstantInt::get(NewIdx->getType(), 1); + } else if (ConstantInt *CI = dyn_cast<ConstantInt>(GEP.getOperand(1))) { + NewIdx = ConstantInt::get(CI->getType(), 1); + Scale = CI; + } else if (Instruction *Inst =dyn_cast<Instruction>(GEP.getOperand(1))){ + if (Inst->getOpcode() == Instruction::Shl && + isa<ConstantInt>(Inst->getOperand(1))) { + ConstantInt *ShAmt = cast<ConstantInt>(Inst->getOperand(1)); + uint32_t ShAmtVal = ShAmt->getLimitedValue(64); + Scale = ConstantInt::get(Inst->getType(), 1ULL << ShAmtVal); + NewIdx = Inst->getOperand(0); + } else if (Inst->getOpcode() == Instruction::Mul && + isa<ConstantInt>(Inst->getOperand(1))) { + Scale = cast<ConstantInt>(Inst->getOperand(1)); + NewIdx = Inst->getOperand(0); + } + } + + // If the index will be to exactly the right offset with the scale taken + // out, perform the transformation. Note, we don't know whether Scale is + // signed or not. We'll use unsigned version of division/modulo + // operation after making sure Scale doesn't have the sign bit set. + if (ArrayEltSize && Scale && Scale->getSExtValue() >= 0LL && + Scale->getZExtValue() % ArrayEltSize == 0) { + Scale = ConstantInt::get(Scale->getType(), + Scale->getZExtValue() / ArrayEltSize); + if (Scale->getZExtValue() != 1) { + Constant *C = ConstantExpr::getIntegerCast(Scale, NewIdx->getType(), + false /*ZExt*/); + Instruction *Sc = BinaryOperator::CreateMul(NewIdx, C, "idxscale"); + NewIdx = InsertNewInstBefore(Sc, GEP); + } + + // Insert the new GEP instruction. + Value *Idx[2]; + Idx[0] = Constant::getNullValue(Type::Int32Ty); + Idx[1] = NewIdx; + Instruction *NewGEP = + GetElementPtrInst::Create(X, Idx, Idx + 2, GEP.getName()); + NewGEP = InsertNewInstBefore(NewGEP, GEP); + // The NewGEP must be pointer typed, so must the old one -> BitCast + return new BitCastInst(NewGEP, GEP.getType()); + } + } + } + } + + /// See if we can simplify: + /// X = bitcast A to B* + /// Y = gep X, <...constant indices...> + /// into a gep of the original struct. This is important for SROA and alias + /// analysis of unions. If "A" is also a bitcast, wait for A/X to be merged. + if (BitCastInst *BCI = dyn_cast<BitCastInst>(PtrOp)) { + if (!isa<BitCastInst>(BCI->getOperand(0)) && GEP.hasAllConstantIndices()) { + // Determine how much the GEP moves the pointer. We are guaranteed to get + // a constant back from EmitGEPOffset. + ConstantInt *OffsetV = cast<ConstantInt>(EmitGEPOffset(&GEP, GEP, *this)); + int64_t Offset = OffsetV->getSExtValue(); + + // If this GEP instruction doesn't move the pointer, just replace the GEP + // with a bitcast of the real input to the dest type. + if (Offset == 0) { + // If the bitcast is of an allocation, and the allocation will be + // converted to match the type of the cast, don't touch this. + if (isa<AllocationInst>(BCI->getOperand(0))) { + // See if the bitcast simplifies, if so, don't nuke this GEP yet. + if (Instruction *I = visitBitCast(*BCI)) { + if (I != BCI) { + I->takeName(BCI); + BCI->getParent()->getInstList().insert(BCI, I); + ReplaceInstUsesWith(*BCI, I); + } + return &GEP; + } + } + return new BitCastInst(BCI->getOperand(0), GEP.getType()); + } + + // Otherwise, if the offset is non-zero, we need to find out if there is a + // field at Offset in 'A's type. If so, we can pull the cast through the + // GEP. + SmallVector<Value*, 8> NewIndices; + const Type *InTy = + cast<PointerType>(BCI->getOperand(0)->getType())->getElementType(); + if (FindElementAtOffset(InTy, Offset, NewIndices, TD)) { + Instruction *NGEP = + GetElementPtrInst::Create(BCI->getOperand(0), NewIndices.begin(), + NewIndices.end()); + if (NGEP->getType() == GEP.getType()) return NGEP; + InsertNewInstBefore(NGEP, GEP); + NGEP->takeName(&GEP); + return new BitCastInst(NGEP, GEP.getType()); + } + } + } + + return 0; +} + +Instruction *InstCombiner::visitAllocationInst(AllocationInst &AI) { + // Convert: malloc Ty, C - where C is a constant != 1 into: malloc [C x Ty], 1 + if (AI.isArrayAllocation()) { // Check C != 1 + if (const ConstantInt *C = dyn_cast<ConstantInt>(AI.getArraySize())) { + const Type *NewTy = + ArrayType::get(AI.getAllocatedType(), C->getZExtValue()); + AllocationInst *New = 0; + + // Create and insert the replacement instruction... + if (isa<MallocInst>(AI)) + New = new MallocInst(NewTy, 0, AI.getAlignment(), AI.getName()); + else { + assert(isa<AllocaInst>(AI) && "Unknown type of allocation inst!"); + New = new AllocaInst(NewTy, 0, AI.getAlignment(), AI.getName()); + } + + InsertNewInstBefore(New, AI); + + // Scan to the end of the allocation instructions, to skip over a block of + // allocas if possible...also skip interleaved debug info + // + BasicBlock::iterator It = New; + while (isa<AllocationInst>(*It) || isa<DbgInfoIntrinsic>(*It)) ++It; + + // Now that I is pointing to the first non-allocation-inst in the block, + // insert our getelementptr instruction... + // + Value *NullIdx = Constant::getNullValue(Type::Int32Ty); + Value *Idx[2]; + Idx[0] = NullIdx; + Idx[1] = NullIdx; + Value *V = GetElementPtrInst::Create(New, Idx, Idx + 2, + New->getName()+".sub", It); + + // Now make everything use the getelementptr instead of the original + // allocation. + return ReplaceInstUsesWith(AI, V); + } else if (isa<UndefValue>(AI.getArraySize())) { + return ReplaceInstUsesWith(AI, Constant::getNullValue(AI.getType())); + } + } + + if (isa<AllocaInst>(AI) && AI.getAllocatedType()->isSized()) { + // If alloca'ing a zero byte object, replace the alloca with a null pointer. + // Note that we only do this for alloca's, because malloc should allocate + // and return a unique pointer, even for a zero byte allocation. + if (TD->getTypeAllocSize(AI.getAllocatedType()) == 0) + return ReplaceInstUsesWith(AI, Constant::getNullValue(AI.getType())); + + // If the alignment is 0 (unspecified), assign it the preferred alignment. + if (AI.getAlignment() == 0) + AI.setAlignment(TD->getPrefTypeAlignment(AI.getAllocatedType())); + } + + return 0; +} + +Instruction *InstCombiner::visitFreeInst(FreeInst &FI) { + Value *Op = FI.getOperand(0); + + // free undef -> unreachable. + if (isa<UndefValue>(Op)) { + // Insert a new store to null because we cannot modify the CFG here. + new StoreInst(ConstantInt::getTrue(), + UndefValue::get(PointerType::getUnqual(Type::Int1Ty)), &FI); + return EraseInstFromFunction(FI); + } + + // If we have 'free null' delete the instruction. This can happen in stl code + // when lots of inlining happens. + if (isa<ConstantPointerNull>(Op)) + return EraseInstFromFunction(FI); + + // Change free <ty>* (cast <ty2>* X to <ty>*) into free <ty2>* X + if (BitCastInst *CI = dyn_cast<BitCastInst>(Op)) { + FI.setOperand(0, CI->getOperand(0)); + return &FI; + } + + // Change free (gep X, 0,0,0,0) into free(X) + if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(Op)) { + if (GEPI->hasAllZeroIndices()) { + AddToWorkList(GEPI); + FI.setOperand(0, GEPI->getOperand(0)); + return &FI; + } + } + + // Change free(malloc) into nothing, if the malloc has a single use. + if (MallocInst *MI = dyn_cast<MallocInst>(Op)) + if (MI->hasOneUse()) { + EraseInstFromFunction(FI); + return EraseInstFromFunction(*MI); + } + + return 0; +} + + +/// InstCombineLoadCast - Fold 'load (cast P)' -> cast (load P)' when possible. +static Instruction *InstCombineLoadCast(InstCombiner &IC, LoadInst &LI, + const TargetData *TD) { + User *CI = cast<User>(LI.getOperand(0)); + Value *CastOp = CI->getOperand(0); + + if (TD) { + if (ConstantExpr *CE = dyn_cast<ConstantExpr>(CI)) { + // Instead of loading constant c string, use corresponding integer value + // directly if string length is small enough. + std::string Str; + if (GetConstantStringInfo(CE->getOperand(0), Str) && !Str.empty()) { + unsigned len = Str.length(); + const Type *Ty = cast<PointerType>(CE->getType())->getElementType(); + unsigned numBits = Ty->getPrimitiveSizeInBits(); + // Replace LI with immediate integer store. + if ((numBits >> 3) == len + 1) { + APInt StrVal(numBits, 0); + APInt SingleChar(numBits, 0); + if (TD->isLittleEndian()) { + for (signed i = len-1; i >= 0; i--) { + SingleChar = (uint64_t) Str[i] & UCHAR_MAX; + StrVal = (StrVal << 8) | SingleChar; + } + } else { + for (unsigned i = 0; i < len; i++) { + SingleChar = (uint64_t) Str[i] & UCHAR_MAX; + StrVal = (StrVal << 8) | SingleChar; + } + // Append NULL at the end. + SingleChar = 0; + StrVal = (StrVal << 8) | SingleChar; + } + Value *NL = ConstantInt::get(StrVal); + return IC.ReplaceInstUsesWith(LI, NL); + } + } + } + } + + const PointerType *DestTy = cast<PointerType>(CI->getType()); + const Type *DestPTy = DestTy->getElementType(); + if (const PointerType *SrcTy = dyn_cast<PointerType>(CastOp->getType())) { + + // If the address spaces don't match, don't eliminate the cast. + if (DestTy->getAddressSpace() != SrcTy->getAddressSpace()) + return 0; + + const Type *SrcPTy = SrcTy->getElementType(); + + if (DestPTy->isInteger() || isa<PointerType>(DestPTy) || + isa<VectorType>(DestPTy)) { + // If the source is an array, the code below will not succeed. Check to + // see if a trivial 'gep P, 0, 0' will help matters. Only do this for + // constants. + if (const ArrayType *ASrcTy = dyn_cast<ArrayType>(SrcPTy)) + if (Constant *CSrc = dyn_cast<Constant>(CastOp)) + if (ASrcTy->getNumElements() != 0) { + Value *Idxs[2]; + Idxs[0] = Idxs[1] = Constant::getNullValue(Type::Int32Ty); + CastOp = ConstantExpr::getGetElementPtr(CSrc, Idxs, 2); + SrcTy = cast<PointerType>(CastOp->getType()); + SrcPTy = SrcTy->getElementType(); + } + + if ((SrcPTy->isInteger() || isa<PointerType>(SrcPTy) || + isa<VectorType>(SrcPTy)) && + // Do not allow turning this into a load of an integer, which is then + // casted to a pointer, this pessimizes pointer analysis a lot. + (isa<PointerType>(SrcPTy) == isa<PointerType>(LI.getType())) && + IC.getTargetData().getTypeSizeInBits(SrcPTy) == + IC.getTargetData().getTypeSizeInBits(DestPTy)) { + + // Okay, we are casting from one integer or pointer type to another of + // the same size. Instead of casting the pointer before the load, cast + // the result of the loaded value. + Value *NewLoad = IC.InsertNewInstBefore(new LoadInst(CastOp, + CI->getName(), + LI.isVolatile()),LI); + // Now cast the result of the load. + return new BitCastInst(NewLoad, LI.getType()); + } + } + } + return 0; +} + +/// isSafeToLoadUnconditionally - Return true if we know that executing a load +/// from this value cannot trap. If it is not obviously safe to load from the +/// specified pointer, we do a quick local scan of the basic block containing +/// ScanFrom, to determine if the address is already accessed. +static bool isSafeToLoadUnconditionally(Value *V, Instruction *ScanFrom) { + // If it is an alloca it is always safe to load from. + if (isa<AllocaInst>(V)) return true; + + // If it is a global variable it is mostly safe to load from. + if (const GlobalValue *GV = dyn_cast<GlobalVariable>(V)) + // Don't try to evaluate aliases. External weak GV can be null. + return !isa<GlobalAlias>(GV) && !GV->hasExternalWeakLinkage(); + + // Otherwise, be a little bit agressive by scanning the local block where we + // want to check to see if the pointer is already being loaded or stored + // from/to. If so, the previous load or store would have already trapped, + // so there is no harm doing an extra load (also, CSE will later eliminate + // the load entirely). + BasicBlock::iterator BBI = ScanFrom, E = ScanFrom->getParent()->begin(); + + while (BBI != E) { + --BBI; + + // If we see a free or a call (which might do a free) the pointer could be + // marked invalid. + if (isa<FreeInst>(BBI) || + (isa<CallInst>(BBI) && !isa<DbgInfoIntrinsic>(BBI))) + return false; + + if (LoadInst *LI = dyn_cast<LoadInst>(BBI)) { + if (LI->getOperand(0) == V) return true; + } else if (StoreInst *SI = dyn_cast<StoreInst>(BBI)) { + if (SI->getOperand(1) == V) return true; + } + + } + return false; +} + +Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { + Value *Op = LI.getOperand(0); + + // Attempt to improve the alignment. + unsigned KnownAlign = + GetOrEnforceKnownAlignment(Op, TD->getPrefTypeAlignment(LI.getType())); + if (KnownAlign > + (LI.getAlignment() == 0 ? TD->getABITypeAlignment(LI.getType()) : + LI.getAlignment())) + LI.setAlignment(KnownAlign); + + // load (cast X) --> cast (load X) iff safe + if (isa<CastInst>(Op)) + if (Instruction *Res = InstCombineLoadCast(*this, LI, TD)) + return Res; + + // None of the following transforms are legal for volatile loads. + if (LI.isVolatile()) return 0; + + // Do really simple store-to-load forwarding and load CSE, to catch cases + // where there are several consequtive memory accesses to the same location, + // separated by a few arithmetic operations. + BasicBlock::iterator BBI = &LI; + if (Value *AvailableVal = FindAvailableLoadedValue(Op, LI.getParent(), BBI,6)) + return ReplaceInstUsesWith(LI, AvailableVal); + + if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(Op)) { + const Value *GEPI0 = GEPI->getOperand(0); + // TODO: Consider a target hook for valid address spaces for this xform. + if (isa<ConstantPointerNull>(GEPI0) && + cast<PointerType>(GEPI0->getType())->getAddressSpace() == 0) { + // Insert a new store to null instruction before the load to indicate + // that this code is not reachable. We do this instead of inserting + // an unreachable instruction directly because we cannot modify the + // CFG. + new StoreInst(UndefValue::get(LI.getType()), + Constant::getNullValue(Op->getType()), &LI); + return ReplaceInstUsesWith(LI, UndefValue::get(LI.getType())); + } + } + + if (Constant *C = dyn_cast<Constant>(Op)) { + // load null/undef -> undef + // TODO: Consider a target hook for valid address spaces for this xform. + if (isa<UndefValue>(C) || (C->isNullValue() && + cast<PointerType>(Op->getType())->getAddressSpace() == 0)) { + // Insert a new store to null instruction before the load to indicate that + // this code is not reachable. We do this instead of inserting an + // unreachable instruction directly because we cannot modify the CFG. + new StoreInst(UndefValue::get(LI.getType()), + Constant::getNullValue(Op->getType()), &LI); + return ReplaceInstUsesWith(LI, UndefValue::get(LI.getType())); + } + + // Instcombine load (constant global) into the value loaded. + if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Op)) + if (GV->isConstant() && GV->hasDefinitiveInitializer()) + return ReplaceInstUsesWith(LI, GV->getInitializer()); + + // Instcombine load (constantexpr_GEP global, 0, ...) into the value loaded. + if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Op)) { + if (CE->getOpcode() == Instruction::GetElementPtr) { + if (GlobalVariable *GV = dyn_cast<GlobalVariable>(CE->getOperand(0))) + if (GV->isConstant() && GV->hasDefinitiveInitializer()) + if (Constant *V = + ConstantFoldLoadThroughGEPConstantExpr(GV->getInitializer(), CE)) + return ReplaceInstUsesWith(LI, V); + if (CE->getOperand(0)->isNullValue()) { + // Insert a new store to null instruction before the load to indicate + // that this code is not reachable. We do this instead of inserting + // an unreachable instruction directly because we cannot modify the + // CFG. + new StoreInst(UndefValue::get(LI.getType()), + Constant::getNullValue(Op->getType()), &LI); + return ReplaceInstUsesWith(LI, UndefValue::get(LI.getType())); + } + + } else if (CE->isCast()) { + if (Instruction *Res = InstCombineLoadCast(*this, LI, TD)) + return Res; + } + } + } + + // If this load comes from anywhere in a constant global, and if the global + // is all undef or zero, we know what it loads. + if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Op->getUnderlyingObject())){ + if (GV->isConstant() && GV->hasDefinitiveInitializer()) { + if (GV->getInitializer()->isNullValue()) + return ReplaceInstUsesWith(LI, Constant::getNullValue(LI.getType())); + else if (isa<UndefValue>(GV->getInitializer())) + return ReplaceInstUsesWith(LI, UndefValue::get(LI.getType())); + } + } + + if (Op->hasOneUse()) { + // Change select and PHI nodes to select values instead of addresses: this + // helps alias analysis out a lot, allows many others simplifications, and + // exposes redundancy in the code. + // + // Note that we cannot do the transformation unless we know that the + // introduced loads cannot trap! Something like this is valid as long as + // the condition is always false: load (select bool %C, int* null, int* %G), + // but it would not be valid if we transformed it to load from null + // unconditionally. + // + if (SelectInst *SI = dyn_cast<SelectInst>(Op)) { + // load (select (Cond, &V1, &V2)) --> select(Cond, load &V1, load &V2). + if (isSafeToLoadUnconditionally(SI->getOperand(1), SI) && + isSafeToLoadUnconditionally(SI->getOperand(2), SI)) { + Value *V1 = InsertNewInstBefore(new LoadInst(SI->getOperand(1), + SI->getOperand(1)->getName()+".val"), LI); + Value *V2 = InsertNewInstBefore(new LoadInst(SI->getOperand(2), + SI->getOperand(2)->getName()+".val"), LI); + return SelectInst::Create(SI->getCondition(), V1, V2); + } + + // load (select (cond, null, P)) -> load P + if (Constant *C = dyn_cast<Constant>(SI->getOperand(1))) + if (C->isNullValue()) { + LI.setOperand(0, SI->getOperand(2)); + return &LI; + } + + // load (select (cond, P, null)) -> load P + if (Constant *C = dyn_cast<Constant>(SI->getOperand(2))) + if (C->isNullValue()) { + LI.setOperand(0, SI->getOperand(1)); + return &LI; + } + } + } + return 0; +} + +/// InstCombineStoreToCast - Fold store V, (cast P) -> store (cast V), P +/// when possible. This makes it generally easy to do alias analysis and/or +/// SROA/mem2reg of the memory object. +static Instruction *InstCombineStoreToCast(InstCombiner &IC, StoreInst &SI) { + User *CI = cast<User>(SI.getOperand(1)); + Value *CastOp = CI->getOperand(0); + + const Type *DestPTy = cast<PointerType>(CI->getType())->getElementType(); + const PointerType *SrcTy = dyn_cast<PointerType>(CastOp->getType()); + if (SrcTy == 0) return 0; + + const Type *SrcPTy = SrcTy->getElementType(); + + if (!DestPTy->isInteger() && !isa<PointerType>(DestPTy)) + return 0; + + /// NewGEPIndices - If SrcPTy is an aggregate type, we can emit a "noop gep" + /// to its first element. This allows us to handle things like: + /// store i32 xxx, (bitcast {foo*, float}* %P to i32*) + /// on 32-bit hosts. + SmallVector<Value*, 4> NewGEPIndices; + + // If the source is an array, the code below will not succeed. Check to + // see if a trivial 'gep P, 0, 0' will help matters. Only do this for + // constants. + if (isa<ArrayType>(SrcPTy) || isa<StructType>(SrcPTy)) { + // Index through pointer. + Constant *Zero = Constant::getNullValue(Type::Int32Ty); + NewGEPIndices.push_back(Zero); + + while (1) { + if (const StructType *STy = dyn_cast<StructType>(SrcPTy)) { + if (!STy->getNumElements()) /* Struct can be empty {} */ + break; + NewGEPIndices.push_back(Zero); + SrcPTy = STy->getElementType(0); + } else if (const ArrayType *ATy = dyn_cast<ArrayType>(SrcPTy)) { + NewGEPIndices.push_back(Zero); + SrcPTy = ATy->getElementType(); + } else { + break; + } + } + + SrcTy = PointerType::get(SrcPTy, SrcTy->getAddressSpace()); + } + + if (!SrcPTy->isInteger() && !isa<PointerType>(SrcPTy)) + return 0; + + // If the pointers point into different address spaces or if they point to + // values with different sizes, we can't do the transformation. + if (SrcTy->getAddressSpace() != + cast<PointerType>(CI->getType())->getAddressSpace() || + IC.getTargetData().getTypeSizeInBits(SrcPTy) != + IC.getTargetData().getTypeSizeInBits(DestPTy)) + return 0; + + // Okay, we are casting from one integer or pointer type to another of + // the same size. Instead of casting the pointer before + // the store, cast the value to be stored. + Value *NewCast; + Value *SIOp0 = SI.getOperand(0); + Instruction::CastOps opcode = Instruction::BitCast; + const Type* CastSrcTy = SIOp0->getType(); + const Type* CastDstTy = SrcPTy; + if (isa<PointerType>(CastDstTy)) { + if (CastSrcTy->isInteger()) + opcode = Instruction::IntToPtr; + } else if (isa<IntegerType>(CastDstTy)) { + if (isa<PointerType>(SIOp0->getType())) + opcode = Instruction::PtrToInt; + } + + // SIOp0 is a pointer to aggregate and this is a store to the first field, + // emit a GEP to index into its first field. + if (!NewGEPIndices.empty()) { + if (Constant *C = dyn_cast<Constant>(CastOp)) + CastOp = ConstantExpr::getGetElementPtr(C, &NewGEPIndices[0], + NewGEPIndices.size()); + else + CastOp = IC.InsertNewInstBefore( + GetElementPtrInst::Create(CastOp, NewGEPIndices.begin(), + NewGEPIndices.end()), SI); + } + + if (Constant *C = dyn_cast<Constant>(SIOp0)) + NewCast = ConstantExpr::getCast(opcode, C, CastDstTy); + else + NewCast = IC.InsertNewInstBefore( + CastInst::Create(opcode, SIOp0, CastDstTy, SIOp0->getName()+".c"), + SI); + return new StoreInst(NewCast, CastOp); +} + +/// equivalentAddressValues - Test if A and B will obviously have the same +/// value. This includes recognizing that %t0 and %t1 will have the same +/// value in code like this: +/// %t0 = getelementptr \@a, 0, 3 +/// store i32 0, i32* %t0 +/// %t1 = getelementptr \@a, 0, 3 +/// %t2 = load i32* %t1 +/// +static bool equivalentAddressValues(Value *A, Value *B) { + // Test if the values are trivially equivalent. + if (A == B) return true; + + // Test if the values come form identical arithmetic instructions. + if (isa<BinaryOperator>(A) || + isa<CastInst>(A) || + isa<PHINode>(A) || + isa<GetElementPtrInst>(A)) + if (Instruction *BI = dyn_cast<Instruction>(B)) + if (cast<Instruction>(A)->isIdenticalTo(BI)) + return true; + + // Otherwise they may not be equivalent. + return false; +} + +// If this instruction has two uses, one of which is a llvm.dbg.declare, +// return the llvm.dbg.declare. +DbgDeclareInst *InstCombiner::hasOneUsePlusDeclare(Value *V) { + if (!V->hasNUses(2)) + return 0; + for (Value::use_iterator UI = V->use_begin(), E = V->use_end(); + UI != E; ++UI) { + if (DbgDeclareInst *DI = dyn_cast<DbgDeclareInst>(UI)) + return DI; + if (isa<BitCastInst>(UI) && UI->hasOneUse()) { + if (DbgDeclareInst *DI = dyn_cast<DbgDeclareInst>(UI->use_begin())) + return DI; + } + } + return 0; +} + +Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { + Value *Val = SI.getOperand(0); + Value *Ptr = SI.getOperand(1); + + if (isa<UndefValue>(Ptr)) { // store X, undef -> noop (even if volatile) + EraseInstFromFunction(SI); + ++NumCombined; + return 0; + } + + // If the RHS is an alloca with a single use, zapify the store, making the + // alloca dead. + // If the RHS is an alloca with a two uses, the other one being a + // llvm.dbg.declare, zapify the store and the declare, making the + // alloca dead. We must do this to prevent declare's from affecting + // codegen. + if (!SI.isVolatile()) { + if (Ptr->hasOneUse()) { + if (isa<AllocaInst>(Ptr)) { + EraseInstFromFunction(SI); + ++NumCombined; + return 0; + } + if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr)) { + if (isa<AllocaInst>(GEP->getOperand(0))) { + if (GEP->getOperand(0)->hasOneUse()) { + EraseInstFromFunction(SI); + ++NumCombined; + return 0; + } + if (DbgDeclareInst *DI = hasOneUsePlusDeclare(GEP->getOperand(0))) { + EraseInstFromFunction(*DI); + EraseInstFromFunction(SI); + ++NumCombined; + return 0; + } + } + } + } + if (DbgDeclareInst *DI = hasOneUsePlusDeclare(Ptr)) { + EraseInstFromFunction(*DI); + EraseInstFromFunction(SI); + ++NumCombined; + return 0; + } + } + + // Attempt to improve the alignment. + unsigned KnownAlign = + GetOrEnforceKnownAlignment(Ptr, TD->getPrefTypeAlignment(Val->getType())); + if (KnownAlign > + (SI.getAlignment() == 0 ? TD->getABITypeAlignment(Val->getType()) : + SI.getAlignment())) + SI.setAlignment(KnownAlign); + + // Do really simple DSE, to catch cases where there are several consecutive + // stores to the same location, separated by a few arithmetic operations. This + // situation often occurs with bitfield accesses. + BasicBlock::iterator BBI = &SI; + for (unsigned ScanInsts = 6; BBI != SI.getParent()->begin() && ScanInsts; + --ScanInsts) { + --BBI; + // Don't count debug info directives, lest they affect codegen, + // and we skip pointer-to-pointer bitcasts, which are NOPs. + // It is necessary for correctness to skip those that feed into a + // llvm.dbg.declare, as these are not present when debugging is off. + if (isa<DbgInfoIntrinsic>(BBI) || + (isa<BitCastInst>(BBI) && isa<PointerType>(BBI->getType()))) { + ScanInsts++; + continue; + } + + if (StoreInst *PrevSI = dyn_cast<StoreInst>(BBI)) { + // Prev store isn't volatile, and stores to the same location? + if (!PrevSI->isVolatile() &&equivalentAddressValues(PrevSI->getOperand(1), + SI.getOperand(1))) { + ++NumDeadStore; + ++BBI; + EraseInstFromFunction(*PrevSI); + continue; + } + break; + } + + // If this is a load, we have to stop. However, if the loaded value is from + // the pointer we're loading and is producing the pointer we're storing, + // then *this* store is dead (X = load P; store X -> P). + if (LoadInst *LI = dyn_cast<LoadInst>(BBI)) { + if (LI == Val && equivalentAddressValues(LI->getOperand(0), Ptr) && + !SI.isVolatile()) { + EraseInstFromFunction(SI); + ++NumCombined; + return 0; + } + // Otherwise, this is a load from some other location. Stores before it + // may not be dead. + break; + } + + // Don't skip over loads or things that can modify memory. + if (BBI->mayWriteToMemory() || BBI->mayReadFromMemory()) + break; + } + + + if (SI.isVolatile()) return 0; // Don't hack volatile stores. + + // store X, null -> turns into 'unreachable' in SimplifyCFG + if (isa<ConstantPointerNull>(Ptr)) { + if (!isa<UndefValue>(Val)) { + SI.setOperand(0, UndefValue::get(Val->getType())); + if (Instruction *U = dyn_cast<Instruction>(Val)) + AddToWorkList(U); // Dropped a use. + ++NumCombined; + } + return 0; // Do not modify these! + } + + // store undef, Ptr -> noop + if (isa<UndefValue>(Val)) { + EraseInstFromFunction(SI); + ++NumCombined; + return 0; + } + + // If the pointer destination is a cast, see if we can fold the cast into the + // source instead. + if (isa<CastInst>(Ptr)) + if (Instruction *Res = InstCombineStoreToCast(*this, SI)) + return Res; + if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Ptr)) + if (CE->isCast()) + if (Instruction *Res = InstCombineStoreToCast(*this, SI)) + return Res; + + + // If this store is the last instruction in the basic block (possibly + // excepting debug info instructions and the pointer bitcasts that feed + // into them), and if the block ends with an unconditional branch, try + // to move it to the successor block. + BBI = &SI; + do { + ++BBI; + } while (isa<DbgInfoIntrinsic>(BBI) || + (isa<BitCastInst>(BBI) && isa<PointerType>(BBI->getType()))); + if (BranchInst *BI = dyn_cast<BranchInst>(BBI)) + if (BI->isUnconditional()) + if (SimplifyStoreAtEndOfBlock(SI)) + return 0; // xform done! + + return 0; +} + +/// SimplifyStoreAtEndOfBlock - Turn things like: +/// if () { *P = v1; } else { *P = v2 } +/// into a phi node with a store in the successor. +/// +/// Simplify things like: +/// *P = v1; if () { *P = v2; } +/// into a phi node with a store in the successor. +/// +bool InstCombiner::SimplifyStoreAtEndOfBlock(StoreInst &SI) { + BasicBlock *StoreBB = SI.getParent(); + + // Check to see if the successor block has exactly two incoming edges. If + // so, see if the other predecessor contains a store to the same location. + // if so, insert a PHI node (if needed) and move the stores down. + BasicBlock *DestBB = StoreBB->getTerminator()->getSuccessor(0); + + // Determine whether Dest has exactly two predecessors and, if so, compute + // the other predecessor. + pred_iterator PI = pred_begin(DestBB); + BasicBlock *OtherBB = 0; + if (*PI != StoreBB) + OtherBB = *PI; + ++PI; + if (PI == pred_end(DestBB)) + return false; + + if (*PI != StoreBB) { + if (OtherBB) + return false; + OtherBB = *PI; + } + if (++PI != pred_end(DestBB)) + return false; + + // Bail out if all the relevant blocks aren't distinct (this can happen, + // for example, if SI is in an infinite loop) + if (StoreBB == DestBB || OtherBB == DestBB) + return false; + + // Verify that the other block ends in a branch and is not otherwise empty. + BasicBlock::iterator BBI = OtherBB->getTerminator(); + BranchInst *OtherBr = dyn_cast<BranchInst>(BBI); + if (!OtherBr || BBI == OtherBB->begin()) + return false; + + // If the other block ends in an unconditional branch, check for the 'if then + // else' case. there is an instruction before the branch. + StoreInst *OtherStore = 0; + if (OtherBr->isUnconditional()) { + --BBI; + // Skip over debugging info. + while (isa<DbgInfoIntrinsic>(BBI) || + (isa<BitCastInst>(BBI) && isa<PointerType>(BBI->getType()))) { + if (BBI==OtherBB->begin()) + return false; + --BBI; + } + // If this isn't a store, or isn't a store to the same location, bail out. + OtherStore = dyn_cast<StoreInst>(BBI); + if (!OtherStore || OtherStore->getOperand(1) != SI.getOperand(1)) + return false; + } else { + // Otherwise, the other block ended with a conditional branch. If one of the + // destinations is StoreBB, then we have the if/then case. + if (OtherBr->getSuccessor(0) != StoreBB && + OtherBr->getSuccessor(1) != StoreBB) + return false; + + // Okay, we know that OtherBr now goes to Dest and StoreBB, so this is an + // if/then triangle. See if there is a store to the same ptr as SI that + // lives in OtherBB. + for (;; --BBI) { + // Check to see if we find the matching store. + if ((OtherStore = dyn_cast<StoreInst>(BBI))) { + if (OtherStore->getOperand(1) != SI.getOperand(1)) + return false; + break; + } + // If we find something that may be using or overwriting the stored + // value, or if we run out of instructions, we can't do the xform. + if (BBI->mayReadFromMemory() || BBI->mayWriteToMemory() || + BBI == OtherBB->begin()) + return false; + } + + // In order to eliminate the store in OtherBr, we have to + // make sure nothing reads or overwrites the stored value in + // StoreBB. + for (BasicBlock::iterator I = StoreBB->begin(); &*I != &SI; ++I) { + // FIXME: This should really be AA driven. + if (I->mayReadFromMemory() || I->mayWriteToMemory()) + return false; + } + } + + // Insert a PHI node now if we need it. + Value *MergedVal = OtherStore->getOperand(0); + if (MergedVal != SI.getOperand(0)) { + PHINode *PN = PHINode::Create(MergedVal->getType(), "storemerge"); + PN->reserveOperandSpace(2); + PN->addIncoming(SI.getOperand(0), SI.getParent()); + PN->addIncoming(OtherStore->getOperand(0), OtherBB); + MergedVal = InsertNewInstBefore(PN, DestBB->front()); + } + + // Advance to a place where it is safe to insert the new store and + // insert it. + BBI = DestBB->getFirstNonPHI(); + InsertNewInstBefore(new StoreInst(MergedVal, SI.getOperand(1), + OtherStore->isVolatile()), *BBI); + + // Nuke the old stores. + EraseInstFromFunction(SI); + EraseInstFromFunction(*OtherStore); + ++NumCombined; + return true; +} + + +Instruction *InstCombiner::visitBranchInst(BranchInst &BI) { + // Change br (not X), label True, label False to: br X, label False, True + Value *X = 0; + BasicBlock *TrueDest; + BasicBlock *FalseDest; + if (match(&BI, m_Br(m_Not(m_Value(X)), TrueDest, FalseDest)) && + !isa<Constant>(X)) { + // Swap Destinations and condition... + BI.setCondition(X); + BI.setSuccessor(0, FalseDest); + BI.setSuccessor(1, TrueDest); + return &BI; + } + + // Cannonicalize fcmp_one -> fcmp_oeq + FCmpInst::Predicate FPred; Value *Y; + if (match(&BI, m_Br(m_FCmp(FPred, m_Value(X), m_Value(Y)), + TrueDest, FalseDest))) + if ((FPred == FCmpInst::FCMP_ONE || FPred == FCmpInst::FCMP_OLE || + FPred == FCmpInst::FCMP_OGE) && BI.getCondition()->hasOneUse()) { + FCmpInst *I = cast<FCmpInst>(BI.getCondition()); + FCmpInst::Predicate NewPred = FCmpInst::getInversePredicate(FPred); + Instruction *NewSCC = new FCmpInst(NewPred, X, Y, "", I); + NewSCC->takeName(I); + // Swap Destinations and condition... + BI.setCondition(NewSCC); + BI.setSuccessor(0, FalseDest); + BI.setSuccessor(1, TrueDest); + RemoveFromWorkList(I); + I->eraseFromParent(); + AddToWorkList(NewSCC); + return &BI; + } + + // Cannonicalize icmp_ne -> icmp_eq + ICmpInst::Predicate IPred; + if (match(&BI, m_Br(m_ICmp(IPred, m_Value(X), m_Value(Y)), + TrueDest, FalseDest))) + if ((IPred == ICmpInst::ICMP_NE || IPred == ICmpInst::ICMP_ULE || + IPred == ICmpInst::ICMP_SLE || IPred == ICmpInst::ICMP_UGE || + IPred == ICmpInst::ICMP_SGE) && BI.getCondition()->hasOneUse()) { + ICmpInst *I = cast<ICmpInst>(BI.getCondition()); + ICmpInst::Predicate NewPred = ICmpInst::getInversePredicate(IPred); + Instruction *NewSCC = new ICmpInst(NewPred, X, Y, "", I); + NewSCC->takeName(I); + // Swap Destinations and condition... + BI.setCondition(NewSCC); + BI.setSuccessor(0, FalseDest); + BI.setSuccessor(1, TrueDest); + RemoveFromWorkList(I); + I->eraseFromParent();; + AddToWorkList(NewSCC); + return &BI; + } + + return 0; +} + +Instruction *InstCombiner::visitSwitchInst(SwitchInst &SI) { + Value *Cond = SI.getCondition(); + if (Instruction *I = dyn_cast<Instruction>(Cond)) { + if (I->getOpcode() == Instruction::Add) + if (ConstantInt *AddRHS = dyn_cast<ConstantInt>(I->getOperand(1))) { + // change 'switch (X+4) case 1:' into 'switch (X) case -3' + for (unsigned i = 2, e = SI.getNumOperands(); i != e; i += 2) + SI.setOperand(i,ConstantExpr::getSub(cast<Constant>(SI.getOperand(i)), + AddRHS)); + SI.setOperand(0, I->getOperand(0)); + AddToWorkList(I); + return &SI; + } + } + return 0; +} + +Instruction *InstCombiner::visitExtractValueInst(ExtractValueInst &EV) { + Value *Agg = EV.getAggregateOperand(); + + if (!EV.hasIndices()) + return ReplaceInstUsesWith(EV, Agg); + + if (Constant *C = dyn_cast<Constant>(Agg)) { + if (isa<UndefValue>(C)) + return ReplaceInstUsesWith(EV, UndefValue::get(EV.getType())); + + if (isa<ConstantAggregateZero>(C)) + return ReplaceInstUsesWith(EV, Constant::getNullValue(EV.getType())); + + if (isa<ConstantArray>(C) || isa<ConstantStruct>(C)) { + // Extract the element indexed by the first index out of the constant + Value *V = C->getOperand(*EV.idx_begin()); + if (EV.getNumIndices() > 1) + // Extract the remaining indices out of the constant indexed by the + // first index + return ExtractValueInst::Create(V, EV.idx_begin() + 1, EV.idx_end()); + else + return ReplaceInstUsesWith(EV, V); + } + return 0; // Can't handle other constants + } + if (InsertValueInst *IV = dyn_cast<InsertValueInst>(Agg)) { + // We're extracting from an insertvalue instruction, compare the indices + const unsigned *exti, *exte, *insi, *inse; + for (exti = EV.idx_begin(), insi = IV->idx_begin(), + exte = EV.idx_end(), inse = IV->idx_end(); + exti != exte && insi != inse; + ++exti, ++insi) { + if (*insi != *exti) + // The insert and extract both reference distinctly different elements. + // This means the extract is not influenced by the insert, and we can + // replace the aggregate operand of the extract with the aggregate + // operand of the insert. i.e., replace + // %I = insertvalue { i32, { i32 } } %A, { i32 } { i32 42 }, 1 + // %E = extractvalue { i32, { i32 } } %I, 0 + // with + // %E = extractvalue { i32, { i32 } } %A, 0 + return ExtractValueInst::Create(IV->getAggregateOperand(), + EV.idx_begin(), EV.idx_end()); + } + if (exti == exte && insi == inse) + // Both iterators are at the end: Index lists are identical. Replace + // %B = insertvalue { i32, { i32 } } %A, i32 42, 1, 0 + // %C = extractvalue { i32, { i32 } } %B, 1, 0 + // with "i32 42" + return ReplaceInstUsesWith(EV, IV->getInsertedValueOperand()); + if (exti == exte) { + // The extract list is a prefix of the insert list. i.e. replace + // %I = insertvalue { i32, { i32 } } %A, i32 42, 1, 0 + // %E = extractvalue { i32, { i32 } } %I, 1 + // with + // %X = extractvalue { i32, { i32 } } %A, 1 + // %E = insertvalue { i32 } %X, i32 42, 0 + // by switching the order of the insert and extract (though the + // insertvalue should be left in, since it may have other uses). + Value *NewEV = InsertNewInstBefore( + ExtractValueInst::Create(IV->getAggregateOperand(), + EV.idx_begin(), EV.idx_end()), + EV); + return InsertValueInst::Create(NewEV, IV->getInsertedValueOperand(), + insi, inse); + } + if (insi == inse) + // The insert list is a prefix of the extract list + // We can simply remove the common indices from the extract and make it + // operate on the inserted value instead of the insertvalue result. + // i.e., replace + // %I = insertvalue { i32, { i32 } } %A, { i32 } { i32 42 }, 1 + // %E = extractvalue { i32, { i32 } } %I, 1, 0 + // with + // %E extractvalue { i32 } { i32 42 }, 0 + return ExtractValueInst::Create(IV->getInsertedValueOperand(), + exti, exte); + } + // Can't simplify extracts from other values. Note that nested extracts are + // already simplified implicitely by the above (extract ( extract (insert) ) + // will be translated into extract ( insert ( extract ) ) first and then just + // the value inserted, if appropriate). + return 0; +} + +/// CheapToScalarize - Return true if the value is cheaper to scalarize than it +/// is to leave as a vector operation. +static bool CheapToScalarize(Value *V, bool isConstant) { + if (isa<ConstantAggregateZero>(V)) + return true; + if (ConstantVector *C = dyn_cast<ConstantVector>(V)) { + if (isConstant) return true; + // If all elts are the same, we can extract. + Constant *Op0 = C->getOperand(0); + for (unsigned i = 1; i < C->getNumOperands(); ++i) + if (C->getOperand(i) != Op0) + return false; + return true; + } + Instruction *I = dyn_cast<Instruction>(V); + if (!I) return false; + + // Insert element gets simplified to the inserted element or is deleted if + // this is constant idx extract element and its a constant idx insertelt. + if (I->getOpcode() == Instruction::InsertElement && isConstant && + isa<ConstantInt>(I->getOperand(2))) + return true; + if (I->getOpcode() == Instruction::Load && I->hasOneUse()) + return true; + if (BinaryOperator *BO = dyn_cast<BinaryOperator>(I)) + if (BO->hasOneUse() && + (CheapToScalarize(BO->getOperand(0), isConstant) || + CheapToScalarize(BO->getOperand(1), isConstant))) + return true; + if (CmpInst *CI = dyn_cast<CmpInst>(I)) + if (CI->hasOneUse() && + (CheapToScalarize(CI->getOperand(0), isConstant) || + CheapToScalarize(CI->getOperand(1), isConstant))) + return true; + + return false; +} + +/// Read and decode a shufflevector mask. +/// +/// It turns undef elements into values that are larger than the number of +/// elements in the input. +static std::vector<unsigned> getShuffleMask(const ShuffleVectorInst *SVI) { + unsigned NElts = SVI->getType()->getNumElements(); + if (isa<ConstantAggregateZero>(SVI->getOperand(2))) + return std::vector<unsigned>(NElts, 0); + if (isa<UndefValue>(SVI->getOperand(2))) + return std::vector<unsigned>(NElts, 2*NElts); + + std::vector<unsigned> Result; + const ConstantVector *CP = cast<ConstantVector>(SVI->getOperand(2)); + for (User::const_op_iterator i = CP->op_begin(), e = CP->op_end(); i!=e; ++i) + if (isa<UndefValue>(*i)) + Result.push_back(NElts*2); // undef -> 8 + else + Result.push_back(cast<ConstantInt>(*i)->getZExtValue()); + return Result; +} + +/// FindScalarElement - Given a vector and an element number, see if the scalar +/// value is already around as a register, for example if it were inserted then +/// extracted from the vector. +static Value *FindScalarElement(Value *V, unsigned EltNo) { + assert(isa<VectorType>(V->getType()) && "Not looking at a vector?"); + const VectorType *PTy = cast<VectorType>(V->getType()); + unsigned Width = PTy->getNumElements(); + if (EltNo >= Width) // Out of range access. + return UndefValue::get(PTy->getElementType()); + + if (isa<UndefValue>(V)) + return UndefValue::get(PTy->getElementType()); + else if (isa<ConstantAggregateZero>(V)) + return Constant::getNullValue(PTy->getElementType()); + else if (ConstantVector *CP = dyn_cast<ConstantVector>(V)) + return CP->getOperand(EltNo); + else if (InsertElementInst *III = dyn_cast<InsertElementInst>(V)) { + // If this is an insert to a variable element, we don't know what it is. + if (!isa<ConstantInt>(III->getOperand(2))) + return 0; + unsigned IIElt = cast<ConstantInt>(III->getOperand(2))->getZExtValue(); + + // If this is an insert to the element we are looking for, return the + // inserted value. + if (EltNo == IIElt) + return III->getOperand(1); + + // Otherwise, the insertelement doesn't modify the value, recurse on its + // vector input. + return FindScalarElement(III->getOperand(0), EltNo); + } else if (ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(V)) { + unsigned LHSWidth = + cast<VectorType>(SVI->getOperand(0)->getType())->getNumElements(); + unsigned InEl = getShuffleMask(SVI)[EltNo]; + if (InEl < LHSWidth) + return FindScalarElement(SVI->getOperand(0), InEl); + else if (InEl < LHSWidth*2) + return FindScalarElement(SVI->getOperand(1), InEl - LHSWidth); + else + return UndefValue::get(PTy->getElementType()); + } + + // Otherwise, we don't know. + return 0; +} + +Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) { + // If vector val is undef, replace extract with scalar undef. + if (isa<UndefValue>(EI.getOperand(0))) + return ReplaceInstUsesWith(EI, UndefValue::get(EI.getType())); + + // If vector val is constant 0, replace extract with scalar 0. + if (isa<ConstantAggregateZero>(EI.getOperand(0))) + return ReplaceInstUsesWith(EI, Constant::getNullValue(EI.getType())); + + if (ConstantVector *C = dyn_cast<ConstantVector>(EI.getOperand(0))) { + // If vector val is constant with all elements the same, replace EI with + // that element. When the elements are not identical, we cannot replace yet + // (we do that below, but only when the index is constant). + Constant *op0 = C->getOperand(0); + for (unsigned i = 1; i < C->getNumOperands(); ++i) + if (C->getOperand(i) != op0) { + op0 = 0; + break; + } + if (op0) + return ReplaceInstUsesWith(EI, op0); + } + + // If extracting a specified index from the vector, see if we can recursively + // find a previously computed scalar that was inserted into the vector. + if (ConstantInt *IdxC = dyn_cast<ConstantInt>(EI.getOperand(1))) { + unsigned IndexVal = IdxC->getZExtValue(); + unsigned VectorWidth = + cast<VectorType>(EI.getOperand(0)->getType())->getNumElements(); + + // If this is extracting an invalid index, turn this into undef, to avoid + // crashing the code below. + if (IndexVal >= VectorWidth) + return ReplaceInstUsesWith(EI, UndefValue::get(EI.getType())); + + // This instruction only demands the single element from the input vector. + // If the input vector has a single use, simplify it based on this use + // property. + if (EI.getOperand(0)->hasOneUse() && VectorWidth != 1) { + APInt UndefElts(VectorWidth, 0); + APInt DemandedMask(VectorWidth, 1 << IndexVal); + if (Value *V = SimplifyDemandedVectorElts(EI.getOperand(0), + DemandedMask, UndefElts)) { + EI.setOperand(0, V); + return &EI; + } + } + + if (Value *Elt = FindScalarElement(EI.getOperand(0), IndexVal)) + return ReplaceInstUsesWith(EI, Elt); + + // If the this extractelement is directly using a bitcast from a vector of + // the same number of elements, see if we can find the source element from + // it. In this case, we will end up needing to bitcast the scalars. + if (BitCastInst *BCI = dyn_cast<BitCastInst>(EI.getOperand(0))) { + if (const VectorType *VT = + dyn_cast<VectorType>(BCI->getOperand(0)->getType())) + if (VT->getNumElements() == VectorWidth) + if (Value *Elt = FindScalarElement(BCI->getOperand(0), IndexVal)) + return new BitCastInst(Elt, EI.getType()); + } + } + + if (Instruction *I = dyn_cast<Instruction>(EI.getOperand(0))) { + if (I->hasOneUse()) { + // Push extractelement into predecessor operation if legal and + // profitable to do so + if (BinaryOperator *BO = dyn_cast<BinaryOperator>(I)) { + bool isConstantElt = isa<ConstantInt>(EI.getOperand(1)); + if (CheapToScalarize(BO, isConstantElt)) { + ExtractElementInst *newEI0 = + new ExtractElementInst(BO->getOperand(0), EI.getOperand(1), + EI.getName()+".lhs"); + ExtractElementInst *newEI1 = + new ExtractElementInst(BO->getOperand(1), EI.getOperand(1), + EI.getName()+".rhs"); + InsertNewInstBefore(newEI0, EI); + InsertNewInstBefore(newEI1, EI); + return BinaryOperator::Create(BO->getOpcode(), newEI0, newEI1); + } + } else if (isa<LoadInst>(I)) { + unsigned AS = + cast<PointerType>(I->getOperand(0)->getType())->getAddressSpace(); + Value *Ptr = InsertBitCastBefore(I->getOperand(0), + PointerType::get(EI.getType(), AS),EI); + GetElementPtrInst *GEP = + GetElementPtrInst::Create(Ptr, EI.getOperand(1), I->getName()+".gep"); + InsertNewInstBefore(GEP, EI); + return new LoadInst(GEP); + } + } + if (InsertElementInst *IE = dyn_cast<InsertElementInst>(I)) { + // Extracting the inserted element? + if (IE->getOperand(2) == EI.getOperand(1)) + return ReplaceInstUsesWith(EI, IE->getOperand(1)); + // If the inserted and extracted elements are constants, they must not + // be the same value, extract from the pre-inserted value instead. + if (isa<Constant>(IE->getOperand(2)) && + isa<Constant>(EI.getOperand(1))) { + AddUsesToWorkList(EI); + EI.setOperand(0, IE->getOperand(0)); + return &EI; + } + } else if (ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(I)) { + // If this is extracting an element from a shufflevector, figure out where + // it came from and extract from the appropriate input element instead. + if (ConstantInt *Elt = dyn_cast<ConstantInt>(EI.getOperand(1))) { + unsigned SrcIdx = getShuffleMask(SVI)[Elt->getZExtValue()]; + Value *Src; + unsigned LHSWidth = + cast<VectorType>(SVI->getOperand(0)->getType())->getNumElements(); + + if (SrcIdx < LHSWidth) + Src = SVI->getOperand(0); + else if (SrcIdx < LHSWidth*2) { + SrcIdx -= LHSWidth; + Src = SVI->getOperand(1); + } else { + return ReplaceInstUsesWith(EI, UndefValue::get(EI.getType())); + } + return new ExtractElementInst(Src, SrcIdx); + } + } + } + return 0; +} + +/// CollectSingleShuffleElements - If V is a shuffle of values that ONLY returns +/// elements from either LHS or RHS, return the shuffle mask and true. +/// Otherwise, return false. +static bool CollectSingleShuffleElements(Value *V, Value *LHS, Value *RHS, + std::vector<Constant*> &Mask) { + assert(V->getType() == LHS->getType() && V->getType() == RHS->getType() && + "Invalid CollectSingleShuffleElements"); + unsigned NumElts = cast<VectorType>(V->getType())->getNumElements(); + + if (isa<UndefValue>(V)) { + Mask.assign(NumElts, UndefValue::get(Type::Int32Ty)); + return true; + } else if (V == LHS) { + for (unsigned i = 0; i != NumElts; ++i) + Mask.push_back(ConstantInt::get(Type::Int32Ty, i)); + return true; + } else if (V == RHS) { + for (unsigned i = 0; i != NumElts; ++i) + Mask.push_back(ConstantInt::get(Type::Int32Ty, i+NumElts)); + return true; + } else if (InsertElementInst *IEI = dyn_cast<InsertElementInst>(V)) { + // If this is an insert of an extract from some other vector, include it. + Value *VecOp = IEI->getOperand(0); + Value *ScalarOp = IEI->getOperand(1); + Value *IdxOp = IEI->getOperand(2); + + if (!isa<ConstantInt>(IdxOp)) + return false; + unsigned InsertedIdx = cast<ConstantInt>(IdxOp)->getZExtValue(); + + if (isa<UndefValue>(ScalarOp)) { // inserting undef into vector. + // Okay, we can handle this if the vector we are insertinting into is + // transitively ok. + if (CollectSingleShuffleElements(VecOp, LHS, RHS, Mask)) { + // If so, update the mask to reflect the inserted undef. + Mask[InsertedIdx] = UndefValue::get(Type::Int32Ty); + return true; + } + } else if (ExtractElementInst *EI = dyn_cast<ExtractElementInst>(ScalarOp)){ + if (isa<ConstantInt>(EI->getOperand(1)) && + EI->getOperand(0)->getType() == V->getType()) { + unsigned ExtractedIdx = + cast<ConstantInt>(EI->getOperand(1))->getZExtValue(); + + // This must be extracting from either LHS or RHS. + if (EI->getOperand(0) == LHS || EI->getOperand(0) == RHS) { + // Okay, we can handle this if the vector we are insertinting into is + // transitively ok. + if (CollectSingleShuffleElements(VecOp, LHS, RHS, Mask)) { + // If so, update the mask to reflect the inserted value. + if (EI->getOperand(0) == LHS) { + Mask[InsertedIdx % NumElts] = + ConstantInt::get(Type::Int32Ty, ExtractedIdx); + } else { + assert(EI->getOperand(0) == RHS); + Mask[InsertedIdx % NumElts] = + ConstantInt::get(Type::Int32Ty, ExtractedIdx+NumElts); + + } + return true; + } + } + } + } + } + // TODO: Handle shufflevector here! + + return false; +} + +/// CollectShuffleElements - We are building a shuffle of V, using RHS as the +/// RHS of the shuffle instruction, if it is not null. Return a shuffle mask +/// that computes V and the LHS value of the shuffle. +static Value *CollectShuffleElements(Value *V, std::vector<Constant*> &Mask, + Value *&RHS) { + assert(isa<VectorType>(V->getType()) && + (RHS == 0 || V->getType() == RHS->getType()) && + "Invalid shuffle!"); + unsigned NumElts = cast<VectorType>(V->getType())->getNumElements(); + + if (isa<UndefValue>(V)) { + Mask.assign(NumElts, UndefValue::get(Type::Int32Ty)); + return V; + } else if (isa<ConstantAggregateZero>(V)) { + Mask.assign(NumElts, ConstantInt::get(Type::Int32Ty, 0)); + return V; + } else if (InsertElementInst *IEI = dyn_cast<InsertElementInst>(V)) { + // If this is an insert of an extract from some other vector, include it. + Value *VecOp = IEI->getOperand(0); + Value *ScalarOp = IEI->getOperand(1); + Value *IdxOp = IEI->getOperand(2); + + if (ExtractElementInst *EI = dyn_cast<ExtractElementInst>(ScalarOp)) { + if (isa<ConstantInt>(EI->getOperand(1)) && isa<ConstantInt>(IdxOp) && + EI->getOperand(0)->getType() == V->getType()) { + unsigned ExtractedIdx = + cast<ConstantInt>(EI->getOperand(1))->getZExtValue(); + unsigned InsertedIdx = cast<ConstantInt>(IdxOp)->getZExtValue(); + + // Either the extracted from or inserted into vector must be RHSVec, + // otherwise we'd end up with a shuffle of three inputs. + if (EI->getOperand(0) == RHS || RHS == 0) { + RHS = EI->getOperand(0); + Value *V = CollectShuffleElements(VecOp, Mask, RHS); + Mask[InsertedIdx % NumElts] = + ConstantInt::get(Type::Int32Ty, NumElts+ExtractedIdx); + return V; + } + + if (VecOp == RHS) { + Value *V = CollectShuffleElements(EI->getOperand(0), Mask, RHS); + // Everything but the extracted element is replaced with the RHS. + for (unsigned i = 0; i != NumElts; ++i) { + if (i != InsertedIdx) + Mask[i] = ConstantInt::get(Type::Int32Ty, NumElts+i); + } + return V; + } + + // If this insertelement is a chain that comes from exactly these two + // vectors, return the vector and the effective shuffle. + if (CollectSingleShuffleElements(IEI, EI->getOperand(0), RHS, Mask)) + return EI->getOperand(0); + + } + } + } + // TODO: Handle shufflevector here! + + // Otherwise, can't do anything fancy. Return an identity vector. + for (unsigned i = 0; i != NumElts; ++i) + Mask.push_back(ConstantInt::get(Type::Int32Ty, i)); + return V; +} + +Instruction *InstCombiner::visitInsertElementInst(InsertElementInst &IE) { + Value *VecOp = IE.getOperand(0); + Value *ScalarOp = IE.getOperand(1); + Value *IdxOp = IE.getOperand(2); + + // Inserting an undef or into an undefined place, remove this. + if (isa<UndefValue>(ScalarOp) || isa<UndefValue>(IdxOp)) + ReplaceInstUsesWith(IE, VecOp); + + // If the inserted element was extracted from some other vector, and if the + // indexes are constant, try to turn this into a shufflevector operation. + if (ExtractElementInst *EI = dyn_cast<ExtractElementInst>(ScalarOp)) { + if (isa<ConstantInt>(EI->getOperand(1)) && isa<ConstantInt>(IdxOp) && + EI->getOperand(0)->getType() == IE.getType()) { + unsigned NumVectorElts = IE.getType()->getNumElements(); + unsigned ExtractedIdx = + cast<ConstantInt>(EI->getOperand(1))->getZExtValue(); + unsigned InsertedIdx = cast<ConstantInt>(IdxOp)->getZExtValue(); + + if (ExtractedIdx >= NumVectorElts) // Out of range extract. + return ReplaceInstUsesWith(IE, VecOp); + + if (InsertedIdx >= NumVectorElts) // Out of range insert. + return ReplaceInstUsesWith(IE, UndefValue::get(IE.getType())); + + // If we are extracting a value from a vector, then inserting it right + // back into the same place, just use the input vector. + if (EI->getOperand(0) == VecOp && ExtractedIdx == InsertedIdx) + return ReplaceInstUsesWith(IE, VecOp); + + // We could theoretically do this for ANY input. However, doing so could + // turn chains of insertelement instructions into a chain of shufflevector + // instructions, and right now we do not merge shufflevectors. As such, + // only do this in a situation where it is clear that there is benefit. + if (isa<UndefValue>(VecOp) || isa<ConstantAggregateZero>(VecOp)) { + // Turn this into shuffle(EIOp0, VecOp, Mask). The result has all of + // the values of VecOp, except then one read from EIOp0. + // Build a new shuffle mask. + std::vector<Constant*> Mask; + if (isa<UndefValue>(VecOp)) + Mask.assign(NumVectorElts, UndefValue::get(Type::Int32Ty)); + else { + assert(isa<ConstantAggregateZero>(VecOp) && "Unknown thing"); + Mask.assign(NumVectorElts, ConstantInt::get(Type::Int32Ty, + NumVectorElts)); + } + Mask[InsertedIdx] = ConstantInt::get(Type::Int32Ty, ExtractedIdx); + return new ShuffleVectorInst(EI->getOperand(0), VecOp, + ConstantVector::get(Mask)); + } + + // If this insertelement isn't used by some other insertelement, turn it + // (and any insertelements it points to), into one big shuffle. + if (!IE.hasOneUse() || !isa<InsertElementInst>(IE.use_back())) { + std::vector<Constant*> Mask; + Value *RHS = 0; + Value *LHS = CollectShuffleElements(&IE, Mask, RHS); + if (RHS == 0) RHS = UndefValue::get(LHS->getType()); + // We now have a shuffle of LHS, RHS, Mask. + return new ShuffleVectorInst(LHS, RHS, ConstantVector::get(Mask)); + } + } + } + + return 0; +} + + +Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { + Value *LHS = SVI.getOperand(0); + Value *RHS = SVI.getOperand(1); + std::vector<unsigned> Mask = getShuffleMask(&SVI); + + bool MadeChange = false; + + // Undefined shuffle mask -> undefined value. + if (isa<UndefValue>(SVI.getOperand(2))) + return ReplaceInstUsesWith(SVI, UndefValue::get(SVI.getType())); + + unsigned VWidth = cast<VectorType>(SVI.getType())->getNumElements(); + + if (VWidth != cast<VectorType>(LHS->getType())->getNumElements()) + return 0; + + APInt UndefElts(VWidth, 0); + APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth)); + if (SimplifyDemandedVectorElts(&SVI, AllOnesEltMask, UndefElts)) { + LHS = SVI.getOperand(0); + RHS = SVI.getOperand(1); + MadeChange = true; + } + + // Canonicalize shuffle(x ,x,mask) -> shuffle(x, undef,mask') + // Canonicalize shuffle(undef,x,mask) -> shuffle(x, undef,mask'). + if (LHS == RHS || isa<UndefValue>(LHS)) { + if (isa<UndefValue>(LHS) && LHS == RHS) { + // shuffle(undef,undef,mask) -> undef. + return ReplaceInstUsesWith(SVI, LHS); + } + + // Remap any references to RHS to use LHS. + std::vector<Constant*> Elts; + for (unsigned i = 0, e = Mask.size(); i != e; ++i) { + if (Mask[i] >= 2*e) + Elts.push_back(UndefValue::get(Type::Int32Ty)); + else { + if ((Mask[i] >= e && isa<UndefValue>(RHS)) || + (Mask[i] < e && isa<UndefValue>(LHS))) { + Mask[i] = 2*e; // Turn into undef. + Elts.push_back(UndefValue::get(Type::Int32Ty)); + } else { + Mask[i] = Mask[i] % e; // Force to LHS. + Elts.push_back(ConstantInt::get(Type::Int32Ty, Mask[i])); + } + } + } + SVI.setOperand(0, SVI.getOperand(1)); + SVI.setOperand(1, UndefValue::get(RHS->getType())); + SVI.setOperand(2, ConstantVector::get(Elts)); + LHS = SVI.getOperand(0); + RHS = SVI.getOperand(1); + MadeChange = true; + } + + // Analyze the shuffle, are the LHS or RHS and identity shuffles? + bool isLHSID = true, isRHSID = true; + + for (unsigned i = 0, e = Mask.size(); i != e; ++i) { + if (Mask[i] >= e*2) continue; // Ignore undef values. + // Is this an identity shuffle of the LHS value? + isLHSID &= (Mask[i] == i); + + // Is this an identity shuffle of the RHS value? + isRHSID &= (Mask[i]-e == i); + } + + // Eliminate identity shuffles. + if (isLHSID) return ReplaceInstUsesWith(SVI, LHS); + if (isRHSID) return ReplaceInstUsesWith(SVI, RHS); + + // If the LHS is a shufflevector itself, see if we can combine it with this + // one without producing an unusual shuffle. Here we are really conservative: + // we are absolutely afraid of producing a shuffle mask not in the input + // program, because the code gen may not be smart enough to turn a merged + // shuffle into two specific shuffles: it may produce worse code. As such, + // we only merge two shuffles if the result is one of the two input shuffle + // masks. In this case, merging the shuffles just removes one instruction, + // which we know is safe. This is good for things like turning: + // (splat(splat)) -> splat. + if (ShuffleVectorInst *LHSSVI = dyn_cast<ShuffleVectorInst>(LHS)) { + if (isa<UndefValue>(RHS)) { + std::vector<unsigned> LHSMask = getShuffleMask(LHSSVI); + + std::vector<unsigned> NewMask; + for (unsigned i = 0, e = Mask.size(); i != e; ++i) + if (Mask[i] >= 2*e) + NewMask.push_back(2*e); + else + NewMask.push_back(LHSMask[Mask[i]]); + + // If the result mask is equal to the src shuffle or this shuffle mask, do + // the replacement. + if (NewMask == LHSMask || NewMask == Mask) { + unsigned LHSInNElts = + cast<VectorType>(LHSSVI->getOperand(0)->getType())->getNumElements(); + std::vector<Constant*> Elts; + for (unsigned i = 0, e = NewMask.size(); i != e; ++i) { + if (NewMask[i] >= LHSInNElts*2) { + Elts.push_back(UndefValue::get(Type::Int32Ty)); + } else { + Elts.push_back(ConstantInt::get(Type::Int32Ty, NewMask[i])); + } + } + return new ShuffleVectorInst(LHSSVI->getOperand(0), + LHSSVI->getOperand(1), + ConstantVector::get(Elts)); + } + } + } + + return MadeChange ? &SVI : 0; +} + + + + +/// TryToSinkInstruction - Try to move the specified instruction from its +/// current block into the beginning of DestBlock, which can only happen if it's +/// safe to move the instruction past all of the instructions between it and the +/// end of its block. +static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock) { + assert(I->hasOneUse() && "Invariants didn't hold!"); + + // Cannot move control-flow-involving, volatile loads, vaarg, etc. + if (isa<PHINode>(I) || I->mayHaveSideEffects() || isa<TerminatorInst>(I)) + return false; + + // Do not sink alloca instructions out of the entry block. + if (isa<AllocaInst>(I) && I->getParent() == + &DestBlock->getParent()->getEntryBlock()) + return false; + + // We can only sink load instructions if there is nothing between the load and + // the end of block that could change the value. + if (I->mayReadFromMemory()) { + for (BasicBlock::iterator Scan = I, E = I->getParent()->end(); + Scan != E; ++Scan) + if (Scan->mayWriteToMemory()) + return false; + } + + BasicBlock::iterator InsertPos = DestBlock->getFirstNonPHI(); + + CopyPrecedingStopPoint(I, InsertPos); + I->moveBefore(InsertPos); + ++NumSunkInst; + return true; +} + + +/// AddReachableCodeToWorklist - Walk the function in depth-first order, adding +/// all reachable code to the worklist. +/// +/// This has a couple of tricks to make the code faster and more powerful. In +/// particular, we constant fold and DCE instructions as we go, to avoid adding +/// them to the worklist (this significantly speeds up instcombine on code where +/// many instructions are dead or constant). Additionally, if we find a branch +/// whose condition is a known constant, we only visit the reachable successors. +/// +static void AddReachableCodeToWorklist(BasicBlock *BB, + SmallPtrSet<BasicBlock*, 64> &Visited, + InstCombiner &IC, + const TargetData *TD) { + SmallVector<BasicBlock*, 256> Worklist; + Worklist.push_back(BB); + + while (!Worklist.empty()) { + BB = Worklist.back(); + Worklist.pop_back(); + + // We have now visited this block! If we've already been here, ignore it. + if (!Visited.insert(BB)) continue; + + DbgInfoIntrinsic *DBI_Prev = NULL; + for (BasicBlock::iterator BBI = BB->begin(), E = BB->end(); BBI != E; ) { + Instruction *Inst = BBI++; + + // DCE instruction if trivially dead. + if (isInstructionTriviallyDead(Inst)) { + ++NumDeadInst; + DOUT << "IC: DCE: " << *Inst; + Inst->eraseFromParent(); + continue; + } + + // ConstantProp instruction if trivially constant. + if (Constant *C = ConstantFoldInstruction(Inst, TD)) { + DOUT << "IC: ConstFold to: " << *C << " from: " << *Inst; + Inst->replaceAllUsesWith(C); + ++NumConstProp; + Inst->eraseFromParent(); + continue; + } + + // If there are two consecutive llvm.dbg.stoppoint calls then + // it is likely that the optimizer deleted code in between these + // two intrinsics. + DbgInfoIntrinsic *DBI_Next = dyn_cast<DbgInfoIntrinsic>(Inst); + if (DBI_Next) { + if (DBI_Prev + && DBI_Prev->getIntrinsicID() == llvm::Intrinsic::dbg_stoppoint + && DBI_Next->getIntrinsicID() == llvm::Intrinsic::dbg_stoppoint) { + IC.RemoveFromWorkList(DBI_Prev); + DBI_Prev->eraseFromParent(); + } + DBI_Prev = DBI_Next; + } else { + DBI_Prev = 0; + } + + IC.AddToWorkList(Inst); + } + + // Recursively visit successors. If this is a branch or switch on a + // constant, only visit the reachable successor. + TerminatorInst *TI = BB->getTerminator(); + if (BranchInst *BI = dyn_cast<BranchInst>(TI)) { + if (BI->isConditional() && isa<ConstantInt>(BI->getCondition())) { + bool CondVal = cast<ConstantInt>(BI->getCondition())->getZExtValue(); + BasicBlock *ReachableBB = BI->getSuccessor(!CondVal); + Worklist.push_back(ReachableBB); + continue; + } + } else if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) { + if (ConstantInt *Cond = dyn_cast<ConstantInt>(SI->getCondition())) { + // See if this is an explicit destination. + for (unsigned i = 1, e = SI->getNumSuccessors(); i != e; ++i) + if (SI->getCaseValue(i) == Cond) { + BasicBlock *ReachableBB = SI->getSuccessor(i); + Worklist.push_back(ReachableBB); + continue; + } + + // Otherwise it is the default destination. + Worklist.push_back(SI->getSuccessor(0)); + continue; + } + } + + for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) + Worklist.push_back(TI->getSuccessor(i)); + } +} + +bool InstCombiner::DoOneIteration(Function &F, unsigned Iteration) { + bool Changed = false; + TD = &getAnalysis<TargetData>(); + + DEBUG(DOUT << "\n\nINSTCOMBINE ITERATION #" << Iteration << " on " + << F.getNameStr() << "\n"); + + { + // Do a depth-first traversal of the function, populate the worklist with + // the reachable instructions. Ignore blocks that are not reachable. Keep + // track of which blocks we visit. + SmallPtrSet<BasicBlock*, 64> Visited; + AddReachableCodeToWorklist(F.begin(), Visited, *this, TD); + + // Do a quick scan over the function. If we find any blocks that are + // unreachable, remove any instructions inside of them. This prevents + // the instcombine code from having to deal with some bad special cases. + for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) + if (!Visited.count(BB)) { + Instruction *Term = BB->getTerminator(); + while (Term != BB->begin()) { // Remove instrs bottom-up + BasicBlock::iterator I = Term; --I; + + DOUT << "IC: DCE: " << *I; + // A debug intrinsic shouldn't force another iteration if we weren't + // going to do one without it. + if (!isa<DbgInfoIntrinsic>(I)) { + ++NumDeadInst; + Changed = true; + } + if (!I->use_empty()) + I->replaceAllUsesWith(UndefValue::get(I->getType())); + I->eraseFromParent(); + } + } + } + + while (!Worklist.empty()) { + Instruction *I = RemoveOneFromWorkList(); + if (I == 0) continue; // skip null values. + + // Check to see if we can DCE the instruction. + if (isInstructionTriviallyDead(I)) { + // Add operands to the worklist. + if (I->getNumOperands() < 4) + AddUsesToWorkList(*I); + ++NumDeadInst; + + DOUT << "IC: DCE: " << *I; + + I->eraseFromParent(); + RemoveFromWorkList(I); + Changed = true; + continue; + } + + // Instruction isn't dead, see if we can constant propagate it. + if (Constant *C = ConstantFoldInstruction(I, TD)) { + DOUT << "IC: ConstFold to: " << *C << " from: " << *I; + + // Add operands to the worklist. + AddUsesToWorkList(*I); + ReplaceInstUsesWith(*I, C); + + ++NumConstProp; + I->eraseFromParent(); + RemoveFromWorkList(I); + Changed = true; + continue; + } + + if (TD && + (I->getType()->getTypeID() == Type::VoidTyID || + I->isTrapping())) { + // See if we can constant fold its operands. + for (User::op_iterator i = I->op_begin(), e = I->op_end(); i != e; ++i) + if (ConstantExpr *CE = dyn_cast<ConstantExpr>(i)) + if (Constant *NewC = ConstantFoldConstantExpression(CE, TD)) + if (NewC != CE) { + i->set(NewC); + Changed = true; + } + } + + // See if we can trivially sink this instruction to a successor basic block. + if (I->hasOneUse()) { + BasicBlock *BB = I->getParent(); + BasicBlock *UserParent = cast<Instruction>(I->use_back())->getParent(); + if (UserParent != BB) { + bool UserIsSuccessor = false; + // See if the user is one of our successors. + for (succ_iterator SI = succ_begin(BB), E = succ_end(BB); SI != E; ++SI) + if (*SI == UserParent) { + UserIsSuccessor = true; + break; + } + + // If the user is one of our immediate successors, and if that successor + // only has us as a predecessors (we'd have to split the critical edge + // otherwise), we can keep going. + if (UserIsSuccessor && !isa<PHINode>(I->use_back()) && + next(pred_begin(UserParent)) == pred_end(UserParent)) + // Okay, the CFG is simple enough, try to sink this instruction. + Changed |= TryToSinkInstruction(I, UserParent); + } + } + + // Now that we have an instruction, try combining it to simplify it... +#ifndef NDEBUG + std::string OrigI; +#endif + DEBUG(std::ostringstream SS; I->print(SS); OrigI = SS.str();); + if (Instruction *Result = visit(*I)) { + ++NumCombined; + // Should we replace the old instruction with a new one? + if (Result != I) { + DOUT << "IC: Old = " << *I + << " New = " << *Result; + + // Everything uses the new instruction now. + I->replaceAllUsesWith(Result); + + // Push the new instruction and any users onto the worklist. + AddToWorkList(Result); + AddUsersToWorkList(*Result); + + // Move the name to the new instruction first. + Result->takeName(I); + + // Insert the new instruction into the basic block... + BasicBlock *InstParent = I->getParent(); + BasicBlock::iterator InsertPos = I; + + if (!isa<PHINode>(Result)) // If combining a PHI, don't insert + while (isa<PHINode>(InsertPos)) // middle of a block of PHIs. + ++InsertPos; + + InstParent->getInstList().insert(InsertPos, Result); + + // Make sure that we reprocess all operands now that we reduced their + // use counts. + AddUsesToWorkList(*I); + + // Instructions can end up on the worklist more than once. Make sure + // we do not process an instruction that has been deleted. + RemoveFromWorkList(I); + + // Erase the old instruction. + InstParent->getInstList().erase(I); + } else { +#ifndef NDEBUG + DOUT << "IC: Mod = " << OrigI + << " New = " << *I; +#endif + + // If the instruction was modified, it's possible that it is now dead. + // if so, remove it. + if (isInstructionTriviallyDead(I)) { + // Make sure we process all operands now that we are reducing their + // use counts. + AddUsesToWorkList(*I); + + // Instructions may end up in the worklist more than once. Erase all + // occurrences of this instruction. + RemoveFromWorkList(I); + I->eraseFromParent(); + } else { + AddToWorkList(I); + AddUsersToWorkList(*I); + } + } + Changed = true; + } + } + + assert(WorklistMap.empty() && "Worklist empty, but map not?"); + + // Do an explicit clear, this shrinks the map if needed. + WorklistMap.clear(); + return Changed; +} + + +bool InstCombiner::runOnFunction(Function &F) { + MustPreserveLCSSA = mustPreserveAnalysisID(LCSSAID); + + bool EverMadeChange = false; + + // Iterate while there is work to do. + unsigned Iteration = 0; + while (DoOneIteration(F, Iteration++)) + EverMadeChange = true; + return EverMadeChange; +} + +FunctionPass *llvm::createInstructionCombiningPass() { + return new InstCombiner(); +} diff --git a/lib/Transforms/Scalar/JumpThreading.cpp b/lib/Transforms/Scalar/JumpThreading.cpp new file mode 100644 index 0000000..c0ca2df --- /dev/null +++ b/lib/Transforms/Scalar/JumpThreading.cpp @@ -0,0 +1,954 @@ +//===- JumpThreading.cpp - Thread control through conditional blocks ------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the Jump Threading pass. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "jump-threading" +#include "llvm/Transforms/Scalar.h" +#include "llvm/IntrinsicInst.h" +#include "llvm/Pass.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Target/TargetData.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ValueHandle.h" +using namespace llvm; + +STATISTIC(NumThreads, "Number of jumps threaded"); +STATISTIC(NumFolds, "Number of terminators folded"); + +static cl::opt<unsigned> +Threshold("jump-threading-threshold", + cl::desc("Max block size to duplicate for jump threading"), + cl::init(6), cl::Hidden); + +namespace { + /// This pass performs 'jump threading', which looks at blocks that have + /// multiple predecessors and multiple successors. If one or more of the + /// predecessors of the block can be proven to always jump to one of the + /// successors, we forward the edge from the predecessor to the successor by + /// duplicating the contents of this block. + /// + /// An example of when this can occur is code like this: + /// + /// if () { ... + /// X = 4; + /// } + /// if (X < 3) { + /// + /// In this case, the unconditional branch at the end of the first if can be + /// revectored to the false side of the second if. + /// + class VISIBILITY_HIDDEN JumpThreading : public FunctionPass { + TargetData *TD; +#ifdef NDEBUG + SmallPtrSet<BasicBlock*, 16> LoopHeaders; +#else + SmallSet<AssertingVH<BasicBlock>, 16> LoopHeaders; +#endif + public: + static char ID; // Pass identification + JumpThreading() : FunctionPass(&ID) {} + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired<TargetData>(); + } + + bool runOnFunction(Function &F); + void FindLoopHeaders(Function &F); + + bool ProcessBlock(BasicBlock *BB); + bool ThreadEdge(BasicBlock *BB, BasicBlock *PredBB, BasicBlock *SuccBB, + unsigned JumpThreadCost); + BasicBlock *FactorCommonPHIPreds(PHINode *PN, Constant *CstVal); + bool ProcessBranchOnDuplicateCond(BasicBlock *PredBB, BasicBlock *DestBB); + bool ProcessSwitchOnDuplicateCond(BasicBlock *PredBB, BasicBlock *DestBB); + + bool ProcessJumpOnPHI(PHINode *PN); + bool ProcessBranchOnLogical(Value *V, BasicBlock *BB, bool isAnd); + bool ProcessBranchOnCompare(CmpInst *Cmp, BasicBlock *BB); + + bool SimplifyPartiallyRedundantLoad(LoadInst *LI); + }; +} + +char JumpThreading::ID = 0; +static RegisterPass<JumpThreading> +X("jump-threading", "Jump Threading"); + +// Public interface to the Jump Threading pass +FunctionPass *llvm::createJumpThreadingPass() { return new JumpThreading(); } + +/// runOnFunction - Top level algorithm. +/// +bool JumpThreading::runOnFunction(Function &F) { + DOUT << "Jump threading on function '" << F.getNameStart() << "'\n"; + TD = &getAnalysis<TargetData>(); + + FindLoopHeaders(F); + + bool AnotherIteration = true, EverChanged = false; + while (AnotherIteration) { + AnotherIteration = false; + bool Changed = false; + for (Function::iterator I = F.begin(), E = F.end(); I != E;) { + BasicBlock *BB = I; + while (ProcessBlock(BB)) + Changed = true; + + ++I; + + // If the block is trivially dead, zap it. This eliminates the successor + // edges which simplifies the CFG. + if (pred_begin(BB) == pred_end(BB) && + BB != &BB->getParent()->getEntryBlock()) { + DOUT << " JT: Deleting dead block '" << BB->getNameStart() + << "' with terminator: " << *BB->getTerminator(); + LoopHeaders.erase(BB); + DeleteDeadBlock(BB); + Changed = true; + } + } + AnotherIteration = Changed; + EverChanged |= Changed; + } + + LoopHeaders.clear(); + return EverChanged; +} + +/// FindLoopHeaders - We do not want jump threading to turn proper loop +/// structures into irreducible loops. Doing this breaks up the loop nesting +/// hierarchy and pessimizes later transformations. To prevent this from +/// happening, we first have to find the loop headers. Here we approximate this +/// by finding targets of backedges in the CFG. +/// +/// Note that there definitely are cases when we want to allow threading of +/// edges across a loop header. For example, threading a jump from outside the +/// loop (the preheader) to an exit block of the loop is definitely profitable. +/// It is also almost always profitable to thread backedges from within the loop +/// to exit blocks, and is often profitable to thread backedges to other blocks +/// within the loop (forming a nested loop). This simple analysis is not rich +/// enough to track all of these properties and keep it up-to-date as the CFG +/// mutates, so we don't allow any of these transformations. +/// +void JumpThreading::FindLoopHeaders(Function &F) { + SmallVector<std::pair<const BasicBlock*,const BasicBlock*>, 32> Edges; + FindFunctionBackedges(F, Edges); + + for (unsigned i = 0, e = Edges.size(); i != e; ++i) + LoopHeaders.insert(const_cast<BasicBlock*>(Edges[i].second)); +} + + +/// FactorCommonPHIPreds - If there are multiple preds with the same incoming +/// value for the PHI, factor them together so we get one block to thread for +/// the whole group. +/// This is important for things like "phi i1 [true, true, false, true, x]" +/// where we only need to clone the block for the true blocks once. +/// +BasicBlock *JumpThreading::FactorCommonPHIPreds(PHINode *PN, Constant *CstVal) { + SmallVector<BasicBlock*, 16> CommonPreds; + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) + if (PN->getIncomingValue(i) == CstVal) + CommonPreds.push_back(PN->getIncomingBlock(i)); + + if (CommonPreds.size() == 1) + return CommonPreds[0]; + + DOUT << " Factoring out " << CommonPreds.size() + << " common predecessors.\n"; + return SplitBlockPredecessors(PN->getParent(), + &CommonPreds[0], CommonPreds.size(), + ".thr_comm", this); +} + + +/// getJumpThreadDuplicationCost - Return the cost of duplicating this block to +/// thread across it. +static unsigned getJumpThreadDuplicationCost(const BasicBlock *BB) { + /// Ignore PHI nodes, these will be flattened when duplication happens. + BasicBlock::const_iterator I = BB->getFirstNonPHI(); + + // Sum up the cost of each instruction until we get to the terminator. Don't + // include the terminator because the copy won't include it. + unsigned Size = 0; + for (; !isa<TerminatorInst>(I); ++I) { + // Debugger intrinsics don't incur code size. + if (isa<DbgInfoIntrinsic>(I)) continue; + + // If this is a pointer->pointer bitcast, it is free. + if (isa<BitCastInst>(I) && isa<PointerType>(I->getType())) + continue; + + // All other instructions count for at least one unit. + ++Size; + + // Calls are more expensive. If they are non-intrinsic calls, we model them + // as having cost of 4. If they are a non-vector intrinsic, we model them + // as having cost of 2 total, and if they are a vector intrinsic, we model + // them as having cost 1. + if (const CallInst *CI = dyn_cast<CallInst>(I)) { + if (!isa<IntrinsicInst>(CI)) + Size += 3; + else if (isa<VectorType>(CI->getType())) + Size += 1; + } + } + + // Threading through a switch statement is particularly profitable. If this + // block ends in a switch, decrease its cost to make it more likely to happen. + if (isa<SwitchInst>(I)) + Size = Size > 6 ? Size-6 : 0; + + return Size; +} + +/// ProcessBlock - If there are any predecessors whose control can be threaded +/// through to a successor, transform them now. +bool JumpThreading::ProcessBlock(BasicBlock *BB) { + // If this block has a single predecessor, and if that pred has a single + // successor, merge the blocks. This encourages recursive jump threading + // because now the condition in this block can be threaded through + // predecessors of our predecessor block. + if (BasicBlock *SinglePred = BB->getSinglePredecessor()) + if (SinglePred->getTerminator()->getNumSuccessors() == 1 && + SinglePred != BB) { + // If SinglePred was a loop header, BB becomes one. + if (LoopHeaders.erase(SinglePred)) + LoopHeaders.insert(BB); + + // Remember if SinglePred was the entry block of the function. If so, we + // will need to move BB back to the entry position. + bool isEntry = SinglePred == &SinglePred->getParent()->getEntryBlock(); + MergeBasicBlockIntoOnlyPred(BB); + + if (isEntry && BB != &BB->getParent()->getEntryBlock()) + BB->moveBefore(&BB->getParent()->getEntryBlock()); + return true; + } + + // See if this block ends with a branch or switch. If so, see if the + // condition is a phi node. If so, and if an entry of the phi node is a + // constant, we can thread the block. + Value *Condition; + if (BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator())) { + // Can't thread an unconditional jump. + if (BI->isUnconditional()) return false; + Condition = BI->getCondition(); + } else if (SwitchInst *SI = dyn_cast<SwitchInst>(BB->getTerminator())) + Condition = SI->getCondition(); + else + return false; // Must be an invoke. + + // If the terminator of this block is branching on a constant, simplify the + // terminator to an unconditional branch. This can occur due to threading in + // other blocks. + if (isa<ConstantInt>(Condition)) { + DOUT << " In block '" << BB->getNameStart() + << "' folding terminator: " << *BB->getTerminator(); + ++NumFolds; + ConstantFoldTerminator(BB); + return true; + } + + // If the terminator is branching on an undef, we can pick any of the + // successors to branch to. Since this is arbitrary, we pick the successor + // with the fewest predecessors. This should reduce the in-degree of the + // others. + if (isa<UndefValue>(Condition)) { + TerminatorInst *BBTerm = BB->getTerminator(); + unsigned MinSucc = 0; + BasicBlock *TestBB = BBTerm->getSuccessor(MinSucc); + // Compute the successor with the minimum number of predecessors. + unsigned MinNumPreds = std::distance(pred_begin(TestBB), pred_end(TestBB)); + for (unsigned i = 1, e = BBTerm->getNumSuccessors(); i != e; ++i) { + TestBB = BBTerm->getSuccessor(i); + unsigned NumPreds = std::distance(pred_begin(TestBB), pred_end(TestBB)); + if (NumPreds < MinNumPreds) + MinSucc = i; + } + + // Fold the branch/switch. + for (unsigned i = 0, e = BBTerm->getNumSuccessors(); i != e; ++i) { + if (i == MinSucc) continue; + BBTerm->getSuccessor(i)->removePredecessor(BB); + } + + DOUT << " In block '" << BB->getNameStart() + << "' folding undef terminator: " << *BBTerm; + BranchInst::Create(BBTerm->getSuccessor(MinSucc), BBTerm); + BBTerm->eraseFromParent(); + return true; + } + + Instruction *CondInst = dyn_cast<Instruction>(Condition); + + // If the condition is an instruction defined in another block, see if a + // predecessor has the same condition: + // br COND, BBX, BBY + // BBX: + // br COND, BBZ, BBW + if (!Condition->hasOneUse() && // Multiple uses. + (CondInst == 0 || CondInst->getParent() != BB)) { // Non-local definition. + pred_iterator PI = pred_begin(BB), E = pred_end(BB); + if (isa<BranchInst>(BB->getTerminator())) { + for (; PI != E; ++PI) + if (BranchInst *PBI = dyn_cast<BranchInst>((*PI)->getTerminator())) + if (PBI->isConditional() && PBI->getCondition() == Condition && + ProcessBranchOnDuplicateCond(*PI, BB)) + return true; + } else { + assert(isa<SwitchInst>(BB->getTerminator()) && "Unknown jump terminator"); + for (; PI != E; ++PI) + if (SwitchInst *PSI = dyn_cast<SwitchInst>((*PI)->getTerminator())) + if (PSI->getCondition() == Condition && + ProcessSwitchOnDuplicateCond(*PI, BB)) + return true; + } + } + + // If there is only a single predecessor of this block, nothing to fold. + if (BB->getSinglePredecessor()) + return false; + + // All the rest of our checks depend on the condition being an instruction. + if (CondInst == 0) + return false; + + // See if this is a phi node in the current block. + if (PHINode *PN = dyn_cast<PHINode>(CondInst)) + if (PN->getParent() == BB) + return ProcessJumpOnPHI(PN); + + // If this is a conditional branch whose condition is and/or of a phi, try to + // simplify it. + if ((CondInst->getOpcode() == Instruction::And || + CondInst->getOpcode() == Instruction::Or) && + isa<BranchInst>(BB->getTerminator()) && + ProcessBranchOnLogical(CondInst, BB, + CondInst->getOpcode() == Instruction::And)) + return true; + + // If we have "br (phi != 42)" and the phi node has any constant values as + // operands, we can thread through this block. + if (CmpInst *CondCmp = dyn_cast<CmpInst>(CondInst)) + if (isa<PHINode>(CondCmp->getOperand(0)) && + isa<Constant>(CondCmp->getOperand(1)) && + ProcessBranchOnCompare(CondCmp, BB)) + return true; + + // Check for some cases that are worth simplifying. Right now we want to look + // for loads that are used by a switch or by the condition for the branch. If + // we see one, check to see if it's partially redundant. If so, insert a PHI + // which can then be used to thread the values. + // + // This is particularly important because reg2mem inserts loads and stores all + // over the place, and this blocks jump threading if we don't zap them. + Value *SimplifyValue = CondInst; + if (CmpInst *CondCmp = dyn_cast<CmpInst>(SimplifyValue)) + if (isa<Constant>(CondCmp->getOperand(1))) + SimplifyValue = CondCmp->getOperand(0); + + if (LoadInst *LI = dyn_cast<LoadInst>(SimplifyValue)) + if (SimplifyPartiallyRedundantLoad(LI)) + return true; + + // TODO: If we have: "br (X > 0)" and we have a predecessor where we know + // "(X == 4)" thread through this block. + + return false; +} + +/// ProcessBranchOnDuplicateCond - We found a block and a predecessor of that +/// block that jump on exactly the same condition. This means that we almost +/// always know the direction of the edge in the DESTBB: +/// PREDBB: +/// br COND, DESTBB, BBY +/// DESTBB: +/// br COND, BBZ, BBW +/// +/// If DESTBB has multiple predecessors, we can't just constant fold the branch +/// in DESTBB, we have to thread over it. +bool JumpThreading::ProcessBranchOnDuplicateCond(BasicBlock *PredBB, + BasicBlock *BB) { + BranchInst *PredBI = cast<BranchInst>(PredBB->getTerminator()); + + // If both successors of PredBB go to DESTBB, we don't know anything. We can + // fold the branch to an unconditional one, which allows other recursive + // simplifications. + bool BranchDir; + if (PredBI->getSuccessor(1) != BB) + BranchDir = true; + else if (PredBI->getSuccessor(0) != BB) + BranchDir = false; + else { + DOUT << " In block '" << PredBB->getNameStart() + << "' folding terminator: " << *PredBB->getTerminator(); + ++NumFolds; + ConstantFoldTerminator(PredBB); + return true; + } + + BranchInst *DestBI = cast<BranchInst>(BB->getTerminator()); + + // If the dest block has one predecessor, just fix the branch condition to a + // constant and fold it. + if (BB->getSinglePredecessor()) { + DOUT << " In block '" << BB->getNameStart() + << "' folding condition to '" << BranchDir << "': " + << *BB->getTerminator(); + ++NumFolds; + DestBI->setCondition(ConstantInt::get(Type::Int1Ty, BranchDir)); + ConstantFoldTerminator(BB); + return true; + } + + // Otherwise we need to thread from PredBB to DestBB's successor which + // involves code duplication. Check to see if it is worth it. + unsigned JumpThreadCost = getJumpThreadDuplicationCost(BB); + if (JumpThreadCost > Threshold) { + DOUT << " Not threading BB '" << BB->getNameStart() + << "' - Cost is too high: " << JumpThreadCost << "\n"; + return false; + } + + // Next, figure out which successor we are threading to. + BasicBlock *SuccBB = DestBI->getSuccessor(!BranchDir); + + // Ok, try to thread it! + return ThreadEdge(BB, PredBB, SuccBB, JumpThreadCost); +} + +/// ProcessSwitchOnDuplicateCond - We found a block and a predecessor of that +/// block that switch on exactly the same condition. This means that we almost +/// always know the direction of the edge in the DESTBB: +/// PREDBB: +/// switch COND [... DESTBB, BBY ... ] +/// DESTBB: +/// switch COND [... BBZ, BBW ] +/// +/// Optimizing switches like this is very important, because simplifycfg builds +/// switches out of repeated 'if' conditions. +bool JumpThreading::ProcessSwitchOnDuplicateCond(BasicBlock *PredBB, + BasicBlock *DestBB) { + // Can't thread edge to self. + if (PredBB == DestBB) + return false; + + + SwitchInst *PredSI = cast<SwitchInst>(PredBB->getTerminator()); + SwitchInst *DestSI = cast<SwitchInst>(DestBB->getTerminator()); + + // There are a variety of optimizations that we can potentially do on these + // blocks: we order them from most to least preferable. + + // If DESTBB *just* contains the switch, then we can forward edges from PREDBB + // directly to their destination. This does not introduce *any* code size + // growth. Skip debug info first. + BasicBlock::iterator BBI = DestBB->begin(); + while (isa<DbgInfoIntrinsic>(BBI)) + BBI++; + + // FIXME: Thread if it just contains a PHI. + if (isa<SwitchInst>(BBI)) { + bool MadeChange = false; + // Ignore the default edge for now. + for (unsigned i = 1, e = DestSI->getNumSuccessors(); i != e; ++i) { + ConstantInt *DestVal = DestSI->getCaseValue(i); + BasicBlock *DestSucc = DestSI->getSuccessor(i); + + // Okay, DestSI has a case for 'DestVal' that goes to 'DestSucc'. See if + // PredSI has an explicit case for it. If so, forward. If it is covered + // by the default case, we can't update PredSI. + unsigned PredCase = PredSI->findCaseValue(DestVal); + if (PredCase == 0) continue; + + // If PredSI doesn't go to DestBB on this value, then it won't reach the + // case on this condition. + if (PredSI->getSuccessor(PredCase) != DestBB && + DestSI->getSuccessor(i) != DestBB) + continue; + + // Otherwise, we're safe to make the change. Make sure that the edge from + // DestSI to DestSucc is not critical and has no PHI nodes. + DOUT << "FORWARDING EDGE " << *DestVal << " FROM: " << *PredSI; + DOUT << "THROUGH: " << *DestSI; + + // If the destination has PHI nodes, just split the edge for updating + // simplicity. + if (isa<PHINode>(DestSucc->begin()) && !DestSucc->getSinglePredecessor()){ + SplitCriticalEdge(DestSI, i, this); + DestSucc = DestSI->getSuccessor(i); + } + FoldSingleEntryPHINodes(DestSucc); + PredSI->setSuccessor(PredCase, DestSucc); + MadeChange = true; + } + + if (MadeChange) + return true; + } + + return false; +} + + +/// SimplifyPartiallyRedundantLoad - If LI is an obviously partially redundant +/// load instruction, eliminate it by replacing it with a PHI node. This is an +/// important optimization that encourages jump threading, and needs to be run +/// interlaced with other jump threading tasks. +bool JumpThreading::SimplifyPartiallyRedundantLoad(LoadInst *LI) { + // Don't hack volatile loads. + if (LI->isVolatile()) return false; + + // If the load is defined in a block with exactly one predecessor, it can't be + // partially redundant. + BasicBlock *LoadBB = LI->getParent(); + if (LoadBB->getSinglePredecessor()) + return false; + + Value *LoadedPtr = LI->getOperand(0); + + // If the loaded operand is defined in the LoadBB, it can't be available. + // FIXME: Could do PHI translation, that would be fun :) + if (Instruction *PtrOp = dyn_cast<Instruction>(LoadedPtr)) + if (PtrOp->getParent() == LoadBB) + return false; + + // Scan a few instructions up from the load, to see if it is obviously live at + // the entry to its block. + BasicBlock::iterator BBIt = LI; + + if (Value *AvailableVal = FindAvailableLoadedValue(LoadedPtr, LoadBB, + BBIt, 6)) { + // If the value if the load is locally available within the block, just use + // it. This frequently occurs for reg2mem'd allocas. + //cerr << "LOAD ELIMINATED:\n" << *BBIt << *LI << "\n"; + + // If the returned value is the load itself, replace with an undef. This can + // only happen in dead loops. + if (AvailableVal == LI) AvailableVal = UndefValue::get(LI->getType()); + LI->replaceAllUsesWith(AvailableVal); + LI->eraseFromParent(); + return true; + } + + // Otherwise, if we scanned the whole block and got to the top of the block, + // we know the block is locally transparent to the load. If not, something + // might clobber its value. + if (BBIt != LoadBB->begin()) + return false; + + + SmallPtrSet<BasicBlock*, 8> PredsScanned; + typedef SmallVector<std::pair<BasicBlock*, Value*>, 8> AvailablePredsTy; + AvailablePredsTy AvailablePreds; + BasicBlock *OneUnavailablePred = 0; + + // If we got here, the loaded value is transparent through to the start of the + // block. Check to see if it is available in any of the predecessor blocks. + for (pred_iterator PI = pred_begin(LoadBB), PE = pred_end(LoadBB); + PI != PE; ++PI) { + BasicBlock *PredBB = *PI; + + // If we already scanned this predecessor, skip it. + if (!PredsScanned.insert(PredBB)) + continue; + + // Scan the predecessor to see if the value is available in the pred. + BBIt = PredBB->end(); + Value *PredAvailable = FindAvailableLoadedValue(LoadedPtr, PredBB, BBIt, 6); + if (!PredAvailable) { + OneUnavailablePred = PredBB; + continue; + } + + // If so, this load is partially redundant. Remember this info so that we + // can create a PHI node. + AvailablePreds.push_back(std::make_pair(PredBB, PredAvailable)); + } + + // If the loaded value isn't available in any predecessor, it isn't partially + // redundant. + if (AvailablePreds.empty()) return false; + + // Okay, the loaded value is available in at least one (and maybe all!) + // predecessors. If the value is unavailable in more than one unique + // predecessor, we want to insert a merge block for those common predecessors. + // This ensures that we only have to insert one reload, thus not increasing + // code size. + BasicBlock *UnavailablePred = 0; + + // If there is exactly one predecessor where the value is unavailable, the + // already computed 'OneUnavailablePred' block is it. If it ends in an + // unconditional branch, we know that it isn't a critical edge. + if (PredsScanned.size() == AvailablePreds.size()+1 && + OneUnavailablePred->getTerminator()->getNumSuccessors() == 1) { + UnavailablePred = OneUnavailablePred; + } else if (PredsScanned.size() != AvailablePreds.size()) { + // Otherwise, we had multiple unavailable predecessors or we had a critical + // edge from the one. + SmallVector<BasicBlock*, 8> PredsToSplit; + SmallPtrSet<BasicBlock*, 8> AvailablePredSet; + + for (unsigned i = 0, e = AvailablePreds.size(); i != e; ++i) + AvailablePredSet.insert(AvailablePreds[i].first); + + // Add all the unavailable predecessors to the PredsToSplit list. + for (pred_iterator PI = pred_begin(LoadBB), PE = pred_end(LoadBB); + PI != PE; ++PI) + if (!AvailablePredSet.count(*PI)) + PredsToSplit.push_back(*PI); + + // Split them out to their own block. + UnavailablePred = + SplitBlockPredecessors(LoadBB, &PredsToSplit[0], PredsToSplit.size(), + "thread-split", this); + } + + // If the value isn't available in all predecessors, then there will be + // exactly one where it isn't available. Insert a load on that edge and add + // it to the AvailablePreds list. + if (UnavailablePred) { + assert(UnavailablePred->getTerminator()->getNumSuccessors() == 1 && + "Can't handle critical edge here!"); + Value *NewVal = new LoadInst(LoadedPtr, LI->getName()+".pr", + UnavailablePred->getTerminator()); + AvailablePreds.push_back(std::make_pair(UnavailablePred, NewVal)); + } + + // Now we know that each predecessor of this block has a value in + // AvailablePreds, sort them for efficient access as we're walking the preds. + array_pod_sort(AvailablePreds.begin(), AvailablePreds.end()); + + // Create a PHI node at the start of the block for the PRE'd load value. + PHINode *PN = PHINode::Create(LI->getType(), "", LoadBB->begin()); + PN->takeName(LI); + + // Insert new entries into the PHI for each predecessor. A single block may + // have multiple entries here. + for (pred_iterator PI = pred_begin(LoadBB), E = pred_end(LoadBB); PI != E; + ++PI) { + AvailablePredsTy::iterator I = + std::lower_bound(AvailablePreds.begin(), AvailablePreds.end(), + std::make_pair(*PI, (Value*)0)); + + assert(I != AvailablePreds.end() && I->first == *PI && + "Didn't find entry for predecessor!"); + + PN->addIncoming(I->second, I->first); + } + + //cerr << "PRE: " << *LI << *PN << "\n"; + + LI->replaceAllUsesWith(PN); + LI->eraseFromParent(); + + return true; +} + + +/// ProcessJumpOnPHI - We have a conditional branch of switch on a PHI node in +/// the current block. See if there are any simplifications we can do based on +/// inputs to the phi node. +/// +bool JumpThreading::ProcessJumpOnPHI(PHINode *PN) { + // See if the phi node has any constant values. If so, we can determine where + // the corresponding predecessor will branch. + ConstantInt *PredCst = 0; + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) + if ((PredCst = dyn_cast<ConstantInt>(PN->getIncomingValue(i)))) + break; + + // If no incoming value has a constant, we don't know the destination of any + // predecessors. + if (PredCst == 0) + return false; + + // See if the cost of duplicating this block is low enough. + BasicBlock *BB = PN->getParent(); + unsigned JumpThreadCost = getJumpThreadDuplicationCost(BB); + if (JumpThreadCost > Threshold) { + DOUT << " Not threading BB '" << BB->getNameStart() + << "' - Cost is too high: " << JumpThreadCost << "\n"; + return false; + } + + // If so, we can actually do this threading. Merge any common predecessors + // that will act the same. + BasicBlock *PredBB = FactorCommonPHIPreds(PN, PredCst); + + // Next, figure out which successor we are threading to. + BasicBlock *SuccBB; + if (BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator())) + SuccBB = BI->getSuccessor(PredCst == ConstantInt::getFalse()); + else { + SwitchInst *SI = cast<SwitchInst>(BB->getTerminator()); + SuccBB = SI->getSuccessor(SI->findCaseValue(PredCst)); + } + + // Ok, try to thread it! + return ThreadEdge(BB, PredBB, SuccBB, JumpThreadCost); +} + +/// ProcessJumpOnLogicalPHI - PN's basic block contains a conditional branch +/// whose condition is an AND/OR where one side is PN. If PN has constant +/// operands that permit us to evaluate the condition for some operand, thread +/// through the block. For example with: +/// br (and X, phi(Y, Z, false)) +/// the predecessor corresponding to the 'false' will always jump to the false +/// destination of the branch. +/// +bool JumpThreading::ProcessBranchOnLogical(Value *V, BasicBlock *BB, + bool isAnd) { + // If this is a binary operator tree of the same AND/OR opcode, check the + // LHS/RHS. + if (BinaryOperator *BO = dyn_cast<BinaryOperator>(V)) + if ((isAnd && BO->getOpcode() == Instruction::And) || + (!isAnd && BO->getOpcode() == Instruction::Or)) { + if (ProcessBranchOnLogical(BO->getOperand(0), BB, isAnd)) + return true; + if (ProcessBranchOnLogical(BO->getOperand(1), BB, isAnd)) + return true; + } + + // If this isn't a PHI node, we can't handle it. + PHINode *PN = dyn_cast<PHINode>(V); + if (!PN || PN->getParent() != BB) return false; + + // We can only do the simplification for phi nodes of 'false' with AND or + // 'true' with OR. See if we have any entries in the phi for this. + unsigned PredNo = ~0U; + ConstantInt *PredCst = ConstantInt::get(Type::Int1Ty, !isAnd); + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { + if (PN->getIncomingValue(i) == PredCst) { + PredNo = i; + break; + } + } + + // If no match, bail out. + if (PredNo == ~0U) + return false; + + // See if the cost of duplicating this block is low enough. + unsigned JumpThreadCost = getJumpThreadDuplicationCost(BB); + if (JumpThreadCost > Threshold) { + DOUT << " Not threading BB '" << BB->getNameStart() + << "' - Cost is too high: " << JumpThreadCost << "\n"; + return false; + } + + // If so, we can actually do this threading. Merge any common predecessors + // that will act the same. + BasicBlock *PredBB = FactorCommonPHIPreds(PN, PredCst); + + // Next, figure out which successor we are threading to. If this was an AND, + // the constant must be FALSE, and we must be targeting the 'false' block. + // If this is an OR, the constant must be TRUE, and we must be targeting the + // 'true' block. + BasicBlock *SuccBB = BB->getTerminator()->getSuccessor(isAnd); + + // Ok, try to thread it! + return ThreadEdge(BB, PredBB, SuccBB, JumpThreadCost); +} + +/// ProcessBranchOnCompare - We found a branch on a comparison between a phi +/// node and a constant. If the PHI node contains any constants as inputs, we +/// can fold the compare for that edge and thread through it. +bool JumpThreading::ProcessBranchOnCompare(CmpInst *Cmp, BasicBlock *BB) { + PHINode *PN = cast<PHINode>(Cmp->getOperand(0)); + Constant *RHS = cast<Constant>(Cmp->getOperand(1)); + + // If the phi isn't in the current block, an incoming edge to this block + // doesn't control the destination. + if (PN->getParent() != BB) + return false; + + // We can do this simplification if any comparisons fold to true or false. + // See if any do. + Constant *PredCst = 0; + bool TrueDirection = false; + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { + PredCst = dyn_cast<Constant>(PN->getIncomingValue(i)); + if (PredCst == 0) continue; + + Constant *Res; + if (ICmpInst *ICI = dyn_cast<ICmpInst>(Cmp)) + Res = ConstantExpr::getICmp(ICI->getPredicate(), PredCst, RHS); + else + Res = ConstantExpr::getFCmp(cast<FCmpInst>(Cmp)->getPredicate(), + PredCst, RHS); + // If this folded to a constant expr, we can't do anything. + if (ConstantInt *ResC = dyn_cast<ConstantInt>(Res)) { + TrueDirection = ResC->getZExtValue(); + break; + } + // If this folded to undef, just go the false way. + if (isa<UndefValue>(Res)) { + TrueDirection = false; + break; + } + + // Otherwise, we can't fold this input. + PredCst = 0; + } + + // If no match, bail out. + if (PredCst == 0) + return false; + + // See if the cost of duplicating this block is low enough. + unsigned JumpThreadCost = getJumpThreadDuplicationCost(BB); + if (JumpThreadCost > Threshold) { + DOUT << " Not threading BB '" << BB->getNameStart() + << "' - Cost is too high: " << JumpThreadCost << "\n"; + return false; + } + + // If so, we can actually do this threading. Merge any common predecessors + // that will act the same. + BasicBlock *PredBB = FactorCommonPHIPreds(PN, PredCst); + + // Next, get our successor. + BasicBlock *SuccBB = BB->getTerminator()->getSuccessor(!TrueDirection); + + // Ok, try to thread it! + return ThreadEdge(BB, PredBB, SuccBB, JumpThreadCost); +} + + +/// ThreadEdge - We have decided that it is safe and profitable to thread an +/// edge from PredBB to SuccBB across BB. Transform the IR to reflect this +/// change. +bool JumpThreading::ThreadEdge(BasicBlock *BB, BasicBlock *PredBB, + BasicBlock *SuccBB, unsigned JumpThreadCost) { + + // If threading to the same block as we come from, we would infinite loop. + if (SuccBB == BB) { + DOUT << " Not threading across BB '" << BB->getNameStart() + << "' - would thread to self!\n"; + return false; + } + + // If threading this would thread across a loop header, don't thread the edge. + // See the comments above FindLoopHeaders for justifications and caveats. + if (LoopHeaders.count(BB)) { + DOUT << " Not threading from '" << PredBB->getNameStart() + << "' across loop header BB '" << BB->getNameStart() + << "' to dest BB '" << SuccBB->getNameStart() + << "' - it might create an irreducible loop!\n"; + return false; + } + + // And finally, do it! + DOUT << " Threading edge from '" << PredBB->getNameStart() << "' to '" + << SuccBB->getNameStart() << "' with cost: " << JumpThreadCost + << ", across block:\n " + << *BB << "\n"; + + // Jump Threading can not update SSA properties correctly if the values + // defined in the duplicated block are used outside of the block itself. For + // this reason, we spill all values that are used outside of BB to the stack. + for (BasicBlock::iterator I = BB->begin(); I != BB->end(); ++I) { + if (!I->isUsedOutsideOfBlock(BB)) + continue; + + // We found a use of I outside of BB. Create a new stack slot to + // break this inter-block usage pattern. + DemoteRegToStack(*I); + } + + // We are going to have to map operands from the original BB block to the new + // copy of the block 'NewBB'. If there are PHI nodes in BB, evaluate them to + // account for entry from PredBB. + DenseMap<Instruction*, Value*> ValueMapping; + + BasicBlock *NewBB = + BasicBlock::Create(BB->getName()+".thread", BB->getParent(), BB); + NewBB->moveAfter(PredBB); + + BasicBlock::iterator BI = BB->begin(); + for (; PHINode *PN = dyn_cast<PHINode>(BI); ++BI) + ValueMapping[PN] = PN->getIncomingValueForBlock(PredBB); + + // Clone the non-phi instructions of BB into NewBB, keeping track of the + // mapping and using it to remap operands in the cloned instructions. + for (; !isa<TerminatorInst>(BI); ++BI) { + Instruction *New = BI->clone(); + New->setName(BI->getNameStart()); + NewBB->getInstList().push_back(New); + ValueMapping[BI] = New; + + // Remap operands to patch up intra-block references. + for (unsigned i = 0, e = New->getNumOperands(); i != e; ++i) + if (Instruction *Inst = dyn_cast<Instruction>(New->getOperand(i))) + if (Value *Remapped = ValueMapping[Inst]) + New->setOperand(i, Remapped); + } + + // We didn't copy the terminator from BB over to NewBB, because there is now + // an unconditional jump to SuccBB. Insert the unconditional jump. + BranchInst::Create(SuccBB, NewBB); + + // Check to see if SuccBB has PHI nodes. If so, we need to add entries to the + // PHI nodes for NewBB now. + for (BasicBlock::iterator PNI = SuccBB->begin(); isa<PHINode>(PNI); ++PNI) { + PHINode *PN = cast<PHINode>(PNI); + // Ok, we have a PHI node. Figure out what the incoming value was for the + // DestBlock. + Value *IV = PN->getIncomingValueForBlock(BB); + + // Remap the value if necessary. + if (Instruction *Inst = dyn_cast<Instruction>(IV)) + if (Value *MappedIV = ValueMapping[Inst]) + IV = MappedIV; + PN->addIncoming(IV, NewBB); + } + + // Ok, NewBB is good to go. Update the terminator of PredBB to jump to + // NewBB instead of BB. This eliminates predecessors from BB, which requires + // us to simplify any PHI nodes in BB. + TerminatorInst *PredTerm = PredBB->getTerminator(); + for (unsigned i = 0, e = PredTerm->getNumSuccessors(); i != e; ++i) + if (PredTerm->getSuccessor(i) == BB) { + BB->removePredecessor(PredBB); + PredTerm->setSuccessor(i, NewBB); + } + + // At this point, the IR is fully up to date and consistent. Do a quick scan + // over the new instructions and zap any that are constants or dead. This + // frequently happens because of phi translation. + BI = NewBB->begin(); + for (BasicBlock::iterator E = NewBB->end(); BI != E; ) { + Instruction *Inst = BI++; + if (Constant *C = ConstantFoldInstruction(Inst, TD)) { + Inst->replaceAllUsesWith(C); + Inst->eraseFromParent(); + continue; + } + + RecursivelyDeleteTriviallyDeadInstructions(Inst); + } + + // Threaded an edge! + ++NumThreads; + return true; +} diff --git a/lib/Transforms/Scalar/LICM.cpp b/lib/Transforms/Scalar/LICM.cpp new file mode 100644 index 0000000..1021469 --- /dev/null +++ b/lib/Transforms/Scalar/LICM.cpp @@ -0,0 +1,885 @@ +//===-- LICM.cpp - Loop Invariant Code Motion Pass ------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass performs loop invariant code motion, attempting to remove as much +// code from the body of a loop as possible. It does this by either hoisting +// code into the preheader block, or by sinking code to the exit blocks if it is +// safe. This pass also promotes must-aliased memory locations in the loop to +// live in registers, thus hoisting and sinking "invariant" loads and stores. +// +// This pass uses alias analysis for two purposes: +// +// 1. Moving loop invariant loads and calls out of loops. If we can determine +// that a load or call inside of a loop never aliases anything stored to, +// we can hoist it or sink it like any other instruction. +// 2. Scalar Promotion of Memory - If there is a store instruction inside of +// the loop, we try to move the store to happen AFTER the loop instead of +// inside of the loop. This can only happen if a few conditions are true: +// A. The pointer stored through is loop invariant +// B. There are no stores or loads in the loop which _may_ alias the +// pointer. There are no calls in the loop which mod/ref the pointer. +// If these conditions are true, we can promote the loads and stores in the +// loop of the pointer to use a temporary alloca'd variable. We then use +// the mem2reg functionality to construct the appropriate SSA form for the +// variable. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "licm" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Instructions.h" +#include "llvm/Target/TargetData.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AliasSetTracker.h" +#include "llvm/Analysis/Dominators.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Transforms/Utils/PromoteMemToReg.h" +#include "llvm/Support/CFG.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/ADT/Statistic.h" +#include <algorithm> +using namespace llvm; + +STATISTIC(NumSunk , "Number of instructions sunk out of loop"); +STATISTIC(NumHoisted , "Number of instructions hoisted out of loop"); +STATISTIC(NumMovedLoads, "Number of load insts hoisted or sunk"); +STATISTIC(NumMovedCalls, "Number of call insts hoisted or sunk"); +STATISTIC(NumPromoted , "Number of memory locations promoted to registers"); + +static cl::opt<bool> +DisablePromotion("disable-licm-promotion", cl::Hidden, + cl::desc("Disable memory promotion in LICM pass")); + +// This feature is currently disabled by default because CodeGen is not yet +// capable of rematerializing these constants in PIC mode, so it can lead to +// degraded performance. Compile test/CodeGen/X86/remat-constant.ll with +// -relocation-model=pic to see an example of this. +static cl::opt<bool> +EnableLICMConstantMotion("enable-licm-constant-variables", cl::Hidden, + cl::desc("Enable hoisting/sinking of constant " + "global variables")); + +namespace { + struct VISIBILITY_HIDDEN LICM : public LoopPass { + static char ID; // Pass identification, replacement for typeid + LICM() : LoopPass(&ID) {} + + virtual bool runOnLoop(Loop *L, LPPassManager &LPM); + + /// This transformation requires natural loop information & requires that + /// loop preheaders be inserted into the CFG... + /// + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesCFG(); + AU.addRequiredID(LoopSimplifyID); + AU.addRequired<LoopInfo>(); + AU.addRequired<DominatorTree>(); + AU.addRequired<DominanceFrontier>(); // For scalar promotion (mem2reg) + AU.addRequired<AliasAnalysis>(); + AU.addPreserved<ScalarEvolution>(); + AU.addPreserved<DominanceFrontier>(); + } + + bool doFinalization() { + // Free the values stored in the map + for (std::map<Loop *, AliasSetTracker *>::iterator + I = LoopToAliasMap.begin(), E = LoopToAliasMap.end(); I != E; ++I) + delete I->second; + + LoopToAliasMap.clear(); + return false; + } + + private: + // Various analyses that we use... + AliasAnalysis *AA; // Current AliasAnalysis information + LoopInfo *LI; // Current LoopInfo + DominatorTree *DT; // Dominator Tree for the current Loop... + DominanceFrontier *DF; // Current Dominance Frontier + + // State that is updated as we process loops + bool Changed; // Set to true when we change anything. + BasicBlock *Preheader; // The preheader block of the current loop... + Loop *CurLoop; // The current loop we are working on... + AliasSetTracker *CurAST; // AliasSet information for the current loop... + std::map<Loop *, AliasSetTracker *> LoopToAliasMap; + + /// cloneBasicBlockAnalysis - Simple Analysis hook. Clone alias set info. + void cloneBasicBlockAnalysis(BasicBlock *From, BasicBlock *To, Loop *L); + + /// deleteAnalysisValue - Simple Analysis hook. Delete value V from alias + /// set. + void deleteAnalysisValue(Value *V, Loop *L); + + /// SinkRegion - Walk the specified region of the CFG (defined by all blocks + /// dominated by the specified block, and that are in the current loop) in + /// reverse depth first order w.r.t the DominatorTree. This allows us to + /// visit uses before definitions, allowing us to sink a loop body in one + /// pass without iteration. + /// + void SinkRegion(DomTreeNode *N); + + /// HoistRegion - Walk the specified region of the CFG (defined by all + /// blocks dominated by the specified block, and that are in the current + /// loop) in depth first order w.r.t the DominatorTree. This allows us to + /// visit definitions before uses, allowing us to hoist a loop body in one + /// pass without iteration. + /// + void HoistRegion(DomTreeNode *N); + + /// inSubLoop - Little predicate that returns true if the specified basic + /// block is in a subloop of the current one, not the current one itself. + /// + bool inSubLoop(BasicBlock *BB) { + assert(CurLoop->contains(BB) && "Only valid if BB is IN the loop"); + for (Loop::iterator I = CurLoop->begin(), E = CurLoop->end(); I != E; ++I) + if ((*I)->contains(BB)) + return true; // A subloop actually contains this block! + return false; + } + + /// isExitBlockDominatedByBlockInLoop - This method checks to see if the + /// specified exit block of the loop is dominated by the specified block + /// that is in the body of the loop. We use these constraints to + /// dramatically limit the amount of the dominator tree that needs to be + /// searched. + bool isExitBlockDominatedByBlockInLoop(BasicBlock *ExitBlock, + BasicBlock *BlockInLoop) const { + // If the block in the loop is the loop header, it must be dominated! + BasicBlock *LoopHeader = CurLoop->getHeader(); + if (BlockInLoop == LoopHeader) + return true; + + DomTreeNode *BlockInLoopNode = DT->getNode(BlockInLoop); + DomTreeNode *IDom = DT->getNode(ExitBlock); + + // Because the exit block is not in the loop, we know we have to get _at + // least_ its immediate dominator. + do { + // Get next Immediate Dominator. + IDom = IDom->getIDom(); + + // If we have got to the header of the loop, then the instructions block + // did not dominate the exit node, so we can't hoist it. + if (IDom->getBlock() == LoopHeader) + return false; + + } while (IDom != BlockInLoopNode); + + return true; + } + + /// sink - When an instruction is found to only be used outside of the loop, + /// this function moves it to the exit blocks and patches up SSA form as + /// needed. + /// + void sink(Instruction &I); + + /// hoist - When an instruction is found to only use loop invariant operands + /// that is safe to hoist, this instruction is called to do the dirty work. + /// + void hoist(Instruction &I); + + /// isSafeToExecuteUnconditionally - Only sink or hoist an instruction if it + /// is not a trapping instruction or if it is a trapping instruction and is + /// guaranteed to execute. + /// + bool isSafeToExecuteUnconditionally(Instruction &I); + + /// pointerInvalidatedByLoop - Return true if the body of this loop may + /// store into the memory location pointed to by V. + /// + bool pointerInvalidatedByLoop(Value *V, unsigned Size) { + // Check to see if any of the basic blocks in CurLoop invalidate *V. + return CurAST->getAliasSetForPointer(V, Size).isMod(); + } + + bool canSinkOrHoistInst(Instruction &I); + bool isLoopInvariantInst(Instruction &I); + bool isNotUsedInLoop(Instruction &I); + + /// PromoteValuesInLoop - Look at the stores in the loop and promote as many + /// to scalars as we can. + /// + void PromoteValuesInLoop(); + + /// FindPromotableValuesInLoop - Check the current loop for stores to + /// definite pointers, which are not loaded and stored through may aliases. + /// If these are found, create an alloca for the value, add it to the + /// PromotedValues list, and keep track of the mapping from value to + /// alloca... + /// + void FindPromotableValuesInLoop( + std::vector<std::pair<AllocaInst*, Value*> > &PromotedValues, + std::map<Value*, AllocaInst*> &Val2AlMap); + }; +} + +char LICM::ID = 0; +static RegisterPass<LICM> X("licm", "Loop Invariant Code Motion"); + +Pass *llvm::createLICMPass() { return new LICM(); } + +/// Hoist expressions out of the specified loop. Note, alias info for inner +/// loop is not preserved so it is not a good idea to run LICM multiple +/// times on one loop. +/// +bool LICM::runOnLoop(Loop *L, LPPassManager &LPM) { + Changed = false; + + // Get our Loop and Alias Analysis information... + LI = &getAnalysis<LoopInfo>(); + AA = &getAnalysis<AliasAnalysis>(); + DF = &getAnalysis<DominanceFrontier>(); + DT = &getAnalysis<DominatorTree>(); + + CurAST = new AliasSetTracker(*AA); + // Collect Alias info from subloops + for (Loop::iterator LoopItr = L->begin(), LoopItrE = L->end(); + LoopItr != LoopItrE; ++LoopItr) { + Loop *InnerL = *LoopItr; + AliasSetTracker *InnerAST = LoopToAliasMap[InnerL]; + assert (InnerAST && "Where is my AST?"); + + // What if InnerLoop was modified by other passes ? + CurAST->add(*InnerAST); + } + + CurLoop = L; + + // Get the preheader block to move instructions into... + Preheader = L->getLoopPreheader(); + assert(Preheader&&"Preheader insertion pass guarantees we have a preheader!"); + + // Loop over the body of this loop, looking for calls, invokes, and stores. + // Because subloops have already been incorporated into AST, we skip blocks in + // subloops. + // + for (Loop::block_iterator I = L->block_begin(), E = L->block_end(); + I != E; ++I) { + BasicBlock *BB = *I; + if (LI->getLoopFor(BB) == L) // Ignore blocks in subloops... + CurAST->add(*BB); // Incorporate the specified basic block + } + + // We want to visit all of the instructions in this loop... that are not parts + // of our subloops (they have already had their invariants hoisted out of + // their loop, into this loop, so there is no need to process the BODIES of + // the subloops). + // + // Traverse the body of the loop in depth first order on the dominator tree so + // that we are guaranteed to see definitions before we see uses. This allows + // us to sink instructions in one pass, without iteration. After sinking + // instructions, we perform another pass to hoist them out of the loop. + // + SinkRegion(DT->getNode(L->getHeader())); + HoistRegion(DT->getNode(L->getHeader())); + + // Now that all loop invariants have been removed from the loop, promote any + // memory references to scalars that we can... + if (!DisablePromotion) + PromoteValuesInLoop(); + + // Clear out loops state information for the next iteration + CurLoop = 0; + Preheader = 0; + + LoopToAliasMap[L] = CurAST; + return Changed; +} + +/// SinkRegion - Walk the specified region of the CFG (defined by all blocks +/// dominated by the specified block, and that are in the current loop) in +/// reverse depth first order w.r.t the DominatorTree. This allows us to visit +/// uses before definitions, allowing us to sink a loop body in one pass without +/// iteration. +/// +void LICM::SinkRegion(DomTreeNode *N) { + assert(N != 0 && "Null dominator tree node?"); + BasicBlock *BB = N->getBlock(); + + // If this subregion is not in the top level loop at all, exit. + if (!CurLoop->contains(BB)) return; + + // We are processing blocks in reverse dfo, so process children first... + const std::vector<DomTreeNode*> &Children = N->getChildren(); + for (unsigned i = 0, e = Children.size(); i != e; ++i) + SinkRegion(Children[i]); + + // Only need to process the contents of this block if it is not part of a + // subloop (which would already have been processed). + if (inSubLoop(BB)) return; + + for (BasicBlock::iterator II = BB->end(); II != BB->begin(); ) { + Instruction &I = *--II; + + // Check to see if we can sink this instruction to the exit blocks + // of the loop. We can do this if the all users of the instruction are + // outside of the loop. In this case, it doesn't even matter if the + // operands of the instruction are loop invariant. + // + if (isNotUsedInLoop(I) && canSinkOrHoistInst(I)) { + ++II; + sink(I); + } + } +} + + +/// HoistRegion - Walk the specified region of the CFG (defined by all blocks +/// dominated by the specified block, and that are in the current loop) in depth +/// first order w.r.t the DominatorTree. This allows us to visit definitions +/// before uses, allowing us to hoist a loop body in one pass without iteration. +/// +void LICM::HoistRegion(DomTreeNode *N) { + assert(N != 0 && "Null dominator tree node?"); + BasicBlock *BB = N->getBlock(); + + // If this subregion is not in the top level loop at all, exit. + if (!CurLoop->contains(BB)) return; + + // Only need to process the contents of this block if it is not part of a + // subloop (which would already have been processed). + if (!inSubLoop(BB)) + for (BasicBlock::iterator II = BB->begin(), E = BB->end(); II != E; ) { + Instruction &I = *II++; + + // Try hoisting the instruction out to the preheader. We can only do this + // if all of the operands of the instruction are loop invariant and if it + // is safe to hoist the instruction. + // + if (isLoopInvariantInst(I) && canSinkOrHoistInst(I) && + isSafeToExecuteUnconditionally(I)) + hoist(I); + } + + const std::vector<DomTreeNode*> &Children = N->getChildren(); + for (unsigned i = 0, e = Children.size(); i != e; ++i) + HoistRegion(Children[i]); +} + +/// canSinkOrHoistInst - Return true if the hoister and sinker can handle this +/// instruction. +/// +bool LICM::canSinkOrHoistInst(Instruction &I) { + // Loads have extra constraints we have to verify before we can hoist them. + if (LoadInst *LI = dyn_cast<LoadInst>(&I)) { + if (LI->isVolatile()) + return false; // Don't hoist volatile loads! + + // Loads from constant memory are always safe to move, even if they end up + // in the same alias set as something that ends up being modified. + if (EnableLICMConstantMotion && + AA->pointsToConstantMemory(LI->getOperand(0))) + return true; + + // Don't hoist loads which have may-aliased stores in loop. + unsigned Size = 0; + if (LI->getType()->isSized()) + Size = AA->getTargetData().getTypeStoreSize(LI->getType()); + return !pointerInvalidatedByLoop(LI->getOperand(0), Size); + } else if (CallInst *CI = dyn_cast<CallInst>(&I)) { + // Handle obvious cases efficiently. + AliasAnalysis::ModRefBehavior Behavior = AA->getModRefBehavior(CI); + if (Behavior == AliasAnalysis::DoesNotAccessMemory) + return true; + else if (Behavior == AliasAnalysis::OnlyReadsMemory) { + // If this call only reads from memory and there are no writes to memory + // in the loop, we can hoist or sink the call as appropriate. + bool FoundMod = false; + for (AliasSetTracker::iterator I = CurAST->begin(), E = CurAST->end(); + I != E; ++I) { + AliasSet &AS = *I; + if (!AS.isForwardingAliasSet() && AS.isMod()) { + FoundMod = true; + break; + } + } + if (!FoundMod) return true; + } + + // FIXME: This should use mod/ref information to see if we can hoist or sink + // the call. + + return false; + } + + // Otherwise these instructions are hoistable/sinkable + return isa<BinaryOperator>(I) || isa<CastInst>(I) || + isa<SelectInst>(I) || isa<GetElementPtrInst>(I) || isa<CmpInst>(I) || + isa<InsertElementInst>(I) || isa<ExtractElementInst>(I) || + isa<ShuffleVectorInst>(I); +} + +/// isNotUsedInLoop - Return true if the only users of this instruction are +/// outside of the loop. If this is true, we can sink the instruction to the +/// exit blocks of the loop. +/// +bool LICM::isNotUsedInLoop(Instruction &I) { + for (Value::use_iterator UI = I.use_begin(), E = I.use_end(); UI != E; ++UI) { + Instruction *User = cast<Instruction>(*UI); + if (PHINode *PN = dyn_cast<PHINode>(User)) { + // PHI node uses occur in predecessor blocks! + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) + if (PN->getIncomingValue(i) == &I) + if (CurLoop->contains(PN->getIncomingBlock(i))) + return false; + } else if (CurLoop->contains(User->getParent())) { + return false; + } + } + return true; +} + + +/// isLoopInvariantInst - Return true if all operands of this instruction are +/// loop invariant. We also filter out non-hoistable instructions here just for +/// efficiency. +/// +bool LICM::isLoopInvariantInst(Instruction &I) { + // The instruction is loop invariant if all of its operands are loop-invariant + for (unsigned i = 0, e = I.getNumOperands(); i != e; ++i) + if (!CurLoop->isLoopInvariant(I.getOperand(i))) + return false; + + // If we got this far, the instruction is loop invariant! + return true; +} + +/// sink - When an instruction is found to only be used outside of the loop, +/// this function moves it to the exit blocks and patches up SSA form as needed. +/// This method is guaranteed to remove the original instruction from its +/// position, and may either delete it or move it to outside of the loop. +/// +void LICM::sink(Instruction &I) { + DOUT << "LICM sinking instruction: " << I; + + SmallVector<BasicBlock*, 8> ExitBlocks; + CurLoop->getExitBlocks(ExitBlocks); + + if (isa<LoadInst>(I)) ++NumMovedLoads; + else if (isa<CallInst>(I)) ++NumMovedCalls; + ++NumSunk; + Changed = true; + + // The case where there is only a single exit node of this loop is common + // enough that we handle it as a special (more efficient) case. It is more + // efficient to handle because there are no PHI nodes that need to be placed. + if (ExitBlocks.size() == 1) { + if (!isExitBlockDominatedByBlockInLoop(ExitBlocks[0], I.getParent())) { + // Instruction is not used, just delete it. + CurAST->deleteValue(&I); + if (!I.use_empty()) // If I has users in unreachable blocks, eliminate. + I.replaceAllUsesWith(UndefValue::get(I.getType())); + I.eraseFromParent(); + } else { + // Move the instruction to the start of the exit block, after any PHI + // nodes in it. + I.removeFromParent(); + + BasicBlock::iterator InsertPt = ExitBlocks[0]->getFirstNonPHI(); + ExitBlocks[0]->getInstList().insert(InsertPt, &I); + } + } else if (ExitBlocks.empty()) { + // The instruction is actually dead if there ARE NO exit blocks. + CurAST->deleteValue(&I); + if (!I.use_empty()) // If I has users in unreachable blocks, eliminate. + I.replaceAllUsesWith(UndefValue::get(I.getType())); + I.eraseFromParent(); + } else { + // Otherwise, if we have multiple exits, use the PromoteMem2Reg function to + // do all of the hard work of inserting PHI nodes as necessary. We convert + // the value into a stack object to get it to do this. + + // Firstly, we create a stack object to hold the value... + AllocaInst *AI = 0; + + if (I.getType() != Type::VoidTy) { + AI = new AllocaInst(I.getType(), 0, I.getName(), + I.getParent()->getParent()->getEntryBlock().begin()); + CurAST->add(AI); + } + + // Secondly, insert load instructions for each use of the instruction + // outside of the loop. + while (!I.use_empty()) { + Instruction *U = cast<Instruction>(I.use_back()); + + // If the user is a PHI Node, we actually have to insert load instructions + // in all predecessor blocks, not in the PHI block itself! + if (PHINode *UPN = dyn_cast<PHINode>(U)) { + // Only insert into each predecessor once, so that we don't have + // different incoming values from the same block! + std::map<BasicBlock*, Value*> InsertedBlocks; + for (unsigned i = 0, e = UPN->getNumIncomingValues(); i != e; ++i) + if (UPN->getIncomingValue(i) == &I) { + BasicBlock *Pred = UPN->getIncomingBlock(i); + Value *&PredVal = InsertedBlocks[Pred]; + if (!PredVal) { + // Insert a new load instruction right before the terminator in + // the predecessor block. + PredVal = new LoadInst(AI, "", Pred->getTerminator()); + CurAST->add(cast<LoadInst>(PredVal)); + } + + UPN->setIncomingValue(i, PredVal); + } + + } else { + LoadInst *L = new LoadInst(AI, "", U); + U->replaceUsesOfWith(&I, L); + CurAST->add(L); + } + } + + // Thirdly, insert a copy of the instruction in each exit block of the loop + // that is dominated by the instruction, storing the result into the memory + // location. Be careful not to insert the instruction into any particular + // basic block more than once. + std::set<BasicBlock*> InsertedBlocks; + BasicBlock *InstOrigBB = I.getParent(); + + for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i) { + BasicBlock *ExitBlock = ExitBlocks[i]; + + if (isExitBlockDominatedByBlockInLoop(ExitBlock, InstOrigBB)) { + // If we haven't already processed this exit block, do so now. + if (InsertedBlocks.insert(ExitBlock).second) { + // Insert the code after the last PHI node... + BasicBlock::iterator InsertPt = ExitBlock->getFirstNonPHI(); + + // If this is the first exit block processed, just move the original + // instruction, otherwise clone the original instruction and insert + // the copy. + Instruction *New; + if (InsertedBlocks.size() == 1) { + I.removeFromParent(); + ExitBlock->getInstList().insert(InsertPt, &I); + New = &I; + } else { + New = I.clone(); + CurAST->copyValue(&I, New); + if (!I.getName().empty()) + New->setName(I.getName()+".le"); + ExitBlock->getInstList().insert(InsertPt, New); + } + + // Now that we have inserted the instruction, store it into the alloca + if (AI) new StoreInst(New, AI, InsertPt); + } + } + } + + // If the instruction doesn't dominate any exit blocks, it must be dead. + if (InsertedBlocks.empty()) { + CurAST->deleteValue(&I); + I.eraseFromParent(); + } + + // Finally, promote the fine value to SSA form. + if (AI) { + std::vector<AllocaInst*> Allocas; + Allocas.push_back(AI); + PromoteMemToReg(Allocas, *DT, *DF, CurAST); + } + } +} + +/// hoist - When an instruction is found to only use loop invariant operands +/// that is safe to hoist, this instruction is called to do the dirty work. +/// +void LICM::hoist(Instruction &I) { + DOUT << "LICM hoisting to " << Preheader->getName() << ": " << I; + + // Remove the instruction from its current basic block... but don't delete the + // instruction. + I.removeFromParent(); + + // Insert the new node in Preheader, before the terminator. + Preheader->getInstList().insert(Preheader->getTerminator(), &I); + + if (isa<LoadInst>(I)) ++NumMovedLoads; + else if (isa<CallInst>(I)) ++NumMovedCalls; + ++NumHoisted; + Changed = true; +} + +/// isSafeToExecuteUnconditionally - Only sink or hoist an instruction if it is +/// not a trapping instruction or if it is a trapping instruction and is +/// guaranteed to execute. +/// +bool LICM::isSafeToExecuteUnconditionally(Instruction &Inst) { + // If it is not a trapping instruction, it is always safe to hoist. + if (!Inst.isTrapping()) return true; + + // Otherwise we have to check to make sure that the instruction dominates all + // of the exit blocks. If it doesn't, then there is a path out of the loop + // which does not execute this instruction, so we can't hoist it. + + // If the instruction is in the header block for the loop (which is very + // common), it is always guaranteed to dominate the exit blocks. Since this + // is a common case, and can save some work, check it now. + if (Inst.getParent() == CurLoop->getHeader()) + return true; + + // It's always safe to load from a global or alloca. + if (isa<LoadInst>(Inst)) + if (isa<AllocationInst>(Inst.getOperand(0)) || + isa<GlobalVariable>(Inst.getOperand(0))) + return true; + + // Get the exit blocks for the current loop. + SmallVector<BasicBlock*, 8> ExitBlocks; + CurLoop->getExitBlocks(ExitBlocks); + + // For each exit block, get the DT node and walk up the DT until the + // instruction's basic block is found or we exit the loop. + for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i) + if (!isExitBlockDominatedByBlockInLoop(ExitBlocks[i], Inst.getParent())) + return false; + + return true; +} + + +/// PromoteValuesInLoop - Try to promote memory values to scalars by sinking +/// stores out of the loop and moving loads to before the loop. We do this by +/// looping over the stores in the loop, looking for stores to Must pointers +/// which are loop invariant. We promote these memory locations to use allocas +/// instead. These allocas can easily be raised to register values by the +/// PromoteMem2Reg functionality. +/// +void LICM::PromoteValuesInLoop() { + // PromotedValues - List of values that are promoted out of the loop. Each + // value has an alloca instruction for it, and a canonical version of the + // pointer. + std::vector<std::pair<AllocaInst*, Value*> > PromotedValues; + std::map<Value*, AllocaInst*> ValueToAllocaMap; // Map of ptr to alloca + + FindPromotableValuesInLoop(PromotedValues, ValueToAllocaMap); + if (ValueToAllocaMap.empty()) return; // If there are values to promote. + + Changed = true; + NumPromoted += PromotedValues.size(); + + std::vector<Value*> PointerValueNumbers; + + // Emit a copy from the value into the alloca'd value in the loop preheader + TerminatorInst *LoopPredInst = Preheader->getTerminator(); + for (unsigned i = 0, e = PromotedValues.size(); i != e; ++i) { + Value *Ptr = PromotedValues[i].second; + + // If we are promoting a pointer value, update alias information for the + // inserted load. + Value *LoadValue = 0; + if (isa<PointerType>(cast<PointerType>(Ptr->getType())->getElementType())) { + // Locate a load or store through the pointer, and assign the same value + // to LI as we are loading or storing. Since we know that the value is + // stored in this loop, this will always succeed. + for (Value::use_iterator UI = Ptr->use_begin(), E = Ptr->use_end(); + UI != E; ++UI) + if (LoadInst *LI = dyn_cast<LoadInst>(*UI)) { + LoadValue = LI; + break; + } else if (StoreInst *SI = dyn_cast<StoreInst>(*UI)) { + if (SI->getOperand(1) == Ptr) { + LoadValue = SI->getOperand(0); + break; + } + } + assert(LoadValue && "No store through the pointer found!"); + PointerValueNumbers.push_back(LoadValue); // Remember this for later. + } + + // Load from the memory we are promoting. + LoadInst *LI = new LoadInst(Ptr, Ptr->getName()+".promoted", LoopPredInst); + + if (LoadValue) CurAST->copyValue(LoadValue, LI); + + // Store into the temporary alloca. + new StoreInst(LI, PromotedValues[i].first, LoopPredInst); + } + + // Scan the basic blocks in the loop, replacing uses of our pointers with + // uses of the allocas in question. + // + for (Loop::block_iterator I = CurLoop->block_begin(), + E = CurLoop->block_end(); I != E; ++I) { + BasicBlock *BB = *I; + // Rewrite all loads and stores in the block of the pointer... + for (BasicBlock::iterator II = BB->begin(), E = BB->end(); II != E; ++II) { + if (LoadInst *L = dyn_cast<LoadInst>(II)) { + std::map<Value*, AllocaInst*>::iterator + I = ValueToAllocaMap.find(L->getOperand(0)); + if (I != ValueToAllocaMap.end()) + L->setOperand(0, I->second); // Rewrite load instruction... + } else if (StoreInst *S = dyn_cast<StoreInst>(II)) { + std::map<Value*, AllocaInst*>::iterator + I = ValueToAllocaMap.find(S->getOperand(1)); + if (I != ValueToAllocaMap.end()) + S->setOperand(1, I->second); // Rewrite store instruction... + } + } + } + + // Now that the body of the loop uses the allocas instead of the original + // memory locations, insert code to copy the alloca value back into the + // original memory location on all exits from the loop. Note that we only + // want to insert one copy of the code in each exit block, though the loop may + // exit to the same block more than once. + // + SmallPtrSet<BasicBlock*, 16> ProcessedBlocks; + + SmallVector<BasicBlock*, 8> ExitBlocks; + CurLoop->getExitBlocks(ExitBlocks); + for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i) { + if (!ProcessedBlocks.insert(ExitBlocks[i])) + continue; + + // Copy all of the allocas into their memory locations. + BasicBlock::iterator BI = ExitBlocks[i]->getFirstNonPHI(); + Instruction *InsertPos = BI; + unsigned PVN = 0; + for (unsigned i = 0, e = PromotedValues.size(); i != e; ++i) { + // Load from the alloca. + LoadInst *LI = new LoadInst(PromotedValues[i].first, "", InsertPos); + + // If this is a pointer type, update alias info appropriately. + if (isa<PointerType>(LI->getType())) + CurAST->copyValue(PointerValueNumbers[PVN++], LI); + + // Store into the memory we promoted. + new StoreInst(LI, PromotedValues[i].second, InsertPos); + } + } + + // Now that we have done the deed, use the mem2reg functionality to promote + // all of the new allocas we just created into real SSA registers. + // + std::vector<AllocaInst*> PromotedAllocas; + PromotedAllocas.reserve(PromotedValues.size()); + for (unsigned i = 0, e = PromotedValues.size(); i != e; ++i) + PromotedAllocas.push_back(PromotedValues[i].first); + PromoteMemToReg(PromotedAllocas, *DT, *DF, CurAST); +} + +/// FindPromotableValuesInLoop - Check the current loop for stores to definite +/// pointers, which are not loaded and stored through may aliases and are safe +/// for promotion. If these are found, create an alloca for the value, add it +/// to the PromotedValues list, and keep track of the mapping from value to +/// alloca. +void LICM::FindPromotableValuesInLoop( + std::vector<std::pair<AllocaInst*, Value*> > &PromotedValues, + std::map<Value*, AllocaInst*> &ValueToAllocaMap) { + Instruction *FnStart = CurLoop->getHeader()->getParent()->begin()->begin(); + + // Loop over all of the alias sets in the tracker object. + for (AliasSetTracker::iterator I = CurAST->begin(), E = CurAST->end(); + I != E; ++I) { + AliasSet &AS = *I; + // We can promote this alias set if it has a store, if it is a "Must" alias + // set, if the pointer is loop invariant, and if we are not eliminating any + // volatile loads or stores. + if (AS.isForwardingAliasSet() || !AS.isMod() || !AS.isMustAlias() || + AS.isVolatile() || !CurLoop->isLoopInvariant(AS.begin()->getValue())) + continue; + + assert(!AS.empty() && + "Must alias set should have at least one pointer element in it!"); + Value *V = AS.begin()->getValue(); + + // Check that all of the pointers in the alias set have the same type. We + // cannot (yet) promote a memory location that is loaded and stored in + // different sizes. + { + bool PointerOk = true; + for (AliasSet::iterator I = AS.begin(), E = AS.end(); I != E; ++I) + if (V->getType() != I->getValue()->getType()) { + PointerOk = false; + break; + } + if (!PointerOk) + continue; + } + + // It isn't safe to promote a load/store from the loop if the load/store is + // conditional. For example, turning: + // + // for () { if (c) *P += 1; } + // + // into: + // + // tmp = *P; for () { if (c) tmp +=1; } *P = tmp; + // + // is not safe, because *P may only be valid to access if 'c' is true. + // + // It is safe to promote P if all uses are direct load/stores and if at + // least one is guaranteed to be executed. + bool GuaranteedToExecute = false; + bool InvalidInst = false; + for (Value::use_iterator UI = V->use_begin(), UE = V->use_end(); + UI != UE; ++UI) { + // Ignore instructions not in this loop. + Instruction *Use = dyn_cast<Instruction>(*UI); + if (!Use || !CurLoop->contains(Use->getParent())) + continue; + + if (!isa<LoadInst>(Use) && !isa<StoreInst>(Use)) { + InvalidInst = true; + break; + } + + if (!GuaranteedToExecute) + GuaranteedToExecute = isSafeToExecuteUnconditionally(*Use); + } + + // If there is an non-load/store instruction in the loop, we can't promote + // it. If there isn't a guaranteed-to-execute instruction, we can't + // promote. + if (InvalidInst || !GuaranteedToExecute) + continue; + + const Type *Ty = cast<PointerType>(V->getType())->getElementType(); + AllocaInst *AI = new AllocaInst(Ty, 0, V->getName()+".tmp", FnStart); + PromotedValues.push_back(std::make_pair(AI, V)); + + // Update the AST and alias analysis. + CurAST->copyValue(V, AI); + + for (AliasSet::iterator I = AS.begin(), E = AS.end(); I != E; ++I) + ValueToAllocaMap.insert(std::make_pair(I->getValue(), AI)); + + DOUT << "LICM: Promoting value: " << *V << "\n"; + } +} + +/// cloneBasicBlockAnalysis - Simple Analysis hook. Clone alias set info. +void LICM::cloneBasicBlockAnalysis(BasicBlock *From, BasicBlock *To, Loop *L) { + AliasSetTracker *AST = LoopToAliasMap[L]; + if (!AST) + return; + + AST->copyValue(From, To); +} + +/// deleteAnalysisValue - Simple Analysis hook. Delete value V from alias +/// set. +void LICM::deleteAnalysisValue(Value *V, Loop *L) { + AliasSetTracker *AST = LoopToAliasMap[L]; + if (!AST) + return; + + AST->deleteValue(V); +} diff --git a/lib/Transforms/Scalar/LoopDeletion.cpp b/lib/Transforms/Scalar/LoopDeletion.cpp new file mode 100644 index 0000000..6512672 --- /dev/null +++ b/lib/Transforms/Scalar/LoopDeletion.cpp @@ -0,0 +1,280 @@ +//===- LoopDeletion.cpp - Dead Loop Deletion Pass ---------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the Dead Loop Deletion Pass. This pass is responsible +// for eliminating loops with non-infinite computable trip counts that have no +// side effects or volatile instructions, and do not contribute to the +// computation of the function's return value. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "loop-delete" + +#include "llvm/Transforms/Scalar.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/SmallVector.h" + +using namespace llvm; + +STATISTIC(NumDeleted, "Number of loops deleted"); + +namespace { + class VISIBILITY_HIDDEN LoopDeletion : public LoopPass { + public: + static char ID; // Pass ID, replacement for typeid + LoopDeletion() : LoopPass(&ID) {} + + // Possibly eliminate loop L if it is dead. + bool runOnLoop(Loop* L, LPPassManager& LPM); + + bool SingleDominatingExit(Loop* L, + SmallVector<BasicBlock*, 4>& exitingBlocks); + bool IsLoopDead(Loop* L, SmallVector<BasicBlock*, 4>& exitingBlocks, + SmallVector<BasicBlock*, 4>& exitBlocks); + bool IsLoopInvariantInst(Instruction *I, Loop* L); + + virtual void getAnalysisUsage(AnalysisUsage& AU) const { + AU.addRequired<ScalarEvolution>(); + AU.addRequired<DominatorTree>(); + AU.addRequired<LoopInfo>(); + AU.addRequiredID(LoopSimplifyID); + AU.addRequiredID(LCSSAID); + + AU.addPreserved<ScalarEvolution>(); + AU.addPreserved<DominatorTree>(); + AU.addPreserved<LoopInfo>(); + AU.addPreservedID(LoopSimplifyID); + AU.addPreservedID(LCSSAID); + AU.addPreserved<DominanceFrontier>(); + } + }; +} + +char LoopDeletion::ID = 0; +static RegisterPass<LoopDeletion> X("loop-deletion", "Delete dead loops"); + +Pass* llvm::createLoopDeletionPass() { + return new LoopDeletion(); +} + +/// SingleDominatingExit - Checks that there is only a single blocks that +/// branches out of the loop, and that it also g the latch block. Loops +/// with multiple or non-latch-dominating exiting blocks could be dead, but we'd +/// have to do more extensive analysis to make sure, for instance, that the +/// control flow logic involved was or could be made loop-invariant. +bool LoopDeletion::SingleDominatingExit(Loop* L, + SmallVector<BasicBlock*, 4>& exitingBlocks) { + + if (exitingBlocks.size() != 1) + return false; + + BasicBlock* latch = L->getLoopLatch(); + if (!latch) + return false; + + DominatorTree& DT = getAnalysis<DominatorTree>(); + return DT.dominates(exitingBlocks[0], latch); +} + +/// IsLoopInvariantInst - Checks if an instruction is invariant with respect to +/// a loop, which is defined as being true if all of its operands are defined +/// outside of the loop. These instructions can be hoisted out of the loop +/// if their results are needed. This could be made more aggressive by +/// recursively checking the operands for invariance, but it's not clear that +/// it's worth it. +bool LoopDeletion::IsLoopInvariantInst(Instruction *I, Loop* L) { + // PHI nodes are not loop invariant if defined in the loop. + if (isa<PHINode>(I) && L->contains(I->getParent())) + return false; + + // The instruction is loop invariant if all of its operands are loop-invariant + for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) + if (!L->isLoopInvariant(I->getOperand(i))) + return false; + + // If we got this far, the instruction is loop invariant! + return true; +} + +/// IsLoopDead - Determined if a loop is dead. This assumes that we've already +/// checked for unique exit and exiting blocks, and that the code is in LCSSA +/// form. +bool LoopDeletion::IsLoopDead(Loop* L, + SmallVector<BasicBlock*, 4>& exitingBlocks, + SmallVector<BasicBlock*, 4>& exitBlocks) { + BasicBlock* exitingBlock = exitingBlocks[0]; + BasicBlock* exitBlock = exitBlocks[0]; + + // Make sure that all PHI entries coming from the loop are loop invariant. + // Because the code is in LCSSA form, any values used outside of the loop + // must pass through a PHI in the exit block, meaning that this check is + // sufficient to guarantee that no loop-variant values are used outside + // of the loop. + BasicBlock::iterator BI = exitBlock->begin(); + while (PHINode* P = dyn_cast<PHINode>(BI)) { + Value* incoming = P->getIncomingValueForBlock(exitingBlock); + if (Instruction* I = dyn_cast<Instruction>(incoming)) + if (!IsLoopInvariantInst(I, L)) + return false; + + BI++; + } + + // Make sure that no instructions in the block have potential side-effects. + // This includes instructions that could write to memory, and loads that are + // marked volatile. This could be made more aggressive by using aliasing + // information to identify readonly and readnone calls. + for (Loop::block_iterator LI = L->block_begin(), LE = L->block_end(); + LI != LE; ++LI) { + for (BasicBlock::iterator BI = (*LI)->begin(), BE = (*LI)->end(); + BI != BE; ++BI) { + if (BI->mayHaveSideEffects()) + return false; + } + } + + return true; +} + +/// runOnLoop - Remove dead loops, by which we mean loops that do not impact the +/// observable behavior of the program other than finite running time. Note +/// we do ensure that this never remove a loop that might be infinite, as doing +/// so could change the halting/non-halting nature of a program. +/// NOTE: This entire process relies pretty heavily on LoopSimplify and LCSSA +/// in order to make various safety checks work. +bool LoopDeletion::runOnLoop(Loop* L, LPPassManager& LPM) { + // We can only remove the loop if there is a preheader that we can + // branch from after removing it. + BasicBlock* preheader = L->getLoopPreheader(); + if (!preheader) + return false; + + // We can't remove loops that contain subloops. If the subloops were dead, + // they would already have been removed in earlier executions of this pass. + if (L->begin() != L->end()) + return false; + + SmallVector<BasicBlock*, 4> exitingBlocks; + L->getExitingBlocks(exitingBlocks); + + SmallVector<BasicBlock*, 4> exitBlocks; + L->getUniqueExitBlocks(exitBlocks); + + // We require that the loop only have a single exit block. Otherwise, we'd + // be in the situation of needing to be able to solve statically which exit + // block will be branched to, or trying to preserve the branching logic in + // a loop invariant manner. + if (exitBlocks.size() != 1) + return false; + + // Loops with multiple exits or exits that don't dominate the latch + // are too complicated to handle correctly. + if (!SingleDominatingExit(L, exitingBlocks)) + return false; + + // Finally, we have to check that the loop really is dead. + if (!IsLoopDead(L, exitingBlocks, exitBlocks)) + return false; + + // Don't remove loops for which we can't solve the trip count. + // They could be infinite, in which case we'd be changing program behavior. + ScalarEvolution& SE = getAnalysis<ScalarEvolution>(); + SCEVHandle S = SE.getBackedgeTakenCount(L); + if (isa<SCEVCouldNotCompute>(S)) + return false; + + // Now that we know the removal is safe, remove the loop by changing the + // branch from the preheader to go to the single exit block. + BasicBlock* exitBlock = exitBlocks[0]; + BasicBlock* exitingBlock = exitingBlocks[0]; + + // Because we're deleting a large chunk of code at once, the sequence in which + // we remove things is very important to avoid invalidation issues. Don't + // mess with this unless you have good reason and know what you're doing. + + // Move simple loop-invariant expressions out of the loop, since they + // might be needed by the exit phis. + for (Loop::block_iterator LI = L->block_begin(), LE = L->block_end(); + LI != LE; ++LI) + for (BasicBlock::iterator BI = (*LI)->begin(), BE = (*LI)->end(); + BI != BE; ) { + Instruction* I = BI++; + if (!I->use_empty() && IsLoopInvariantInst(I, L)) + I->moveBefore(preheader->getTerminator()); + } + + // Connect the preheader directly to the exit block. + TerminatorInst* TI = preheader->getTerminator(); + TI->replaceUsesOfWith(L->getHeader(), exitBlock); + + // Rewrite phis in the exit block to get their inputs from + // the preheader instead of the exiting block. + BasicBlock::iterator BI = exitBlock->begin(); + while (PHINode* P = dyn_cast<PHINode>(BI)) { + P->replaceUsesOfWith(exitingBlock, preheader); + BI++; + } + + // Update the dominator tree and remove the instructions and blocks that will + // be deleted from the reference counting scheme. + DominatorTree& DT = getAnalysis<DominatorTree>(); + DominanceFrontier* DF = getAnalysisIfAvailable<DominanceFrontier>(); + SmallPtrSet<DomTreeNode*, 8> ChildNodes; + for (Loop::block_iterator LI = L->block_begin(), LE = L->block_end(); + LI != LE; ++LI) { + // Move all of the block's children to be children of the preheader, which + // allows us to remove the domtree entry for the block. + ChildNodes.insert(DT[*LI]->begin(), DT[*LI]->end()); + for (SmallPtrSet<DomTreeNode*, 8>::iterator DI = ChildNodes.begin(), + DE = ChildNodes.end(); DI != DE; ++DI) { + DT.changeImmediateDominator(*DI, DT[preheader]); + if (DF) DF->changeImmediateDominator((*DI)->getBlock(), preheader, &DT); + } + + ChildNodes.clear(); + DT.eraseNode(*LI); + if (DF) DF->removeBlock(*LI); + + // Remove the block from the reference counting scheme, so that we can + // delete it freely later. + (*LI)->dropAllReferences(); + } + + // Tell ScalarEvolution that the loop is deleted. Do this before + // deleting the loop so that ScalarEvolution can look at the loop + // to determine what it needs to clean up. + SE.forgetLoopBackedgeTakenCount(L); + + // Erase the instructions and the blocks without having to worry + // about ordering because we already dropped the references. + // NOTE: This iteration is safe because erasing the block does not remove its + // entry from the loop's block list. We do that in the next section. + for (Loop::block_iterator LI = L->block_begin(), LE = L->block_end(); + LI != LE; ++LI) + (*LI)->eraseFromParent(); + + // Finally, the blocks from loopinfo. This has to happen late because + // otherwise our loop iterators won't work. + LoopInfo& loopInfo = getAnalysis<LoopInfo>(); + SmallPtrSet<BasicBlock*, 8> blocks; + blocks.insert(L->block_begin(), L->block_end()); + for (SmallPtrSet<BasicBlock*,8>::iterator I = blocks.begin(), + E = blocks.end(); I != E; ++I) + loopInfo.removeBlock(*I); + + // The last step is to inform the loop pass manager that we've + // eliminated this loop. + LPM.deleteLoopFromQueue(L); + + NumDeleted++; + + return true; +} diff --git a/lib/Transforms/Scalar/LoopIndexSplit.cpp b/lib/Transforms/Scalar/LoopIndexSplit.cpp new file mode 100644 index 0000000..9c78596 --- /dev/null +++ b/lib/Transforms/Scalar/LoopIndexSplit.cpp @@ -0,0 +1,1237 @@ +//===- LoopIndexSplit.cpp - Loop Index Splitting Pass ---------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements Loop Index Splitting Pass. This pass handles three +// kinds of loops. +// +// [1] A loop may be eliminated if the body is executed exactly once. +// For example, +// +// for (i = 0; i < N; ++i) { +// if (i == X) { +// body; +// } +// } +// +// is transformed to +// +// i = X; +// body; +// +// [2] A loop's iteration space may be shrunk if the loop body is executed +// for a proper sub-range of the loop's iteration space. For example, +// +// for (i = 0; i < N; ++i) { +// if (i > A && i < B) { +// ... +// } +// } +// +// is transformed to iterators from A to B, if A > 0 and B < N. +// +// [3] A loop may be split if the loop body is dominated by a branch. +// For example, +// +// for (i = LB; i < UB; ++i) { if (i < SV) A; else B; } +// +// is transformed into +// +// AEV = BSV = SV +// for (i = LB; i < min(UB, AEV); ++i) +// A; +// for (i = max(LB, BSV); i < UB; ++i); +// B; +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "loop-index-split" + +#include "llvm/Transforms/Scalar.h" +#include "llvm/IntrinsicInst.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/Dominators.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Support/Compiler.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/Statistic.h" + +using namespace llvm; + +STATISTIC(NumIndexSplit, "Number of loop index split"); +STATISTIC(NumIndexSplitRemoved, "Number of loops eliminated by loop index split"); +STATISTIC(NumRestrictBounds, "Number of loop iteration space restricted"); + +namespace { + + class VISIBILITY_HIDDEN LoopIndexSplit : public LoopPass { + + public: + static char ID; // Pass ID, replacement for typeid + LoopIndexSplit() : LoopPass(&ID) {} + + // Index split Loop L. Return true if loop is split. + bool runOnLoop(Loop *L, LPPassManager &LPM); + + void getAnalysisUsage(AnalysisUsage &AU) const { + AU.addPreserved<ScalarEvolution>(); + AU.addRequiredID(LCSSAID); + AU.addPreservedID(LCSSAID); + AU.addRequired<LoopInfo>(); + AU.addPreserved<LoopInfo>(); + AU.addRequiredID(LoopSimplifyID); + AU.addPreservedID(LoopSimplifyID); + AU.addRequired<DominatorTree>(); + AU.addRequired<DominanceFrontier>(); + AU.addPreserved<DominatorTree>(); + AU.addPreserved<DominanceFrontier>(); + } + + private: + /// processOneIterationLoop -- Eliminate loop if loop body is executed + /// only once. For example, + /// for (i = 0; i < N; ++i) { + /// if ( i == X) { + /// ... + /// } + /// } + /// + bool processOneIterationLoop(); + + // -- Routines used by updateLoopIterationSpace(); + + /// updateLoopIterationSpace -- Update loop's iteration space if loop + /// body is executed for certain IV range only. For example, + /// + /// for (i = 0; i < N; ++i) { + /// if ( i > A && i < B) { + /// ... + /// } + /// } + /// is transformed to iterators from A to B, if A > 0 and B < N. + /// + bool updateLoopIterationSpace(); + + /// restrictLoopBound - Op dominates loop body. Op compares an IV based value + /// with a loop invariant value. Update loop's lower and upper bound based on + /// the loop invariant value. + bool restrictLoopBound(ICmpInst &Op); + + // --- Routines used by splitLoop(). --- / + + bool splitLoop(); + + /// removeBlocks - Remove basic block DeadBB and all blocks dominated by + /// DeadBB. This routine is used to remove split condition's dead branch, + /// dominated by DeadBB. LiveBB dominates split conidition's other branch. + void removeBlocks(BasicBlock *DeadBB, Loop *LP, BasicBlock *LiveBB); + + /// moveExitCondition - Move exit condition EC into split condition block. + void moveExitCondition(BasicBlock *CondBB, BasicBlock *ActiveBB, + BasicBlock *ExitBB, ICmpInst *EC, ICmpInst *SC, + PHINode *IV, Instruction *IVAdd, Loop *LP, + unsigned); + + /// updatePHINodes - CFG has been changed. + /// Before + /// - ExitBB's single predecessor was Latch + /// - Latch's second successor was Header + /// Now + /// - ExitBB's single predecessor was Header + /// - Latch's one and only successor was Header + /// + /// Update ExitBB PHINodes' to reflect this change. + void updatePHINodes(BasicBlock *ExitBB, BasicBlock *Latch, + BasicBlock *Header, + PHINode *IV, Instruction *IVIncrement, Loop *LP); + + // --- Utility routines --- / + + /// cleanBlock - A block is considered clean if all non terminal + /// instructions are either PHINodes or IV based values. + bool cleanBlock(BasicBlock *BB); + + /// IVisLT - If Op is comparing IV based value with an loop invariant and + /// IV based value is less than the loop invariant then return the loop + /// invariant. Otherwise return NULL. + Value * IVisLT(ICmpInst &Op); + + /// IVisLE - If Op is comparing IV based value with an loop invariant and + /// IV based value is less than or equal to the loop invariant then + /// return the loop invariant. Otherwise return NULL. + Value * IVisLE(ICmpInst &Op); + + /// IVisGT - If Op is comparing IV based value with an loop invariant and + /// IV based value is greater than the loop invariant then return the loop + /// invariant. Otherwise return NULL. + Value * IVisGT(ICmpInst &Op); + + /// IVisGE - If Op is comparing IV based value with an loop invariant and + /// IV based value is greater than or equal to the loop invariant then + /// return the loop invariant. Otherwise return NULL. + Value * IVisGE(ICmpInst &Op); + + private: + + // Current Loop information. + Loop *L; + LPPassManager *LPM; + LoopInfo *LI; + DominatorTree *DT; + DominanceFrontier *DF; + + PHINode *IndVar; + ICmpInst *ExitCondition; + ICmpInst *SplitCondition; + Value *IVStartValue; + Value *IVExitValue; + Instruction *IVIncrement; + SmallPtrSet<Value *, 4> IVBasedValues; + }; +} + +char LoopIndexSplit::ID = 0; +static RegisterPass<LoopIndexSplit> +X("loop-index-split", "Index Split Loops"); + +Pass *llvm::createLoopIndexSplitPass() { + return new LoopIndexSplit(); +} + +// Index split Loop L. Return true if loop is split. +bool LoopIndexSplit::runOnLoop(Loop *IncomingLoop, LPPassManager &LPM_Ref) { + L = IncomingLoop; + LPM = &LPM_Ref; + + // FIXME - Nested loops make dominator info updates tricky. + if (!L->getSubLoops().empty()) + return false; + + DT = &getAnalysis<DominatorTree>(); + LI = &getAnalysis<LoopInfo>(); + DF = &getAnalysis<DominanceFrontier>(); + + // Initialize loop data. + IndVar = L->getCanonicalInductionVariable(); + if (!IndVar) return false; + + bool P1InLoop = L->contains(IndVar->getIncomingBlock(1)); + IVStartValue = IndVar->getIncomingValue(!P1InLoop); + IVIncrement = dyn_cast<Instruction>(IndVar->getIncomingValue(P1InLoop)); + if (!IVIncrement) return false; + + IVBasedValues.clear(); + IVBasedValues.insert(IndVar); + IVBasedValues.insert(IVIncrement); + for (Loop::block_iterator I = L->block_begin(), E = L->block_end(); + I != E; ++I) + for(BasicBlock::iterator BI = (*I)->begin(), BE = (*I)->end(); + BI != BE; ++BI) { + if (BinaryOperator *BO = dyn_cast<BinaryOperator>(BI)) + if (BO != IVIncrement + && (BO->getOpcode() == Instruction::Add + || BO->getOpcode() == Instruction::Sub)) + if (IVBasedValues.count(BO->getOperand(0)) + && L->isLoopInvariant(BO->getOperand(1))) + IVBasedValues.insert(BO); + } + + // Reject loop if loop exit condition is not suitable. + BasicBlock *ExitingBlock = L->getExitingBlock(); + if (!ExitingBlock) + return false; + BranchInst *EBR = dyn_cast<BranchInst>(ExitingBlock->getTerminator()); + if (!EBR) return false; + ExitCondition = dyn_cast<ICmpInst>(EBR->getCondition()); + if (!ExitCondition) return false; + if (ExitingBlock != L->getLoopLatch()) return false; + IVExitValue = ExitCondition->getOperand(1); + if (!L->isLoopInvariant(IVExitValue)) + IVExitValue = ExitCondition->getOperand(0); + if (!L->isLoopInvariant(IVExitValue)) + return false; + + // If start value is more then exit value where induction variable + // increments by 1 then we are potentially dealing with an infinite loop. + // Do not index split this loop. + if (ConstantInt *SV = dyn_cast<ConstantInt>(IVStartValue)) + if (ConstantInt *EV = dyn_cast<ConstantInt>(IVExitValue)) + if (SV->getSExtValue() > EV->getSExtValue()) + return false; + + if (processOneIterationLoop()) + return true; + + if (updateLoopIterationSpace()) + return true; + + if (splitLoop()) + return true; + + return false; +} + +// --- Helper routines --- +// isUsedOutsideLoop - Returns true iff V is used outside the loop L. +static bool isUsedOutsideLoop(Value *V, Loop *L) { + for(Value::use_iterator UI = V->use_begin(), E = V->use_end(); UI != E; ++UI) + if (!L->contains(cast<Instruction>(*UI)->getParent())) + return true; + return false; +} + +// Return V+1 +static Value *getPlusOne(Value *V, bool Sign, Instruction *InsertPt) { + ConstantInt *One = ConstantInt::get(V->getType(), 1, Sign); + return BinaryOperator::CreateAdd(V, One, "lsp", InsertPt); +} + +// Return V-1 +static Value *getMinusOne(Value *V, bool Sign, Instruction *InsertPt) { + ConstantInt *One = ConstantInt::get(V->getType(), 1, Sign); + return BinaryOperator::CreateSub(V, One, "lsp", InsertPt); +} + +// Return min(V1, V1) +static Value *getMin(Value *V1, Value *V2, bool Sign, Instruction *InsertPt) { + + Value *C = new ICmpInst(Sign ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, + V1, V2, "lsp", InsertPt); + return SelectInst::Create(C, V1, V2, "lsp", InsertPt); +} + +// Return max(V1, V2) +static Value *getMax(Value *V1, Value *V2, bool Sign, Instruction *InsertPt) { + + Value *C = new ICmpInst(Sign ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, + V1, V2, "lsp", InsertPt); + return SelectInst::Create(C, V2, V1, "lsp", InsertPt); +} + +/// processOneIterationLoop -- Eliminate loop if loop body is executed +/// only once. For example, +/// for (i = 0; i < N; ++i) { +/// if ( i == X) { +/// ... +/// } +/// } +/// +bool LoopIndexSplit::processOneIterationLoop() { + SplitCondition = NULL; + BasicBlock *Latch = L->getLoopLatch(); + BasicBlock *Header = L->getHeader(); + BranchInst *BR = dyn_cast<BranchInst>(Header->getTerminator()); + if (!BR) return false; + if (!isa<BranchInst>(Latch->getTerminator())) return false; + if (BR->isUnconditional()) return false; + SplitCondition = dyn_cast<ICmpInst>(BR->getCondition()); + if (!SplitCondition) return false; + if (SplitCondition == ExitCondition) return false; + if (SplitCondition->getPredicate() != ICmpInst::ICMP_EQ) return false; + if (BR->getOperand(1) != Latch) return false; + if (!IVBasedValues.count(SplitCondition->getOperand(0)) + && !IVBasedValues.count(SplitCondition->getOperand(1))) + return false; + + // If IV is used outside the loop then this loop traversal is required. + // FIXME: Calculate and use last IV value. + if (isUsedOutsideLoop(IVIncrement, L)) + return false; + + // If BR operands are not IV or not loop invariants then skip this loop. + Value *OPV = SplitCondition->getOperand(0); + Value *SplitValue = SplitCondition->getOperand(1); + if (!L->isLoopInvariant(SplitValue)) + std::swap(OPV, SplitValue); + if (!L->isLoopInvariant(SplitValue)) + return false; + Instruction *OPI = dyn_cast<Instruction>(OPV); + if (!OPI) + return false; + if (OPI->getParent() != Header || isUsedOutsideLoop(OPI, L)) + return false; + Value *StartValue = IVStartValue; + Value *ExitValue = IVExitValue;; + + if (OPV != IndVar) { + // If BR operand is IV based then use this operand to calculate + // effective conditions for loop body. + BinaryOperator *BOPV = dyn_cast<BinaryOperator>(OPV); + if (!BOPV) + return false; + if (BOPV->getOpcode() != Instruction::Add) + return false; + StartValue = BinaryOperator::CreateAdd(OPV, StartValue, "" , BR); + ExitValue = BinaryOperator::CreateAdd(OPV, ExitValue, "" , BR); + } + + if (!cleanBlock(Header)) + return false; + + if (!cleanBlock(Latch)) + return false; + + // If the merge point for BR is not loop latch then skip this loop. + if (BR->getSuccessor(0) != Latch) { + DominanceFrontier::iterator DF0 = DF->find(BR->getSuccessor(0)); + assert (DF0 != DF->end() && "Unable to find dominance frontier"); + if (!DF0->second.count(Latch)) + return false; + } + + if (BR->getSuccessor(1) != Latch) { + DominanceFrontier::iterator DF1 = DF->find(BR->getSuccessor(1)); + assert (DF1 != DF->end() && "Unable to find dominance frontier"); + if (!DF1->second.count(Latch)) + return false; + } + + // Now, Current loop L contains compare instruction + // that compares induction variable, IndVar, against loop invariant. And + // entire (i.e. meaningful) loop body is dominated by this compare + // instruction. In such case eliminate + // loop structure surrounding this loop body. For example, + // for (int i = start; i < end; ++i) { + // if ( i == somevalue) { + // loop_body + // } + // } + // can be transformed into + // if (somevalue >= start && somevalue < end) { + // i = somevalue; + // loop_body + // } + + // Replace index variable with split value in loop body. Loop body is executed + // only when index variable is equal to split value. + IndVar->replaceAllUsesWith(SplitValue); + + // Replace split condition in header. + // Transform + // SplitCondition : icmp eq i32 IndVar, SplitValue + // into + // c1 = icmp uge i32 SplitValue, StartValue + // c2 = icmp ult i32 SplitValue, ExitValue + // and i32 c1, c2 + Instruction *C1 = new ICmpInst(ExitCondition->isSignedPredicate() ? + ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, + SplitValue, StartValue, "lisplit", BR); + + CmpInst::Predicate C2P = ExitCondition->getPredicate(); + BranchInst *LatchBR = cast<BranchInst>(Latch->getTerminator()); + if (LatchBR->getOperand(0) != Header) + C2P = CmpInst::getInversePredicate(C2P); + Instruction *C2 = new ICmpInst(C2P, SplitValue, ExitValue, "lisplit", BR); + Instruction *NSplitCond = BinaryOperator::CreateAnd(C1, C2, "lisplit", BR); + + SplitCondition->replaceAllUsesWith(NSplitCond); + SplitCondition->eraseFromParent(); + + // Remove Latch to Header edge. + BasicBlock *LatchSucc = NULL; + Header->removePredecessor(Latch); + for (succ_iterator SI = succ_begin(Latch), E = succ_end(Latch); + SI != E; ++SI) { + if (Header != *SI) + LatchSucc = *SI; + } + + // Clean up latch block. + Value *LatchBRCond = LatchBR->getCondition(); + LatchBR->setUnconditionalDest(LatchSucc); + RecursivelyDeleteTriviallyDeadInstructions(LatchBRCond); + + LPM->deleteLoopFromQueue(L); + + // Update Dominator Info. + // Only CFG change done is to remove Latch to Header edge. This + // does not change dominator tree because Latch did not dominate + // Header. + if (DF) { + DominanceFrontier::iterator HeaderDF = DF->find(Header); + if (HeaderDF != DF->end()) + DF->removeFromFrontier(HeaderDF, Header); + + DominanceFrontier::iterator LatchDF = DF->find(Latch); + if (LatchDF != DF->end()) + DF->removeFromFrontier(LatchDF, Header); + } + + ++NumIndexSplitRemoved; + return true; +} + +/// restrictLoopBound - Op dominates loop body. Op compares an IV based value +/// with a loop invariant value. Update loop's lower and upper bound based on +/// the loop invariant value. +bool LoopIndexSplit::restrictLoopBound(ICmpInst &Op) { + bool Sign = Op.isSignedPredicate(); + Instruction *PHTerm = L->getLoopPreheader()->getTerminator(); + + if (IVisGT(*ExitCondition) || IVisGE(*ExitCondition)) { + BranchInst *EBR = + cast<BranchInst>(ExitCondition->getParent()->getTerminator()); + ExitCondition->setPredicate(ExitCondition->getInversePredicate()); + BasicBlock *T = EBR->getSuccessor(0); + EBR->setSuccessor(0, EBR->getSuccessor(1)); + EBR->setSuccessor(1, T); + } + + // New upper and lower bounds. + Value *NLB = NULL; + Value *NUB = NULL; + if (Value *V = IVisLT(Op)) { + // Restrict upper bound. + if (IVisLE(*ExitCondition)) + V = getMinusOne(V, Sign, PHTerm); + NUB = getMin(V, IVExitValue, Sign, PHTerm); + } else if (Value *V = IVisLE(Op)) { + // Restrict upper bound. + if (IVisLT(*ExitCondition)) + V = getPlusOne(V, Sign, PHTerm); + NUB = getMin(V, IVExitValue, Sign, PHTerm); + } else if (Value *V = IVisGT(Op)) { + // Restrict lower bound. + V = getPlusOne(V, Sign, PHTerm); + NLB = getMax(V, IVStartValue, Sign, PHTerm); + } else if (Value *V = IVisGE(Op)) + // Restrict lower bound. + NLB = getMax(V, IVStartValue, Sign, PHTerm); + + if (!NLB && !NUB) + return false; + + if (NLB) { + unsigned i = IndVar->getBasicBlockIndex(L->getLoopPreheader()); + IndVar->setIncomingValue(i, NLB); + } + + if (NUB) { + unsigned i = (ExitCondition->getOperand(0) != IVExitValue); + ExitCondition->setOperand(i, NUB); + } + return true; +} + +/// updateLoopIterationSpace -- Update loop's iteration space if loop +/// body is executed for certain IV range only. For example, +/// +/// for (i = 0; i < N; ++i) { +/// if ( i > A && i < B) { +/// ... +/// } +/// } +/// is transformed to iterators from A to B, if A > 0 and B < N. +/// +bool LoopIndexSplit::updateLoopIterationSpace() { + SplitCondition = NULL; + if (ExitCondition->getPredicate() == ICmpInst::ICMP_NE + || ExitCondition->getPredicate() == ICmpInst::ICMP_EQ) + return false; + BasicBlock *Latch = L->getLoopLatch(); + BasicBlock *Header = L->getHeader(); + BranchInst *BR = dyn_cast<BranchInst>(Header->getTerminator()); + if (!BR) return false; + if (!isa<BranchInst>(Latch->getTerminator())) return false; + if (BR->isUnconditional()) return false; + BinaryOperator *AND = dyn_cast<BinaryOperator>(BR->getCondition()); + if (!AND) return false; + if (AND->getOpcode() != Instruction::And) return false; + ICmpInst *Op0 = dyn_cast<ICmpInst>(AND->getOperand(0)); + ICmpInst *Op1 = dyn_cast<ICmpInst>(AND->getOperand(1)); + if (!Op0 || !Op1) + return false; + IVBasedValues.insert(AND); + IVBasedValues.insert(Op0); + IVBasedValues.insert(Op1); + if (!cleanBlock(Header)) return false; + BasicBlock *ExitingBlock = ExitCondition->getParent(); + if (!cleanBlock(ExitingBlock)) return false; + + // If the merge point for BR is not loop latch then skip this loop. + if (BR->getSuccessor(0) != Latch) { + DominanceFrontier::iterator DF0 = DF->find(BR->getSuccessor(0)); + assert (DF0 != DF->end() && "Unable to find dominance frontier"); + if (!DF0->second.count(Latch)) + return false; + } + + if (BR->getSuccessor(1) != Latch) { + DominanceFrontier::iterator DF1 = DF->find(BR->getSuccessor(1)); + assert (DF1 != DF->end() && "Unable to find dominance frontier"); + if (!DF1->second.count(Latch)) + return false; + } + + // Verify that loop exiting block has only two predecessor, where one pred + // is split condition block. The other predecessor will become exiting block's + // dominator after CFG is updated. TODO : Handle CFG's where exiting block has + // more then two predecessors. This requires extra work in updating dominator + // information. + BasicBlock *ExitingBBPred = NULL; + for (pred_iterator PI = pred_begin(ExitingBlock), PE = pred_end(ExitingBlock); + PI != PE; ++PI) { + BasicBlock *BB = *PI; + if (Header == BB) + continue; + if (ExitingBBPred) + return false; + else + ExitingBBPred = BB; + } + + if (!restrictLoopBound(*Op0)) + return false; + + if (!restrictLoopBound(*Op1)) + return false; + + // Update CFG. + if (BR->getSuccessor(0) == ExitingBlock) + BR->setUnconditionalDest(BR->getSuccessor(1)); + else + BR->setUnconditionalDest(BR->getSuccessor(0)); + + AND->eraseFromParent(); + if (Op0->use_empty()) + Op0->eraseFromParent(); + if (Op1->use_empty()) + Op1->eraseFromParent(); + + // Update domiantor info. Now, ExitingBlock has only one predecessor, + // ExitingBBPred, and it is ExitingBlock's immediate domiantor. + DT->changeImmediateDominator(ExitingBlock, ExitingBBPred); + + BasicBlock *ExitBlock = ExitingBlock->getTerminator()->getSuccessor(1); + if (L->contains(ExitBlock)) + ExitBlock = ExitingBlock->getTerminator()->getSuccessor(0); + + // If ExitingBlock is a member of the loop basic blocks' DF list then + // replace ExitingBlock with header and exit block in the DF list + DominanceFrontier::iterator ExitingBlockDF = DF->find(ExitingBlock); + for (Loop::block_iterator I = L->block_begin(), E = L->block_end(); + I != E; ++I) { + BasicBlock *BB = *I; + if (BB == Header || BB == ExitingBlock) + continue; + DominanceFrontier::iterator BBDF = DF->find(BB); + DominanceFrontier::DomSetType::iterator DomSetI = BBDF->second.begin(); + DominanceFrontier::DomSetType::iterator DomSetE = BBDF->second.end(); + while (DomSetI != DomSetE) { + DominanceFrontier::DomSetType::iterator CurrentItr = DomSetI; + ++DomSetI; + BasicBlock *DFBB = *CurrentItr; + if (DFBB == ExitingBlock) { + BBDF->second.erase(DFBB); + for (DominanceFrontier::DomSetType::iterator + EBI = ExitingBlockDF->second.begin(), + EBE = ExitingBlockDF->second.end(); EBI != EBE; ++EBI) + BBDF->second.insert(*EBI); + } + } + } + NumRestrictBounds++; + return true; +} + +/// removeBlocks - Remove basic block DeadBB and all blocks dominated by DeadBB. +/// This routine is used to remove split condition's dead branch, dominated by +/// DeadBB. LiveBB dominates split conidition's other branch. +void LoopIndexSplit::removeBlocks(BasicBlock *DeadBB, Loop *LP, + BasicBlock *LiveBB) { + + // First update DeadBB's dominance frontier. + SmallVector<BasicBlock *, 8> FrontierBBs; + DominanceFrontier::iterator DeadBBDF = DF->find(DeadBB); + if (DeadBBDF != DF->end()) { + SmallVector<BasicBlock *, 8> PredBlocks; + + DominanceFrontier::DomSetType DeadBBSet = DeadBBDF->second; + for (DominanceFrontier::DomSetType::iterator DeadBBSetI = DeadBBSet.begin(), + DeadBBSetE = DeadBBSet.end(); DeadBBSetI != DeadBBSetE; ++DeadBBSetI) + { + BasicBlock *FrontierBB = *DeadBBSetI; + FrontierBBs.push_back(FrontierBB); + + // Rremove any PHI incoming edge from blocks dominated by DeadBB. + PredBlocks.clear(); + for(pred_iterator PI = pred_begin(FrontierBB), PE = pred_end(FrontierBB); + PI != PE; ++PI) { + BasicBlock *P = *PI; + if (P == DeadBB || DT->dominates(DeadBB, P)) + PredBlocks.push_back(P); + } + + for(BasicBlock::iterator FBI = FrontierBB->begin(), FBE = FrontierBB->end(); + FBI != FBE; ++FBI) { + if (PHINode *PN = dyn_cast<PHINode>(FBI)) { + for(SmallVector<BasicBlock *, 8>::iterator PI = PredBlocks.begin(), + PE = PredBlocks.end(); PI != PE; ++PI) { + BasicBlock *P = *PI; + PN->removeIncomingValue(P); + } + } + else + break; + } + } + } + + // Now remove DeadBB and all nodes dominated by DeadBB in df order. + SmallVector<BasicBlock *, 32> WorkList; + DomTreeNode *DN = DT->getNode(DeadBB); + for (df_iterator<DomTreeNode*> DI = df_begin(DN), + E = df_end(DN); DI != E; ++DI) { + BasicBlock *BB = DI->getBlock(); + WorkList.push_back(BB); + BB->replaceAllUsesWith(UndefValue::get(Type::LabelTy)); + } + + while (!WorkList.empty()) { + BasicBlock *BB = WorkList.back(); WorkList.pop_back(); + LPM->deleteSimpleAnalysisValue(BB, LP); + for(BasicBlock::iterator BBI = BB->begin(), BBE = BB->end(); + BBI != BBE; ) { + Instruction *I = BBI; + ++BBI; + I->replaceAllUsesWith(UndefValue::get(I->getType())); + LPM->deleteSimpleAnalysisValue(I, LP); + I->eraseFromParent(); + } + DT->eraseNode(BB); + DF->removeBlock(BB); + LI->removeBlock(BB); + BB->eraseFromParent(); + } + + // Update Frontier BBs' dominator info. + while (!FrontierBBs.empty()) { + BasicBlock *FBB = FrontierBBs.back(); FrontierBBs.pop_back(); + BasicBlock *NewDominator = FBB->getSinglePredecessor(); + if (!NewDominator) { + pred_iterator PI = pred_begin(FBB), PE = pred_end(FBB); + NewDominator = *PI; + ++PI; + if (NewDominator != LiveBB) { + for(; PI != PE; ++PI) { + BasicBlock *P = *PI; + if (P == LiveBB) { + NewDominator = LiveBB; + break; + } + NewDominator = DT->findNearestCommonDominator(NewDominator, P); + } + } + } + assert (NewDominator && "Unable to fix dominator info."); + DT->changeImmediateDominator(FBB, NewDominator); + DF->changeImmediateDominator(FBB, NewDominator, DT); + } + +} + +// moveExitCondition - Move exit condition EC into split condition block CondBB. +void LoopIndexSplit::moveExitCondition(BasicBlock *CondBB, BasicBlock *ActiveBB, + BasicBlock *ExitBB, ICmpInst *EC, + ICmpInst *SC, PHINode *IV, + Instruction *IVAdd, Loop *LP, + unsigned ExitValueNum) { + + BasicBlock *ExitingBB = EC->getParent(); + Instruction *CurrentBR = CondBB->getTerminator(); + + // Move exit condition into split condition block. + EC->moveBefore(CurrentBR); + EC->setOperand(ExitValueNum == 0 ? 1 : 0, IV); + + // Move exiting block's branch into split condition block. Update its branch + // destination. + BranchInst *ExitingBR = cast<BranchInst>(ExitingBB->getTerminator()); + ExitingBR->moveBefore(CurrentBR); + BasicBlock *OrigDestBB = NULL; + if (ExitingBR->getSuccessor(0) == ExitBB) { + OrigDestBB = ExitingBR->getSuccessor(1); + ExitingBR->setSuccessor(1, ActiveBB); + } + else { + OrigDestBB = ExitingBR->getSuccessor(0); + ExitingBR->setSuccessor(0, ActiveBB); + } + + // Remove split condition and current split condition branch. + SC->eraseFromParent(); + CurrentBR->eraseFromParent(); + + // Connect exiting block to original destination. + BranchInst::Create(OrigDestBB, ExitingBB); + + // Update PHINodes + updatePHINodes(ExitBB, ExitingBB, CondBB, IV, IVAdd, LP); + + // Fix dominator info. + // ExitBB is now dominated by CondBB + DT->changeImmediateDominator(ExitBB, CondBB); + DF->changeImmediateDominator(ExitBB, CondBB, DT); + + // Blocks outside the loop may have been in the dominance frontier of blocks + // inside the condition; this is now impossible because the blocks inside the + // condition no loger dominate the exit. Remove the relevant blocks from + // the dominance frontiers. + for (Loop::block_iterator I = LP->block_begin(), E = LP->block_end(); + I != E; ++I) { + if (*I == CondBB || !DT->dominates(CondBB, *I)) continue; + DominanceFrontier::iterator BBDF = DF->find(*I); + DominanceFrontier::DomSetType::iterator DomSetI = BBDF->second.begin(); + DominanceFrontier::DomSetType::iterator DomSetE = BBDF->second.end(); + while (DomSetI != DomSetE) { + DominanceFrontier::DomSetType::iterator CurrentItr = DomSetI; + ++DomSetI; + BasicBlock *DFBB = *CurrentItr; + if (!LP->contains(DFBB)) + BBDF->second.erase(DFBB); + } + } +} + +/// updatePHINodes - CFG has been changed. +/// Before +/// - ExitBB's single predecessor was Latch +/// - Latch's second successor was Header +/// Now +/// - ExitBB's single predecessor is Header +/// - Latch's one and only successor is Header +/// +/// Update ExitBB PHINodes' to reflect this change. +void LoopIndexSplit::updatePHINodes(BasicBlock *ExitBB, BasicBlock *Latch, + BasicBlock *Header, + PHINode *IV, Instruction *IVIncrement, + Loop *LP) { + + for (BasicBlock::iterator BI = ExitBB->begin(), BE = ExitBB->end(); + BI != BE; ) { + PHINode *PN = dyn_cast<PHINode>(BI); + ++BI; + if (!PN) + break; + + Value *V = PN->getIncomingValueForBlock(Latch); + if (PHINode *PHV = dyn_cast<PHINode>(V)) { + // PHV is in Latch. PHV has one use is in ExitBB PHINode. And one use + // in Header which is new incoming value for PN. + Value *NewV = NULL; + for (Value::use_iterator UI = PHV->use_begin(), E = PHV->use_end(); + UI != E; ++UI) + if (PHINode *U = dyn_cast<PHINode>(*UI)) + if (LP->contains(U->getParent())) { + NewV = U; + break; + } + + // Add incoming value from header only if PN has any use inside the loop. + if (NewV) + PN->addIncoming(NewV, Header); + + } else if (Instruction *PHI = dyn_cast<Instruction>(V)) { + // If this instruction is IVIncrement then IV is new incoming value + // from header otherwise this instruction must be incoming value from + // header because loop is in LCSSA form. + if (PHI == IVIncrement) + PN->addIncoming(IV, Header); + else + PN->addIncoming(V, Header); + } else + // Otherwise this is an incoming value from header because loop is in + // LCSSA form. + PN->addIncoming(V, Header); + + // Remove incoming value from Latch. + PN->removeIncomingValue(Latch); + } +} + +bool LoopIndexSplit::splitLoop() { + SplitCondition = NULL; + if (ExitCondition->getPredicate() == ICmpInst::ICMP_NE + || ExitCondition->getPredicate() == ICmpInst::ICMP_EQ) + return false; + BasicBlock *Header = L->getHeader(); + BasicBlock *Latch = L->getLoopLatch(); + BranchInst *SBR = NULL; // Split Condition Branch + BranchInst *EBR = cast<BranchInst>(ExitCondition->getParent()->getTerminator()); + // If Exiting block includes loop variant instructions then this + // loop may not be split safely. + BasicBlock *ExitingBlock = ExitCondition->getParent(); + if (!cleanBlock(ExitingBlock)) return false; + + for (Loop::block_iterator I = L->block_begin(), E = L->block_end(); + I != E; ++I) { + BranchInst *BR = dyn_cast<BranchInst>((*I)->getTerminator()); + if (!BR || BR->isUnconditional()) continue; + ICmpInst *CI = dyn_cast<ICmpInst>(BR->getCondition()); + if (!CI || CI == ExitCondition + || CI->getPredicate() == ICmpInst::ICMP_NE + || CI->getPredicate() == ICmpInst::ICMP_EQ) + continue; + + // Unable to handle triangle loops at the moment. + // In triangle loop, split condition is in header and one of the + // the split destination is loop latch. If split condition is EQ + // then such loops are already handle in processOneIterationLoop(). + if (Header == (*I) + && (Latch == BR->getSuccessor(0) || Latch == BR->getSuccessor(1))) + continue; + + // If the block does not dominate the latch then this is not a diamond. + // Such loop may not benefit from index split. + if (!DT->dominates((*I), Latch)) + continue; + + // If split condition branches heads do not have single predecessor, + // SplitCondBlock, then is not possible to remove inactive branch. + if (!BR->getSuccessor(0)->getSinglePredecessor() + || !BR->getSuccessor(1)->getSinglePredecessor()) + return false; + + // If the merge point for BR is not loop latch then skip this condition. + if (BR->getSuccessor(0) != Latch) { + DominanceFrontier::iterator DF0 = DF->find(BR->getSuccessor(0)); + assert (DF0 != DF->end() && "Unable to find dominance frontier"); + if (!DF0->second.count(Latch)) + continue; + } + + if (BR->getSuccessor(1) != Latch) { + DominanceFrontier::iterator DF1 = DF->find(BR->getSuccessor(1)); + assert (DF1 != DF->end() && "Unable to find dominance frontier"); + if (!DF1->second.count(Latch)) + continue; + } + SplitCondition = CI; + SBR = BR; + break; + } + + if (!SplitCondition) + return false; + + // If the predicate sign does not match then skip. + if (ExitCondition->isSignedPredicate() != SplitCondition->isSignedPredicate()) + return false; + + unsigned EVOpNum = (ExitCondition->getOperand(1) == IVExitValue); + unsigned SVOpNum = IVBasedValues.count(SplitCondition->getOperand(0)); + Value *SplitValue = SplitCondition->getOperand(SVOpNum); + if (!L->isLoopInvariant(SplitValue)) + return false; + if (!IVBasedValues.count(SplitCondition->getOperand(!SVOpNum))) + return false; + + // Normalize loop conditions so that it is easier to calculate new loop + // bounds. + if (IVisGT(*ExitCondition) || IVisGE(*ExitCondition)) { + ExitCondition->setPredicate(ExitCondition->getInversePredicate()); + BasicBlock *T = EBR->getSuccessor(0); + EBR->setSuccessor(0, EBR->getSuccessor(1)); + EBR->setSuccessor(1, T); + } + + if (IVisGT(*SplitCondition) || IVisGE(*SplitCondition)) { + SplitCondition->setPredicate(SplitCondition->getInversePredicate()); + BasicBlock *T = SBR->getSuccessor(0); + SBR->setSuccessor(0, SBR->getSuccessor(1)); + SBR->setSuccessor(1, T); + } + + //[*] Calculate new loop bounds. + Value *AEV = SplitValue; + Value *BSV = SplitValue; + bool Sign = SplitCondition->isSignedPredicate(); + Instruction *PHTerm = L->getLoopPreheader()->getTerminator(); + + if (IVisLT(*ExitCondition)) { + if (IVisLT(*SplitCondition)) { + /* Do nothing */ + } + else if (IVisLE(*SplitCondition)) { + AEV = getPlusOne(SplitValue, Sign, PHTerm); + BSV = getPlusOne(SplitValue, Sign, PHTerm); + } else { + assert (0 && "Unexpected split condition!"); + } + } + else if (IVisLE(*ExitCondition)) { + if (IVisLT(*SplitCondition)) { + AEV = getMinusOne(SplitValue, Sign, PHTerm); + } + else if (IVisLE(*SplitCondition)) { + BSV = getPlusOne(SplitValue, Sign, PHTerm); + } else { + assert (0 && "Unexpected split condition!"); + } + } else { + assert (0 && "Unexpected exit condition!"); + } + AEV = getMin(AEV, IVExitValue, Sign, PHTerm); + BSV = getMax(BSV, IVStartValue, Sign, PHTerm); + + // [*] Clone Loop + DenseMap<const Value *, Value *> ValueMap; + Loop *BLoop = CloneLoop(L, LPM, LI, ValueMap, this); + Loop *ALoop = L; + + // [*] ALoop's exiting edge enters BLoop's header. + // ALoop's original exit block becomes BLoop's exit block. + PHINode *B_IndVar = cast<PHINode>(ValueMap[IndVar]); + BasicBlock *A_ExitingBlock = ExitCondition->getParent(); + BranchInst *A_ExitInsn = + dyn_cast<BranchInst>(A_ExitingBlock->getTerminator()); + assert (A_ExitInsn && "Unable to find suitable loop exit branch"); + BasicBlock *B_ExitBlock = A_ExitInsn->getSuccessor(1); + BasicBlock *B_Header = BLoop->getHeader(); + if (ALoop->contains(B_ExitBlock)) { + B_ExitBlock = A_ExitInsn->getSuccessor(0); + A_ExitInsn->setSuccessor(0, B_Header); + } else + A_ExitInsn->setSuccessor(1, B_Header); + + // [*] Update ALoop's exit value using new exit value. + ExitCondition->setOperand(EVOpNum, AEV); + + // [*] Update BLoop's header phi nodes. Remove incoming PHINode's from + // original loop's preheader. Add incoming PHINode values from + // ALoop's exiting block. Update BLoop header's domiantor info. + + // Collect inverse map of Header PHINodes. + DenseMap<Value *, Value *> InverseMap; + for (BasicBlock::iterator BI = ALoop->getHeader()->begin(), + BE = ALoop->getHeader()->end(); BI != BE; ++BI) { + if (PHINode *PN = dyn_cast<PHINode>(BI)) { + PHINode *PNClone = cast<PHINode>(ValueMap[PN]); + InverseMap[PNClone] = PN; + } else + break; + } + + BasicBlock *A_Preheader = ALoop->getLoopPreheader(); + for (BasicBlock::iterator BI = B_Header->begin(), BE = B_Header->end(); + BI != BE; ++BI) { + if (PHINode *PN = dyn_cast<PHINode>(BI)) { + // Remove incoming value from original preheader. + PN->removeIncomingValue(A_Preheader); + + // Add incoming value from A_ExitingBlock. + if (PN == B_IndVar) + PN->addIncoming(BSV, A_ExitingBlock); + else { + PHINode *OrigPN = cast<PHINode>(InverseMap[PN]); + Value *V2 = NULL; + // If loop header is also loop exiting block then + // OrigPN is incoming value for B loop header. + if (A_ExitingBlock == ALoop->getHeader()) + V2 = OrigPN; + else + V2 = OrigPN->getIncomingValueForBlock(A_ExitingBlock); + PN->addIncoming(V2, A_ExitingBlock); + } + } else + break; + } + + DT->changeImmediateDominator(B_Header, A_ExitingBlock); + DF->changeImmediateDominator(B_Header, A_ExitingBlock, DT); + + // [*] Update BLoop's exit block. Its new predecessor is BLoop's exit + // block. Remove incoming PHINode values from ALoop's exiting block. + // Add new incoming values from BLoop's incoming exiting value. + // Update BLoop exit block's dominator info.. + BasicBlock *B_ExitingBlock = cast<BasicBlock>(ValueMap[A_ExitingBlock]); + for (BasicBlock::iterator BI = B_ExitBlock->begin(), BE = B_ExitBlock->end(); + BI != BE; ++BI) { + if (PHINode *PN = dyn_cast<PHINode>(BI)) { + PN->addIncoming(ValueMap[PN->getIncomingValueForBlock(A_ExitingBlock)], + B_ExitingBlock); + PN->removeIncomingValue(A_ExitingBlock); + } else + break; + } + + DT->changeImmediateDominator(B_ExitBlock, B_ExitingBlock); + DF->changeImmediateDominator(B_ExitBlock, B_ExitingBlock, DT); + + //[*] Split ALoop's exit edge. This creates a new block which + // serves two purposes. First one is to hold PHINode defnitions + // to ensure that ALoop's LCSSA form. Second use it to act + // as a preheader for BLoop. + BasicBlock *A_ExitBlock = SplitEdge(A_ExitingBlock, B_Header, this); + + //[*] Preserve ALoop's LCSSA form. Create new forwarding PHINodes + // in A_ExitBlock to redefine outgoing PHI definitions from ALoop. + for(BasicBlock::iterator BI = B_Header->begin(), BE = B_Header->end(); + BI != BE; ++BI) { + if (PHINode *PN = dyn_cast<PHINode>(BI)) { + Value *V1 = PN->getIncomingValueForBlock(A_ExitBlock); + PHINode *newPHI = PHINode::Create(PN->getType(), PN->getName()); + newPHI->addIncoming(V1, A_ExitingBlock); + A_ExitBlock->getInstList().push_front(newPHI); + PN->removeIncomingValue(A_ExitBlock); + PN->addIncoming(newPHI, A_ExitBlock); + } else + break; + } + + //[*] Eliminate split condition's inactive branch from ALoop. + BasicBlock *A_SplitCondBlock = SplitCondition->getParent(); + BranchInst *A_BR = cast<BranchInst>(A_SplitCondBlock->getTerminator()); + BasicBlock *A_InactiveBranch = NULL; + BasicBlock *A_ActiveBranch = NULL; + A_ActiveBranch = A_BR->getSuccessor(0); + A_InactiveBranch = A_BR->getSuccessor(1); + A_BR->setUnconditionalDest(A_ActiveBranch); + removeBlocks(A_InactiveBranch, L, A_ActiveBranch); + + //[*] Eliminate split condition's inactive branch in from BLoop. + BasicBlock *B_SplitCondBlock = cast<BasicBlock>(ValueMap[A_SplitCondBlock]); + BranchInst *B_BR = cast<BranchInst>(B_SplitCondBlock->getTerminator()); + BasicBlock *B_InactiveBranch = NULL; + BasicBlock *B_ActiveBranch = NULL; + B_ActiveBranch = B_BR->getSuccessor(1); + B_InactiveBranch = B_BR->getSuccessor(0); + B_BR->setUnconditionalDest(B_ActiveBranch); + removeBlocks(B_InactiveBranch, BLoop, B_ActiveBranch); + + BasicBlock *A_Header = ALoop->getHeader(); + if (A_ExitingBlock == A_Header) + return true; + + //[*] Move exit condition into split condition block to avoid + // executing dead loop iteration. + ICmpInst *B_ExitCondition = cast<ICmpInst>(ValueMap[ExitCondition]); + Instruction *B_IndVarIncrement = cast<Instruction>(ValueMap[IVIncrement]); + ICmpInst *B_SplitCondition = cast<ICmpInst>(ValueMap[SplitCondition]); + + moveExitCondition(A_SplitCondBlock, A_ActiveBranch, A_ExitBlock, ExitCondition, + cast<ICmpInst>(SplitCondition), IndVar, IVIncrement, + ALoop, EVOpNum); + + moveExitCondition(B_SplitCondBlock, B_ActiveBranch, + B_ExitBlock, B_ExitCondition, + B_SplitCondition, B_IndVar, B_IndVarIncrement, + BLoop, EVOpNum); + + NumIndexSplit++; + return true; +} + +/// cleanBlock - A block is considered clean if all non terminal instructions +/// are either, PHINodes, IV based. +bool LoopIndexSplit::cleanBlock(BasicBlock *BB) { + Instruction *Terminator = BB->getTerminator(); + for(BasicBlock::iterator BI = BB->begin(), BE = BB->end(); + BI != BE; ++BI) { + Instruction *I = BI; + + if (isa<PHINode>(I) || I == Terminator || I == ExitCondition + || I == SplitCondition || IVBasedValues.count(I) + || isa<DbgInfoIntrinsic>(I)) + continue; + + if (I->mayHaveSideEffects()) + return false; + + // I is used only inside this block then it is OK. + bool usedOutsideBB = false; + for (Value::use_iterator UI = I->use_begin(), UE = I->use_end(); + UI != UE; ++UI) { + Instruction *U = cast<Instruction>(UI); + if (U->getParent() != BB) + usedOutsideBB = true; + } + if (!usedOutsideBB) + continue; + + // Otherwise we have a instruction that may not allow loop spliting. + return false; + } + return true; +} + +/// IVisLT - If Op is comparing IV based value with an loop invariant and +/// IV based value is less than the loop invariant then return the loop +/// invariant. Otherwise return NULL. +Value * LoopIndexSplit::IVisLT(ICmpInst &Op) { + ICmpInst::Predicate P = Op.getPredicate(); + if ((P == ICmpInst::ICMP_SLT || P == ICmpInst::ICMP_ULT) + && IVBasedValues.count(Op.getOperand(0)) + && L->isLoopInvariant(Op.getOperand(1))) + return Op.getOperand(1); + + if ((P == ICmpInst::ICMP_SGT || P == ICmpInst::ICMP_UGT) + && IVBasedValues.count(Op.getOperand(1)) + && L->isLoopInvariant(Op.getOperand(0))) + return Op.getOperand(0); + + return NULL; +} + +/// IVisLE - If Op is comparing IV based value with an loop invariant and +/// IV based value is less than or equal to the loop invariant then +/// return the loop invariant. Otherwise return NULL. +Value * LoopIndexSplit::IVisLE(ICmpInst &Op) { + ICmpInst::Predicate P = Op.getPredicate(); + if ((P == ICmpInst::ICMP_SLE || P == ICmpInst::ICMP_ULE) + && IVBasedValues.count(Op.getOperand(0)) + && L->isLoopInvariant(Op.getOperand(1))) + return Op.getOperand(1); + + if ((P == ICmpInst::ICMP_SGE || P == ICmpInst::ICMP_UGE) + && IVBasedValues.count(Op.getOperand(1)) + && L->isLoopInvariant(Op.getOperand(0))) + return Op.getOperand(0); + + return NULL; +} + +/// IVisGT - If Op is comparing IV based value with an loop invariant and +/// IV based value is greater than the loop invariant then return the loop +/// invariant. Otherwise return NULL. +Value * LoopIndexSplit::IVisGT(ICmpInst &Op) { + ICmpInst::Predicate P = Op.getPredicate(); + if ((P == ICmpInst::ICMP_SGT || P == ICmpInst::ICMP_UGT) + && IVBasedValues.count(Op.getOperand(0)) + && L->isLoopInvariant(Op.getOperand(1))) + return Op.getOperand(1); + + if ((P == ICmpInst::ICMP_SLT || P == ICmpInst::ICMP_ULT) + && IVBasedValues.count(Op.getOperand(1)) + && L->isLoopInvariant(Op.getOperand(0))) + return Op.getOperand(0); + + return NULL; +} + +/// IVisGE - If Op is comparing IV based value with an loop invariant and +/// IV based value is greater than or equal to the loop invariant then +/// return the loop invariant. Otherwise return NULL. +Value * LoopIndexSplit::IVisGE(ICmpInst &Op) { + ICmpInst::Predicate P = Op.getPredicate(); + if ((P == ICmpInst::ICMP_SGE || P == ICmpInst::ICMP_UGE) + && IVBasedValues.count(Op.getOperand(0)) + && L->isLoopInvariant(Op.getOperand(1))) + return Op.getOperand(1); + + if ((P == ICmpInst::ICMP_SLE || P == ICmpInst::ICMP_ULE) + && IVBasedValues.count(Op.getOperand(1)) + && L->isLoopInvariant(Op.getOperand(0))) + return Op.getOperand(0); + + return NULL; +} + diff --git a/lib/Transforms/Scalar/LoopRotation.cpp b/lib/Transforms/Scalar/LoopRotation.cpp new file mode 100644 index 0000000..a088230 --- /dev/null +++ b/lib/Transforms/Scalar/LoopRotation.cpp @@ -0,0 +1,572 @@ +//===- LoopRotation.cpp - Loop Rotation Pass ------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements Loop Rotation Pass. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "loop-rotate" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Function.h" +#include "llvm/IntrinsicInst.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/Dominators.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/SmallVector.h" +using namespace llvm; + +#define MAX_HEADER_SIZE 16 + +STATISTIC(NumRotated, "Number of loops rotated"); +namespace { + + class VISIBILITY_HIDDEN RenameData { + public: + RenameData(Instruction *O, Value *P, Instruction *H) + : Original(O), PreHeader(P), Header(H) { } + public: + Instruction *Original; // Original instruction + Value *PreHeader; // Original pre-header replacement + Instruction *Header; // New header replacement + }; + + class VISIBILITY_HIDDEN LoopRotate : public LoopPass { + + public: + static char ID; // Pass ID, replacement for typeid + LoopRotate() : LoopPass(&ID) {} + + // Rotate Loop L as many times as possible. Return true if + // loop is rotated at least once. + bool runOnLoop(Loop *L, LPPassManager &LPM); + + // LCSSA form makes instruction renaming easier. + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequiredID(LoopSimplifyID); + AU.addPreservedID(LoopSimplifyID); + AU.addRequiredID(LCSSAID); + AU.addPreservedID(LCSSAID); + AU.addPreserved<ScalarEvolution>(); + AU.addPreserved<LoopInfo>(); + AU.addPreserved<DominatorTree>(); + AU.addPreserved<DominanceFrontier>(); + } + + // Helper functions + + /// Do actual work + bool rotateLoop(Loop *L, LPPassManager &LPM); + + /// Initialize local data + void initialize(); + + /// Make sure all Exit block PHINodes have required incoming values. + /// If incoming value is constant or defined outside the loop then + /// PHINode may not have an entry for original pre-header. + void updateExitBlock(); + + /// Return true if this instruction is used outside original header. + bool usedOutsideOriginalHeader(Instruction *In); + + /// Find Replacement information for instruction. Return NULL if it is + /// not available. + const RenameData *findReplacementData(Instruction *I); + + /// After loop rotation, loop pre-header has multiple sucessors. + /// Insert one forwarding basic block to ensure that loop pre-header + /// has only one successor. + void preserveCanonicalLoopForm(LPPassManager &LPM); + + private: + + Loop *L; + BasicBlock *OrigHeader; + BasicBlock *OrigPreHeader; + BasicBlock *OrigLatch; + BasicBlock *NewHeader; + BasicBlock *Exit; + LPPassManager *LPM_Ptr; + SmallVector<RenameData, MAX_HEADER_SIZE> LoopHeaderInfo; + }; +} + +char LoopRotate::ID = 0; +static RegisterPass<LoopRotate> X("loop-rotate", "Rotate Loops"); + +Pass *llvm::createLoopRotatePass() { return new LoopRotate(); } + +/// Rotate Loop L as many times as possible. Return true if +/// loop is rotated at least once. +bool LoopRotate::runOnLoop(Loop *Lp, LPPassManager &LPM) { + + bool RotatedOneLoop = false; + initialize(); + LPM_Ptr = &LPM; + + // One loop can be rotated multiple times. + while (rotateLoop(Lp,LPM)) { + RotatedOneLoop = true; + initialize(); + } + + return RotatedOneLoop; +} + +/// Rotate loop LP. Return true if the loop is rotated. +bool LoopRotate::rotateLoop(Loop *Lp, LPPassManager &LPM) { + L = Lp; + + OrigHeader = L->getHeader(); + OrigPreHeader = L->getLoopPreheader(); + OrigLatch = L->getLoopLatch(); + + // If loop has only one block then there is not much to rotate. + if (L->getBlocks().size() == 1) + return false; + + assert(OrigHeader && OrigLatch && OrigPreHeader && + "Loop is not in canonical form"); + + // If loop header is not one of the loop exit block then + // either this loop is already rotated or it is not + // suitable for loop rotation transformations. + if (!L->isLoopExit(OrigHeader)) + return false; + + BranchInst *BI = dyn_cast<BranchInst>(OrigHeader->getTerminator()); + if (!BI) + return false; + assert(BI->isConditional() && "Branch Instruction is not conditional"); + + // Updating PHInodes in loops with multiple exits adds complexity. + // Keep it simple, and restrict loop rotation to loops with one exit only. + // In future, lift this restriction and support for multiple exits if + // required. + SmallVector<BasicBlock*, 8> ExitBlocks; + L->getExitBlocks(ExitBlocks); + if (ExitBlocks.size() > 1) + return false; + + // Check size of original header and reject + // loop if it is very big. + unsigned Size = 0; + + // FIXME: Use common api to estimate size. + for (BasicBlock::const_iterator OI = OrigHeader->begin(), + OE = OrigHeader->end(); OI != OE; ++OI) { + if (isa<PHINode>(OI)) + continue; // PHI nodes don't count. + if (isa<DbgInfoIntrinsic>(OI)) + continue; // Debug intrinsics don't count as size. + Size++; + } + + if (Size > MAX_HEADER_SIZE) + return false; + + // Now, this loop is suitable for rotation. + + // Find new Loop header. NewHeader is a Header's one and only successor + // that is inside loop. Header's other successor is outside the + // loop. Otherwise loop is not suitable for rotation. + Exit = BI->getSuccessor(0); + NewHeader = BI->getSuccessor(1); + if (L->contains(Exit)) + std::swap(Exit, NewHeader); + assert(NewHeader && "Unable to determine new loop header"); + assert(L->contains(NewHeader) && !L->contains(Exit) && + "Unable to determine loop header and exit blocks"); + + // This code assumes that new header has exactly one predecessor. Remove any + // single entry PHI nodes in it. + assert(NewHeader->getSinglePredecessor() && + "New header doesn't have one pred!"); + FoldSingleEntryPHINodes(NewHeader); + + // Copy PHI nodes and other instructions from original header + // into original pre-header. Unlike original header, original pre-header is + // not a member of loop. + // + // New loop header is one and only successor of original header that + // is inside the loop. All other original header successors are outside + // the loop. Copy PHI Nodes from original header into new loop header. + // Add second incoming value, from original loop pre-header into these phi + // nodes. If a value defined in original header is used outside original + // header then new loop header will need new phi nodes with two incoming + // values, one definition from original header and second definition is + // from original loop pre-header. + + // Remove terminator from Original pre-header. Original pre-header will + // receive a clone of original header terminator as a new terminator. + OrigPreHeader->getInstList().pop_back(); + BasicBlock::iterator I = OrigHeader->begin(), E = OrigHeader->end(); + PHINode *PN = 0; + for (; (PN = dyn_cast<PHINode>(I)); ++I) { + // PHI nodes are not copied into original pre-header. Instead their values + // are directly propagated. + Value *NPV = PN->getIncomingValueForBlock(OrigPreHeader); + + // Create new PHI node with two incoming values for NewHeader. + // One incoming value is from OrigLatch (through OrigHeader) and + // second incoming value is from original pre-header. + PHINode *NH = PHINode::Create(PN->getType(), PN->getName(), + NewHeader->begin()); + NH->addIncoming(PN->getIncomingValueForBlock(OrigLatch), OrigHeader); + NH->addIncoming(NPV, OrigPreHeader); + + // "In" can be replaced by NH at various places. + LoopHeaderInfo.push_back(RenameData(PN, NPV, NH)); + } + + // Now, handle non-phi instructions. + for (; I != E; ++I) { + Instruction *In = I; + assert(!isa<PHINode>(In) && "PHINode is not expected here"); + + // This is not a PHI instruction. Insert its clone into original pre-header. + // If this instruction is using a value from same basic block then + // update it to use value from cloned instruction. + Instruction *C = In->clone(); + C->setName(In->getName()); + OrigPreHeader->getInstList().push_back(C); + + for (unsigned opi = 0, e = In->getNumOperands(); opi != e; ++opi) { + Instruction *OpInsn = dyn_cast<Instruction>(In->getOperand(opi)); + if (!OpInsn) continue; // Ignore non-instruction values. + if (const RenameData *D = findReplacementData(OpInsn)) + C->setOperand(opi, D->PreHeader); + } + + // If this instruction is used outside this basic block then + // create new PHINode for this instruction. + Instruction *NewHeaderReplacement = NULL; + if (usedOutsideOriginalHeader(In)) { + PHINode *PN = PHINode::Create(In->getType(), In->getName(), + NewHeader->begin()); + PN->addIncoming(In, OrigHeader); + PN->addIncoming(C, OrigPreHeader); + NewHeaderReplacement = PN; + } + LoopHeaderInfo.push_back(RenameData(In, C, NewHeaderReplacement)); + } + + // Rename uses of original header instructions to reflect their new + // definitions (either from original pre-header node or from newly created + // new header PHINodes. + // + // Original header instructions are used in + // 1) Original header: + // + // If instruction is used in non-phi instructions then it is using + // defintion from original heder iteself. Do not replace this use + // with definition from new header or original pre-header. + // + // If instruction is used in phi node then it is an incoming + // value. Rename its use to reflect new definition from new-preheader + // or new header. + // + // 2) Inside loop but not in original header + // + // Replace this use to reflect definition from new header. + for (unsigned LHI = 0, LHI_E = LoopHeaderInfo.size(); LHI != LHI_E; ++LHI) { + const RenameData &ILoopHeaderInfo = LoopHeaderInfo[LHI]; + + if (!ILoopHeaderInfo.Header) + continue; + + Instruction *OldPhi = ILoopHeaderInfo.Original; + Instruction *NewPhi = ILoopHeaderInfo.Header; + + // Before replacing uses, collect them first, so that iterator is + // not invalidated. + SmallVector<Instruction *, 16> AllUses; + for (Value::use_iterator UI = OldPhi->use_begin(), UE = OldPhi->use_end(); + UI != UE; ++UI) + AllUses.push_back(cast<Instruction>(UI)); + + for (SmallVector<Instruction *, 16>::iterator UI = AllUses.begin(), + UE = AllUses.end(); UI != UE; ++UI) { + Instruction *U = *UI; + BasicBlock *Parent = U->getParent(); + + // Used inside original header + if (Parent == OrigHeader) { + // Do not rename uses inside original header non-phi instructions. + PHINode *PU = dyn_cast<PHINode>(U); + if (!PU) + continue; + + // Do not rename uses inside original header phi nodes, if the + // incoming value is for new header. + if (PU->getBasicBlockIndex(NewHeader) != -1 + && PU->getIncomingValueForBlock(NewHeader) == U) + continue; + + U->replaceUsesOfWith(OldPhi, NewPhi); + continue; + } + + // Used inside loop, but not in original header. + if (L->contains(U->getParent())) { + if (U != NewPhi) + U->replaceUsesOfWith(OldPhi, NewPhi); + continue; + } + + // Used inside Exit Block. Since we are in LCSSA form, U must be PHINode. + if (U->getParent() == Exit) { + assert(isa<PHINode>(U) && "Use in Exit Block that is not PHINode"); + + PHINode *UPhi = cast<PHINode>(U); + // UPhi already has one incoming argument from original header. + // Add second incoming argument from new Pre header. + UPhi->addIncoming(ILoopHeaderInfo.PreHeader, OrigPreHeader); + } else { + // Used outside Exit block. Create a new PHI node from exit block + // to receive value from ne new header ane pre header. + PHINode *PN = PHINode::Create(U->getType(), U->getName(), + Exit->begin()); + PN->addIncoming(ILoopHeaderInfo.PreHeader, OrigPreHeader); + PN->addIncoming(OldPhi, OrigHeader); + U->replaceUsesOfWith(OldPhi, PN); + } + } + } + + /// Make sure all Exit block PHINodes have required incoming values. + updateExitBlock(); + + // Update CFG + + // Removing incoming branch from loop preheader to original header. + // Now original header is inside the loop. + for (BasicBlock::iterator I = OrigHeader->begin(), E = OrigHeader->end(); + I != E; ++I) + if (PHINode *PN = dyn_cast<PHINode>(I)) + PN->removeIncomingValue(OrigPreHeader); + + // Make NewHeader as the new header for the loop. + L->moveToHeader(NewHeader); + + preserveCanonicalLoopForm(LPM); + + NumRotated++; + return true; +} + +/// Make sure all Exit block PHINodes have required incoming values. +/// If incoming value is constant or defined outside the loop then +/// PHINode may not have an entry for original pre-header. +void LoopRotate::updateExitBlock() { + + for (BasicBlock::iterator I = Exit->begin(), E = Exit->end(); + I != E; ++I) { + + PHINode *PN = dyn_cast<PHINode>(I); + if (!PN) + break; + + // There is already one incoming value from original pre-header block. + if (PN->getBasicBlockIndex(OrigPreHeader) != -1) + continue; + + const RenameData *ILoopHeaderInfo; + Value *V = PN->getIncomingValueForBlock(OrigHeader); + if (isa<Instruction>(V) && + (ILoopHeaderInfo = findReplacementData(cast<Instruction>(V)))) { + assert(ILoopHeaderInfo->PreHeader && "Missing New Preheader Instruction"); + PN->addIncoming(ILoopHeaderInfo->PreHeader, OrigPreHeader); + } else { + PN->addIncoming(V, OrigPreHeader); + } + } +} + +/// Initialize local data +void LoopRotate::initialize() { + L = NULL; + OrigHeader = NULL; + OrigPreHeader = NULL; + NewHeader = NULL; + Exit = NULL; + + LoopHeaderInfo.clear(); +} + +/// Return true if this instruction is used by any instructions in the loop that +/// aren't in original header. +bool LoopRotate::usedOutsideOriginalHeader(Instruction *In) { + for (Value::use_iterator UI = In->use_begin(), UE = In->use_end(); + UI != UE; ++UI) { + BasicBlock *UserBB = cast<Instruction>(UI)->getParent(); + if (UserBB != OrigHeader && L->contains(UserBB)) + return true; + } + + return false; +} + +/// Find Replacement information for instruction. Return NULL if it is +/// not available. +const RenameData *LoopRotate::findReplacementData(Instruction *In) { + + // Since LoopHeaderInfo is small, linear walk is OK. + for (unsigned LHI = 0, LHI_E = LoopHeaderInfo.size(); LHI != LHI_E; ++LHI) { + const RenameData &ILoopHeaderInfo = LoopHeaderInfo[LHI]; + if (ILoopHeaderInfo.Original == In) + return &ILoopHeaderInfo; + } + return NULL; +} + +/// After loop rotation, loop pre-header has multiple sucessors. +/// Insert one forwarding basic block to ensure that loop pre-header +/// has only one successor. +void LoopRotate::preserveCanonicalLoopForm(LPPassManager &LPM) { + + // Right now original pre-header has two successors, new header and + // exit block. Insert new block between original pre-header and + // new header such that loop's new pre-header has only one successor. + BasicBlock *NewPreHeader = BasicBlock::Create("bb.nph", + OrigHeader->getParent(), + NewHeader); + LoopInfo &LI = LPM.getAnalysis<LoopInfo>(); + if (Loop *PL = LI.getLoopFor(OrigPreHeader)) + PL->addBasicBlockToLoop(NewPreHeader, LI.getBase()); + BranchInst::Create(NewHeader, NewPreHeader); + + BranchInst *OrigPH_BI = cast<BranchInst>(OrigPreHeader->getTerminator()); + if (OrigPH_BI->getSuccessor(0) == NewHeader) + OrigPH_BI->setSuccessor(0, NewPreHeader); + else { + assert(OrigPH_BI->getSuccessor(1) == NewHeader && + "Unexpected original pre-header terminator"); + OrigPH_BI->setSuccessor(1, NewPreHeader); + } + + for (BasicBlock::iterator I = NewHeader->begin(), E = NewHeader->end(); + I != E; ++I) { + PHINode *PN = dyn_cast<PHINode>(I); + if (!PN) + break; + + int index = PN->getBasicBlockIndex(OrigPreHeader); + assert(index != -1 && "Expected incoming value from Original PreHeader"); + PN->setIncomingBlock(index, NewPreHeader); + assert(PN->getBasicBlockIndex(OrigPreHeader) == -1 && + "Expected only one incoming value from Original PreHeader"); + } + + if (DominatorTree *DT = getAnalysisIfAvailable<DominatorTree>()) { + DT->addNewBlock(NewPreHeader, OrigPreHeader); + DT->changeImmediateDominator(L->getHeader(), NewPreHeader); + DT->changeImmediateDominator(Exit, OrigPreHeader); + for (Loop::block_iterator BI = L->block_begin(), BE = L->block_end(); + BI != BE; ++BI) { + BasicBlock *B = *BI; + if (L->getHeader() != B) { + DomTreeNode *Node = DT->getNode(B); + if (Node && Node->getBlock() == OrigHeader) + DT->changeImmediateDominator(*BI, L->getHeader()); + } + } + DT->changeImmediateDominator(OrigHeader, OrigLatch); + } + + if (DominanceFrontier *DF = getAnalysisIfAvailable<DominanceFrontier>()) { + // New Preheader's dominance frontier is Exit block. + DominanceFrontier::DomSetType NewPHSet; + NewPHSet.insert(Exit); + DF->addBasicBlock(NewPreHeader, NewPHSet); + + // New Header's dominance frontier now includes itself and Exit block + DominanceFrontier::iterator HeadI = DF->find(L->getHeader()); + if (HeadI != DF->end()) { + DominanceFrontier::DomSetType & HeaderSet = HeadI->second; + HeaderSet.clear(); + HeaderSet.insert(L->getHeader()); + HeaderSet.insert(Exit); + } else { + DominanceFrontier::DomSetType HeaderSet; + HeaderSet.insert(L->getHeader()); + HeaderSet.insert(Exit); + DF->addBasicBlock(L->getHeader(), HeaderSet); + } + + // Original header (new Loop Latch)'s dominance frontier is Exit. + DominanceFrontier::iterator LatchI = DF->find(L->getLoopLatch()); + if (LatchI != DF->end()) { + DominanceFrontier::DomSetType &LatchSet = LatchI->second; + LatchSet = LatchI->second; + LatchSet.clear(); + LatchSet.insert(Exit); + } else { + DominanceFrontier::DomSetType LatchSet; + LatchSet.insert(Exit); + DF->addBasicBlock(L->getHeader(), LatchSet); + } + + // If a loop block dominates new loop latch then its frontier is + // new header and Exit. + BasicBlock *NewLatch = L->getLoopLatch(); + DominatorTree *DT = getAnalysisIfAvailable<DominatorTree>(); + for (Loop::block_iterator BI = L->block_begin(), BE = L->block_end(); + BI != BE; ++BI) { + BasicBlock *B = *BI; + if (DT->dominates(B, NewLatch)) { + DominanceFrontier::iterator BDFI = DF->find(B); + if (BDFI != DF->end()) { + DominanceFrontier::DomSetType &BSet = BDFI->second; + BSet = BDFI->second; + BSet.clear(); + BSet.insert(L->getHeader()); + BSet.insert(Exit); + } else { + DominanceFrontier::DomSetType BSet; + BSet.insert(L->getHeader()); + BSet.insert(Exit); + DF->addBasicBlock(B, BSet); + } + } + } + } + + // Preserve canonical loop form, which means Exit block should + // have only one predecessor. + BasicBlock *NExit = SplitEdge(L->getLoopLatch(), Exit, this); + + // Preserve LCSSA. + BasicBlock::iterator I = Exit->begin(), E = Exit->end(); + PHINode *PN = NULL; + for (; (PN = dyn_cast<PHINode>(I)); ++I) { + unsigned N = PN->getNumIncomingValues(); + for (unsigned index = 0; index < N; ++index) + if (PN->getIncomingBlock(index) == NExit) { + PHINode *NewPN = PHINode::Create(PN->getType(), PN->getName(), + NExit->begin()); + NewPN->addIncoming(PN->getIncomingValue(index), L->getLoopLatch()); + PN->setIncomingValue(index, NewPN); + PN->setIncomingBlock(index, NExit); + break; + } + } + + assert(NewHeader && L->getHeader() == NewHeader && + "Invalid loop header after loop rotation"); + assert(NewPreHeader && L->getLoopPreheader() == NewPreHeader && + "Invalid loop preheader after loop rotation"); + assert(L->getLoopLatch() && + "Invalid loop latch after loop rotation"); +} diff --git a/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/lib/Transforms/Scalar/LoopStrengthReduce.cpp new file mode 100644 index 0000000..92270b5 --- /dev/null +++ b/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -0,0 +1,2605 @@ +//===- LoopStrengthReduce.cpp - Strength Reduce IVs in Loops --------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This transformation analyzes and transforms the induction variables (and +// computations derived from them) into forms suitable for efficient execution +// on the target. +// +// This pass performs a strength reduction on array references inside loops that +// have as one or more of their components the loop induction variable, it +// rewrites expressions to take advantage of scaled-index addressing modes +// available on the target, and it performs a variety of other optimizations +// related to loop induction variables. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "loop-reduce" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Constants.h" +#include "llvm/Instructions.h" +#include "llvm/IntrinsicInst.h" +#include "llvm/Type.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Analysis/Dominators.h" +#include "llvm/Analysis/IVUsers.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" +#include "llvm/Transforms/Utils/AddrModeMatcher.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Support/CFG.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ValueHandle.h" +#include "llvm/Target/TargetLowering.h" +#include <algorithm> +using namespace llvm; + +STATISTIC(NumReduced , "Number of IV uses strength reduced"); +STATISTIC(NumInserted, "Number of PHIs inserted"); +STATISTIC(NumVariable, "Number of PHIs with variable strides"); +STATISTIC(NumEliminated, "Number of strides eliminated"); +STATISTIC(NumShadow, "Number of Shadow IVs optimized"); +STATISTIC(NumImmSunk, "Number of common expr immediates sunk into uses"); +STATISTIC(NumLoopCond, "Number of loop terminating conds optimized"); + +static cl::opt<bool> EnableFullLSRMode("enable-full-lsr", + cl::init(false), + cl::Hidden); + +namespace { + + struct BasedUser; + + /// IVInfo - This structure keeps track of one IV expression inserted during + /// StrengthReduceStridedIVUsers. It contains the stride, the common base, as + /// well as the PHI node and increment value created for rewrite. + struct VISIBILITY_HIDDEN IVExpr { + SCEVHandle Stride; + SCEVHandle Base; + PHINode *PHI; + + IVExpr(const SCEVHandle &stride, const SCEVHandle &base, PHINode *phi) + : Stride(stride), Base(base), PHI(phi) {} + }; + + /// IVsOfOneStride - This structure keeps track of all IV expression inserted + /// during StrengthReduceStridedIVUsers for a particular stride of the IV. + struct VISIBILITY_HIDDEN IVsOfOneStride { + std::vector<IVExpr> IVs; + + void addIV(const SCEVHandle &Stride, const SCEVHandle &Base, PHINode *PHI) { + IVs.push_back(IVExpr(Stride, Base, PHI)); + } + }; + + class VISIBILITY_HIDDEN LoopStrengthReduce : public LoopPass { + IVUsers *IU; + LoopInfo *LI; + DominatorTree *DT; + ScalarEvolution *SE; + bool Changed; + + /// IVsByStride - Keep track of all IVs that have been inserted for a + /// particular stride. + std::map<SCEVHandle, IVsOfOneStride> IVsByStride; + + /// StrideNoReuse - Keep track of all the strides whose ivs cannot be + /// reused (nor should they be rewritten to reuse other strides). + SmallSet<SCEVHandle, 4> StrideNoReuse; + + /// DeadInsts - Keep track of instructions we may have made dead, so that + /// we can remove them after we are done working. + SmallVector<WeakVH, 16> DeadInsts; + + /// TLI - Keep a pointer of a TargetLowering to consult for determining + /// transformation profitability. + const TargetLowering *TLI; + + public: + static char ID; // Pass ID, replacement for typeid + explicit LoopStrengthReduce(const TargetLowering *tli = NULL) : + LoopPass(&ID), TLI(tli) { + } + + bool runOnLoop(Loop *L, LPPassManager &LPM); + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + // We split critical edges, so we change the CFG. However, we do update + // many analyses if they are around. + AU.addPreservedID(LoopSimplifyID); + AU.addPreserved<LoopInfo>(); + AU.addPreserved<DominanceFrontier>(); + AU.addPreserved<DominatorTree>(); + + AU.addRequiredID(LoopSimplifyID); + AU.addRequired<LoopInfo>(); + AU.addRequired<DominatorTree>(); + AU.addRequired<ScalarEvolution>(); + AU.addPreserved<ScalarEvolution>(); + AU.addRequired<IVUsers>(); + AU.addPreserved<IVUsers>(); + } + + private: + ICmpInst *ChangeCompareStride(Loop *L, ICmpInst *Cond, + IVStrideUse* &CondUse, + const SCEVHandle* &CondStride); + + void OptimizeIndvars(Loop *L); + void OptimizeLoopCountIV(Loop *L); + void OptimizeLoopTermCond(Loop *L); + + /// OptimizeShadowIV - If IV is used in a int-to-float cast + /// inside the loop then try to eliminate the cast opeation. + void OptimizeShadowIV(Loop *L); + + /// OptimizeSMax - Rewrite the loop's terminating condition + /// if it uses an smax computation. + ICmpInst *OptimizeSMax(Loop *L, ICmpInst *Cond, + IVStrideUse* &CondUse); + + bool FindIVUserForCond(ICmpInst *Cond, IVStrideUse *&CondUse, + const SCEVHandle *&CondStride); + bool RequiresTypeConversion(const Type *Ty, const Type *NewTy); + SCEVHandle CheckForIVReuse(bool, bool, bool, const SCEVHandle&, + IVExpr&, const Type*, + const std::vector<BasedUser>& UsersToProcess); + bool ValidScale(bool, int64_t, + const std::vector<BasedUser>& UsersToProcess); + bool ValidOffset(bool, int64_t, int64_t, + const std::vector<BasedUser>& UsersToProcess); + SCEVHandle CollectIVUsers(const SCEVHandle &Stride, + IVUsersOfOneStride &Uses, + Loop *L, + bool &AllUsesAreAddresses, + bool &AllUsesAreOutsideLoop, + std::vector<BasedUser> &UsersToProcess); + bool ShouldUseFullStrengthReductionMode( + const std::vector<BasedUser> &UsersToProcess, + const Loop *L, + bool AllUsesAreAddresses, + SCEVHandle Stride); + void PrepareToStrengthReduceFully( + std::vector<BasedUser> &UsersToProcess, + SCEVHandle Stride, + SCEVHandle CommonExprs, + const Loop *L, + SCEVExpander &PreheaderRewriter); + void PrepareToStrengthReduceFromSmallerStride( + std::vector<BasedUser> &UsersToProcess, + Value *CommonBaseV, + const IVExpr &ReuseIV, + Instruction *PreInsertPt); + void PrepareToStrengthReduceWithNewPhi( + std::vector<BasedUser> &UsersToProcess, + SCEVHandle Stride, + SCEVHandle CommonExprs, + Value *CommonBaseV, + Instruction *IVIncInsertPt, + const Loop *L, + SCEVExpander &PreheaderRewriter); + void StrengthReduceStridedIVUsers(const SCEVHandle &Stride, + IVUsersOfOneStride &Uses, + Loop *L); + void DeleteTriviallyDeadInstructions(); + }; +} + +char LoopStrengthReduce::ID = 0; +static RegisterPass<LoopStrengthReduce> +X("loop-reduce", "Loop Strength Reduction"); + +Pass *llvm::createLoopStrengthReducePass(const TargetLowering *TLI) { + return new LoopStrengthReduce(TLI); +} + +/// DeleteTriviallyDeadInstructions - If any of the instructions is the +/// specified set are trivially dead, delete them and see if this makes any of +/// their operands subsequently dead. +void LoopStrengthReduce::DeleteTriviallyDeadInstructions() { + if (DeadInsts.empty()) return; + + while (!DeadInsts.empty()) { + Instruction *I = dyn_cast_or_null<Instruction>(DeadInsts.back()); + DeadInsts.pop_back(); + + if (I == 0 || !isInstructionTriviallyDead(I)) + continue; + + for (User::op_iterator OI = I->op_begin(), E = I->op_end(); OI != E; ++OI) { + if (Instruction *U = dyn_cast<Instruction>(*OI)) { + *OI = 0; + if (U->use_empty()) + DeadInsts.push_back(U); + } + } + + I->eraseFromParent(); + Changed = true; + } +} + +/// containsAddRecFromDifferentLoop - Determine whether expression S involves a +/// subexpression that is an AddRec from a loop other than L. An outer loop +/// of L is OK, but not an inner loop nor a disjoint loop. +static bool containsAddRecFromDifferentLoop(SCEVHandle S, Loop *L) { + // This is very common, put it first. + if (isa<SCEVConstant>(S)) + return false; + if (const SCEVCommutativeExpr *AE = dyn_cast<SCEVCommutativeExpr>(S)) { + for (unsigned int i=0; i< AE->getNumOperands(); i++) + if (containsAddRecFromDifferentLoop(AE->getOperand(i), L)) + return true; + return false; + } + if (const SCEVAddRecExpr *AE = dyn_cast<SCEVAddRecExpr>(S)) { + if (const Loop *newLoop = AE->getLoop()) { + if (newLoop == L) + return false; + // if newLoop is an outer loop of L, this is OK. + if (!LoopInfoBase<BasicBlock>::isNotAlreadyContainedIn(L, newLoop)) + return false; + } + return true; + } + if (const SCEVUDivExpr *DE = dyn_cast<SCEVUDivExpr>(S)) + return containsAddRecFromDifferentLoop(DE->getLHS(), L) || + containsAddRecFromDifferentLoop(DE->getRHS(), L); +#if 0 + // SCEVSDivExpr has been backed out temporarily, but will be back; we'll + // need this when it is. + if (const SCEVSDivExpr *DE = dyn_cast<SCEVSDivExpr>(S)) + return containsAddRecFromDifferentLoop(DE->getLHS(), L) || + containsAddRecFromDifferentLoop(DE->getRHS(), L); +#endif + if (const SCEVCastExpr *CE = dyn_cast<SCEVCastExpr>(S)) + return containsAddRecFromDifferentLoop(CE->getOperand(), L); + return false; +} + +/// isAddressUse - Returns true if the specified instruction is using the +/// specified value as an address. +static bool isAddressUse(Instruction *Inst, Value *OperandVal) { + bool isAddress = isa<LoadInst>(Inst); + if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { + if (SI->getOperand(1) == OperandVal) + isAddress = true; + } else if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) { + // Addressing modes can also be folded into prefetches and a variety + // of intrinsics. + switch (II->getIntrinsicID()) { + default: break; + case Intrinsic::prefetch: + case Intrinsic::x86_sse2_loadu_dq: + case Intrinsic::x86_sse2_loadu_pd: + case Intrinsic::x86_sse_loadu_ps: + case Intrinsic::x86_sse_storeu_ps: + case Intrinsic::x86_sse2_storeu_pd: + case Intrinsic::x86_sse2_storeu_dq: + case Intrinsic::x86_sse2_storel_dq: + if (II->getOperand(1) == OperandVal) + isAddress = true; + break; + } + } + return isAddress; +} + +/// getAccessType - Return the type of the memory being accessed. +static const Type *getAccessType(const Instruction *Inst) { + const Type *AccessTy = Inst->getType(); + if (const StoreInst *SI = dyn_cast<StoreInst>(Inst)) + AccessTy = SI->getOperand(0)->getType(); + else if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) { + // Addressing modes can also be folded into prefetches and a variety + // of intrinsics. + switch (II->getIntrinsicID()) { + default: break; + case Intrinsic::x86_sse_storeu_ps: + case Intrinsic::x86_sse2_storeu_pd: + case Intrinsic::x86_sse2_storeu_dq: + case Intrinsic::x86_sse2_storel_dq: + AccessTy = II->getOperand(1)->getType(); + break; + } + } + return AccessTy; +} + +namespace { + /// BasedUser - For a particular base value, keep information about how we've + /// partitioned the expression so far. + struct BasedUser { + /// SE - The current ScalarEvolution object. + ScalarEvolution *SE; + + /// Base - The Base value for the PHI node that needs to be inserted for + /// this use. As the use is processed, information gets moved from this + /// field to the Imm field (below). BasedUser values are sorted by this + /// field. + SCEVHandle Base; + + /// Inst - The instruction using the induction variable. + Instruction *Inst; + + /// OperandValToReplace - The operand value of Inst to replace with the + /// EmittedBase. + Value *OperandValToReplace; + + /// isSigned - The stride (and thus also the Base) of this use may be in + /// a narrower type than the use itself (OperandValToReplace->getType()). + /// When this is the case, the isSigned field indicates whether the + /// IV expression should be signed-extended instead of zero-extended to + /// fit the type of the use. + bool isSigned; + + /// Imm - The immediate value that should be added to the base immediately + /// before Inst, because it will be folded into the imm field of the + /// instruction. This is also sometimes used for loop-variant values that + /// must be added inside the loop. + SCEVHandle Imm; + + /// Phi - The induction variable that performs the striding that + /// should be used for this user. + PHINode *Phi; + + // isUseOfPostIncrementedValue - True if this should use the + // post-incremented version of this IV, not the preincremented version. + // This can only be set in special cases, such as the terminating setcc + // instruction for a loop and uses outside the loop that are dominated by + // the loop. + bool isUseOfPostIncrementedValue; + + BasedUser(IVStrideUse &IVSU, ScalarEvolution *se) + : SE(se), Base(IVSU.getOffset()), Inst(IVSU.getUser()), + OperandValToReplace(IVSU.getOperandValToReplace()), + isSigned(IVSU.isSigned()), + Imm(SE->getIntegerSCEV(0, Base->getType())), + isUseOfPostIncrementedValue(IVSU.isUseOfPostIncrementedValue()) {} + + // Once we rewrite the code to insert the new IVs we want, update the + // operands of Inst to use the new expression 'NewBase', with 'Imm' added + // to it. + void RewriteInstructionToUseNewBase(const SCEVHandle &NewBase, + Instruction *InsertPt, + SCEVExpander &Rewriter, Loop *L, Pass *P, + LoopInfo &LI, + SmallVectorImpl<WeakVH> &DeadInsts); + + Value *InsertCodeForBaseAtPosition(const SCEVHandle &NewBase, + const Type *Ty, + SCEVExpander &Rewriter, + Instruction *IP, Loop *L, + LoopInfo &LI); + void dump() const; + }; +} + +void BasedUser::dump() const { + cerr << " Base=" << *Base; + cerr << " Imm=" << *Imm; + cerr << " Inst: " << *Inst; +} + +Value *BasedUser::InsertCodeForBaseAtPosition(const SCEVHandle &NewBase, + const Type *Ty, + SCEVExpander &Rewriter, + Instruction *IP, Loop *L, + LoopInfo &LI) { + // Figure out where we *really* want to insert this code. In particular, if + // the user is inside of a loop that is nested inside of L, we really don't + // want to insert this expression before the user, we'd rather pull it out as + // many loops as possible. + Instruction *BaseInsertPt = IP; + + // Figure out the most-nested loop that IP is in. + Loop *InsertLoop = LI.getLoopFor(IP->getParent()); + + // If InsertLoop is not L, and InsertLoop is nested inside of L, figure out + // the preheader of the outer-most loop where NewBase is not loop invariant. + if (L->contains(IP->getParent())) + while (InsertLoop && NewBase->isLoopInvariant(InsertLoop)) { + BaseInsertPt = InsertLoop->getLoopPreheader()->getTerminator(); + InsertLoop = InsertLoop->getParentLoop(); + } + + Value *Base = Rewriter.expandCodeFor(NewBase, 0, BaseInsertPt); + + SCEVHandle NewValSCEV = SE->getUnknown(Base); + + // If there is no immediate value, skip the next part. + if (!Imm->isZero()) { + // If we are inserting the base and imm values in the same block, make sure + // to adjust the IP position if insertion reused a result. + if (IP == BaseInsertPt) + IP = Rewriter.getInsertionPoint(); + + // Always emit the immediate (if non-zero) into the same block as the user. + NewValSCEV = SE->getAddExpr(NewValSCEV, Imm); + } + + if (isSigned) + NewValSCEV = SE->getTruncateOrSignExtend(NewValSCEV, Ty); + else + NewValSCEV = SE->getTruncateOrZeroExtend(NewValSCEV, Ty); + + return Rewriter.expandCodeFor(NewValSCEV, Ty, IP); +} + + +// Once we rewrite the code to insert the new IVs we want, update the +// operands of Inst to use the new expression 'NewBase', with 'Imm' added +// to it. NewBasePt is the last instruction which contributes to the +// value of NewBase in the case that it's a diffferent instruction from +// the PHI that NewBase is computed from, or null otherwise. +// +void BasedUser::RewriteInstructionToUseNewBase(const SCEVHandle &NewBase, + Instruction *NewBasePt, + SCEVExpander &Rewriter, Loop *L, Pass *P, + LoopInfo &LI, + SmallVectorImpl<WeakVH> &DeadInsts) { + if (!isa<PHINode>(Inst)) { + // By default, insert code at the user instruction. + BasicBlock::iterator InsertPt = Inst; + + // However, if the Operand is itself an instruction, the (potentially + // complex) inserted code may be shared by many users. Because of this, we + // want to emit code for the computation of the operand right before its old + // computation. This is usually safe, because we obviously used to use the + // computation when it was computed in its current block. However, in some + // cases (e.g. use of a post-incremented induction variable) the NewBase + // value will be pinned to live somewhere after the original computation. + // In this case, we have to back off. + // + // If this is a use outside the loop (which means after, since it is based + // on a loop indvar) we use the post-incremented value, so that we don't + // artificially make the preinc value live out the bottom of the loop. + if (!isUseOfPostIncrementedValue && L->contains(Inst->getParent())) { + if (NewBasePt && isa<PHINode>(OperandValToReplace)) { + InsertPt = NewBasePt; + ++InsertPt; + } else if (Instruction *OpInst + = dyn_cast<Instruction>(OperandValToReplace)) { + InsertPt = OpInst; + while (isa<PHINode>(InsertPt)) ++InsertPt; + } + } + Value *NewVal = InsertCodeForBaseAtPosition(NewBase, + OperandValToReplace->getType(), + Rewriter, InsertPt, L, LI); + // Replace the use of the operand Value with the new Phi we just created. + Inst->replaceUsesOfWith(OperandValToReplace, NewVal); + + DOUT << " Replacing with "; + DEBUG(WriteAsOperand(*DOUT, NewVal, /*PrintType=*/false)); + DOUT << ", which has value " << *NewBase << " plus IMM " << *Imm << "\n"; + return; + } + + // PHI nodes are more complex. We have to insert one copy of the NewBase+Imm + // expression into each operand block that uses it. Note that PHI nodes can + // have multiple entries for the same predecessor. We use a map to make sure + // that a PHI node only has a single Value* for each predecessor (which also + // prevents us from inserting duplicate code in some blocks). + DenseMap<BasicBlock*, Value*> InsertedCode; + PHINode *PN = cast<PHINode>(Inst); + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { + if (PN->getIncomingValue(i) == OperandValToReplace) { + // If the original expression is outside the loop, put the replacement + // code in the same place as the original expression, + // which need not be an immediate predecessor of this PHI. This way we + // need only one copy of it even if it is referenced multiple times in + // the PHI. We don't do this when the original expression is inside the + // loop because multiple copies sometimes do useful sinking of code in + // that case(?). + Instruction *OldLoc = dyn_cast<Instruction>(OperandValToReplace); + if (L->contains(OldLoc->getParent())) { + // If this is a critical edge, split the edge so that we do not insert + // the code on all predecessor/successor paths. We do this unless this + // is the canonical backedge for this loop, as this can make some + // inserted code be in an illegal position. + BasicBlock *PHIPred = PN->getIncomingBlock(i); + if (e != 1 && PHIPred->getTerminator()->getNumSuccessors() > 1 && + (PN->getParent() != L->getHeader() || !L->contains(PHIPred))) { + + // First step, split the critical edge. + SplitCriticalEdge(PHIPred, PN->getParent(), P, false); + + // Next step: move the basic block. In particular, if the PHI node + // is outside of the loop, and PredTI is in the loop, we want to + // move the block to be immediately before the PHI block, not + // immediately after PredTI. + if (L->contains(PHIPred) && !L->contains(PN->getParent())) { + BasicBlock *NewBB = PN->getIncomingBlock(i); + NewBB->moveBefore(PN->getParent()); + } + + // Splitting the edge can reduce the number of PHI entries we have. + e = PN->getNumIncomingValues(); + } + } + Value *&Code = InsertedCode[PN->getIncomingBlock(i)]; + if (!Code) { + // Insert the code into the end of the predecessor block. + Instruction *InsertPt = (L->contains(OldLoc->getParent())) ? + PN->getIncomingBlock(i)->getTerminator() : + OldLoc->getParent()->getTerminator(); + Code = InsertCodeForBaseAtPosition(NewBase, PN->getType(), + Rewriter, InsertPt, L, LI); + + DOUT << " Changing PHI use to "; + DEBUG(WriteAsOperand(*DOUT, Code, /*PrintType=*/false)); + DOUT << ", which has value " << *NewBase << " plus IMM " << *Imm << "\n"; + } + + // Replace the use of the operand Value with the new Phi we just created. + PN->setIncomingValue(i, Code); + Rewriter.clear(); + } + } + + // PHI node might have become a constant value after SplitCriticalEdge. + DeadInsts.push_back(Inst); +} + + +/// fitsInAddressMode - Return true if V can be subsumed within an addressing +/// mode, and does not need to be put in a register first. +static bool fitsInAddressMode(const SCEVHandle &V, const Type *AccessTy, + const TargetLowering *TLI, bool HasBaseReg) { + if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(V)) { + int64_t VC = SC->getValue()->getSExtValue(); + if (TLI) { + TargetLowering::AddrMode AM; + AM.BaseOffs = VC; + AM.HasBaseReg = HasBaseReg; + return TLI->isLegalAddressingMode(AM, AccessTy); + } else { + // Defaults to PPC. PPC allows a sign-extended 16-bit immediate field. + return (VC > -(1 << 16) && VC < (1 << 16)-1); + } + } + + if (const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(V)) + if (GlobalValue *GV = dyn_cast<GlobalValue>(SU->getValue())) { + if (TLI) { + TargetLowering::AddrMode AM; + AM.BaseGV = GV; + AM.HasBaseReg = HasBaseReg; + return TLI->isLegalAddressingMode(AM, AccessTy); + } else { + // Default: assume global addresses are not legal. + } + } + + return false; +} + +/// MoveLoopVariantsToImmediateField - Move any subexpressions from Val that are +/// loop varying to the Imm operand. +static void MoveLoopVariantsToImmediateField(SCEVHandle &Val, SCEVHandle &Imm, + Loop *L, ScalarEvolution *SE) { + if (Val->isLoopInvariant(L)) return; // Nothing to do. + + if (const SCEVAddExpr *SAE = dyn_cast<SCEVAddExpr>(Val)) { + std::vector<SCEVHandle> NewOps; + NewOps.reserve(SAE->getNumOperands()); + + for (unsigned i = 0; i != SAE->getNumOperands(); ++i) + if (!SAE->getOperand(i)->isLoopInvariant(L)) { + // If this is a loop-variant expression, it must stay in the immediate + // field of the expression. + Imm = SE->getAddExpr(Imm, SAE->getOperand(i)); + } else { + NewOps.push_back(SAE->getOperand(i)); + } + + if (NewOps.empty()) + Val = SE->getIntegerSCEV(0, Val->getType()); + else + Val = SE->getAddExpr(NewOps); + } else if (const SCEVAddRecExpr *SARE = dyn_cast<SCEVAddRecExpr>(Val)) { + // Try to pull immediates out of the start value of nested addrec's. + SCEVHandle Start = SARE->getStart(); + MoveLoopVariantsToImmediateField(Start, Imm, L, SE); + + std::vector<SCEVHandle> Ops(SARE->op_begin(), SARE->op_end()); + Ops[0] = Start; + Val = SE->getAddRecExpr(Ops, SARE->getLoop()); + } else { + // Otherwise, all of Val is variant, move the whole thing over. + Imm = SE->getAddExpr(Imm, Val); + Val = SE->getIntegerSCEV(0, Val->getType()); + } +} + + +/// MoveImmediateValues - Look at Val, and pull out any additions of constants +/// that can fit into the immediate field of instructions in the target. +/// Accumulate these immediate values into the Imm value. +static void MoveImmediateValues(const TargetLowering *TLI, + const Type *AccessTy, + SCEVHandle &Val, SCEVHandle &Imm, + bool isAddress, Loop *L, + ScalarEvolution *SE) { + if (const SCEVAddExpr *SAE = dyn_cast<SCEVAddExpr>(Val)) { + std::vector<SCEVHandle> NewOps; + NewOps.reserve(SAE->getNumOperands()); + + for (unsigned i = 0; i != SAE->getNumOperands(); ++i) { + SCEVHandle NewOp = SAE->getOperand(i); + MoveImmediateValues(TLI, AccessTy, NewOp, Imm, isAddress, L, SE); + + if (!NewOp->isLoopInvariant(L)) { + // If this is a loop-variant expression, it must stay in the immediate + // field of the expression. + Imm = SE->getAddExpr(Imm, NewOp); + } else { + NewOps.push_back(NewOp); + } + } + + if (NewOps.empty()) + Val = SE->getIntegerSCEV(0, Val->getType()); + else + Val = SE->getAddExpr(NewOps); + return; + } else if (const SCEVAddRecExpr *SARE = dyn_cast<SCEVAddRecExpr>(Val)) { + // Try to pull immediates out of the start value of nested addrec's. + SCEVHandle Start = SARE->getStart(); + MoveImmediateValues(TLI, AccessTy, Start, Imm, isAddress, L, SE); + + if (Start != SARE->getStart()) { + std::vector<SCEVHandle> Ops(SARE->op_begin(), SARE->op_end()); + Ops[0] = Start; + Val = SE->getAddRecExpr(Ops, SARE->getLoop()); + } + return; + } else if (const SCEVMulExpr *SME = dyn_cast<SCEVMulExpr>(Val)) { + // Transform "8 * (4 + v)" -> "32 + 8*V" if "32" fits in the immed field. + if (isAddress && + fitsInAddressMode(SME->getOperand(0), AccessTy, TLI, false) && + SME->getNumOperands() == 2 && SME->isLoopInvariant(L)) { + + SCEVHandle SubImm = SE->getIntegerSCEV(0, Val->getType()); + SCEVHandle NewOp = SME->getOperand(1); + MoveImmediateValues(TLI, AccessTy, NewOp, SubImm, isAddress, L, SE); + + // If we extracted something out of the subexpressions, see if we can + // simplify this! + if (NewOp != SME->getOperand(1)) { + // Scale SubImm up by "8". If the result is a target constant, we are + // good. + SubImm = SE->getMulExpr(SubImm, SME->getOperand(0)); + if (fitsInAddressMode(SubImm, AccessTy, TLI, false)) { + // Accumulate the immediate. + Imm = SE->getAddExpr(Imm, SubImm); + + // Update what is left of 'Val'. + Val = SE->getMulExpr(SME->getOperand(0), NewOp); + return; + } + } + } + } + + // Loop-variant expressions must stay in the immediate field of the + // expression. + if ((isAddress && fitsInAddressMode(Val, AccessTy, TLI, false)) || + !Val->isLoopInvariant(L)) { + Imm = SE->getAddExpr(Imm, Val); + Val = SE->getIntegerSCEV(0, Val->getType()); + return; + } + + // Otherwise, no immediates to move. +} + +static void MoveImmediateValues(const TargetLowering *TLI, + Instruction *User, + SCEVHandle &Val, SCEVHandle &Imm, + bool isAddress, Loop *L, + ScalarEvolution *SE) { + const Type *AccessTy = getAccessType(User); + MoveImmediateValues(TLI, AccessTy, Val, Imm, isAddress, L, SE); +} + +/// SeparateSubExprs - Decompose Expr into all of the subexpressions that are +/// added together. This is used to reassociate common addition subexprs +/// together for maximal sharing when rewriting bases. +static void SeparateSubExprs(std::vector<SCEVHandle> &SubExprs, + SCEVHandle Expr, + ScalarEvolution *SE) { + if (const SCEVAddExpr *AE = dyn_cast<SCEVAddExpr>(Expr)) { + for (unsigned j = 0, e = AE->getNumOperands(); j != e; ++j) + SeparateSubExprs(SubExprs, AE->getOperand(j), SE); + } else if (const SCEVAddRecExpr *SARE = dyn_cast<SCEVAddRecExpr>(Expr)) { + SCEVHandle Zero = SE->getIntegerSCEV(0, Expr->getType()); + if (SARE->getOperand(0) == Zero) { + SubExprs.push_back(Expr); + } else { + // Compute the addrec with zero as its base. + std::vector<SCEVHandle> Ops(SARE->op_begin(), SARE->op_end()); + Ops[0] = Zero; // Start with zero base. + SubExprs.push_back(SE->getAddRecExpr(Ops, SARE->getLoop())); + + + SeparateSubExprs(SubExprs, SARE->getOperand(0), SE); + } + } else if (!Expr->isZero()) { + // Do not add zero. + SubExprs.push_back(Expr); + } +} + +// This is logically local to the following function, but C++ says we have +// to make it file scope. +struct SubExprUseData { unsigned Count; bool notAllUsesAreFree; }; + +/// RemoveCommonExpressionsFromUseBases - Look through all of the Bases of all +/// the Uses, removing any common subexpressions, except that if all such +/// subexpressions can be folded into an addressing mode for all uses inside +/// the loop (this case is referred to as "free" in comments herein) we do +/// not remove anything. This looks for things like (a+b+c) and +/// (a+c+d) and computes the common (a+c) subexpression. The common expression +/// is *removed* from the Bases and returned. +static SCEVHandle +RemoveCommonExpressionsFromUseBases(std::vector<BasedUser> &Uses, + ScalarEvolution *SE, Loop *L, + const TargetLowering *TLI) { + unsigned NumUses = Uses.size(); + + // Only one use? This is a very common case, so we handle it specially and + // cheaply. + SCEVHandle Zero = SE->getIntegerSCEV(0, Uses[0].Base->getType()); + SCEVHandle Result = Zero; + SCEVHandle FreeResult = Zero; + if (NumUses == 1) { + // If the use is inside the loop, use its base, regardless of what it is: + // it is clearly shared across all the IV's. If the use is outside the loop + // (which means after it) we don't want to factor anything *into* the loop, + // so just use 0 as the base. + if (L->contains(Uses[0].Inst->getParent())) + std::swap(Result, Uses[0].Base); + return Result; + } + + // To find common subexpressions, count how many of Uses use each expression. + // If any subexpressions are used Uses.size() times, they are common. + // Also track whether all uses of each expression can be moved into an + // an addressing mode "for free"; such expressions are left within the loop. + // struct SubExprUseData { unsigned Count; bool notAllUsesAreFree; }; + std::map<SCEVHandle, SubExprUseData> SubExpressionUseData; + + // UniqueSubExprs - Keep track of all of the subexpressions we see in the + // order we see them. + std::vector<SCEVHandle> UniqueSubExprs; + + std::vector<SCEVHandle> SubExprs; + unsigned NumUsesInsideLoop = 0; + for (unsigned i = 0; i != NumUses; ++i) { + // If the user is outside the loop, just ignore it for base computation. + // Since the user is outside the loop, it must be *after* the loop (if it + // were before, it could not be based on the loop IV). We don't want users + // after the loop to affect base computation of values *inside* the loop, + // because we can always add their offsets to the result IV after the loop + // is done, ensuring we get good code inside the loop. + if (!L->contains(Uses[i].Inst->getParent())) + continue; + NumUsesInsideLoop++; + + // If the base is zero (which is common), return zero now, there are no + // CSEs we can find. + if (Uses[i].Base == Zero) return Zero; + + // If this use is as an address we may be able to put CSEs in the addressing + // mode rather than hoisting them. + bool isAddrUse = isAddressUse(Uses[i].Inst, Uses[i].OperandValToReplace); + // We may need the AccessTy below, but only when isAddrUse, so compute it + // only in that case. + const Type *AccessTy = 0; + if (isAddrUse) + AccessTy = getAccessType(Uses[i].Inst); + + // Split the expression into subexprs. + SeparateSubExprs(SubExprs, Uses[i].Base, SE); + // Add one to SubExpressionUseData.Count for each subexpr present, and + // if the subexpr is not a valid immediate within an addressing mode use, + // set SubExpressionUseData.notAllUsesAreFree. We definitely want to + // hoist these out of the loop (if they are common to all uses). + for (unsigned j = 0, e = SubExprs.size(); j != e; ++j) { + if (++SubExpressionUseData[SubExprs[j]].Count == 1) + UniqueSubExprs.push_back(SubExprs[j]); + if (!isAddrUse || !fitsInAddressMode(SubExprs[j], AccessTy, TLI, false)) + SubExpressionUseData[SubExprs[j]].notAllUsesAreFree = true; + } + SubExprs.clear(); + } + + // Now that we know how many times each is used, build Result. Iterate over + // UniqueSubexprs so that we have a stable ordering. + for (unsigned i = 0, e = UniqueSubExprs.size(); i != e; ++i) { + std::map<SCEVHandle, SubExprUseData>::iterator I = + SubExpressionUseData.find(UniqueSubExprs[i]); + assert(I != SubExpressionUseData.end() && "Entry not found?"); + if (I->second.Count == NumUsesInsideLoop) { // Found CSE! + if (I->second.notAllUsesAreFree) + Result = SE->getAddExpr(Result, I->first); + else + FreeResult = SE->getAddExpr(FreeResult, I->first); + } else + // Remove non-cse's from SubExpressionUseData. + SubExpressionUseData.erase(I); + } + + if (FreeResult != Zero) { + // We have some subexpressions that can be subsumed into addressing + // modes in every use inside the loop. However, it's possible that + // there are so many of them that the combined FreeResult cannot + // be subsumed, or that the target cannot handle both a FreeResult + // and a Result in the same instruction (for example because it would + // require too many registers). Check this. + for (unsigned i=0; i<NumUses; ++i) { + if (!L->contains(Uses[i].Inst->getParent())) + continue; + // We know this is an addressing mode use; if there are any uses that + // are not, FreeResult would be Zero. + const Type *AccessTy = getAccessType(Uses[i].Inst); + if (!fitsInAddressMode(FreeResult, AccessTy, TLI, Result!=Zero)) { + // FIXME: could split up FreeResult into pieces here, some hoisted + // and some not. There is no obvious advantage to this. + Result = SE->getAddExpr(Result, FreeResult); + FreeResult = Zero; + break; + } + } + } + + // If we found no CSE's, return now. + if (Result == Zero) return Result; + + // If we still have a FreeResult, remove its subexpressions from + // SubExpressionUseData. This means they will remain in the use Bases. + if (FreeResult != Zero) { + SeparateSubExprs(SubExprs, FreeResult, SE); + for (unsigned j = 0, e = SubExprs.size(); j != e; ++j) { + std::map<SCEVHandle, SubExprUseData>::iterator I = + SubExpressionUseData.find(SubExprs[j]); + SubExpressionUseData.erase(I); + } + SubExprs.clear(); + } + + // Otherwise, remove all of the CSE's we found from each of the base values. + for (unsigned i = 0; i != NumUses; ++i) { + // Uses outside the loop don't necessarily include the common base, but + // the final IV value coming into those uses does. Instead of trying to + // remove the pieces of the common base, which might not be there, + // subtract off the base to compensate for this. + if (!L->contains(Uses[i].Inst->getParent())) { + Uses[i].Base = SE->getMinusSCEV(Uses[i].Base, Result); + continue; + } + + // Split the expression into subexprs. + SeparateSubExprs(SubExprs, Uses[i].Base, SE); + + // Remove any common subexpressions. + for (unsigned j = 0, e = SubExprs.size(); j != e; ++j) + if (SubExpressionUseData.count(SubExprs[j])) { + SubExprs.erase(SubExprs.begin()+j); + --j; --e; + } + + // Finally, add the non-shared expressions together. + if (SubExprs.empty()) + Uses[i].Base = Zero; + else + Uses[i].Base = SE->getAddExpr(SubExprs); + SubExprs.clear(); + } + + return Result; +} + +/// ValidScale - Check whether the given Scale is valid for all loads and +/// stores in UsersToProcess. +/// +bool LoopStrengthReduce::ValidScale(bool HasBaseReg, int64_t Scale, + const std::vector<BasedUser>& UsersToProcess) { + if (!TLI) + return true; + + for (unsigned i = 0, e = UsersToProcess.size(); i!=e; ++i) { + // If this is a load or other access, pass the type of the access in. + const Type *AccessTy = Type::VoidTy; + if (isAddressUse(UsersToProcess[i].Inst, + UsersToProcess[i].OperandValToReplace)) + AccessTy = getAccessType(UsersToProcess[i].Inst); + else if (isa<PHINode>(UsersToProcess[i].Inst)) + continue; + + TargetLowering::AddrMode AM; + if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(UsersToProcess[i].Imm)) + AM.BaseOffs = SC->getValue()->getSExtValue(); + AM.HasBaseReg = HasBaseReg || !UsersToProcess[i].Base->isZero(); + AM.Scale = Scale; + + // If load[imm+r*scale] is illegal, bail out. + if (!TLI->isLegalAddressingMode(AM, AccessTy)) + return false; + } + return true; +} + +/// ValidOffset - Check whether the given Offset is valid for all loads and +/// stores in UsersToProcess. +/// +bool LoopStrengthReduce::ValidOffset(bool HasBaseReg, + int64_t Offset, + int64_t Scale, + const std::vector<BasedUser>& UsersToProcess) { + if (!TLI) + return true; + + for (unsigned i=0, e = UsersToProcess.size(); i!=e; ++i) { + // If this is a load or other access, pass the type of the access in. + const Type *AccessTy = Type::VoidTy; + if (isAddressUse(UsersToProcess[i].Inst, + UsersToProcess[i].OperandValToReplace)) + AccessTy = getAccessType(UsersToProcess[i].Inst); + else if (isa<PHINode>(UsersToProcess[i].Inst)) + continue; + + TargetLowering::AddrMode AM; + if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(UsersToProcess[i].Imm)) + AM.BaseOffs = SC->getValue()->getSExtValue(); + AM.BaseOffs = (uint64_t)AM.BaseOffs + (uint64_t)Offset; + AM.HasBaseReg = HasBaseReg || !UsersToProcess[i].Base->isZero(); + AM.Scale = Scale; + + // If load[imm+r*scale] is illegal, bail out. + if (!TLI->isLegalAddressingMode(AM, AccessTy)) + return false; + } + return true; +} + +/// RequiresTypeConversion - Returns true if converting Ty1 to Ty2 is not +/// a nop. +bool LoopStrengthReduce::RequiresTypeConversion(const Type *Ty1, + const Type *Ty2) { + if (Ty1 == Ty2) + return false; + Ty1 = SE->getEffectiveSCEVType(Ty1); + Ty2 = SE->getEffectiveSCEVType(Ty2); + if (Ty1 == Ty2) + return false; + if (Ty1->canLosslesslyBitCastTo(Ty2)) + return false; + if (TLI && TLI->isTruncateFree(Ty1, Ty2)) + return false; + return true; +} + +/// CheckForIVReuse - Returns the multiple if the stride is the multiple +/// of a previous stride and it is a legal value for the target addressing +/// mode scale component and optional base reg. This allows the users of +/// this stride to be rewritten as prev iv * factor. It returns 0 if no +/// reuse is possible. Factors can be negative on same targets, e.g. ARM. +/// +/// If all uses are outside the loop, we don't require that all multiplies +/// be folded into the addressing mode, nor even that the factor be constant; +/// a multiply (executed once) outside the loop is better than another IV +/// within. Well, usually. +SCEVHandle LoopStrengthReduce::CheckForIVReuse(bool HasBaseReg, + bool AllUsesAreAddresses, + bool AllUsesAreOutsideLoop, + const SCEVHandle &Stride, + IVExpr &IV, const Type *Ty, + const std::vector<BasedUser>& UsersToProcess) { + if (StrideNoReuse.count(Stride)) + return SE->getIntegerSCEV(0, Stride->getType()); + + if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Stride)) { + int64_t SInt = SC->getValue()->getSExtValue(); + for (unsigned NewStride = 0, e = IU->StrideOrder.size(); + NewStride != e; ++NewStride) { + std::map<SCEVHandle, IVsOfOneStride>::iterator SI = + IVsByStride.find(IU->StrideOrder[NewStride]); + if (SI == IVsByStride.end() || !isa<SCEVConstant>(SI->first) || + StrideNoReuse.count(SI->first)) + continue; + int64_t SSInt = cast<SCEVConstant>(SI->first)->getValue()->getSExtValue(); + if (SI->first != Stride && + (unsigned(abs64(SInt)) < SSInt || (SInt % SSInt) != 0)) + continue; + int64_t Scale = SInt / SSInt; + // Check that this stride is valid for all the types used for loads and + // stores; if it can be used for some and not others, we might as well use + // the original stride everywhere, since we have to create the IV for it + // anyway. If the scale is 1, then we don't need to worry about folding + // multiplications. + if (Scale == 1 || + (AllUsesAreAddresses && + ValidScale(HasBaseReg, Scale, UsersToProcess))) { + // Prefer to reuse an IV with a base of zero. + for (std::vector<IVExpr>::iterator II = SI->second.IVs.begin(), + IE = SI->second.IVs.end(); II != IE; ++II) + // Only reuse previous IV if it would not require a type conversion + // and if the base difference can be folded. + if (II->Base->isZero() && + !RequiresTypeConversion(II->Base->getType(), Ty)) { + IV = *II; + return SE->getIntegerSCEV(Scale, Stride->getType()); + } + // Otherwise, settle for an IV with a foldable base. + if (AllUsesAreAddresses) + for (std::vector<IVExpr>::iterator II = SI->second.IVs.begin(), + IE = SI->second.IVs.end(); II != IE; ++II) + // Only reuse previous IV if it would not require a type conversion + // and if the base difference can be folded. + if (SE->getEffectiveSCEVType(II->Base->getType()) == + SE->getEffectiveSCEVType(Ty) && + isa<SCEVConstant>(II->Base)) { + int64_t Base = + cast<SCEVConstant>(II->Base)->getValue()->getSExtValue(); + if (Base > INT32_MIN && Base <= INT32_MAX && + ValidOffset(HasBaseReg, -Base * Scale, + Scale, UsersToProcess)) { + IV = *II; + return SE->getIntegerSCEV(Scale, Stride->getType()); + } + } + } + } + } else if (AllUsesAreOutsideLoop) { + // Accept nonconstant strides here; it is really really right to substitute + // an existing IV if we can. + for (unsigned NewStride = 0, e = IU->StrideOrder.size(); + NewStride != e; ++NewStride) { + std::map<SCEVHandle, IVsOfOneStride>::iterator SI = + IVsByStride.find(IU->StrideOrder[NewStride]); + if (SI == IVsByStride.end() || !isa<SCEVConstant>(SI->first)) + continue; + int64_t SSInt = cast<SCEVConstant>(SI->first)->getValue()->getSExtValue(); + if (SI->first != Stride && SSInt != 1) + continue; + for (std::vector<IVExpr>::iterator II = SI->second.IVs.begin(), + IE = SI->second.IVs.end(); II != IE; ++II) + // Accept nonzero base here. + // Only reuse previous IV if it would not require a type conversion. + if (!RequiresTypeConversion(II->Base->getType(), Ty)) { + IV = *II; + return Stride; + } + } + // Special case, old IV is -1*x and this one is x. Can treat this one as + // -1*old. + for (unsigned NewStride = 0, e = IU->StrideOrder.size(); + NewStride != e; ++NewStride) { + std::map<SCEVHandle, IVsOfOneStride>::iterator SI = + IVsByStride.find(IU->StrideOrder[NewStride]); + if (SI == IVsByStride.end()) + continue; + if (const SCEVMulExpr *ME = dyn_cast<SCEVMulExpr>(SI->first)) + if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(ME->getOperand(0))) + if (Stride == ME->getOperand(1) && + SC->getValue()->getSExtValue() == -1LL) + for (std::vector<IVExpr>::iterator II = SI->second.IVs.begin(), + IE = SI->second.IVs.end(); II != IE; ++II) + // Accept nonzero base here. + // Only reuse previous IV if it would not require type conversion. + if (!RequiresTypeConversion(II->Base->getType(), Ty)) { + IV = *II; + return SE->getIntegerSCEV(-1LL, Stride->getType()); + } + } + } + return SE->getIntegerSCEV(0, Stride->getType()); +} + +/// PartitionByIsUseOfPostIncrementedValue - Simple boolean predicate that +/// returns true if Val's isUseOfPostIncrementedValue is true. +static bool PartitionByIsUseOfPostIncrementedValue(const BasedUser &Val) { + return Val.isUseOfPostIncrementedValue; +} + +/// isNonConstantNegative - Return true if the specified scev is negated, but +/// not a constant. +static bool isNonConstantNegative(const SCEVHandle &Expr) { + const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Expr); + if (!Mul) return false; + + // If there is a constant factor, it will be first. + const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0)); + if (!SC) return false; + + // Return true if the value is negative, this matches things like (-42 * V). + return SC->getValue()->getValue().isNegative(); +} + +// CollectIVUsers - Transform our list of users and offsets to a bit more +// complex table. In this new vector, each 'BasedUser' contains 'Base', the base +// of the strided accesses, as well as the old information from Uses. We +// progressively move information from the Base field to the Imm field, until +// we eventually have the full access expression to rewrite the use. +SCEVHandle LoopStrengthReduce::CollectIVUsers(const SCEVHandle &Stride, + IVUsersOfOneStride &Uses, + Loop *L, + bool &AllUsesAreAddresses, + bool &AllUsesAreOutsideLoop, + std::vector<BasedUser> &UsersToProcess) { + // FIXME: Generalize to non-affine IV's. + if (!Stride->isLoopInvariant(L)) + return SE->getIntegerSCEV(0, Stride->getType()); + + UsersToProcess.reserve(Uses.Users.size()); + for (ilist<IVStrideUse>::iterator I = Uses.Users.begin(), + E = Uses.Users.end(); I != E; ++I) { + UsersToProcess.push_back(BasedUser(*I, SE)); + + // Move any loop variant operands from the offset field to the immediate + // field of the use, so that we don't try to use something before it is + // computed. + MoveLoopVariantsToImmediateField(UsersToProcess.back().Base, + UsersToProcess.back().Imm, L, SE); + assert(UsersToProcess.back().Base->isLoopInvariant(L) && + "Base value is not loop invariant!"); + } + + // We now have a whole bunch of uses of like-strided induction variables, but + // they might all have different bases. We want to emit one PHI node for this + // stride which we fold as many common expressions (between the IVs) into as + // possible. Start by identifying the common expressions in the base values + // for the strides (e.g. if we have "A+C+B" and "A+B+D" as our bases, find + // "A+B"), emit it to the preheader, then remove the expression from the + // UsersToProcess base values. + SCEVHandle CommonExprs = + RemoveCommonExpressionsFromUseBases(UsersToProcess, SE, L, TLI); + + // Next, figure out what we can represent in the immediate fields of + // instructions. If we can represent anything there, move it to the imm + // fields of the BasedUsers. We do this so that it increases the commonality + // of the remaining uses. + unsigned NumPHI = 0; + bool HasAddress = false; + for (unsigned i = 0, e = UsersToProcess.size(); i != e; ++i) { + // If the user is not in the current loop, this means it is using the exit + // value of the IV. Do not put anything in the base, make sure it's all in + // the immediate field to allow as much factoring as possible. + if (!L->contains(UsersToProcess[i].Inst->getParent())) { + UsersToProcess[i].Imm = SE->getAddExpr(UsersToProcess[i].Imm, + UsersToProcess[i].Base); + UsersToProcess[i].Base = + SE->getIntegerSCEV(0, UsersToProcess[i].Base->getType()); + } else { + // Not all uses are outside the loop. + AllUsesAreOutsideLoop = false; + + // Addressing modes can be folded into loads and stores. Be careful that + // the store is through the expression, not of the expression though. + bool isPHI = false; + bool isAddress = isAddressUse(UsersToProcess[i].Inst, + UsersToProcess[i].OperandValToReplace); + if (isa<PHINode>(UsersToProcess[i].Inst)) { + isPHI = true; + ++NumPHI; + } + + if (isAddress) + HasAddress = true; + + // If this use isn't an address, then not all uses are addresses. + if (!isAddress && !isPHI) + AllUsesAreAddresses = false; + + MoveImmediateValues(TLI, UsersToProcess[i].Inst, UsersToProcess[i].Base, + UsersToProcess[i].Imm, isAddress, L, SE); + } + } + + // If one of the use is a PHI node and all other uses are addresses, still + // allow iv reuse. Essentially we are trading one constant multiplication + // for one fewer iv. + if (NumPHI > 1) + AllUsesAreAddresses = false; + + // There are no in-loop address uses. + if (AllUsesAreAddresses && (!HasAddress && !AllUsesAreOutsideLoop)) + AllUsesAreAddresses = false; + + return CommonExprs; +} + +/// ShouldUseFullStrengthReductionMode - Test whether full strength-reduction +/// is valid and profitable for the given set of users of a stride. In +/// full strength-reduction mode, all addresses at the current stride are +/// strength-reduced all the way down to pointer arithmetic. +/// +bool LoopStrengthReduce::ShouldUseFullStrengthReductionMode( + const std::vector<BasedUser> &UsersToProcess, + const Loop *L, + bool AllUsesAreAddresses, + SCEVHandle Stride) { + if (!EnableFullLSRMode) + return false; + + // The heuristics below aim to avoid increasing register pressure, but + // fully strength-reducing all the addresses increases the number of + // add instructions, so don't do this when optimizing for size. + // TODO: If the loop is large, the savings due to simpler addresses + // may oughtweight the costs of the extra increment instructions. + if (L->getHeader()->getParent()->hasFnAttr(Attribute::OptimizeForSize)) + return false; + + // TODO: For now, don't do full strength reduction if there could + // potentially be greater-stride multiples of the current stride + // which could reuse the current stride IV. + if (IU->StrideOrder.back() != Stride) + return false; + + // Iterate through the uses to find conditions that automatically rule out + // full-lsr mode. + for (unsigned i = 0, e = UsersToProcess.size(); i != e; ) { + const SCEV *Base = UsersToProcess[i].Base; + const SCEV *Imm = UsersToProcess[i].Imm; + // If any users have a loop-variant component, they can't be fully + // strength-reduced. + if (Imm && !Imm->isLoopInvariant(L)) + return false; + // If there are to users with the same base and the difference between + // the two Imm values can't be folded into the address, full + // strength reduction would increase register pressure. + do { + const SCEV *CurImm = UsersToProcess[i].Imm; + if ((CurImm || Imm) && CurImm != Imm) { + if (!CurImm) CurImm = SE->getIntegerSCEV(0, Stride->getType()); + if (!Imm) Imm = SE->getIntegerSCEV(0, Stride->getType()); + const Instruction *Inst = UsersToProcess[i].Inst; + const Type *AccessTy = getAccessType(Inst); + SCEVHandle Diff = SE->getMinusSCEV(UsersToProcess[i].Imm, Imm); + if (!Diff->isZero() && + (!AllUsesAreAddresses || + !fitsInAddressMode(Diff, AccessTy, TLI, /*HasBaseReg=*/true))) + return false; + } + } while (++i != e && Base == UsersToProcess[i].Base); + } + + // If there's exactly one user in this stride, fully strength-reducing it + // won't increase register pressure. If it's starting from a non-zero base, + // it'll be simpler this way. + if (UsersToProcess.size() == 1 && !UsersToProcess[0].Base->isZero()) + return true; + + // Otherwise, if there are any users in this stride that don't require + // a register for their base, full strength-reduction will increase + // register pressure. + for (unsigned i = 0, e = UsersToProcess.size(); i != e; ++i) + if (UsersToProcess[i].Base->isZero()) + return false; + + // Otherwise, go for it. + return true; +} + +/// InsertAffinePhi Create and insert a PHI node for an induction variable +/// with the specified start and step values in the specified loop. +/// +/// If NegateStride is true, the stride should be negated by using a +/// subtract instead of an add. +/// +/// Return the created phi node. +/// +static PHINode *InsertAffinePhi(SCEVHandle Start, SCEVHandle Step, + Instruction *IVIncInsertPt, + const Loop *L, + SCEVExpander &Rewriter) { + assert(Start->isLoopInvariant(L) && "New PHI start is not loop invariant!"); + assert(Step->isLoopInvariant(L) && "New PHI stride is not loop invariant!"); + + BasicBlock *Header = L->getHeader(); + BasicBlock *Preheader = L->getLoopPreheader(); + BasicBlock *LatchBlock = L->getLoopLatch(); + const Type *Ty = Start->getType(); + Ty = Rewriter.SE.getEffectiveSCEVType(Ty); + + PHINode *PN = PHINode::Create(Ty, "lsr.iv", Header->begin()); + PN->addIncoming(Rewriter.expandCodeFor(Start, Ty, Preheader->getTerminator()), + Preheader); + + // If the stride is negative, insert a sub instead of an add for the + // increment. + bool isNegative = isNonConstantNegative(Step); + SCEVHandle IncAmount = Step; + if (isNegative) + IncAmount = Rewriter.SE.getNegativeSCEV(Step); + + // Insert an add instruction right before the terminator corresponding + // to the back-edge or just before the only use. The location is determined + // by the caller and passed in as IVIncInsertPt. + Value *StepV = Rewriter.expandCodeFor(IncAmount, Ty, + Preheader->getTerminator()); + Instruction *IncV; + if (isNegative) { + IncV = BinaryOperator::CreateSub(PN, StepV, "lsr.iv.next", + IVIncInsertPt); + } else { + IncV = BinaryOperator::CreateAdd(PN, StepV, "lsr.iv.next", + IVIncInsertPt); + } + if (!isa<ConstantInt>(StepV)) ++NumVariable; + + PN->addIncoming(IncV, LatchBlock); + + ++NumInserted; + return PN; +} + +static void SortUsersToProcess(std::vector<BasedUser> &UsersToProcess) { + // We want to emit code for users inside the loop first. To do this, we + // rearrange BasedUser so that the entries at the end have + // isUseOfPostIncrementedValue = false, because we pop off the end of the + // vector (so we handle them first). + std::partition(UsersToProcess.begin(), UsersToProcess.end(), + PartitionByIsUseOfPostIncrementedValue); + + // Sort this by base, so that things with the same base are handled + // together. By partitioning first and stable-sorting later, we are + // guaranteed that within each base we will pop off users from within the + // loop before users outside of the loop with a particular base. + // + // We would like to use stable_sort here, but we can't. The problem is that + // SCEVHandle's don't have a deterministic ordering w.r.t to each other, so + // we don't have anything to do a '<' comparison on. Because we think the + // number of uses is small, do a horrible bubble sort which just relies on + // ==. + for (unsigned i = 0, e = UsersToProcess.size(); i != e; ++i) { + // Get a base value. + SCEVHandle Base = UsersToProcess[i].Base; + + // Compact everything with this base to be consecutive with this one. + for (unsigned j = i+1; j != e; ++j) { + if (UsersToProcess[j].Base == Base) { + std::swap(UsersToProcess[i+1], UsersToProcess[j]); + ++i; + } + } + } +} + +/// PrepareToStrengthReduceFully - Prepare to fully strength-reduce +/// UsersToProcess, meaning lowering addresses all the way down to direct +/// pointer arithmetic. +/// +void +LoopStrengthReduce::PrepareToStrengthReduceFully( + std::vector<BasedUser> &UsersToProcess, + SCEVHandle Stride, + SCEVHandle CommonExprs, + const Loop *L, + SCEVExpander &PreheaderRewriter) { + DOUT << " Fully reducing all users\n"; + + // Rewrite the UsersToProcess records, creating a separate PHI for each + // unique Base value. + Instruction *IVIncInsertPt = L->getLoopLatch()->getTerminator(); + for (unsigned i = 0, e = UsersToProcess.size(); i != e; ) { + // TODO: The uses are grouped by base, but not sorted. We arbitrarily + // pick the first Imm value here to start with, and adjust it for the + // other uses. + SCEVHandle Imm = UsersToProcess[i].Imm; + SCEVHandle Base = UsersToProcess[i].Base; + SCEVHandle Start = SE->getAddExpr(CommonExprs, Base, Imm); + PHINode *Phi = InsertAffinePhi(Start, Stride, IVIncInsertPt, L, + PreheaderRewriter); + // Loop over all the users with the same base. + do { + UsersToProcess[i].Base = SE->getIntegerSCEV(0, Stride->getType()); + UsersToProcess[i].Imm = SE->getMinusSCEV(UsersToProcess[i].Imm, Imm); + UsersToProcess[i].Phi = Phi; + assert(UsersToProcess[i].Imm->isLoopInvariant(L) && + "ShouldUseFullStrengthReductionMode should reject this!"); + } while (++i != e && Base == UsersToProcess[i].Base); + } +} + +/// FindIVIncInsertPt - Return the location to insert the increment instruction. +/// If the only use if a use of postinc value, (must be the loop termination +/// condition), then insert it just before the use. +static Instruction *FindIVIncInsertPt(std::vector<BasedUser> &UsersToProcess, + const Loop *L) { + if (UsersToProcess.size() == 1 && + UsersToProcess[0].isUseOfPostIncrementedValue && + L->contains(UsersToProcess[0].Inst->getParent())) + return UsersToProcess[0].Inst; + return L->getLoopLatch()->getTerminator(); +} + +/// PrepareToStrengthReduceWithNewPhi - Insert a new induction variable for the +/// given users to share. +/// +void +LoopStrengthReduce::PrepareToStrengthReduceWithNewPhi( + std::vector<BasedUser> &UsersToProcess, + SCEVHandle Stride, + SCEVHandle CommonExprs, + Value *CommonBaseV, + Instruction *IVIncInsertPt, + const Loop *L, + SCEVExpander &PreheaderRewriter) { + DOUT << " Inserting new PHI:\n"; + + PHINode *Phi = InsertAffinePhi(SE->getUnknown(CommonBaseV), + Stride, IVIncInsertPt, L, + PreheaderRewriter); + + // Remember this in case a later stride is multiple of this. + IVsByStride[Stride].addIV(Stride, CommonExprs, Phi); + + // All the users will share this new IV. + for (unsigned i = 0, e = UsersToProcess.size(); i != e; ++i) + UsersToProcess[i].Phi = Phi; + + DOUT << " IV="; + DEBUG(WriteAsOperand(*DOUT, Phi, /*PrintType=*/false)); + DOUT << "\n"; +} + +/// PrepareToStrengthReduceFromSmallerStride - Prepare for the given users to +/// reuse an induction variable with a stride that is a factor of the current +/// induction variable. +/// +void +LoopStrengthReduce::PrepareToStrengthReduceFromSmallerStride( + std::vector<BasedUser> &UsersToProcess, + Value *CommonBaseV, + const IVExpr &ReuseIV, + Instruction *PreInsertPt) { + DOUT << " Rewriting in terms of existing IV of STRIDE " << *ReuseIV.Stride + << " and BASE " << *ReuseIV.Base << "\n"; + + // All the users will share the reused IV. + for (unsigned i = 0, e = UsersToProcess.size(); i != e; ++i) + UsersToProcess[i].Phi = ReuseIV.PHI; + + Constant *C = dyn_cast<Constant>(CommonBaseV); + if (C && + (!C->isNullValue() && + !fitsInAddressMode(SE->getUnknown(CommonBaseV), CommonBaseV->getType(), + TLI, false))) + // We want the common base emitted into the preheader! This is just + // using cast as a copy so BitCast (no-op cast) is appropriate + CommonBaseV = new BitCastInst(CommonBaseV, CommonBaseV->getType(), + "commonbase", PreInsertPt); +} + +static bool IsImmFoldedIntoAddrMode(GlobalValue *GV, int64_t Offset, + const Type *AccessTy, + std::vector<BasedUser> &UsersToProcess, + const TargetLowering *TLI) { + SmallVector<Instruction*, 16> AddrModeInsts; + for (unsigned i = 0, e = UsersToProcess.size(); i != e; ++i) { + if (UsersToProcess[i].isUseOfPostIncrementedValue) + continue; + ExtAddrMode AddrMode = + AddressingModeMatcher::Match(UsersToProcess[i].OperandValToReplace, + AccessTy, UsersToProcess[i].Inst, + AddrModeInsts, *TLI); + if (GV && GV != AddrMode.BaseGV) + return false; + if (Offset && !AddrMode.BaseOffs) + // FIXME: How to accurate check it's immediate offset is folded. + return false; + AddrModeInsts.clear(); + } + return true; +} + +/// StrengthReduceStridedIVUsers - Strength reduce all of the users of a single +/// stride of IV. All of the users may have different starting values, and this +/// may not be the only stride. +void LoopStrengthReduce::StrengthReduceStridedIVUsers(const SCEVHandle &Stride, + IVUsersOfOneStride &Uses, + Loop *L) { + // If all the users are moved to another stride, then there is nothing to do. + if (Uses.Users.empty()) + return; + + // Keep track if every use in UsersToProcess is an address. If they all are, + // we may be able to rewrite the entire collection of them in terms of a + // smaller-stride IV. + bool AllUsesAreAddresses = true; + + // Keep track if every use of a single stride is outside the loop. If so, + // we want to be more aggressive about reusing a smaller-stride IV; a + // multiply outside the loop is better than another IV inside. Well, usually. + bool AllUsesAreOutsideLoop = true; + + // Transform our list of users and offsets to a bit more complex table. In + // this new vector, each 'BasedUser' contains 'Base' the base of the + // strided accessas well as the old information from Uses. We progressively + // move information from the Base field to the Imm field, until we eventually + // have the full access expression to rewrite the use. + std::vector<BasedUser> UsersToProcess; + SCEVHandle CommonExprs = CollectIVUsers(Stride, Uses, L, AllUsesAreAddresses, + AllUsesAreOutsideLoop, + UsersToProcess); + + // Sort the UsersToProcess array so that users with common bases are + // next to each other. + SortUsersToProcess(UsersToProcess); + + // If we managed to find some expressions in common, we'll need to carry + // their value in a register and add it in for each use. This will take up + // a register operand, which potentially restricts what stride values are + // valid. + bool HaveCommonExprs = !CommonExprs->isZero(); + const Type *ReplacedTy = CommonExprs->getType(); + + // If all uses are addresses, consider sinking the immediate part of the + // common expression back into uses if they can fit in the immediate fields. + if (TLI && HaveCommonExprs && AllUsesAreAddresses) { + SCEVHandle NewCommon = CommonExprs; + SCEVHandle Imm = SE->getIntegerSCEV(0, ReplacedTy); + MoveImmediateValues(TLI, Type::VoidTy, NewCommon, Imm, true, L, SE); + if (!Imm->isZero()) { + bool DoSink = true; + + // If the immediate part of the common expression is a GV, check if it's + // possible to fold it into the target addressing mode. + GlobalValue *GV = 0; + if (const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(Imm)) + GV = dyn_cast<GlobalValue>(SU->getValue()); + int64_t Offset = 0; + if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Imm)) + Offset = SC->getValue()->getSExtValue(); + if (GV || Offset) + // Pass VoidTy as the AccessTy to be conservative, because + // there could be multiple access types among all the uses. + DoSink = IsImmFoldedIntoAddrMode(GV, Offset, Type::VoidTy, + UsersToProcess, TLI); + + if (DoSink) { + DOUT << " Sinking " << *Imm << " back down into uses\n"; + for (unsigned i = 0, e = UsersToProcess.size(); i != e; ++i) + UsersToProcess[i].Imm = SE->getAddExpr(UsersToProcess[i].Imm, Imm); + CommonExprs = NewCommon; + HaveCommonExprs = !CommonExprs->isZero(); + ++NumImmSunk; + } + } + } + + // Now that we know what we need to do, insert the PHI node itself. + // + DOUT << "LSR: Examining IVs of TYPE " << *ReplacedTy << " of STRIDE " + << *Stride << ":\n" + << " Common base: " << *CommonExprs << "\n"; + + SCEVExpander Rewriter(*SE); + SCEVExpander PreheaderRewriter(*SE); + + BasicBlock *Preheader = L->getLoopPreheader(); + Instruction *PreInsertPt = Preheader->getTerminator(); + BasicBlock *LatchBlock = L->getLoopLatch(); + Instruction *IVIncInsertPt = LatchBlock->getTerminator(); + + Value *CommonBaseV = Constant::getNullValue(ReplacedTy); + + SCEVHandle RewriteFactor = SE->getIntegerSCEV(0, ReplacedTy); + IVExpr ReuseIV(SE->getIntegerSCEV(0, Type::Int32Ty), + SE->getIntegerSCEV(0, Type::Int32Ty), + 0); + + /// Choose a strength-reduction strategy and prepare for it by creating + /// the necessary PHIs and adjusting the bookkeeping. + if (ShouldUseFullStrengthReductionMode(UsersToProcess, L, + AllUsesAreAddresses, Stride)) { + PrepareToStrengthReduceFully(UsersToProcess, Stride, CommonExprs, L, + PreheaderRewriter); + } else { + // Emit the initial base value into the loop preheader. + CommonBaseV = PreheaderRewriter.expandCodeFor(CommonExprs, ReplacedTy, + PreInsertPt); + + // If all uses are addresses, check if it is possible to reuse an IV. The + // new IV must have a stride that is a multiple of the old stride; the + // multiple must be a number that can be encoded in the scale field of the + // target addressing mode; and we must have a valid instruction after this + // substitution, including the immediate field, if any. + RewriteFactor = CheckForIVReuse(HaveCommonExprs, AllUsesAreAddresses, + AllUsesAreOutsideLoop, + Stride, ReuseIV, ReplacedTy, + UsersToProcess); + if (!RewriteFactor->isZero()) + PrepareToStrengthReduceFromSmallerStride(UsersToProcess, CommonBaseV, + ReuseIV, PreInsertPt); + else { + IVIncInsertPt = FindIVIncInsertPt(UsersToProcess, L); + PrepareToStrengthReduceWithNewPhi(UsersToProcess, Stride, CommonExprs, + CommonBaseV, IVIncInsertPt, + L, PreheaderRewriter); + } + } + + // Process all the users now, replacing their strided uses with + // strength-reduced forms. This outer loop handles all bases, the inner + // loop handles all users of a particular base. + while (!UsersToProcess.empty()) { + SCEVHandle Base = UsersToProcess.back().Base; + Instruction *Inst = UsersToProcess.back().Inst; + + // Emit the code for Base into the preheader. + Value *BaseV = 0; + if (!Base->isZero()) { + BaseV = PreheaderRewriter.expandCodeFor(Base, 0, PreInsertPt); + + DOUT << " INSERTING code for BASE = " << *Base << ":"; + if (BaseV->hasName()) + DOUT << " Result value name = %" << BaseV->getNameStr(); + DOUT << "\n"; + + // If BaseV is a non-zero constant, make sure that it gets inserted into + // the preheader, instead of being forward substituted into the uses. We + // do this by forcing a BitCast (noop cast) to be inserted into the + // preheader in this case. + if (!fitsInAddressMode(Base, getAccessType(Inst), TLI, false)) { + // We want this constant emitted into the preheader! This is just + // using cast as a copy so BitCast (no-op cast) is appropriate + BaseV = new BitCastInst(BaseV, BaseV->getType(), "preheaderinsert", + PreInsertPt); + } + } + + // Emit the code to add the immediate offset to the Phi value, just before + // the instructions that we identified as using this stride and base. + do { + // FIXME: Use emitted users to emit other users. + BasedUser &User = UsersToProcess.back(); + + DOUT << " Examining "; + if (User.isUseOfPostIncrementedValue) + DOUT << "postinc"; + else + DOUT << "preinc"; + DOUT << " use "; + DEBUG(WriteAsOperand(*DOUT, UsersToProcess.back().OperandValToReplace, + /*PrintType=*/false)); + DOUT << " in Inst: " << *(User.Inst); + + // If this instruction wants to use the post-incremented value, move it + // after the post-inc and use its value instead of the PHI. + Value *RewriteOp = User.Phi; + if (User.isUseOfPostIncrementedValue) { + RewriteOp = User.Phi->getIncomingValueForBlock(LatchBlock); + // If this user is in the loop, make sure it is the last thing in the + // loop to ensure it is dominated by the increment. In case it's the + // only use of the iv, the increment instruction is already before the + // use. + if (L->contains(User.Inst->getParent()) && User.Inst != IVIncInsertPt) + User.Inst->moveBefore(IVIncInsertPt); + } + + SCEVHandle RewriteExpr = SE->getUnknown(RewriteOp); + + if (SE->getEffectiveSCEVType(RewriteOp->getType()) != + SE->getEffectiveSCEVType(ReplacedTy)) { + assert(SE->getTypeSizeInBits(RewriteOp->getType()) > + SE->getTypeSizeInBits(ReplacedTy) && + "Unexpected widening cast!"); + RewriteExpr = SE->getTruncateExpr(RewriteExpr, ReplacedTy); + } + + // If we had to insert new instructions for RewriteOp, we have to + // consider that they may not have been able to end up immediately + // next to RewriteOp, because non-PHI instructions may never precede + // PHI instructions in a block. In this case, remember where the last + // instruction was inserted so that if we're replacing a different + // PHI node, we can use the later point to expand the final + // RewriteExpr. + Instruction *NewBasePt = dyn_cast<Instruction>(RewriteOp); + if (RewriteOp == User.Phi) NewBasePt = 0; + + // Clear the SCEVExpander's expression map so that we are guaranteed + // to have the code emitted where we expect it. + Rewriter.clear(); + + // If we are reusing the iv, then it must be multiplied by a constant + // factor to take advantage of the addressing mode scale component. + if (!RewriteFactor->isZero()) { + // If we're reusing an IV with a nonzero base (currently this happens + // only when all reuses are outside the loop) subtract that base here. + // The base has been used to initialize the PHI node but we don't want + // it here. + if (!ReuseIV.Base->isZero()) { + SCEVHandle typedBase = ReuseIV.Base; + if (SE->getEffectiveSCEVType(RewriteExpr->getType()) != + SE->getEffectiveSCEVType(ReuseIV.Base->getType())) { + // It's possible the original IV is a larger type than the new IV, + // in which case we have to truncate the Base. We checked in + // RequiresTypeConversion that this is valid. + assert(SE->getTypeSizeInBits(RewriteExpr->getType()) < + SE->getTypeSizeInBits(ReuseIV.Base->getType()) && + "Unexpected lengthening conversion!"); + typedBase = SE->getTruncateExpr(ReuseIV.Base, + RewriteExpr->getType()); + } + RewriteExpr = SE->getMinusSCEV(RewriteExpr, typedBase); + } + + // Multiply old variable, with base removed, by new scale factor. + RewriteExpr = SE->getMulExpr(RewriteFactor, + RewriteExpr); + + // The common base is emitted in the loop preheader. But since we + // are reusing an IV, it has not been used to initialize the PHI node. + // Add it to the expression used to rewrite the uses. + // When this use is outside the loop, we earlier subtracted the + // common base, and are adding it back here. Use the same expression + // as before, rather than CommonBaseV, so DAGCombiner will zap it. + if (!CommonExprs->isZero()) { + if (L->contains(User.Inst->getParent())) + RewriteExpr = SE->getAddExpr(RewriteExpr, + SE->getUnknown(CommonBaseV)); + else + RewriteExpr = SE->getAddExpr(RewriteExpr, CommonExprs); + } + } + + // Now that we know what we need to do, insert code before User for the + // immediate and any loop-variant expressions. + if (BaseV) + // Add BaseV to the PHI value if needed. + RewriteExpr = SE->getAddExpr(RewriteExpr, SE->getUnknown(BaseV)); + + User.RewriteInstructionToUseNewBase(RewriteExpr, NewBasePt, + Rewriter, L, this, *LI, + DeadInsts); + + // Mark old value we replaced as possibly dead, so that it is eliminated + // if we just replaced the last use of that value. + DeadInsts.push_back(User.OperandValToReplace); + + UsersToProcess.pop_back(); + ++NumReduced; + + // If there are any more users to process with the same base, process them + // now. We sorted by base above, so we just have to check the last elt. + } while (!UsersToProcess.empty() && UsersToProcess.back().Base == Base); + // TODO: Next, find out which base index is the most common, pull it out. + } + + // IMPORTANT TODO: Figure out how to partition the IV's with this stride, but + // different starting values, into different PHIs. +} + +/// FindIVUserForCond - If Cond has an operand that is an expression of an IV, +/// set the IV user and stride information and return true, otherwise return +/// false. +bool LoopStrengthReduce::FindIVUserForCond(ICmpInst *Cond, IVStrideUse *&CondUse, + const SCEVHandle *&CondStride) { + for (unsigned Stride = 0, e = IU->StrideOrder.size(); + Stride != e && !CondUse; ++Stride) { + std::map<SCEVHandle, IVUsersOfOneStride *>::iterator SI = + IU->IVUsesByStride.find(IU->StrideOrder[Stride]); + assert(SI != IU->IVUsesByStride.end() && "Stride doesn't exist!"); + + for (ilist<IVStrideUse>::iterator UI = SI->second->Users.begin(), + E = SI->second->Users.end(); UI != E; ++UI) + if (UI->getUser() == Cond) { + // NOTE: we could handle setcc instructions with multiple uses here, but + // InstCombine does it as well for simple uses, it's not clear that it + // occurs enough in real life to handle. + CondUse = UI; + CondStride = &SI->first; + return true; + } + } + return false; +} + +namespace { + // Constant strides come first which in turns are sorted by their absolute + // values. If absolute values are the same, then positive strides comes first. + // e.g. + // 4, -1, X, 1, 2 ==> 1, -1, 2, 4, X + struct StrideCompare { + const ScalarEvolution *SE; + explicit StrideCompare(const ScalarEvolution *se) : SE(se) {} + + bool operator()(const SCEVHandle &LHS, const SCEVHandle &RHS) { + const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS); + const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS); + if (LHSC && RHSC) { + int64_t LV = LHSC->getValue()->getSExtValue(); + int64_t RV = RHSC->getValue()->getSExtValue(); + uint64_t ALV = (LV < 0) ? -LV : LV; + uint64_t ARV = (RV < 0) ? -RV : RV; + if (ALV == ARV) { + if (LV != RV) + return LV > RV; + } else { + return ALV < ARV; + } + + // If it's the same value but different type, sort by bit width so + // that we emit larger induction variables before smaller + // ones, letting the smaller be re-written in terms of larger ones. + return SE->getTypeSizeInBits(RHS->getType()) < + SE->getTypeSizeInBits(LHS->getType()); + } + return LHSC && !RHSC; + } + }; +} + +/// ChangeCompareStride - If a loop termination compare instruction is the +/// only use of its stride, and the compaison is against a constant value, +/// try eliminate the stride by moving the compare instruction to another +/// stride and change its constant operand accordingly. e.g. +/// +/// loop: +/// ... +/// v1 = v1 + 3 +/// v2 = v2 + 1 +/// if (v2 < 10) goto loop +/// => +/// loop: +/// ... +/// v1 = v1 + 3 +/// if (v1 < 30) goto loop +ICmpInst *LoopStrengthReduce::ChangeCompareStride(Loop *L, ICmpInst *Cond, + IVStrideUse* &CondUse, + const SCEVHandle* &CondStride) { + // If there's only one stride in the loop, there's nothing to do here. + if (IU->StrideOrder.size() < 2) + return Cond; + // If there are other users of the condition's stride, don't bother + // trying to change the condition because the stride will still + // remain. + std::map<SCEVHandle, IVUsersOfOneStride *>::iterator I = + IU->IVUsesByStride.find(*CondStride); + if (I == IU->IVUsesByStride.end() || + I->second->Users.size() != 1) + return Cond; + // Only handle constant strides for now. + const SCEVConstant *SC = dyn_cast<SCEVConstant>(*CondStride); + if (!SC) return Cond; + + ICmpInst::Predicate Predicate = Cond->getPredicate(); + int64_t CmpSSInt = SC->getValue()->getSExtValue(); + unsigned BitWidth = SE->getTypeSizeInBits((*CondStride)->getType()); + uint64_t SignBit = 1ULL << (BitWidth-1); + const Type *CmpTy = Cond->getOperand(0)->getType(); + const Type *NewCmpTy = NULL; + unsigned TyBits = SE->getTypeSizeInBits(CmpTy); + unsigned NewTyBits = 0; + SCEVHandle *NewStride = NULL; + Value *NewCmpLHS = NULL; + Value *NewCmpRHS = NULL; + int64_t Scale = 1; + SCEVHandle NewOffset = SE->getIntegerSCEV(0, CmpTy); + + if (ConstantInt *C = dyn_cast<ConstantInt>(Cond->getOperand(1))) { + int64_t CmpVal = C->getValue().getSExtValue(); + + // Check stride constant and the comparision constant signs to detect + // overflow. + if ((CmpVal & SignBit) != (CmpSSInt & SignBit)) + return Cond; + + // Look for a suitable stride / iv as replacement. + for (unsigned i = 0, e = IU->StrideOrder.size(); i != e; ++i) { + std::map<SCEVHandle, IVUsersOfOneStride *>::iterator SI = + IU->IVUsesByStride.find(IU->StrideOrder[i]); + if (!isa<SCEVConstant>(SI->first)) + continue; + int64_t SSInt = cast<SCEVConstant>(SI->first)->getValue()->getSExtValue(); + if (SSInt == CmpSSInt || + abs64(SSInt) < abs64(CmpSSInt) || + (SSInt % CmpSSInt) != 0) + continue; + + Scale = SSInt / CmpSSInt; + int64_t NewCmpVal = CmpVal * Scale; + APInt Mul = APInt(BitWidth*2, CmpVal, true); + Mul = Mul * APInt(BitWidth*2, Scale, true); + // Check for overflow. + if (!Mul.isSignedIntN(BitWidth)) + continue; + // Check for overflow in the stride's type too. + if (!Mul.isSignedIntN(SE->getTypeSizeInBits(SI->first->getType()))) + continue; + + // Watch out for overflow. + if (ICmpInst::isSignedPredicate(Predicate) && + (CmpVal & SignBit) != (NewCmpVal & SignBit)) + continue; + + if (NewCmpVal == CmpVal) + continue; + // Pick the best iv to use trying to avoid a cast. + NewCmpLHS = NULL; + for (ilist<IVStrideUse>::iterator UI = SI->second->Users.begin(), + E = SI->second->Users.end(); UI != E; ++UI) { + Value *Op = UI->getOperandValToReplace(); + + // If the IVStrideUse implies a cast, check for an actual cast which + // can be used to find the original IV expression. + if (SE->getEffectiveSCEVType(Op->getType()) != + SE->getEffectiveSCEVType(SI->first->getType())) { + CastInst *CI = dyn_cast<CastInst>(Op); + // If it's not a simple cast, it's complicated. + if (!CI) + continue; + // If it's a cast from a type other than the stride type, + // it's complicated. + if (CI->getOperand(0)->getType() != SI->first->getType()) + continue; + // Ok, we found the IV expression in the stride's type. + Op = CI->getOperand(0); + } + + NewCmpLHS = Op; + if (NewCmpLHS->getType() == CmpTy) + break; + } + if (!NewCmpLHS) + continue; + + NewCmpTy = NewCmpLHS->getType(); + NewTyBits = SE->getTypeSizeInBits(NewCmpTy); + const Type *NewCmpIntTy = IntegerType::get(NewTyBits); + if (RequiresTypeConversion(NewCmpTy, CmpTy)) { + // Check if it is possible to rewrite it using + // an iv / stride of a smaller integer type. + unsigned Bits = NewTyBits; + if (ICmpInst::isSignedPredicate(Predicate)) + --Bits; + uint64_t Mask = (1ULL << Bits) - 1; + if (((uint64_t)NewCmpVal & Mask) != (uint64_t)NewCmpVal) + continue; + } + + // Don't rewrite if use offset is non-constant and the new type is + // of a different type. + // FIXME: too conservative? + if (NewTyBits != TyBits && !isa<SCEVConstant>(CondUse->getOffset())) + continue; + + bool AllUsesAreAddresses = true; + bool AllUsesAreOutsideLoop = true; + std::vector<BasedUser> UsersToProcess; + SCEVHandle CommonExprs = CollectIVUsers(SI->first, *SI->second, L, + AllUsesAreAddresses, + AllUsesAreOutsideLoop, + UsersToProcess); + // Avoid rewriting the compare instruction with an iv of new stride + // if it's likely the new stride uses will be rewritten using the + // stride of the compare instruction. + if (AllUsesAreAddresses && + ValidScale(!CommonExprs->isZero(), Scale, UsersToProcess)) + continue; + + // Avoid rewriting the compare instruction with an iv which has + // implicit extension or truncation built into it. + // TODO: This is over-conservative. + if (SE->getTypeSizeInBits(CondUse->getOffset()->getType()) != TyBits) + continue; + + // If scale is negative, use swapped predicate unless it's testing + // for equality. + if (Scale < 0 && !Cond->isEquality()) + Predicate = ICmpInst::getSwappedPredicate(Predicate); + + NewStride = &IU->StrideOrder[i]; + if (!isa<PointerType>(NewCmpTy)) + NewCmpRHS = ConstantInt::get(NewCmpTy, NewCmpVal); + else { + ConstantInt *CI = ConstantInt::get(NewCmpIntTy, NewCmpVal); + NewCmpRHS = ConstantExpr::getIntToPtr(CI, NewCmpTy); + } + NewOffset = TyBits == NewTyBits + ? SE->getMulExpr(CondUse->getOffset(), + SE->getConstant(ConstantInt::get(CmpTy, Scale))) + : SE->getConstant(ConstantInt::get(NewCmpIntTy, + cast<SCEVConstant>(CondUse->getOffset())->getValue() + ->getSExtValue()*Scale)); + break; + } + } + + // Forgo this transformation if it the increment happens to be + // unfortunately positioned after the condition, and the condition + // has multiple uses which prevent it from being moved immediately + // before the branch. See + // test/Transforms/LoopStrengthReduce/change-compare-stride-trickiness-*.ll + // for an example of this situation. + if (!Cond->hasOneUse()) { + for (BasicBlock::iterator I = Cond, E = Cond->getParent()->end(); + I != E; ++I) + if (I == NewCmpLHS) + return Cond; + } + + if (NewCmpRHS) { + // Create a new compare instruction using new stride / iv. + ICmpInst *OldCond = Cond; + // Insert new compare instruction. + Cond = new ICmpInst(Predicate, NewCmpLHS, NewCmpRHS, + L->getHeader()->getName() + ".termcond", + OldCond); + + // Remove the old compare instruction. The old indvar is probably dead too. + DeadInsts.push_back(CondUse->getOperandValToReplace()); + OldCond->replaceAllUsesWith(Cond); + OldCond->eraseFromParent(); + + IU->IVUsesByStride[*NewStride]->addUser(NewOffset, Cond, NewCmpLHS, false); + CondUse = &IU->IVUsesByStride[*NewStride]->Users.back(); + CondStride = NewStride; + ++NumEliminated; + Changed = true; + } + + return Cond; +} + +/// OptimizeSMax - Rewrite the loop's terminating condition if it uses +/// an smax computation. +/// +/// This is a narrow solution to a specific, but acute, problem. For loops +/// like this: +/// +/// i = 0; +/// do { +/// p[i] = 0.0; +/// } while (++i < n); +/// +/// where the comparison is signed, the trip count isn't just 'n', because +/// 'n' could be negative. And unfortunately this can come up even for loops +/// where the user didn't use a C do-while loop. For example, seemingly +/// well-behaved top-test loops will commonly be lowered like this: +// +/// if (n > 0) { +/// i = 0; +/// do { +/// p[i] = 0.0; +/// } while (++i < n); +/// } +/// +/// and then it's possible for subsequent optimization to obscure the if +/// test in such a way that indvars can't find it. +/// +/// When indvars can't find the if test in loops like this, it creates a +/// signed-max expression, which allows it to give the loop a canonical +/// induction variable: +/// +/// i = 0; +/// smax = n < 1 ? 1 : n; +/// do { +/// p[i] = 0.0; +/// } while (++i != smax); +/// +/// Canonical induction variables are necessary because the loop passes +/// are designed around them. The most obvious example of this is the +/// LoopInfo analysis, which doesn't remember trip count values. It +/// expects to be able to rediscover the trip count each time it is +/// needed, and it does this using a simple analyis that only succeeds if +/// the loop has a canonical induction variable. +/// +/// However, when it comes time to generate code, the maximum operation +/// can be quite costly, especially if it's inside of an outer loop. +/// +/// This function solves this problem by detecting this type of loop and +/// rewriting their conditions from ICMP_NE back to ICMP_SLT, and deleting +/// the instructions for the maximum computation. +/// +ICmpInst *LoopStrengthReduce::OptimizeSMax(Loop *L, ICmpInst *Cond, + IVStrideUse* &CondUse) { + // Check that the loop matches the pattern we're looking for. + if (Cond->getPredicate() != CmpInst::ICMP_EQ && + Cond->getPredicate() != CmpInst::ICMP_NE) + return Cond; + + SelectInst *Sel = dyn_cast<SelectInst>(Cond->getOperand(1)); + if (!Sel || !Sel->hasOneUse()) return Cond; + + SCEVHandle BackedgeTakenCount = SE->getBackedgeTakenCount(L); + if (isa<SCEVCouldNotCompute>(BackedgeTakenCount)) + return Cond; + SCEVHandle One = SE->getIntegerSCEV(1, BackedgeTakenCount->getType()); + + // Add one to the backedge-taken count to get the trip count. + SCEVHandle IterationCount = SE->getAddExpr(BackedgeTakenCount, One); + + // Check for a max calculation that matches the pattern. + const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(IterationCount); + if (!SMax || SMax != SE->getSCEV(Sel)) return Cond; + + SCEVHandle SMaxLHS = SMax->getOperand(0); + SCEVHandle SMaxRHS = SMax->getOperand(1); + if (!SMaxLHS || SMaxLHS != One) return Cond; + + // Check the relevant induction variable for conformance to + // the pattern. + SCEVHandle IV = SE->getSCEV(Cond->getOperand(0)); + const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(IV); + if (!AR || !AR->isAffine() || + AR->getStart() != One || + AR->getStepRecurrence(*SE) != One) + return Cond; + + assert(AR->getLoop() == L && + "Loop condition operand is an addrec in a different loop!"); + + // Check the right operand of the select, and remember it, as it will + // be used in the new comparison instruction. + Value *NewRHS = 0; + if (SE->getSCEV(Sel->getOperand(1)) == SMaxRHS) + NewRHS = Sel->getOperand(1); + else if (SE->getSCEV(Sel->getOperand(2)) == SMaxRHS) + NewRHS = Sel->getOperand(2); + if (!NewRHS) return Cond; + + // Ok, everything looks ok to change the condition into an SLT or SGE and + // delete the max calculation. + ICmpInst *NewCond = + new ICmpInst(Cond->getPredicate() == CmpInst::ICMP_NE ? + CmpInst::ICMP_SLT : + CmpInst::ICMP_SGE, + Cond->getOperand(0), NewRHS, "scmp", Cond); + + // Delete the max calculation instructions. + Cond->replaceAllUsesWith(NewCond); + CondUse->setUser(NewCond); + Instruction *Cmp = cast<Instruction>(Sel->getOperand(0)); + Cond->eraseFromParent(); + Sel->eraseFromParent(); + if (Cmp->use_empty()) + Cmp->eraseFromParent(); + return NewCond; +} + +/// OptimizeShadowIV - If IV is used in a int-to-float cast +/// inside the loop then try to eliminate the cast opeation. +void LoopStrengthReduce::OptimizeShadowIV(Loop *L) { + + SCEVHandle BackedgeTakenCount = SE->getBackedgeTakenCount(L); + if (isa<SCEVCouldNotCompute>(BackedgeTakenCount)) + return; + + for (unsigned Stride = 0, e = IU->StrideOrder.size(); Stride != e; + ++Stride) { + std::map<SCEVHandle, IVUsersOfOneStride *>::iterator SI = + IU->IVUsesByStride.find(IU->StrideOrder[Stride]); + assert(SI != IU->IVUsesByStride.end() && "Stride doesn't exist!"); + if (!isa<SCEVConstant>(SI->first)) + continue; + + for (ilist<IVStrideUse>::iterator UI = SI->second->Users.begin(), + E = SI->second->Users.end(); UI != E; /* empty */) { + ilist<IVStrideUse>::iterator CandidateUI = UI; + ++UI; + Instruction *ShadowUse = CandidateUI->getUser(); + const Type *DestTy = NULL; + + /* If shadow use is a int->float cast then insert a second IV + to eliminate this cast. + + for (unsigned i = 0; i < n; ++i) + foo((double)i); + + is transformed into + + double d = 0.0; + for (unsigned i = 0; i < n; ++i, ++d) + foo(d); + */ + if (UIToFPInst *UCast = dyn_cast<UIToFPInst>(CandidateUI->getUser())) + DestTy = UCast->getDestTy(); + else if (SIToFPInst *SCast = dyn_cast<SIToFPInst>(CandidateUI->getUser())) + DestTy = SCast->getDestTy(); + if (!DestTy) continue; + + if (TLI) { + // If target does not support DestTy natively then do not apply + // this transformation. + MVT DVT = TLI->getValueType(DestTy); + if (!TLI->isTypeLegal(DVT)) continue; + } + + PHINode *PH = dyn_cast<PHINode>(ShadowUse->getOperand(0)); + if (!PH) continue; + if (PH->getNumIncomingValues() != 2) continue; + + const Type *SrcTy = PH->getType(); + int Mantissa = DestTy->getFPMantissaWidth(); + if (Mantissa == -1) continue; + if ((int)SE->getTypeSizeInBits(SrcTy) > Mantissa) + continue; + + unsigned Entry, Latch; + if (PH->getIncomingBlock(0) == L->getLoopPreheader()) { + Entry = 0; + Latch = 1; + } else { + Entry = 1; + Latch = 0; + } + + ConstantInt *Init = dyn_cast<ConstantInt>(PH->getIncomingValue(Entry)); + if (!Init) continue; + ConstantFP *NewInit = ConstantFP::get(DestTy, Init->getZExtValue()); + + BinaryOperator *Incr = + dyn_cast<BinaryOperator>(PH->getIncomingValue(Latch)); + if (!Incr) continue; + if (Incr->getOpcode() != Instruction::Add + && Incr->getOpcode() != Instruction::Sub) + continue; + + /* Initialize new IV, double d = 0.0 in above example. */ + ConstantInt *C = NULL; + if (Incr->getOperand(0) == PH) + C = dyn_cast<ConstantInt>(Incr->getOperand(1)); + else if (Incr->getOperand(1) == PH) + C = dyn_cast<ConstantInt>(Incr->getOperand(0)); + else + continue; + + if (!C) continue; + + /* Add new PHINode. */ + PHINode *NewPH = PHINode::Create(DestTy, "IV.S.", PH); + + /* create new increment. '++d' in above example. */ + ConstantFP *CFP = ConstantFP::get(DestTy, C->getZExtValue()); + BinaryOperator *NewIncr = + BinaryOperator::Create(Incr->getOpcode(), + NewPH, CFP, "IV.S.next.", Incr); + + NewPH->addIncoming(NewInit, PH->getIncomingBlock(Entry)); + NewPH->addIncoming(NewIncr, PH->getIncomingBlock(Latch)); + + /* Remove cast operation */ + ShadowUse->replaceAllUsesWith(NewPH); + ShadowUse->eraseFromParent(); + NumShadow++; + break; + } + } +} + +// OptimizeIndvars - Now that IVUsesByStride is set up with all of the indvar +// uses in the loop, look to see if we can eliminate some, in favor of using +// common indvars for the different uses. +void LoopStrengthReduce::OptimizeIndvars(Loop *L) { + // TODO: implement optzns here. + + OptimizeShadowIV(L); +} + +/// OptimizeLoopTermCond - Change loop terminating condition to use the +/// postinc iv when possible. +void LoopStrengthReduce::OptimizeLoopTermCond(Loop *L) { + // Finally, get the terminating condition for the loop if possible. If we + // can, we want to change it to use a post-incremented version of its + // induction variable, to allow coalescing the live ranges for the IV into + // one register value. + BasicBlock *LatchBlock = L->getLoopLatch(); + BasicBlock *ExitBlock = L->getExitingBlock(); + if (!ExitBlock) + // Multiple exits, just look at the exit in the latch block if there is one. + ExitBlock = LatchBlock; + BranchInst *TermBr = dyn_cast<BranchInst>(ExitBlock->getTerminator()); + if (!TermBr) + return; + if (TermBr->isUnconditional() || !isa<ICmpInst>(TermBr->getCondition())) + return; + + // Search IVUsesByStride to find Cond's IVUse if there is one. + IVStrideUse *CondUse = 0; + const SCEVHandle *CondStride = 0; + ICmpInst *Cond = cast<ICmpInst>(TermBr->getCondition()); + if (!FindIVUserForCond(Cond, CondUse, CondStride)) + return; // setcc doesn't use the IV. + + if (ExitBlock != LatchBlock) { + if (!Cond->hasOneUse()) + // See below, we don't want the condition to be cloned. + return; + + // If exiting block is the latch block, we know it's safe and profitable to + // transform the icmp to use post-inc iv. Otherwise do so only if it would + // not reuse another iv and its iv would be reused by other uses. We are + // optimizing for the case where the icmp is the only use of the iv. + IVUsersOfOneStride &StrideUses = *IU->IVUsesByStride[*CondStride]; + for (ilist<IVStrideUse>::iterator I = StrideUses.Users.begin(), + E = StrideUses.Users.end(); I != E; ++I) { + if (I->getUser() == Cond) + continue; + if (!I->isUseOfPostIncrementedValue()) + return; + } + + // FIXME: This is expensive, and worse still ChangeCompareStride does a + // similar check. Can we perform all the icmp related transformations after + // StrengthReduceStridedIVUsers? + if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(*CondStride)) { + int64_t SInt = SC->getValue()->getSExtValue(); + for (unsigned NewStride = 0, ee = IU->StrideOrder.size(); NewStride != ee; + ++NewStride) { + std::map<SCEVHandle, IVUsersOfOneStride *>::iterator SI = + IU->IVUsesByStride.find(IU->StrideOrder[NewStride]); + if (!isa<SCEVConstant>(SI->first) || SI->first == *CondStride) + continue; + int64_t SSInt = + cast<SCEVConstant>(SI->first)->getValue()->getSExtValue(); + if (SSInt == SInt) + return; // This can definitely be reused. + if (unsigned(abs64(SSInt)) < SInt || (SSInt % SInt) != 0) + continue; + int64_t Scale = SSInt / SInt; + bool AllUsesAreAddresses = true; + bool AllUsesAreOutsideLoop = true; + std::vector<BasedUser> UsersToProcess; + SCEVHandle CommonExprs = CollectIVUsers(SI->first, *SI->second, L, + AllUsesAreAddresses, + AllUsesAreOutsideLoop, + UsersToProcess); + // Avoid rewriting the compare instruction with an iv of new stride + // if it's likely the new stride uses will be rewritten using the + // stride of the compare instruction. + if (AllUsesAreAddresses && + ValidScale(!CommonExprs->isZero(), Scale, UsersToProcess)) + return; + } + } + + StrideNoReuse.insert(*CondStride); + } + + // If the trip count is computed in terms of an smax (due to ScalarEvolution + // being unable to find a sufficient guard, for example), change the loop + // comparison to use SLT instead of NE. + Cond = OptimizeSMax(L, Cond, CondUse); + + // If possible, change stride and operands of the compare instruction to + // eliminate one stride. + if (ExitBlock == LatchBlock) + Cond = ChangeCompareStride(L, Cond, CondUse, CondStride); + + // It's possible for the setcc instruction to be anywhere in the loop, and + // possible for it to have multiple users. If it is not immediately before + // the latch block branch, move it. + if (&*++BasicBlock::iterator(Cond) != (Instruction*)TermBr) { + if (Cond->hasOneUse()) { // Condition has a single use, just move it. + Cond->moveBefore(TermBr); + } else { + // Otherwise, clone the terminating condition and insert into the loopend. + Cond = cast<ICmpInst>(Cond->clone()); + Cond->setName(L->getHeader()->getName() + ".termcond"); + LatchBlock->getInstList().insert(TermBr, Cond); + + // Clone the IVUse, as the old use still exists! + IU->IVUsesByStride[*CondStride]->addUser(CondUse->getOffset(), Cond, + CondUse->getOperandValToReplace(), + false); + CondUse = &IU->IVUsesByStride[*CondStride]->Users.back(); + } + } + + // If we get to here, we know that we can transform the setcc instruction to + // use the post-incremented version of the IV, allowing us to coalesce the + // live ranges for the IV correctly. + CondUse->setOffset(SE->getMinusSCEV(CondUse->getOffset(), *CondStride)); + CondUse->setIsUseOfPostIncrementedValue(true); + Changed = true; + + ++NumLoopCond; +} + +// OptimizeLoopCountIV - If, after all sharing of IVs, the IV used for deciding +// when to exit the loop is used only for that purpose, try to rearrange things +// so it counts down to a test against zero. +void LoopStrengthReduce::OptimizeLoopCountIV(Loop *L) { + + // If the number of times the loop is executed isn't computable, give up. + SCEVHandle BackedgeTakenCount = SE->getBackedgeTakenCount(L); + if (isa<SCEVCouldNotCompute>(BackedgeTakenCount)) + return; + + // Get the terminating condition for the loop if possible (this isn't + // necessarily in the latch, or a block that's a predecessor of the header). + SmallVector<BasicBlock*, 8> ExitBlocks; + L->getExitBlocks(ExitBlocks); + if (ExitBlocks.size() != 1) return; + + // Okay, there is one exit block. Try to find the condition that causes the + // loop to be exited. + BasicBlock *ExitBlock = ExitBlocks[0]; + + BasicBlock *ExitingBlock = 0; + for (pred_iterator PI = pred_begin(ExitBlock), E = pred_end(ExitBlock); + PI != E; ++PI) + if (L->contains(*PI)) { + if (ExitingBlock == 0) + ExitingBlock = *PI; + else + return; // More than one block exiting! + } + assert(ExitingBlock && "No exits from loop, something is broken!"); + + // Okay, we've computed the exiting block. See what condition causes us to + // exit. + // + // FIXME: we should be able to handle switch instructions (with a single exit) + BranchInst *TermBr = dyn_cast<BranchInst>(ExitingBlock->getTerminator()); + if (TermBr == 0) return; + assert(TermBr->isConditional() && "If unconditional, it can't be in loop!"); + if (!isa<ICmpInst>(TermBr->getCondition())) + return; + ICmpInst *Cond = cast<ICmpInst>(TermBr->getCondition()); + + // Handle only tests for equality for the moment, and only stride 1. + if (Cond->getPredicate() != CmpInst::ICMP_EQ) + return; + SCEVHandle IV = SE->getSCEV(Cond->getOperand(0)); + const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(IV); + SCEVHandle One = SE->getIntegerSCEV(1, BackedgeTakenCount->getType()); + if (!AR || !AR->isAffine() || AR->getStepRecurrence(*SE) != One) + return; + // If the RHS of the comparison is defined inside the loop, the rewrite + // cannot be done. + if (Instruction *CR = dyn_cast<Instruction>(Cond->getOperand(1))) + if (L->contains(CR->getParent())) + return; + + // Make sure the IV is only used for counting. Value may be preinc or + // postinc; 2 uses in either case. + if (!Cond->getOperand(0)->hasNUses(2)) + return; + PHINode *phi = dyn_cast<PHINode>(Cond->getOperand(0)); + Instruction *incr; + if (phi && phi->getParent()==L->getHeader()) { + // value tested is preinc. Find the increment. + // A CmpInst is not a BinaryOperator; we depend on this. + Instruction::use_iterator UI = phi->use_begin(); + incr = dyn_cast<BinaryOperator>(UI); + if (!incr) + incr = dyn_cast<BinaryOperator>(++UI); + // 1 use for postinc value, the phi. Unnecessarily conservative? + if (!incr || !incr->hasOneUse() || incr->getOpcode()!=Instruction::Add) + return; + } else { + // Value tested is postinc. Find the phi node. + incr = dyn_cast<BinaryOperator>(Cond->getOperand(0)); + if (!incr || incr->getOpcode()!=Instruction::Add) + return; + + Instruction::use_iterator UI = Cond->getOperand(0)->use_begin(); + phi = dyn_cast<PHINode>(UI); + if (!phi) + phi = dyn_cast<PHINode>(++UI); + // 1 use for preinc value, the increment. + if (!phi || phi->getParent()!=L->getHeader() || !phi->hasOneUse()) + return; + } + + // Replace the increment with a decrement. + BinaryOperator *decr = + BinaryOperator::Create(Instruction::Sub, incr->getOperand(0), + incr->getOperand(1), "tmp", incr); + incr->replaceAllUsesWith(decr); + incr->eraseFromParent(); + + // Substitute endval-startval for the original startval, and 0 for the + // original endval. Since we're only testing for equality this is OK even + // if the computation wraps around. + BasicBlock *Preheader = L->getLoopPreheader(); + Instruction *PreInsertPt = Preheader->getTerminator(); + int inBlock = L->contains(phi->getIncomingBlock(0)) ? 1 : 0; + Value *startVal = phi->getIncomingValue(inBlock); + Value *endVal = Cond->getOperand(1); + // FIXME check for case where both are constant + ConstantInt* Zero = ConstantInt::get(Cond->getOperand(1)->getType(), 0); + BinaryOperator *NewStartVal = + BinaryOperator::Create(Instruction::Sub, endVal, startVal, + "tmp", PreInsertPt); + phi->setIncomingValue(inBlock, NewStartVal); + Cond->setOperand(1, Zero); + + Changed = true; +} + +bool LoopStrengthReduce::runOnLoop(Loop *L, LPPassManager &LPM) { + + IU = &getAnalysis<IVUsers>(); + LI = &getAnalysis<LoopInfo>(); + DT = &getAnalysis<DominatorTree>(); + SE = &getAnalysis<ScalarEvolution>(); + Changed = false; + + if (!IU->IVUsesByStride.empty()) { +#ifndef NDEBUG + DOUT << "\nLSR on \"" << L->getHeader()->getParent()->getNameStart() + << "\" "; + DEBUG(L->dump()); +#endif + + // Sort the StrideOrder so we process larger strides first. + std::stable_sort(IU->StrideOrder.begin(), IU->StrideOrder.end(), + StrideCompare(SE)); + + // Optimize induction variables. Some indvar uses can be transformed to use + // strides that will be needed for other purposes. A common example of this + // is the exit test for the loop, which can often be rewritten to use the + // computation of some other indvar to decide when to terminate the loop. + OptimizeIndvars(L); + + // Change loop terminating condition to use the postinc iv when possible + // and optimize loop terminating compare. FIXME: Move this after + // StrengthReduceStridedIVUsers? + OptimizeLoopTermCond(L); + + // FIXME: We can shrink overlarge IV's here. e.g. if the code has + // computation in i64 values and the target doesn't support i64, demote + // the computation to 32-bit if safe. + + // FIXME: Attempt to reuse values across multiple IV's. In particular, we + // could have something like "for(i) { foo(i*8); bar(i*16) }", which should + // be codegened as "for (j = 0;; j+=8) { foo(j); bar(j+j); }" on X86/PPC. + // Need to be careful that IV's are all the same type. Only works for + // intptr_t indvars. + + // IVsByStride keeps IVs for one particular loop. + assert(IVsByStride.empty() && "Stale entries in IVsByStride?"); + + // Note: this processes each stride/type pair individually. All users + // passed into StrengthReduceStridedIVUsers have the same type AND stride. + // Also, note that we iterate over IVUsesByStride indirectly by using + // StrideOrder. This extra layer of indirection makes the ordering of + // strides deterministic - not dependent on map order. + for (unsigned Stride = 0, e = IU->StrideOrder.size(); + Stride != e; ++Stride) { + std::map<SCEVHandle, IVUsersOfOneStride *>::iterator SI = + IU->IVUsesByStride.find(IU->StrideOrder[Stride]); + assert(SI != IU->IVUsesByStride.end() && "Stride doesn't exist!"); + // FIXME: Generalize to non-affine IV's. + if (!SI->first->isLoopInvariant(L)) + continue; + StrengthReduceStridedIVUsers(SI->first, *SI->second, L); + } + } + + // After all sharing is done, see if we can adjust the loop to test against + // zero instead of counting up to a maximum. This is usually faster. + OptimizeLoopCountIV(L); + + // We're done analyzing this loop; release all the state we built up for it. + IVsByStride.clear(); + StrideNoReuse.clear(); + + // Clean up after ourselves + if (!DeadInsts.empty()) + DeleteTriviallyDeadInstructions(); + + // At this point, it is worth checking to see if any recurrence PHIs are also + // dead, so that we can remove them as well. + DeleteDeadPHIs(L->getHeader()); + + return Changed; +} diff --git a/lib/Transforms/Scalar/LoopUnroll.cpp b/lib/Transforms/Scalar/LoopUnroll.cpp new file mode 100644 index 0000000..23757cd --- /dev/null +++ b/lib/Transforms/Scalar/LoopUnroll.cpp @@ -0,0 +1,183 @@ +//===-- LoopUnroll.cpp - Loop unroller pass -------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass implements a simple loop unroller. It works best when loops have +// been canonicalized by the -indvars pass, allowing it to determine the trip +// counts of loops easily. +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "loop-unroll" +#include "llvm/IntrinsicInst.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Transforms/Utils/UnrollLoop.h" +#include <climits> + +using namespace llvm; + +static cl::opt<unsigned> +UnrollThreshold("unroll-threshold", cl::init(100), cl::Hidden, + cl::desc("The cut-off point for automatic loop unrolling")); + +static cl::opt<unsigned> +UnrollCount("unroll-count", cl::init(0), cl::Hidden, + cl::desc("Use this unroll count for all loops, for testing purposes")); + +static cl::opt<bool> +UnrollAllowPartial("unroll-allow-partial", cl::init(false), cl::Hidden, + cl::desc("Allows loops to be partially unrolled until " + "-unroll-threshold loop size is reached.")); + +namespace { + class VISIBILITY_HIDDEN LoopUnroll : public LoopPass { + public: + static char ID; // Pass ID, replacement for typeid + LoopUnroll() : LoopPass(&ID) {} + + /// A magic value for use with the Threshold parameter to indicate + /// that the loop unroll should be performed regardless of how much + /// code expansion would result. + static const unsigned NoThreshold = UINT_MAX; + + bool runOnLoop(Loop *L, LPPassManager &LPM); + + /// This transformation requires natural loop information & requires that + /// loop preheaders be inserted into the CFG... + /// + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequiredID(LoopSimplifyID); + AU.addRequiredID(LCSSAID); + AU.addRequired<LoopInfo>(); + AU.addPreservedID(LCSSAID); + AU.addPreserved<LoopInfo>(); + // FIXME: Loop unroll requires LCSSA. And LCSSA requires dom info. + // If loop unroll does not preserve dom info then LCSSA pass on next + // loop will receive invalid dom info. + // For now, recreate dom info, if loop is unrolled. + AU.addPreserved<DominatorTree>(); + AU.addPreserved<DominanceFrontier>(); + } + }; +} + +char LoopUnroll::ID = 0; +static RegisterPass<LoopUnroll> X("loop-unroll", "Unroll loops"); + +Pass *llvm::createLoopUnrollPass() { return new LoopUnroll(); } + +/// ApproximateLoopSize - Approximate the size of the loop. +static unsigned ApproximateLoopSize(const Loop *L) { + unsigned Size = 0; + for (Loop::block_iterator I = L->block_begin(), E = L->block_end(); + I != E; ++I) { + BasicBlock *BB = *I; + Instruction *Term = BB->getTerminator(); + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) { + if (isa<PHINode>(I) && BB == L->getHeader()) { + // Ignore PHI nodes in the header. + } else if (I->hasOneUse() && I->use_back() == Term) { + // Ignore instructions only used by the loop terminator. + } else if (isa<DbgInfoIntrinsic>(I)) { + // Ignore debug instructions + } else if (isa<GetElementPtrInst>(I) && I->hasOneUse()) { + // Ignore GEP as they generally are subsumed into a load or store. + } else if (isa<CallInst>(I)) { + // Estimate size overhead introduced by call instructions which + // is higher than other instructions. Here 3 and 10 are magic + // numbers that help one isolated test case from PR2067 without + // negatively impacting measured benchmarks. + if (isa<IntrinsicInst>(I)) + Size = Size + 3; + else + Size = Size + 10; + } else { + ++Size; + } + + // TODO: Ignore expressions derived from PHI and constants if inval of phi + // is a constant, or if operation is associative. This will get induction + // variables. + } + } + + return Size; +} + +bool LoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) { + assert(L->isLCSSAForm()); + LoopInfo *LI = &getAnalysis<LoopInfo>(); + + BasicBlock *Header = L->getHeader(); + DOUT << "Loop Unroll: F[" << Header->getParent()->getName() + << "] Loop %" << Header->getName() << "\n"; + + // Find trip count + unsigned TripCount = L->getSmallConstantTripCount(); + unsigned Count = UnrollCount; + + // Automatically select an unroll count. + if (Count == 0) { + // Conservative heuristic: if we know the trip count, see if we can + // completely unroll (subject to the threshold, checked below); otherwise + // try to find greatest modulo of the trip count which is still under + // threshold value. + if (TripCount != 0) { + Count = TripCount; + } else { + return false; + } + } + + // Enforce the threshold. + if (UnrollThreshold != NoThreshold) { + unsigned LoopSize = ApproximateLoopSize(L); + DOUT << " Loop Size = " << LoopSize << "\n"; + uint64_t Size = (uint64_t)LoopSize*Count; + if (TripCount != 1 && Size > UnrollThreshold) { + DOUT << " Too large to fully unroll with count: " << Count + << " because size: " << Size << ">" << UnrollThreshold << "\n"; + if (UnrollAllowPartial) { + // Reduce unroll count to be modulo of TripCount for partial unrolling + Count = UnrollThreshold / LoopSize; + while (Count != 0 && TripCount%Count != 0) { + Count--; + } + if (Count < 2) { + DOUT << " could not unroll partially\n"; + return false; + } else { + DOUT << " partially unrolling with count: " << Count << "\n"; + } + } else { + DOUT << " will not try to unroll partially because " + << "-unroll-allow-partial not given\n"; + return false; + } + } + } + + // Unroll the loop. + Function *F = L->getHeader()->getParent(); + if (!UnrollLoop(L, Count, LI, &LPM)) + return false; + + // FIXME: Reconstruct dom info, because it is not preserved properly. + DominatorTree *DT = getAnalysisIfAvailable<DominatorTree>(); + if (DT) { + DT->runOnFunction(*F); + DominanceFrontier *DF = getAnalysisIfAvailable<DominanceFrontier>(); + if (DF) + DF->runOnFunction(*F); + } + return true; +} diff --git a/lib/Transforms/Scalar/LoopUnswitch.cpp b/lib/Transforms/Scalar/LoopUnswitch.cpp new file mode 100644 index 0000000..e3e881f --- /dev/null +++ b/lib/Transforms/Scalar/LoopUnswitch.cpp @@ -0,0 +1,1098 @@ +//===-- LoopUnswitch.cpp - Hoist loop-invariant conditionals in loop ------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass transforms loops that contain branches on loop-invariant conditions +// to have multiple loops. For example, it turns the left into the right code: +// +// for (...) if (lic) +// A for (...) +// if (lic) A; B; C +// B else +// C for (...) +// A; C +// +// This can increase the size of the code exponentially (doubling it every time +// a loop is unswitched) so we only unswitch if the resultant code will be +// smaller than a threshold. +// +// This pass expects LICM to be run before it to hoist invariant conditions out +// of the loop, to make the unswitching opportunity obvious. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "loop-unswitch" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Function.h" +#include "llvm/Instructions.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/Dominators.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include <algorithm> +#include <set> +using namespace llvm; + +STATISTIC(NumBranches, "Number of branches unswitched"); +STATISTIC(NumSwitches, "Number of switches unswitched"); +STATISTIC(NumSelects , "Number of selects unswitched"); +STATISTIC(NumTrivial , "Number of unswitches that are trivial"); +STATISTIC(NumSimplify, "Number of simplifications of unswitched code"); + +static cl::opt<unsigned> +Threshold("loop-unswitch-threshold", cl::desc("Max loop size to unswitch"), + cl::init(10), cl::Hidden); + +namespace { + class VISIBILITY_HIDDEN LoopUnswitch : public LoopPass { + LoopInfo *LI; // Loop information + LPPassManager *LPM; + + // LoopProcessWorklist - Used to check if second loop needs processing + // after RewriteLoopBodyWithConditionConstant rewrites first loop. + std::vector<Loop*> LoopProcessWorklist; + SmallPtrSet<Value *,8> UnswitchedVals; + + bool OptimizeForSize; + bool redoLoop; + + Loop *currentLoop; + DominanceFrontier *DF; + DominatorTree *DT; + BasicBlock *loopHeader; + BasicBlock *loopPreheader; + + // LoopBlocks contains all of the basic blocks of the loop, including the + // preheader of the loop, the body of the loop, and the exit blocks of the + // loop, in that order. + std::vector<BasicBlock*> LoopBlocks; + // NewBlocks contained cloned copy of basic blocks from LoopBlocks. + std::vector<BasicBlock*> NewBlocks; + + public: + static char ID; // Pass ID, replacement for typeid + explicit LoopUnswitch(bool Os = false) : + LoopPass(&ID), OptimizeForSize(Os), redoLoop(false), + currentLoop(NULL), DF(NULL), DT(NULL), loopHeader(NULL), + loopPreheader(NULL) {} + + bool runOnLoop(Loop *L, LPPassManager &LPM); + bool processCurrentLoop(); + + /// This transformation requires natural loop information & requires that + /// loop preheaders be inserted into the CFG... + /// + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequiredID(LoopSimplifyID); + AU.addPreservedID(LoopSimplifyID); + AU.addRequired<LoopInfo>(); + AU.addPreserved<LoopInfo>(); + AU.addRequiredID(LCSSAID); + AU.addPreservedID(LCSSAID); + AU.addPreserved<DominatorTree>(); + AU.addPreserved<DominanceFrontier>(); + } + + private: + + /// RemoveLoopFromWorklist - If the specified loop is on the loop worklist, + /// remove it. + void RemoveLoopFromWorklist(Loop *L) { + std::vector<Loop*>::iterator I = std::find(LoopProcessWorklist.begin(), + LoopProcessWorklist.end(), L); + if (I != LoopProcessWorklist.end()) + LoopProcessWorklist.erase(I); + } + + void initLoopData() { + loopHeader = currentLoop->getHeader(); + loopPreheader = currentLoop->getLoopPreheader(); + } + + /// Split all of the edges from inside the loop to their exit blocks. + /// Update the appropriate Phi nodes as we do so. + void SplitExitEdges(Loop *L, const SmallVector<BasicBlock *, 8> &ExitBlocks); + + bool UnswitchIfProfitable(Value *LoopCond, Constant *Val); + unsigned getLoopUnswitchCost(Value *LIC); + void UnswitchTrivialCondition(Loop *L, Value *Cond, Constant *Val, + BasicBlock *ExitBlock); + void UnswitchNontrivialCondition(Value *LIC, Constant *OnVal, Loop *L); + + void RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC, + Constant *Val, bool isEqual); + + void EmitPreheaderBranchOnCondition(Value *LIC, Constant *Val, + BasicBlock *TrueDest, + BasicBlock *FalseDest, + Instruction *InsertPt); + + void SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L); + void RemoveBlockIfDead(BasicBlock *BB, + std::vector<Instruction*> &Worklist, Loop *l); + void RemoveLoopFromHierarchy(Loop *L); + bool IsTrivialUnswitchCondition(Value *Cond, Constant **Val = 0, + BasicBlock **LoopExit = 0); + + }; +} +char LoopUnswitch::ID = 0; +static RegisterPass<LoopUnswitch> X("loop-unswitch", "Unswitch loops"); + +Pass *llvm::createLoopUnswitchPass(bool Os) { + return new LoopUnswitch(Os); +} + +/// FindLIVLoopCondition - Cond is a condition that occurs in L. If it is +/// invariant in the loop, or has an invariant piece, return the invariant. +/// Otherwise, return null. +static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed) { + // Constants should be folded, not unswitched on! + if (isa<Constant>(Cond)) return 0; + + // TODO: Handle: br (VARIANT|INVARIANT). + // TODO: Hoist simple expressions out of loops. + if (L->isLoopInvariant(Cond)) return Cond; + + if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Cond)) + if (BO->getOpcode() == Instruction::And || + BO->getOpcode() == Instruction::Or) { + // If either the left or right side is invariant, we can unswitch on this, + // which will cause the branch to go away in one loop and the condition to + // simplify in the other one. + if (Value *LHS = FindLIVLoopCondition(BO->getOperand(0), L, Changed)) + return LHS; + if (Value *RHS = FindLIVLoopCondition(BO->getOperand(1), L, Changed)) + return RHS; + } + + return 0; +} + +bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPM_Ref) { + LI = &getAnalysis<LoopInfo>(); + LPM = &LPM_Ref; + DF = getAnalysisIfAvailable<DominanceFrontier>(); + DT = getAnalysisIfAvailable<DominatorTree>(); + currentLoop = L; + Function *F = currentLoop->getHeader()->getParent(); + bool Changed = false; + do { + assert(currentLoop->isLCSSAForm()); + redoLoop = false; + Changed |= processCurrentLoop(); + } while(redoLoop); + + if (Changed) { + // FIXME: Reconstruct dom info, because it is not preserved properly. + if (DT) + DT->runOnFunction(*F); + if (DF) + DF->runOnFunction(*F); + } + return Changed; +} + +/// processCurrentLoop - Do actual work and unswitch loop if possible +/// and profitable. +bool LoopUnswitch::processCurrentLoop() { + bool Changed = false; + + // Loop over all of the basic blocks in the loop. If we find an interior + // block that is branching on a loop-invariant condition, we can unswitch this + // loop. + for (Loop::block_iterator I = currentLoop->block_begin(), + E = currentLoop->block_end(); + I != E; ++I) { + TerminatorInst *TI = (*I)->getTerminator(); + if (BranchInst *BI = dyn_cast<BranchInst>(TI)) { + // If this isn't branching on an invariant condition, we can't unswitch + // it. + if (BI->isConditional()) { + // See if this, or some part of it, is loop invariant. If so, we can + // unswitch on it if we desire. + Value *LoopCond = FindLIVLoopCondition(BI->getCondition(), + currentLoop, Changed); + if (LoopCond && UnswitchIfProfitable(LoopCond, + ConstantInt::getTrue())) { + ++NumBranches; + return true; + } + } + } else if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) { + Value *LoopCond = FindLIVLoopCondition(SI->getCondition(), + currentLoop, Changed); + if (LoopCond && SI->getNumCases() > 1) { + // Find a value to unswitch on: + // FIXME: this should chose the most expensive case! + Constant *UnswitchVal = SI->getCaseValue(1); + // Do not process same value again and again. + if (!UnswitchedVals.insert(UnswitchVal)) + continue; + + if (UnswitchIfProfitable(LoopCond, UnswitchVal)) { + ++NumSwitches; + return true; + } + } + } + + // Scan the instructions to check for unswitchable values. + for (BasicBlock::iterator BBI = (*I)->begin(), E = (*I)->end(); + BBI != E; ++BBI) + if (SelectInst *SI = dyn_cast<SelectInst>(BBI)) { + Value *LoopCond = FindLIVLoopCondition(SI->getCondition(), + currentLoop, Changed); + if (LoopCond && UnswitchIfProfitable(LoopCond, + ConstantInt::getTrue())) { + ++NumSelects; + return true; + } + } + } + return Changed; +} + +/// isTrivialLoopExitBlock - Check to see if all paths from BB either: +/// 1. Exit the loop with no side effects. +/// 2. Branch to the latch block with no side-effects. +/// +/// If these conditions are true, we return true and set ExitBB to the block we +/// exit through. +/// +static bool isTrivialLoopExitBlockHelper(Loop *L, BasicBlock *BB, + BasicBlock *&ExitBB, + std::set<BasicBlock*> &Visited) { + if (!Visited.insert(BB).second) { + // Already visited and Ok, end of recursion. + return true; + } else if (!L->contains(BB)) { + // Otherwise, this is a loop exit, this is fine so long as this is the + // first exit. + if (ExitBB != 0) return false; + ExitBB = BB; + return true; + } + + // Otherwise, this is an unvisited intra-loop node. Check all successors. + for (succ_iterator SI = succ_begin(BB), E = succ_end(BB); SI != E; ++SI) { + // Check to see if the successor is a trivial loop exit. + if (!isTrivialLoopExitBlockHelper(L, *SI, ExitBB, Visited)) + return false; + } + + // Okay, everything after this looks good, check to make sure that this block + // doesn't include any side effects. + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) + if (I->mayHaveSideEffects()) + return false; + + return true; +} + +/// isTrivialLoopExitBlock - Return true if the specified block unconditionally +/// leads to an exit from the specified loop, and has no side-effects in the +/// process. If so, return the block that is exited to, otherwise return null. +static BasicBlock *isTrivialLoopExitBlock(Loop *L, BasicBlock *BB) { + std::set<BasicBlock*> Visited; + Visited.insert(L->getHeader()); // Branches to header are ok. + BasicBlock *ExitBB = 0; + if (isTrivialLoopExitBlockHelper(L, BB, ExitBB, Visited)) + return ExitBB; + return 0; +} + +/// IsTrivialUnswitchCondition - Check to see if this unswitch condition is +/// trivial: that is, that the condition controls whether or not the loop does +/// anything at all. If this is a trivial condition, unswitching produces no +/// code duplications (equivalently, it produces a simpler loop and a new empty +/// loop, which gets deleted). +/// +/// If this is a trivial condition, return true, otherwise return false. When +/// returning true, this sets Cond and Val to the condition that controls the +/// trivial condition: when Cond dynamically equals Val, the loop is known to +/// exit. Finally, this sets LoopExit to the BB that the loop exits to when +/// Cond == Val. +/// +bool LoopUnswitch::IsTrivialUnswitchCondition(Value *Cond, Constant **Val, + BasicBlock **LoopExit) { + BasicBlock *Header = currentLoop->getHeader(); + TerminatorInst *HeaderTerm = Header->getTerminator(); + + BasicBlock *LoopExitBB = 0; + if (BranchInst *BI = dyn_cast<BranchInst>(HeaderTerm)) { + // If the header block doesn't end with a conditional branch on Cond, we + // can't handle it. + if (!BI->isConditional() || BI->getCondition() != Cond) + return false; + + // Check to see if a successor of the branch is guaranteed to go to the + // latch block or exit through a one exit block without having any + // side-effects. If so, determine the value of Cond that causes it to do + // this. + if ((LoopExitBB = isTrivialLoopExitBlock(currentLoop, + BI->getSuccessor(0)))) { + if (Val) *Val = ConstantInt::getTrue(); + } else if ((LoopExitBB = isTrivialLoopExitBlock(currentLoop, + BI->getSuccessor(1)))) { + if (Val) *Val = ConstantInt::getFalse(); + } + } else if (SwitchInst *SI = dyn_cast<SwitchInst>(HeaderTerm)) { + // If this isn't a switch on Cond, we can't handle it. + if (SI->getCondition() != Cond) return false; + + // Check to see if a successor of the switch is guaranteed to go to the + // latch block or exit through a one exit block without having any + // side-effects. If so, determine the value of Cond that causes it to do + // this. Note that we can't trivially unswitch on the default case. + for (unsigned i = 1, e = SI->getNumSuccessors(); i != e; ++i) + if ((LoopExitBB = isTrivialLoopExitBlock(currentLoop, + SI->getSuccessor(i)))) { + // Okay, we found a trivial case, remember the value that is trivial. + if (Val) *Val = SI->getCaseValue(i); + break; + } + } + + // If we didn't find a single unique LoopExit block, or if the loop exit block + // contains phi nodes, this isn't trivial. + if (!LoopExitBB || isa<PHINode>(LoopExitBB->begin())) + return false; // Can't handle this. + + if (LoopExit) *LoopExit = LoopExitBB; + + // We already know that nothing uses any scalar values defined inside of this + // loop. As such, we just have to check to see if this loop will execute any + // side-effecting instructions (e.g. stores, calls, volatile loads) in the + // part of the loop that the code *would* execute. We already checked the + // tail, check the header now. + for (BasicBlock::iterator I = Header->begin(), E = Header->end(); I != E; ++I) + if (I->mayHaveSideEffects()) + return false; + return true; +} + +/// getLoopUnswitchCost - Return the cost (code size growth) that will happen if +/// we choose to unswitch current loop on the specified value. +/// +unsigned LoopUnswitch::getLoopUnswitchCost(Value *LIC) { + // If the condition is trivial, always unswitch. There is no code growth for + // this case. + if (IsTrivialUnswitchCondition(LIC)) + return 0; + + // FIXME: This is really overly conservative. However, more liberal + // estimations have thus far resulted in excessive unswitching, which is bad + // both in compile time and in code size. This should be replaced once + // someone figures out how a good estimation. + return currentLoop->getBlocks().size(); + + unsigned Cost = 0; + // FIXME: this is brain dead. It should take into consideration code + // shrinkage. + for (Loop::block_iterator I = currentLoop->block_begin(), + E = currentLoop->block_end(); + I != E; ++I) { + BasicBlock *BB = *I; + // Do not include empty blocks in the cost calculation. This happen due to + // loop canonicalization and will be removed. + if (BB->begin() == BasicBlock::iterator(BB->getTerminator())) + continue; + + // Count basic blocks. + ++Cost; + } + + return Cost; +} + +/// UnswitchIfProfitable - We have found that we can unswitch currentLoop when +/// LoopCond == Val to simplify the loop. If we decide that this is profitable, +/// unswitch the loop, reprocess the pieces, then return true. +bool LoopUnswitch::UnswitchIfProfitable(Value *LoopCond, Constant *Val){ + + initLoopData(); + Function *F = loopHeader->getParent(); + + + // Check to see if it would be profitable to unswitch current loop. + unsigned Cost = getLoopUnswitchCost(LoopCond); + + // Do not do non-trivial unswitch while optimizing for size. + if (Cost && OptimizeForSize) + return false; + if (Cost && !F->isDeclaration() && F->hasFnAttr(Attribute::OptimizeForSize)) + return false; + + if (Cost > Threshold) { + // FIXME: this should estimate growth by the amount of code shared by the + // resultant unswitched loops. + // + DOUT << "NOT unswitching loop %" + << currentLoop->getHeader()->getName() << ", cost too high: " + << currentLoop->getBlocks().size() << "\n"; + return false; + } + + Constant *CondVal; + BasicBlock *ExitBlock; + if (IsTrivialUnswitchCondition(LoopCond, &CondVal, &ExitBlock)) { + UnswitchTrivialCondition(currentLoop, LoopCond, CondVal, ExitBlock); + } else { + UnswitchNontrivialCondition(LoopCond, Val, currentLoop); + } + + return true; +} + +// RemapInstruction - Convert the instruction operands from referencing the +// current values into those specified by ValueMap. +// +static inline void RemapInstruction(Instruction *I, + DenseMap<const Value *, Value*> &ValueMap) { + for (unsigned op = 0, E = I->getNumOperands(); op != E; ++op) { + Value *Op = I->getOperand(op); + DenseMap<const Value *, Value*>::iterator It = ValueMap.find(Op); + if (It != ValueMap.end()) Op = It->second; + I->setOperand(op, Op); + } +} + +/// CloneLoop - Recursively clone the specified loop and all of its children, +/// mapping the blocks with the specified map. +static Loop *CloneLoop(Loop *L, Loop *PL, DenseMap<const Value*, Value*> &VM, + LoopInfo *LI, LPPassManager *LPM) { + Loop *New = new Loop(); + + LPM->insertLoop(New, PL); + + // Add all of the blocks in L to the new loop. + for (Loop::block_iterator I = L->block_begin(), E = L->block_end(); + I != E; ++I) + if (LI->getLoopFor(*I) == L) + New->addBasicBlockToLoop(cast<BasicBlock>(VM[*I]), LI->getBase()); + + // Add all of the subloops to the new loop. + for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I) + CloneLoop(*I, New, VM, LI, LPM); + + return New; +} + +/// EmitPreheaderBranchOnCondition - Emit a conditional branch on two values +/// if LIC == Val, branch to TrueDst, otherwise branch to FalseDest. Insert the +/// code immediately before InsertPt. +void LoopUnswitch::EmitPreheaderBranchOnCondition(Value *LIC, Constant *Val, + BasicBlock *TrueDest, + BasicBlock *FalseDest, + Instruction *InsertPt) { + // Insert a conditional branch on LIC to the two preheaders. The original + // code is the true version and the new code is the false version. + Value *BranchVal = LIC; + if (!isa<ConstantInt>(Val) || Val->getType() != Type::Int1Ty) + BranchVal = new ICmpInst(ICmpInst::ICMP_EQ, LIC, Val, "tmp", InsertPt); + else if (Val != ConstantInt::getTrue()) + // We want to enter the new loop when the condition is true. + std::swap(TrueDest, FalseDest); + + // Insert the new branch. + BranchInst::Create(TrueDest, FalseDest, BranchVal, InsertPt); +} + +/// UnswitchTrivialCondition - Given a loop that has a trivial unswitchable +/// condition in it (a cond branch from its header block to its latch block, +/// where the path through the loop that doesn't execute its body has no +/// side-effects), unswitch it. This doesn't involve any code duplication, just +/// moving the conditional branch outside of the loop and updating loop info. +void LoopUnswitch::UnswitchTrivialCondition(Loop *L, Value *Cond, + Constant *Val, + BasicBlock *ExitBlock) { + DOUT << "loop-unswitch: Trivial-Unswitch loop %" + << loopHeader->getName() << " [" << L->getBlocks().size() + << " blocks] in Function " << L->getHeader()->getParent()->getName() + << " on cond: " << *Val << " == " << *Cond << "\n"; + + // First step, split the preheader, so that we know that there is a safe place + // to insert the conditional branch. We will change loopPreheader to have a + // conditional branch on Cond. + BasicBlock *NewPH = SplitEdge(loopPreheader, loopHeader, this); + + // Now that we have a place to insert the conditional branch, create a place + // to branch to: this is the exit block out of the loop that we should + // short-circuit to. + + // Split this block now, so that the loop maintains its exit block, and so + // that the jump from the preheader can execute the contents of the exit block + // without actually branching to it (the exit block should be dominated by the + // loop header, not the preheader). + assert(!L->contains(ExitBlock) && "Exit block is in the loop?"); + BasicBlock *NewExit = SplitBlock(ExitBlock, ExitBlock->begin(), this); + + // Okay, now we have a position to branch from and a position to branch to, + // insert the new conditional branch. + EmitPreheaderBranchOnCondition(Cond, Val, NewExit, NewPH, + loopPreheader->getTerminator()); + LPM->deleteSimpleAnalysisValue(loopPreheader->getTerminator(), L); + loopPreheader->getTerminator()->eraseFromParent(); + + // We need to reprocess this loop, it could be unswitched again. + redoLoop = true; + + // Now that we know that the loop is never entered when this condition is a + // particular value, rewrite the loop with this info. We know that this will + // at least eliminate the old branch. + RewriteLoopBodyWithConditionConstant(L, Cond, Val, false); + ++NumTrivial; +} + +/// SplitExitEdges - Split all of the edges from inside the loop to their exit +/// blocks. Update the appropriate Phi nodes as we do so. +void LoopUnswitch::SplitExitEdges(Loop *L, + const SmallVector<BasicBlock *, 8> &ExitBlocks) +{ + + for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i) { + BasicBlock *ExitBlock = ExitBlocks[i]; + std::vector<BasicBlock*> Preds(pred_begin(ExitBlock), pred_end(ExitBlock)); + + for (unsigned j = 0, e = Preds.size(); j != e; ++j) { + BasicBlock* NewExitBlock = SplitEdge(Preds[j], ExitBlock, this); + BasicBlock* StartBlock = Preds[j]; + BasicBlock* EndBlock; + if (NewExitBlock->getSinglePredecessor() == ExitBlock) { + EndBlock = NewExitBlock; + NewExitBlock = EndBlock->getSinglePredecessor(); + } else { + EndBlock = ExitBlock; + } + + std::set<PHINode*> InsertedPHIs; + PHINode* OldLCSSA = 0; + for (BasicBlock::iterator I = EndBlock->begin(); + (OldLCSSA = dyn_cast<PHINode>(I)); ++I) { + Value* OldValue = OldLCSSA->getIncomingValueForBlock(NewExitBlock); + PHINode* NewLCSSA = PHINode::Create(OldLCSSA->getType(), + OldLCSSA->getName() + ".us-lcssa", + NewExitBlock->getTerminator()); + NewLCSSA->addIncoming(OldValue, StartBlock); + OldLCSSA->setIncomingValue(OldLCSSA->getBasicBlockIndex(NewExitBlock), + NewLCSSA); + InsertedPHIs.insert(NewLCSSA); + } + + BasicBlock::iterator InsertPt = EndBlock->getFirstNonPHI(); + for (BasicBlock::iterator I = NewExitBlock->begin(); + (OldLCSSA = dyn_cast<PHINode>(I)) && InsertedPHIs.count(OldLCSSA) == 0; + ++I) { + PHINode *NewLCSSA = PHINode::Create(OldLCSSA->getType(), + OldLCSSA->getName() + ".us-lcssa", + InsertPt); + OldLCSSA->replaceAllUsesWith(NewLCSSA); + NewLCSSA->addIncoming(OldLCSSA, NewExitBlock); + } + + } + } + +} + +/// UnswitchNontrivialCondition - We determined that the loop is profitable +/// to unswitch when LIC equal Val. Split it into loop versions and test the +/// condition outside of either loop. Return the loops created as Out1/Out2. +void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val, + Loop *L) { + Function *F = loopHeader->getParent(); + DOUT << "loop-unswitch: Unswitching loop %" + << loopHeader->getName() << " [" << L->getBlocks().size() + << " blocks] in Function " << F->getName() + << " when '" << *Val << "' == " << *LIC << "\n"; + + LoopBlocks.clear(); + NewBlocks.clear(); + + // First step, split the preheader and exit blocks, and add these blocks to + // the LoopBlocks list. + BasicBlock *NewPreheader = SplitEdge(loopPreheader, loopHeader, this); + LoopBlocks.push_back(NewPreheader); + + // We want the loop to come after the preheader, but before the exit blocks. + LoopBlocks.insert(LoopBlocks.end(), L->block_begin(), L->block_end()); + + SmallVector<BasicBlock*, 8> ExitBlocks; + L->getUniqueExitBlocks(ExitBlocks); + + // Split all of the edges from inside the loop to their exit blocks. Update + // the appropriate Phi nodes as we do so. + SplitExitEdges(L, ExitBlocks); + + // The exit blocks may have been changed due to edge splitting, recompute. + ExitBlocks.clear(); + L->getUniqueExitBlocks(ExitBlocks); + + // Add exit blocks to the loop blocks. + LoopBlocks.insert(LoopBlocks.end(), ExitBlocks.begin(), ExitBlocks.end()); + + // Next step, clone all of the basic blocks that make up the loop (including + // the loop preheader and exit blocks), keeping track of the mapping between + // the instructions and blocks. + NewBlocks.reserve(LoopBlocks.size()); + DenseMap<const Value*, Value*> ValueMap; + for (unsigned i = 0, e = LoopBlocks.size(); i != e; ++i) { + BasicBlock *New = CloneBasicBlock(LoopBlocks[i], ValueMap, ".us", F); + NewBlocks.push_back(New); + ValueMap[LoopBlocks[i]] = New; // Keep the BB mapping. + LPM->cloneBasicBlockSimpleAnalysis(LoopBlocks[i], New, L); + } + + // Splice the newly inserted blocks into the function right before the + // original preheader. + F->getBasicBlockList().splice(LoopBlocks[0], F->getBasicBlockList(), + NewBlocks[0], F->end()); + + // Now we create the new Loop object for the versioned loop. + Loop *NewLoop = CloneLoop(L, L->getParentLoop(), ValueMap, LI, LPM); + Loop *ParentLoop = L->getParentLoop(); + if (ParentLoop) { + // Make sure to add the cloned preheader and exit blocks to the parent loop + // as well. + ParentLoop->addBasicBlockToLoop(NewBlocks[0], LI->getBase()); + } + + for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i) { + BasicBlock *NewExit = cast<BasicBlock>(ValueMap[ExitBlocks[i]]); + // The new exit block should be in the same loop as the old one. + if (Loop *ExitBBLoop = LI->getLoopFor(ExitBlocks[i])) + ExitBBLoop->addBasicBlockToLoop(NewExit, LI->getBase()); + + assert(NewExit->getTerminator()->getNumSuccessors() == 1 && + "Exit block should have been split to have one successor!"); + BasicBlock *ExitSucc = NewExit->getTerminator()->getSuccessor(0); + + // If the successor of the exit block had PHI nodes, add an entry for + // NewExit. + PHINode *PN; + for (BasicBlock::iterator I = ExitSucc->begin(); + (PN = dyn_cast<PHINode>(I)); ++I) { + Value *V = PN->getIncomingValueForBlock(ExitBlocks[i]); + DenseMap<const Value *, Value*>::iterator It = ValueMap.find(V); + if (It != ValueMap.end()) V = It->second; + PN->addIncoming(V, NewExit); + } + } + + // Rewrite the code to refer to itself. + for (unsigned i = 0, e = NewBlocks.size(); i != e; ++i) + for (BasicBlock::iterator I = NewBlocks[i]->begin(), + E = NewBlocks[i]->end(); I != E; ++I) + RemapInstruction(I, ValueMap); + + // Rewrite the original preheader to select between versions of the loop. + BranchInst *OldBR = cast<BranchInst>(loopPreheader->getTerminator()); + assert(OldBR->isUnconditional() && OldBR->getSuccessor(0) == LoopBlocks[0] && + "Preheader splitting did not work correctly!"); + + // Emit the new branch that selects between the two versions of this loop. + EmitPreheaderBranchOnCondition(LIC, Val, NewBlocks[0], LoopBlocks[0], OldBR); + LPM->deleteSimpleAnalysisValue(OldBR, L); + OldBR->eraseFromParent(); + + LoopProcessWorklist.push_back(NewLoop); + redoLoop = true; + + // Now we rewrite the original code to know that the condition is true and the + // new code to know that the condition is false. + RewriteLoopBodyWithConditionConstant(L , LIC, Val, false); + + // It's possible that simplifying one loop could cause the other to be + // deleted. If so, don't simplify it. + if (!LoopProcessWorklist.empty() && LoopProcessWorklist.back() == NewLoop) + RewriteLoopBodyWithConditionConstant(NewLoop, LIC, Val, true); + +} + +/// RemoveFromWorklist - Remove all instances of I from the worklist vector +/// specified. +static void RemoveFromWorklist(Instruction *I, + std::vector<Instruction*> &Worklist) { + std::vector<Instruction*>::iterator WI = std::find(Worklist.begin(), + Worklist.end(), I); + while (WI != Worklist.end()) { + unsigned Offset = WI-Worklist.begin(); + Worklist.erase(WI); + WI = std::find(Worklist.begin()+Offset, Worklist.end(), I); + } +} + +/// ReplaceUsesOfWith - When we find that I really equals V, remove I from the +/// program, replacing all uses with V and update the worklist. +static void ReplaceUsesOfWith(Instruction *I, Value *V, + std::vector<Instruction*> &Worklist, + Loop *L, LPPassManager *LPM) { + DOUT << "Replace with '" << *V << "': " << *I; + + // Add uses to the worklist, which may be dead now. + for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) + if (Instruction *Use = dyn_cast<Instruction>(I->getOperand(i))) + Worklist.push_back(Use); + + // Add users to the worklist which may be simplified now. + for (Value::use_iterator UI = I->use_begin(), E = I->use_end(); + UI != E; ++UI) + Worklist.push_back(cast<Instruction>(*UI)); + LPM->deleteSimpleAnalysisValue(I, L); + RemoveFromWorklist(I, Worklist); + I->replaceAllUsesWith(V); + I->eraseFromParent(); + ++NumSimplify; +} + +/// RemoveBlockIfDead - If the specified block is dead, remove it, update loop +/// information, and remove any dead successors it has. +/// +void LoopUnswitch::RemoveBlockIfDead(BasicBlock *BB, + std::vector<Instruction*> &Worklist, + Loop *L) { + if (pred_begin(BB) != pred_end(BB)) { + // This block isn't dead, since an edge to BB was just removed, see if there + // are any easy simplifications we can do now. + if (BasicBlock *Pred = BB->getSinglePredecessor()) { + // If it has one pred, fold phi nodes in BB. + while (isa<PHINode>(BB->begin())) + ReplaceUsesOfWith(BB->begin(), + cast<PHINode>(BB->begin())->getIncomingValue(0), + Worklist, L, LPM); + + // If this is the header of a loop and the only pred is the latch, we now + // have an unreachable loop. + if (Loop *L = LI->getLoopFor(BB)) + if (loopHeader == BB && L->contains(Pred)) { + // Remove the branch from the latch to the header block, this makes + // the header dead, which will make the latch dead (because the header + // dominates the latch). + LPM->deleteSimpleAnalysisValue(Pred->getTerminator(), L); + Pred->getTerminator()->eraseFromParent(); + new UnreachableInst(Pred); + + // The loop is now broken, remove it from LI. + RemoveLoopFromHierarchy(L); + + // Reprocess the header, which now IS dead. + RemoveBlockIfDead(BB, Worklist, L); + return; + } + + // If pred ends in a uncond branch, add uncond branch to worklist so that + // the two blocks will get merged. + if (BranchInst *BI = dyn_cast<BranchInst>(Pred->getTerminator())) + if (BI->isUnconditional()) + Worklist.push_back(BI); + } + return; + } + + DOUT << "Nuking dead block: " << *BB; + + // Remove the instructions in the basic block from the worklist. + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) { + RemoveFromWorklist(I, Worklist); + + // Anything that uses the instructions in this basic block should have their + // uses replaced with undefs. + if (!I->use_empty()) + I->replaceAllUsesWith(UndefValue::get(I->getType())); + } + + // If this is the edge to the header block for a loop, remove the loop and + // promote all subloops. + if (Loop *BBLoop = LI->getLoopFor(BB)) { + if (BBLoop->getLoopLatch() == BB) + RemoveLoopFromHierarchy(BBLoop); + } + + // Remove the block from the loop info, which removes it from any loops it + // was in. + LI->removeBlock(BB); + + + // Remove phi node entries in successors for this block. + TerminatorInst *TI = BB->getTerminator(); + SmallVector<BasicBlock*, 4> Succs; + for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) { + Succs.push_back(TI->getSuccessor(i)); + TI->getSuccessor(i)->removePredecessor(BB); + } + + // Unique the successors, remove anything with multiple uses. + array_pod_sort(Succs.begin(), Succs.end()); + Succs.erase(std::unique(Succs.begin(), Succs.end()), Succs.end()); + + // Remove the basic block, including all of the instructions contained in it. + LPM->deleteSimpleAnalysisValue(BB, L); + BB->eraseFromParent(); + // Remove successor blocks here that are not dead, so that we know we only + // have dead blocks in this list. Nondead blocks have a way of becoming dead, + // then getting removed before we revisit them, which is badness. + // + for (unsigned i = 0; i != Succs.size(); ++i) + if (pred_begin(Succs[i]) != pred_end(Succs[i])) { + // One exception is loop headers. If this block was the preheader for a + // loop, then we DO want to visit the loop so the loop gets deleted. + // We know that if the successor is a loop header, that this loop had to + // be the preheader: the case where this was the latch block was handled + // above and headers can only have two predecessors. + if (!LI->isLoopHeader(Succs[i])) { + Succs.erase(Succs.begin()+i); + --i; + } + } + + for (unsigned i = 0, e = Succs.size(); i != e; ++i) + RemoveBlockIfDead(Succs[i], Worklist, L); +} + +/// RemoveLoopFromHierarchy - We have discovered that the specified loop has +/// become unwrapped, either because the backedge was deleted, or because the +/// edge into the header was removed. If the edge into the header from the +/// latch block was removed, the loop is unwrapped but subloops are still alive, +/// so they just reparent loops. If the loops are actually dead, they will be +/// removed later. +void LoopUnswitch::RemoveLoopFromHierarchy(Loop *L) { + LPM->deleteLoopFromQueue(L); + RemoveLoopFromWorklist(L); +} + +// RewriteLoopBodyWithConditionConstant - We know either that the value LIC has +// the value specified by Val in the specified loop, or we know it does NOT have +// that value. Rewrite any uses of LIC or of properties correlated to it. +void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC, + Constant *Val, + bool IsEqual) { + assert(!isa<Constant>(LIC) && "Why are we unswitching on a constant?"); + + // FIXME: Support correlated properties, like: + // for (...) + // if (li1 < li2) + // ... + // if (li1 > li2) + // ... + + // FOLD boolean conditions (X|LIC), (X&LIC). Fold conditional branches, + // selects, switches. + std::vector<User*> Users(LIC->use_begin(), LIC->use_end()); + std::vector<Instruction*> Worklist; + + // If we know that LIC == Val, or that LIC == NotVal, just replace uses of LIC + // in the loop with the appropriate one directly. + if (IsEqual || (isa<ConstantInt>(Val) && Val->getType() == Type::Int1Ty)) { + Value *Replacement; + if (IsEqual) + Replacement = Val; + else + Replacement = ConstantInt::get(Type::Int1Ty, + !cast<ConstantInt>(Val)->getZExtValue()); + + for (unsigned i = 0, e = Users.size(); i != e; ++i) + if (Instruction *U = cast<Instruction>(Users[i])) { + if (!L->contains(U->getParent())) + continue; + U->replaceUsesOfWith(LIC, Replacement); + Worklist.push_back(U); + } + } else { + // Otherwise, we don't know the precise value of LIC, but we do know that it + // is certainly NOT "Val". As such, simplify any uses in the loop that we + // can. This case occurs when we unswitch switch statements. + for (unsigned i = 0, e = Users.size(); i != e; ++i) + if (Instruction *U = cast<Instruction>(Users[i])) { + if (!L->contains(U->getParent())) + continue; + + Worklist.push_back(U); + + // If we know that LIC is not Val, use this info to simplify code. + if (SwitchInst *SI = dyn_cast<SwitchInst>(U)) { + for (unsigned i = 1, e = SI->getNumCases(); i != e; ++i) { + if (SI->getCaseValue(i) == Val) { + // Found a dead case value. Don't remove PHI nodes in the + // successor if they become single-entry, those PHI nodes may + // be in the Users list. + + // FIXME: This is a hack. We need to keep the successor around + // and hooked up so as to preserve the loop structure, because + // trying to update it is complicated. So instead we preserve the + // loop structure and put the block on an dead code path. + + BasicBlock *SISucc = SI->getSuccessor(i); + BasicBlock* Old = SI->getParent(); + BasicBlock* Split = SplitBlock(Old, SI, this); + + Instruction* OldTerm = Old->getTerminator(); + BranchInst::Create(Split, SISucc, + ConstantInt::getTrue(), OldTerm); + + LPM->deleteSimpleAnalysisValue(Old->getTerminator(), L); + Old->getTerminator()->eraseFromParent(); + + PHINode *PN; + for (BasicBlock::iterator II = SISucc->begin(); + (PN = dyn_cast<PHINode>(II)); ++II) { + Value *InVal = PN->removeIncomingValue(Split, false); + PN->addIncoming(InVal, Old); + } + + SI->removeCase(i); + break; + } + } + } + + // TODO: We could do other simplifications, for example, turning + // LIC == Val -> false. + } + } + + SimplifyCode(Worklist, L); +} + +/// SimplifyCode - Okay, now that we have simplified some instructions in the +/// loop, walk over it and constant prop, dce, and fold control flow where +/// possible. Note that this is effectively a very simple loop-structure-aware +/// optimizer. During processing of this loop, L could very well be deleted, so +/// it must not be used. +/// +/// FIXME: When the loop optimizer is more mature, separate this out to a new +/// pass. +/// +void LoopUnswitch::SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L) { + while (!Worklist.empty()) { + Instruction *I = Worklist.back(); + Worklist.pop_back(); + + // Simple constant folding. + if (Constant *C = ConstantFoldInstruction(I)) { + ReplaceUsesOfWith(I, C, Worklist, L, LPM); + continue; + } + + // Simple DCE. + if (isInstructionTriviallyDead(I)) { + DOUT << "Remove dead instruction '" << *I; + + // Add uses to the worklist, which may be dead now. + for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) + if (Instruction *Use = dyn_cast<Instruction>(I->getOperand(i))) + Worklist.push_back(Use); + LPM->deleteSimpleAnalysisValue(I, L); + RemoveFromWorklist(I, Worklist); + I->eraseFromParent(); + ++NumSimplify; + continue; + } + + // Special case hacks that appear commonly in unswitched code. + switch (I->getOpcode()) { + case Instruction::Select: + if (ConstantInt *CB = dyn_cast<ConstantInt>(I->getOperand(0))) { + ReplaceUsesOfWith(I, I->getOperand(!CB->getZExtValue()+1), Worklist, L, + LPM); + continue; + } + break; + case Instruction::And: + if (isa<ConstantInt>(I->getOperand(0)) && + I->getOperand(0)->getType() == Type::Int1Ty) // constant -> RHS + cast<BinaryOperator>(I)->swapOperands(); + if (ConstantInt *CB = dyn_cast<ConstantInt>(I->getOperand(1))) + if (CB->getType() == Type::Int1Ty) { + if (CB->isOne()) // X & 1 -> X + ReplaceUsesOfWith(I, I->getOperand(0), Worklist, L, LPM); + else // X & 0 -> 0 + ReplaceUsesOfWith(I, I->getOperand(1), Worklist, L, LPM); + continue; + } + break; + case Instruction::Or: + if (isa<ConstantInt>(I->getOperand(0)) && + I->getOperand(0)->getType() == Type::Int1Ty) // constant -> RHS + cast<BinaryOperator>(I)->swapOperands(); + if (ConstantInt *CB = dyn_cast<ConstantInt>(I->getOperand(1))) + if (CB->getType() == Type::Int1Ty) { + if (CB->isOne()) // X | 1 -> 1 + ReplaceUsesOfWith(I, I->getOperand(1), Worklist, L, LPM); + else // X | 0 -> X + ReplaceUsesOfWith(I, I->getOperand(0), Worklist, L, LPM); + continue; + } + break; + case Instruction::Br: { + BranchInst *BI = cast<BranchInst>(I); + if (BI->isUnconditional()) { + // If BI's parent is the only pred of the successor, fold the two blocks + // together. + BasicBlock *Pred = BI->getParent(); + BasicBlock *Succ = BI->getSuccessor(0); + BasicBlock *SinglePred = Succ->getSinglePredecessor(); + if (!SinglePred) continue; // Nothing to do. + assert(SinglePred == Pred && "CFG broken"); + + DOUT << "Merging blocks: " << Pred->getName() << " <- " + << Succ->getName() << "\n"; + + // Resolve any single entry PHI nodes in Succ. + while (PHINode *PN = dyn_cast<PHINode>(Succ->begin())) + ReplaceUsesOfWith(PN, PN->getIncomingValue(0), Worklist, L, LPM); + + // Move all of the successor contents from Succ to Pred. + Pred->getInstList().splice(BI, Succ->getInstList(), Succ->begin(), + Succ->end()); + LPM->deleteSimpleAnalysisValue(BI, L); + BI->eraseFromParent(); + RemoveFromWorklist(BI, Worklist); + + // If Succ has any successors with PHI nodes, update them to have + // entries coming from Pred instead of Succ. + Succ->replaceAllUsesWith(Pred); + + // Remove Succ from the loop tree. + LI->removeBlock(Succ); + LPM->deleteSimpleAnalysisValue(Succ, L); + Succ->eraseFromParent(); + ++NumSimplify; + } else if (ConstantInt *CB = dyn_cast<ConstantInt>(BI->getCondition())){ + // Conditional branch. Turn it into an unconditional branch, then + // remove dead blocks. + break; // FIXME: Enable. + + DOUT << "Folded branch: " << *BI; + BasicBlock *DeadSucc = BI->getSuccessor(CB->getZExtValue()); + BasicBlock *LiveSucc = BI->getSuccessor(!CB->getZExtValue()); + DeadSucc->removePredecessor(BI->getParent(), true); + Worklist.push_back(BranchInst::Create(LiveSucc, BI)); + LPM->deleteSimpleAnalysisValue(BI, L); + BI->eraseFromParent(); + RemoveFromWorklist(BI, Worklist); + ++NumSimplify; + + RemoveBlockIfDead(DeadSucc, Worklist, L); + } + break; + } + } + } +} diff --git a/lib/Transforms/Scalar/Makefile b/lib/Transforms/Scalar/Makefile new file mode 100644 index 0000000..cc42fd0 --- /dev/null +++ b/lib/Transforms/Scalar/Makefile @@ -0,0 +1,15 @@ +##===- lib/Transforms/Scalar/Makefile ----------------------*- Makefile -*-===## +# +# The LLVM Compiler Infrastructure +# +# This file is distributed under the University of Illinois Open Source +# License. See LICENSE.TXT for details. +# +##===----------------------------------------------------------------------===## + +LEVEL = ../../.. +LIBRARYNAME = LLVMScalarOpts +BUILD_ARCHIVE = 1 + +include $(LEVEL)/Makefile.common + diff --git a/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/lib/Transforms/Scalar/MemCpyOptimizer.cpp new file mode 100644 index 0000000..5cf0518 --- /dev/null +++ b/lib/Transforms/Scalar/MemCpyOptimizer.cpp @@ -0,0 +1,741 @@ +//===- MemCpyOptimizer.cpp - Optimize use of memcpy and friends -----------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass performs various transformations related to eliminating memcpy +// calls, or transforming sets of stores into memset's. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "memcpyopt" +#include "llvm/Transforms/Scalar.h" +#include "llvm/IntrinsicInst.h" +#include "llvm/Instructions.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/Dominators.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/MemoryDependenceAnalysis.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/GetElementPtrTypeIterator.h" +#include "llvm/Target/TargetData.h" +#include <list> +using namespace llvm; + +STATISTIC(NumMemCpyInstr, "Number of memcpy instructions deleted"); +STATISTIC(NumMemSetInfer, "Number of memsets inferred"); + +/// isBytewiseValue - If the specified value can be set by repeating the same +/// byte in memory, return the i8 value that it is represented with. This is +/// true for all i8 values obviously, but is also true for i32 0, i32 -1, +/// i16 0xF0F0, double 0.0 etc. If the value can't be handled with a repeated +/// byte store (e.g. i16 0x1234), return null. +static Value *isBytewiseValue(Value *V) { + // All byte-wide stores are splatable, even of arbitrary variables. + if (V->getType() == Type::Int8Ty) return V; + + // Constant float and double values can be handled as integer values if the + // corresponding integer value is "byteable". An important case is 0.0. + if (ConstantFP *CFP = dyn_cast<ConstantFP>(V)) { + if (CFP->getType() == Type::FloatTy) + V = ConstantExpr::getBitCast(CFP, Type::Int32Ty); + if (CFP->getType() == Type::DoubleTy) + V = ConstantExpr::getBitCast(CFP, Type::Int64Ty); + // Don't handle long double formats, which have strange constraints. + } + + // We can handle constant integers that are power of two in size and a + // multiple of 8 bits. + if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) { + unsigned Width = CI->getBitWidth(); + if (isPowerOf2_32(Width) && Width > 8) { + // We can handle this value if the recursive binary decomposition is the + // same at all levels. + APInt Val = CI->getValue(); + APInt Val2; + while (Val.getBitWidth() != 8) { + unsigned NextWidth = Val.getBitWidth()/2; + Val2 = Val.lshr(NextWidth); + Val2.trunc(Val.getBitWidth()/2); + Val.trunc(Val.getBitWidth()/2); + + // If the top/bottom halves aren't the same, reject it. + if (Val != Val2) + return 0; + } + return ConstantInt::get(Val); + } + } + + // Conceptually, we could handle things like: + // %a = zext i8 %X to i16 + // %b = shl i16 %a, 8 + // %c = or i16 %a, %b + // but until there is an example that actually needs this, it doesn't seem + // worth worrying about. + return 0; +} + +static int64_t GetOffsetFromIndex(const GetElementPtrInst *GEP, unsigned Idx, + bool &VariableIdxFound, TargetData &TD) { + // Skip over the first indices. + gep_type_iterator GTI = gep_type_begin(GEP); + for (unsigned i = 1; i != Idx; ++i, ++GTI) + /*skip along*/; + + // Compute the offset implied by the rest of the indices. + int64_t Offset = 0; + for (unsigned i = Idx, e = GEP->getNumOperands(); i != e; ++i, ++GTI) { + ConstantInt *OpC = dyn_cast<ConstantInt>(GEP->getOperand(i)); + if (OpC == 0) + return VariableIdxFound = true; + if (OpC->isZero()) continue; // No offset. + + // Handle struct indices, which add their field offset to the pointer. + if (const StructType *STy = dyn_cast<StructType>(*GTI)) { + Offset += TD.getStructLayout(STy)->getElementOffset(OpC->getZExtValue()); + continue; + } + + // Otherwise, we have a sequential type like an array or vector. Multiply + // the index by the ElementSize. + uint64_t Size = TD.getTypeAllocSize(GTI.getIndexedType()); + Offset += Size*OpC->getSExtValue(); + } + + return Offset; +} + +/// IsPointerOffset - Return true if Ptr1 is provably equal to Ptr2 plus a +/// constant offset, and return that constant offset. For example, Ptr1 might +/// be &A[42], and Ptr2 might be &A[40]. In this case offset would be -8. +static bool IsPointerOffset(Value *Ptr1, Value *Ptr2, int64_t &Offset, + TargetData &TD) { + // Right now we handle the case when Ptr1/Ptr2 are both GEPs with an identical + // base. After that base, they may have some number of common (and + // potentially variable) indices. After that they handle some constant + // offset, which determines their offset from each other. At this point, we + // handle no other case. + GetElementPtrInst *GEP1 = dyn_cast<GetElementPtrInst>(Ptr1); + GetElementPtrInst *GEP2 = dyn_cast<GetElementPtrInst>(Ptr2); + if (!GEP1 || !GEP2 || GEP1->getOperand(0) != GEP2->getOperand(0)) + return false; + + // Skip any common indices and track the GEP types. + unsigned Idx = 1; + for (; Idx != GEP1->getNumOperands() && Idx != GEP2->getNumOperands(); ++Idx) + if (GEP1->getOperand(Idx) != GEP2->getOperand(Idx)) + break; + + bool VariableIdxFound = false; + int64_t Offset1 = GetOffsetFromIndex(GEP1, Idx, VariableIdxFound, TD); + int64_t Offset2 = GetOffsetFromIndex(GEP2, Idx, VariableIdxFound, TD); + if (VariableIdxFound) return false; + + Offset = Offset2-Offset1; + return true; +} + + +/// MemsetRange - Represents a range of memset'd bytes with the ByteVal value. +/// This allows us to analyze stores like: +/// store 0 -> P+1 +/// store 0 -> P+0 +/// store 0 -> P+3 +/// store 0 -> P+2 +/// which sometimes happens with stores to arrays of structs etc. When we see +/// the first store, we make a range [1, 2). The second store extends the range +/// to [0, 2). The third makes a new range [2, 3). The fourth store joins the +/// two ranges into [0, 3) which is memset'able. +namespace { +struct MemsetRange { + // Start/End - A semi range that describes the span that this range covers. + // The range is closed at the start and open at the end: [Start, End). + int64_t Start, End; + + /// StartPtr - The getelementptr instruction that points to the start of the + /// range. + Value *StartPtr; + + /// Alignment - The known alignment of the first store. + unsigned Alignment; + + /// TheStores - The actual stores that make up this range. + SmallVector<StoreInst*, 16> TheStores; + + bool isProfitableToUseMemset(const TargetData &TD) const; + +}; +} // end anon namespace + +bool MemsetRange::isProfitableToUseMemset(const TargetData &TD) const { + // If we found more than 8 stores to merge or 64 bytes, use memset. + if (TheStores.size() >= 8 || End-Start >= 64) return true; + + // Assume that the code generator is capable of merging pairs of stores + // together if it wants to. + if (TheStores.size() <= 2) return false; + + // If we have fewer than 8 stores, it can still be worthwhile to do this. + // For example, merging 4 i8 stores into an i32 store is useful almost always. + // However, merging 2 32-bit stores isn't useful on a 32-bit architecture (the + // memset will be split into 2 32-bit stores anyway) and doing so can + // pessimize the llvm optimizer. + // + // Since we don't have perfect knowledge here, make some assumptions: assume + // the maximum GPR width is the same size as the pointer size and assume that + // this width can be stored. If so, check to see whether we will end up + // actually reducing the number of stores used. + unsigned Bytes = unsigned(End-Start); + unsigned NumPointerStores = Bytes/TD.getPointerSize(); + + // Assume the remaining bytes if any are done a byte at a time. + unsigned NumByteStores = Bytes - NumPointerStores*TD.getPointerSize(); + + // If we will reduce the # stores (according to this heuristic), do the + // transformation. This encourages merging 4 x i8 -> i32 and 2 x i16 -> i32 + // etc. + return TheStores.size() > NumPointerStores+NumByteStores; +} + + +namespace { +class MemsetRanges { + /// Ranges - A sorted list of the memset ranges. We use std::list here + /// because each element is relatively large and expensive to copy. + std::list<MemsetRange> Ranges; + typedef std::list<MemsetRange>::iterator range_iterator; + TargetData &TD; +public: + MemsetRanges(TargetData &td) : TD(td) {} + + typedef std::list<MemsetRange>::const_iterator const_iterator; + const_iterator begin() const { return Ranges.begin(); } + const_iterator end() const { return Ranges.end(); } + bool empty() const { return Ranges.empty(); } + + void addStore(int64_t OffsetFromFirst, StoreInst *SI); +}; + +} // end anon namespace + + +/// addStore - Add a new store to the MemsetRanges data structure. This adds a +/// new range for the specified store at the specified offset, merging into +/// existing ranges as appropriate. +void MemsetRanges::addStore(int64_t Start, StoreInst *SI) { + int64_t End = Start+TD.getTypeStoreSize(SI->getOperand(0)->getType()); + + // Do a linear search of the ranges to see if this can be joined and/or to + // find the insertion point in the list. We keep the ranges sorted for + // simplicity here. This is a linear search of a linked list, which is ugly, + // however the number of ranges is limited, so this won't get crazy slow. + range_iterator I = Ranges.begin(), E = Ranges.end(); + + while (I != E && Start > I->End) + ++I; + + // We now know that I == E, in which case we didn't find anything to merge + // with, or that Start <= I->End. If End < I->Start or I == E, then we need + // to insert a new range. Handle this now. + if (I == E || End < I->Start) { + MemsetRange &R = *Ranges.insert(I, MemsetRange()); + R.Start = Start; + R.End = End; + R.StartPtr = SI->getPointerOperand(); + R.Alignment = SI->getAlignment(); + R.TheStores.push_back(SI); + return; + } + + // This store overlaps with I, add it. + I->TheStores.push_back(SI); + + // At this point, we may have an interval that completely contains our store. + // If so, just add it to the interval and return. + if (I->Start <= Start && I->End >= End) + return; + + // Now we know that Start <= I->End and End >= I->Start so the range overlaps + // but is not entirely contained within the range. + + // See if the range extends the start of the range. In this case, it couldn't + // possibly cause it to join the prior range, because otherwise we would have + // stopped on *it*. + if (Start < I->Start) { + I->Start = Start; + I->StartPtr = SI->getPointerOperand(); + } + + // Now we know that Start <= I->End and Start >= I->Start (so the startpoint + // is in or right at the end of I), and that End >= I->Start. Extend I out to + // End. + if (End > I->End) { + I->End = End; + range_iterator NextI = I; + while (++NextI != E && End >= NextI->Start) { + // Merge the range in. + I->TheStores.append(NextI->TheStores.begin(), NextI->TheStores.end()); + if (NextI->End > I->End) + I->End = NextI->End; + Ranges.erase(NextI); + NextI = I; + } + } +} + +//===----------------------------------------------------------------------===// +// MemCpyOpt Pass +//===----------------------------------------------------------------------===// + +namespace { + + class VISIBILITY_HIDDEN MemCpyOpt : public FunctionPass { + bool runOnFunction(Function &F); + public: + static char ID; // Pass identification, replacement for typeid + MemCpyOpt() : FunctionPass(&ID) {} + + private: + // This transformation requires dominator postdominator info + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesCFG(); + AU.addRequired<DominatorTree>(); + AU.addRequired<MemoryDependenceAnalysis>(); + AU.addRequired<AliasAnalysis>(); + AU.addRequired<TargetData>(); + AU.addPreserved<AliasAnalysis>(); + AU.addPreserved<MemoryDependenceAnalysis>(); + AU.addPreserved<TargetData>(); + } + + // Helper fuctions + bool processStore(StoreInst *SI, BasicBlock::iterator& BBI); + bool processMemCpy(MemCpyInst* M); + bool performCallSlotOptzn(MemCpyInst* cpy, CallInst* C); + bool iterateOnFunction(Function &F); + }; + + char MemCpyOpt::ID = 0; +} + +// createMemCpyOptPass - The public interface to this file... +FunctionPass *llvm::createMemCpyOptPass() { return new MemCpyOpt(); } + +static RegisterPass<MemCpyOpt> X("memcpyopt", + "MemCpy Optimization"); + + + +/// processStore - When GVN is scanning forward over instructions, we look for +/// some other patterns to fold away. In particular, this looks for stores to +/// neighboring locations of memory. If it sees enough consequtive ones +/// (currently 4) it attempts to merge them together into a memcpy/memset. +bool MemCpyOpt::processStore(StoreInst *SI, BasicBlock::iterator& BBI) { + if (SI->isVolatile()) return false; + + // There are two cases that are interesting for this code to handle: memcpy + // and memset. Right now we only handle memset. + + // Ensure that the value being stored is something that can be memset'able a + // byte at a time like "0" or "-1" or any width, as well as things like + // 0xA0A0A0A0 and 0.0. + Value *ByteVal = isBytewiseValue(SI->getOperand(0)); + if (!ByteVal) + return false; + + TargetData &TD = getAnalysis<TargetData>(); + AliasAnalysis &AA = getAnalysis<AliasAnalysis>(); + + // Okay, so we now have a single store that can be splatable. Scan to find + // all subsequent stores of the same value to offset from the same pointer. + // Join these together into ranges, so we can decide whether contiguous blocks + // are stored. + MemsetRanges Ranges(TD); + + Value *StartPtr = SI->getPointerOperand(); + + BasicBlock::iterator BI = SI; + for (++BI; !isa<TerminatorInst>(BI); ++BI) { + if (isa<CallInst>(BI) || isa<InvokeInst>(BI)) { + // If the call is readnone, ignore it, otherwise bail out. We don't even + // allow readonly here because we don't want something like: + // A[1] = 2; strlen(A); A[2] = 2; -> memcpy(A, ...); strlen(A). + if (AA.getModRefBehavior(CallSite::get(BI)) == + AliasAnalysis::DoesNotAccessMemory) + continue; + + // TODO: If this is a memset, try to join it in. + + break; + } else if (isa<VAArgInst>(BI) || isa<LoadInst>(BI)) + break; + + // If this is a non-store instruction it is fine, ignore it. + StoreInst *NextStore = dyn_cast<StoreInst>(BI); + if (NextStore == 0) continue; + + // If this is a store, see if we can merge it in. + if (NextStore->isVolatile()) break; + + // Check to see if this stored value is of the same byte-splattable value. + if (ByteVal != isBytewiseValue(NextStore->getOperand(0))) + break; + + // Check to see if this store is to a constant offset from the start ptr. + int64_t Offset; + if (!IsPointerOffset(StartPtr, NextStore->getPointerOperand(), Offset, TD)) + break; + + Ranges.addStore(Offset, NextStore); + } + + // If we have no ranges, then we just had a single store with nothing that + // could be merged in. This is a very common case of course. + if (Ranges.empty()) + return false; + + // If we had at least one store that could be merged in, add the starting + // store as well. We try to avoid this unless there is at least something + // interesting as a small compile-time optimization. + Ranges.addStore(0, SI); + + + Function *MemSetF = 0; + + // Now that we have full information about ranges, loop over the ranges and + // emit memset's for anything big enough to be worthwhile. + bool MadeChange = false; + for (MemsetRanges::const_iterator I = Ranges.begin(), E = Ranges.end(); + I != E; ++I) { + const MemsetRange &Range = *I; + + if (Range.TheStores.size() == 1) continue; + + // If it is profitable to lower this range to memset, do so now. + if (!Range.isProfitableToUseMemset(TD)) + continue; + + // Otherwise, we do want to transform this! Create a new memset. We put + // the memset right before the first instruction that isn't part of this + // memset block. This ensure that the memset is dominated by any addressing + // instruction needed by the start of the block. + BasicBlock::iterator InsertPt = BI; + + if (MemSetF == 0) { + const Type *Tys[] = {Type::Int64Ty}; + MemSetF = Intrinsic::getDeclaration(SI->getParent()->getParent() + ->getParent(), Intrinsic::memset, + Tys, 1); + } + + // Get the starting pointer of the block. + StartPtr = Range.StartPtr; + + // Cast the start ptr to be i8* as memset requires. + const Type *i8Ptr = PointerType::getUnqual(Type::Int8Ty); + if (StartPtr->getType() != i8Ptr) + StartPtr = new BitCastInst(StartPtr, i8Ptr, StartPtr->getNameStart(), + InsertPt); + + Value *Ops[] = { + StartPtr, ByteVal, // Start, value + ConstantInt::get(Type::Int64Ty, Range.End-Range.Start), // size + ConstantInt::get(Type::Int32Ty, Range.Alignment) // align + }; + Value *C = CallInst::Create(MemSetF, Ops, Ops+4, "", InsertPt); + DEBUG(cerr << "Replace stores:\n"; + for (unsigned i = 0, e = Range.TheStores.size(); i != e; ++i) + cerr << *Range.TheStores[i]; + cerr << "With: " << *C); C=C; + + // Don't invalidate the iterator + BBI = BI; + + // Zap all the stores. + for (SmallVector<StoreInst*, 16>::const_iterator SI = Range.TheStores.begin(), + SE = Range.TheStores.end(); SI != SE; ++SI) + (*SI)->eraseFromParent(); + ++NumMemSetInfer; + MadeChange = true; + } + + return MadeChange; +} + + +/// performCallSlotOptzn - takes a memcpy and a call that it depends on, +/// and checks for the possibility of a call slot optimization by having +/// the call write its result directly into the destination of the memcpy. +bool MemCpyOpt::performCallSlotOptzn(MemCpyInst *cpy, CallInst *C) { + // The general transformation to keep in mind is + // + // call @func(..., src, ...) + // memcpy(dest, src, ...) + // + // -> + // + // memcpy(dest, src, ...) + // call @func(..., dest, ...) + // + // Since moving the memcpy is technically awkward, we additionally check that + // src only holds uninitialized values at the moment of the call, meaning that + // the memcpy can be discarded rather than moved. + + // Deliberately get the source and destination with bitcasts stripped away, + // because we'll need to do type comparisons based on the underlying type. + Value* cpyDest = cpy->getDest(); + Value* cpySrc = cpy->getSource(); + CallSite CS = CallSite::get(C); + + // We need to be able to reason about the size of the memcpy, so we require + // that it be a constant. + ConstantInt* cpyLength = dyn_cast<ConstantInt>(cpy->getLength()); + if (!cpyLength) + return false; + + // Require that src be an alloca. This simplifies the reasoning considerably. + AllocaInst* srcAlloca = dyn_cast<AllocaInst>(cpySrc); + if (!srcAlloca) + return false; + + // Check that all of src is copied to dest. + TargetData& TD = getAnalysis<TargetData>(); + + ConstantInt* srcArraySize = dyn_cast<ConstantInt>(srcAlloca->getArraySize()); + if (!srcArraySize) + return false; + + uint64_t srcSize = TD.getTypeAllocSize(srcAlloca->getAllocatedType()) * + srcArraySize->getZExtValue(); + + if (cpyLength->getZExtValue() < srcSize) + return false; + + // Check that accessing the first srcSize bytes of dest will not cause a + // trap. Otherwise the transform is invalid since it might cause a trap + // to occur earlier than it otherwise would. + if (AllocaInst* A = dyn_cast<AllocaInst>(cpyDest)) { + // The destination is an alloca. Check it is larger than srcSize. + ConstantInt* destArraySize = dyn_cast<ConstantInt>(A->getArraySize()); + if (!destArraySize) + return false; + + uint64_t destSize = TD.getTypeAllocSize(A->getAllocatedType()) * + destArraySize->getZExtValue(); + + if (destSize < srcSize) + return false; + } else if (Argument* A = dyn_cast<Argument>(cpyDest)) { + // If the destination is an sret parameter then only accesses that are + // outside of the returned struct type can trap. + if (!A->hasStructRetAttr()) + return false; + + const Type* StructTy = cast<PointerType>(A->getType())->getElementType(); + uint64_t destSize = TD.getTypeAllocSize(StructTy); + + if (destSize < srcSize) + return false; + } else { + return false; + } + + // Check that src is not accessed except via the call and the memcpy. This + // guarantees that it holds only undefined values when passed in (so the final + // memcpy can be dropped), that it is not read or written between the call and + // the memcpy, and that writing beyond the end of it is undefined. + SmallVector<User*, 8> srcUseList(srcAlloca->use_begin(), + srcAlloca->use_end()); + while (!srcUseList.empty()) { + User* UI = srcUseList.back(); + srcUseList.pop_back(); + + if (isa<BitCastInst>(UI)) { + for (User::use_iterator I = UI->use_begin(), E = UI->use_end(); + I != E; ++I) + srcUseList.push_back(*I); + } else if (GetElementPtrInst* G = dyn_cast<GetElementPtrInst>(UI)) { + if (G->hasAllZeroIndices()) + for (User::use_iterator I = UI->use_begin(), E = UI->use_end(); + I != E; ++I) + srcUseList.push_back(*I); + else + return false; + } else if (UI != C && UI != cpy) { + return false; + } + } + + // Since we're changing the parameter to the callsite, we need to make sure + // that what would be the new parameter dominates the callsite. + DominatorTree& DT = getAnalysis<DominatorTree>(); + if (Instruction* cpyDestInst = dyn_cast<Instruction>(cpyDest)) + if (!DT.dominates(cpyDestInst, C)) + return false; + + // In addition to knowing that the call does not access src in some + // unexpected manner, for example via a global, which we deduce from + // the use analysis, we also need to know that it does not sneakily + // access dest. We rely on AA to figure this out for us. + AliasAnalysis& AA = getAnalysis<AliasAnalysis>(); + if (AA.getModRefInfo(C, cpy->getRawDest(), srcSize) != + AliasAnalysis::NoModRef) + return false; + + // All the checks have passed, so do the transformation. + bool changedArgument = false; + for (unsigned i = 0; i < CS.arg_size(); ++i) + if (CS.getArgument(i)->stripPointerCasts() == cpySrc) { + if (cpySrc->getType() != cpyDest->getType()) + cpyDest = CastInst::CreatePointerCast(cpyDest, cpySrc->getType(), + cpyDest->getName(), C); + changedArgument = true; + if (CS.getArgument(i)->getType() != cpyDest->getType()) + CS.setArgument(i, CastInst::CreatePointerCast(cpyDest, + CS.getArgument(i)->getType(), cpyDest->getName(), C)); + else + CS.setArgument(i, cpyDest); + } + + if (!changedArgument) + return false; + + // Drop any cached information about the call, because we may have changed + // its dependence information by changing its parameter. + MemoryDependenceAnalysis& MD = getAnalysis<MemoryDependenceAnalysis>(); + MD.removeInstruction(C); + + // Remove the memcpy + MD.removeInstruction(cpy); + cpy->eraseFromParent(); + NumMemCpyInstr++; + + return true; +} + +/// processMemCpy - perform simplication of memcpy's. If we have memcpy A which +/// copies X to Y, and memcpy B which copies Y to Z, then we can rewrite B to be +/// a memcpy from X to Z (or potentially a memmove, depending on circumstances). +/// This allows later passes to remove the first memcpy altogether. +bool MemCpyOpt::processMemCpy(MemCpyInst* M) { + MemoryDependenceAnalysis& MD = getAnalysis<MemoryDependenceAnalysis>(); + + // The are two possible optimizations we can do for memcpy: + // a) memcpy-memcpy xform which exposes redundance for DSE + // b) call-memcpy xform for return slot optimization + MemDepResult dep = MD.getDependency(M); + if (!dep.isClobber()) + return false; + if (!isa<MemCpyInst>(dep.getInst())) { + if (CallInst* C = dyn_cast<CallInst>(dep.getInst())) + return performCallSlotOptzn(M, C); + return false; + } + + MemCpyInst* MDep = cast<MemCpyInst>(dep.getInst()); + + // We can only transforms memcpy's where the dest of one is the source of the + // other + if (M->getSource() != MDep->getDest()) + return false; + + // Second, the length of the memcpy's must be the same, or the preceeding one + // must be larger than the following one. + ConstantInt* C1 = dyn_cast<ConstantInt>(MDep->getLength()); + ConstantInt* C2 = dyn_cast<ConstantInt>(M->getLength()); + if (!C1 || !C2) + return false; + + uint64_t DepSize = C1->getValue().getZExtValue(); + uint64_t CpySize = C2->getValue().getZExtValue(); + + if (DepSize < CpySize) + return false; + + // Finally, we have to make sure that the dest of the second does not + // alias the source of the first + AliasAnalysis& AA = getAnalysis<AliasAnalysis>(); + if (AA.alias(M->getRawDest(), CpySize, MDep->getRawSource(), DepSize) != + AliasAnalysis::NoAlias) + return false; + else if (AA.alias(M->getRawDest(), CpySize, M->getRawSource(), CpySize) != + AliasAnalysis::NoAlias) + return false; + else if (AA.alias(MDep->getRawDest(), DepSize, MDep->getRawSource(), DepSize) + != AliasAnalysis::NoAlias) + return false; + + // If all checks passed, then we can transform these memcpy's + const Type *Tys[1]; + Tys[0] = M->getLength()->getType(); + Function* MemCpyFun = Intrinsic::getDeclaration( + M->getParent()->getParent()->getParent(), + M->getIntrinsicID(), Tys, 1); + + Value *Args[4] = { + M->getRawDest(), MDep->getRawSource(), M->getLength(), M->getAlignmentCst() + }; + + CallInst* C = CallInst::Create(MemCpyFun, Args, Args+4, "", M); + + + // If C and M don't interfere, then this is a valid transformation. If they + // did, this would mean that the two sources overlap, which would be bad. + if (MD.getDependency(C) == dep) { + MD.removeInstruction(M); + M->eraseFromParent(); + NumMemCpyInstr++; + return true; + } + + // Otherwise, there was no point in doing this, so we remove the call we + // inserted and act like nothing happened. + MD.removeInstruction(C); + C->eraseFromParent(); + return false; +} + +// MemCpyOpt::runOnFunction - This is the main transformation entry point for a +// function. +// +bool MemCpyOpt::runOnFunction(Function& F) { + + bool changed = false; + bool shouldContinue = true; + + while (shouldContinue) { + shouldContinue = iterateOnFunction(F); + changed |= shouldContinue; + } + + return changed; +} + + +// MemCpyOpt::iterateOnFunction - Executes one iteration of GVN +bool MemCpyOpt::iterateOnFunction(Function &F) { + bool changed_function = false; + + // Walk all instruction in the function + for (Function::iterator BB = F.begin(), BBE = F.end(); BB != BBE; ++BB) { + for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); + BI != BE;) { + // Avoid invalidating the iterator + Instruction* I = BI++; + + if (StoreInst *SI = dyn_cast<StoreInst>(I)) + changed_function |= processStore(SI, BI); + else if (MemCpyInst* M = dyn_cast<MemCpyInst>(I)) { + changed_function |= processMemCpy(M); + } + } + } + + return changed_function; +} diff --git a/lib/Transforms/Scalar/PredicateSimplifier.cpp b/lib/Transforms/Scalar/PredicateSimplifier.cpp new file mode 100644 index 0000000..a7e4d6e --- /dev/null +++ b/lib/Transforms/Scalar/PredicateSimplifier.cpp @@ -0,0 +1,2725 @@ +//===-- PredicateSimplifier.cpp - Path Sensitive Simplifier ---------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// Path-sensitive optimizer. In a branch where x == y, replace uses of +// x with y. Permits further optimization, such as the elimination of +// the unreachable call: +// +// void test(int *p, int *q) +// { +// if (p != q) +// return; +// +// if (*p != *q) +// foo(); // unreachable +// } +// +//===----------------------------------------------------------------------===// +// +// The InequalityGraph focusses on four properties; equals, not equals, +// less-than and less-than-or-equals-to. The greater-than forms are also held +// just to allow walking from a lesser node to a greater one. These properties +// are stored in a lattice; LE can become LT or EQ, NE can become LT or GT. +// +// These relationships define a graph between values of the same type. Each +// Value is stored in a map table that retrieves the associated Node. This +// is how EQ relationships are stored; the map contains pointers from equal +// Value to the same node. The node contains a most canonical Value* form +// and the list of known relationships with other nodes. +// +// If two nodes are known to be inequal, then they will contain pointers to +// each other with an "NE" relationship. If node getNode(%x) is less than +// getNode(%y), then the %x node will contain <%y, GT> and %y will contain +// <%x, LT>. This allows us to tie nodes together into a graph like this: +// +// %a < %b < %c < %d +// +// with four nodes representing the properties. The InequalityGraph provides +// querying with "isRelatedBy" and mutators "addEquality" and "addInequality". +// To find a relationship, we start with one of the nodes any binary search +// through its list to find where the relationships with the second node start. +// Then we iterate through those to find the first relationship that dominates +// our context node. +// +// To create these properties, we wait until a branch or switch instruction +// implies that a particular value is true (or false). The VRPSolver is +// responsible for analyzing the variable and seeing what new inferences +// can be made from each property. For example: +// +// %P = icmp ne i32* %ptr, null +// %a = and i1 %P, %Q +// br i1 %a label %cond_true, label %cond_false +// +// For the true branch, the VRPSolver will start with %a EQ true and look at +// the definition of %a and find that it can infer that %P and %Q are both +// true. From %P being true, it can infer that %ptr NE null. For the false +// branch it can't infer anything from the "and" instruction. +// +// Besides branches, we can also infer properties from instruction that may +// have undefined behaviour in certain cases. For example, the dividend of +// a division may never be zero. After the division instruction, we may assume +// that the dividend is not equal to zero. +// +//===----------------------------------------------------------------------===// +// +// The ValueRanges class stores the known integer bounds of a Value. When we +// encounter i8 %a u< %b, the ValueRanges stores that %a = [1, 255] and +// %b = [0, 254]. +// +// It never stores an empty range, because that means that the code is +// unreachable. It never stores a single-element range since that's an equality +// relationship and better stored in the InequalityGraph, nor an empty range +// since that is better stored in UnreachableBlocks. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "predsimplify" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Instructions.h" +#include "llvm/Pass.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/SetOperations.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Analysis/Dominators.h" +#include "llvm/Assembly/Writer.h" +#include "llvm/Support/CFG.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/ConstantRange.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/InstVisitor.h" +#include "llvm/Target/TargetData.h" +#include "llvm/Transforms/Utils/Local.h" +#include <algorithm> +#include <deque> +#include <stack> +using namespace llvm; + +STATISTIC(NumVarsReplaced, "Number of argument substitutions"); +STATISTIC(NumInstruction , "Number of instructions removed"); +STATISTIC(NumSimple , "Number of simple replacements"); +STATISTIC(NumBlocks , "Number of blocks marked unreachable"); +STATISTIC(NumSnuggle , "Number of comparisons snuggled"); + +namespace { + class DomTreeDFS { + public: + class Node { + friend class DomTreeDFS; + public: + typedef std::vector<Node *>::iterator iterator; + typedef std::vector<Node *>::const_iterator const_iterator; + + unsigned getDFSNumIn() const { return DFSin; } + unsigned getDFSNumOut() const { return DFSout; } + + BasicBlock *getBlock() const { return BB; } + + iterator begin() { return Children.begin(); } + iterator end() { return Children.end(); } + + const_iterator begin() const { return Children.begin(); } + const_iterator end() const { return Children.end(); } + + bool dominates(const Node *N) const { + return DFSin <= N->DFSin && DFSout >= N->DFSout; + } + + bool DominatedBy(const Node *N) const { + return N->dominates(this); + } + + /// Sorts by the number of descendants. With this, you can iterate + /// through a sorted list and the first matching entry is the most + /// specific match for your basic block. The order provided is stable; + /// DomTreeDFS::Nodes with the same number of descendants are sorted by + /// DFS in number. + bool operator<(const Node &N) const { + unsigned spread = DFSout - DFSin; + unsigned N_spread = N.DFSout - N.DFSin; + if (spread == N_spread) return DFSin < N.DFSin; + return spread < N_spread; + } + bool operator>(const Node &N) const { return N < *this; } + + private: + unsigned DFSin, DFSout; + BasicBlock *BB; + + std::vector<Node *> Children; + }; + + // XXX: this may be slow. Instead of using "new" for each node, consider + // putting them in a vector to keep them contiguous. + explicit DomTreeDFS(DominatorTree *DT) { + std::stack<std::pair<Node *, DomTreeNode *> > S; + + Entry = new Node; + Entry->BB = DT->getRootNode()->getBlock(); + S.push(std::make_pair(Entry, DT->getRootNode())); + + NodeMap[Entry->BB] = Entry; + + while (!S.empty()) { + std::pair<Node *, DomTreeNode *> &Pair = S.top(); + Node *N = Pair.first; + DomTreeNode *DTNode = Pair.second; + S.pop(); + + for (DomTreeNode::iterator I = DTNode->begin(), E = DTNode->end(); + I != E; ++I) { + Node *NewNode = new Node; + NewNode->BB = (*I)->getBlock(); + N->Children.push_back(NewNode); + S.push(std::make_pair(NewNode, *I)); + + NodeMap[NewNode->BB] = NewNode; + } + } + + renumber(); + +#ifndef NDEBUG + DEBUG(dump()); +#endif + } + +#ifndef NDEBUG + virtual +#endif + ~DomTreeDFS() { + std::stack<Node *> S; + + S.push(Entry); + while (!S.empty()) { + Node *N = S.top(); S.pop(); + + for (Node::iterator I = N->begin(), E = N->end(); I != E; ++I) + S.push(*I); + + delete N; + } + } + + /// getRootNode - This returns the entry node for the CFG of the function. + Node *getRootNode() const { return Entry; } + + /// getNodeForBlock - return the node for the specified basic block. + Node *getNodeForBlock(BasicBlock *BB) const { + if (!NodeMap.count(BB)) return 0; + return const_cast<DomTreeDFS*>(this)->NodeMap[BB]; + } + + /// dominates - returns true if the basic block for I1 dominates that of + /// the basic block for I2. If the instructions belong to the same basic + /// block, the instruction first instruction sequentially in the block is + /// considered dominating. + bool dominates(Instruction *I1, Instruction *I2) { + BasicBlock *BB1 = I1->getParent(), + *BB2 = I2->getParent(); + if (BB1 == BB2) { + if (isa<TerminatorInst>(I1)) return false; + if (isa<TerminatorInst>(I2)) return true; + if ( isa<PHINode>(I1) && !isa<PHINode>(I2)) return true; + if (!isa<PHINode>(I1) && isa<PHINode>(I2)) return false; + + for (BasicBlock::const_iterator I = BB2->begin(), E = BB2->end(); + I != E; ++I) { + if (&*I == I1) return true; + else if (&*I == I2) return false; + } + assert(!"Instructions not found in parent BasicBlock?"); + } else { + Node *Node1 = getNodeForBlock(BB1), + *Node2 = getNodeForBlock(BB2); + return Node1 && Node2 && Node1->dominates(Node2); + } + return false; // Not reached + } + + private: + /// renumber - calculates the depth first search numberings and applies + /// them onto the nodes. + void renumber() { + std::stack<std::pair<Node *, Node::iterator> > S; + unsigned n = 0; + + Entry->DFSin = ++n; + S.push(std::make_pair(Entry, Entry->begin())); + + while (!S.empty()) { + std::pair<Node *, Node::iterator> &Pair = S.top(); + Node *N = Pair.first; + Node::iterator &I = Pair.second; + + if (I == N->end()) { + N->DFSout = ++n; + S.pop(); + } else { + Node *Next = *I++; + Next->DFSin = ++n; + S.push(std::make_pair(Next, Next->begin())); + } + } + } + +#ifndef NDEBUG + virtual void dump() const { + dump(*cerr.stream()); + } + + void dump(std::ostream &os) const { + os << "Predicate simplifier DomTreeDFS: \n"; + dump(Entry, 0, os); + os << "\n\n"; + } + + void dump(Node *N, int depth, std::ostream &os) const { + ++depth; + for (int i = 0; i < depth; ++i) { os << " "; } + os << "[" << depth << "] "; + + os << N->getBlock()->getName() << " (" << N->getDFSNumIn() + << ", " << N->getDFSNumOut() << ")\n"; + + for (Node::iterator I = N->begin(), E = N->end(); I != E; ++I) + dump(*I, depth, os); + } +#endif + + Node *Entry; + std::map<BasicBlock *, Node *> NodeMap; + }; + + // SLT SGT ULT UGT EQ + // 0 1 0 1 0 -- GT 10 + // 0 1 0 1 1 -- GE 11 + // 0 1 1 0 0 -- SGTULT 12 + // 0 1 1 0 1 -- SGEULE 13 + // 0 1 1 1 0 -- SGT 14 + // 0 1 1 1 1 -- SGE 15 + // 1 0 0 1 0 -- SLTUGT 18 + // 1 0 0 1 1 -- SLEUGE 19 + // 1 0 1 0 0 -- LT 20 + // 1 0 1 0 1 -- LE 21 + // 1 0 1 1 0 -- SLT 22 + // 1 0 1 1 1 -- SLE 23 + // 1 1 0 1 0 -- UGT 26 + // 1 1 0 1 1 -- UGE 27 + // 1 1 1 0 0 -- ULT 28 + // 1 1 1 0 1 -- ULE 29 + // 1 1 1 1 0 -- NE 30 + enum LatticeBits { + EQ_BIT = 1, UGT_BIT = 2, ULT_BIT = 4, SGT_BIT = 8, SLT_BIT = 16 + }; + enum LatticeVal { + GT = SGT_BIT | UGT_BIT, + GE = GT | EQ_BIT, + LT = SLT_BIT | ULT_BIT, + LE = LT | EQ_BIT, + NE = SLT_BIT | SGT_BIT | ULT_BIT | UGT_BIT, + SGTULT = SGT_BIT | ULT_BIT, + SGEULE = SGTULT | EQ_BIT, + SLTUGT = SLT_BIT | UGT_BIT, + SLEUGE = SLTUGT | EQ_BIT, + ULT = SLT_BIT | SGT_BIT | ULT_BIT, + UGT = SLT_BIT | SGT_BIT | UGT_BIT, + SLT = SLT_BIT | ULT_BIT | UGT_BIT, + SGT = SGT_BIT | ULT_BIT | UGT_BIT, + SLE = SLT | EQ_BIT, + SGE = SGT | EQ_BIT, + ULE = ULT | EQ_BIT, + UGE = UGT | EQ_BIT + }; + +#ifndef NDEBUG + /// validPredicate - determines whether a given value is actually a lattice + /// value. Only used in assertions or debugging. + static bool validPredicate(LatticeVal LV) { + switch (LV) { + case GT: case GE: case LT: case LE: case NE: + case SGTULT: case SGT: case SGEULE: + case SLTUGT: case SLT: case SLEUGE: + case ULT: case UGT: + case SLE: case SGE: case ULE: case UGE: + return true; + default: + return false; + } + } +#endif + + /// reversePredicate - reverse the direction of the inequality + static LatticeVal reversePredicate(LatticeVal LV) { + unsigned reverse = LV ^ (SLT_BIT|SGT_BIT|ULT_BIT|UGT_BIT); //preserve EQ_BIT + + if ((reverse & (SLT_BIT|SGT_BIT)) == 0) + reverse |= (SLT_BIT|SGT_BIT); + + if ((reverse & (ULT_BIT|UGT_BIT)) == 0) + reverse |= (ULT_BIT|UGT_BIT); + + LatticeVal Rev = static_cast<LatticeVal>(reverse); + assert(validPredicate(Rev) && "Failed reversing predicate."); + return Rev; + } + + /// ValueNumbering stores the scope-specific value numbers for a given Value. + class VISIBILITY_HIDDEN ValueNumbering { + + /// VNPair is a tuple of {Value, index number, DomTreeDFS::Node}. It + /// includes the comparison operators necessary to allow you to store it + /// in a sorted vector. + class VISIBILITY_HIDDEN VNPair { + public: + Value *V; + unsigned index; + DomTreeDFS::Node *Subtree; + + VNPair(Value *V, unsigned index, DomTreeDFS::Node *Subtree) + : V(V), index(index), Subtree(Subtree) {} + + bool operator==(const VNPair &RHS) const { + return V == RHS.V && Subtree == RHS.Subtree; + } + + bool operator<(const VNPair &RHS) const { + if (V != RHS.V) return V < RHS.V; + return *Subtree < *RHS.Subtree; + } + + bool operator<(Value *RHS) const { + return V < RHS; + } + + bool operator>(Value *RHS) const { + return V > RHS; + } + + friend bool operator<(Value *RHS, const VNPair &pair) { + return pair.operator>(RHS); + } + }; + + typedef std::vector<VNPair> VNMapType; + VNMapType VNMap; + + /// The canonical choice for value number at index. + std::vector<Value *> Values; + + DomTreeDFS *DTDFS; + + public: +#ifndef NDEBUG + virtual ~ValueNumbering() {} + virtual void dump() { + dump(*cerr.stream()); + } + + void dump(std::ostream &os) { + for (unsigned i = 1; i <= Values.size(); ++i) { + os << i << " = "; + WriteAsOperand(os, Values[i-1]); + os << " {"; + for (unsigned j = 0; j < VNMap.size(); ++j) { + if (VNMap[j].index == i) { + WriteAsOperand(os, VNMap[j].V); + os << " (" << VNMap[j].Subtree->getDFSNumIn() << ") "; + } + } + os << "}\n"; + } + } +#endif + + /// compare - returns true if V1 is a better canonical value than V2. + bool compare(Value *V1, Value *V2) const { + if (isa<Constant>(V1)) + return !isa<Constant>(V2); + else if (isa<Constant>(V2)) + return false; + else if (isa<Argument>(V1)) + return !isa<Argument>(V2); + else if (isa<Argument>(V2)) + return false; + + Instruction *I1 = dyn_cast<Instruction>(V1); + Instruction *I2 = dyn_cast<Instruction>(V2); + + if (!I1 || !I2) + return V1->getNumUses() < V2->getNumUses(); + + return DTDFS->dominates(I1, I2); + } + + ValueNumbering(DomTreeDFS *DTDFS) : DTDFS(DTDFS) {} + + /// valueNumber - finds the value number for V under the Subtree. If + /// there is no value number, returns zero. + unsigned valueNumber(Value *V, DomTreeDFS::Node *Subtree) { + if (!(isa<Constant>(V) || isa<Argument>(V) || isa<Instruction>(V)) + || V->getType() == Type::VoidTy) return 0; + + VNMapType::iterator E = VNMap.end(); + VNPair pair(V, 0, Subtree); + VNMapType::iterator I = std::lower_bound(VNMap.begin(), E, pair); + while (I != E && I->V == V) { + if (I->Subtree->dominates(Subtree)) + return I->index; + ++I; + } + return 0; + } + + /// getOrInsertVN - always returns a value number, creating it if necessary. + unsigned getOrInsertVN(Value *V, DomTreeDFS::Node *Subtree) { + if (unsigned n = valueNumber(V, Subtree)) + return n; + else + return newVN(V); + } + + /// newVN - creates a new value number. Value V must not already have a + /// value number assigned. + unsigned newVN(Value *V) { + assert((isa<Constant>(V) || isa<Argument>(V) || isa<Instruction>(V)) && + "Bad Value for value numbering."); + assert(V->getType() != Type::VoidTy && "Won't value number a void value"); + + Values.push_back(V); + + VNPair pair = VNPair(V, Values.size(), DTDFS->getRootNode()); + VNMapType::iterator I = std::lower_bound(VNMap.begin(), VNMap.end(), pair); + assert((I == VNMap.end() || value(I->index) != V) && + "Attempt to create a duplicate value number."); + VNMap.insert(I, pair); + + return Values.size(); + } + + /// value - returns the Value associated with a value number. + Value *value(unsigned index) const { + assert(index != 0 && "Zero index is reserved for not found."); + assert(index <= Values.size() && "Index out of range."); + return Values[index-1]; + } + + /// canonicalize - return a Value that is equal to V under Subtree. + Value *canonicalize(Value *V, DomTreeDFS::Node *Subtree) { + if (isa<Constant>(V)) return V; + + if (unsigned n = valueNumber(V, Subtree)) + return value(n); + else + return V; + } + + /// addEquality - adds that value V belongs to the set of equivalent + /// values defined by value number n under Subtree. + void addEquality(unsigned n, Value *V, DomTreeDFS::Node *Subtree) { + assert(canonicalize(value(n), Subtree) == value(n) && + "Node's 'canonical' choice isn't best within this subtree."); + + // Suppose that we are given "%x -> node #1 (%y)". The problem is that + // we may already have "%z -> node #2 (%x)" somewhere above us in the + // graph. We need to find those edges and add "%z -> node #1 (%y)" + // to keep the lookups canonical. + + std::vector<Value *> ToRepoint(1, V); + + if (unsigned Conflict = valueNumber(V, Subtree)) { + for (VNMapType::iterator I = VNMap.begin(), E = VNMap.end(); + I != E; ++I) { + if (I->index == Conflict && I->Subtree->dominates(Subtree)) + ToRepoint.push_back(I->V); + } + } + + for (std::vector<Value *>::iterator VI = ToRepoint.begin(), + VE = ToRepoint.end(); VI != VE; ++VI) { + Value *V = *VI; + + VNPair pair(V, n, Subtree); + VNMapType::iterator B = VNMap.begin(), E = VNMap.end(); + VNMapType::iterator I = std::lower_bound(B, E, pair); + if (I != E && I->V == V && I->Subtree == Subtree) + I->index = n; // Update best choice + else + VNMap.insert(I, pair); // New Value + + // XXX: we currently don't have to worry about updating values with + // more specific Subtrees, but we will need to for PHI node support. + +#ifndef NDEBUG + Value *V_n = value(n); + if (isa<Constant>(V) && isa<Constant>(V_n)) { + assert(V == V_n && "Constant equals different constant?"); + } +#endif + } + } + + /// remove - removes all references to value V. + void remove(Value *V) { + VNMapType::iterator B = VNMap.begin(), E = VNMap.end(); + VNPair pair(V, 0, DTDFS->getRootNode()); + VNMapType::iterator J = std::upper_bound(B, E, pair); + VNMapType::iterator I = J; + + while (I != B && (I == E || I->V == V)) --I; + + VNMap.erase(I, J); + } + }; + + /// The InequalityGraph stores the relationships between values. + /// Each Value in the graph is assigned to a Node. Nodes are pointer + /// comparable for equality. The caller is expected to maintain the logical + /// consistency of the system. + /// + /// The InequalityGraph class may invalidate Node*s after any mutator call. + /// @brief The InequalityGraph stores the relationships between values. + class VISIBILITY_HIDDEN InequalityGraph { + ValueNumbering &VN; + DomTreeDFS::Node *TreeRoot; + + InequalityGraph(); // DO NOT IMPLEMENT + InequalityGraph(InequalityGraph &); // DO NOT IMPLEMENT + public: + InequalityGraph(ValueNumbering &VN, DomTreeDFS::Node *TreeRoot) + : VN(VN), TreeRoot(TreeRoot) {} + + class Node; + + /// An Edge is contained inside a Node making one end of the edge implicit + /// and contains a pointer to the other end. The edge contains a lattice + /// value specifying the relationship and an DomTreeDFS::Node specifying + /// the root in the dominator tree to which this edge applies. + class VISIBILITY_HIDDEN Edge { + public: + Edge(unsigned T, LatticeVal V, DomTreeDFS::Node *ST) + : To(T), LV(V), Subtree(ST) {} + + unsigned To; + LatticeVal LV; + DomTreeDFS::Node *Subtree; + + bool operator<(const Edge &edge) const { + if (To != edge.To) return To < edge.To; + return *Subtree < *edge.Subtree; + } + + bool operator<(unsigned to) const { + return To < to; + } + + bool operator>(unsigned to) const { + return To > to; + } + + friend bool operator<(unsigned to, const Edge &edge) { + return edge.operator>(to); + } + }; + + /// A single node in the InequalityGraph. This stores the canonical Value + /// for the node, as well as the relationships with the neighbours. + /// + /// @brief A single node in the InequalityGraph. + class VISIBILITY_HIDDEN Node { + friend class InequalityGraph; + + typedef SmallVector<Edge, 4> RelationsType; + RelationsType Relations; + + // TODO: can this idea improve performance? + //friend class std::vector<Node>; + //Node(Node &N) { RelationsType.swap(N.RelationsType); } + + public: + typedef RelationsType::iterator iterator; + typedef RelationsType::const_iterator const_iterator; + +#ifndef NDEBUG + virtual ~Node() {} + virtual void dump() const { + dump(*cerr.stream()); + } + private: + void dump(std::ostream &os) const { + static const std::string names[32] = + { "000000", "000001", "000002", "000003", "000004", "000005", + "000006", "000007", "000008", "000009", " >", " >=", + " s>u<", "s>=u<=", " s>", " s>=", "000016", "000017", + " s<u>", "s<=u>=", " <", " <=", " s<", " s<=", + "000024", "000025", " u>", " u>=", " u<", " u<=", + " !=", "000031" }; + for (Node::const_iterator NI = begin(), NE = end(); NI != NE; ++NI) { + os << names[NI->LV] << " " << NI->To + << " (" << NI->Subtree->getDFSNumIn() << "), "; + } + } + public: +#endif + + iterator begin() { return Relations.begin(); } + iterator end() { return Relations.end(); } + const_iterator begin() const { return Relations.begin(); } + const_iterator end() const { return Relations.end(); } + + iterator find(unsigned n, DomTreeDFS::Node *Subtree) { + iterator E = end(); + for (iterator I = std::lower_bound(begin(), E, n); + I != E && I->To == n; ++I) { + if (Subtree->DominatedBy(I->Subtree)) + return I; + } + return E; + } + + const_iterator find(unsigned n, DomTreeDFS::Node *Subtree) const { + const_iterator E = end(); + for (const_iterator I = std::lower_bound(begin(), E, n); + I != E && I->To == n; ++I) { + if (Subtree->DominatedBy(I->Subtree)) + return I; + } + return E; + } + + /// update - updates the lattice value for a given node, creating a new + /// entry if one doesn't exist. The new lattice value must not be + /// inconsistent with any previously existing value. + void update(unsigned n, LatticeVal R, DomTreeDFS::Node *Subtree) { + assert(validPredicate(R) && "Invalid predicate."); + + Edge edge(n, R, Subtree); + iterator B = begin(), E = end(); + iterator I = std::lower_bound(B, E, edge); + + iterator J = I; + while (J != E && J->To == n) { + if (Subtree->DominatedBy(J->Subtree)) + break; + ++J; + } + + if (J != E && J->To == n) { + edge.LV = static_cast<LatticeVal>(J->LV & R); + assert(validPredicate(edge.LV) && "Invalid union of lattice values."); + + if (edge.LV == J->LV) + return; // This update adds nothing new. + } + + if (I != B) { + // We also have to tighten any edge beneath our update. + for (iterator K = I - 1; K->To == n; --K) { + if (K->Subtree->DominatedBy(Subtree)) { + LatticeVal LV = static_cast<LatticeVal>(K->LV & edge.LV); + assert(validPredicate(LV) && "Invalid union of lattice values"); + K->LV = LV; + } + if (K == B) break; + } + } + + // Insert new edge at Subtree if it isn't already there. + if (I == E || I->To != n || Subtree != I->Subtree) + Relations.insert(I, edge); + } + }; + + private: + + std::vector<Node> Nodes; + + public: + /// node - returns the node object at a given value number. The pointer + /// returned may be invalidated on the next call to node(). + Node *node(unsigned index) { + assert(VN.value(index)); // This triggers the necessary checks. + if (Nodes.size() < index) Nodes.resize(index); + return &Nodes[index-1]; + } + + /// isRelatedBy - true iff n1 op n2 + bool isRelatedBy(unsigned n1, unsigned n2, DomTreeDFS::Node *Subtree, + LatticeVal LV) { + if (n1 == n2) return LV & EQ_BIT; + + Node *N1 = node(n1); + Node::iterator I = N1->find(n2, Subtree), E = N1->end(); + if (I != E) return (I->LV & LV) == I->LV; + + return false; + } + + // The add* methods assume that your input is logically valid and may + // assertion-fail or infinitely loop if you attempt a contradiction. + + /// addInequality - Sets n1 op n2. + /// It is also an error to call this on an inequality that is already true. + void addInequality(unsigned n1, unsigned n2, DomTreeDFS::Node *Subtree, + LatticeVal LV1) { + assert(n1 != n2 && "A node can't be inequal to itself."); + + if (LV1 != NE) + assert(!isRelatedBy(n1, n2, Subtree, reversePredicate(LV1)) && + "Contradictory inequality."); + + // Suppose we're adding %n1 < %n2. Find all the %a < %n1 and + // add %a < %n2 too. This keeps the graph fully connected. + if (LV1 != NE) { + // Break up the relationship into signed and unsigned comparison parts. + // If the signed parts of %a op1 %n1 match that of %n1 op2 %n2, and + // op1 and op2 aren't NE, then add %a op3 %n2. The new relationship + // should have the EQ_BIT iff it's set for both op1 and op2. + + unsigned LV1_s = LV1 & (SLT_BIT|SGT_BIT); + unsigned LV1_u = LV1 & (ULT_BIT|UGT_BIT); + + for (Node::iterator I = node(n1)->begin(), E = node(n1)->end(); I != E; ++I) { + if (I->LV != NE && I->To != n2) { + + DomTreeDFS::Node *Local_Subtree = NULL; + if (Subtree->DominatedBy(I->Subtree)) + Local_Subtree = Subtree; + else if (I->Subtree->DominatedBy(Subtree)) + Local_Subtree = I->Subtree; + + if (Local_Subtree) { + unsigned new_relationship = 0; + LatticeVal ILV = reversePredicate(I->LV); + unsigned ILV_s = ILV & (SLT_BIT|SGT_BIT); + unsigned ILV_u = ILV & (ULT_BIT|UGT_BIT); + + if (LV1_s != (SLT_BIT|SGT_BIT) && ILV_s == LV1_s) + new_relationship |= ILV_s; + if (LV1_u != (ULT_BIT|UGT_BIT) && ILV_u == LV1_u) + new_relationship |= ILV_u; + + if (new_relationship) { + if ((new_relationship & (SLT_BIT|SGT_BIT)) == 0) + new_relationship |= (SLT_BIT|SGT_BIT); + if ((new_relationship & (ULT_BIT|UGT_BIT)) == 0) + new_relationship |= (ULT_BIT|UGT_BIT); + if ((LV1 & EQ_BIT) && (ILV & EQ_BIT)) + new_relationship |= EQ_BIT; + + LatticeVal NewLV = static_cast<LatticeVal>(new_relationship); + + node(I->To)->update(n2, NewLV, Local_Subtree); + node(n2)->update(I->To, reversePredicate(NewLV), Local_Subtree); + } + } + } + } + + for (Node::iterator I = node(n2)->begin(), E = node(n2)->end(); I != E; ++I) { + if (I->LV != NE && I->To != n1) { + DomTreeDFS::Node *Local_Subtree = NULL; + if (Subtree->DominatedBy(I->Subtree)) + Local_Subtree = Subtree; + else if (I->Subtree->DominatedBy(Subtree)) + Local_Subtree = I->Subtree; + + if (Local_Subtree) { + unsigned new_relationship = 0; + unsigned ILV_s = I->LV & (SLT_BIT|SGT_BIT); + unsigned ILV_u = I->LV & (ULT_BIT|UGT_BIT); + + if (LV1_s != (SLT_BIT|SGT_BIT) && ILV_s == LV1_s) + new_relationship |= ILV_s; + + if (LV1_u != (ULT_BIT|UGT_BIT) && ILV_u == LV1_u) + new_relationship |= ILV_u; + + if (new_relationship) { + if ((new_relationship & (SLT_BIT|SGT_BIT)) == 0) + new_relationship |= (SLT_BIT|SGT_BIT); + if ((new_relationship & (ULT_BIT|UGT_BIT)) == 0) + new_relationship |= (ULT_BIT|UGT_BIT); + if ((LV1 & EQ_BIT) && (I->LV & EQ_BIT)) + new_relationship |= EQ_BIT; + + LatticeVal NewLV = static_cast<LatticeVal>(new_relationship); + + node(n1)->update(I->To, NewLV, Local_Subtree); + node(I->To)->update(n1, reversePredicate(NewLV), Local_Subtree); + } + } + } + } + } + + node(n1)->update(n2, LV1, Subtree); + node(n2)->update(n1, reversePredicate(LV1), Subtree); + } + + /// remove - removes a node from the graph by removing all references to + /// and from it. + void remove(unsigned n) { + Node *N = node(n); + for (Node::iterator NI = N->begin(), NE = N->end(); NI != NE; ++NI) { + Node::iterator Iter = node(NI->To)->find(n, TreeRoot); + do { + node(NI->To)->Relations.erase(Iter); + Iter = node(NI->To)->find(n, TreeRoot); + } while (Iter != node(NI->To)->end()); + } + N->Relations.clear(); + } + +#ifndef NDEBUG + virtual ~InequalityGraph() {} + virtual void dump() { + dump(*cerr.stream()); + } + + void dump(std::ostream &os) { + for (unsigned i = 1; i <= Nodes.size(); ++i) { + os << i << " = {"; + node(i)->dump(os); + os << "}\n"; + } + } +#endif + }; + + class VRPSolver; + + /// ValueRanges tracks the known integer ranges and anti-ranges of the nodes + /// in the InequalityGraph. + class VISIBILITY_HIDDEN ValueRanges { + ValueNumbering &VN; + TargetData *TD; + + class VISIBILITY_HIDDEN ScopedRange { + typedef std::vector<std::pair<DomTreeDFS::Node *, ConstantRange> > + RangeListType; + RangeListType RangeList; + + static bool swo(const std::pair<DomTreeDFS::Node *, ConstantRange> &LHS, + const std::pair<DomTreeDFS::Node *, ConstantRange> &RHS) { + return *LHS.first < *RHS.first; + } + + public: +#ifndef NDEBUG + virtual ~ScopedRange() {} + virtual void dump() const { + dump(*cerr.stream()); + } + + void dump(std::ostream &os) const { + os << "{"; + for (const_iterator I = begin(), E = end(); I != E; ++I) { + os << &I->second << " (" << I->first->getDFSNumIn() << "), "; + } + os << "}"; + } +#endif + + typedef RangeListType::iterator iterator; + typedef RangeListType::const_iterator const_iterator; + + iterator begin() { return RangeList.begin(); } + iterator end() { return RangeList.end(); } + const_iterator begin() const { return RangeList.begin(); } + const_iterator end() const { return RangeList.end(); } + + iterator find(DomTreeDFS::Node *Subtree) { + static ConstantRange empty(1, false); + iterator E = end(); + iterator I = std::lower_bound(begin(), E, + std::make_pair(Subtree, empty), swo); + + while (I != E && !I->first->dominates(Subtree)) ++I; + return I; + } + + const_iterator find(DomTreeDFS::Node *Subtree) const { + static const ConstantRange empty(1, false); + const_iterator E = end(); + const_iterator I = std::lower_bound(begin(), E, + std::make_pair(Subtree, empty), swo); + + while (I != E && !I->first->dominates(Subtree)) ++I; + return I; + } + + void update(const ConstantRange &CR, DomTreeDFS::Node *Subtree) { + assert(!CR.isEmptySet() && "Empty ConstantRange."); + assert(!CR.isSingleElement() && "Refusing to store single element."); + + static ConstantRange empty(1, false); + iterator E = end(); + iterator I = + std::lower_bound(begin(), E, std::make_pair(Subtree, empty), swo); + + if (I != end() && I->first == Subtree) { + ConstantRange CR2 = I->second.maximalIntersectWith(CR); + assert(!CR2.isEmptySet() && !CR2.isSingleElement() && + "Invalid union of ranges."); + I->second = CR2; + } else + RangeList.insert(I, std::make_pair(Subtree, CR)); + } + }; + + std::vector<ScopedRange> Ranges; + + void update(unsigned n, const ConstantRange &CR, DomTreeDFS::Node *Subtree){ + if (CR.isFullSet()) return; + if (Ranges.size() < n) Ranges.resize(n); + Ranges[n-1].update(CR, Subtree); + } + + /// create - Creates a ConstantRange that matches the given LatticeVal + /// relation with a given integer. + ConstantRange create(LatticeVal LV, const ConstantRange &CR) { + assert(!CR.isEmptySet() && "Can't deal with empty set."); + + if (LV == NE) + return makeConstantRange(ICmpInst::ICMP_NE, CR); + + unsigned LV_s = LV & (SGT_BIT|SLT_BIT); + unsigned LV_u = LV & (UGT_BIT|ULT_BIT); + bool hasEQ = LV & EQ_BIT; + + ConstantRange Range(CR.getBitWidth()); + + if (LV_s == SGT_BIT) { + Range = Range.maximalIntersectWith(makeConstantRange( + hasEQ ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_SGT, CR)); + } else if (LV_s == SLT_BIT) { + Range = Range.maximalIntersectWith(makeConstantRange( + hasEQ ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_SLT, CR)); + } + + if (LV_u == UGT_BIT) { + Range = Range.maximalIntersectWith(makeConstantRange( + hasEQ ? ICmpInst::ICMP_UGE : ICmpInst::ICMP_UGT, CR)); + } else if (LV_u == ULT_BIT) { + Range = Range.maximalIntersectWith(makeConstantRange( + hasEQ ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_ULT, CR)); + } + + return Range; + } + + /// makeConstantRange - Creates a ConstantRange representing the set of all + /// value that match the ICmpInst::Predicate with any of the values in CR. + ConstantRange makeConstantRange(ICmpInst::Predicate ICmpOpcode, + const ConstantRange &CR) { + uint32_t W = CR.getBitWidth(); + switch (ICmpOpcode) { + default: assert(!"Invalid ICmp opcode to makeConstantRange()"); + case ICmpInst::ICMP_EQ: + return ConstantRange(CR.getLower(), CR.getUpper()); + case ICmpInst::ICMP_NE: + if (CR.isSingleElement()) + return ConstantRange(CR.getUpper(), CR.getLower()); + return ConstantRange(W); + case ICmpInst::ICMP_ULT: + return ConstantRange(APInt::getMinValue(W), CR.getUnsignedMax()); + case ICmpInst::ICMP_SLT: + return ConstantRange(APInt::getSignedMinValue(W), CR.getSignedMax()); + case ICmpInst::ICMP_ULE: { + APInt UMax(CR.getUnsignedMax()); + if (UMax.isMaxValue()) + return ConstantRange(W); + return ConstantRange(APInt::getMinValue(W), UMax + 1); + } + case ICmpInst::ICMP_SLE: { + APInt SMax(CR.getSignedMax()); + if (SMax.isMaxSignedValue() || (SMax+1).isMaxSignedValue()) + return ConstantRange(W); + return ConstantRange(APInt::getSignedMinValue(W), SMax + 1); + } + case ICmpInst::ICMP_UGT: + return ConstantRange(CR.getUnsignedMin() + 1, APInt::getNullValue(W)); + case ICmpInst::ICMP_SGT: + return ConstantRange(CR.getSignedMin() + 1, + APInt::getSignedMinValue(W)); + case ICmpInst::ICMP_UGE: { + APInt UMin(CR.getUnsignedMin()); + if (UMin.isMinValue()) + return ConstantRange(W); + return ConstantRange(UMin, APInt::getNullValue(W)); + } + case ICmpInst::ICMP_SGE: { + APInt SMin(CR.getSignedMin()); + if (SMin.isMinSignedValue()) + return ConstantRange(W); + return ConstantRange(SMin, APInt::getSignedMinValue(W)); + } + } + } + +#ifndef NDEBUG + bool isCanonical(Value *V, DomTreeDFS::Node *Subtree) { + return V == VN.canonicalize(V, Subtree); + } +#endif + + public: + + ValueRanges(ValueNumbering &VN, TargetData *TD) : VN(VN), TD(TD) {} + +#ifndef NDEBUG + virtual ~ValueRanges() {} + + virtual void dump() const { + dump(*cerr.stream()); + } + + void dump(std::ostream &os) const { + for (unsigned i = 0, e = Ranges.size(); i != e; ++i) { + os << (i+1) << " = "; + Ranges[i].dump(os); + os << "\n"; + } + } +#endif + + /// range - looks up the ConstantRange associated with a value number. + ConstantRange range(unsigned n, DomTreeDFS::Node *Subtree) { + assert(VN.value(n)); // performs range checks + + if (n <= Ranges.size()) { + ScopedRange::iterator I = Ranges[n-1].find(Subtree); + if (I != Ranges[n-1].end()) return I->second; + } + + Value *V = VN.value(n); + ConstantRange CR = range(V); + return CR; + } + + /// range - determine a range from a Value without performing any lookups. + ConstantRange range(Value *V) const { + if (ConstantInt *C = dyn_cast<ConstantInt>(V)) + return ConstantRange(C->getValue()); + else if (isa<ConstantPointerNull>(V)) + return ConstantRange(APInt::getNullValue(typeToWidth(V->getType()))); + else + return ConstantRange(typeToWidth(V->getType())); + } + + // typeToWidth - returns the number of bits necessary to store a value of + // this type, or zero if unknown. + uint32_t typeToWidth(const Type *Ty) const { + if (TD) + return TD->getTypeSizeInBits(Ty); + else + return Ty->getPrimitiveSizeInBits(); + } + + static bool isRelatedBy(const ConstantRange &CR1, const ConstantRange &CR2, + LatticeVal LV) { + switch (LV) { + default: assert(!"Impossible lattice value!"); + case NE: + return CR1.maximalIntersectWith(CR2).isEmptySet(); + case ULT: + return CR1.getUnsignedMax().ult(CR2.getUnsignedMin()); + case ULE: + return CR1.getUnsignedMax().ule(CR2.getUnsignedMin()); + case UGT: + return CR1.getUnsignedMin().ugt(CR2.getUnsignedMax()); + case UGE: + return CR1.getUnsignedMin().uge(CR2.getUnsignedMax()); + case SLT: + return CR1.getSignedMax().slt(CR2.getSignedMin()); + case SLE: + return CR1.getSignedMax().sle(CR2.getSignedMin()); + case SGT: + return CR1.getSignedMin().sgt(CR2.getSignedMax()); + case SGE: + return CR1.getSignedMin().sge(CR2.getSignedMax()); + case LT: + return CR1.getUnsignedMax().ult(CR2.getUnsignedMin()) && + CR1.getSignedMax().slt(CR2.getUnsignedMin()); + case LE: + return CR1.getUnsignedMax().ule(CR2.getUnsignedMin()) && + CR1.getSignedMax().sle(CR2.getUnsignedMin()); + case GT: + return CR1.getUnsignedMin().ugt(CR2.getUnsignedMax()) && + CR1.getSignedMin().sgt(CR2.getSignedMax()); + case GE: + return CR1.getUnsignedMin().uge(CR2.getUnsignedMax()) && + CR1.getSignedMin().sge(CR2.getSignedMax()); + case SLTUGT: + return CR1.getSignedMax().slt(CR2.getSignedMin()) && + CR1.getUnsignedMin().ugt(CR2.getUnsignedMax()); + case SLEUGE: + return CR1.getSignedMax().sle(CR2.getSignedMin()) && + CR1.getUnsignedMin().uge(CR2.getUnsignedMax()); + case SGTULT: + return CR1.getSignedMin().sgt(CR2.getSignedMax()) && + CR1.getUnsignedMax().ult(CR2.getUnsignedMin()); + case SGEULE: + return CR1.getSignedMin().sge(CR2.getSignedMax()) && + CR1.getUnsignedMax().ule(CR2.getUnsignedMin()); + } + } + + bool isRelatedBy(unsigned n1, unsigned n2, DomTreeDFS::Node *Subtree, + LatticeVal LV) { + ConstantRange CR1 = range(n1, Subtree); + ConstantRange CR2 = range(n2, Subtree); + + // True iff all values in CR1 are LV to all values in CR2. + return isRelatedBy(CR1, CR2, LV); + } + + void addToWorklist(Value *V, Constant *C, ICmpInst::Predicate Pred, + VRPSolver *VRP); + void markBlock(VRPSolver *VRP); + + void mergeInto(Value **I, unsigned n, unsigned New, + DomTreeDFS::Node *Subtree, VRPSolver *VRP) { + ConstantRange CR_New = range(New, Subtree); + ConstantRange Merged = CR_New; + + for (; n != 0; ++I, --n) { + unsigned i = VN.valueNumber(*I, Subtree); + ConstantRange CR_Kill = i ? range(i, Subtree) : range(*I); + if (CR_Kill.isFullSet()) continue; + Merged = Merged.maximalIntersectWith(CR_Kill); + } + + if (Merged.isFullSet() || Merged == CR_New) return; + + applyRange(New, Merged, Subtree, VRP); + } + + void applyRange(unsigned n, const ConstantRange &CR, + DomTreeDFS::Node *Subtree, VRPSolver *VRP) { + ConstantRange Merged = CR.maximalIntersectWith(range(n, Subtree)); + if (Merged.isEmptySet()) { + markBlock(VRP); + return; + } + + if (const APInt *I = Merged.getSingleElement()) { + Value *V = VN.value(n); // XXX: redesign worklist. + const Type *Ty = V->getType(); + if (Ty->isInteger()) { + addToWorklist(V, ConstantInt::get(*I), ICmpInst::ICMP_EQ, VRP); + return; + } else if (const PointerType *PTy = dyn_cast<PointerType>(Ty)) { + assert(*I == 0 && "Pointer is null but not zero?"); + addToWorklist(V, ConstantPointerNull::get(PTy), + ICmpInst::ICMP_EQ, VRP); + return; + } + } + + update(n, Merged, Subtree); + } + + void addNotEquals(unsigned n1, unsigned n2, DomTreeDFS::Node *Subtree, + VRPSolver *VRP) { + ConstantRange CR1 = range(n1, Subtree); + ConstantRange CR2 = range(n2, Subtree); + + uint32_t W = CR1.getBitWidth(); + + if (const APInt *I = CR1.getSingleElement()) { + if (CR2.isFullSet()) { + ConstantRange NewCR2(CR1.getUpper(), CR1.getLower()); + applyRange(n2, NewCR2, Subtree, VRP); + } else if (*I == CR2.getLower()) { + APInt NewLower(CR2.getLower() + 1), + NewUpper(CR2.getUpper()); + if (NewLower == NewUpper) + NewLower = NewUpper = APInt::getMinValue(W); + + ConstantRange NewCR2(NewLower, NewUpper); + applyRange(n2, NewCR2, Subtree, VRP); + } else if (*I == CR2.getUpper() - 1) { + APInt NewLower(CR2.getLower()), + NewUpper(CR2.getUpper() - 1); + if (NewLower == NewUpper) + NewLower = NewUpper = APInt::getMinValue(W); + + ConstantRange NewCR2(NewLower, NewUpper); + applyRange(n2, NewCR2, Subtree, VRP); + } + } + + if (const APInt *I = CR2.getSingleElement()) { + if (CR1.isFullSet()) { + ConstantRange NewCR1(CR2.getUpper(), CR2.getLower()); + applyRange(n1, NewCR1, Subtree, VRP); + } else if (*I == CR1.getLower()) { + APInt NewLower(CR1.getLower() + 1), + NewUpper(CR1.getUpper()); + if (NewLower == NewUpper) + NewLower = NewUpper = APInt::getMinValue(W); + + ConstantRange NewCR1(NewLower, NewUpper); + applyRange(n1, NewCR1, Subtree, VRP); + } else if (*I == CR1.getUpper() - 1) { + APInt NewLower(CR1.getLower()), + NewUpper(CR1.getUpper() - 1); + if (NewLower == NewUpper) + NewLower = NewUpper = APInt::getMinValue(W); + + ConstantRange NewCR1(NewLower, NewUpper); + applyRange(n1, NewCR1, Subtree, VRP); + } + } + } + + void addInequality(unsigned n1, unsigned n2, DomTreeDFS::Node *Subtree, + LatticeVal LV, VRPSolver *VRP) { + assert(!isRelatedBy(n1, n2, Subtree, LV) && "Asked to do useless work."); + + if (LV == NE) { + addNotEquals(n1, n2, Subtree, VRP); + return; + } + + ConstantRange CR1 = range(n1, Subtree); + ConstantRange CR2 = range(n2, Subtree); + + if (!CR1.isSingleElement()) { + ConstantRange NewCR1 = CR1.maximalIntersectWith(create(LV, CR2)); + if (NewCR1 != CR1) + applyRange(n1, NewCR1, Subtree, VRP); + } + + if (!CR2.isSingleElement()) { + ConstantRange NewCR2 = CR2.maximalIntersectWith( + create(reversePredicate(LV), CR1)); + if (NewCR2 != CR2) + applyRange(n2, NewCR2, Subtree, VRP); + } + } + }; + + /// UnreachableBlocks keeps tracks of blocks that are for one reason or + /// another discovered to be unreachable. This is used to cull the graph when + /// analyzing instructions, and to mark blocks with the "unreachable" + /// terminator instruction after the function has executed. + class VISIBILITY_HIDDEN UnreachableBlocks { + private: + std::vector<BasicBlock *> DeadBlocks; + + public: + /// mark - mark a block as dead + void mark(BasicBlock *BB) { + std::vector<BasicBlock *>::iterator E = DeadBlocks.end(); + std::vector<BasicBlock *>::iterator I = + std::lower_bound(DeadBlocks.begin(), E, BB); + + if (I == E || *I != BB) DeadBlocks.insert(I, BB); + } + + /// isDead - returns whether a block is known to be dead already + bool isDead(BasicBlock *BB) { + std::vector<BasicBlock *>::iterator E = DeadBlocks.end(); + std::vector<BasicBlock *>::iterator I = + std::lower_bound(DeadBlocks.begin(), E, BB); + + return I != E && *I == BB; + } + + /// kill - replace the dead blocks' terminator with an UnreachableInst. + bool kill() { + bool modified = false; + for (std::vector<BasicBlock *>::iterator I = DeadBlocks.begin(), + E = DeadBlocks.end(); I != E; ++I) { + BasicBlock *BB = *I; + + DOUT << "unreachable block: " << BB->getName() << "\n"; + + for (succ_iterator SI = succ_begin(BB), SE = succ_end(BB); + SI != SE; ++SI) { + BasicBlock *Succ = *SI; + Succ->removePredecessor(BB); + } + + TerminatorInst *TI = BB->getTerminator(); + TI->replaceAllUsesWith(UndefValue::get(TI->getType())); + TI->eraseFromParent(); + new UnreachableInst(BB); + ++NumBlocks; + modified = true; + } + DeadBlocks.clear(); + return modified; + } + }; + + /// VRPSolver keeps track of how changes to one variable affect other + /// variables, and forwards changes along to the InequalityGraph. It + /// also maintains the correct choice for "canonical" in the IG. + /// @brief VRPSolver calculates inferences from a new relationship. + class VISIBILITY_HIDDEN VRPSolver { + private: + friend class ValueRanges; + + struct Operation { + Value *LHS, *RHS; + ICmpInst::Predicate Op; + + BasicBlock *ContextBB; // XXX use a DomTreeDFS::Node instead + Instruction *ContextInst; + }; + std::deque<Operation> WorkList; + + ValueNumbering &VN; + InequalityGraph &IG; + UnreachableBlocks &UB; + ValueRanges &VR; + DomTreeDFS *DTDFS; + DomTreeDFS::Node *Top; + BasicBlock *TopBB; + Instruction *TopInst; + bool &modified; + + typedef InequalityGraph::Node Node; + + // below - true if the Instruction is dominated by the current context + // block or instruction + bool below(Instruction *I) { + BasicBlock *BB = I->getParent(); + if (TopInst && TopInst->getParent() == BB) { + if (isa<TerminatorInst>(TopInst)) return false; + if (isa<TerminatorInst>(I)) return true; + if ( isa<PHINode>(TopInst) && !isa<PHINode>(I)) return true; + if (!isa<PHINode>(TopInst) && isa<PHINode>(I)) return false; + + for (BasicBlock::const_iterator Iter = BB->begin(), E = BB->end(); + Iter != E; ++Iter) { + if (&*Iter == TopInst) return true; + else if (&*Iter == I) return false; + } + assert(!"Instructions not found in parent BasicBlock?"); + } else { + DomTreeDFS::Node *Node = DTDFS->getNodeForBlock(BB); + if (!Node) return false; + return Top->dominates(Node); + } + return false; // Not reached + } + + // aboveOrBelow - true if the Instruction either dominates or is dominated + // by the current context block or instruction + bool aboveOrBelow(Instruction *I) { + BasicBlock *BB = I->getParent(); + DomTreeDFS::Node *Node = DTDFS->getNodeForBlock(BB); + if (!Node) return false; + + return Top == Node || Top->dominates(Node) || Node->dominates(Top); + } + + bool makeEqual(Value *V1, Value *V2) { + DOUT << "makeEqual(" << *V1 << ", " << *V2 << ")\n"; + DOUT << "context is "; + if (TopInst) DOUT << "I: " << *TopInst << "\n"; + else DOUT << "BB: " << TopBB->getName() + << "(" << Top->getDFSNumIn() << ")\n"; + + assert(V1->getType() == V2->getType() && + "Can't make two values with different types equal."); + + if (V1 == V2) return true; + + if (isa<Constant>(V1) && isa<Constant>(V2)) + return false; + + unsigned n1 = VN.valueNumber(V1, Top), n2 = VN.valueNumber(V2, Top); + + if (n1 && n2) { + if (n1 == n2) return true; + if (IG.isRelatedBy(n1, n2, Top, NE)) return false; + } + + if (n1) assert(V1 == VN.value(n1) && "Value isn't canonical."); + if (n2) assert(V2 == VN.value(n2) && "Value isn't canonical."); + + assert(!VN.compare(V2, V1) && "Please order parameters to makeEqual."); + + assert(!isa<Constant>(V2) && "Tried to remove a constant."); + + SetVector<unsigned> Remove; + if (n2) Remove.insert(n2); + + if (n1 && n2) { + // Suppose we're being told that %x == %y, and %x <= %z and %y >= %z. + // We can't just merge %x and %y because the relationship with %z would + // be EQ and that's invalid. What we're doing is looking for any nodes + // %z such that %x <= %z and %y >= %z, and vice versa. + + Node::iterator end = IG.node(n2)->end(); + + // Find the intersection between N1 and N2 which is dominated by + // Top. If we find %x where N1 <= %x <= N2 (or >=) then add %x to + // Remove. + for (Node::iterator I = IG.node(n1)->begin(), E = IG.node(n1)->end(); + I != E; ++I) { + if (!(I->LV & EQ_BIT) || !Top->DominatedBy(I->Subtree)) continue; + + unsigned ILV_s = I->LV & (SLT_BIT|SGT_BIT); + unsigned ILV_u = I->LV & (ULT_BIT|UGT_BIT); + Node::iterator NI = IG.node(n2)->find(I->To, Top); + if (NI != end) { + LatticeVal NILV = reversePredicate(NI->LV); + unsigned NILV_s = NILV & (SLT_BIT|SGT_BIT); + unsigned NILV_u = NILV & (ULT_BIT|UGT_BIT); + + if ((ILV_s != (SLT_BIT|SGT_BIT) && ILV_s == NILV_s) || + (ILV_u != (ULT_BIT|UGT_BIT) && ILV_u == NILV_u)) + Remove.insert(I->To); + } + } + + // See if one of the nodes about to be removed is actually a better + // canonical choice than n1. + unsigned orig_n1 = n1; + SetVector<unsigned>::iterator DontRemove = Remove.end(); + for (SetVector<unsigned>::iterator I = Remove.begin()+1 /* skip n2 */, + E = Remove.end(); I != E; ++I) { + unsigned n = *I; + Value *V = VN.value(n); + if (VN.compare(V, V1)) { + V1 = V; + n1 = n; + DontRemove = I; + } + } + if (DontRemove != Remove.end()) { + unsigned n = *DontRemove; + Remove.remove(n); + Remove.insert(orig_n1); + } + } + + // We'd like to allow makeEqual on two values to perform a simple + // substitution without creating nodes in the IG whenever possible. + // + // The first iteration through this loop operates on V2 before going + // through the Remove list and operating on those too. If all of the + // iterations performed simple replacements then we exit early. + bool mergeIGNode = false; + unsigned i = 0; + for (Value *R = V2; i == 0 || i < Remove.size(); ++i) { + if (i) R = VN.value(Remove[i]); // skip n2. + + // Try to replace the whole instruction. If we can, we're done. + Instruction *I2 = dyn_cast<Instruction>(R); + if (I2 && below(I2)) { + std::vector<Instruction *> ToNotify; + for (Value::use_iterator UI = R->use_begin(), UE = R->use_end(); + UI != UE;) { + Use &TheUse = UI.getUse(); + ++UI; + if (Instruction *I = dyn_cast<Instruction>(TheUse.getUser())) + ToNotify.push_back(I); + } + + DOUT << "Simply removing " << *I2 + << ", replacing with " << *V1 << "\n"; + I2->replaceAllUsesWith(V1); + // leave it dead; it'll get erased later. + ++NumInstruction; + modified = true; + + for (std::vector<Instruction *>::iterator II = ToNotify.begin(), + IE = ToNotify.end(); II != IE; ++II) { + opsToDef(*II); + } + + continue; + } + + // Otherwise, replace all dominated uses. + for (Value::use_iterator UI = R->use_begin(), UE = R->use_end(); + UI != UE;) { + Use &TheUse = UI.getUse(); + ++UI; + if (Instruction *I = dyn_cast<Instruction>(TheUse.getUser())) { + if (below(I)) { + TheUse.set(V1); + modified = true; + ++NumVarsReplaced; + opsToDef(I); + } + } + } + + // If that killed the instruction, stop here. + if (I2 && isInstructionTriviallyDead(I2)) { + DOUT << "Killed all uses of " << *I2 + << ", replacing with " << *V1 << "\n"; + continue; + } + + // If we make it to here, then we will need to create a node for N1. + // Otherwise, we can skip out early! + mergeIGNode = true; + } + + if (!isa<Constant>(V1)) { + if (Remove.empty()) { + VR.mergeInto(&V2, 1, VN.getOrInsertVN(V1, Top), Top, this); + } else { + std::vector<Value*> RemoveVals; + RemoveVals.reserve(Remove.size()); + + for (SetVector<unsigned>::iterator I = Remove.begin(), + E = Remove.end(); I != E; ++I) { + Value *V = VN.value(*I); + if (!V->use_empty()) + RemoveVals.push_back(V); + } + VR.mergeInto(&RemoveVals[0], RemoveVals.size(), + VN.getOrInsertVN(V1, Top), Top, this); + } + } + + if (mergeIGNode) { + // Create N1. + if (!n1) n1 = VN.getOrInsertVN(V1, Top); + IG.node(n1); // Ensure that IG.Nodes won't get resized + + // Migrate relationships from removed nodes to N1. + for (SetVector<unsigned>::iterator I = Remove.begin(), E = Remove.end(); + I != E; ++I) { + unsigned n = *I; + for (Node::iterator NI = IG.node(n)->begin(), NE = IG.node(n)->end(); + NI != NE; ++NI) { + if (NI->Subtree->DominatedBy(Top)) { + if (NI->To == n1) { + assert((NI->LV & EQ_BIT) && "Node inequal to itself."); + continue; + } + if (Remove.count(NI->To)) + continue; + + IG.node(NI->To)->update(n1, reversePredicate(NI->LV), Top); + IG.node(n1)->update(NI->To, NI->LV, Top); + } + } + } + + // Point V2 (and all items in Remove) to N1. + if (!n2) + VN.addEquality(n1, V2, Top); + else { + for (SetVector<unsigned>::iterator I = Remove.begin(), + E = Remove.end(); I != E; ++I) { + VN.addEquality(n1, VN.value(*I), Top); + } + } + + // If !Remove.empty() then V2 = Remove[0]->getValue(). + // Even when Remove is empty, we still want to process V2. + i = 0; + for (Value *R = V2; i == 0 || i < Remove.size(); ++i) { + if (i) R = VN.value(Remove[i]); // skip n2. + + if (Instruction *I2 = dyn_cast<Instruction>(R)) { + if (aboveOrBelow(I2)) + defToOps(I2); + } + for (Value::use_iterator UI = V2->use_begin(), UE = V2->use_end(); + UI != UE;) { + Use &TheUse = UI.getUse(); + ++UI; + if (Instruction *I = dyn_cast<Instruction>(TheUse.getUser())) { + if (aboveOrBelow(I)) + opsToDef(I); + } + } + } + } + + // re-opsToDef all dominated users of V1. + if (Instruction *I = dyn_cast<Instruction>(V1)) { + for (Value::use_iterator UI = I->use_begin(), UE = I->use_end(); + UI != UE;) { + Use &TheUse = UI.getUse(); + ++UI; + Value *V = TheUse.getUser(); + if (!V->use_empty()) { + if (Instruction *Inst = dyn_cast<Instruction>(V)) { + if (aboveOrBelow(Inst)) + opsToDef(Inst); + } + } + } + } + + return true; + } + + /// cmpInstToLattice - converts an CmpInst::Predicate to lattice value + /// Requires that the lattice value be valid; does not accept ICMP_EQ. + static LatticeVal cmpInstToLattice(ICmpInst::Predicate Pred) { + switch (Pred) { + case ICmpInst::ICMP_EQ: + assert(!"No matching lattice value."); + return static_cast<LatticeVal>(EQ_BIT); + default: + assert(!"Invalid 'icmp' predicate."); + case ICmpInst::ICMP_NE: + return NE; + case ICmpInst::ICMP_UGT: + return UGT; + case ICmpInst::ICMP_UGE: + return UGE; + case ICmpInst::ICMP_ULT: + return ULT; + case ICmpInst::ICMP_ULE: + return ULE; + case ICmpInst::ICMP_SGT: + return SGT; + case ICmpInst::ICMP_SGE: + return SGE; + case ICmpInst::ICMP_SLT: + return SLT; + case ICmpInst::ICMP_SLE: + return SLE; + } + } + + public: + VRPSolver(ValueNumbering &VN, InequalityGraph &IG, UnreachableBlocks &UB, + ValueRanges &VR, DomTreeDFS *DTDFS, bool &modified, + BasicBlock *TopBB) + : VN(VN), + IG(IG), + UB(UB), + VR(VR), + DTDFS(DTDFS), + Top(DTDFS->getNodeForBlock(TopBB)), + TopBB(TopBB), + TopInst(NULL), + modified(modified) + { + assert(Top && "VRPSolver created for unreachable basic block."); + } + + VRPSolver(ValueNumbering &VN, InequalityGraph &IG, UnreachableBlocks &UB, + ValueRanges &VR, DomTreeDFS *DTDFS, bool &modified, + Instruction *TopInst) + : VN(VN), + IG(IG), + UB(UB), + VR(VR), + DTDFS(DTDFS), + Top(DTDFS->getNodeForBlock(TopInst->getParent())), + TopBB(TopInst->getParent()), + TopInst(TopInst), + modified(modified) + { + assert(Top && "VRPSolver created for unreachable basic block."); + assert(Top->getBlock() == TopInst->getParent() && "Context mismatch."); + } + + bool isRelatedBy(Value *V1, Value *V2, ICmpInst::Predicate Pred) const { + if (Constant *C1 = dyn_cast<Constant>(V1)) + if (Constant *C2 = dyn_cast<Constant>(V2)) + return ConstantExpr::getCompare(Pred, C1, C2) == + ConstantInt::getTrue(); + + unsigned n1 = VN.valueNumber(V1, Top); + unsigned n2 = VN.valueNumber(V2, Top); + + if (n1 && n2) { + if (n1 == n2) return Pred == ICmpInst::ICMP_EQ || + Pred == ICmpInst::ICMP_ULE || + Pred == ICmpInst::ICMP_UGE || + Pred == ICmpInst::ICMP_SLE || + Pred == ICmpInst::ICMP_SGE; + if (Pred == ICmpInst::ICMP_EQ) return false; + if (IG.isRelatedBy(n1, n2, Top, cmpInstToLattice(Pred))) return true; + if (VR.isRelatedBy(n1, n2, Top, cmpInstToLattice(Pred))) return true; + } + + if ((n1 && !n2 && isa<Constant>(V2)) || + (n2 && !n1 && isa<Constant>(V1))) { + ConstantRange CR1 = n1 ? VR.range(n1, Top) : VR.range(V1); + ConstantRange CR2 = n2 ? VR.range(n2, Top) : VR.range(V2); + + if (Pred == ICmpInst::ICMP_EQ) + return CR1.isSingleElement() && + CR1.getSingleElement() == CR2.getSingleElement(); + + return VR.isRelatedBy(CR1, CR2, cmpInstToLattice(Pred)); + } + if (Pred == ICmpInst::ICMP_EQ) return V1 == V2; + return false; + } + + /// add - adds a new property to the work queue + void add(Value *V1, Value *V2, ICmpInst::Predicate Pred, + Instruction *I = NULL) { + DOUT << "adding " << *V1 << " " << Pred << " " << *V2; + if (I) DOUT << " context: " << *I; + else DOUT << " default context (" << Top->getDFSNumIn() << ")"; + DOUT << "\n"; + + assert(V1->getType() == V2->getType() && + "Can't relate two values with different types."); + + WorkList.push_back(Operation()); + Operation &O = WorkList.back(); + O.LHS = V1, O.RHS = V2, O.Op = Pred, O.ContextInst = I; + O.ContextBB = I ? I->getParent() : TopBB; + } + + /// defToOps - Given an instruction definition that we've learned something + /// new about, find any new relationships between its operands. + void defToOps(Instruction *I) { + Instruction *NewContext = below(I) ? I : TopInst; + Value *Canonical = VN.canonicalize(I, Top); + + if (BinaryOperator *BO = dyn_cast<BinaryOperator>(I)) { + const Type *Ty = BO->getType(); + assert(!Ty->isFPOrFPVector() && "Float in work queue!"); + + Value *Op0 = VN.canonicalize(BO->getOperand(0), Top); + Value *Op1 = VN.canonicalize(BO->getOperand(1), Top); + + // TODO: "and i32 -1, %x" EQ %y then %x EQ %y. + + switch (BO->getOpcode()) { + case Instruction::And: { + // "and i32 %a, %b" EQ -1 then %a EQ -1 and %b EQ -1 + ConstantInt *CI = ConstantInt::getAllOnesValue(Ty); + if (Canonical == CI) { + add(CI, Op0, ICmpInst::ICMP_EQ, NewContext); + add(CI, Op1, ICmpInst::ICMP_EQ, NewContext); + } + } break; + case Instruction::Or: { + // "or i32 %a, %b" EQ 0 then %a EQ 0 and %b EQ 0 + Constant *Zero = Constant::getNullValue(Ty); + if (Canonical == Zero) { + add(Zero, Op0, ICmpInst::ICMP_EQ, NewContext); + add(Zero, Op1, ICmpInst::ICMP_EQ, NewContext); + } + } break; + case Instruction::Xor: { + // "xor i32 %c, %a" EQ %b then %a EQ %c ^ %b + // "xor i32 %c, %a" EQ %c then %a EQ 0 + // "xor i32 %c, %a" NE %c then %a NE 0 + // Repeat the above, with order of operands reversed. + Value *LHS = Op0; + Value *RHS = Op1; + if (!isa<Constant>(LHS)) std::swap(LHS, RHS); + + if (ConstantInt *CI = dyn_cast<ConstantInt>(Canonical)) { + if (ConstantInt *Arg = dyn_cast<ConstantInt>(LHS)) { + add(RHS, ConstantInt::get(CI->getValue() ^ Arg->getValue()), + ICmpInst::ICMP_EQ, NewContext); + } + } + if (Canonical == LHS) { + if (isa<ConstantInt>(Canonical)) + add(RHS, Constant::getNullValue(Ty), ICmpInst::ICMP_EQ, + NewContext); + } else if (isRelatedBy(LHS, Canonical, ICmpInst::ICMP_NE)) { + add(RHS, Constant::getNullValue(Ty), ICmpInst::ICMP_NE, + NewContext); + } + } break; + default: + break; + } + } else if (ICmpInst *IC = dyn_cast<ICmpInst>(I)) { + // "icmp ult i32 %a, %y" EQ true then %a u< y + // etc. + + if (Canonical == ConstantInt::getTrue()) { + add(IC->getOperand(0), IC->getOperand(1), IC->getPredicate(), + NewContext); + } else if (Canonical == ConstantInt::getFalse()) { + add(IC->getOperand(0), IC->getOperand(1), + ICmpInst::getInversePredicate(IC->getPredicate()), NewContext); + } + } else if (SelectInst *SI = dyn_cast<SelectInst>(I)) { + if (I->getType()->isFPOrFPVector()) return; + + // Given: "%a = select i1 %x, i32 %b, i32 %c" + // %a EQ %b and %b NE %c then %x EQ true + // %a EQ %c and %b NE %c then %x EQ false + + Value *True = SI->getTrueValue(); + Value *False = SI->getFalseValue(); + if (isRelatedBy(True, False, ICmpInst::ICMP_NE)) { + if (Canonical == VN.canonicalize(True, Top) || + isRelatedBy(Canonical, False, ICmpInst::ICMP_NE)) + add(SI->getCondition(), ConstantInt::getTrue(), + ICmpInst::ICMP_EQ, NewContext); + else if (Canonical == VN.canonicalize(False, Top) || + isRelatedBy(Canonical, True, ICmpInst::ICMP_NE)) + add(SI->getCondition(), ConstantInt::getFalse(), + ICmpInst::ICMP_EQ, NewContext); + } + } else if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(I)) { + for (GetElementPtrInst::op_iterator OI = GEPI->idx_begin(), + OE = GEPI->idx_end(); OI != OE; ++OI) { + ConstantInt *Op = dyn_cast<ConstantInt>(VN.canonicalize(*OI, Top)); + if (!Op || !Op->isZero()) return; + } + // TODO: The GEPI indices are all zero. Copy from definition to operand, + // jumping the type plane as needed. + if (isRelatedBy(GEPI, Constant::getNullValue(GEPI->getType()), + ICmpInst::ICMP_NE)) { + Value *Ptr = GEPI->getPointerOperand(); + add(Ptr, Constant::getNullValue(Ptr->getType()), ICmpInst::ICMP_NE, + NewContext); + } + } else if (CastInst *CI = dyn_cast<CastInst>(I)) { + const Type *SrcTy = CI->getSrcTy(); + + unsigned ci = VN.getOrInsertVN(CI, Top); + uint32_t W = VR.typeToWidth(SrcTy); + if (!W) return; + ConstantRange CR = VR.range(ci, Top); + + if (CR.isFullSet()) return; + + switch (CI->getOpcode()) { + default: break; + case Instruction::ZExt: + case Instruction::SExt: + VR.applyRange(VN.getOrInsertVN(CI->getOperand(0), Top), + CR.truncate(W), Top, this); + break; + case Instruction::BitCast: + VR.applyRange(VN.getOrInsertVN(CI->getOperand(0), Top), + CR, Top, this); + break; + } + } + } + + /// opsToDef - A new relationship was discovered involving one of this + /// instruction's operands. Find any new relationship involving the + /// definition, or another operand. + void opsToDef(Instruction *I) { + Instruction *NewContext = below(I) ? I : TopInst; + + if (BinaryOperator *BO = dyn_cast<BinaryOperator>(I)) { + Value *Op0 = VN.canonicalize(BO->getOperand(0), Top); + Value *Op1 = VN.canonicalize(BO->getOperand(1), Top); + + if (ConstantInt *CI0 = dyn_cast<ConstantInt>(Op0)) + if (ConstantInt *CI1 = dyn_cast<ConstantInt>(Op1)) { + add(BO, ConstantExpr::get(BO->getOpcode(), CI0, CI1), + ICmpInst::ICMP_EQ, NewContext); + return; + } + + // "%y = and i1 true, %x" then %x EQ %y + // "%y = or i1 false, %x" then %x EQ %y + // "%x = add i32 %y, 0" then %x EQ %y + // "%x = mul i32 %y, 0" then %x EQ 0 + + Instruction::BinaryOps Opcode = BO->getOpcode(); + const Type *Ty = BO->getType(); + assert(!Ty->isFPOrFPVector() && "Float in work queue!"); + + Constant *Zero = Constant::getNullValue(Ty); + Constant *One = ConstantInt::get(Ty, 1); + ConstantInt *AllOnes = ConstantInt::getAllOnesValue(Ty); + + switch (Opcode) { + default: break; + case Instruction::LShr: + case Instruction::AShr: + case Instruction::Shl: + if (Op1 == Zero) { + add(BO, Op0, ICmpInst::ICMP_EQ, NewContext); + return; + } + break; + case Instruction::Sub: + if (Op1 == Zero) { + add(BO, Op0, ICmpInst::ICMP_EQ, NewContext); + return; + } + if (ConstantInt *CI0 = dyn_cast<ConstantInt>(Op0)) { + unsigned n_ci0 = VN.getOrInsertVN(Op1, Top); + ConstantRange CR = VR.range(n_ci0, Top); + if (!CR.isFullSet()) { + CR.subtract(CI0->getValue()); + unsigned n_bo = VN.getOrInsertVN(BO, Top); + VR.applyRange(n_bo, CR, Top, this); + return; + } + } + if (ConstantInt *CI1 = dyn_cast<ConstantInt>(Op1)) { + unsigned n_ci1 = VN.getOrInsertVN(Op0, Top); + ConstantRange CR = VR.range(n_ci1, Top); + if (!CR.isFullSet()) { + CR.subtract(CI1->getValue()); + unsigned n_bo = VN.getOrInsertVN(BO, Top); + VR.applyRange(n_bo, CR, Top, this); + return; + } + } + break; + case Instruction::Or: + if (Op0 == AllOnes || Op1 == AllOnes) { + add(BO, AllOnes, ICmpInst::ICMP_EQ, NewContext); + return; + } + if (Op0 == Zero) { + add(BO, Op1, ICmpInst::ICMP_EQ, NewContext); + return; + } else if (Op1 == Zero) { + add(BO, Op0, ICmpInst::ICMP_EQ, NewContext); + return; + } + break; + case Instruction::Add: + if (ConstantInt *CI0 = dyn_cast<ConstantInt>(Op0)) { + unsigned n_ci0 = VN.getOrInsertVN(Op1, Top); + ConstantRange CR = VR.range(n_ci0, Top); + if (!CR.isFullSet()) { + CR.subtract(-CI0->getValue()); + unsigned n_bo = VN.getOrInsertVN(BO, Top); + VR.applyRange(n_bo, CR, Top, this); + return; + } + } + if (ConstantInt *CI1 = dyn_cast<ConstantInt>(Op1)) { + unsigned n_ci1 = VN.getOrInsertVN(Op0, Top); + ConstantRange CR = VR.range(n_ci1, Top); + if (!CR.isFullSet()) { + CR.subtract(-CI1->getValue()); + unsigned n_bo = VN.getOrInsertVN(BO, Top); + VR.applyRange(n_bo, CR, Top, this); + return; + } + } + // fall-through + case Instruction::Xor: + if (Op0 == Zero) { + add(BO, Op1, ICmpInst::ICMP_EQ, NewContext); + return; + } else if (Op1 == Zero) { + add(BO, Op0, ICmpInst::ICMP_EQ, NewContext); + return; + } + break; + case Instruction::And: + if (Op0 == AllOnes) { + add(BO, Op1, ICmpInst::ICMP_EQ, NewContext); + return; + } else if (Op1 == AllOnes) { + add(BO, Op0, ICmpInst::ICMP_EQ, NewContext); + return; + } + if (Op0 == Zero || Op1 == Zero) { + add(BO, Zero, ICmpInst::ICMP_EQ, NewContext); + return; + } + break; + case Instruction::Mul: + if (Op0 == Zero || Op1 == Zero) { + add(BO, Zero, ICmpInst::ICMP_EQ, NewContext); + return; + } + if (Op0 == One) { + add(BO, Op1, ICmpInst::ICMP_EQ, NewContext); + return; + } else if (Op1 == One) { + add(BO, Op0, ICmpInst::ICMP_EQ, NewContext); + return; + } + break; + } + + // "%x = add i32 %y, %z" and %x EQ %y then %z EQ 0 + // "%x = add i32 %y, %z" and %x EQ %z then %y EQ 0 + // "%x = shl i32 %y, %z" and %x EQ %y and %y NE 0 then %z EQ 0 + // "%x = udiv i32 %y, %z" and %x EQ %y and %y NE 0 then %z EQ 1 + + Value *Known = Op0, *Unknown = Op1, + *TheBO = VN.canonicalize(BO, Top); + if (Known != TheBO) std::swap(Known, Unknown); + if (Known == TheBO) { + switch (Opcode) { + default: break; + case Instruction::LShr: + case Instruction::AShr: + case Instruction::Shl: + if (!isRelatedBy(Known, Zero, ICmpInst::ICMP_NE)) break; + // otherwise, fall-through. + case Instruction::Sub: + if (Unknown == Op0) break; + // otherwise, fall-through. + case Instruction::Xor: + case Instruction::Add: + add(Unknown, Zero, ICmpInst::ICMP_EQ, NewContext); + break; + case Instruction::UDiv: + case Instruction::SDiv: + if (Unknown == Op1) break; + if (isRelatedBy(Known, Zero, ICmpInst::ICMP_NE)) + add(Unknown, One, ICmpInst::ICMP_EQ, NewContext); + break; + } + } + + // TODO: "%a = add i32 %b, 1" and %b > %z then %a >= %z. + + } else if (ICmpInst *IC = dyn_cast<ICmpInst>(I)) { + // "%a = icmp ult i32 %b, %c" and %b u< %c then %a EQ true + // "%a = icmp ult i32 %b, %c" and %b u>= %c then %a EQ false + // etc. + + Value *Op0 = VN.canonicalize(IC->getOperand(0), Top); + Value *Op1 = VN.canonicalize(IC->getOperand(1), Top); + + ICmpInst::Predicate Pred = IC->getPredicate(); + if (isRelatedBy(Op0, Op1, Pred)) + add(IC, ConstantInt::getTrue(), ICmpInst::ICMP_EQ, NewContext); + else if (isRelatedBy(Op0, Op1, ICmpInst::getInversePredicate(Pred))) + add(IC, ConstantInt::getFalse(), ICmpInst::ICMP_EQ, NewContext); + + } else if (SelectInst *SI = dyn_cast<SelectInst>(I)) { + if (I->getType()->isFPOrFPVector()) return; + + // Given: "%a = select i1 %x, i32 %b, i32 %c" + // %x EQ true then %a EQ %b + // %x EQ false then %a EQ %c + // %b EQ %c then %a EQ %b + + Value *Canonical = VN.canonicalize(SI->getCondition(), Top); + if (Canonical == ConstantInt::getTrue()) { + add(SI, SI->getTrueValue(), ICmpInst::ICMP_EQ, NewContext); + } else if (Canonical == ConstantInt::getFalse()) { + add(SI, SI->getFalseValue(), ICmpInst::ICMP_EQ, NewContext); + } else if (VN.canonicalize(SI->getTrueValue(), Top) == + VN.canonicalize(SI->getFalseValue(), Top)) { + add(SI, SI->getTrueValue(), ICmpInst::ICMP_EQ, NewContext); + } + } else if (CastInst *CI = dyn_cast<CastInst>(I)) { + const Type *DestTy = CI->getDestTy(); + if (DestTy->isFPOrFPVector()) return; + + Value *Op = VN.canonicalize(CI->getOperand(0), Top); + Instruction::CastOps Opcode = CI->getOpcode(); + + if (Constant *C = dyn_cast<Constant>(Op)) { + add(CI, ConstantExpr::getCast(Opcode, C, DestTy), + ICmpInst::ICMP_EQ, NewContext); + } + + uint32_t W = VR.typeToWidth(DestTy); + unsigned ci = VN.getOrInsertVN(CI, Top); + ConstantRange CR = VR.range(VN.getOrInsertVN(Op, Top), Top); + + if (!CR.isFullSet()) { + switch (Opcode) { + default: break; + case Instruction::ZExt: + VR.applyRange(ci, CR.zeroExtend(W), Top, this); + break; + case Instruction::SExt: + VR.applyRange(ci, CR.signExtend(W), Top, this); + break; + case Instruction::Trunc: { + ConstantRange Result = CR.truncate(W); + if (!Result.isFullSet()) + VR.applyRange(ci, Result, Top, this); + } break; + case Instruction::BitCast: + VR.applyRange(ci, CR, Top, this); + break; + // TODO: other casts? + } + } + } else if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(I)) { + for (GetElementPtrInst::op_iterator OI = GEPI->idx_begin(), + OE = GEPI->idx_end(); OI != OE; ++OI) { + ConstantInt *Op = dyn_cast<ConstantInt>(VN.canonicalize(*OI, Top)); + if (!Op || !Op->isZero()) return; + } + // TODO: The GEPI indices are all zero. Copy from operand to definition, + // jumping the type plane as needed. + Value *Ptr = GEPI->getPointerOperand(); + if (isRelatedBy(Ptr, Constant::getNullValue(Ptr->getType()), + ICmpInst::ICMP_NE)) { + add(GEPI, Constant::getNullValue(GEPI->getType()), ICmpInst::ICMP_NE, + NewContext); + } + } + } + + /// solve - process the work queue + void solve() { + //DOUT << "WorkList entry, size: " << WorkList.size() << "\n"; + while (!WorkList.empty()) { + //DOUT << "WorkList size: " << WorkList.size() << "\n"; + + Operation &O = WorkList.front(); + TopInst = O.ContextInst; + TopBB = O.ContextBB; + Top = DTDFS->getNodeForBlock(TopBB); // XXX move this into Context + + O.LHS = VN.canonicalize(O.LHS, Top); + O.RHS = VN.canonicalize(O.RHS, Top); + + assert(O.LHS == VN.canonicalize(O.LHS, Top) && "Canonicalize isn't."); + assert(O.RHS == VN.canonicalize(O.RHS, Top) && "Canonicalize isn't."); + + DOUT << "solving " << *O.LHS << " " << O.Op << " " << *O.RHS; + if (O.ContextInst) DOUT << " context inst: " << *O.ContextInst; + else DOUT << " context block: " << O.ContextBB->getName(); + DOUT << "\n"; + + DEBUG(VN.dump()); + DEBUG(IG.dump()); + DEBUG(VR.dump()); + + // If they're both Constant, skip it. Check for contradiction and mark + // the BB as unreachable if so. + if (Constant *CI_L = dyn_cast<Constant>(O.LHS)) { + if (Constant *CI_R = dyn_cast<Constant>(O.RHS)) { + if (ConstantExpr::getCompare(O.Op, CI_L, CI_R) == + ConstantInt::getFalse()) + UB.mark(TopBB); + + WorkList.pop_front(); + continue; + } + } + + if (VN.compare(O.LHS, O.RHS)) { + std::swap(O.LHS, O.RHS); + O.Op = ICmpInst::getSwappedPredicate(O.Op); + } + + if (O.Op == ICmpInst::ICMP_EQ) { + if (!makeEqual(O.RHS, O.LHS)) + UB.mark(TopBB); + } else { + LatticeVal LV = cmpInstToLattice(O.Op); + + if ((LV & EQ_BIT) && + isRelatedBy(O.LHS, O.RHS, ICmpInst::getSwappedPredicate(O.Op))) { + if (!makeEqual(O.RHS, O.LHS)) + UB.mark(TopBB); + } else { + if (isRelatedBy(O.LHS, O.RHS, ICmpInst::getInversePredicate(O.Op))){ + UB.mark(TopBB); + WorkList.pop_front(); + continue; + } + + unsigned n1 = VN.getOrInsertVN(O.LHS, Top); + unsigned n2 = VN.getOrInsertVN(O.RHS, Top); + + if (n1 == n2) { + if (O.Op != ICmpInst::ICMP_UGE && O.Op != ICmpInst::ICMP_ULE && + O.Op != ICmpInst::ICMP_SGE && O.Op != ICmpInst::ICMP_SLE) + UB.mark(TopBB); + + WorkList.pop_front(); + continue; + } + + if (VR.isRelatedBy(n1, n2, Top, LV) || + IG.isRelatedBy(n1, n2, Top, LV)) { + WorkList.pop_front(); + continue; + } + + VR.addInequality(n1, n2, Top, LV, this); + if ((!isa<ConstantInt>(O.RHS) && !isa<ConstantInt>(O.LHS)) || + LV == NE) + IG.addInequality(n1, n2, Top, LV); + + if (Instruction *I1 = dyn_cast<Instruction>(O.LHS)) { + if (aboveOrBelow(I1)) + defToOps(I1); + } + if (isa<Instruction>(O.LHS) || isa<Argument>(O.LHS)) { + for (Value::use_iterator UI = O.LHS->use_begin(), + UE = O.LHS->use_end(); UI != UE;) { + Use &TheUse = UI.getUse(); + ++UI; + if (Instruction *I = dyn_cast<Instruction>(TheUse.getUser())) { + if (aboveOrBelow(I)) + opsToDef(I); + } + } + } + if (Instruction *I2 = dyn_cast<Instruction>(O.RHS)) { + if (aboveOrBelow(I2)) + defToOps(I2); + } + if (isa<Instruction>(O.RHS) || isa<Argument>(O.RHS)) { + for (Value::use_iterator UI = O.RHS->use_begin(), + UE = O.RHS->use_end(); UI != UE;) { + Use &TheUse = UI.getUse(); + ++UI; + if (Instruction *I = dyn_cast<Instruction>(TheUse.getUser())) { + if (aboveOrBelow(I)) + opsToDef(I); + } + } + } + } + } + WorkList.pop_front(); + } + } + }; + + void ValueRanges::addToWorklist(Value *V, Constant *C, + ICmpInst::Predicate Pred, VRPSolver *VRP) { + VRP->add(V, C, Pred, VRP->TopInst); + } + + void ValueRanges::markBlock(VRPSolver *VRP) { + VRP->UB.mark(VRP->TopBB); + } + + /// PredicateSimplifier - This class is a simplifier that replaces + /// one equivalent variable with another. It also tracks what + /// can't be equal and will solve setcc instructions when possible. + /// @brief Root of the predicate simplifier optimization. + class VISIBILITY_HIDDEN PredicateSimplifier : public FunctionPass { + DomTreeDFS *DTDFS; + bool modified; + ValueNumbering *VN; + InequalityGraph *IG; + UnreachableBlocks UB; + ValueRanges *VR; + + std::vector<DomTreeDFS::Node *> WorkList; + + public: + static char ID; // Pass identification, replacement for typeid + PredicateSimplifier() : FunctionPass(&ID) {} + + bool runOnFunction(Function &F); + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequiredID(BreakCriticalEdgesID); + AU.addRequired<DominatorTree>(); + AU.addRequired<TargetData>(); + AU.addPreserved<TargetData>(); + } + + private: + /// Forwards - Adds new properties to VRPSolver and uses them to + /// simplify instructions. Because new properties sometimes apply to + /// a transition from one BasicBlock to another, this will use the + /// PredicateSimplifier::proceedToSuccessor(s) interface to enter the + /// basic block. + /// @brief Performs abstract execution of the program. + class VISIBILITY_HIDDEN Forwards : public InstVisitor<Forwards> { + friend class InstVisitor<Forwards>; + PredicateSimplifier *PS; + DomTreeDFS::Node *DTNode; + + public: + ValueNumbering &VN; + InequalityGraph &IG; + UnreachableBlocks &UB; + ValueRanges &VR; + + Forwards(PredicateSimplifier *PS, DomTreeDFS::Node *DTNode) + : PS(PS), DTNode(DTNode), VN(*PS->VN), IG(*PS->IG), UB(PS->UB), + VR(*PS->VR) {} + + void visitTerminatorInst(TerminatorInst &TI); + void visitBranchInst(BranchInst &BI); + void visitSwitchInst(SwitchInst &SI); + + void visitAllocaInst(AllocaInst &AI); + void visitLoadInst(LoadInst &LI); + void visitStoreInst(StoreInst &SI); + + void visitSExtInst(SExtInst &SI); + void visitZExtInst(ZExtInst &ZI); + + void visitBinaryOperator(BinaryOperator &BO); + void visitICmpInst(ICmpInst &IC); + }; + + // Used by terminator instructions to proceed from the current basic + // block to the next. Verifies that "current" dominates "next", + // then calls visitBasicBlock. + void proceedToSuccessors(DomTreeDFS::Node *Current) { + for (DomTreeDFS::Node::iterator I = Current->begin(), + E = Current->end(); I != E; ++I) { + WorkList.push_back(*I); + } + } + + void proceedToSuccessor(DomTreeDFS::Node *Next) { + WorkList.push_back(Next); + } + + // Visits each instruction in the basic block. + void visitBasicBlock(DomTreeDFS::Node *Node) { + BasicBlock *BB = Node->getBlock(); + DOUT << "Entering Basic Block: " << BB->getName() + << " (" << Node->getDFSNumIn() << ")\n"; + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E;) { + visitInstruction(I++, Node); + } + } + + // Tries to simplify each Instruction and add new properties. + void visitInstruction(Instruction *I, DomTreeDFS::Node *DT) { + DOUT << "Considering instruction " << *I << "\n"; + DEBUG(VN->dump()); + DEBUG(IG->dump()); + DEBUG(VR->dump()); + + // Sometimes instructions are killed in earlier analysis. + if (isInstructionTriviallyDead(I)) { + ++NumSimple; + modified = true; + if (unsigned n = VN->valueNumber(I, DTDFS->getRootNode())) + if (VN->value(n) == I) IG->remove(n); + VN->remove(I); + I->eraseFromParent(); + return; + } + +#ifndef NDEBUG + // Try to replace the whole instruction. + Value *V = VN->canonicalize(I, DT); + assert(V == I && "Late instruction canonicalization."); + if (V != I) { + modified = true; + ++NumInstruction; + DOUT << "Removing " << *I << ", replacing with " << *V << "\n"; + if (unsigned n = VN->valueNumber(I, DTDFS->getRootNode())) + if (VN->value(n) == I) IG->remove(n); + VN->remove(I); + I->replaceAllUsesWith(V); + I->eraseFromParent(); + return; + } + + // Try to substitute operands. + for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { + Value *Oper = I->getOperand(i); + Value *V = VN->canonicalize(Oper, DT); + assert(V == Oper && "Late operand canonicalization."); + if (V != Oper) { + modified = true; + ++NumVarsReplaced; + DOUT << "Resolving " << *I; + I->setOperand(i, V); + DOUT << " into " << *I; + } + } +#endif + + std::string name = I->getParent()->getName(); + DOUT << "push (%" << name << ")\n"; + Forwards visit(this, DT); + visit.visit(*I); + DOUT << "pop (%" << name << ")\n"; + } + }; + + bool PredicateSimplifier::runOnFunction(Function &F) { + DominatorTree *DT = &getAnalysis<DominatorTree>(); + DTDFS = new DomTreeDFS(DT); + TargetData *TD = &getAnalysis<TargetData>(); + + DOUT << "Entering Function: " << F.getName() << "\n"; + + modified = false; + DomTreeDFS::Node *Root = DTDFS->getRootNode(); + VN = new ValueNumbering(DTDFS); + IG = new InequalityGraph(*VN, Root); + VR = new ValueRanges(*VN, TD); + WorkList.push_back(Root); + + do { + DomTreeDFS::Node *DTNode = WorkList.back(); + WorkList.pop_back(); + if (!UB.isDead(DTNode->getBlock())) visitBasicBlock(DTNode); + } while (!WorkList.empty()); + + delete DTDFS; + delete VR; + delete IG; + delete VN; + + modified |= UB.kill(); + + return modified; + } + + void PredicateSimplifier::Forwards::visitTerminatorInst(TerminatorInst &TI) { + PS->proceedToSuccessors(DTNode); + } + + void PredicateSimplifier::Forwards::visitBranchInst(BranchInst &BI) { + if (BI.isUnconditional()) { + PS->proceedToSuccessors(DTNode); + return; + } + + Value *Condition = BI.getCondition(); + BasicBlock *TrueDest = BI.getSuccessor(0); + BasicBlock *FalseDest = BI.getSuccessor(1); + + if (isa<Constant>(Condition) || TrueDest == FalseDest) { + PS->proceedToSuccessors(DTNode); + return; + } + + for (DomTreeDFS::Node::iterator I = DTNode->begin(), E = DTNode->end(); + I != E; ++I) { + BasicBlock *Dest = (*I)->getBlock(); + DOUT << "Branch thinking about %" << Dest->getName() + << "(" << PS->DTDFS->getNodeForBlock(Dest)->getDFSNumIn() << ")\n"; + + if (Dest == TrueDest) { + DOUT << "(" << DTNode->getBlock()->getName() << ") true set:\n"; + VRPSolver VRP(VN, IG, UB, VR, PS->DTDFS, PS->modified, Dest); + VRP.add(ConstantInt::getTrue(), Condition, ICmpInst::ICMP_EQ); + VRP.solve(); + DEBUG(VN.dump()); + DEBUG(IG.dump()); + DEBUG(VR.dump()); + } else if (Dest == FalseDest) { + DOUT << "(" << DTNode->getBlock()->getName() << ") false set:\n"; + VRPSolver VRP(VN, IG, UB, VR, PS->DTDFS, PS->modified, Dest); + VRP.add(ConstantInt::getFalse(), Condition, ICmpInst::ICMP_EQ); + VRP.solve(); + DEBUG(VN.dump()); + DEBUG(IG.dump()); + DEBUG(VR.dump()); + } + + PS->proceedToSuccessor(*I); + } + } + + void PredicateSimplifier::Forwards::visitSwitchInst(SwitchInst &SI) { + Value *Condition = SI.getCondition(); + + // Set the EQProperty in each of the cases BBs, and the NEProperties + // in the default BB. + + for (DomTreeDFS::Node::iterator I = DTNode->begin(), E = DTNode->end(); + I != E; ++I) { + BasicBlock *BB = (*I)->getBlock(); + DOUT << "Switch thinking about BB %" << BB->getName() + << "(" << PS->DTDFS->getNodeForBlock(BB)->getDFSNumIn() << ")\n"; + + VRPSolver VRP(VN, IG, UB, VR, PS->DTDFS, PS->modified, BB); + if (BB == SI.getDefaultDest()) { + for (unsigned i = 1, e = SI.getNumCases(); i < e; ++i) + if (SI.getSuccessor(i) != BB) + VRP.add(Condition, SI.getCaseValue(i), ICmpInst::ICMP_NE); + VRP.solve(); + } else if (ConstantInt *CI = SI.findCaseDest(BB)) { + VRP.add(Condition, CI, ICmpInst::ICMP_EQ); + VRP.solve(); + } + PS->proceedToSuccessor(*I); + } + } + + void PredicateSimplifier::Forwards::visitAllocaInst(AllocaInst &AI) { + VRPSolver VRP(VN, IG, UB, VR, PS->DTDFS, PS->modified, &AI); + VRP.add(Constant::getNullValue(AI.getType()), &AI, ICmpInst::ICMP_NE); + VRP.solve(); + } + + void PredicateSimplifier::Forwards::visitLoadInst(LoadInst &LI) { + Value *Ptr = LI.getPointerOperand(); + // avoid "load i8* null" -> null NE null. + if (isa<Constant>(Ptr)) return; + + VRPSolver VRP(VN, IG, UB, VR, PS->DTDFS, PS->modified, &LI); + VRP.add(Constant::getNullValue(Ptr->getType()), Ptr, ICmpInst::ICMP_NE); + VRP.solve(); + } + + void PredicateSimplifier::Forwards::visitStoreInst(StoreInst &SI) { + Value *Ptr = SI.getPointerOperand(); + if (isa<Constant>(Ptr)) return; + + VRPSolver VRP(VN, IG, UB, VR, PS->DTDFS, PS->modified, &SI); + VRP.add(Constant::getNullValue(Ptr->getType()), Ptr, ICmpInst::ICMP_NE); + VRP.solve(); + } + + void PredicateSimplifier::Forwards::visitSExtInst(SExtInst &SI) { + VRPSolver VRP(VN, IG, UB, VR, PS->DTDFS, PS->modified, &SI); + uint32_t SrcBitWidth = cast<IntegerType>(SI.getSrcTy())->getBitWidth(); + uint32_t DstBitWidth = cast<IntegerType>(SI.getDestTy())->getBitWidth(); + APInt Min(APInt::getHighBitsSet(DstBitWidth, DstBitWidth-SrcBitWidth+1)); + APInt Max(APInt::getLowBitsSet(DstBitWidth, SrcBitWidth-1)); + VRP.add(ConstantInt::get(Min), &SI, ICmpInst::ICMP_SLE); + VRP.add(ConstantInt::get(Max), &SI, ICmpInst::ICMP_SGE); + VRP.solve(); + } + + void PredicateSimplifier::Forwards::visitZExtInst(ZExtInst &ZI) { + VRPSolver VRP(VN, IG, UB, VR, PS->DTDFS, PS->modified, &ZI); + uint32_t SrcBitWidth = cast<IntegerType>(ZI.getSrcTy())->getBitWidth(); + uint32_t DstBitWidth = cast<IntegerType>(ZI.getDestTy())->getBitWidth(); + APInt Max(APInt::getLowBitsSet(DstBitWidth, SrcBitWidth)); + VRP.add(ConstantInt::get(Max), &ZI, ICmpInst::ICMP_UGE); + VRP.solve(); + } + + void PredicateSimplifier::Forwards::visitBinaryOperator(BinaryOperator &BO) { + Instruction::BinaryOps ops = BO.getOpcode(); + + switch (ops) { + default: break; + case Instruction::URem: + case Instruction::SRem: + case Instruction::UDiv: + case Instruction::SDiv: { + Value *Divisor = BO.getOperand(1); + VRPSolver VRP(VN, IG, UB, VR, PS->DTDFS, PS->modified, &BO); + VRP.add(Constant::getNullValue(Divisor->getType()), Divisor, + ICmpInst::ICMP_NE); + VRP.solve(); + break; + } + } + + switch (ops) { + default: break; + case Instruction::Shl: { + VRPSolver VRP(VN, IG, UB, VR, PS->DTDFS, PS->modified, &BO); + VRP.add(&BO, BO.getOperand(0), ICmpInst::ICMP_UGE); + VRP.solve(); + } break; + case Instruction::AShr: { + VRPSolver VRP(VN, IG, UB, VR, PS->DTDFS, PS->modified, &BO); + VRP.add(&BO, BO.getOperand(0), ICmpInst::ICMP_SLE); + VRP.solve(); + } break; + case Instruction::LShr: + case Instruction::UDiv: { + VRPSolver VRP(VN, IG, UB, VR, PS->DTDFS, PS->modified, &BO); + VRP.add(&BO, BO.getOperand(0), ICmpInst::ICMP_ULE); + VRP.solve(); + } break; + case Instruction::URem: { + VRPSolver VRP(VN, IG, UB, VR, PS->DTDFS, PS->modified, &BO); + VRP.add(&BO, BO.getOperand(1), ICmpInst::ICMP_ULE); + VRP.solve(); + } break; + case Instruction::And: { + VRPSolver VRP(VN, IG, UB, VR, PS->DTDFS, PS->modified, &BO); + VRP.add(&BO, BO.getOperand(0), ICmpInst::ICMP_ULE); + VRP.add(&BO, BO.getOperand(1), ICmpInst::ICMP_ULE); + VRP.solve(); + } break; + case Instruction::Or: { + VRPSolver VRP(VN, IG, UB, VR, PS->DTDFS, PS->modified, &BO); + VRP.add(&BO, BO.getOperand(0), ICmpInst::ICMP_UGE); + VRP.add(&BO, BO.getOperand(1), ICmpInst::ICMP_UGE); + VRP.solve(); + } break; + } + } + + void PredicateSimplifier::Forwards::visitICmpInst(ICmpInst &IC) { + // If possible, squeeze the ICmp predicate into something simpler. + // Eg., if x = [0, 4) and we're being asked icmp uge %x, 3 then change + // the predicate to eq. + + // XXX: once we do full PHI handling, modifying the instruction in the + // Forwards visitor will cause missed optimizations. + + ICmpInst::Predicate Pred = IC.getPredicate(); + + switch (Pred) { + default: break; + case ICmpInst::ICMP_ULE: Pred = ICmpInst::ICMP_ULT; break; + case ICmpInst::ICMP_UGE: Pred = ICmpInst::ICMP_UGT; break; + case ICmpInst::ICMP_SLE: Pred = ICmpInst::ICMP_SLT; break; + case ICmpInst::ICMP_SGE: Pred = ICmpInst::ICMP_SGT; break; + } + if (Pred != IC.getPredicate()) { + VRPSolver VRP(VN, IG, UB, VR, PS->DTDFS, PS->modified, &IC); + if (VRP.isRelatedBy(IC.getOperand(1), IC.getOperand(0), + ICmpInst::ICMP_NE)) { + ++NumSnuggle; + PS->modified = true; + IC.setPredicate(Pred); + } + } + + Pred = IC.getPredicate(); + + if (ConstantInt *Op1 = dyn_cast<ConstantInt>(IC.getOperand(1))) { + ConstantInt *NextVal = 0; + switch (Pred) { + default: break; + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_ULT: + if (Op1->getValue() != 0) + NextVal = ConstantInt::get(Op1->getValue()-1); + break; + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_UGT: + if (!Op1->getValue().isAllOnesValue()) + NextVal = ConstantInt::get(Op1->getValue()+1); + break; + } + + if (NextVal) { + VRPSolver VRP(VN, IG, UB, VR, PS->DTDFS, PS->modified, &IC); + if (VRP.isRelatedBy(IC.getOperand(0), NextVal, + ICmpInst::getInversePredicate(Pred))) { + ICmpInst *NewIC = new ICmpInst(ICmpInst::ICMP_EQ, IC.getOperand(0), + NextVal, "", &IC); + NewIC->takeName(&IC); + IC.replaceAllUsesWith(NewIC); + + // XXX: prove this isn't necessary + if (unsigned n = VN.valueNumber(&IC, PS->DTDFS->getRootNode())) + if (VN.value(n) == &IC) IG.remove(n); + VN.remove(&IC); + + IC.eraseFromParent(); + ++NumSnuggle; + PS->modified = true; + } + } + } + } +} + +char PredicateSimplifier::ID = 0; +static RegisterPass<PredicateSimplifier> +X("predsimplify", "Predicate Simplifier"); + +FunctionPass *llvm::createPredicateSimplifierPass() { + return new PredicateSimplifier(); +} diff --git a/lib/Transforms/Scalar/Reassociate.cpp b/lib/Transforms/Scalar/Reassociate.cpp new file mode 100644 index 0000000..293cf92 --- /dev/null +++ b/lib/Transforms/Scalar/Reassociate.cpp @@ -0,0 +1,896 @@ +//===- Reassociate.cpp - Reassociate binary expressions -------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass reassociates commutative expressions in an order that is designed +// to promote better constant propagation, GCSE, LICM, PRE... +// +// For example: 4 + (x + 5) -> x + (4 + 5) +// +// In the implementation of this algorithm, constants are assigned rank = 0, +// function arguments are rank = 1, and other values are assigned ranks +// corresponding to the reverse post order traversal of current function +// (starting at 2), which effectively gives values in deep loops higher rank +// than values not in loops. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "reassociate" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Function.h" +#include "llvm/Instructions.h" +#include "llvm/IntrinsicInst.h" +#include "llvm/Pass.h" +#include "llvm/Assembly/Writer.h" +#include "llvm/Support/CFG.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ValueHandle.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/Statistic.h" +#include <algorithm> +#include <map> +using namespace llvm; + +STATISTIC(NumLinear , "Number of insts linearized"); +STATISTIC(NumChanged, "Number of insts reassociated"); +STATISTIC(NumAnnihil, "Number of expr tree annihilated"); +STATISTIC(NumFactor , "Number of multiplies factored"); + +namespace { + struct VISIBILITY_HIDDEN ValueEntry { + unsigned Rank; + Value *Op; + ValueEntry(unsigned R, Value *O) : Rank(R), Op(O) {} + }; + inline bool operator<(const ValueEntry &LHS, const ValueEntry &RHS) { + return LHS.Rank > RHS.Rank; // Sort so that highest rank goes to start. + } +} + +#ifndef NDEBUG +/// PrintOps - Print out the expression identified in the Ops list. +/// +static void PrintOps(Instruction *I, const std::vector<ValueEntry> &Ops) { + Module *M = I->getParent()->getParent()->getParent(); + cerr << Instruction::getOpcodeName(I->getOpcode()) << " " + << *Ops[0].Op->getType(); + for (unsigned i = 0, e = Ops.size(); i != e; ++i) { + WriteAsOperand(*cerr.stream() << " ", Ops[i].Op, false, M); + cerr << "," << Ops[i].Rank; + } +} +#endif + +namespace { + class VISIBILITY_HIDDEN Reassociate : public FunctionPass { + std::map<BasicBlock*, unsigned> RankMap; + std::map<AssertingVH<>, unsigned> ValueRankMap; + bool MadeChange; + public: + static char ID; // Pass identification, replacement for typeid + Reassociate() : FunctionPass(&ID) {} + + bool runOnFunction(Function &F); + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesCFG(); + } + private: + void BuildRankMap(Function &F); + unsigned getRank(Value *V); + void ReassociateExpression(BinaryOperator *I); + void RewriteExprTree(BinaryOperator *I, std::vector<ValueEntry> &Ops, + unsigned Idx = 0); + Value *OptimizeExpression(BinaryOperator *I, std::vector<ValueEntry> &Ops); + void LinearizeExprTree(BinaryOperator *I, std::vector<ValueEntry> &Ops); + void LinearizeExpr(BinaryOperator *I); + Value *RemoveFactorFromExpression(Value *V, Value *Factor); + void ReassociateBB(BasicBlock *BB); + + void RemoveDeadBinaryOp(Value *V); + }; +} + +char Reassociate::ID = 0; +static RegisterPass<Reassociate> X("reassociate", "Reassociate expressions"); + +// Public interface to the Reassociate pass +FunctionPass *llvm::createReassociatePass() { return new Reassociate(); } + +void Reassociate::RemoveDeadBinaryOp(Value *V) { + Instruction *Op = dyn_cast<Instruction>(V); + if (!Op || !isa<BinaryOperator>(Op) || !isa<CmpInst>(Op) || !Op->use_empty()) + return; + + Value *LHS = Op->getOperand(0), *RHS = Op->getOperand(1); + RemoveDeadBinaryOp(LHS); + RemoveDeadBinaryOp(RHS); +} + + +static bool isUnmovableInstruction(Instruction *I) { + if (I->getOpcode() == Instruction::PHI || + I->getOpcode() == Instruction::Alloca || + I->getOpcode() == Instruction::Load || + I->getOpcode() == Instruction::Malloc || + I->getOpcode() == Instruction::Invoke || + (I->getOpcode() == Instruction::Call && + !isa<DbgInfoIntrinsic>(I)) || + I->getOpcode() == Instruction::UDiv || + I->getOpcode() == Instruction::SDiv || + I->getOpcode() == Instruction::FDiv || + I->getOpcode() == Instruction::URem || + I->getOpcode() == Instruction::SRem || + I->getOpcode() == Instruction::FRem) + return true; + return false; +} + +void Reassociate::BuildRankMap(Function &F) { + unsigned i = 2; + + // Assign distinct ranks to function arguments + for (Function::arg_iterator I = F.arg_begin(), E = F.arg_end(); I != E; ++I) + ValueRankMap[&*I] = ++i; + + ReversePostOrderTraversal<Function*> RPOT(&F); + for (ReversePostOrderTraversal<Function*>::rpo_iterator I = RPOT.begin(), + E = RPOT.end(); I != E; ++I) { + BasicBlock *BB = *I; + unsigned BBRank = RankMap[BB] = ++i << 16; + + // Walk the basic block, adding precomputed ranks for any instructions that + // we cannot move. This ensures that the ranks for these instructions are + // all different in the block. + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) + if (isUnmovableInstruction(I)) + ValueRankMap[&*I] = ++BBRank; + } +} + +unsigned Reassociate::getRank(Value *V) { + if (isa<Argument>(V)) return ValueRankMap[V]; // Function argument... + + Instruction *I = dyn_cast<Instruction>(V); + if (I == 0) return 0; // Otherwise it's a global or constant, rank 0. + + unsigned &CachedRank = ValueRankMap[I]; + if (CachedRank) return CachedRank; // Rank already known? + + // If this is an expression, return the 1+MAX(rank(LHS), rank(RHS)) so that + // we can reassociate expressions for code motion! Since we do not recurse + // for PHI nodes, we cannot have infinite recursion here, because there + // cannot be loops in the value graph that do not go through PHI nodes. + unsigned Rank = 0, MaxRank = RankMap[I->getParent()]; + for (unsigned i = 0, e = I->getNumOperands(); + i != e && Rank != MaxRank; ++i) + Rank = std::max(Rank, getRank(I->getOperand(i))); + + // If this is a not or neg instruction, do not count it for rank. This + // assures us that X and ~X will have the same rank. + if (!I->getType()->isInteger() || + (!BinaryOperator::isNot(I) && !BinaryOperator::isNeg(I))) + ++Rank; + + //DOUT << "Calculated Rank[" << V->getName() << "] = " + // << Rank << "\n"; + + return CachedRank = Rank; +} + +/// isReassociableOp - Return true if V is an instruction of the specified +/// opcode and if it only has one use. +static BinaryOperator *isReassociableOp(Value *V, unsigned Opcode) { + if ((V->hasOneUse() || V->use_empty()) && isa<Instruction>(V) && + cast<Instruction>(V)->getOpcode() == Opcode) + return cast<BinaryOperator>(V); + return 0; +} + +/// LowerNegateToMultiply - Replace 0-X with X*-1. +/// +static Instruction *LowerNegateToMultiply(Instruction *Neg, + std::map<AssertingVH<>, unsigned> &ValueRankMap) { + Constant *Cst = ConstantInt::getAllOnesValue(Neg->getType()); + + Instruction *Res = BinaryOperator::CreateMul(Neg->getOperand(1), Cst, "",Neg); + ValueRankMap.erase(Neg); + Res->takeName(Neg); + Neg->replaceAllUsesWith(Res); + Neg->eraseFromParent(); + return Res; +} + +// Given an expression of the form '(A+B)+(D+C)', turn it into '(((A+B)+C)+D)'. +// Note that if D is also part of the expression tree that we recurse to +// linearize it as well. Besides that case, this does not recurse into A,B, or +// C. +void Reassociate::LinearizeExpr(BinaryOperator *I) { + BinaryOperator *LHS = cast<BinaryOperator>(I->getOperand(0)); + BinaryOperator *RHS = cast<BinaryOperator>(I->getOperand(1)); + assert(isReassociableOp(LHS, I->getOpcode()) && + isReassociableOp(RHS, I->getOpcode()) && + "Not an expression that needs linearization?"); + + DOUT << "Linear" << *LHS << *RHS << *I; + + // Move the RHS instruction to live immediately before I, avoiding breaking + // dominator properties. + RHS->moveBefore(I); + + // Move operands around to do the linearization. + I->setOperand(1, RHS->getOperand(0)); + RHS->setOperand(0, LHS); + I->setOperand(0, RHS); + + ++NumLinear; + MadeChange = true; + DOUT << "Linearized: " << *I; + + // If D is part of this expression tree, tail recurse. + if (isReassociableOp(I->getOperand(1), I->getOpcode())) + LinearizeExpr(I); +} + + +/// LinearizeExprTree - Given an associative binary expression tree, traverse +/// all of the uses putting it into canonical form. This forces a left-linear +/// form of the the expression (((a+b)+c)+d), and collects information about the +/// rank of the non-tree operands. +/// +/// NOTE: These intentionally destroys the expression tree operands (turning +/// them into undef values) to reduce #uses of the values. This means that the +/// caller MUST use something like RewriteExprTree to put the values back in. +/// +void Reassociate::LinearizeExprTree(BinaryOperator *I, + std::vector<ValueEntry> &Ops) { + Value *LHS = I->getOperand(0), *RHS = I->getOperand(1); + unsigned Opcode = I->getOpcode(); + + // First step, linearize the expression if it is in ((A+B)+(C+D)) form. + BinaryOperator *LHSBO = isReassociableOp(LHS, Opcode); + BinaryOperator *RHSBO = isReassociableOp(RHS, Opcode); + + // If this is a multiply expression tree and it contains internal negations, + // transform them into multiplies by -1 so they can be reassociated. + if (I->getOpcode() == Instruction::Mul) { + if (!LHSBO && LHS->hasOneUse() && BinaryOperator::isNeg(LHS)) { + LHS = LowerNegateToMultiply(cast<Instruction>(LHS), ValueRankMap); + LHSBO = isReassociableOp(LHS, Opcode); + } + if (!RHSBO && RHS->hasOneUse() && BinaryOperator::isNeg(RHS)) { + RHS = LowerNegateToMultiply(cast<Instruction>(RHS), ValueRankMap); + RHSBO = isReassociableOp(RHS, Opcode); + } + } + + if (!LHSBO) { + if (!RHSBO) { + // Neither the LHS or RHS as part of the tree, thus this is a leaf. As + // such, just remember these operands and their rank. + Ops.push_back(ValueEntry(getRank(LHS), LHS)); + Ops.push_back(ValueEntry(getRank(RHS), RHS)); + + // Clear the leaves out. + I->setOperand(0, UndefValue::get(I->getType())); + I->setOperand(1, UndefValue::get(I->getType())); + return; + } else { + // Turn X+(Y+Z) -> (Y+Z)+X + std::swap(LHSBO, RHSBO); + std::swap(LHS, RHS); + bool Success = !I->swapOperands(); + assert(Success && "swapOperands failed"); + Success = false; + MadeChange = true; + } + } else if (RHSBO) { + // Turn (A+B)+(C+D) -> (((A+B)+C)+D). This guarantees the the RHS is not + // part of the expression tree. + LinearizeExpr(I); + LHS = LHSBO = cast<BinaryOperator>(I->getOperand(0)); + RHS = I->getOperand(1); + RHSBO = 0; + } + + // Okay, now we know that the LHS is a nested expression and that the RHS is + // not. Perform reassociation. + assert(!isReassociableOp(RHS, Opcode) && "LinearizeExpr failed!"); + + // Move LHS right before I to make sure that the tree expression dominates all + // values. + LHSBO->moveBefore(I); + + // Linearize the expression tree on the LHS. + LinearizeExprTree(LHSBO, Ops); + + // Remember the RHS operand and its rank. + Ops.push_back(ValueEntry(getRank(RHS), RHS)); + + // Clear the RHS leaf out. + I->setOperand(1, UndefValue::get(I->getType())); +} + +// RewriteExprTree - Now that the operands for this expression tree are +// linearized and optimized, emit them in-order. This function is written to be +// tail recursive. +void Reassociate::RewriteExprTree(BinaryOperator *I, + std::vector<ValueEntry> &Ops, + unsigned i) { + if (i+2 == Ops.size()) { + if (I->getOperand(0) != Ops[i].Op || + I->getOperand(1) != Ops[i+1].Op) { + Value *OldLHS = I->getOperand(0); + DOUT << "RA: " << *I; + I->setOperand(0, Ops[i].Op); + I->setOperand(1, Ops[i+1].Op); + DOUT << "TO: " << *I; + MadeChange = true; + ++NumChanged; + + // If we reassociated a tree to fewer operands (e.g. (1+a+2) -> (a+3) + // delete the extra, now dead, nodes. + RemoveDeadBinaryOp(OldLHS); + } + return; + } + assert(i+2 < Ops.size() && "Ops index out of range!"); + + if (I->getOperand(1) != Ops[i].Op) { + DOUT << "RA: " << *I; + I->setOperand(1, Ops[i].Op); + DOUT << "TO: " << *I; + MadeChange = true; + ++NumChanged; + } + + BinaryOperator *LHS = cast<BinaryOperator>(I->getOperand(0)); + assert(LHS->getOpcode() == I->getOpcode() && + "Improper expression tree!"); + + // Compactify the tree instructions together with each other to guarantee + // that the expression tree is dominated by all of Ops. + LHS->moveBefore(I); + RewriteExprTree(LHS, Ops, i+1); +} + + + +// NegateValue - Insert instructions before the instruction pointed to by BI, +// that computes the negative version of the value specified. The negative +// version of the value is returned, and BI is left pointing at the instruction +// that should be processed next by the reassociation pass. +// +static Value *NegateValue(Value *V, Instruction *BI) { + // We are trying to expose opportunity for reassociation. One of the things + // that we want to do to achieve this is to push a negation as deep into an + // expression chain as possible, to expose the add instructions. In practice, + // this means that we turn this: + // X = -(A+12+C+D) into X = -A + -12 + -C + -D = -12 + -A + -C + -D + // so that later, a: Y = 12+X could get reassociated with the -12 to eliminate + // the constants. We assume that instcombine will clean up the mess later if + // we introduce tons of unnecessary negation instructions... + // + if (Instruction *I = dyn_cast<Instruction>(V)) + if (I->getOpcode() == Instruction::Add && I->hasOneUse()) { + // Push the negates through the add. + I->setOperand(0, NegateValue(I->getOperand(0), BI)); + I->setOperand(1, NegateValue(I->getOperand(1), BI)); + + // We must move the add instruction here, because the neg instructions do + // not dominate the old add instruction in general. By moving it, we are + // assured that the neg instructions we just inserted dominate the + // instruction we are about to insert after them. + // + I->moveBefore(BI); + I->setName(I->getName()+".neg"); + return I; + } + + // Insert a 'neg' instruction that subtracts the value from zero to get the + // negation. + // + return BinaryOperator::CreateNeg(V, V->getName() + ".neg", BI); +} + +/// ShouldBreakUpSubtract - Return true if we should break up this subtract of +/// X-Y into (X + -Y). +static bool ShouldBreakUpSubtract(Instruction *Sub) { + // If this is a negation, we can't split it up! + if (BinaryOperator::isNeg(Sub)) + return false; + + // Don't bother to break this up unless either the LHS is an associable add or + // subtract or if this is only used by one. + if (isReassociableOp(Sub->getOperand(0), Instruction::Add) || + isReassociableOp(Sub->getOperand(0), Instruction::Sub)) + return true; + if (isReassociableOp(Sub->getOperand(1), Instruction::Add) || + isReassociableOp(Sub->getOperand(1), Instruction::Sub)) + return true; + if (Sub->hasOneUse() && + (isReassociableOp(Sub->use_back(), Instruction::Add) || + isReassociableOp(Sub->use_back(), Instruction::Sub))) + return true; + + return false; +} + +/// BreakUpSubtract - If we have (X-Y), and if either X is an add, or if this is +/// only used by an add, transform this into (X+(0-Y)) to promote better +/// reassociation. +static Instruction *BreakUpSubtract(Instruction *Sub, + std::map<AssertingVH<>, unsigned> &ValueRankMap) { + // Convert a subtract into an add and a neg instruction... so that sub + // instructions can be commuted with other add instructions... + // + // Calculate the negative value of Operand 1 of the sub instruction... + // and set it as the RHS of the add instruction we just made... + // + Value *NegVal = NegateValue(Sub->getOperand(1), Sub); + Instruction *New = + BinaryOperator::CreateAdd(Sub->getOperand(0), NegVal, "", Sub); + New->takeName(Sub); + + // Everyone now refers to the add instruction. + ValueRankMap.erase(Sub); + Sub->replaceAllUsesWith(New); + Sub->eraseFromParent(); + + DOUT << "Negated: " << *New; + return New; +} + +/// ConvertShiftToMul - If this is a shift of a reassociable multiply or is used +/// by one, change this into a multiply by a constant to assist with further +/// reassociation. +static Instruction *ConvertShiftToMul(Instruction *Shl, + std::map<AssertingVH<>, unsigned> &ValueRankMap) { + // If an operand of this shift is a reassociable multiply, or if the shift + // is used by a reassociable multiply or add, turn into a multiply. + if (isReassociableOp(Shl->getOperand(0), Instruction::Mul) || + (Shl->hasOneUse() && + (isReassociableOp(Shl->use_back(), Instruction::Mul) || + isReassociableOp(Shl->use_back(), Instruction::Add)))) { + Constant *MulCst = ConstantInt::get(Shl->getType(), 1); + MulCst = ConstantExpr::getShl(MulCst, cast<Constant>(Shl->getOperand(1))); + + Instruction *Mul = BinaryOperator::CreateMul(Shl->getOperand(0), MulCst, + "", Shl); + ValueRankMap.erase(Shl); + Mul->takeName(Shl); + Shl->replaceAllUsesWith(Mul); + Shl->eraseFromParent(); + return Mul; + } + return 0; +} + +// Scan backwards and forwards among values with the same rank as element i to +// see if X exists. If X does not exist, return i. +static unsigned FindInOperandList(std::vector<ValueEntry> &Ops, unsigned i, + Value *X) { + unsigned XRank = Ops[i].Rank; + unsigned e = Ops.size(); + for (unsigned j = i+1; j != e && Ops[j].Rank == XRank; ++j) + if (Ops[j].Op == X) + return j; + // Scan backwards + for (unsigned j = i-1; j != ~0U && Ops[j].Rank == XRank; --j) + if (Ops[j].Op == X) + return j; + return i; +} + +/// EmitAddTreeOfValues - Emit a tree of add instructions, summing Ops together +/// and returning the result. Insert the tree before I. +static Value *EmitAddTreeOfValues(Instruction *I, std::vector<Value*> &Ops) { + if (Ops.size() == 1) return Ops.back(); + + Value *V1 = Ops.back(); + Ops.pop_back(); + Value *V2 = EmitAddTreeOfValues(I, Ops); + return BinaryOperator::CreateAdd(V2, V1, "tmp", I); +} + +/// RemoveFactorFromExpression - If V is an expression tree that is a +/// multiplication sequence, and if this sequence contains a multiply by Factor, +/// remove Factor from the tree and return the new tree. +Value *Reassociate::RemoveFactorFromExpression(Value *V, Value *Factor) { + BinaryOperator *BO = isReassociableOp(V, Instruction::Mul); + if (!BO) return 0; + + std::vector<ValueEntry> Factors; + LinearizeExprTree(BO, Factors); + + bool FoundFactor = false; + for (unsigned i = 0, e = Factors.size(); i != e; ++i) + if (Factors[i].Op == Factor) { + FoundFactor = true; + Factors.erase(Factors.begin()+i); + break; + } + if (!FoundFactor) { + // Make sure to restore the operands to the expression tree. + RewriteExprTree(BO, Factors); + return 0; + } + + if (Factors.size() == 1) return Factors[0].Op; + + RewriteExprTree(BO, Factors); + return BO; +} + +/// FindSingleUseMultiplyFactors - If V is a single-use multiply, recursively +/// add its operands as factors, otherwise add V to the list of factors. +static void FindSingleUseMultiplyFactors(Value *V, + std::vector<Value*> &Factors) { + BinaryOperator *BO; + if ((!V->hasOneUse() && !V->use_empty()) || + !(BO = dyn_cast<BinaryOperator>(V)) || + BO->getOpcode() != Instruction::Mul) { + Factors.push_back(V); + return; + } + + // Otherwise, add the LHS and RHS to the list of factors. + FindSingleUseMultiplyFactors(BO->getOperand(1), Factors); + FindSingleUseMultiplyFactors(BO->getOperand(0), Factors); +} + + + +Value *Reassociate::OptimizeExpression(BinaryOperator *I, + std::vector<ValueEntry> &Ops) { + // Now that we have the linearized expression tree, try to optimize it. + // Start by folding any constants that we found. + bool IterateOptimization = false; + if (Ops.size() == 1) return Ops[0].Op; + + unsigned Opcode = I->getOpcode(); + + if (Constant *V1 = dyn_cast<Constant>(Ops[Ops.size()-2].Op)) + if (Constant *V2 = dyn_cast<Constant>(Ops.back().Op)) { + Ops.pop_back(); + Ops.back().Op = ConstantExpr::get(Opcode, V1, V2); + return OptimizeExpression(I, Ops); + } + + // Check for destructive annihilation due to a constant being used. + if (ConstantInt *CstVal = dyn_cast<ConstantInt>(Ops.back().Op)) + switch (Opcode) { + default: break; + case Instruction::And: + if (CstVal->isZero()) { // ... & 0 -> 0 + ++NumAnnihil; + return CstVal; + } else if (CstVal->isAllOnesValue()) { // ... & -1 -> ... + Ops.pop_back(); + } + break; + case Instruction::Mul: + if (CstVal->isZero()) { // ... * 0 -> 0 + ++NumAnnihil; + return CstVal; + } else if (cast<ConstantInt>(CstVal)->isOne()) { + Ops.pop_back(); // ... * 1 -> ... + } + break; + case Instruction::Or: + if (CstVal->isAllOnesValue()) { // ... | -1 -> -1 + ++NumAnnihil; + return CstVal; + } + // FALLTHROUGH! + case Instruction::Add: + case Instruction::Xor: + if (CstVal->isZero()) // ... [|^+] 0 -> ... + Ops.pop_back(); + break; + } + if (Ops.size() == 1) return Ops[0].Op; + + // Handle destructive annihilation do to identities between elements in the + // argument list here. + switch (Opcode) { + default: break; + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + // Scan the operand lists looking for X and ~X pairs, along with X,X pairs. + // If we find any, we can simplify the expression. X&~X == 0, X|~X == -1. + for (unsigned i = 0, e = Ops.size(); i != e; ++i) { + // First, check for X and ~X in the operand list. + assert(i < Ops.size()); + if (BinaryOperator::isNot(Ops[i].Op)) { // Cannot occur for ^. + Value *X = BinaryOperator::getNotArgument(Ops[i].Op); + unsigned FoundX = FindInOperandList(Ops, i, X); + if (FoundX != i) { + if (Opcode == Instruction::And) { // ...&X&~X = 0 + ++NumAnnihil; + return Constant::getNullValue(X->getType()); + } else if (Opcode == Instruction::Or) { // ...|X|~X = -1 + ++NumAnnihil; + return ConstantInt::getAllOnesValue(X->getType()); + } + } + } + + // Next, check for duplicate pairs of values, which we assume are next to + // each other, due to our sorting criteria. + assert(i < Ops.size()); + if (i+1 != Ops.size() && Ops[i+1].Op == Ops[i].Op) { + if (Opcode == Instruction::And || Opcode == Instruction::Or) { + // Drop duplicate values. + Ops.erase(Ops.begin()+i); + --i; --e; + IterateOptimization = true; + ++NumAnnihil; + } else { + assert(Opcode == Instruction::Xor); + if (e == 2) { + ++NumAnnihil; + return Constant::getNullValue(Ops[0].Op->getType()); + } + // ... X^X -> ... + Ops.erase(Ops.begin()+i, Ops.begin()+i+2); + i -= 1; e -= 2; + IterateOptimization = true; + ++NumAnnihil; + } + } + } + break; + + case Instruction::Add: + // Scan the operand lists looking for X and -X pairs. If we find any, we + // can simplify the expression. X+-X == 0. + for (unsigned i = 0, e = Ops.size(); i != e; ++i) { + assert(i < Ops.size()); + // Check for X and -X in the operand list. + if (BinaryOperator::isNeg(Ops[i].Op)) { + Value *X = BinaryOperator::getNegArgument(Ops[i].Op); + unsigned FoundX = FindInOperandList(Ops, i, X); + if (FoundX != i) { + // Remove X and -X from the operand list. + if (Ops.size() == 2) { + ++NumAnnihil; + return Constant::getNullValue(X->getType()); + } else { + Ops.erase(Ops.begin()+i); + if (i < FoundX) + --FoundX; + else + --i; // Need to back up an extra one. + Ops.erase(Ops.begin()+FoundX); + IterateOptimization = true; + ++NumAnnihil; + --i; // Revisit element. + e -= 2; // Removed two elements. + } + } + } + } + + + // Scan the operand list, checking to see if there are any common factors + // between operands. Consider something like A*A+A*B*C+D. We would like to + // reassociate this to A*(A+B*C)+D, which reduces the number of multiplies. + // To efficiently find this, we count the number of times a factor occurs + // for any ADD operands that are MULs. + std::map<Value*, unsigned> FactorOccurrences; + unsigned MaxOcc = 0; + Value *MaxOccVal = 0; + for (unsigned i = 0, e = Ops.size(); i != e; ++i) { + if (BinaryOperator *BOp = dyn_cast<BinaryOperator>(Ops[i].Op)) { + if (BOp->getOpcode() == Instruction::Mul && BOp->use_empty()) { + // Compute all of the factors of this added value. + std::vector<Value*> Factors; + FindSingleUseMultiplyFactors(BOp, Factors); + assert(Factors.size() > 1 && "Bad linearize!"); + + // Add one to FactorOccurrences for each unique factor in this op. + if (Factors.size() == 2) { + unsigned Occ = ++FactorOccurrences[Factors[0]]; + if (Occ > MaxOcc) { MaxOcc = Occ; MaxOccVal = Factors[0]; } + if (Factors[0] != Factors[1]) { // Don't double count A*A. + Occ = ++FactorOccurrences[Factors[1]]; + if (Occ > MaxOcc) { MaxOcc = Occ; MaxOccVal = Factors[1]; } + } + } else { + std::set<Value*> Duplicates; + for (unsigned i = 0, e = Factors.size(); i != e; ++i) { + if (Duplicates.insert(Factors[i]).second) { + unsigned Occ = ++FactorOccurrences[Factors[i]]; + if (Occ > MaxOcc) { MaxOcc = Occ; MaxOccVal = Factors[i]; } + } + } + } + } + } + } + + // If any factor occurred more than one time, we can pull it out. + if (MaxOcc > 1) { + DOUT << "\nFACTORING [" << MaxOcc << "]: " << *MaxOccVal << "\n"; + + // Create a new instruction that uses the MaxOccVal twice. If we don't do + // this, we could otherwise run into situations where removing a factor + // from an expression will drop a use of maxocc, and this can cause + // RemoveFactorFromExpression on successive values to behave differently. + Instruction *DummyInst = BinaryOperator::CreateAdd(MaxOccVal, MaxOccVal); + std::vector<Value*> NewMulOps; + for (unsigned i = 0, e = Ops.size(); i != e; ++i) { + if (Value *V = RemoveFactorFromExpression(Ops[i].Op, MaxOccVal)) { + NewMulOps.push_back(V); + Ops.erase(Ops.begin()+i); + --i; --e; + } + } + + // No need for extra uses anymore. + delete DummyInst; + + unsigned NumAddedValues = NewMulOps.size(); + Value *V = EmitAddTreeOfValues(I, NewMulOps); + Value *V2 = BinaryOperator::CreateMul(V, MaxOccVal, "tmp", I); + + // Now that we have inserted V and its sole use, optimize it. This allows + // us to handle cases that require multiple factoring steps, such as this: + // A*A*B + A*A*C --> A*(A*B+A*C) --> A*(A*(B+C)) + if (NumAddedValues > 1) + ReassociateExpression(cast<BinaryOperator>(V)); + + ++NumFactor; + + if (Ops.empty()) + return V2; + + // Add the new value to the list of things being added. + Ops.insert(Ops.begin(), ValueEntry(getRank(V2), V2)); + + // Rewrite the tree so that there is now a use of V. + RewriteExprTree(I, Ops); + return OptimizeExpression(I, Ops); + } + break; + //case Instruction::Mul: + } + + if (IterateOptimization) + return OptimizeExpression(I, Ops); + return 0; +} + + +/// ReassociateBB - Inspect all of the instructions in this basic block, +/// reassociating them as we go. +void Reassociate::ReassociateBB(BasicBlock *BB) { + for (BasicBlock::iterator BBI = BB->begin(); BBI != BB->end(); ) { + Instruction *BI = BBI++; + if (BI->getOpcode() == Instruction::Shl && + isa<ConstantInt>(BI->getOperand(1))) + if (Instruction *NI = ConvertShiftToMul(BI, ValueRankMap)) { + MadeChange = true; + BI = NI; + } + + // Reject cases where it is pointless to do this. + if (!isa<BinaryOperator>(BI) || BI->getType()->isFloatingPoint() || + isa<VectorType>(BI->getType())) + continue; // Floating point ops are not associative. + + // If this is a subtract instruction which is not already in negate form, + // see if we can convert it to X+-Y. + if (BI->getOpcode() == Instruction::Sub) { + if (ShouldBreakUpSubtract(BI)) { + BI = BreakUpSubtract(BI, ValueRankMap); + MadeChange = true; + } else if (BinaryOperator::isNeg(BI)) { + // Otherwise, this is a negation. See if the operand is a multiply tree + // and if this is not an inner node of a multiply tree. + if (isReassociableOp(BI->getOperand(1), Instruction::Mul) && + (!BI->hasOneUse() || + !isReassociableOp(BI->use_back(), Instruction::Mul))) { + BI = LowerNegateToMultiply(BI, ValueRankMap); + MadeChange = true; + } + } + } + + // If this instruction is a commutative binary operator, process it. + if (!BI->isAssociative()) continue; + BinaryOperator *I = cast<BinaryOperator>(BI); + + // If this is an interior node of a reassociable tree, ignore it until we + // get to the root of the tree, to avoid N^2 analysis. + if (I->hasOneUse() && isReassociableOp(I->use_back(), I->getOpcode())) + continue; + + // If this is an add tree that is used by a sub instruction, ignore it + // until we process the subtract. + if (I->hasOneUse() && I->getOpcode() == Instruction::Add && + cast<Instruction>(I->use_back())->getOpcode() == Instruction::Sub) + continue; + + ReassociateExpression(I); + } +} + +void Reassociate::ReassociateExpression(BinaryOperator *I) { + + // First, walk the expression tree, linearizing the tree, collecting + std::vector<ValueEntry> Ops; + LinearizeExprTree(I, Ops); + + DOUT << "RAIn:\t"; DEBUG(PrintOps(I, Ops)); DOUT << "\n"; + + // Now that we have linearized the tree to a list and have gathered all of + // the operands and their ranks, sort the operands by their rank. Use a + // stable_sort so that values with equal ranks will have their relative + // positions maintained (and so the compiler is deterministic). Note that + // this sorts so that the highest ranking values end up at the beginning of + // the vector. + std::stable_sort(Ops.begin(), Ops.end()); + + // OptimizeExpression - Now that we have the expression tree in a convenient + // sorted form, optimize it globally if possible. + if (Value *V = OptimizeExpression(I, Ops)) { + // This expression tree simplified to something that isn't a tree, + // eliminate it. + DOUT << "Reassoc to scalar: " << *V << "\n"; + I->replaceAllUsesWith(V); + RemoveDeadBinaryOp(I); + return; + } + + // We want to sink immediates as deeply as possible except in the case where + // this is a multiply tree used only by an add, and the immediate is a -1. + // In this case we reassociate to put the negation on the outside so that we + // can fold the negation into the add: (-X)*Y + Z -> Z-X*Y + if (I->getOpcode() == Instruction::Mul && I->hasOneUse() && + cast<Instruction>(I->use_back())->getOpcode() == Instruction::Add && + isa<ConstantInt>(Ops.back().Op) && + cast<ConstantInt>(Ops.back().Op)->isAllOnesValue()) { + Ops.insert(Ops.begin(), Ops.back()); + Ops.pop_back(); + } + + DOUT << "RAOut:\t"; DEBUG(PrintOps(I, Ops)); DOUT << "\n"; + + if (Ops.size() == 1) { + // This expression tree simplified to something that isn't a tree, + // eliminate it. + I->replaceAllUsesWith(Ops[0].Op); + RemoveDeadBinaryOp(I); + } else { + // Now that we ordered and optimized the expressions, splat them back into + // the expression tree, removing any unneeded nodes. + RewriteExprTree(I, Ops); + } +} + + +bool Reassociate::runOnFunction(Function &F) { + // Recalculate the rank map for F + BuildRankMap(F); + + MadeChange = false; + for (Function::iterator FI = F.begin(), FE = F.end(); FI != FE; ++FI) + ReassociateBB(FI); + + // We are done with the rank map... + RankMap.clear(); + ValueRankMap.clear(); + return MadeChange; +} + diff --git a/lib/Transforms/Scalar/Reg2Mem.cpp b/lib/Transforms/Scalar/Reg2Mem.cpp new file mode 100644 index 0000000..46b2952 --- /dev/null +++ b/lib/Transforms/Scalar/Reg2Mem.cpp @@ -0,0 +1,125 @@ +//===- Reg2Mem.cpp - Convert registers to allocas -------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file demotes all registers to memory references. It is intented to be +// the inverse of PromoteMemoryToRegister. By converting to loads, the only +// values live accross basic blocks are allocas and loads before phi nodes. +// It is intended that this should make CFG hacking much easier. +// To make later hacking easier, the entry block is split into two, such that +// all introduced allocas and nothing else are in the entry block. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "reg2mem" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Pass.h" +#include "llvm/Function.h" +#include "llvm/Module.h" +#include "llvm/BasicBlock.h" +#include "llvm/Instructions.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/CFG.h" +#include <list> +using namespace llvm; + +STATISTIC(NumRegsDemoted, "Number of registers demoted"); +STATISTIC(NumPhisDemoted, "Number of phi-nodes demoted"); + +namespace { + struct VISIBILITY_HIDDEN RegToMem : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + RegToMem() : FunctionPass(&ID) {} + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequiredID(BreakCriticalEdgesID); + AU.addPreservedID(BreakCriticalEdgesID); + } + + bool valueEscapes(Instruction* i) { + BasicBlock* bb = i->getParent(); + for (Value::use_iterator ii = i->use_begin(), ie = i->use_end(); + ii != ie; ++ii) + if (cast<Instruction>(*ii)->getParent() != bb || + isa<PHINode>(*ii)) + return true; + return false; + } + + virtual bool runOnFunction(Function &F) { + if (!F.isDeclaration()) { + // Insert all new allocas into entry block. + BasicBlock* BBEntry = &F.getEntryBlock(); + assert(pred_begin(BBEntry) == pred_end(BBEntry) && + "Entry block to function must not have predecessors!"); + + // Find first non-alloca instruction and create insertion point. This is + // safe if block is well-formed: it always have terminator, otherwise + // we'll get and assertion. + BasicBlock::iterator I = BBEntry->begin(); + while (isa<AllocaInst>(I)) ++I; + + CastInst *AllocaInsertionPoint = + CastInst::Create(Instruction::BitCast, + Constant::getNullValue(Type::Int32Ty), Type::Int32Ty, + "reg2mem alloca point", I); + + // Find the escaped instructions. But don't create stack slots for + // allocas in entry block. + std::list<Instruction*> worklist; + for (Function::iterator ibb = F.begin(), ibe = F.end(); + ibb != ibe; ++ibb) + for (BasicBlock::iterator iib = ibb->begin(), iie = ibb->end(); + iib != iie; ++iib) { + if (!(isa<AllocaInst>(iib) && iib->getParent() == BBEntry) && + valueEscapes(iib)) { + worklist.push_front(&*iib); + } + } + + // Demote escaped instructions + NumRegsDemoted += worklist.size(); + for (std::list<Instruction*>::iterator ilb = worklist.begin(), + ile = worklist.end(); ilb != ile; ++ilb) + DemoteRegToStack(**ilb, false, AllocaInsertionPoint); + + worklist.clear(); + + // Find all phi's + for (Function::iterator ibb = F.begin(), ibe = F.end(); + ibb != ibe; ++ibb) + for (BasicBlock::iterator iib = ibb->begin(), iie = ibb->end(); + iib != iie; ++iib) + if (isa<PHINode>(iib)) + worklist.push_front(&*iib); + + // Demote phi nodes + NumPhisDemoted += worklist.size(); + for (std::list<Instruction*>::iterator ilb = worklist.begin(), + ile = worklist.end(); ilb != ile; ++ilb) + DemotePHIToStack(cast<PHINode>(*ilb), AllocaInsertionPoint); + + return true; + } + return false; + } + }; +} + +char RegToMem::ID = 0; +static RegisterPass<RegToMem> +X("reg2mem", "Demote all values to stack slots"); + +// createDemoteRegisterToMemory - Provide an entry point to create this pass. +// +const PassInfo *const llvm::DemoteRegisterToMemoryID = &X; +FunctionPass *llvm::createDemoteRegisterToMemoryPass() { + return new RegToMem(); +} diff --git a/lib/Transforms/Scalar/SCCP.cpp b/lib/Transforms/Scalar/SCCP.cpp new file mode 100644 index 0000000..d73519c --- /dev/null +++ b/lib/Transforms/Scalar/SCCP.cpp @@ -0,0 +1,1855 @@ +//===- SCCP.cpp - Sparse Conditional Constant Propagation -----------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements sparse conditional constant propagation and merging: +// +// Specifically, this: +// * Assumes values are constant unless proven otherwise +// * Assumes BasicBlocks are dead unless proven otherwise +// * Proves values to be constant, and replaces them with constants +// * Proves conditional branches to be unconditional +// +// Notice that: +// * This pass has a habit of making definitions be dead. It is a good idea +// to to run a DCE pass sometime after running this pass. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "sccp" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/IPO.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Instructions.h" +#include "llvm/Pass.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Support/CallSite.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/InstVisitor.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/STLExtras.h" +#include <algorithm> +#include <map> +using namespace llvm; + +STATISTIC(NumInstRemoved, "Number of instructions removed"); +STATISTIC(NumDeadBlocks , "Number of basic blocks unreachable"); + +STATISTIC(IPNumInstRemoved, "Number of instructions removed by IPSCCP"); +STATISTIC(IPNumDeadBlocks , "Number of basic blocks unreachable by IPSCCP"); +STATISTIC(IPNumArgsElimed ,"Number of arguments constant propagated by IPSCCP"); +STATISTIC(IPNumGlobalConst, "Number of globals found to be constant by IPSCCP"); + +namespace { +/// LatticeVal class - This class represents the different lattice values that +/// an LLVM value may occupy. It is a simple class with value semantics. +/// +class VISIBILITY_HIDDEN LatticeVal { + enum { + /// undefined - This LLVM Value has no known value yet. + undefined, + + /// constant - This LLVM Value has a specific constant value. + constant, + + /// forcedconstant - This LLVM Value was thought to be undef until + /// ResolvedUndefsIn. This is treated just like 'constant', but if merged + /// with another (different) constant, it goes to overdefined, instead of + /// asserting. + forcedconstant, + + /// overdefined - This instruction is not known to be constant, and we know + /// it has a value. + overdefined + } LatticeValue; // The current lattice position + + Constant *ConstantVal; // If Constant value, the current value +public: + inline LatticeVal() : LatticeValue(undefined), ConstantVal(0) {} + + // markOverdefined - Return true if this is a new status to be in... + inline bool markOverdefined() { + if (LatticeValue != overdefined) { + LatticeValue = overdefined; + return true; + } + return false; + } + + // markConstant - Return true if this is a new status for us. + inline bool markConstant(Constant *V) { + if (LatticeValue != constant) { + if (LatticeValue == undefined) { + LatticeValue = constant; + assert(V && "Marking constant with NULL"); + ConstantVal = V; + } else { + assert(LatticeValue == forcedconstant && + "Cannot move from overdefined to constant!"); + // Stay at forcedconstant if the constant is the same. + if (V == ConstantVal) return false; + + // Otherwise, we go to overdefined. Assumptions made based on the + // forced value are possibly wrong. Assuming this is another constant + // could expose a contradiction. + LatticeValue = overdefined; + } + return true; + } else { + assert(ConstantVal == V && "Marking constant with different value"); + } + return false; + } + + inline void markForcedConstant(Constant *V) { + assert(LatticeValue == undefined && "Can't force a defined value!"); + LatticeValue = forcedconstant; + ConstantVal = V; + } + + inline bool isUndefined() const { return LatticeValue == undefined; } + inline bool isConstant() const { + return LatticeValue == constant || LatticeValue == forcedconstant; + } + inline bool isOverdefined() const { return LatticeValue == overdefined; } + + inline Constant *getConstant() const { + assert(isConstant() && "Cannot get the constant of a non-constant!"); + return ConstantVal; + } +}; + +//===----------------------------------------------------------------------===// +// +/// SCCPSolver - This class is a general purpose solver for Sparse Conditional +/// Constant Propagation. +/// +class SCCPSolver : public InstVisitor<SCCPSolver> { + DenseSet<BasicBlock*> BBExecutable;// The basic blocks that are executable + std::map<Value*, LatticeVal> ValueState; // The state each value is in. + + /// GlobalValue - If we are tracking any values for the contents of a global + /// variable, we keep a mapping from the constant accessor to the element of + /// the global, to the currently known value. If the value becomes + /// overdefined, it's entry is simply removed from this map. + DenseMap<GlobalVariable*, LatticeVal> TrackedGlobals; + + /// TrackedRetVals - If we are tracking arguments into and the return + /// value out of a function, it will have an entry in this map, indicating + /// what the known return value for the function is. + DenseMap<Function*, LatticeVal> TrackedRetVals; + + /// TrackedMultipleRetVals - Same as TrackedRetVals, but used for functions + /// that return multiple values. + DenseMap<std::pair<Function*, unsigned>, LatticeVal> TrackedMultipleRetVals; + + // The reason for two worklists is that overdefined is the lowest state + // on the lattice, and moving things to overdefined as fast as possible + // makes SCCP converge much faster. + // By having a separate worklist, we accomplish this because everything + // possibly overdefined will become overdefined at the soonest possible + // point. + SmallVector<Value*, 64> OverdefinedInstWorkList; + SmallVector<Value*, 64> InstWorkList; + + + SmallVector<BasicBlock*, 64> BBWorkList; // The BasicBlock work list + + /// UsersOfOverdefinedPHIs - Keep track of any users of PHI nodes that are not + /// overdefined, despite the fact that the PHI node is overdefined. + std::multimap<PHINode*, Instruction*> UsersOfOverdefinedPHIs; + + /// KnownFeasibleEdges - Entries in this set are edges which have already had + /// PHI nodes retriggered. + typedef std::pair<BasicBlock*, BasicBlock*> Edge; + DenseSet<Edge> KnownFeasibleEdges; +public: + + /// MarkBlockExecutable - This method can be used by clients to mark all of + /// the blocks that are known to be intrinsically live in the processed unit. + void MarkBlockExecutable(BasicBlock *BB) { + DOUT << "Marking Block Executable: " << BB->getNameStart() << "\n"; + BBExecutable.insert(BB); // Basic block is executable! + BBWorkList.push_back(BB); // Add the block to the work list! + } + + /// TrackValueOfGlobalVariable - Clients can use this method to + /// inform the SCCPSolver that it should track loads and stores to the + /// specified global variable if it can. This is only legal to call if + /// performing Interprocedural SCCP. + void TrackValueOfGlobalVariable(GlobalVariable *GV) { + const Type *ElTy = GV->getType()->getElementType(); + if (ElTy->isFirstClassType()) { + LatticeVal &IV = TrackedGlobals[GV]; + if (!isa<UndefValue>(GV->getInitializer())) + IV.markConstant(GV->getInitializer()); + } + } + + /// AddTrackedFunction - If the SCCP solver is supposed to track calls into + /// and out of the specified function (which cannot have its address taken), + /// this method must be called. + void AddTrackedFunction(Function *F) { + assert(F->hasLocalLinkage() && "Can only track internal functions!"); + // Add an entry, F -> undef. + if (const StructType *STy = dyn_cast<StructType>(F->getReturnType())) { + for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) + TrackedMultipleRetVals.insert(std::make_pair(std::make_pair(F, i), + LatticeVal())); + } else + TrackedRetVals.insert(std::make_pair(F, LatticeVal())); + } + + /// Solve - Solve for constants and executable blocks. + /// + void Solve(); + + /// ResolvedUndefsIn - While solving the dataflow for a function, we assume + /// that branches on undef values cannot reach any of their successors. + /// However, this is not a safe assumption. After we solve dataflow, this + /// method should be use to handle this. If this returns true, the solver + /// should be rerun. + bool ResolvedUndefsIn(Function &F); + + bool isBlockExecutable(BasicBlock *BB) const { + return BBExecutable.count(BB); + } + + /// getValueMapping - Once we have solved for constants, return the mapping of + /// LLVM values to LatticeVals. + std::map<Value*, LatticeVal> &getValueMapping() { + return ValueState; + } + + /// getTrackedRetVals - Get the inferred return value map. + /// + const DenseMap<Function*, LatticeVal> &getTrackedRetVals() { + return TrackedRetVals; + } + + /// getTrackedGlobals - Get and return the set of inferred initializers for + /// global variables. + const DenseMap<GlobalVariable*, LatticeVal> &getTrackedGlobals() { + return TrackedGlobals; + } + + inline void markOverdefined(Value *V) { + markOverdefined(ValueState[V], V); + } + +private: + // markConstant - Make a value be marked as "constant". If the value + // is not already a constant, add it to the instruction work list so that + // the users of the instruction are updated later. + // + inline void markConstant(LatticeVal &IV, Value *V, Constant *C) { + if (IV.markConstant(C)) { + DOUT << "markConstant: " << *C << ": " << *V; + InstWorkList.push_back(V); + } + } + + inline void markForcedConstant(LatticeVal &IV, Value *V, Constant *C) { + IV.markForcedConstant(C); + DOUT << "markForcedConstant: " << *C << ": " << *V; + InstWorkList.push_back(V); + } + + inline void markConstant(Value *V, Constant *C) { + markConstant(ValueState[V], V, C); + } + + // markOverdefined - Make a value be marked as "overdefined". If the + // value is not already overdefined, add it to the overdefined instruction + // work list so that the users of the instruction are updated later. + inline void markOverdefined(LatticeVal &IV, Value *V) { + if (IV.markOverdefined()) { + DEBUG(DOUT << "markOverdefined: "; + if (Function *F = dyn_cast<Function>(V)) + DOUT << "Function '" << F->getName() << "'\n"; + else + DOUT << *V); + // Only instructions go on the work list + OverdefinedInstWorkList.push_back(V); + } + } + + inline void mergeInValue(LatticeVal &IV, Value *V, LatticeVal &MergeWithV) { + if (IV.isOverdefined() || MergeWithV.isUndefined()) + return; // Noop. + if (MergeWithV.isOverdefined()) + markOverdefined(IV, V); + else if (IV.isUndefined()) + markConstant(IV, V, MergeWithV.getConstant()); + else if (IV.getConstant() != MergeWithV.getConstant()) + markOverdefined(IV, V); + } + + inline void mergeInValue(Value *V, LatticeVal &MergeWithV) { + return mergeInValue(ValueState[V], V, MergeWithV); + } + + + // getValueState - Return the LatticeVal object that corresponds to the value. + // This function is necessary because not all values should start out in the + // underdefined state... Argument's should be overdefined, and + // constants should be marked as constants. If a value is not known to be an + // Instruction object, then use this accessor to get its value from the map. + // + inline LatticeVal &getValueState(Value *V) { + std::map<Value*, LatticeVal>::iterator I = ValueState.find(V); + if (I != ValueState.end()) return I->second; // Common case, in the map + + if (Constant *C = dyn_cast<Constant>(V)) { + if (isa<UndefValue>(V)) { + // Nothing to do, remain undefined. + } else { + LatticeVal &LV = ValueState[C]; + LV.markConstant(C); // Constants are constant + return LV; + } + } + // All others are underdefined by default... + return ValueState[V]; + } + + // markEdgeExecutable - Mark a basic block as executable, adding it to the BB + // work list if it is not already executable... + // + void markEdgeExecutable(BasicBlock *Source, BasicBlock *Dest) { + if (!KnownFeasibleEdges.insert(Edge(Source, Dest)).second) + return; // This edge is already known to be executable! + + if (BBExecutable.count(Dest)) { + DOUT << "Marking Edge Executable: " << Source->getNameStart() + << " -> " << Dest->getNameStart() << "\n"; + + // The destination is already executable, but we just made an edge + // feasible that wasn't before. Revisit the PHI nodes in the block + // because they have potentially new operands. + for (BasicBlock::iterator I = Dest->begin(); isa<PHINode>(I); ++I) + visitPHINode(*cast<PHINode>(I)); + + } else { + MarkBlockExecutable(Dest); + } + } + + // getFeasibleSuccessors - Return a vector of booleans to indicate which + // successors are reachable from a given terminator instruction. + // + void getFeasibleSuccessors(TerminatorInst &TI, SmallVector<bool, 16> &Succs); + + // isEdgeFeasible - Return true if the control flow edge from the 'From' basic + // block to the 'To' basic block is currently feasible... + // + bool isEdgeFeasible(BasicBlock *From, BasicBlock *To); + + // OperandChangedState - This method is invoked on all of the users of an + // instruction that was just changed state somehow.... Based on this + // information, we need to update the specified user of this instruction. + // + void OperandChangedState(User *U) { + // Only instructions use other variable values! + Instruction &I = cast<Instruction>(*U); + if (BBExecutable.count(I.getParent())) // Inst is executable? + visit(I); + } + +private: + friend class InstVisitor<SCCPSolver>; + + // visit implementations - Something changed in this instruction... Either an + // operand made a transition, or the instruction is newly executable. Change + // the value type of I to reflect these changes if appropriate. + // + void visitPHINode(PHINode &I); + + // Terminators + void visitReturnInst(ReturnInst &I); + void visitTerminatorInst(TerminatorInst &TI); + + void visitCastInst(CastInst &I); + void visitSelectInst(SelectInst &I); + void visitBinaryOperator(Instruction &I); + void visitCmpInst(CmpInst &I); + void visitExtractElementInst(ExtractElementInst &I); + void visitInsertElementInst(InsertElementInst &I); + void visitShuffleVectorInst(ShuffleVectorInst &I); + void visitExtractValueInst(ExtractValueInst &EVI); + void visitInsertValueInst(InsertValueInst &IVI); + + // Instructions that cannot be folded away... + void visitStoreInst (Instruction &I); + void visitLoadInst (LoadInst &I); + void visitGetElementPtrInst(GetElementPtrInst &I); + void visitCallInst (CallInst &I) { visitCallSite(CallSite::get(&I)); } + void visitInvokeInst (InvokeInst &II) { + visitCallSite(CallSite::get(&II)); + visitTerminatorInst(II); + } + void visitCallSite (CallSite CS); + void visitUnwindInst (TerminatorInst &I) { /*returns void*/ } + void visitUnreachableInst(TerminatorInst &I) { /*returns void*/ } + void visitAllocationInst(Instruction &I) { markOverdefined(&I); } + void visitVANextInst (Instruction &I) { markOverdefined(&I); } + void visitVAArgInst (Instruction &I) { markOverdefined(&I); } + void visitFreeInst (Instruction &I) { /*returns void*/ } + + void visitInstruction(Instruction &I) { + // If a new instruction is added to LLVM that we don't handle... + cerr << "SCCP: Don't know how to handle: " << I; + markOverdefined(&I); // Just in case + } +}; + +} // end anonymous namespace + + +// getFeasibleSuccessors - Return a vector of booleans to indicate which +// successors are reachable from a given terminator instruction. +// +void SCCPSolver::getFeasibleSuccessors(TerminatorInst &TI, + SmallVector<bool, 16> &Succs) { + Succs.resize(TI.getNumSuccessors()); + if (BranchInst *BI = dyn_cast<BranchInst>(&TI)) { + if (BI->isUnconditional()) { + Succs[0] = true; + } else { + LatticeVal &BCValue = getValueState(BI->getCondition()); + if (BCValue.isOverdefined() || + (BCValue.isConstant() && !isa<ConstantInt>(BCValue.getConstant()))) { + // Overdefined condition variables, and branches on unfoldable constant + // conditions, mean the branch could go either way. + Succs[0] = Succs[1] = true; + } else if (BCValue.isConstant()) { + // Constant condition variables mean the branch can only go a single way + Succs[BCValue.getConstant() == ConstantInt::getFalse()] = true; + } + } + } else if (isa<InvokeInst>(&TI)) { + // Invoke instructions successors are always executable. + Succs[0] = Succs[1] = true; + } else if (SwitchInst *SI = dyn_cast<SwitchInst>(&TI)) { + LatticeVal &SCValue = getValueState(SI->getCondition()); + if (SCValue.isOverdefined() || // Overdefined condition? + (SCValue.isConstant() && !isa<ConstantInt>(SCValue.getConstant()))) { + // All destinations are executable! + Succs.assign(TI.getNumSuccessors(), true); + } else if (SCValue.isConstant()) + Succs[SI->findCaseValue(cast<ConstantInt>(SCValue.getConstant()))] = true; + } else { + assert(0 && "SCCP: Don't know how to handle this terminator!"); + } +} + + +// isEdgeFeasible - Return true if the control flow edge from the 'From' basic +// block to the 'To' basic block is currently feasible... +// +bool SCCPSolver::isEdgeFeasible(BasicBlock *From, BasicBlock *To) { + assert(BBExecutable.count(To) && "Dest should always be alive!"); + + // Make sure the source basic block is executable!! + if (!BBExecutable.count(From)) return false; + + // Check to make sure this edge itself is actually feasible now... + TerminatorInst *TI = From->getTerminator(); + if (BranchInst *BI = dyn_cast<BranchInst>(TI)) { + if (BI->isUnconditional()) + return true; + else { + LatticeVal &BCValue = getValueState(BI->getCondition()); + if (BCValue.isOverdefined()) { + // Overdefined condition variables mean the branch could go either way. + return true; + } else if (BCValue.isConstant()) { + // Not branching on an evaluatable constant? + if (!isa<ConstantInt>(BCValue.getConstant())) return true; + + // Constant condition variables mean the branch can only go a single way + return BI->getSuccessor(BCValue.getConstant() == + ConstantInt::getFalse()) == To; + } + return false; + } + } else if (isa<InvokeInst>(TI)) { + // Invoke instructions successors are always executable. + return true; + } else if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) { + LatticeVal &SCValue = getValueState(SI->getCondition()); + if (SCValue.isOverdefined()) { // Overdefined condition? + // All destinations are executable! + return true; + } else if (SCValue.isConstant()) { + Constant *CPV = SCValue.getConstant(); + if (!isa<ConstantInt>(CPV)) + return true; // not a foldable constant? + + // Make sure to skip the "default value" which isn't a value + for (unsigned i = 1, E = SI->getNumSuccessors(); i != E; ++i) + if (SI->getSuccessorValue(i) == CPV) // Found the taken branch... + return SI->getSuccessor(i) == To; + + // Constant value not equal to any of the branches... must execute + // default branch then... + return SI->getDefaultDest() == To; + } + return false; + } else { + cerr << "Unknown terminator instruction: " << *TI; + abort(); + } +} + +// visit Implementations - Something changed in this instruction... Either an +// operand made a transition, or the instruction is newly executable. Change +// the value type of I to reflect these changes if appropriate. This method +// makes sure to do the following actions: +// +// 1. If a phi node merges two constants in, and has conflicting value coming +// from different branches, or if the PHI node merges in an overdefined +// value, then the PHI node becomes overdefined. +// 2. If a phi node merges only constants in, and they all agree on value, the +// PHI node becomes a constant value equal to that. +// 3. If V <- x (op) y && isConstant(x) && isConstant(y) V = Constant +// 4. If V <- x (op) y && (isOverdefined(x) || isOverdefined(y)) V = Overdefined +// 5. If V <- MEM or V <- CALL or V <- (unknown) then V = Overdefined +// 6. If a conditional branch has a value that is constant, make the selected +// destination executable +// 7. If a conditional branch has a value that is overdefined, make all +// successors executable. +// +void SCCPSolver::visitPHINode(PHINode &PN) { + LatticeVal &PNIV = getValueState(&PN); + if (PNIV.isOverdefined()) { + // There may be instructions using this PHI node that are not overdefined + // themselves. If so, make sure that they know that the PHI node operand + // changed. + std::multimap<PHINode*, Instruction*>::iterator I, E; + tie(I, E) = UsersOfOverdefinedPHIs.equal_range(&PN); + if (I != E) { + SmallVector<Instruction*, 16> Users; + for (; I != E; ++I) Users.push_back(I->second); + while (!Users.empty()) { + visit(Users.back()); + Users.pop_back(); + } + } + return; // Quick exit + } + + // Super-extra-high-degree PHI nodes are unlikely to ever be marked constant, + // and slow us down a lot. Just mark them overdefined. + if (PN.getNumIncomingValues() > 64) { + markOverdefined(PNIV, &PN); + return; + } + + // Look at all of the executable operands of the PHI node. If any of them + // are overdefined, the PHI becomes overdefined as well. If they are all + // constant, and they agree with each other, the PHI becomes the identical + // constant. If they are constant and don't agree, the PHI is overdefined. + // If there are no executable operands, the PHI remains undefined. + // + Constant *OperandVal = 0; + for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i) { + LatticeVal &IV = getValueState(PN.getIncomingValue(i)); + if (IV.isUndefined()) continue; // Doesn't influence PHI node. + + if (isEdgeFeasible(PN.getIncomingBlock(i), PN.getParent())) { + if (IV.isOverdefined()) { // PHI node becomes overdefined! + markOverdefined(&PN); + return; + } + + if (OperandVal == 0) { // Grab the first value... + OperandVal = IV.getConstant(); + } else { // Another value is being merged in! + // There is already a reachable operand. If we conflict with it, + // then the PHI node becomes overdefined. If we agree with it, we + // can continue on. + + // Check to see if there are two different constants merging... + if (IV.getConstant() != OperandVal) { + // Yes there is. This means the PHI node is not constant. + // You must be overdefined poor PHI. + // + markOverdefined(&PN); // The PHI node now becomes overdefined + return; // I'm done analyzing you + } + } + } + } + + // If we exited the loop, this means that the PHI node only has constant + // arguments that agree with each other(and OperandVal is the constant) or + // OperandVal is null because there are no defined incoming arguments. If + // this is the case, the PHI remains undefined. + // + if (OperandVal) + markConstant(&PN, OperandVal); // Acquire operand value +} + +void SCCPSolver::visitReturnInst(ReturnInst &I) { + if (I.getNumOperands() == 0) return; // Ret void + + Function *F = I.getParent()->getParent(); + // If we are tracking the return value of this function, merge it in. + if (!F->hasLocalLinkage()) + return; + + if (!TrackedRetVals.empty() && I.getNumOperands() == 1) { + DenseMap<Function*, LatticeVal>::iterator TFRVI = + TrackedRetVals.find(F); + if (TFRVI != TrackedRetVals.end() && + !TFRVI->second.isOverdefined()) { + LatticeVal &IV = getValueState(I.getOperand(0)); + mergeInValue(TFRVI->second, F, IV); + return; + } + } + + // Handle functions that return multiple values. + if (!TrackedMultipleRetVals.empty() && I.getNumOperands() > 1) { + for (unsigned i = 0, e = I.getNumOperands(); i != e; ++i) { + DenseMap<std::pair<Function*, unsigned>, LatticeVal>::iterator + It = TrackedMultipleRetVals.find(std::make_pair(F, i)); + if (It == TrackedMultipleRetVals.end()) break; + mergeInValue(It->second, F, getValueState(I.getOperand(i))); + } + } else if (!TrackedMultipleRetVals.empty() && + I.getNumOperands() == 1 && + isa<StructType>(I.getOperand(0)->getType())) { + for (unsigned i = 0, e = I.getOperand(0)->getType()->getNumContainedTypes(); + i != e; ++i) { + DenseMap<std::pair<Function*, unsigned>, LatticeVal>::iterator + It = TrackedMultipleRetVals.find(std::make_pair(F, i)); + if (It == TrackedMultipleRetVals.end()) break; + Value *Val = FindInsertedValue(I.getOperand(0), i); + mergeInValue(It->second, F, getValueState(Val)); + } + } +} + +void SCCPSolver::visitTerminatorInst(TerminatorInst &TI) { + SmallVector<bool, 16> SuccFeasible; + getFeasibleSuccessors(TI, SuccFeasible); + + BasicBlock *BB = TI.getParent(); + + // Mark all feasible successors executable... + for (unsigned i = 0, e = SuccFeasible.size(); i != e; ++i) + if (SuccFeasible[i]) + markEdgeExecutable(BB, TI.getSuccessor(i)); +} + +void SCCPSolver::visitCastInst(CastInst &I) { + Value *V = I.getOperand(0); + LatticeVal &VState = getValueState(V); + if (VState.isOverdefined()) // Inherit overdefinedness of operand + markOverdefined(&I); + else if (VState.isConstant()) // Propagate constant value + markConstant(&I, ConstantExpr::getCast(I.getOpcode(), + VState.getConstant(), I.getType())); +} + +void SCCPSolver::visitExtractValueInst(ExtractValueInst &EVI) { + Value *Aggr = EVI.getAggregateOperand(); + + // If the operand to the extractvalue is an undef, the result is undef. + if (isa<UndefValue>(Aggr)) + return; + + // Currently only handle single-index extractvalues. + if (EVI.getNumIndices() != 1) { + markOverdefined(&EVI); + return; + } + + Function *F = 0; + if (CallInst *CI = dyn_cast<CallInst>(Aggr)) + F = CI->getCalledFunction(); + else if (InvokeInst *II = dyn_cast<InvokeInst>(Aggr)) + F = II->getCalledFunction(); + + // TODO: If IPSCCP resolves the callee of this function, we could propagate a + // result back! + if (F == 0 || TrackedMultipleRetVals.empty()) { + markOverdefined(&EVI); + return; + } + + // See if we are tracking the result of the callee. If not tracking this + // function (for example, it is a declaration) just move to overdefined. + if (!TrackedMultipleRetVals.count(std::make_pair(F, *EVI.idx_begin()))) { + markOverdefined(&EVI); + return; + } + + // Otherwise, the value will be merged in here as a result of CallSite + // handling. +} + +void SCCPSolver::visitInsertValueInst(InsertValueInst &IVI) { + Value *Aggr = IVI.getAggregateOperand(); + Value *Val = IVI.getInsertedValueOperand(); + + // If the operands to the insertvalue are undef, the result is undef. + if (isa<UndefValue>(Aggr) && isa<UndefValue>(Val)) + return; + + // Currently only handle single-index insertvalues. + if (IVI.getNumIndices() != 1) { + markOverdefined(&IVI); + return; + } + + // Currently only handle insertvalue instructions that are in a single-use + // chain that builds up a return value. + for (const InsertValueInst *TmpIVI = &IVI; ; ) { + if (!TmpIVI->hasOneUse()) { + markOverdefined(&IVI); + return; + } + const Value *V = *TmpIVI->use_begin(); + if (isa<ReturnInst>(V)) + break; + TmpIVI = dyn_cast<InsertValueInst>(V); + if (!TmpIVI) { + markOverdefined(&IVI); + return; + } + } + + // See if we are tracking the result of the callee. + Function *F = IVI.getParent()->getParent(); + DenseMap<std::pair<Function*, unsigned>, LatticeVal>::iterator + It = TrackedMultipleRetVals.find(std::make_pair(F, *IVI.idx_begin())); + + // Merge in the inserted member value. + if (It != TrackedMultipleRetVals.end()) + mergeInValue(It->second, F, getValueState(Val)); + + // Mark the aggregate result of the IVI overdefined; any tracking that we do + // will be done on the individual member values. + markOverdefined(&IVI); +} + +void SCCPSolver::visitSelectInst(SelectInst &I) { + LatticeVal &CondValue = getValueState(I.getCondition()); + if (CondValue.isUndefined()) + return; + if (CondValue.isConstant()) { + if (ConstantInt *CondCB = dyn_cast<ConstantInt>(CondValue.getConstant())){ + mergeInValue(&I, getValueState(CondCB->getZExtValue() ? I.getTrueValue() + : I.getFalseValue())); + return; + } + } + + // Otherwise, the condition is overdefined or a constant we can't evaluate. + // See if we can produce something better than overdefined based on the T/F + // value. + LatticeVal &TVal = getValueState(I.getTrueValue()); + LatticeVal &FVal = getValueState(I.getFalseValue()); + + // select ?, C, C -> C. + if (TVal.isConstant() && FVal.isConstant() && + TVal.getConstant() == FVal.getConstant()) { + markConstant(&I, FVal.getConstant()); + return; + } + + if (TVal.isUndefined()) { // select ?, undef, X -> X. + mergeInValue(&I, FVal); + } else if (FVal.isUndefined()) { // select ?, X, undef -> X. + mergeInValue(&I, TVal); + } else { + markOverdefined(&I); + } +} + +// Handle BinaryOperators and Shift Instructions... +void SCCPSolver::visitBinaryOperator(Instruction &I) { + LatticeVal &IV = ValueState[&I]; + if (IV.isOverdefined()) return; + + LatticeVal &V1State = getValueState(I.getOperand(0)); + LatticeVal &V2State = getValueState(I.getOperand(1)); + + if (V1State.isOverdefined() || V2State.isOverdefined()) { + // If this is an AND or OR with 0 or -1, it doesn't matter that the other + // operand is overdefined. + if (I.getOpcode() == Instruction::And || I.getOpcode() == Instruction::Or) { + LatticeVal *NonOverdefVal = 0; + if (!V1State.isOverdefined()) { + NonOverdefVal = &V1State; + } else if (!V2State.isOverdefined()) { + NonOverdefVal = &V2State; + } + + if (NonOverdefVal) { + if (NonOverdefVal->isUndefined()) { + // Could annihilate value. + if (I.getOpcode() == Instruction::And) + markConstant(IV, &I, Constant::getNullValue(I.getType())); + else if (const VectorType *PT = dyn_cast<VectorType>(I.getType())) + markConstant(IV, &I, ConstantVector::getAllOnesValue(PT)); + else + markConstant(IV, &I, ConstantInt::getAllOnesValue(I.getType())); + return; + } else { + if (I.getOpcode() == Instruction::And) { + if (NonOverdefVal->getConstant()->isNullValue()) { + markConstant(IV, &I, NonOverdefVal->getConstant()); + return; // X and 0 = 0 + } + } else { + if (ConstantInt *CI = + dyn_cast<ConstantInt>(NonOverdefVal->getConstant())) + if (CI->isAllOnesValue()) { + markConstant(IV, &I, NonOverdefVal->getConstant()); + return; // X or -1 = -1 + } + } + } + } + } + + + // If both operands are PHI nodes, it is possible that this instruction has + // a constant value, despite the fact that the PHI node doesn't. Check for + // this condition now. + if (PHINode *PN1 = dyn_cast<PHINode>(I.getOperand(0))) + if (PHINode *PN2 = dyn_cast<PHINode>(I.getOperand(1))) + if (PN1->getParent() == PN2->getParent()) { + // Since the two PHI nodes are in the same basic block, they must have + // entries for the same predecessors. Walk the predecessor list, and + // if all of the incoming values are constants, and the result of + // evaluating this expression with all incoming value pairs is the + // same, then this expression is a constant even though the PHI node + // is not a constant! + LatticeVal Result; + for (unsigned i = 0, e = PN1->getNumIncomingValues(); i != e; ++i) { + LatticeVal &In1 = getValueState(PN1->getIncomingValue(i)); + BasicBlock *InBlock = PN1->getIncomingBlock(i); + LatticeVal &In2 = + getValueState(PN2->getIncomingValueForBlock(InBlock)); + + if (In1.isOverdefined() || In2.isOverdefined()) { + Result.markOverdefined(); + break; // Cannot fold this operation over the PHI nodes! + } else if (In1.isConstant() && In2.isConstant()) { + Constant *V = ConstantExpr::get(I.getOpcode(), In1.getConstant(), + In2.getConstant()); + if (Result.isUndefined()) + Result.markConstant(V); + else if (Result.isConstant() && Result.getConstant() != V) { + Result.markOverdefined(); + break; + } + } + } + + // If we found a constant value here, then we know the instruction is + // constant despite the fact that the PHI nodes are overdefined. + if (Result.isConstant()) { + markConstant(IV, &I, Result.getConstant()); + // Remember that this instruction is virtually using the PHI node + // operands. + UsersOfOverdefinedPHIs.insert(std::make_pair(PN1, &I)); + UsersOfOverdefinedPHIs.insert(std::make_pair(PN2, &I)); + return; + } else if (Result.isUndefined()) { + return; + } + + // Okay, this really is overdefined now. Since we might have + // speculatively thought that this was not overdefined before, and + // added ourselves to the UsersOfOverdefinedPHIs list for the PHIs, + // make sure to clean out any entries that we put there, for + // efficiency. + std::multimap<PHINode*, Instruction*>::iterator It, E; + tie(It, E) = UsersOfOverdefinedPHIs.equal_range(PN1); + while (It != E) { + if (It->second == &I) { + UsersOfOverdefinedPHIs.erase(It++); + } else + ++It; + } + tie(It, E) = UsersOfOverdefinedPHIs.equal_range(PN2); + while (It != E) { + if (It->second == &I) { + UsersOfOverdefinedPHIs.erase(It++); + } else + ++It; + } + } + + markOverdefined(IV, &I); + } else if (V1State.isConstant() && V2State.isConstant()) { + markConstant(IV, &I, ConstantExpr::get(I.getOpcode(), V1State.getConstant(), + V2State.getConstant())); + } +} + +// Handle ICmpInst instruction... +void SCCPSolver::visitCmpInst(CmpInst &I) { + LatticeVal &IV = ValueState[&I]; + if (IV.isOverdefined()) return; + + LatticeVal &V1State = getValueState(I.getOperand(0)); + LatticeVal &V2State = getValueState(I.getOperand(1)); + + if (V1State.isOverdefined() || V2State.isOverdefined()) { + // If both operands are PHI nodes, it is possible that this instruction has + // a constant value, despite the fact that the PHI node doesn't. Check for + // this condition now. + if (PHINode *PN1 = dyn_cast<PHINode>(I.getOperand(0))) + if (PHINode *PN2 = dyn_cast<PHINode>(I.getOperand(1))) + if (PN1->getParent() == PN2->getParent()) { + // Since the two PHI nodes are in the same basic block, they must have + // entries for the same predecessors. Walk the predecessor list, and + // if all of the incoming values are constants, and the result of + // evaluating this expression with all incoming value pairs is the + // same, then this expression is a constant even though the PHI node + // is not a constant! + LatticeVal Result; + for (unsigned i = 0, e = PN1->getNumIncomingValues(); i != e; ++i) { + LatticeVal &In1 = getValueState(PN1->getIncomingValue(i)); + BasicBlock *InBlock = PN1->getIncomingBlock(i); + LatticeVal &In2 = + getValueState(PN2->getIncomingValueForBlock(InBlock)); + + if (In1.isOverdefined() || In2.isOverdefined()) { + Result.markOverdefined(); + break; // Cannot fold this operation over the PHI nodes! + } else if (In1.isConstant() && In2.isConstant()) { + Constant *V = ConstantExpr::getCompare(I.getPredicate(), + In1.getConstant(), + In2.getConstant()); + if (Result.isUndefined()) + Result.markConstant(V); + else if (Result.isConstant() && Result.getConstant() != V) { + Result.markOverdefined(); + break; + } + } + } + + // If we found a constant value here, then we know the instruction is + // constant despite the fact that the PHI nodes are overdefined. + if (Result.isConstant()) { + markConstant(IV, &I, Result.getConstant()); + // Remember that this instruction is virtually using the PHI node + // operands. + UsersOfOverdefinedPHIs.insert(std::make_pair(PN1, &I)); + UsersOfOverdefinedPHIs.insert(std::make_pair(PN2, &I)); + return; + } else if (Result.isUndefined()) { + return; + } + + // Okay, this really is overdefined now. Since we might have + // speculatively thought that this was not overdefined before, and + // added ourselves to the UsersOfOverdefinedPHIs list for the PHIs, + // make sure to clean out any entries that we put there, for + // efficiency. + std::multimap<PHINode*, Instruction*>::iterator It, E; + tie(It, E) = UsersOfOverdefinedPHIs.equal_range(PN1); + while (It != E) { + if (It->second == &I) { + UsersOfOverdefinedPHIs.erase(It++); + } else + ++It; + } + tie(It, E) = UsersOfOverdefinedPHIs.equal_range(PN2); + while (It != E) { + if (It->second == &I) { + UsersOfOverdefinedPHIs.erase(It++); + } else + ++It; + } + } + + markOverdefined(IV, &I); + } else if (V1State.isConstant() && V2State.isConstant()) { + markConstant(IV, &I, ConstantExpr::getCompare(I.getPredicate(), + V1State.getConstant(), + V2State.getConstant())); + } +} + +void SCCPSolver::visitExtractElementInst(ExtractElementInst &I) { + // FIXME : SCCP does not handle vectors properly. + markOverdefined(&I); + return; + +#if 0 + LatticeVal &ValState = getValueState(I.getOperand(0)); + LatticeVal &IdxState = getValueState(I.getOperand(1)); + + if (ValState.isOverdefined() || IdxState.isOverdefined()) + markOverdefined(&I); + else if(ValState.isConstant() && IdxState.isConstant()) + markConstant(&I, ConstantExpr::getExtractElement(ValState.getConstant(), + IdxState.getConstant())); +#endif +} + +void SCCPSolver::visitInsertElementInst(InsertElementInst &I) { + // FIXME : SCCP does not handle vectors properly. + markOverdefined(&I); + return; +#if 0 + LatticeVal &ValState = getValueState(I.getOperand(0)); + LatticeVal &EltState = getValueState(I.getOperand(1)); + LatticeVal &IdxState = getValueState(I.getOperand(2)); + + if (ValState.isOverdefined() || EltState.isOverdefined() || + IdxState.isOverdefined()) + markOverdefined(&I); + else if(ValState.isConstant() && EltState.isConstant() && + IdxState.isConstant()) + markConstant(&I, ConstantExpr::getInsertElement(ValState.getConstant(), + EltState.getConstant(), + IdxState.getConstant())); + else if (ValState.isUndefined() && EltState.isConstant() && + IdxState.isConstant()) + markConstant(&I,ConstantExpr::getInsertElement(UndefValue::get(I.getType()), + EltState.getConstant(), + IdxState.getConstant())); +#endif +} + +void SCCPSolver::visitShuffleVectorInst(ShuffleVectorInst &I) { + // FIXME : SCCP does not handle vectors properly. + markOverdefined(&I); + return; +#if 0 + LatticeVal &V1State = getValueState(I.getOperand(0)); + LatticeVal &V2State = getValueState(I.getOperand(1)); + LatticeVal &MaskState = getValueState(I.getOperand(2)); + + if (MaskState.isUndefined() || + (V1State.isUndefined() && V2State.isUndefined())) + return; // Undefined output if mask or both inputs undefined. + + if (V1State.isOverdefined() || V2State.isOverdefined() || + MaskState.isOverdefined()) { + markOverdefined(&I); + } else { + // A mix of constant/undef inputs. + Constant *V1 = V1State.isConstant() ? + V1State.getConstant() : UndefValue::get(I.getType()); + Constant *V2 = V2State.isConstant() ? + V2State.getConstant() : UndefValue::get(I.getType()); + Constant *Mask = MaskState.isConstant() ? + MaskState.getConstant() : UndefValue::get(I.getOperand(2)->getType()); + markConstant(&I, ConstantExpr::getShuffleVector(V1, V2, Mask)); + } +#endif +} + +// Handle getelementptr instructions... if all operands are constants then we +// can turn this into a getelementptr ConstantExpr. +// +void SCCPSolver::visitGetElementPtrInst(GetElementPtrInst &I) { + LatticeVal &IV = ValueState[&I]; + if (IV.isOverdefined()) return; + + SmallVector<Constant*, 8> Operands; + Operands.reserve(I.getNumOperands()); + + for (unsigned i = 0, e = I.getNumOperands(); i != e; ++i) { + LatticeVal &State = getValueState(I.getOperand(i)); + if (State.isUndefined()) + return; // Operands are not resolved yet... + else if (State.isOverdefined()) { + markOverdefined(IV, &I); + return; + } + assert(State.isConstant() && "Unknown state!"); + Operands.push_back(State.getConstant()); + } + + Constant *Ptr = Operands[0]; + Operands.erase(Operands.begin()); // Erase the pointer from idx list... + + markConstant(IV, &I, ConstantExpr::getGetElementPtr(Ptr, &Operands[0], + Operands.size())); +} + +void SCCPSolver::visitStoreInst(Instruction &SI) { + if (TrackedGlobals.empty() || !isa<GlobalVariable>(SI.getOperand(1))) + return; + GlobalVariable *GV = cast<GlobalVariable>(SI.getOperand(1)); + DenseMap<GlobalVariable*, LatticeVal>::iterator I = TrackedGlobals.find(GV); + if (I == TrackedGlobals.end() || I->second.isOverdefined()) return; + + // Get the value we are storing into the global. + LatticeVal &PtrVal = getValueState(SI.getOperand(0)); + + mergeInValue(I->second, GV, PtrVal); + if (I->second.isOverdefined()) + TrackedGlobals.erase(I); // No need to keep tracking this! +} + + +// Handle load instructions. If the operand is a constant pointer to a constant +// global, we can replace the load with the loaded constant value! +void SCCPSolver::visitLoadInst(LoadInst &I) { + LatticeVal &IV = ValueState[&I]; + if (IV.isOverdefined()) return; + + LatticeVal &PtrVal = getValueState(I.getOperand(0)); + if (PtrVal.isUndefined()) return; // The pointer is not resolved yet! + if (PtrVal.isConstant() && !I.isVolatile()) { + Value *Ptr = PtrVal.getConstant(); + // TODO: Consider a target hook for valid address spaces for this xform. + if (isa<ConstantPointerNull>(Ptr) && + cast<PointerType>(Ptr->getType())->getAddressSpace() == 0) { + // load null -> null + markConstant(IV, &I, Constant::getNullValue(I.getType())); + return; + } + + // Transform load (constant global) into the value loaded. + if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Ptr)) { + if (GV->isConstant()) { + if (GV->hasDefinitiveInitializer()) { + markConstant(IV, &I, GV->getInitializer()); + return; + } + } else if (!TrackedGlobals.empty()) { + // If we are tracking this global, merge in the known value for it. + DenseMap<GlobalVariable*, LatticeVal>::iterator It = + TrackedGlobals.find(GV); + if (It != TrackedGlobals.end()) { + mergeInValue(IV, &I, It->second); + return; + } + } + } + + // Transform load (constantexpr_GEP global, 0, ...) into the value loaded. + if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Ptr)) + if (CE->getOpcode() == Instruction::GetElementPtr) + if (GlobalVariable *GV = dyn_cast<GlobalVariable>(CE->getOperand(0))) + if (GV->isConstant() && GV->hasDefinitiveInitializer()) + if (Constant *V = + ConstantFoldLoadThroughGEPConstantExpr(GV->getInitializer(), CE)) { + markConstant(IV, &I, V); + return; + } + } + + // Otherwise we cannot say for certain what value this load will produce. + // Bail out. + markOverdefined(IV, &I); +} + +void SCCPSolver::visitCallSite(CallSite CS) { + Function *F = CS.getCalledFunction(); + Instruction *I = CS.getInstruction(); + + // The common case is that we aren't tracking the callee, either because we + // are not doing interprocedural analysis or the callee is indirect, or is + // external. Handle these cases first. + if (F == 0 || !F->hasLocalLinkage()) { +CallOverdefined: + // Void return and not tracking callee, just bail. + if (I->getType() == Type::VoidTy) return; + + // Otherwise, if we have a single return value case, and if the function is + // a declaration, maybe we can constant fold it. + if (!isa<StructType>(I->getType()) && F && F->isDeclaration() && + canConstantFoldCallTo(F)) { + + SmallVector<Constant*, 8> Operands; + for (CallSite::arg_iterator AI = CS.arg_begin(), E = CS.arg_end(); + AI != E; ++AI) { + LatticeVal &State = getValueState(*AI); + if (State.isUndefined()) + return; // Operands are not resolved yet. + else if (State.isOverdefined()) { + markOverdefined(I); + return; + } + assert(State.isConstant() && "Unknown state!"); + Operands.push_back(State.getConstant()); + } + + // If we can constant fold this, mark the result of the call as a + // constant. + if (Constant *C = ConstantFoldCall(F, Operands.data(), Operands.size())) { + markConstant(I, C); + return; + } + } + + // Otherwise, we don't know anything about this call, mark it overdefined. + markOverdefined(I); + return; + } + + // If this is a single/zero retval case, see if we're tracking the function. + DenseMap<Function*, LatticeVal>::iterator TFRVI = TrackedRetVals.find(F); + if (TFRVI != TrackedRetVals.end()) { + // If so, propagate the return value of the callee into this call result. + mergeInValue(I, TFRVI->second); + } else if (isa<StructType>(I->getType())) { + // Check to see if we're tracking this callee, if not, handle it in the + // common path above. + DenseMap<std::pair<Function*, unsigned>, LatticeVal>::iterator + TMRVI = TrackedMultipleRetVals.find(std::make_pair(F, 0)); + if (TMRVI == TrackedMultipleRetVals.end()) + goto CallOverdefined; + + // If we are tracking this callee, propagate the return values of the call + // into this call site. We do this by walking all the uses. Single-index + // ExtractValueInst uses can be tracked; anything more complicated is + // currently handled conservatively. + for (Value::use_iterator UI = I->use_begin(), E = I->use_end(); + UI != E; ++UI) { + if (ExtractValueInst *EVI = dyn_cast<ExtractValueInst>(*UI)) { + if (EVI->getNumIndices() == 1) { + mergeInValue(EVI, + TrackedMultipleRetVals[std::make_pair(F, *EVI->idx_begin())]); + continue; + } + } + // The aggregate value is used in a way not handled here. Assume nothing. + markOverdefined(*UI); + } + } else { + // Otherwise we're not tracking this callee, so handle it in the + // common path above. + goto CallOverdefined; + } + + // Finally, if this is the first call to the function hit, mark its entry + // block executable. + if (!BBExecutable.count(F->begin())) + MarkBlockExecutable(F->begin()); + + // Propagate information from this call site into the callee. + CallSite::arg_iterator CAI = CS.arg_begin(); + for (Function::arg_iterator AI = F->arg_begin(), E = F->arg_end(); + AI != E; ++AI, ++CAI) { + LatticeVal &IV = ValueState[AI]; + if (!IV.isOverdefined()) + mergeInValue(IV, AI, getValueState(*CAI)); + } +} + + +void SCCPSolver::Solve() { + // Process the work lists until they are empty! + while (!BBWorkList.empty() || !InstWorkList.empty() || + !OverdefinedInstWorkList.empty()) { + // Process the instruction work list... + while (!OverdefinedInstWorkList.empty()) { + Value *I = OverdefinedInstWorkList.back(); + OverdefinedInstWorkList.pop_back(); + + DOUT << "\nPopped off OI-WL: " << *I; + + // "I" got into the work list because it either made the transition from + // bottom to constant + // + // Anything on this worklist that is overdefined need not be visited + // since all of its users will have already been marked as overdefined + // Update all of the users of this instruction's value... + // + for (Value::use_iterator UI = I->use_begin(), E = I->use_end(); + UI != E; ++UI) + OperandChangedState(*UI); + } + // Process the instruction work list... + while (!InstWorkList.empty()) { + Value *I = InstWorkList.back(); + InstWorkList.pop_back(); + + DOUT << "\nPopped off I-WL: " << *I; + + // "I" got into the work list because it either made the transition from + // bottom to constant + // + // Anything on this worklist that is overdefined need not be visited + // since all of its users will have already been marked as overdefined. + // Update all of the users of this instruction's value... + // + if (!getValueState(I).isOverdefined()) + for (Value::use_iterator UI = I->use_begin(), E = I->use_end(); + UI != E; ++UI) + OperandChangedState(*UI); + } + + // Process the basic block work list... + while (!BBWorkList.empty()) { + BasicBlock *BB = BBWorkList.back(); + BBWorkList.pop_back(); + + DOUT << "\nPopped off BBWL: " << *BB; + + // Notify all instructions in this basic block that they are newly + // executable. + visit(BB); + } + } +} + +/// ResolvedUndefsIn - While solving the dataflow for a function, we assume +/// that branches on undef values cannot reach any of their successors. +/// However, this is not a safe assumption. After we solve dataflow, this +/// method should be use to handle this. If this returns true, the solver +/// should be rerun. +/// +/// This method handles this by finding an unresolved branch and marking it one +/// of the edges from the block as being feasible, even though the condition +/// doesn't say it would otherwise be. This allows SCCP to find the rest of the +/// CFG and only slightly pessimizes the analysis results (by marking one, +/// potentially infeasible, edge feasible). This cannot usefully modify the +/// constraints on the condition of the branch, as that would impact other users +/// of the value. +/// +/// This scan also checks for values that use undefs, whose results are actually +/// defined. For example, 'zext i8 undef to i32' should produce all zeros +/// conservatively, as "(zext i8 X -> i32) & 0xFF00" must always return zero, +/// even if X isn't defined. +bool SCCPSolver::ResolvedUndefsIn(Function &F) { + for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) { + if (!BBExecutable.count(BB)) + continue; + + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) { + // Look for instructions which produce undef values. + if (I->getType() == Type::VoidTy) continue; + + LatticeVal &LV = getValueState(I); + if (!LV.isUndefined()) continue; + + // Get the lattice values of the first two operands for use below. + LatticeVal &Op0LV = getValueState(I->getOperand(0)); + LatticeVal Op1LV; + if (I->getNumOperands() == 2) { + // If this is a two-operand instruction, and if both operands are + // undefs, the result stays undef. + Op1LV = getValueState(I->getOperand(1)); + if (Op0LV.isUndefined() && Op1LV.isUndefined()) + continue; + } + + // If this is an instructions whose result is defined even if the input is + // not fully defined, propagate the information. + const Type *ITy = I->getType(); + switch (I->getOpcode()) { + default: break; // Leave the instruction as an undef. + case Instruction::ZExt: + // After a zero extend, we know the top part is zero. SExt doesn't have + // to be handled here, because we don't know whether the top part is 1's + // or 0's. + assert(Op0LV.isUndefined()); + markForcedConstant(LV, I, Constant::getNullValue(ITy)); + return true; + case Instruction::Mul: + case Instruction::And: + // undef * X -> 0. X could be zero. + // undef & X -> 0. X could be zero. + markForcedConstant(LV, I, Constant::getNullValue(ITy)); + return true; + + case Instruction::Or: + // undef | X -> -1. X could be -1. + if (const VectorType *PTy = dyn_cast<VectorType>(ITy)) + markForcedConstant(LV, I, ConstantVector::getAllOnesValue(PTy)); + else + markForcedConstant(LV, I, ConstantInt::getAllOnesValue(ITy)); + return true; + + case Instruction::SDiv: + case Instruction::UDiv: + case Instruction::SRem: + case Instruction::URem: + // X / undef -> undef. No change. + // X % undef -> undef. No change. + if (Op1LV.isUndefined()) break; + + // undef / X -> 0. X could be maxint. + // undef % X -> 0. X could be 1. + markForcedConstant(LV, I, Constant::getNullValue(ITy)); + return true; + + case Instruction::AShr: + // undef >>s X -> undef. No change. + if (Op0LV.isUndefined()) break; + + // X >>s undef -> X. X could be 0, X could have the high-bit known set. + if (Op0LV.isConstant()) + markForcedConstant(LV, I, Op0LV.getConstant()); + else + markOverdefined(LV, I); + return true; + case Instruction::LShr: + case Instruction::Shl: + // undef >> X -> undef. No change. + // undef << X -> undef. No change. + if (Op0LV.isUndefined()) break; + + // X >> undef -> 0. X could be 0. + // X << undef -> 0. X could be 0. + markForcedConstant(LV, I, Constant::getNullValue(ITy)); + return true; + case Instruction::Select: + // undef ? X : Y -> X or Y. There could be commonality between X/Y. + if (Op0LV.isUndefined()) { + if (!Op1LV.isConstant()) // Pick the constant one if there is any. + Op1LV = getValueState(I->getOperand(2)); + } else if (Op1LV.isUndefined()) { + // c ? undef : undef -> undef. No change. + Op1LV = getValueState(I->getOperand(2)); + if (Op1LV.isUndefined()) + break; + // Otherwise, c ? undef : x -> x. + } else { + // Leave Op1LV as Operand(1)'s LatticeValue. + } + + if (Op1LV.isConstant()) + markForcedConstant(LV, I, Op1LV.getConstant()); + else + markOverdefined(LV, I); + return true; + case Instruction::Call: + // If a call has an undef result, it is because it is constant foldable + // but one of the inputs was undef. Just force the result to + // overdefined. + markOverdefined(LV, I); + return true; + } + } + + TerminatorInst *TI = BB->getTerminator(); + if (BranchInst *BI = dyn_cast<BranchInst>(TI)) { + if (!BI->isConditional()) continue; + if (!getValueState(BI->getCondition()).isUndefined()) + continue; + } else if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) { + if (SI->getNumSuccessors()<2) // no cases + continue; + if (!getValueState(SI->getCondition()).isUndefined()) + continue; + } else { + continue; + } + + // If the edge to the second successor isn't thought to be feasible yet, + // mark it so now. We pick the second one so that this goes to some + // enumerated value in a switch instead of going to the default destination. + if (KnownFeasibleEdges.count(Edge(BB, TI->getSuccessor(1)))) + continue; + + // Otherwise, it isn't already thought to be feasible. Mark it as such now + // and return. This will make other blocks reachable, which will allow new + // values to be discovered and existing ones to be moved in the lattice. + markEdgeExecutable(BB, TI->getSuccessor(1)); + + // This must be a conditional branch of switch on undef. At this point, + // force the old terminator to branch to the first successor. This is + // required because we are now influencing the dataflow of the function with + // the assumption that this edge is taken. If we leave the branch condition + // as undef, then further analysis could think the undef went another way + // leading to an inconsistent set of conclusions. + if (BranchInst *BI = dyn_cast<BranchInst>(TI)) { + BI->setCondition(ConstantInt::getFalse()); + } else { + SwitchInst *SI = cast<SwitchInst>(TI); + SI->setCondition(SI->getCaseValue(1)); + } + + return true; + } + + return false; +} + + +namespace { + //===--------------------------------------------------------------------===// + // + /// SCCP Class - This class uses the SCCPSolver to implement a per-function + /// Sparse Conditional Constant Propagator. + /// + struct VISIBILITY_HIDDEN SCCP : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + SCCP() : FunctionPass(&ID) {} + + // runOnFunction - Run the Sparse Conditional Constant Propagation + // algorithm, and return true if the function was modified. + // + bool runOnFunction(Function &F); + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesCFG(); + } + }; +} // end anonymous namespace + +char SCCP::ID = 0; +static RegisterPass<SCCP> +X("sccp", "Sparse Conditional Constant Propagation"); + +// createSCCPPass - This is the public interface to this file... +FunctionPass *llvm::createSCCPPass() { + return new SCCP(); +} + + +// runOnFunction() - Run the Sparse Conditional Constant Propagation algorithm, +// and return true if the function was modified. +// +bool SCCP::runOnFunction(Function &F) { + DOUT << "SCCP on function '" << F.getNameStart() << "'\n"; + SCCPSolver Solver; + + // Mark the first block of the function as being executable. + Solver.MarkBlockExecutable(F.begin()); + + // Mark all arguments to the function as being overdefined. + for (Function::arg_iterator AI = F.arg_begin(), E = F.arg_end(); AI != E;++AI) + Solver.markOverdefined(AI); + + // Solve for constants. + bool ResolvedUndefs = true; + while (ResolvedUndefs) { + Solver.Solve(); + DOUT << "RESOLVING UNDEFs\n"; + ResolvedUndefs = Solver.ResolvedUndefsIn(F); + } + + bool MadeChanges = false; + + // If we decided that there are basic blocks that are dead in this function, + // delete their contents now. Note that we cannot actually delete the blocks, + // as we cannot modify the CFG of the function. + // + SmallVector<Instruction*, 512> Insts; + std::map<Value*, LatticeVal> &Values = Solver.getValueMapping(); + + for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) + if (!Solver.isBlockExecutable(BB)) { + DOUT << " BasicBlock Dead:" << *BB; + ++NumDeadBlocks; + + // Delete the instructions backwards, as it has a reduced likelihood of + // having to update as many def-use and use-def chains. + for (BasicBlock::iterator I = BB->begin(), E = BB->getTerminator(); + I != E; ++I) + Insts.push_back(I); + while (!Insts.empty()) { + Instruction *I = Insts.back(); + Insts.pop_back(); + if (!I->use_empty()) + I->replaceAllUsesWith(UndefValue::get(I->getType())); + BB->getInstList().erase(I); + MadeChanges = true; + ++NumInstRemoved; + } + } else { + // Iterate over all of the instructions in a function, replacing them with + // constants if we have found them to be of constant values. + // + for (BasicBlock::iterator BI = BB->begin(), E = BB->end(); BI != E; ) { + Instruction *Inst = BI++; + if (Inst->getType() == Type::VoidTy || + isa<TerminatorInst>(Inst)) + continue; + + LatticeVal &IV = Values[Inst]; + if (!IV.isConstant() && !IV.isUndefined()) + continue; + + Constant *Const = IV.isConstant() + ? IV.getConstant() : UndefValue::get(Inst->getType()); + DOUT << " Constant: " << *Const << " = " << *Inst; + + // Replaces all of the uses of a variable with uses of the constant. + Inst->replaceAllUsesWith(Const); + + // Delete the instruction. + Inst->eraseFromParent(); + + // Hey, we just changed something! + MadeChanges = true; + ++NumInstRemoved; + } + } + + return MadeChanges; +} + +namespace { + //===--------------------------------------------------------------------===// + // + /// IPSCCP Class - This class implements interprocedural Sparse Conditional + /// Constant Propagation. + /// + struct VISIBILITY_HIDDEN IPSCCP : public ModulePass { + static char ID; + IPSCCP() : ModulePass(&ID) {} + bool runOnModule(Module &M); + }; +} // end anonymous namespace + +char IPSCCP::ID = 0; +static RegisterPass<IPSCCP> +Y("ipsccp", "Interprocedural Sparse Conditional Constant Propagation"); + +// createIPSCCPPass - This is the public interface to this file... +ModulePass *llvm::createIPSCCPPass() { + return new IPSCCP(); +} + + +static bool AddressIsTaken(GlobalValue *GV) { + // Delete any dead constantexpr klingons. + GV->removeDeadConstantUsers(); + + for (Value::use_iterator UI = GV->use_begin(), E = GV->use_end(); + UI != E; ++UI) + if (StoreInst *SI = dyn_cast<StoreInst>(*UI)) { + if (SI->getOperand(0) == GV || SI->isVolatile()) + return true; // Storing addr of GV. + } else if (isa<InvokeInst>(*UI) || isa<CallInst>(*UI)) { + // Make sure we are calling the function, not passing the address. + CallSite CS = CallSite::get(cast<Instruction>(*UI)); + if (CS.hasArgument(GV)) + return true; + } else if (LoadInst *LI = dyn_cast<LoadInst>(*UI)) { + if (LI->isVolatile()) + return true; + } else { + return true; + } + return false; +} + +bool IPSCCP::runOnModule(Module &M) { + SCCPSolver Solver; + + // Loop over all functions, marking arguments to those with their addresses + // taken or that are external as overdefined. + // + for (Module::iterator F = M.begin(), E = M.end(); F != E; ++F) + if (!F->hasLocalLinkage() || AddressIsTaken(F)) { + if (!F->isDeclaration()) + Solver.MarkBlockExecutable(F->begin()); + for (Function::arg_iterator AI = F->arg_begin(), E = F->arg_end(); + AI != E; ++AI) + Solver.markOverdefined(AI); + } else { + Solver.AddTrackedFunction(F); + } + + // Loop over global variables. We inform the solver about any internal global + // variables that do not have their 'addresses taken'. If they don't have + // their addresses taken, we can propagate constants through them. + for (Module::global_iterator G = M.global_begin(), E = M.global_end(); + G != E; ++G) + if (!G->isConstant() && G->hasLocalLinkage() && !AddressIsTaken(G)) + Solver.TrackValueOfGlobalVariable(G); + + // Solve for constants. + bool ResolvedUndefs = true; + while (ResolvedUndefs) { + Solver.Solve(); + + DOUT << "RESOLVING UNDEFS\n"; + ResolvedUndefs = false; + for (Module::iterator F = M.begin(), E = M.end(); F != E; ++F) + ResolvedUndefs |= Solver.ResolvedUndefsIn(*F); + } + + bool MadeChanges = false; + + // Iterate over all of the instructions in the module, replacing them with + // constants if we have found them to be of constant values. + // + SmallVector<Instruction*, 512> Insts; + SmallVector<BasicBlock*, 512> BlocksToErase; + std::map<Value*, LatticeVal> &Values = Solver.getValueMapping(); + + for (Module::iterator F = M.begin(), E = M.end(); F != E; ++F) { + for (Function::arg_iterator AI = F->arg_begin(), E = F->arg_end(); + AI != E; ++AI) + if (!AI->use_empty()) { + LatticeVal &IV = Values[AI]; + if (IV.isConstant() || IV.isUndefined()) { + Constant *CST = IV.isConstant() ? + IV.getConstant() : UndefValue::get(AI->getType()); + DOUT << "*** Arg " << *AI << " = " << *CST <<"\n"; + + // Replaces all of the uses of a variable with uses of the + // constant. + AI->replaceAllUsesWith(CST); + ++IPNumArgsElimed; + } + } + + for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB) + if (!Solver.isBlockExecutable(BB)) { + DOUT << " BasicBlock Dead:" << *BB; + ++IPNumDeadBlocks; + + // Delete the instructions backwards, as it has a reduced likelihood of + // having to update as many def-use and use-def chains. + TerminatorInst *TI = BB->getTerminator(); + for (BasicBlock::iterator I = BB->begin(), E = TI; I != E; ++I) + Insts.push_back(I); + + while (!Insts.empty()) { + Instruction *I = Insts.back(); + Insts.pop_back(); + if (!I->use_empty()) + I->replaceAllUsesWith(UndefValue::get(I->getType())); + BB->getInstList().erase(I); + MadeChanges = true; + ++IPNumInstRemoved; + } + + for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) { + BasicBlock *Succ = TI->getSuccessor(i); + if (!Succ->empty() && isa<PHINode>(Succ->begin())) + TI->getSuccessor(i)->removePredecessor(BB); + } + if (!TI->use_empty()) + TI->replaceAllUsesWith(UndefValue::get(TI->getType())); + BB->getInstList().erase(TI); + + if (&*BB != &F->front()) + BlocksToErase.push_back(BB); + else + new UnreachableInst(BB); + + } else { + for (BasicBlock::iterator BI = BB->begin(), E = BB->end(); BI != E; ) { + Instruction *Inst = BI++; + if (Inst->getType() == Type::VoidTy) + continue; + + LatticeVal &IV = Values[Inst]; + if (!IV.isConstant() && !IV.isUndefined()) + continue; + + Constant *Const = IV.isConstant() + ? IV.getConstant() : UndefValue::get(Inst->getType()); + DOUT << " Constant: " << *Const << " = " << *Inst; + + // Replaces all of the uses of a variable with uses of the + // constant. + Inst->replaceAllUsesWith(Const); + + // Delete the instruction. + if (!isa<CallInst>(Inst) && !isa<TerminatorInst>(Inst)) + Inst->eraseFromParent(); + + // Hey, we just changed something! + MadeChanges = true; + ++IPNumInstRemoved; + } + } + + // Now that all instructions in the function are constant folded, erase dead + // blocks, because we can now use ConstantFoldTerminator to get rid of + // in-edges. + for (unsigned i = 0, e = BlocksToErase.size(); i != e; ++i) { + // If there are any PHI nodes in this successor, drop entries for BB now. + BasicBlock *DeadBB = BlocksToErase[i]; + while (!DeadBB->use_empty()) { + Instruction *I = cast<Instruction>(DeadBB->use_back()); + bool Folded = ConstantFoldTerminator(I->getParent()); + if (!Folded) { + // The constant folder may not have been able to fold the terminator + // if this is a branch or switch on undef. Fold it manually as a + // branch to the first successor. +#ifndef NDEBUG + if (BranchInst *BI = dyn_cast<BranchInst>(I)) { + assert(BI->isConditional() && isa<UndefValue>(BI->getCondition()) && + "Branch should be foldable!"); + } else if (SwitchInst *SI = dyn_cast<SwitchInst>(I)) { + assert(isa<UndefValue>(SI->getCondition()) && "Switch should fold"); + } else { + assert(0 && "Didn't fold away reference to block!"); + } +#endif + + // Make this an uncond branch to the first successor. + TerminatorInst *TI = I->getParent()->getTerminator(); + BranchInst::Create(TI->getSuccessor(0), TI); + + // Remove entries in successor phi nodes to remove edges. + for (unsigned i = 1, e = TI->getNumSuccessors(); i != e; ++i) + TI->getSuccessor(i)->removePredecessor(TI->getParent()); + + // Remove the old terminator. + TI->eraseFromParent(); + } + } + + // Finally, delete the basic block. + F->getBasicBlockList().erase(DeadBB); + } + BlocksToErase.clear(); + } + + // If we inferred constant or undef return values for a function, we replaced + // all call uses with the inferred value. This means we don't need to bother + // actually returning anything from the function. Replace all return + // instructions with return undef. + // TODO: Process multiple value ret instructions also. + const DenseMap<Function*, LatticeVal> &RV = Solver.getTrackedRetVals(); + for (DenseMap<Function*, LatticeVal>::const_iterator I = RV.begin(), + E = RV.end(); I != E; ++I) + if (!I->second.isOverdefined() && + I->first->getReturnType() != Type::VoidTy) { + Function *F = I->first; + for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB) + if (ReturnInst *RI = dyn_cast<ReturnInst>(BB->getTerminator())) + if (!isa<UndefValue>(RI->getOperand(0))) + RI->setOperand(0, UndefValue::get(F->getReturnType())); + } + + // If we infered constant or undef values for globals variables, we can delete + // the global and any stores that remain to it. + const DenseMap<GlobalVariable*, LatticeVal> &TG = Solver.getTrackedGlobals(); + for (DenseMap<GlobalVariable*, LatticeVal>::const_iterator I = TG.begin(), + E = TG.end(); I != E; ++I) { + GlobalVariable *GV = I->first; + assert(!I->second.isOverdefined() && + "Overdefined values should have been taken out of the map!"); + DOUT << "Found that GV '" << GV->getNameStart() << "' is constant!\n"; + while (!GV->use_empty()) { + StoreInst *SI = cast<StoreInst>(GV->use_back()); + SI->eraseFromParent(); + } + M.getGlobalList().erase(GV); + ++IPNumGlobalConst; + } + + return MadeChanges; +} diff --git a/lib/Transforms/Scalar/Scalar.cpp b/lib/Transforms/Scalar/Scalar.cpp new file mode 100644 index 0000000..5669da0 --- /dev/null +++ b/lib/Transforms/Scalar/Scalar.cpp @@ -0,0 +1,111 @@ +//===-- Scalar.cpp --------------------------------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the C bindings for libLLVMScalarOpts.a, which implements +// several scalar transformations over the LLVM intermediate representation. +// +//===----------------------------------------------------------------------===// + +#include "llvm-c/Transforms/Scalar.h" +#include "llvm/PassManager.h" +#include "llvm/Transforms/Scalar.h" + +using namespace llvm; + +void LLVMAddAggressiveDCEPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createAggressiveDCEPass()); +} + +void LLVMAddCFGSimplificationPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createCFGSimplificationPass()); +} + +void LLVMAddCondPropagationPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createCondPropagationPass()); +} + +void LLVMAddDeadStoreEliminationPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createDeadStoreEliminationPass()); +} + +void LLVMAddGVNPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createGVNPass()); +} + +void LLVMAddIndVarSimplifyPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createIndVarSimplifyPass()); +} + +void LLVMAddInstructionCombiningPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createInstructionCombiningPass()); +} + +void LLVMAddJumpThreadingPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createJumpThreadingPass()); +} + +void LLVMAddLICMPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createLICMPass()); +} + +void LLVMAddLoopDeletionPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createLoopDeletionPass()); +} + +void LLVMAddLoopIndexSplitPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createLoopIndexSplitPass()); +} + +void LLVMAddLoopRotatePass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createLoopRotatePass()); +} + +void LLVMAddLoopUnrollPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createLoopUnrollPass()); +} + +void LLVMAddLoopUnswitchPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createLoopUnswitchPass()); +} + +void LLVMAddMemCpyOptPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createMemCpyOptPass()); +} + +void LLVMAddPromoteMemoryToRegisterPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createPromoteMemoryToRegisterPass()); +} + +void LLVMAddReassociatePass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createReassociatePass()); +} + +void LLVMAddSCCPPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createSCCPPass()); +} + +void LLVMAddScalarReplAggregatesPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createScalarReplAggregatesPass()); +} + +void LLVMAddSimplifyLibCallsPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createSimplifyLibCallsPass()); +} + +void LLVMAddTailCallEliminationPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createTailCallEliminationPass()); +} + +void LLVMAddConstantPropagationPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createConstantPropagationPass()); +} + +void LLVMAddDemoteMemoryToRegisterPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createDemoteRegisterToMemoryPass()); +} diff --git a/lib/Transforms/Scalar/ScalarReplAggregates.cpp b/lib/Transforms/Scalar/ScalarReplAggregates.cpp new file mode 100644 index 0000000..9935f12 --- /dev/null +++ b/lib/Transforms/Scalar/ScalarReplAggregates.cpp @@ -0,0 +1,1820 @@ +//===- ScalarReplAggregates.cpp - Scalar Replacement of Aggregates --------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This transformation implements the well known scalar replacement of +// aggregates transformation. This xform breaks up alloca instructions of +// aggregate type (structure or array) into individual alloca instructions for +// each member (if possible). Then, if possible, it transforms the individual +// alloca instructions into nice clean scalar SSA form. +// +// This combines a simple SRoA algorithm with the Mem2Reg algorithm because +// often interact, especially for C++ programs. As such, iterating between +// SRoA, then Mem2Reg until we run out of things to promote works well. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "scalarrepl" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Function.h" +#include "llvm/GlobalVariable.h" +#include "llvm/Instructions.h" +#include "llvm/IntrinsicInst.h" +#include "llvm/Pass.h" +#include "llvm/Analysis/Dominators.h" +#include "llvm/Target/TargetData.h" +#include "llvm/Transforms/Utils/PromoteMemToReg.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/GetElementPtrTypeIterator.h" +#include "llvm/Support/IRBuilder.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/Compiler.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringExtras.h" +using namespace llvm; + +STATISTIC(NumReplaced, "Number of allocas broken up"); +STATISTIC(NumPromoted, "Number of allocas promoted"); +STATISTIC(NumConverted, "Number of aggregates converted to scalar"); +STATISTIC(NumGlobals, "Number of allocas copied from constant global"); + +namespace { + struct VISIBILITY_HIDDEN SROA : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + explicit SROA(signed T = -1) : FunctionPass(&ID) { + if (T == -1) + SRThreshold = 128; + else + SRThreshold = T; + } + + bool runOnFunction(Function &F); + + bool performScalarRepl(Function &F); + bool performPromotion(Function &F); + + // getAnalysisUsage - This pass does not require any passes, but we know it + // will not alter the CFG, so say so. + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired<DominatorTree>(); + AU.addRequired<DominanceFrontier>(); + AU.addRequired<TargetData>(); + AU.setPreservesCFG(); + } + + private: + TargetData *TD; + + /// AllocaInfo - When analyzing uses of an alloca instruction, this captures + /// information about the uses. All these fields are initialized to false + /// and set to true when something is learned. + struct AllocaInfo { + /// isUnsafe - This is set to true if the alloca cannot be SROA'd. + bool isUnsafe : 1; + + /// needsCleanup - This is set to true if there is some use of the alloca + /// that requires cleanup. + bool needsCleanup : 1; + + /// isMemCpySrc - This is true if this aggregate is memcpy'd from. + bool isMemCpySrc : 1; + + /// isMemCpyDst - This is true if this aggregate is memcpy'd into. + bool isMemCpyDst : 1; + + AllocaInfo() + : isUnsafe(false), needsCleanup(false), + isMemCpySrc(false), isMemCpyDst(false) {} + }; + + unsigned SRThreshold; + + void MarkUnsafe(AllocaInfo &I) { I.isUnsafe = true; } + + int isSafeAllocaToScalarRepl(AllocationInst *AI); + + void isSafeUseOfAllocation(Instruction *User, AllocationInst *AI, + AllocaInfo &Info); + void isSafeElementUse(Value *Ptr, bool isFirstElt, AllocationInst *AI, + AllocaInfo &Info); + void isSafeMemIntrinsicOnAllocation(MemIntrinsic *MI, AllocationInst *AI, + unsigned OpNo, AllocaInfo &Info); + void isSafeUseOfBitCastedAllocation(BitCastInst *User, AllocationInst *AI, + AllocaInfo &Info); + + void DoScalarReplacement(AllocationInst *AI, + std::vector<AllocationInst*> &WorkList); + void CleanupGEP(GetElementPtrInst *GEP); + void CleanupAllocaUsers(AllocationInst *AI); + AllocaInst *AddNewAlloca(Function &F, const Type *Ty, AllocationInst *Base); + + void RewriteBitCastUserOfAlloca(Instruction *BCInst, AllocationInst *AI, + SmallVector<AllocaInst*, 32> &NewElts); + + void RewriteMemIntrinUserOfAlloca(MemIntrinsic *MI, Instruction *BCInst, + AllocationInst *AI, + SmallVector<AllocaInst*, 32> &NewElts); + void RewriteStoreUserOfWholeAlloca(StoreInst *SI, AllocationInst *AI, + SmallVector<AllocaInst*, 32> &NewElts); + void RewriteLoadUserOfWholeAlloca(LoadInst *LI, AllocationInst *AI, + SmallVector<AllocaInst*, 32> &NewElts); + + bool CanConvertToScalar(Value *V, bool &IsNotTrivial, const Type *&VecTy, + bool &SawVec, uint64_t Offset, unsigned AllocaSize); + void ConvertUsesToScalar(Value *Ptr, AllocaInst *NewAI, uint64_t Offset); + Value *ConvertScalar_ExtractValue(Value *NV, const Type *ToType, + uint64_t Offset, IRBuilder<> &Builder); + Value *ConvertScalar_InsertValue(Value *StoredVal, Value *ExistingVal, + uint64_t Offset, IRBuilder<> &Builder); + static Instruction *isOnlyCopiedFromConstantGlobal(AllocationInst *AI); + }; +} + +char SROA::ID = 0; +static RegisterPass<SROA> X("scalarrepl", "Scalar Replacement of Aggregates"); + +// Public interface to the ScalarReplAggregates pass +FunctionPass *llvm::createScalarReplAggregatesPass(signed int Threshold) { + return new SROA(Threshold); +} + + +bool SROA::runOnFunction(Function &F) { + TD = &getAnalysis<TargetData>(); + + bool Changed = performPromotion(F); + while (1) { + bool LocalChange = performScalarRepl(F); + if (!LocalChange) break; // No need to repromote if no scalarrepl + Changed = true; + LocalChange = performPromotion(F); + if (!LocalChange) break; // No need to re-scalarrepl if no promotion + } + + return Changed; +} + + +bool SROA::performPromotion(Function &F) { + std::vector<AllocaInst*> Allocas; + DominatorTree &DT = getAnalysis<DominatorTree>(); + DominanceFrontier &DF = getAnalysis<DominanceFrontier>(); + + BasicBlock &BB = F.getEntryBlock(); // Get the entry node for the function + + bool Changed = false; + + while (1) { + Allocas.clear(); + + // Find allocas that are safe to promote, by looking at all instructions in + // the entry node + for (BasicBlock::iterator I = BB.begin(), E = --BB.end(); I != E; ++I) + if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) // Is it an alloca? + if (isAllocaPromotable(AI)) + Allocas.push_back(AI); + + if (Allocas.empty()) break; + + PromoteMemToReg(Allocas, DT, DF); + NumPromoted += Allocas.size(); + Changed = true; + } + + return Changed; +} + +/// getNumSAElements - Return the number of elements in the specific struct or +/// array. +static uint64_t getNumSAElements(const Type *T) { + if (const StructType *ST = dyn_cast<StructType>(T)) + return ST->getNumElements(); + return cast<ArrayType>(T)->getNumElements(); +} + +// performScalarRepl - This algorithm is a simple worklist driven algorithm, +// which runs on all of the malloc/alloca instructions in the function, removing +// them if they are only used by getelementptr instructions. +// +bool SROA::performScalarRepl(Function &F) { + std::vector<AllocationInst*> WorkList; + + // Scan the entry basic block, adding any alloca's and mallocs to the worklist + BasicBlock &BB = F.getEntryBlock(); + for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ++I) + if (AllocationInst *A = dyn_cast<AllocationInst>(I)) + WorkList.push_back(A); + + // Process the worklist + bool Changed = false; + while (!WorkList.empty()) { + AllocationInst *AI = WorkList.back(); + WorkList.pop_back(); + + // Handle dead allocas trivially. These can be formed by SROA'ing arrays + // with unused elements. + if (AI->use_empty()) { + AI->eraseFromParent(); + continue; + } + + // If this alloca is impossible for us to promote, reject it early. + if (AI->isArrayAllocation() || !AI->getAllocatedType()->isSized()) + continue; + + // Check to see if this allocation is only modified by a memcpy/memmove from + // a constant global. If this is the case, we can change all users to use + // the constant global instead. This is commonly produced by the CFE by + // constructs like "void foo() { int A[] = {1,2,3,4,5,6,7,8,9...}; }" if 'A' + // is only subsequently read. + if (Instruction *TheCopy = isOnlyCopiedFromConstantGlobal(AI)) { + DOUT << "Found alloca equal to global: " << *AI; + DOUT << " memcpy = " << *TheCopy; + Constant *TheSrc = cast<Constant>(TheCopy->getOperand(2)); + AI->replaceAllUsesWith(ConstantExpr::getBitCast(TheSrc, AI->getType())); + TheCopy->eraseFromParent(); // Don't mutate the global. + AI->eraseFromParent(); + ++NumGlobals; + Changed = true; + continue; + } + + // Check to see if we can perform the core SROA transformation. We cannot + // transform the allocation instruction if it is an array allocation + // (allocations OF arrays are ok though), and an allocation of a scalar + // value cannot be decomposed at all. + uint64_t AllocaSize = TD->getTypeAllocSize(AI->getAllocatedType()); + + // Do not promote any struct whose size is too big. + if (AllocaSize > SRThreshold) continue; + + if ((isa<StructType>(AI->getAllocatedType()) || + isa<ArrayType>(AI->getAllocatedType())) && + // Do not promote any struct into more than "32" separate vars. + getNumSAElements(AI->getAllocatedType()) <= SRThreshold/4) { + // Check that all of the users of the allocation are capable of being + // transformed. + switch (isSafeAllocaToScalarRepl(AI)) { + default: assert(0 && "Unexpected value!"); + case 0: // Not safe to scalar replace. + break; + case 1: // Safe, but requires cleanup/canonicalizations first + CleanupAllocaUsers(AI); + // FALL THROUGH. + case 3: // Safe to scalar replace. + DoScalarReplacement(AI, WorkList); + Changed = true; + continue; + } + } + + // If we can turn this aggregate value (potentially with casts) into a + // simple scalar value that can be mem2reg'd into a register value. + // IsNotTrivial tracks whether this is something that mem2reg could have + // promoted itself. If so, we don't want to transform it needlessly. Note + // that we can't just check based on the type: the alloca may be of an i32 + // but that has pointer arithmetic to set byte 3 of it or something. + bool IsNotTrivial = false; + const Type *VectorTy = 0; + bool HadAVector = false; + if (CanConvertToScalar(AI, IsNotTrivial, VectorTy, HadAVector, + 0, unsigned(AllocaSize)) && IsNotTrivial) { + AllocaInst *NewAI; + // If we were able to find a vector type that can handle this with + // insert/extract elements, and if there was at least one use that had + // a vector type, promote this to a vector. We don't want to promote + // random stuff that doesn't use vectors (e.g. <9 x double>) because then + // we just get a lot of insert/extracts. If at least one vector is + // involved, then we probably really do have a union of vector/array. + if (VectorTy && isa<VectorType>(VectorTy) && HadAVector) { + DOUT << "CONVERT TO VECTOR: " << *AI << " TYPE = " << *VectorTy <<"\n"; + + // Create and insert the vector alloca. + NewAI = new AllocaInst(VectorTy, 0, "", AI->getParent()->begin()); + ConvertUsesToScalar(AI, NewAI, 0); + } else { + DOUT << "CONVERT TO SCALAR INTEGER: " << *AI << "\n"; + + // Create and insert the integer alloca. + const Type *NewTy = IntegerType::get(AllocaSize*8); + NewAI = new AllocaInst(NewTy, 0, "", AI->getParent()->begin()); + ConvertUsesToScalar(AI, NewAI, 0); + } + NewAI->takeName(AI); + AI->eraseFromParent(); + ++NumConverted; + Changed = true; + continue; + } + + // Otherwise, couldn't process this alloca. + } + + return Changed; +} + +/// DoScalarReplacement - This alloca satisfied the isSafeAllocaToScalarRepl +/// predicate, do SROA now. +void SROA::DoScalarReplacement(AllocationInst *AI, + std::vector<AllocationInst*> &WorkList) { + DOUT << "Found inst to SROA: " << *AI; + SmallVector<AllocaInst*, 32> ElementAllocas; + if (const StructType *ST = dyn_cast<StructType>(AI->getAllocatedType())) { + ElementAllocas.reserve(ST->getNumContainedTypes()); + for (unsigned i = 0, e = ST->getNumContainedTypes(); i != e; ++i) { + AllocaInst *NA = new AllocaInst(ST->getContainedType(i), 0, + AI->getAlignment(), + AI->getName() + "." + utostr(i), AI); + ElementAllocas.push_back(NA); + WorkList.push_back(NA); // Add to worklist for recursive processing + } + } else { + const ArrayType *AT = cast<ArrayType>(AI->getAllocatedType()); + ElementAllocas.reserve(AT->getNumElements()); + const Type *ElTy = AT->getElementType(); + for (unsigned i = 0, e = AT->getNumElements(); i != e; ++i) { + AllocaInst *NA = new AllocaInst(ElTy, 0, AI->getAlignment(), + AI->getName() + "." + utostr(i), AI); + ElementAllocas.push_back(NA); + WorkList.push_back(NA); // Add to worklist for recursive processing + } + } + + // Now that we have created the alloca instructions that we want to use, + // expand the getelementptr instructions to use them. + // + while (!AI->use_empty()) { + Instruction *User = cast<Instruction>(AI->use_back()); + if (BitCastInst *BCInst = dyn_cast<BitCastInst>(User)) { + RewriteBitCastUserOfAlloca(BCInst, AI, ElementAllocas); + BCInst->eraseFromParent(); + continue; + } + + // Replace: + // %res = load { i32, i32 }* %alloc + // with: + // %load.0 = load i32* %alloc.0 + // %insert.0 insertvalue { i32, i32 } zeroinitializer, i32 %load.0, 0 + // %load.1 = load i32* %alloc.1 + // %insert = insertvalue { i32, i32 } %insert.0, i32 %load.1, 1 + // (Also works for arrays instead of structs) + if (LoadInst *LI = dyn_cast<LoadInst>(User)) { + Value *Insert = UndefValue::get(LI->getType()); + for (unsigned i = 0, e = ElementAllocas.size(); i != e; ++i) { + Value *Load = new LoadInst(ElementAllocas[i], "load", LI); + Insert = InsertValueInst::Create(Insert, Load, i, "insert", LI); + } + LI->replaceAllUsesWith(Insert); + LI->eraseFromParent(); + continue; + } + + // Replace: + // store { i32, i32 } %val, { i32, i32 }* %alloc + // with: + // %val.0 = extractvalue { i32, i32 } %val, 0 + // store i32 %val.0, i32* %alloc.0 + // %val.1 = extractvalue { i32, i32 } %val, 1 + // store i32 %val.1, i32* %alloc.1 + // (Also works for arrays instead of structs) + if (StoreInst *SI = dyn_cast<StoreInst>(User)) { + Value *Val = SI->getOperand(0); + for (unsigned i = 0, e = ElementAllocas.size(); i != e; ++i) { + Value *Extract = ExtractValueInst::Create(Val, i, Val->getName(), SI); + new StoreInst(Extract, ElementAllocas[i], SI); + } + SI->eraseFromParent(); + continue; + } + + GetElementPtrInst *GEPI = cast<GetElementPtrInst>(User); + // We now know that the GEP is of the form: GEP <ptr>, 0, <cst> + unsigned Idx = + (unsigned)cast<ConstantInt>(GEPI->getOperand(2))->getZExtValue(); + + assert(Idx < ElementAllocas.size() && "Index out of range?"); + AllocaInst *AllocaToUse = ElementAllocas[Idx]; + + Value *RepValue; + if (GEPI->getNumOperands() == 3) { + // Do not insert a new getelementptr instruction with zero indices, only + // to have it optimized out later. + RepValue = AllocaToUse; + } else { + // We are indexing deeply into the structure, so we still need a + // getelement ptr instruction to finish the indexing. This may be + // expanded itself once the worklist is rerun. + // + SmallVector<Value*, 8> NewArgs; + NewArgs.push_back(Constant::getNullValue(Type::Int32Ty)); + NewArgs.append(GEPI->op_begin()+3, GEPI->op_end()); + RepValue = GetElementPtrInst::Create(AllocaToUse, NewArgs.begin(), + NewArgs.end(), "", GEPI); + RepValue->takeName(GEPI); + } + + // If this GEP is to the start of the aggregate, check for memcpys. + if (Idx == 0 && GEPI->hasAllZeroIndices()) + RewriteBitCastUserOfAlloca(GEPI, AI, ElementAllocas); + + // Move all of the users over to the new GEP. + GEPI->replaceAllUsesWith(RepValue); + // Delete the old GEP + GEPI->eraseFromParent(); + } + + // Finally, delete the Alloca instruction + AI->eraseFromParent(); + NumReplaced++; +} + + +/// isSafeElementUse - Check to see if this use is an allowed use for a +/// getelementptr instruction of an array aggregate allocation. isFirstElt +/// indicates whether Ptr is known to the start of the aggregate. +/// +void SROA::isSafeElementUse(Value *Ptr, bool isFirstElt, AllocationInst *AI, + AllocaInfo &Info) { + for (Value::use_iterator I = Ptr->use_begin(), E = Ptr->use_end(); + I != E; ++I) { + Instruction *User = cast<Instruction>(*I); + switch (User->getOpcode()) { + case Instruction::Load: break; + case Instruction::Store: + // Store is ok if storing INTO the pointer, not storing the pointer + if (User->getOperand(0) == Ptr) return MarkUnsafe(Info); + break; + case Instruction::GetElementPtr: { + GetElementPtrInst *GEP = cast<GetElementPtrInst>(User); + bool AreAllZeroIndices = isFirstElt; + if (GEP->getNumOperands() > 1) { + if (!isa<ConstantInt>(GEP->getOperand(1)) || + !cast<ConstantInt>(GEP->getOperand(1))->isZero()) + // Using pointer arithmetic to navigate the array. + return MarkUnsafe(Info); + + if (AreAllZeroIndices) + AreAllZeroIndices = GEP->hasAllZeroIndices(); + } + isSafeElementUse(GEP, AreAllZeroIndices, AI, Info); + if (Info.isUnsafe) return; + break; + } + case Instruction::BitCast: + if (isFirstElt) { + isSafeUseOfBitCastedAllocation(cast<BitCastInst>(User), AI, Info); + if (Info.isUnsafe) return; + break; + } + DOUT << " Transformation preventing inst: " << *User; + return MarkUnsafe(Info); + case Instruction::Call: + if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(User)) { + if (isFirstElt) { + isSafeMemIntrinsicOnAllocation(MI, AI, I.getOperandNo(), Info); + if (Info.isUnsafe) return; + break; + } + } + DOUT << " Transformation preventing inst: " << *User; + return MarkUnsafe(Info); + default: + DOUT << " Transformation preventing inst: " << *User; + return MarkUnsafe(Info); + } + } + return; // All users look ok :) +} + +/// AllUsersAreLoads - Return true if all users of this value are loads. +static bool AllUsersAreLoads(Value *Ptr) { + for (Value::use_iterator I = Ptr->use_begin(), E = Ptr->use_end(); + I != E; ++I) + if (cast<Instruction>(*I)->getOpcode() != Instruction::Load) + return false; + return true; +} + +/// isSafeUseOfAllocation - Check to see if this user is an allowed use for an +/// aggregate allocation. +/// +void SROA::isSafeUseOfAllocation(Instruction *User, AllocationInst *AI, + AllocaInfo &Info) { + if (BitCastInst *C = dyn_cast<BitCastInst>(User)) + return isSafeUseOfBitCastedAllocation(C, AI, Info); + + if (LoadInst *LI = dyn_cast<LoadInst>(User)) + if (!LI->isVolatile()) + return;// Loads (returning a first class aggregrate) are always rewritable + + if (StoreInst *SI = dyn_cast<StoreInst>(User)) + if (!SI->isVolatile() && SI->getOperand(0) != AI) + return;// Store is ok if storing INTO the pointer, not storing the pointer + + GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(User); + if (GEPI == 0) + return MarkUnsafe(Info); + + gep_type_iterator I = gep_type_begin(GEPI), E = gep_type_end(GEPI); + + // The GEP is not safe to transform if not of the form "GEP <ptr>, 0, <cst>". + if (I == E || + I.getOperand() != Constant::getNullValue(I.getOperand()->getType())) { + return MarkUnsafe(Info); + } + + ++I; + if (I == E) return MarkUnsafe(Info); // ran out of GEP indices?? + + bool IsAllZeroIndices = true; + + // If the first index is a non-constant index into an array, see if we can + // handle it as a special case. + if (const ArrayType *AT = dyn_cast<ArrayType>(*I)) { + if (!isa<ConstantInt>(I.getOperand())) { + IsAllZeroIndices = 0; + uint64_t NumElements = AT->getNumElements(); + + // If this is an array index and the index is not constant, we cannot + // promote... that is unless the array has exactly one or two elements in + // it, in which case we CAN promote it, but we have to canonicalize this + // out if this is the only problem. + if ((NumElements == 1 || NumElements == 2) && + AllUsersAreLoads(GEPI)) { + Info.needsCleanup = true; + return; // Canonicalization required! + } + return MarkUnsafe(Info); + } + } + + // Walk through the GEP type indices, checking the types that this indexes + // into. + for (; I != E; ++I) { + // Ignore struct elements, no extra checking needed for these. + if (isa<StructType>(*I)) + continue; + + ConstantInt *IdxVal = dyn_cast<ConstantInt>(I.getOperand()); + if (!IdxVal) return MarkUnsafe(Info); + + // Are all indices still zero? + IsAllZeroIndices &= IdxVal->isZero(); + + if (const ArrayType *AT = dyn_cast<ArrayType>(*I)) { + // This GEP indexes an array. Verify that this is an in-range constant + // integer. Specifically, consider A[0][i]. We cannot know that the user + // isn't doing invalid things like allowing i to index an out-of-range + // subscript that accesses A[1]. Because of this, we have to reject SROA + // of any accesses into structs where any of the components are variables. + if (IdxVal->getZExtValue() >= AT->getNumElements()) + return MarkUnsafe(Info); + } else if (const VectorType *VT = dyn_cast<VectorType>(*I)) { + if (IdxVal->getZExtValue() >= VT->getNumElements()) + return MarkUnsafe(Info); + } + } + + // If there are any non-simple uses of this getelementptr, make sure to reject + // them. + return isSafeElementUse(GEPI, IsAllZeroIndices, AI, Info); +} + +/// isSafeMemIntrinsicOnAllocation - Return true if the specified memory +/// intrinsic can be promoted by SROA. At this point, we know that the operand +/// of the memintrinsic is a pointer to the beginning of the allocation. +void SROA::isSafeMemIntrinsicOnAllocation(MemIntrinsic *MI, AllocationInst *AI, + unsigned OpNo, AllocaInfo &Info) { + // If not constant length, give up. + ConstantInt *Length = dyn_cast<ConstantInt>(MI->getLength()); + if (!Length) return MarkUnsafe(Info); + + // If not the whole aggregate, give up. + if (Length->getZExtValue() != + TD->getTypeAllocSize(AI->getType()->getElementType())) + return MarkUnsafe(Info); + + // We only know about memcpy/memset/memmove. + if (!isa<MemIntrinsic>(MI)) + return MarkUnsafe(Info); + + // Otherwise, we can transform it. Determine whether this is a memcpy/set + // into or out of the aggregate. + if (OpNo == 1) + Info.isMemCpyDst = true; + else { + assert(OpNo == 2); + Info.isMemCpySrc = true; + } +} + +/// isSafeUseOfBitCastedAllocation - Return true if all users of this bitcast +/// are +void SROA::isSafeUseOfBitCastedAllocation(BitCastInst *BC, AllocationInst *AI, + AllocaInfo &Info) { + for (Value::use_iterator UI = BC->use_begin(), E = BC->use_end(); + UI != E; ++UI) { + if (BitCastInst *BCU = dyn_cast<BitCastInst>(UI)) { + isSafeUseOfBitCastedAllocation(BCU, AI, Info); + } else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(UI)) { + isSafeMemIntrinsicOnAllocation(MI, AI, UI.getOperandNo(), Info); + } else if (StoreInst *SI = dyn_cast<StoreInst>(UI)) { + if (SI->isVolatile()) + return MarkUnsafe(Info); + + // If storing the entire alloca in one chunk through a bitcasted pointer + // to integer, we can transform it. This happens (for example) when you + // cast a {i32,i32}* to i64* and store through it. This is similar to the + // memcpy case and occurs in various "byval" cases and emulated memcpys. + if (isa<IntegerType>(SI->getOperand(0)->getType()) && + TD->getTypeAllocSize(SI->getOperand(0)->getType()) == + TD->getTypeAllocSize(AI->getType()->getElementType())) { + Info.isMemCpyDst = true; + continue; + } + return MarkUnsafe(Info); + } else if (LoadInst *LI = dyn_cast<LoadInst>(UI)) { + if (LI->isVolatile()) + return MarkUnsafe(Info); + + // If loading the entire alloca in one chunk through a bitcasted pointer + // to integer, we can transform it. This happens (for example) when you + // cast a {i32,i32}* to i64* and load through it. This is similar to the + // memcpy case and occurs in various "byval" cases and emulated memcpys. + if (isa<IntegerType>(LI->getType()) && + TD->getTypeAllocSize(LI->getType()) == + TD->getTypeAllocSize(AI->getType()->getElementType())) { + Info.isMemCpySrc = true; + continue; + } + return MarkUnsafe(Info); + } else if (isa<DbgInfoIntrinsic>(UI)) { + // If one user is DbgInfoIntrinsic then check if all users are + // DbgInfoIntrinsics. + if (OnlyUsedByDbgInfoIntrinsics(BC)) { + Info.needsCleanup = true; + return; + } + else + MarkUnsafe(Info); + } + else { + return MarkUnsafe(Info); + } + if (Info.isUnsafe) return; + } +} + +/// RewriteBitCastUserOfAlloca - BCInst (transitively) bitcasts AI, or indexes +/// to its first element. Transform users of the cast to use the new values +/// instead. +void SROA::RewriteBitCastUserOfAlloca(Instruction *BCInst, AllocationInst *AI, + SmallVector<AllocaInst*, 32> &NewElts) { + Value::use_iterator UI = BCInst->use_begin(), UE = BCInst->use_end(); + while (UI != UE) { + Instruction *User = cast<Instruction>(*UI++); + if (BitCastInst *BCU = dyn_cast<BitCastInst>(User)) { + RewriteBitCastUserOfAlloca(BCU, AI, NewElts); + if (BCU->use_empty()) BCU->eraseFromParent(); + continue; + } + + if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(User)) { + // This must be memcpy/memmove/memset of the entire aggregate. + // Split into one per element. + RewriteMemIntrinUserOfAlloca(MI, BCInst, AI, NewElts); + continue; + } + + if (StoreInst *SI = dyn_cast<StoreInst>(User)) { + // If this is a store of the entire alloca from an integer, rewrite it. + RewriteStoreUserOfWholeAlloca(SI, AI, NewElts); + continue; + } + + if (LoadInst *LI = dyn_cast<LoadInst>(User)) { + // If this is a load of the entire alloca to an integer, rewrite it. + RewriteLoadUserOfWholeAlloca(LI, AI, NewElts); + continue; + } + + // Otherwise it must be some other user of a gep of the first pointer. Just + // leave these alone. + continue; + } +} + +/// RewriteMemIntrinUserOfAlloca - MI is a memcpy/memset/memmove from or to AI. +/// Rewrite it to copy or set the elements of the scalarized memory. +void SROA::RewriteMemIntrinUserOfAlloca(MemIntrinsic *MI, Instruction *BCInst, + AllocationInst *AI, + SmallVector<AllocaInst*, 32> &NewElts) { + + // If this is a memcpy/memmove, construct the other pointer as the + // appropriate type. The "Other" pointer is the pointer that goes to memory + // that doesn't have anything to do with the alloca that we are promoting. For + // memset, this Value* stays null. + Value *OtherPtr = 0; + unsigned MemAlignment = MI->getAlignment(); + if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) { // memmove/memcopy + if (BCInst == MTI->getRawDest()) + OtherPtr = MTI->getRawSource(); + else { + assert(BCInst == MTI->getRawSource()); + OtherPtr = MTI->getRawDest(); + } + } + + // If there is an other pointer, we want to convert it to the same pointer + // type as AI has, so we can GEP through it safely. + if (OtherPtr) { + // It is likely that OtherPtr is a bitcast, if so, remove it. + if (BitCastInst *BC = dyn_cast<BitCastInst>(OtherPtr)) + OtherPtr = BC->getOperand(0); + // All zero GEPs are effectively bitcasts. + if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(OtherPtr)) + if (GEP->hasAllZeroIndices()) + OtherPtr = GEP->getOperand(0); + + if (ConstantExpr *BCE = dyn_cast<ConstantExpr>(OtherPtr)) + if (BCE->getOpcode() == Instruction::BitCast) + OtherPtr = BCE->getOperand(0); + + // If the pointer is not the right type, insert a bitcast to the right + // type. + if (OtherPtr->getType() != AI->getType()) + OtherPtr = new BitCastInst(OtherPtr, AI->getType(), OtherPtr->getName(), + MI); + } + + // Process each element of the aggregate. + Value *TheFn = MI->getOperand(0); + const Type *BytePtrTy = MI->getRawDest()->getType(); + bool SROADest = MI->getRawDest() == BCInst; + + Constant *Zero = Constant::getNullValue(Type::Int32Ty); + + for (unsigned i = 0, e = NewElts.size(); i != e; ++i) { + // If this is a memcpy/memmove, emit a GEP of the other element address. + Value *OtherElt = 0; + unsigned OtherEltAlign = MemAlignment; + + if (OtherPtr) { + Value *Idx[2] = { Zero, ConstantInt::get(Type::Int32Ty, i) }; + OtherElt = GetElementPtrInst::Create(OtherPtr, Idx, Idx + 2, + OtherPtr->getNameStr()+"."+utostr(i), + MI); + uint64_t EltOffset; + const PointerType *OtherPtrTy = cast<PointerType>(OtherPtr->getType()); + if (const StructType *ST = + dyn_cast<StructType>(OtherPtrTy->getElementType())) { + EltOffset = TD->getStructLayout(ST)->getElementOffset(i); + } else { + const Type *EltTy = + cast<SequentialType>(OtherPtr->getType())->getElementType(); + EltOffset = TD->getTypeAllocSize(EltTy)*i; + } + + // The alignment of the other pointer is the guaranteed alignment of the + // element, which is affected by both the known alignment of the whole + // mem intrinsic and the alignment of the element. If the alignment of + // the memcpy (f.e.) is 32 but the element is at a 4-byte offset, then the + // known alignment is just 4 bytes. + OtherEltAlign = (unsigned)MinAlign(OtherEltAlign, EltOffset); + } + + Value *EltPtr = NewElts[i]; + const Type *EltTy = cast<PointerType>(EltPtr->getType())->getElementType(); + + // If we got down to a scalar, insert a load or store as appropriate. + if (EltTy->isSingleValueType()) { + if (isa<MemTransferInst>(MI)) { + if (SROADest) { + // From Other to Alloca. + Value *Elt = new LoadInst(OtherElt, "tmp", false, OtherEltAlign, MI); + new StoreInst(Elt, EltPtr, MI); + } else { + // From Alloca to Other. + Value *Elt = new LoadInst(EltPtr, "tmp", MI); + new StoreInst(Elt, OtherElt, false, OtherEltAlign, MI); + } + continue; + } + assert(isa<MemSetInst>(MI)); + + // If the stored element is zero (common case), just store a null + // constant. + Constant *StoreVal; + if (ConstantInt *CI = dyn_cast<ConstantInt>(MI->getOperand(2))) { + if (CI->isZero()) { + StoreVal = Constant::getNullValue(EltTy); // 0.0, null, 0, <0,0> + } else { + // If EltTy is a vector type, get the element type. + const Type *ValTy = EltTy; + if (const VectorType *VTy = dyn_cast<VectorType>(ValTy)) + ValTy = VTy->getElementType(); + + // Construct an integer with the right value. + unsigned EltSize = TD->getTypeSizeInBits(ValTy); + APInt OneVal(EltSize, CI->getZExtValue()); + APInt TotalVal(OneVal); + // Set each byte. + for (unsigned i = 0; 8*i < EltSize; ++i) { + TotalVal = TotalVal.shl(8); + TotalVal |= OneVal; + } + + // Convert the integer value to the appropriate type. + StoreVal = ConstantInt::get(TotalVal); + if (isa<PointerType>(ValTy)) + StoreVal = ConstantExpr::getIntToPtr(StoreVal, ValTy); + else if (ValTy->isFloatingPoint()) + StoreVal = ConstantExpr::getBitCast(StoreVal, ValTy); + assert(StoreVal->getType() == ValTy && "Type mismatch!"); + + // If the requested value was a vector constant, create it. + if (EltTy != ValTy) { + unsigned NumElts = cast<VectorType>(ValTy)->getNumElements(); + SmallVector<Constant*, 16> Elts(NumElts, StoreVal); + StoreVal = ConstantVector::get(&Elts[0], NumElts); + } + } + new StoreInst(StoreVal, EltPtr, MI); + continue; + } + // Otherwise, if we're storing a byte variable, use a memset call for + // this element. + } + + // Cast the element pointer to BytePtrTy. + if (EltPtr->getType() != BytePtrTy) + EltPtr = new BitCastInst(EltPtr, BytePtrTy, EltPtr->getNameStr(), MI); + + // Cast the other pointer (if we have one) to BytePtrTy. + if (OtherElt && OtherElt->getType() != BytePtrTy) + OtherElt = new BitCastInst(OtherElt, BytePtrTy,OtherElt->getNameStr(), + MI); + + unsigned EltSize = TD->getTypeAllocSize(EltTy); + + // Finally, insert the meminst for this element. + if (isa<MemTransferInst>(MI)) { + Value *Ops[] = { + SROADest ? EltPtr : OtherElt, // Dest ptr + SROADest ? OtherElt : EltPtr, // Src ptr + ConstantInt::get(MI->getOperand(3)->getType(), EltSize), // Size + ConstantInt::get(Type::Int32Ty, OtherEltAlign) // Align + }; + CallInst::Create(TheFn, Ops, Ops + 4, "", MI); + } else { + assert(isa<MemSetInst>(MI)); + Value *Ops[] = { + EltPtr, MI->getOperand(2), // Dest, Value, + ConstantInt::get(MI->getOperand(3)->getType(), EltSize), // Size + Zero // Align + }; + CallInst::Create(TheFn, Ops, Ops + 4, "", MI); + } + } + MI->eraseFromParent(); +} + +/// RewriteStoreUserOfWholeAlloca - We found an store of an integer that +/// overwrites the entire allocation. Extract out the pieces of the stored +/// integer and store them individually. +void SROA::RewriteStoreUserOfWholeAlloca(StoreInst *SI, + AllocationInst *AI, + SmallVector<AllocaInst*, 32> &NewElts){ + // Extract each element out of the integer according to its structure offset + // and store the element value to the individual alloca. + Value *SrcVal = SI->getOperand(0); + const Type *AllocaEltTy = AI->getType()->getElementType(); + uint64_t AllocaSizeBits = TD->getTypeAllocSizeInBits(AllocaEltTy); + + // If this isn't a store of an integer to the whole alloca, it may be a store + // to the first element. Just ignore the store in this case and normal SROA + // will handle it. + if (!isa<IntegerType>(SrcVal->getType()) || + TD->getTypeAllocSizeInBits(SrcVal->getType()) != AllocaSizeBits) + return; + // Handle tail padding by extending the operand + if (TD->getTypeSizeInBits(SrcVal->getType()) != AllocaSizeBits) + SrcVal = new ZExtInst(SrcVal, IntegerType::get(AllocaSizeBits), "", SI); + + DOUT << "PROMOTING STORE TO WHOLE ALLOCA: " << *AI << *SI; + + // There are two forms here: AI could be an array or struct. Both cases + // have different ways to compute the element offset. + if (const StructType *EltSTy = dyn_cast<StructType>(AllocaEltTy)) { + const StructLayout *Layout = TD->getStructLayout(EltSTy); + + for (unsigned i = 0, e = NewElts.size(); i != e; ++i) { + // Get the number of bits to shift SrcVal to get the value. + const Type *FieldTy = EltSTy->getElementType(i); + uint64_t Shift = Layout->getElementOffsetInBits(i); + + if (TD->isBigEndian()) + Shift = AllocaSizeBits-Shift-TD->getTypeAllocSizeInBits(FieldTy); + + Value *EltVal = SrcVal; + if (Shift) { + Value *ShiftVal = ConstantInt::get(EltVal->getType(), Shift); + EltVal = BinaryOperator::CreateLShr(EltVal, ShiftVal, + "sroa.store.elt", SI); + } + + // Truncate down to an integer of the right size. + uint64_t FieldSizeBits = TD->getTypeSizeInBits(FieldTy); + + // Ignore zero sized fields like {}, they obviously contain no data. + if (FieldSizeBits == 0) continue; + + if (FieldSizeBits != AllocaSizeBits) + EltVal = new TruncInst(EltVal, IntegerType::get(FieldSizeBits), "", SI); + Value *DestField = NewElts[i]; + if (EltVal->getType() == FieldTy) { + // Storing to an integer field of this size, just do it. + } else if (FieldTy->isFloatingPoint() || isa<VectorType>(FieldTy)) { + // Bitcast to the right element type (for fp/vector values). + EltVal = new BitCastInst(EltVal, FieldTy, "", SI); + } else { + // Otherwise, bitcast the dest pointer (for aggregates). + DestField = new BitCastInst(DestField, + PointerType::getUnqual(EltVal->getType()), + "", SI); + } + new StoreInst(EltVal, DestField, SI); + } + + } else { + const ArrayType *ATy = cast<ArrayType>(AllocaEltTy); + const Type *ArrayEltTy = ATy->getElementType(); + uint64_t ElementOffset = TD->getTypeAllocSizeInBits(ArrayEltTy); + uint64_t ElementSizeBits = TD->getTypeSizeInBits(ArrayEltTy); + + uint64_t Shift; + + if (TD->isBigEndian()) + Shift = AllocaSizeBits-ElementOffset; + else + Shift = 0; + + for (unsigned i = 0, e = NewElts.size(); i != e; ++i) { + // Ignore zero sized fields like {}, they obviously contain no data. + if (ElementSizeBits == 0) continue; + + Value *EltVal = SrcVal; + if (Shift) { + Value *ShiftVal = ConstantInt::get(EltVal->getType(), Shift); + EltVal = BinaryOperator::CreateLShr(EltVal, ShiftVal, + "sroa.store.elt", SI); + } + + // Truncate down to an integer of the right size. + if (ElementSizeBits != AllocaSizeBits) + EltVal = new TruncInst(EltVal, IntegerType::get(ElementSizeBits),"",SI); + Value *DestField = NewElts[i]; + if (EltVal->getType() == ArrayEltTy) { + // Storing to an integer field of this size, just do it. + } else if (ArrayEltTy->isFloatingPoint() || isa<VectorType>(ArrayEltTy)) { + // Bitcast to the right element type (for fp/vector values). + EltVal = new BitCastInst(EltVal, ArrayEltTy, "", SI); + } else { + // Otherwise, bitcast the dest pointer (for aggregates). + DestField = new BitCastInst(DestField, + PointerType::getUnqual(EltVal->getType()), + "", SI); + } + new StoreInst(EltVal, DestField, SI); + + if (TD->isBigEndian()) + Shift -= ElementOffset; + else + Shift += ElementOffset; + } + } + + SI->eraseFromParent(); +} + +/// RewriteLoadUserOfWholeAlloca - We found an load of the entire allocation to +/// an integer. Load the individual pieces to form the aggregate value. +void SROA::RewriteLoadUserOfWholeAlloca(LoadInst *LI, AllocationInst *AI, + SmallVector<AllocaInst*, 32> &NewElts) { + // Extract each element out of the NewElts according to its structure offset + // and form the result value. + const Type *AllocaEltTy = AI->getType()->getElementType(); + uint64_t AllocaSizeBits = TD->getTypeAllocSizeInBits(AllocaEltTy); + + // If this isn't a load of the whole alloca to an integer, it may be a load + // of the first element. Just ignore the load in this case and normal SROA + // will handle it. + if (!isa<IntegerType>(LI->getType()) || + TD->getTypeAllocSizeInBits(LI->getType()) != AllocaSizeBits) + return; + + DOUT << "PROMOTING LOAD OF WHOLE ALLOCA: " << *AI << *LI; + + // There are two forms here: AI could be an array or struct. Both cases + // have different ways to compute the element offset. + const StructLayout *Layout = 0; + uint64_t ArrayEltBitOffset = 0; + if (const StructType *EltSTy = dyn_cast<StructType>(AllocaEltTy)) { + Layout = TD->getStructLayout(EltSTy); + } else { + const Type *ArrayEltTy = cast<ArrayType>(AllocaEltTy)->getElementType(); + ArrayEltBitOffset = TD->getTypeAllocSizeInBits(ArrayEltTy); + } + + Value *ResultVal = Constant::getNullValue(IntegerType::get(AllocaSizeBits)); + + for (unsigned i = 0, e = NewElts.size(); i != e; ++i) { + // Load the value from the alloca. If the NewElt is an aggregate, cast + // the pointer to an integer of the same size before doing the load. + Value *SrcField = NewElts[i]; + const Type *FieldTy = + cast<PointerType>(SrcField->getType())->getElementType(); + uint64_t FieldSizeBits = TD->getTypeSizeInBits(FieldTy); + + // Ignore zero sized fields like {}, they obviously contain no data. + if (FieldSizeBits == 0) continue; + + const IntegerType *FieldIntTy = IntegerType::get(FieldSizeBits); + if (!isa<IntegerType>(FieldTy) && !FieldTy->isFloatingPoint() && + !isa<VectorType>(FieldTy)) + SrcField = new BitCastInst(SrcField, PointerType::getUnqual(FieldIntTy), + "", LI); + SrcField = new LoadInst(SrcField, "sroa.load.elt", LI); + + // If SrcField is a fp or vector of the right size but that isn't an + // integer type, bitcast to an integer so we can shift it. + if (SrcField->getType() != FieldIntTy) + SrcField = new BitCastInst(SrcField, FieldIntTy, "", LI); + + // Zero extend the field to be the same size as the final alloca so that + // we can shift and insert it. + if (SrcField->getType() != ResultVal->getType()) + SrcField = new ZExtInst(SrcField, ResultVal->getType(), "", LI); + + // Determine the number of bits to shift SrcField. + uint64_t Shift; + if (Layout) // Struct case. + Shift = Layout->getElementOffsetInBits(i); + else // Array case. + Shift = i*ArrayEltBitOffset; + + if (TD->isBigEndian()) + Shift = AllocaSizeBits-Shift-FieldIntTy->getBitWidth(); + + if (Shift) { + Value *ShiftVal = ConstantInt::get(SrcField->getType(), Shift); + SrcField = BinaryOperator::CreateShl(SrcField, ShiftVal, "", LI); + } + + ResultVal = BinaryOperator::CreateOr(SrcField, ResultVal, "", LI); + } + + // Handle tail padding by truncating the result + if (TD->getTypeSizeInBits(LI->getType()) != AllocaSizeBits) + ResultVal = new TruncInst(ResultVal, LI->getType(), "", LI); + + LI->replaceAllUsesWith(ResultVal); + LI->eraseFromParent(); +} + + +/// HasPadding - Return true if the specified type has any structure or +/// alignment padding, false otherwise. +static bool HasPadding(const Type *Ty, const TargetData &TD) { + if (const StructType *STy = dyn_cast<StructType>(Ty)) { + const StructLayout *SL = TD.getStructLayout(STy); + unsigned PrevFieldBitOffset = 0; + for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { + unsigned FieldBitOffset = SL->getElementOffsetInBits(i); + + // Padding in sub-elements? + if (HasPadding(STy->getElementType(i), TD)) + return true; + + // Check to see if there is any padding between this element and the + // previous one. + if (i) { + unsigned PrevFieldEnd = + PrevFieldBitOffset+TD.getTypeSizeInBits(STy->getElementType(i-1)); + if (PrevFieldEnd < FieldBitOffset) + return true; + } + + PrevFieldBitOffset = FieldBitOffset; + } + + // Check for tail padding. + if (unsigned EltCount = STy->getNumElements()) { + unsigned PrevFieldEnd = PrevFieldBitOffset + + TD.getTypeSizeInBits(STy->getElementType(EltCount-1)); + if (PrevFieldEnd < SL->getSizeInBits()) + return true; + } + + } else if (const ArrayType *ATy = dyn_cast<ArrayType>(Ty)) { + return HasPadding(ATy->getElementType(), TD); + } else if (const VectorType *VTy = dyn_cast<VectorType>(Ty)) { + return HasPadding(VTy->getElementType(), TD); + } + return TD.getTypeSizeInBits(Ty) != TD.getTypeAllocSizeInBits(Ty); +} + +/// isSafeStructAllocaToScalarRepl - Check to see if the specified allocation of +/// an aggregate can be broken down into elements. Return 0 if not, 3 if safe, +/// or 1 if safe after canonicalization has been performed. +/// +int SROA::isSafeAllocaToScalarRepl(AllocationInst *AI) { + // Loop over the use list of the alloca. We can only transform it if all of + // the users are safe to transform. + AllocaInfo Info; + + for (Value::use_iterator I = AI->use_begin(), E = AI->use_end(); + I != E; ++I) { + isSafeUseOfAllocation(cast<Instruction>(*I), AI, Info); + if (Info.isUnsafe) { + DOUT << "Cannot transform: " << *AI << " due to user: " << **I; + return 0; + } + } + + // Okay, we know all the users are promotable. If the aggregate is a memcpy + // source and destination, we have to be careful. In particular, the memcpy + // could be moving around elements that live in structure padding of the LLVM + // types, but may actually be used. In these cases, we refuse to promote the + // struct. + if (Info.isMemCpySrc && Info.isMemCpyDst && + HasPadding(AI->getType()->getElementType(), *TD)) + return 0; + + // If we require cleanup, return 1, otherwise return 3. + return Info.needsCleanup ? 1 : 3; +} + +/// CleanupGEP - GEP is used by an Alloca, which can be prompted after the GEP +/// is canonicalized here. +void SROA::CleanupGEP(GetElementPtrInst *GEPI) { + gep_type_iterator I = gep_type_begin(GEPI); + ++I; + + const ArrayType *AT = dyn_cast<ArrayType>(*I); + if (!AT) + return; + + uint64_t NumElements = AT->getNumElements(); + + if (isa<ConstantInt>(I.getOperand())) + return; + + if (NumElements == 1) { + GEPI->setOperand(2, Constant::getNullValue(Type::Int32Ty)); + return; + } + + assert(NumElements == 2 && "Unhandled case!"); + // All users of the GEP must be loads. At each use of the GEP, insert + // two loads of the appropriate indexed GEP and select between them. + Value *IsOne = new ICmpInst(ICmpInst::ICMP_NE, I.getOperand(), + Constant::getNullValue(I.getOperand()->getType()), + "isone", GEPI); + // Insert the new GEP instructions, which are properly indexed. + SmallVector<Value*, 8> Indices(GEPI->op_begin()+1, GEPI->op_end()); + Indices[1] = Constant::getNullValue(Type::Int32Ty); + Value *ZeroIdx = GetElementPtrInst::Create(GEPI->getOperand(0), + Indices.begin(), + Indices.end(), + GEPI->getName()+".0", GEPI); + Indices[1] = ConstantInt::get(Type::Int32Ty, 1); + Value *OneIdx = GetElementPtrInst::Create(GEPI->getOperand(0), + Indices.begin(), + Indices.end(), + GEPI->getName()+".1", GEPI); + // Replace all loads of the variable index GEP with loads from both + // indexes and a select. + while (!GEPI->use_empty()) { + LoadInst *LI = cast<LoadInst>(GEPI->use_back()); + Value *Zero = new LoadInst(ZeroIdx, LI->getName()+".0", LI); + Value *One = new LoadInst(OneIdx , LI->getName()+".1", LI); + Value *R = SelectInst::Create(IsOne, One, Zero, LI->getName(), LI); + LI->replaceAllUsesWith(R); + LI->eraseFromParent(); + } + GEPI->eraseFromParent(); +} + + +/// CleanupAllocaUsers - If SROA reported that it can promote the specified +/// allocation, but only if cleaned up, perform the cleanups required. +void SROA::CleanupAllocaUsers(AllocationInst *AI) { + // At this point, we know that the end result will be SROA'd and promoted, so + // we can insert ugly code if required so long as sroa+mem2reg will clean it + // up. + for (Value::use_iterator UI = AI->use_begin(), E = AI->use_end(); + UI != E; ) { + User *U = *UI++; + if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(U)) + CleanupGEP(GEPI); + else if (Instruction *I = dyn_cast<Instruction>(U)) { + SmallVector<DbgInfoIntrinsic *, 2> DbgInUses; + if (!isa<StoreInst>(I) && OnlyUsedByDbgInfoIntrinsics(I, &DbgInUses)) { + // Safe to remove debug info uses. + while (!DbgInUses.empty()) { + DbgInfoIntrinsic *DI = DbgInUses.back(); DbgInUses.pop_back(); + DI->eraseFromParent(); + } + I->eraseFromParent(); + } + } + } +} + +/// MergeInType - Add the 'In' type to the accumulated type (Accum) so far at +/// the offset specified by Offset (which is specified in bytes). +/// +/// There are two cases we handle here: +/// 1) A union of vector types of the same size and potentially its elements. +/// Here we turn element accesses into insert/extract element operations. +/// This promotes a <4 x float> with a store of float to the third element +/// into a <4 x float> that uses insert element. +/// 2) A fully general blob of memory, which we turn into some (potentially +/// large) integer type with extract and insert operations where the loads +/// and stores would mutate the memory. +static void MergeInType(const Type *In, uint64_t Offset, const Type *&VecTy, + unsigned AllocaSize, const TargetData &TD) { + // If this could be contributing to a vector, analyze it. + if (VecTy != Type::VoidTy) { // either null or a vector type. + + // If the In type is a vector that is the same size as the alloca, see if it + // matches the existing VecTy. + if (const VectorType *VInTy = dyn_cast<VectorType>(In)) { + if (VInTy->getBitWidth()/8 == AllocaSize && Offset == 0) { + // If we're storing/loading a vector of the right size, allow it as a + // vector. If this the first vector we see, remember the type so that + // we know the element size. + if (VecTy == 0) + VecTy = VInTy; + return; + } + } else if (In == Type::FloatTy || In == Type::DoubleTy || + (isa<IntegerType>(In) && In->getPrimitiveSizeInBits() >= 8 && + isPowerOf2_32(In->getPrimitiveSizeInBits()))) { + // If we're accessing something that could be an element of a vector, see + // if the implied vector agrees with what we already have and if Offset is + // compatible with it. + unsigned EltSize = In->getPrimitiveSizeInBits()/8; + if (Offset % EltSize == 0 && + AllocaSize % EltSize == 0 && + (VecTy == 0 || + cast<VectorType>(VecTy)->getElementType() + ->getPrimitiveSizeInBits()/8 == EltSize)) { + if (VecTy == 0) + VecTy = VectorType::get(In, AllocaSize/EltSize); + return; + } + } + } + + // Otherwise, we have a case that we can't handle with an optimized vector + // form. We can still turn this into a large integer. + VecTy = Type::VoidTy; +} + +/// CanConvertToScalar - V is a pointer. If we can convert the pointee and all +/// its accesses to use a to single vector type, return true, and set VecTy to +/// the new type. If we could convert the alloca into a single promotable +/// integer, return true but set VecTy to VoidTy. Further, if the use is not a +/// completely trivial use that mem2reg could promote, set IsNotTrivial. Offset +/// is the current offset from the base of the alloca being analyzed. +/// +/// If we see at least one access to the value that is as a vector type, set the +/// SawVec flag. +/// +bool SROA::CanConvertToScalar(Value *V, bool &IsNotTrivial, const Type *&VecTy, + bool &SawVec, uint64_t Offset, + unsigned AllocaSize) { + for (Value::use_iterator UI = V->use_begin(), E = V->use_end(); UI!=E; ++UI) { + Instruction *User = cast<Instruction>(*UI); + + if (LoadInst *LI = dyn_cast<LoadInst>(User)) { + // Don't break volatile loads. + if (LI->isVolatile()) + return false; + MergeInType(LI->getType(), Offset, VecTy, AllocaSize, *TD); + SawVec |= isa<VectorType>(LI->getType()); + continue; + } + + if (StoreInst *SI = dyn_cast<StoreInst>(User)) { + // Storing the pointer, not into the value? + if (SI->getOperand(0) == V || SI->isVolatile()) return 0; + MergeInType(SI->getOperand(0)->getType(), Offset, VecTy, AllocaSize, *TD); + SawVec |= isa<VectorType>(SI->getOperand(0)->getType()); + continue; + } + + if (BitCastInst *BCI = dyn_cast<BitCastInst>(User)) { + if (!CanConvertToScalar(BCI, IsNotTrivial, VecTy, SawVec, Offset, + AllocaSize)) + return false; + IsNotTrivial = true; + continue; + } + + if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(User)) { + // If this is a GEP with a variable indices, we can't handle it. + if (!GEP->hasAllConstantIndices()) + return false; + + // Compute the offset that this GEP adds to the pointer. + SmallVector<Value*, 8> Indices(GEP->op_begin()+1, GEP->op_end()); + uint64_t GEPOffset = TD->getIndexedOffset(GEP->getOperand(0)->getType(), + &Indices[0], Indices.size()); + // See if all uses can be converted. + if (!CanConvertToScalar(GEP, IsNotTrivial, VecTy, SawVec,Offset+GEPOffset, + AllocaSize)) + return false; + IsNotTrivial = true; + continue; + } + + // If this is a constant sized memset of a constant value (e.g. 0) we can + // handle it. + if (MemSetInst *MSI = dyn_cast<MemSetInst>(User)) { + // Store of constant value and constant size. + if (isa<ConstantInt>(MSI->getValue()) && + isa<ConstantInt>(MSI->getLength())) { + IsNotTrivial = true; + continue; + } + } + + // If this is a memcpy or memmove into or out of the whole allocation, we + // can handle it like a load or store of the scalar type. + if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(User)) { + if (ConstantInt *Len = dyn_cast<ConstantInt>(MTI->getLength())) + if (Len->getZExtValue() == AllocaSize && Offset == 0) { + IsNotTrivial = true; + continue; + } + } + + // Ignore dbg intrinsic. + if (isa<DbgInfoIntrinsic>(User)) + continue; + + // Otherwise, we cannot handle this! + return false; + } + + return true; +} + + +/// ConvertUsesToScalar - Convert all of the users of Ptr to use the new alloca +/// directly. This happens when we are converting an "integer union" to a +/// single integer scalar, or when we are converting a "vector union" to a +/// vector with insert/extractelement instructions. +/// +/// Offset is an offset from the original alloca, in bits that need to be +/// shifted to the right. By the end of this, there should be no uses of Ptr. +void SROA::ConvertUsesToScalar(Value *Ptr, AllocaInst *NewAI, uint64_t Offset) { + while (!Ptr->use_empty()) { + Instruction *User = cast<Instruction>(Ptr->use_back()); + + if (BitCastInst *CI = dyn_cast<BitCastInst>(User)) { + ConvertUsesToScalar(CI, NewAI, Offset); + CI->eraseFromParent(); + continue; + } + + if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(User)) { + // Compute the offset that this GEP adds to the pointer. + SmallVector<Value*, 8> Indices(GEP->op_begin()+1, GEP->op_end()); + uint64_t GEPOffset = TD->getIndexedOffset(GEP->getOperand(0)->getType(), + &Indices[0], Indices.size()); + ConvertUsesToScalar(GEP, NewAI, Offset+GEPOffset*8); + GEP->eraseFromParent(); + continue; + } + + IRBuilder<> Builder(User->getParent(), User); + + if (LoadInst *LI = dyn_cast<LoadInst>(User)) { + // The load is a bit extract from NewAI shifted right by Offset bits. + Value *LoadedVal = Builder.CreateLoad(NewAI, "tmp"); + Value *NewLoadVal + = ConvertScalar_ExtractValue(LoadedVal, LI->getType(), Offset, Builder); + LI->replaceAllUsesWith(NewLoadVal); + LI->eraseFromParent(); + continue; + } + + if (StoreInst *SI = dyn_cast<StoreInst>(User)) { + assert(SI->getOperand(0) != Ptr && "Consistency error!"); + Value *Old = Builder.CreateLoad(NewAI, (NewAI->getName()+".in").c_str()); + Value *New = ConvertScalar_InsertValue(SI->getOperand(0), Old, Offset, + Builder); + Builder.CreateStore(New, NewAI); + SI->eraseFromParent(); + continue; + } + + // If this is a constant sized memset of a constant value (e.g. 0) we can + // transform it into a store of the expanded constant value. + if (MemSetInst *MSI = dyn_cast<MemSetInst>(User)) { + assert(MSI->getRawDest() == Ptr && "Consistency error!"); + unsigned NumBytes = cast<ConstantInt>(MSI->getLength())->getZExtValue(); + if (NumBytes != 0) { + unsigned Val = cast<ConstantInt>(MSI->getValue())->getZExtValue(); + + // Compute the value replicated the right number of times. + APInt APVal(NumBytes*8, Val); + + // Splat the value if non-zero. + if (Val) + for (unsigned i = 1; i != NumBytes; ++i) + APVal |= APVal << 8; + + Value *Old = Builder.CreateLoad(NewAI, (NewAI->getName()+".in").c_str()); + Value *New = ConvertScalar_InsertValue(ConstantInt::get(APVal), Old, + Offset, Builder); + Builder.CreateStore(New, NewAI); + } + MSI->eraseFromParent(); + continue; + } + + // If this is a memcpy or memmove into or out of the whole allocation, we + // can handle it like a load or store of the scalar type. + if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(User)) { + assert(Offset == 0 && "must be store to start of alloca"); + + // If the source and destination are both to the same alloca, then this is + // a noop copy-to-self, just delete it. Otherwise, emit a load and store + // as appropriate. + AllocaInst *OrigAI = cast<AllocaInst>(Ptr->getUnderlyingObject()); + + if (MTI->getSource()->getUnderlyingObject() != OrigAI) { + // Dest must be OrigAI, change this to be a load from the original + // pointer (bitcasted), then a store to our new alloca. + assert(MTI->getRawDest() == Ptr && "Neither use is of pointer?"); + Value *SrcPtr = MTI->getSource(); + SrcPtr = Builder.CreateBitCast(SrcPtr, NewAI->getType()); + + LoadInst *SrcVal = Builder.CreateLoad(SrcPtr, "srcval"); + SrcVal->setAlignment(MTI->getAlignment()); + Builder.CreateStore(SrcVal, NewAI); + } else if (MTI->getDest()->getUnderlyingObject() != OrigAI) { + // Src must be OrigAI, change this to be a load from NewAI then a store + // through the original dest pointer (bitcasted). + assert(MTI->getRawSource() == Ptr && "Neither use is of pointer?"); + LoadInst *SrcVal = Builder.CreateLoad(NewAI, "srcval"); + + Value *DstPtr = Builder.CreateBitCast(MTI->getDest(), NewAI->getType()); + StoreInst *NewStore = Builder.CreateStore(SrcVal, DstPtr); + NewStore->setAlignment(MTI->getAlignment()); + } else { + // Noop transfer. Src == Dst + } + + + MTI->eraseFromParent(); + continue; + } + + // If user is a dbg info intrinsic then it is safe to remove it. + if (isa<DbgInfoIntrinsic>(User)) { + User->eraseFromParent(); + continue; + } + + assert(0 && "Unsupported operation!"); + abort(); + } +} + +/// ConvertScalar_ExtractValue - Extract a value of type ToType from an integer +/// or vector value FromVal, extracting the bits from the offset specified by +/// Offset. This returns the value, which is of type ToType. +/// +/// This happens when we are converting an "integer union" to a single +/// integer scalar, or when we are converting a "vector union" to a vector with +/// insert/extractelement instructions. +/// +/// Offset is an offset from the original alloca, in bits that need to be +/// shifted to the right. +Value *SROA::ConvertScalar_ExtractValue(Value *FromVal, const Type *ToType, + uint64_t Offset, IRBuilder<> &Builder) { + // If the load is of the whole new alloca, no conversion is needed. + if (FromVal->getType() == ToType && Offset == 0) + return FromVal; + + // If the result alloca is a vector type, this is either an element + // access or a bitcast to another vector type of the same size. + if (const VectorType *VTy = dyn_cast<VectorType>(FromVal->getType())) { + if (isa<VectorType>(ToType)) + return Builder.CreateBitCast(FromVal, ToType, "tmp"); + + // Otherwise it must be an element access. + unsigned Elt = 0; + if (Offset) { + unsigned EltSize = TD->getTypeAllocSizeInBits(VTy->getElementType()); + Elt = Offset/EltSize; + assert(EltSize*Elt == Offset && "Invalid modulus in validity checking"); + } + // Return the element extracted out of it. + Value *V = Builder.CreateExtractElement(FromVal, + ConstantInt::get(Type::Int32Ty,Elt), + "tmp"); + if (V->getType() != ToType) + V = Builder.CreateBitCast(V, ToType, "tmp"); + return V; + } + + // If ToType is a first class aggregate, extract out each of the pieces and + // use insertvalue's to form the FCA. + if (const StructType *ST = dyn_cast<StructType>(ToType)) { + const StructLayout &Layout = *TD->getStructLayout(ST); + Value *Res = UndefValue::get(ST); + for (unsigned i = 0, e = ST->getNumElements(); i != e; ++i) { + Value *Elt = ConvertScalar_ExtractValue(FromVal, ST->getElementType(i), + Offset+Layout.getElementOffsetInBits(i), + Builder); + Res = Builder.CreateInsertValue(Res, Elt, i, "tmp"); + } + return Res; + } + + if (const ArrayType *AT = dyn_cast<ArrayType>(ToType)) { + uint64_t EltSize = TD->getTypeAllocSizeInBits(AT->getElementType()); + Value *Res = UndefValue::get(AT); + for (unsigned i = 0, e = AT->getNumElements(); i != e; ++i) { + Value *Elt = ConvertScalar_ExtractValue(FromVal, AT->getElementType(), + Offset+i*EltSize, Builder); + Res = Builder.CreateInsertValue(Res, Elt, i, "tmp"); + } + return Res; + } + + // Otherwise, this must be a union that was converted to an integer value. + const IntegerType *NTy = cast<IntegerType>(FromVal->getType()); + + // If this is a big-endian system and the load is narrower than the + // full alloca type, we need to do a shift to get the right bits. + int ShAmt = 0; + if (TD->isBigEndian()) { + // On big-endian machines, the lowest bit is stored at the bit offset + // from the pointer given by getTypeStoreSizeInBits. This matters for + // integers with a bitwidth that is not a multiple of 8. + ShAmt = TD->getTypeStoreSizeInBits(NTy) - + TD->getTypeStoreSizeInBits(ToType) - Offset; + } else { + ShAmt = Offset; + } + + // Note: we support negative bitwidths (with shl) which are not defined. + // We do this to support (f.e.) loads off the end of a structure where + // only some bits are used. + if (ShAmt > 0 && (unsigned)ShAmt < NTy->getBitWidth()) + FromVal = Builder.CreateLShr(FromVal, ConstantInt::get(FromVal->getType(), + ShAmt), "tmp"); + else if (ShAmt < 0 && (unsigned)-ShAmt < NTy->getBitWidth()) + FromVal = Builder.CreateShl(FromVal, ConstantInt::get(FromVal->getType(), + -ShAmt), "tmp"); + + // Finally, unconditionally truncate the integer to the right width. + unsigned LIBitWidth = TD->getTypeSizeInBits(ToType); + if (LIBitWidth < NTy->getBitWidth()) + FromVal = Builder.CreateTrunc(FromVal, IntegerType::get(LIBitWidth), "tmp"); + else if (LIBitWidth > NTy->getBitWidth()) + FromVal = Builder.CreateZExt(FromVal, IntegerType::get(LIBitWidth), "tmp"); + + // If the result is an integer, this is a trunc or bitcast. + if (isa<IntegerType>(ToType)) { + // Should be done. + } else if (ToType->isFloatingPoint() || isa<VectorType>(ToType)) { + // Just do a bitcast, we know the sizes match up. + FromVal = Builder.CreateBitCast(FromVal, ToType, "tmp"); + } else { + // Otherwise must be a pointer. + FromVal = Builder.CreateIntToPtr(FromVal, ToType, "tmp"); + } + assert(FromVal->getType() == ToType && "Didn't convert right?"); + return FromVal; +} + + +/// ConvertScalar_InsertValue - Insert the value "SV" into the existing integer +/// or vector value "Old" at the offset specified by Offset. +/// +/// This happens when we are converting an "integer union" to a +/// single integer scalar, or when we are converting a "vector union" to a +/// vector with insert/extractelement instructions. +/// +/// Offset is an offset from the original alloca, in bits that need to be +/// shifted to the right. +Value *SROA::ConvertScalar_InsertValue(Value *SV, Value *Old, + uint64_t Offset, IRBuilder<> &Builder) { + + // Convert the stored type to the actual type, shift it left to insert + // then 'or' into place. + const Type *AllocaType = Old->getType(); + + if (const VectorType *VTy = dyn_cast<VectorType>(AllocaType)) { + uint64_t VecSize = TD->getTypeAllocSizeInBits(VTy); + uint64_t ValSize = TD->getTypeAllocSizeInBits(SV->getType()); + + // Changing the whole vector with memset or with an access of a different + // vector type? + if (ValSize == VecSize) + return Builder.CreateBitCast(SV, AllocaType, "tmp"); + + uint64_t EltSize = TD->getTypeAllocSizeInBits(VTy->getElementType()); + + // Must be an element insertion. + unsigned Elt = Offset/EltSize; + + if (SV->getType() != VTy->getElementType()) + SV = Builder.CreateBitCast(SV, VTy->getElementType(), "tmp"); + + SV = Builder.CreateInsertElement(Old, SV, + ConstantInt::get(Type::Int32Ty, Elt), + "tmp"); + return SV; + } + + // If SV is a first-class aggregate value, insert each value recursively. + if (const StructType *ST = dyn_cast<StructType>(SV->getType())) { + const StructLayout &Layout = *TD->getStructLayout(ST); + for (unsigned i = 0, e = ST->getNumElements(); i != e; ++i) { + Value *Elt = Builder.CreateExtractValue(SV, i, "tmp"); + Old = ConvertScalar_InsertValue(Elt, Old, + Offset+Layout.getElementOffsetInBits(i), + Builder); + } + return Old; + } + + if (const ArrayType *AT = dyn_cast<ArrayType>(SV->getType())) { + uint64_t EltSize = TD->getTypeAllocSizeInBits(AT->getElementType()); + for (unsigned i = 0, e = AT->getNumElements(); i != e; ++i) { + Value *Elt = Builder.CreateExtractValue(SV, i, "tmp"); + Old = ConvertScalar_InsertValue(Elt, Old, Offset+i*EltSize, Builder); + } + return Old; + } + + // If SV is a float, convert it to the appropriate integer type. + // If it is a pointer, do the same. + unsigned SrcWidth = TD->getTypeSizeInBits(SV->getType()); + unsigned DestWidth = TD->getTypeSizeInBits(AllocaType); + unsigned SrcStoreWidth = TD->getTypeStoreSizeInBits(SV->getType()); + unsigned DestStoreWidth = TD->getTypeStoreSizeInBits(AllocaType); + if (SV->getType()->isFloatingPoint() || isa<VectorType>(SV->getType())) + SV = Builder.CreateBitCast(SV, IntegerType::get(SrcWidth), "tmp"); + else if (isa<PointerType>(SV->getType())) + SV = Builder.CreatePtrToInt(SV, TD->getIntPtrType(), "tmp"); + + // Zero extend or truncate the value if needed. + if (SV->getType() != AllocaType) { + if (SV->getType()->getPrimitiveSizeInBits() < + AllocaType->getPrimitiveSizeInBits()) + SV = Builder.CreateZExt(SV, AllocaType, "tmp"); + else { + // Truncation may be needed if storing more than the alloca can hold + // (undefined behavior). + SV = Builder.CreateTrunc(SV, AllocaType, "tmp"); + SrcWidth = DestWidth; + SrcStoreWidth = DestStoreWidth; + } + } + + // If this is a big-endian system and the store is narrower than the + // full alloca type, we need to do a shift to get the right bits. + int ShAmt = 0; + if (TD->isBigEndian()) { + // On big-endian machines, the lowest bit is stored at the bit offset + // from the pointer given by getTypeStoreSizeInBits. This matters for + // integers with a bitwidth that is not a multiple of 8. + ShAmt = DestStoreWidth - SrcStoreWidth - Offset; + } else { + ShAmt = Offset; + } + + // Note: we support negative bitwidths (with shr) which are not defined. + // We do this to support (f.e.) stores off the end of a structure where + // only some bits in the structure are set. + APInt Mask(APInt::getLowBitsSet(DestWidth, SrcWidth)); + if (ShAmt > 0 && (unsigned)ShAmt < DestWidth) { + SV = Builder.CreateShl(SV, ConstantInt::get(SV->getType(), ShAmt), "tmp"); + Mask <<= ShAmt; + } else if (ShAmt < 0 && (unsigned)-ShAmt < DestWidth) { + SV = Builder.CreateLShr(SV, ConstantInt::get(SV->getType(), -ShAmt), "tmp"); + Mask = Mask.lshr(-ShAmt); + } + + // Mask out the bits we are about to insert from the old value, and or + // in the new bits. + if (SrcWidth != DestWidth) { + assert(DestWidth > SrcWidth); + Old = Builder.CreateAnd(Old, ConstantInt::get(~Mask), "mask"); + SV = Builder.CreateOr(Old, SV, "ins"); + } + return SV; +} + + + +/// PointsToConstantGlobal - Return true if V (possibly indirectly) points to +/// some part of a constant global variable. This intentionally only accepts +/// constant expressions because we don't can't rewrite arbitrary instructions. +static bool PointsToConstantGlobal(Value *V) { + if (GlobalVariable *GV = dyn_cast<GlobalVariable>(V)) + return GV->isConstant(); + if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) + if (CE->getOpcode() == Instruction::BitCast || + CE->getOpcode() == Instruction::GetElementPtr) + return PointsToConstantGlobal(CE->getOperand(0)); + return false; +} + +/// isOnlyCopiedFromConstantGlobal - Recursively walk the uses of a (derived) +/// pointer to an alloca. Ignore any reads of the pointer, return false if we +/// see any stores or other unknown uses. If we see pointer arithmetic, keep +/// track of whether it moves the pointer (with isOffset) but otherwise traverse +/// the uses. If we see a memcpy/memmove that targets an unoffseted pointer to +/// the alloca, and if the source pointer is a pointer to a constant global, we +/// can optimize this. +static bool isOnlyCopiedFromConstantGlobal(Value *V, Instruction *&TheCopy, + bool isOffset) { + for (Value::use_iterator UI = V->use_begin(), E = V->use_end(); UI!=E; ++UI) { + if (LoadInst *LI = dyn_cast<LoadInst>(*UI)) + // Ignore non-volatile loads, they are always ok. + if (!LI->isVolatile()) + continue; + + if (BitCastInst *BCI = dyn_cast<BitCastInst>(*UI)) { + // If uses of the bitcast are ok, we are ok. + if (!isOnlyCopiedFromConstantGlobal(BCI, TheCopy, isOffset)) + return false; + continue; + } + if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(*UI)) { + // If the GEP has all zero indices, it doesn't offset the pointer. If it + // doesn't, it does. + if (!isOnlyCopiedFromConstantGlobal(GEP, TheCopy, + isOffset || !GEP->hasAllZeroIndices())) + return false; + continue; + } + + // If this is isn't our memcpy/memmove, reject it as something we can't + // handle. + if (!isa<MemTransferInst>(*UI)) + return false; + + // If we already have seen a copy, reject the second one. + if (TheCopy) return false; + + // If the pointer has been offset from the start of the alloca, we can't + // safely handle this. + if (isOffset) return false; + + // If the memintrinsic isn't using the alloca as the dest, reject it. + if (UI.getOperandNo() != 1) return false; + + MemIntrinsic *MI = cast<MemIntrinsic>(*UI); + + // If the source of the memcpy/move is not a constant global, reject it. + if (!PointsToConstantGlobal(MI->getOperand(2))) + return false; + + // Otherwise, the transform is safe. Remember the copy instruction. + TheCopy = MI; + } + return true; +} + +/// isOnlyCopiedFromConstantGlobal - Return true if the specified alloca is only +/// modified by a copy from a constant global. If we can prove this, we can +/// replace any uses of the alloca with uses of the global directly. +Instruction *SROA::isOnlyCopiedFromConstantGlobal(AllocationInst *AI) { + Instruction *TheCopy = 0; + if (::isOnlyCopiedFromConstantGlobal(AI, TheCopy, false)) + return TheCopy; + return 0; +} diff --git a/lib/Transforms/Scalar/SimplifyCFGPass.cpp b/lib/Transforms/Scalar/SimplifyCFGPass.cpp new file mode 100644 index 0000000..b499279 --- /dev/null +++ b/lib/Transforms/Scalar/SimplifyCFGPass.cpp @@ -0,0 +1,232 @@ +//===- SimplifyCFGPass.cpp - CFG Simplification Pass ----------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements dead code elimination and basic block merging, along +// with a collection of other peephole control flow optimizations. For example: +// +// * Removes basic blocks with no predecessors. +// * Merges a basic block into its predecessor if there is only one and the +// predecessor only has one successor. +// * Eliminates PHI nodes for basic blocks with a single predecessor. +// * Eliminates a basic block that only contains an unconditional branch. +// * Changes invoke instructions to nounwind functions to be calls. +// * Change things like "if (x) if (y)" into "if (x&y)". +// * etc.. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "simplifycfg" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Constants.h" +#include "llvm/Instructions.h" +#include "llvm/Module.h" +#include "llvm/Attributes.h" +#include "llvm/Support/CFG.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Pass.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/Statistic.h" +using namespace llvm; + +STATISTIC(NumSimpl, "Number of blocks simplified"); + +namespace { + struct VISIBILITY_HIDDEN CFGSimplifyPass : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + CFGSimplifyPass() : FunctionPass(&ID) {} + + virtual bool runOnFunction(Function &F); + }; +} + +char CFGSimplifyPass::ID = 0; +static RegisterPass<CFGSimplifyPass> X("simplifycfg", "Simplify the CFG"); + +// Public interface to the CFGSimplification pass +FunctionPass *llvm::createCFGSimplificationPass() { + return new CFGSimplifyPass(); +} + +/// ChangeToUnreachable - Insert an unreachable instruction before the specified +/// instruction, making it and the rest of the code in the block dead. +static void ChangeToUnreachable(Instruction *I) { + BasicBlock *BB = I->getParent(); + // Loop over all of the successors, removing BB's entry from any PHI + // nodes. + for (succ_iterator SI = succ_begin(BB), SE = succ_end(BB); SI != SE; ++SI) + (*SI)->removePredecessor(BB); + + new UnreachableInst(I); + + // All instructions after this are dead. + BasicBlock::iterator BBI = I, BBE = BB->end(); + while (BBI != BBE) { + if (!BBI->use_empty()) + BBI->replaceAllUsesWith(UndefValue::get(BBI->getType())); + BB->getInstList().erase(BBI++); + } +} + +/// ChangeToCall - Convert the specified invoke into a normal call. +static void ChangeToCall(InvokeInst *II) { + BasicBlock *BB = II->getParent(); + SmallVector<Value*, 8> Args(II->op_begin()+3, II->op_end()); + CallInst *NewCall = CallInst::Create(II->getCalledValue(), Args.begin(), + Args.end(), "", II); + NewCall->takeName(II); + NewCall->setCallingConv(II->getCallingConv()); + NewCall->setAttributes(II->getAttributes()); + II->replaceAllUsesWith(NewCall); + + // Follow the call by a branch to the normal destination. + BranchInst::Create(II->getNormalDest(), II); + + // Update PHI nodes in the unwind destination + II->getUnwindDest()->removePredecessor(BB); + BB->getInstList().erase(II); +} + +static bool MarkAliveBlocks(BasicBlock *BB, + SmallPtrSet<BasicBlock*, 128> &Reachable) { + + SmallVector<BasicBlock*, 128> Worklist; + Worklist.push_back(BB); + bool Changed = false; + while (!Worklist.empty()) { + BB = Worklist.back(); + Worklist.pop_back(); + + if (!Reachable.insert(BB)) + continue; + + // Do a quick scan of the basic block, turning any obviously unreachable + // instructions into LLVM unreachable insts. The instruction combining pass + // canonicalizes unreachable insts into stores to null or undef. + for (BasicBlock::iterator BBI = BB->begin(), E = BB->end(); BBI != E;++BBI){ + if (CallInst *CI = dyn_cast<CallInst>(BBI)) { + if (CI->doesNotReturn()) { + // If we found a call to a no-return function, insert an unreachable + // instruction after it. Make sure there isn't *already* one there + // though. + ++BBI; + if (!isa<UnreachableInst>(BBI)) { + ChangeToUnreachable(BBI); + Changed = true; + } + break; + } + } + + if (StoreInst *SI = dyn_cast<StoreInst>(BBI)) + if (isa<ConstantPointerNull>(SI->getOperand(1)) || + isa<UndefValue>(SI->getOperand(1))) { + ChangeToUnreachable(SI); + Changed = true; + break; + } + } + + // Turn invokes that call 'nounwind' functions into ordinary calls. + if (InvokeInst *II = dyn_cast<InvokeInst>(BB->getTerminator())) + if (II->doesNotThrow()) { + ChangeToCall(II); + Changed = true; + } + + Changed |= ConstantFoldTerminator(BB); + for (succ_iterator SI = succ_begin(BB), SE = succ_end(BB); SI != SE; ++SI) + Worklist.push_back(*SI); + } + return Changed; +} + +/// RemoveUnreachableBlocksFromFn - Remove blocks that are not reachable, even +/// if they are in a dead cycle. Return true if a change was made, false +/// otherwise. +static bool RemoveUnreachableBlocksFromFn(Function &F) { + SmallPtrSet<BasicBlock*, 128> Reachable; + bool Changed = MarkAliveBlocks(F.begin(), Reachable); + + // If there are unreachable blocks in the CFG... + if (Reachable.size() == F.size()) + return Changed; + + assert(Reachable.size() < F.size()); + NumSimpl += F.size()-Reachable.size(); + + // Loop over all of the basic blocks that are not reachable, dropping all of + // their internal references... + for (Function::iterator BB = ++F.begin(), E = F.end(); BB != E; ++BB) { + if (Reachable.count(BB)) + continue; + + for (succ_iterator SI = succ_begin(BB), SE = succ_end(BB); SI != SE; ++SI) + if (Reachable.count(*SI)) + (*SI)->removePredecessor(BB); + BB->dropAllReferences(); + } + + for (Function::iterator I = ++F.begin(); I != F.end();) + if (!Reachable.count(I)) + I = F.getBasicBlockList().erase(I); + else + ++I; + + return true; +} + +/// IterativeSimplifyCFG - Call SimplifyCFG on all the blocks in the function, +/// iterating until no more changes are made. +static bool IterativeSimplifyCFG(Function &F) { + bool Changed = false; + bool LocalChange = true; + while (LocalChange) { + LocalChange = false; + + // Loop over all of the basic blocks (except the first one) and remove them + // if they are unneeded... + // + for (Function::iterator BBIt = ++F.begin(); BBIt != F.end(); ) { + if (SimplifyCFG(BBIt++)) { + LocalChange = true; + ++NumSimpl; + } + } + Changed |= LocalChange; + } + return Changed; +} + +// It is possible that we may require multiple passes over the code to fully +// simplify the CFG. +// +bool CFGSimplifyPass::runOnFunction(Function &F) { + bool EverChanged = RemoveUnreachableBlocksFromFn(F); + EverChanged |= IterativeSimplifyCFG(F); + + // If neither pass changed anything, we're done. + if (!EverChanged) return false; + + // IterativeSimplifyCFG can (rarely) make some loops dead. If this happens, + // RemoveUnreachableBlocksFromFn is needed to nuke them, which means we should + // iterate between the two optimizations. We structure the code like this to + // avoid reruning IterativeSimplifyCFG if the second pass of + // RemoveUnreachableBlocksFromFn doesn't do anything. + if (!RemoveUnreachableBlocksFromFn(F)) + return true; + + do { + EverChanged = IterativeSimplifyCFG(F); + EverChanged |= RemoveUnreachableBlocksFromFn(F); + } while (EverChanged); + + return true; +} diff --git a/lib/Transforms/Scalar/SimplifyHalfPowrLibCalls.cpp b/lib/Transforms/Scalar/SimplifyHalfPowrLibCalls.cpp new file mode 100644 index 0000000..4aad17d --- /dev/null +++ b/lib/Transforms/Scalar/SimplifyHalfPowrLibCalls.cpp @@ -0,0 +1,159 @@ +//===- SimplifyHalfPowrLibCalls.cpp - Optimize specific half_powr calls ---===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements a simple pass that applies an experimental +// transformation on calls to specific functions. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "simplify-libcalls-halfpowr" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Instructions.h" +#include "llvm/Intrinsics.h" +#include "llvm/Module.h" +#include "llvm/Pass.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Target/TargetData.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Config/config.h" +using namespace llvm; + +namespace { + /// This pass optimizes well half_powr function calls. + /// + class VISIBILITY_HIDDEN SimplifyHalfPowrLibCalls : public FunctionPass { + const TargetData *TD; + public: + static char ID; // Pass identification + SimplifyHalfPowrLibCalls() : FunctionPass(&ID) {} + + bool runOnFunction(Function &F); + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired<TargetData>(); + } + + Instruction * + InlineHalfPowrs(const std::vector<Instruction *> &HalfPowrs, + Instruction *InsertPt); + }; + char SimplifyHalfPowrLibCalls::ID = 0; +} // end anonymous namespace. + +static RegisterPass<SimplifyHalfPowrLibCalls> +X("simplify-libcalls-halfpowr", "Simplify half_powr library calls"); + +// Public interface to the Simplify HalfPowr LibCalls pass. +FunctionPass *llvm::createSimplifyHalfPowrLibCallsPass() { + return new SimplifyHalfPowrLibCalls(); +} + +/// InlineHalfPowrs - Inline a sequence of adjacent half_powr calls, rearranging +/// their control flow to better facilitate subsequent optimization. +Instruction * +SimplifyHalfPowrLibCalls::InlineHalfPowrs(const std::vector<Instruction *> &HalfPowrs, + Instruction *InsertPt) { + std::vector<BasicBlock *> Bodies; + BasicBlock *NewBlock = 0; + + for (unsigned i = 0, e = HalfPowrs.size(); i != e; ++i) { + CallInst *Call = cast<CallInst>(HalfPowrs[i]); + Function *Callee = Call->getCalledFunction(); + + // Minimally sanity-check the CFG of half_powr to ensure that it contains + // the the kind of code we expect. If we're running this pass, we have + // reason to believe it will be what we expect. + Function::iterator I = Callee->begin(); + BasicBlock *Prologue = I++; + if (I == Callee->end()) break; + BasicBlock *SubnormalHandling = I++; + if (I == Callee->end()) break; + BasicBlock *Body = I++; + if (I != Callee->end()) break; + if (SubnormalHandling->getSinglePredecessor() != Prologue) + break; + BranchInst *PBI = dyn_cast<BranchInst>(Prologue->getTerminator()); + if (!PBI || !PBI->isConditional()) + break; + BranchInst *SNBI = dyn_cast<BranchInst>(SubnormalHandling->getTerminator()); + if (!SNBI || SNBI->isConditional()) + break; + if (!isa<ReturnInst>(Body->getTerminator())) + break; + + Instruction *NextInst = next(BasicBlock::iterator(Call)); + + // Inline the call, taking care of what code ends up where. + NewBlock = SplitBlock(NextInst->getParent(), NextInst, this); + + bool B = InlineFunction(Call, 0, TD); + assert(B && "half_powr didn't inline?"); B=B; + + BasicBlock *NewBody = NewBlock->getSinglePredecessor(); + assert(NewBody); + Bodies.push_back(NewBody); + } + + if (!NewBlock) + return InsertPt; + + // Put the code for all the bodies into one block, to facilitate + // subsequent optimization. + (void)SplitEdge(NewBlock->getSinglePredecessor(), NewBlock, this); + for (unsigned i = 0, e = Bodies.size(); i != e; ++i) { + BasicBlock *Body = Bodies[i]; + Instruction *FNP = Body->getFirstNonPHI(); + // Splice the insts from body into NewBlock. + NewBlock->getInstList().splice(NewBlock->begin(), Body->getInstList(), + FNP, Body->getTerminator()); + } + + return NewBlock->begin(); +} + +/// runOnFunction - Top level algorithm. +/// +bool SimplifyHalfPowrLibCalls::runOnFunction(Function &F) { + TD = &getAnalysis<TargetData>(); + + bool Changed = false; + std::vector<Instruction *> HalfPowrs; + for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) { + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) { + // Look for calls. + bool IsHalfPowr = false; + if (CallInst *CI = dyn_cast<CallInst>(I)) { + // Look for direct calls and calls to non-external functions. + Function *Callee = CI->getCalledFunction(); + if (Callee && Callee->hasExternalLinkage()) { + // Look for calls with well-known names. + const char *CalleeName = Callee->getNameStart(); + if (strcmp(CalleeName, "__half_powrf4") == 0) + IsHalfPowr = true; + } + } + if (IsHalfPowr) + HalfPowrs.push_back(I); + // We're looking for sequences of up to three such calls, which we'll + // simplify as a group. + if ((!IsHalfPowr && !HalfPowrs.empty()) || HalfPowrs.size() == 3) { + I = InlineHalfPowrs(HalfPowrs, I); + E = I->getParent()->end(); + HalfPowrs.clear(); + Changed = true; + } + } + assert(HalfPowrs.empty() && "Block had no terminator!"); + } + + return Changed; +} diff --git a/lib/Transforms/Scalar/SimplifyLibCalls.cpp b/lib/Transforms/Scalar/SimplifyLibCalls.cpp new file mode 100644 index 0000000..4b00640 --- /dev/null +++ b/lib/Transforms/Scalar/SimplifyLibCalls.cpp @@ -0,0 +1,2429 @@ +//===- SimplifyLibCalls.cpp - Optimize specific well-known library calls --===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements a simple pass that applies a variety of small +// optimizations for calls to specific well-known function calls (e.g. runtime +// library functions). For example, a call to the function "exit(3)" that +// occurs within the main() function can be transformed into a simple "return 3" +// instruction. Any optimization that takes this form (replace call to library +// function with simpler code that provides the same result) belongs in this +// file. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "simplify-libcalls" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Intrinsics.h" +#include "llvm/Module.h" +#include "llvm/Pass.h" +#include "llvm/Support/IRBuilder.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Target/TargetData.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Config/config.h" +using namespace llvm; + +STATISTIC(NumSimplified, "Number of library calls simplified"); +STATISTIC(NumAnnotated, "Number of attributes added to library functions"); + +//===----------------------------------------------------------------------===// +// Optimizer Base Class +//===----------------------------------------------------------------------===// + +/// This class is the abstract base class for the set of optimizations that +/// corresponds to one library call. +namespace { +class VISIBILITY_HIDDEN LibCallOptimization { +protected: + Function *Caller; + const TargetData *TD; +public: + LibCallOptimization() { } + virtual ~LibCallOptimization() {} + + /// CallOptimizer - This pure virtual method is implemented by base classes to + /// do various optimizations. If this returns null then no transformation was + /// performed. If it returns CI, then it transformed the call and CI is to be + /// deleted. If it returns something else, replace CI with the new value and + /// delete CI. + virtual Value *CallOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) + =0; + + Value *OptimizeCall(CallInst *CI, const TargetData &TD, IRBuilder<> &B) { + Caller = CI->getParent()->getParent(); + this->TD = &TD; + return CallOptimizer(CI->getCalledFunction(), CI, B); + } + + /// CastToCStr - Return V if it is an i8*, otherwise cast it to i8*. + Value *CastToCStr(Value *V, IRBuilder<> &B); + + /// EmitStrLen - Emit a call to the strlen function to the builder, for the + /// specified pointer. Ptr is required to be some pointer type, and the + /// return value has 'intptr_t' type. + Value *EmitStrLen(Value *Ptr, IRBuilder<> &B); + + /// EmitMemCpy - Emit a call to the memcpy function to the builder. This + /// always expects that the size has type 'intptr_t' and Dst/Src are pointers. + Value *EmitMemCpy(Value *Dst, Value *Src, Value *Len, + unsigned Align, IRBuilder<> &B); + + /// EmitMemChr - Emit a call to the memchr function. This assumes that Ptr is + /// a pointer, Val is an i32 value, and Len is an 'intptr_t' value. + Value *EmitMemChr(Value *Ptr, Value *Val, Value *Len, IRBuilder<> &B); + + /// EmitMemCmp - Emit a call to the memcmp function. + Value *EmitMemCmp(Value *Ptr1, Value *Ptr2, Value *Len, IRBuilder<> &B); + + /// EmitMemSet - Emit a call to the memset function + Value *EmitMemSet(Value *Dst, Value *Val, Value *Len, IRBuilder<> &B); + + /// EmitUnaryFloatFnCall - Emit a call to the unary function named 'Name' (e.g. + /// 'floor'). This function is known to take a single of type matching 'Op' + /// and returns one value with the same type. If 'Op' is a long double, 'l' + /// is added as the suffix of name, if 'Op' is a float, we add a 'f' suffix. + Value *EmitUnaryFloatFnCall(Value *Op, const char *Name, IRBuilder<> &B); + + /// EmitPutChar - Emit a call to the putchar function. This assumes that Char + /// is an integer. + void EmitPutChar(Value *Char, IRBuilder<> &B); + + /// EmitPutS - Emit a call to the puts function. This assumes that Str is + /// some pointer. + void EmitPutS(Value *Str, IRBuilder<> &B); + + /// EmitFPutC - Emit a call to the fputc function. This assumes that Char is + /// an i32, and File is a pointer to FILE. + void EmitFPutC(Value *Char, Value *File, IRBuilder<> &B); + + /// EmitFPutS - Emit a call to the puts function. Str is required to be a + /// pointer and File is a pointer to FILE. + void EmitFPutS(Value *Str, Value *File, IRBuilder<> &B); + + /// EmitFWrite - Emit a call to the fwrite function. This assumes that Ptr is + /// a pointer, Size is an 'intptr_t', and File is a pointer to FILE. + void EmitFWrite(Value *Ptr, Value *Size, Value *File, IRBuilder<> &B); + +}; +} // End anonymous namespace. + +/// CastToCStr - Return V if it is an i8*, otherwise cast it to i8*. +Value *LibCallOptimization::CastToCStr(Value *V, IRBuilder<> &B) { + return B.CreateBitCast(V, PointerType::getUnqual(Type::Int8Ty), "cstr"); +} + +/// EmitStrLen - Emit a call to the strlen function to the builder, for the +/// specified pointer. This always returns an integer value of size intptr_t. +Value *LibCallOptimization::EmitStrLen(Value *Ptr, IRBuilder<> &B) { + Module *M = Caller->getParent(); + AttributeWithIndex AWI[2]; + AWI[0] = AttributeWithIndex::get(1, Attribute::NoCapture); + AWI[1] = AttributeWithIndex::get(~0u, Attribute::ReadOnly | + Attribute::NoUnwind); + + Constant *StrLen =M->getOrInsertFunction("strlen", AttrListPtr::get(AWI, 2), + TD->getIntPtrType(), + PointerType::getUnqual(Type::Int8Ty), + NULL); + return B.CreateCall(StrLen, CastToCStr(Ptr, B), "strlen"); +} + +/// EmitMemCpy - Emit a call to the memcpy function to the builder. This always +/// expects that the size has type 'intptr_t' and Dst/Src are pointers. +Value *LibCallOptimization::EmitMemCpy(Value *Dst, Value *Src, Value *Len, + unsigned Align, IRBuilder<> &B) { + Module *M = Caller->getParent(); + Intrinsic::ID IID = Intrinsic::memcpy; + const Type *Tys[1]; + Tys[0] = Len->getType(); + Value *MemCpy = Intrinsic::getDeclaration(M, IID, Tys, 1); + return B.CreateCall4(MemCpy, CastToCStr(Dst, B), CastToCStr(Src, B), Len, + ConstantInt::get(Type::Int32Ty, Align)); +} + +/// EmitMemChr - Emit a call to the memchr function. This assumes that Ptr is +/// a pointer, Val is an i32 value, and Len is an 'intptr_t' value. +Value *LibCallOptimization::EmitMemChr(Value *Ptr, Value *Val, + Value *Len, IRBuilder<> &B) { + Module *M = Caller->getParent(); + AttributeWithIndex AWI; + AWI = AttributeWithIndex::get(~0u, Attribute::ReadOnly | Attribute::NoUnwind); + + Value *MemChr = M->getOrInsertFunction("memchr", AttrListPtr::get(&AWI, 1), + PointerType::getUnqual(Type::Int8Ty), + PointerType::getUnqual(Type::Int8Ty), + Type::Int32Ty, TD->getIntPtrType(), + NULL); + return B.CreateCall3(MemChr, CastToCStr(Ptr, B), Val, Len, "memchr"); +} + +/// EmitMemCmp - Emit a call to the memcmp function. +Value *LibCallOptimization::EmitMemCmp(Value *Ptr1, Value *Ptr2, + Value *Len, IRBuilder<> &B) { + Module *M = Caller->getParent(); + AttributeWithIndex AWI[3]; + AWI[0] = AttributeWithIndex::get(1, Attribute::NoCapture); + AWI[1] = AttributeWithIndex::get(2, Attribute::NoCapture); + AWI[2] = AttributeWithIndex::get(~0u, Attribute::ReadOnly | + Attribute::NoUnwind); + + Value *MemCmp = M->getOrInsertFunction("memcmp", AttrListPtr::get(AWI, 3), + Type::Int32Ty, + PointerType::getUnqual(Type::Int8Ty), + PointerType::getUnqual(Type::Int8Ty), + TD->getIntPtrType(), NULL); + return B.CreateCall3(MemCmp, CastToCStr(Ptr1, B), CastToCStr(Ptr2, B), + Len, "memcmp"); +} + +/// EmitMemSet - Emit a call to the memset function +Value *LibCallOptimization::EmitMemSet(Value *Dst, Value *Val, + Value *Len, IRBuilder<> &B) { + Module *M = Caller->getParent(); + Intrinsic::ID IID = Intrinsic::memset; + const Type *Tys[1]; + Tys[0] = Len->getType(); + Value *MemSet = Intrinsic::getDeclaration(M, IID, Tys, 1); + Value *Align = ConstantInt::get(Type::Int32Ty, 1); + return B.CreateCall4(MemSet, CastToCStr(Dst, B), Val, Len, Align); +} + +/// EmitUnaryFloatFnCall - Emit a call to the unary function named 'Name' (e.g. +/// 'floor'). This function is known to take a single of type matching 'Op' and +/// returns one value with the same type. If 'Op' is a long double, 'l' is +/// added as the suffix of name, if 'Op' is a float, we add a 'f' suffix. +Value *LibCallOptimization::EmitUnaryFloatFnCall(Value *Op, const char *Name, + IRBuilder<> &B) { + char NameBuffer[20]; + if (Op->getType() != Type::DoubleTy) { + // If we need to add a suffix, copy into NameBuffer. + unsigned NameLen = strlen(Name); + assert(NameLen < sizeof(NameBuffer)-2); + memcpy(NameBuffer, Name, NameLen); + if (Op->getType() == Type::FloatTy) + NameBuffer[NameLen] = 'f'; // floorf + else + NameBuffer[NameLen] = 'l'; // floorl + NameBuffer[NameLen+1] = 0; + Name = NameBuffer; + } + + Module *M = Caller->getParent(); + Value *Callee = M->getOrInsertFunction(Name, Op->getType(), + Op->getType(), NULL); + return B.CreateCall(Callee, Op, Name); +} + +/// EmitPutChar - Emit a call to the putchar function. This assumes that Char +/// is an integer. +void LibCallOptimization::EmitPutChar(Value *Char, IRBuilder<> &B) { + Module *M = Caller->getParent(); + Value *F = M->getOrInsertFunction("putchar", Type::Int32Ty, + Type::Int32Ty, NULL); + B.CreateCall(F, B.CreateIntCast(Char, Type::Int32Ty, "chari"), "putchar"); +} + +/// EmitPutS - Emit a call to the puts function. This assumes that Str is +/// some pointer. +void LibCallOptimization::EmitPutS(Value *Str, IRBuilder<> &B) { + Module *M = Caller->getParent(); + AttributeWithIndex AWI[2]; + AWI[0] = AttributeWithIndex::get(1, Attribute::NoCapture); + AWI[1] = AttributeWithIndex::get(~0u, Attribute::NoUnwind); + + Value *F = M->getOrInsertFunction("puts", AttrListPtr::get(AWI, 2), + Type::Int32Ty, + PointerType::getUnqual(Type::Int8Ty), NULL); + B.CreateCall(F, CastToCStr(Str, B), "puts"); +} + +/// EmitFPutC - Emit a call to the fputc function. This assumes that Char is +/// an integer and File is a pointer to FILE. +void LibCallOptimization::EmitFPutC(Value *Char, Value *File, IRBuilder<> &B) { + Module *M = Caller->getParent(); + AttributeWithIndex AWI[2]; + AWI[0] = AttributeWithIndex::get(2, Attribute::NoCapture); + AWI[1] = AttributeWithIndex::get(~0u, Attribute::NoUnwind); + Constant *F; + if (isa<PointerType>(File->getType())) + F = M->getOrInsertFunction("fputc", AttrListPtr::get(AWI, 2), Type::Int32Ty, + Type::Int32Ty, File->getType(), NULL); + + else + F = M->getOrInsertFunction("fputc", Type::Int32Ty, Type::Int32Ty, + File->getType(), NULL); + Char = B.CreateIntCast(Char, Type::Int32Ty, "chari"); + B.CreateCall2(F, Char, File, "fputc"); +} + +/// EmitFPutS - Emit a call to the puts function. Str is required to be a +/// pointer and File is a pointer to FILE. +void LibCallOptimization::EmitFPutS(Value *Str, Value *File, IRBuilder<> &B) { + Module *M = Caller->getParent(); + AttributeWithIndex AWI[3]; + AWI[0] = AttributeWithIndex::get(1, Attribute::NoCapture); + AWI[1] = AttributeWithIndex::get(2, Attribute::NoCapture); + AWI[2] = AttributeWithIndex::get(~0u, Attribute::NoUnwind); + Constant *F; + if (isa<PointerType>(File->getType())) + F = M->getOrInsertFunction("fputs", AttrListPtr::get(AWI, 3), Type::Int32Ty, + PointerType::getUnqual(Type::Int8Ty), + File->getType(), NULL); + else + F = M->getOrInsertFunction("fputs", Type::Int32Ty, + PointerType::getUnqual(Type::Int8Ty), + File->getType(), NULL); + B.CreateCall2(F, CastToCStr(Str, B), File, "fputs"); +} + +/// EmitFWrite - Emit a call to the fwrite function. This assumes that Ptr is +/// a pointer, Size is an 'intptr_t', and File is a pointer to FILE. +void LibCallOptimization::EmitFWrite(Value *Ptr, Value *Size, Value *File, + IRBuilder<> &B) { + Module *M = Caller->getParent(); + AttributeWithIndex AWI[3]; + AWI[0] = AttributeWithIndex::get(1, Attribute::NoCapture); + AWI[1] = AttributeWithIndex::get(4, Attribute::NoCapture); + AWI[2] = AttributeWithIndex::get(~0u, Attribute::NoUnwind); + Constant *F; + if (isa<PointerType>(File->getType())) + F = M->getOrInsertFunction("fwrite", AttrListPtr::get(AWI, 3), + TD->getIntPtrType(), + PointerType::getUnqual(Type::Int8Ty), + TD->getIntPtrType(), TD->getIntPtrType(), + File->getType(), NULL); + else + F = M->getOrInsertFunction("fwrite", TD->getIntPtrType(), + PointerType::getUnqual(Type::Int8Ty), + TD->getIntPtrType(), TD->getIntPtrType(), + File->getType(), NULL); + B.CreateCall4(F, CastToCStr(Ptr, B), Size, + ConstantInt::get(TD->getIntPtrType(), 1), File); +} + +//===----------------------------------------------------------------------===// +// Helper Functions +//===----------------------------------------------------------------------===// + +/// GetStringLengthH - If we can compute the length of the string pointed to by +/// the specified pointer, return 'len+1'. If we can't, return 0. +static uint64_t GetStringLengthH(Value *V, SmallPtrSet<PHINode*, 32> &PHIs) { + // Look through noop bitcast instructions. + if (BitCastInst *BCI = dyn_cast<BitCastInst>(V)) + return GetStringLengthH(BCI->getOperand(0), PHIs); + + // If this is a PHI node, there are two cases: either we have already seen it + // or we haven't. + if (PHINode *PN = dyn_cast<PHINode>(V)) { + if (!PHIs.insert(PN)) + return ~0ULL; // already in the set. + + // If it was new, see if all the input strings are the same length. + uint64_t LenSoFar = ~0ULL; + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { + uint64_t Len = GetStringLengthH(PN->getIncomingValue(i), PHIs); + if (Len == 0) return 0; // Unknown length -> unknown. + + if (Len == ~0ULL) continue; + + if (Len != LenSoFar && LenSoFar != ~0ULL) + return 0; // Disagree -> unknown. + LenSoFar = Len; + } + + // Success, all agree. + return LenSoFar; + } + + // strlen(select(c,x,y)) -> strlen(x) ^ strlen(y) + if (SelectInst *SI = dyn_cast<SelectInst>(V)) { + uint64_t Len1 = GetStringLengthH(SI->getTrueValue(), PHIs); + if (Len1 == 0) return 0; + uint64_t Len2 = GetStringLengthH(SI->getFalseValue(), PHIs); + if (Len2 == 0) return 0; + if (Len1 == ~0ULL) return Len2; + if (Len2 == ~0ULL) return Len1; + if (Len1 != Len2) return 0; + return Len1; + } + + // If the value is not a GEP instruction nor a constant expression with a + // GEP instruction, then return unknown. + User *GEP = 0; + if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(V)) { + GEP = GEPI; + } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) { + if (CE->getOpcode() != Instruction::GetElementPtr) + return 0; + GEP = CE; + } else { + return 0; + } + + // Make sure the GEP has exactly three arguments. + if (GEP->getNumOperands() != 3) + return 0; + + // Check to make sure that the first operand of the GEP is an integer and + // has value 0 so that we are sure we're indexing into the initializer. + if (ConstantInt *Idx = dyn_cast<ConstantInt>(GEP->getOperand(1))) { + if (!Idx->isZero()) + return 0; + } else + return 0; + + // If the second index isn't a ConstantInt, then this is a variable index + // into the array. If this occurs, we can't say anything meaningful about + // the string. + uint64_t StartIdx = 0; + if (ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(2))) + StartIdx = CI->getZExtValue(); + else + return 0; + + // The GEP instruction, constant or instruction, must reference a global + // variable that is a constant and is initialized. The referenced constant + // initializer is the array that we'll use for optimization. + GlobalVariable* GV = dyn_cast<GlobalVariable>(GEP->getOperand(0)); + if (!GV || !GV->isConstant() || !GV->hasInitializer()) + return 0; + Constant *GlobalInit = GV->getInitializer(); + + // Handle the ConstantAggregateZero case, which is a degenerate case. The + // initializer is constant zero so the length of the string must be zero. + if (isa<ConstantAggregateZero>(GlobalInit)) + return 1; // Len = 0 offset by 1. + + // Must be a Constant Array + ConstantArray *Array = dyn_cast<ConstantArray>(GlobalInit); + if (!Array || Array->getType()->getElementType() != Type::Int8Ty) + return false; + + // Get the number of elements in the array + uint64_t NumElts = Array->getType()->getNumElements(); + + // Traverse the constant array from StartIdx (derived above) which is + // the place the GEP refers to in the array. + for (unsigned i = StartIdx; i != NumElts; ++i) { + Constant *Elt = Array->getOperand(i); + ConstantInt *CI = dyn_cast<ConstantInt>(Elt); + if (!CI) // This array isn't suitable, non-int initializer. + return 0; + if (CI->isZero()) + return i-StartIdx+1; // We found end of string, success! + } + + return 0; // The array isn't null terminated, conservatively return 'unknown'. +} + +/// GetStringLength - If we can compute the length of the string pointed to by +/// the specified pointer, return 'len+1'. If we can't, return 0. +static uint64_t GetStringLength(Value *V) { + if (!isa<PointerType>(V->getType())) return 0; + + SmallPtrSet<PHINode*, 32> PHIs; + uint64_t Len = GetStringLengthH(V, PHIs); + // If Len is ~0ULL, we had an infinite phi cycle: this is dead code, so return + // an empty string as a length. + return Len == ~0ULL ? 1 : Len; +} + +/// IsOnlyUsedInZeroEqualityComparison - Return true if it only matters that the +/// value is equal or not-equal to zero. +static bool IsOnlyUsedInZeroEqualityComparison(Value *V) { + for (Value::use_iterator UI = V->use_begin(), E = V->use_end(); + UI != E; ++UI) { + if (ICmpInst *IC = dyn_cast<ICmpInst>(*UI)) + if (IC->isEquality()) + if (Constant *C = dyn_cast<Constant>(IC->getOperand(1))) + if (C->isNullValue()) + continue; + // Unknown instruction. + return false; + } + return true; +} + +//===----------------------------------------------------------------------===// +// Miscellaneous LibCall Optimizations +//===----------------------------------------------------------------------===// + +namespace { +//===---------------------------------------===// +// 'exit' Optimizations + +/// ExitOpt - int main() { exit(4); } --> int main() { return 4; } +struct VISIBILITY_HIDDEN ExitOpt : public LibCallOptimization { + virtual Value *CallOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) { + // Verify we have a reasonable prototype for exit. + if (Callee->arg_size() == 0 || !CI->use_empty()) + return 0; + + // Verify the caller is main, and that the result type of main matches the + // argument type of exit. + if (!Caller->isName("main") || !Caller->hasExternalLinkage() || + Caller->getReturnType() != CI->getOperand(1)->getType()) + return 0; + + TerminatorInst *OldTI = CI->getParent()->getTerminator(); + + // Create the return after the call. + ReturnInst *RI = B.CreateRet(CI->getOperand(1)); + + // Drop all successor phi node entries. + for (unsigned i = 0, e = OldTI->getNumSuccessors(); i != e; ++i) + OldTI->getSuccessor(i)->removePredecessor(CI->getParent()); + + // Erase all instructions from after our return instruction until the end of + // the block. + BasicBlock::iterator FirstDead = RI; ++FirstDead; + CI->getParent()->getInstList().erase(FirstDead, CI->getParent()->end()); + return CI; + } +}; + +//===----------------------------------------------------------------------===// +// String and Memory LibCall Optimizations +//===----------------------------------------------------------------------===// + +//===---------------------------------------===// +// 'strcat' Optimizations + +struct VISIBILITY_HIDDEN StrCatOpt : public LibCallOptimization { + virtual Value *CallOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) { + // Verify the "strcat" function prototype. + const FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2 || + FT->getReturnType() != PointerType::getUnqual(Type::Int8Ty) || + FT->getParamType(0) != FT->getReturnType() || + FT->getParamType(1) != FT->getReturnType()) + return 0; + + // Extract some information from the instruction + Value *Dst = CI->getOperand(1); + Value *Src = CI->getOperand(2); + + // See if we can get the length of the input string. + uint64_t Len = GetStringLength(Src); + if (Len == 0) return 0; + --Len; // Unbias length. + + // Handle the simple, do-nothing case: strcat(x, "") -> x + if (Len == 0) + return Dst; + + EmitStrLenMemCpy(Src, Dst, Len, B); + return Dst; + } + + void EmitStrLenMemCpy(Value *Src, Value *Dst, uint64_t Len, IRBuilder<> &B) { + // We need to find the end of the destination string. That's where the + // memory is to be moved to. We just generate a call to strlen. + Value *DstLen = EmitStrLen(Dst, B); + + // Now that we have the destination's length, we must index into the + // destination's pointer to get the actual memcpy destination (end of + // the string .. we're concatenating). + Value *CpyDst = B.CreateGEP(Dst, DstLen, "endptr"); + + // We have enough information to now generate the memcpy call to do the + // concatenation for us. Make a memcpy to copy the nul byte with align = 1. + EmitMemCpy(CpyDst, Src, ConstantInt::get(TD->getIntPtrType(), Len+1), 1, B); + } +}; + +//===---------------------------------------===// +// 'strncat' Optimizations + +struct VISIBILITY_HIDDEN StrNCatOpt : public StrCatOpt { + virtual Value *CallOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) { + // Verify the "strncat" function prototype. + const FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 3 || + FT->getReturnType() != PointerType::getUnqual(Type::Int8Ty) || + FT->getParamType(0) != FT->getReturnType() || + FT->getParamType(1) != FT->getReturnType() || + !isa<IntegerType>(FT->getParamType(2))) + return 0; + + // Extract some information from the instruction + Value *Dst = CI->getOperand(1); + Value *Src = CI->getOperand(2); + uint64_t Len; + + // We don't do anything if length is not constant + if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(CI->getOperand(3))) + Len = LengthArg->getZExtValue(); + else + return 0; + + // See if we can get the length of the input string. + uint64_t SrcLen = GetStringLength(Src); + if (SrcLen == 0) return 0; + --SrcLen; // Unbias length. + + // Handle the simple, do-nothing cases: + // strncat(x, "", c) -> x + // strncat(x, c, 0) -> x + if (SrcLen == 0 || Len == 0) return Dst; + + // We don't optimize this case + if (Len < SrcLen) return 0; + + // strncat(x, s, c) -> strcat(x, s) + // s is constant so the strcat can be optimized further + EmitStrLenMemCpy(Src, Dst, SrcLen, B); + return Dst; + } +}; + +//===---------------------------------------===// +// 'strchr' Optimizations + +struct VISIBILITY_HIDDEN StrChrOpt : public LibCallOptimization { + virtual Value *CallOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) { + // Verify the "strchr" function prototype. + const FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2 || + FT->getReturnType() != PointerType::getUnqual(Type::Int8Ty) || + FT->getParamType(0) != FT->getReturnType()) + return 0; + + Value *SrcStr = CI->getOperand(1); + + // If the second operand is non-constant, see if we can compute the length + // of the input string and turn this into memchr. + ConstantInt *CharC = dyn_cast<ConstantInt>(CI->getOperand(2)); + if (CharC == 0) { + uint64_t Len = GetStringLength(SrcStr); + if (Len == 0 || FT->getParamType(1) != Type::Int32Ty) // memchr needs i32. + return 0; + + return EmitMemChr(SrcStr, CI->getOperand(2), // include nul. + ConstantInt::get(TD->getIntPtrType(), Len), B); + } + + // Otherwise, the character is a constant, see if the first argument is + // a string literal. If so, we can constant fold. + std::string Str; + if (!GetConstantStringInfo(SrcStr, Str)) + return 0; + + // strchr can find the nul character. + Str += '\0'; + char CharValue = CharC->getSExtValue(); + + // Compute the offset. + uint64_t i = 0; + while (1) { + if (i == Str.size()) // Didn't find the char. strchr returns null. + return Constant::getNullValue(CI->getType()); + // Did we find our match? + if (Str[i] == CharValue) + break; + ++i; + } + + // strchr(s+n,c) -> gep(s+n+i,c) + Value *Idx = ConstantInt::get(Type::Int64Ty, i); + return B.CreateGEP(SrcStr, Idx, "strchr"); + } +}; + +//===---------------------------------------===// +// 'strcmp' Optimizations + +struct VISIBILITY_HIDDEN StrCmpOpt : public LibCallOptimization { + virtual Value *CallOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) { + // Verify the "strcmp" function prototype. + const FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2 || FT->getReturnType() != Type::Int32Ty || + FT->getParamType(0) != FT->getParamType(1) || + FT->getParamType(0) != PointerType::getUnqual(Type::Int8Ty)) + return 0; + + Value *Str1P = CI->getOperand(1), *Str2P = CI->getOperand(2); + if (Str1P == Str2P) // strcmp(x,x) -> 0 + return ConstantInt::get(CI->getType(), 0); + + std::string Str1, Str2; + bool HasStr1 = GetConstantStringInfo(Str1P, Str1); + bool HasStr2 = GetConstantStringInfo(Str2P, Str2); + + if (HasStr1 && Str1.empty()) // strcmp("", x) -> *x + return B.CreateZExt(B.CreateLoad(Str2P, "strcmpload"), CI->getType()); + + if (HasStr2 && Str2.empty()) // strcmp(x,"") -> *x + return B.CreateZExt(B.CreateLoad(Str1P, "strcmpload"), CI->getType()); + + // strcmp(x, y) -> cnst (if both x and y are constant strings) + if (HasStr1 && HasStr2) + return ConstantInt::get(CI->getType(), strcmp(Str1.c_str(),Str2.c_str())); + + // strcmp(P, "x") -> memcmp(P, "x", 2) + uint64_t Len1 = GetStringLength(Str1P); + uint64_t Len2 = GetStringLength(Str2P); + if (Len1 || Len2) { + // Choose the smallest Len excluding 0 which means 'unknown'. + if (!Len1 || (Len2 && Len2 < Len1)) + Len1 = Len2; + return EmitMemCmp(Str1P, Str2P, + ConstantInt::get(TD->getIntPtrType(), Len1), B); + } + + return 0; + } +}; + +//===---------------------------------------===// +// 'strncmp' Optimizations + +struct VISIBILITY_HIDDEN StrNCmpOpt : public LibCallOptimization { + virtual Value *CallOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) { + // Verify the "strncmp" function prototype. + const FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 3 || FT->getReturnType() != Type::Int32Ty || + FT->getParamType(0) != FT->getParamType(1) || + FT->getParamType(0) != PointerType::getUnqual(Type::Int8Ty) || + !isa<IntegerType>(FT->getParamType(2))) + return 0; + + Value *Str1P = CI->getOperand(1), *Str2P = CI->getOperand(2); + if (Str1P == Str2P) // strncmp(x,x,n) -> 0 + return ConstantInt::get(CI->getType(), 0); + + // Get the length argument if it is constant. + uint64_t Length; + if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(CI->getOperand(3))) + Length = LengthArg->getZExtValue(); + else + return 0; + + if (Length == 0) // strncmp(x,y,0) -> 0 + return ConstantInt::get(CI->getType(), 0); + + std::string Str1, Str2; + bool HasStr1 = GetConstantStringInfo(Str1P, Str1); + bool HasStr2 = GetConstantStringInfo(Str2P, Str2); + + if (HasStr1 && Str1.empty()) // strncmp("", x, n) -> *x + return B.CreateZExt(B.CreateLoad(Str2P, "strcmpload"), CI->getType()); + + if (HasStr2 && Str2.empty()) // strncmp(x, "", n) -> *x + return B.CreateZExt(B.CreateLoad(Str1P, "strcmpload"), CI->getType()); + + // strncmp(x, y) -> cnst (if both x and y are constant strings) + if (HasStr1 && HasStr2) + return ConstantInt::get(CI->getType(), + strncmp(Str1.c_str(), Str2.c_str(), Length)); + return 0; + } +}; + + +//===---------------------------------------===// +// 'strcpy' Optimizations + +struct VISIBILITY_HIDDEN StrCpyOpt : public LibCallOptimization { + virtual Value *CallOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) { + // Verify the "strcpy" function prototype. + const FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2 || FT->getReturnType() != FT->getParamType(0) || + FT->getParamType(0) != FT->getParamType(1) || + FT->getParamType(0) != PointerType::getUnqual(Type::Int8Ty)) + return 0; + + Value *Dst = CI->getOperand(1), *Src = CI->getOperand(2); + if (Dst == Src) // strcpy(x,x) -> x + return Src; + + // See if we can get the length of the input string. + uint64_t Len = GetStringLength(Src); + if (Len == 0) return 0; + + // We have enough information to now generate the memcpy call to do the + // concatenation for us. Make a memcpy to copy the nul byte with align = 1. + EmitMemCpy(Dst, Src, ConstantInt::get(TD->getIntPtrType(), Len), 1, B); + return Dst; + } +}; + +//===---------------------------------------===// +// 'strncpy' Optimizations + +struct VISIBILITY_HIDDEN StrNCpyOpt : public LibCallOptimization { + virtual Value *CallOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) { + const FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 3 || FT->getReturnType() != FT->getParamType(0) || + FT->getParamType(0) != FT->getParamType(1) || + FT->getParamType(0) != PointerType::getUnqual(Type::Int8Ty) || + !isa<IntegerType>(FT->getParamType(2))) + return 0; + + Value *Dst = CI->getOperand(1); + Value *Src = CI->getOperand(2); + Value *LenOp = CI->getOperand(3); + + // See if we can get the length of the input string. + uint64_t SrcLen = GetStringLength(Src); + if (SrcLen == 0) return 0; + --SrcLen; + + if (SrcLen == 0) { + // strncpy(x, "", y) -> memset(x, '\0', y, 1) + EmitMemSet(Dst, ConstantInt::get(Type::Int8Ty, '\0'), LenOp, B); + return Dst; + } + + uint64_t Len; + if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(LenOp)) + Len = LengthArg->getZExtValue(); + else + return 0; + + if (Len == 0) return Dst; // strncpy(x, y, 0) -> x + + // Let strncpy handle the zero padding + if (Len > SrcLen+1) return 0; + + // strncpy(x, s, c) -> memcpy(x, s, c, 1) [s and c are constant] + EmitMemCpy(Dst, Src, ConstantInt::get(TD->getIntPtrType(), Len), 1, B); + + return Dst; + } +}; + +//===---------------------------------------===// +// 'strlen' Optimizations + +struct VISIBILITY_HIDDEN StrLenOpt : public LibCallOptimization { + virtual Value *CallOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) { + const FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 1 || + FT->getParamType(0) != PointerType::getUnqual(Type::Int8Ty) || + !isa<IntegerType>(FT->getReturnType())) + return 0; + + Value *Src = CI->getOperand(1); + + // Constant folding: strlen("xyz") -> 3 + if (uint64_t Len = GetStringLength(Src)) + return ConstantInt::get(CI->getType(), Len-1); + + // Handle strlen(p) != 0. + if (!IsOnlyUsedInZeroEqualityComparison(CI)) return 0; + + // strlen(x) != 0 --> *x != 0 + // strlen(x) == 0 --> *x == 0 + return B.CreateZExt(B.CreateLoad(Src, "strlenfirst"), CI->getType()); + } +}; + +//===---------------------------------------===// +// 'strto*' Optimizations + +struct VISIBILITY_HIDDEN StrToOpt : public LibCallOptimization { + virtual Value *CallOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) { + const FunctionType *FT = Callee->getFunctionType(); + if ((FT->getNumParams() != 2 && FT->getNumParams() != 3) || + !isa<PointerType>(FT->getParamType(0)) || + !isa<PointerType>(FT->getParamType(1))) + return 0; + + Value *EndPtr = CI->getOperand(2); + if (isa<ConstantPointerNull>(EndPtr)) { + CI->setOnlyReadsMemory(); + CI->addAttribute(1, Attribute::NoCapture); + } + + return 0; + } +}; + + +//===---------------------------------------===// +// 'memcmp' Optimizations + +struct VISIBILITY_HIDDEN MemCmpOpt : public LibCallOptimization { + virtual Value *CallOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) { + const FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 3 || !isa<PointerType>(FT->getParamType(0)) || + !isa<PointerType>(FT->getParamType(1)) || + FT->getReturnType() != Type::Int32Ty) + return 0; + + Value *LHS = CI->getOperand(1), *RHS = CI->getOperand(2); + + if (LHS == RHS) // memcmp(s,s,x) -> 0 + return Constant::getNullValue(CI->getType()); + + // Make sure we have a constant length. + ConstantInt *LenC = dyn_cast<ConstantInt>(CI->getOperand(3)); + if (!LenC) return 0; + uint64_t Len = LenC->getZExtValue(); + + if (Len == 0) // memcmp(s1,s2,0) -> 0 + return Constant::getNullValue(CI->getType()); + + if (Len == 1) { // memcmp(S1,S2,1) -> *LHS - *RHS + Value *LHSV = B.CreateLoad(CastToCStr(LHS, B), "lhsv"); + Value *RHSV = B.CreateLoad(CastToCStr(RHS, B), "rhsv"); + return B.CreateSExt(B.CreateSub(LHSV, RHSV, "chardiff"), CI->getType()); + } + + // memcmp(S1,S2,2) != 0 -> (*(short*)LHS ^ *(short*)RHS) != 0 + // memcmp(S1,S2,4) != 0 -> (*(int*)LHS ^ *(int*)RHS) != 0 + if ((Len == 2 || Len == 4) && IsOnlyUsedInZeroEqualityComparison(CI)) { + const Type *PTy = PointerType::getUnqual(Len == 2 ? + Type::Int16Ty : Type::Int32Ty); + LHS = B.CreateBitCast(LHS, PTy, "tmp"); + RHS = B.CreateBitCast(RHS, PTy, "tmp"); + LoadInst *LHSV = B.CreateLoad(LHS, "lhsv"); + LoadInst *RHSV = B.CreateLoad(RHS, "rhsv"); + LHSV->setAlignment(1); RHSV->setAlignment(1); // Unaligned loads. + return B.CreateZExt(B.CreateXor(LHSV, RHSV, "shortdiff"), CI->getType()); + } + + return 0; + } +}; + +//===---------------------------------------===// +// 'memcpy' Optimizations + +struct VISIBILITY_HIDDEN MemCpyOpt : public LibCallOptimization { + virtual Value *CallOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) { + const FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 3 || FT->getReturnType() != FT->getParamType(0) || + !isa<PointerType>(FT->getParamType(0)) || + !isa<PointerType>(FT->getParamType(1)) || + FT->getParamType(2) != TD->getIntPtrType()) + return 0; + + // memcpy(x, y, n) -> llvm.memcpy(x, y, n, 1) + EmitMemCpy(CI->getOperand(1), CI->getOperand(2), CI->getOperand(3), 1, B); + return CI->getOperand(1); + } +}; + +//===---------------------------------------===// +// 'memmove' Optimizations + +struct VISIBILITY_HIDDEN MemMoveOpt : public LibCallOptimization { + virtual Value *CallOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) { + const FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 3 || FT->getReturnType() != FT->getParamType(0) || + !isa<PointerType>(FT->getParamType(0)) || + !isa<PointerType>(FT->getParamType(1)) || + FT->getParamType(2) != TD->getIntPtrType()) + return 0; + + // memmove(x, y, n) -> llvm.memmove(x, y, n, 1) + Module *M = Caller->getParent(); + Intrinsic::ID IID = Intrinsic::memmove; + const Type *Tys[1]; + Tys[0] = TD->getIntPtrType(); + Value *MemMove = Intrinsic::getDeclaration(M, IID, Tys, 1); + Value *Dst = CastToCStr(CI->getOperand(1), B); + Value *Src = CastToCStr(CI->getOperand(2), B); + Value *Size = CI->getOperand(3); + Value *Align = ConstantInt::get(Type::Int32Ty, 1); + B.CreateCall4(MemMove, Dst, Src, Size, Align); + return CI->getOperand(1); + } +}; + +//===---------------------------------------===// +// 'memset' Optimizations + +struct VISIBILITY_HIDDEN MemSetOpt : public LibCallOptimization { + virtual Value *CallOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) { + const FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 3 || FT->getReturnType() != FT->getParamType(0) || + !isa<PointerType>(FT->getParamType(0)) || + FT->getParamType(1) != TD->getIntPtrType() || + FT->getParamType(2) != TD->getIntPtrType()) + return 0; + + // memset(p, v, n) -> llvm.memset(p, v, n, 1) + Value *Val = B.CreateTrunc(CI->getOperand(2), Type::Int8Ty); + EmitMemSet(CI->getOperand(1), Val, CI->getOperand(3), B); + return CI->getOperand(1); + } +}; + +//===----------------------------------------------------------------------===// +// Math Library Optimizations +//===----------------------------------------------------------------------===// + +//===---------------------------------------===// +// 'pow*' Optimizations + +struct VISIBILITY_HIDDEN PowOpt : public LibCallOptimization { + virtual Value *CallOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) { + const FunctionType *FT = Callee->getFunctionType(); + // Just make sure this has 2 arguments of the same FP type, which match the + // result type. + if (FT->getNumParams() != 2 || FT->getReturnType() != FT->getParamType(0) || + FT->getParamType(0) != FT->getParamType(1) || + !FT->getParamType(0)->isFloatingPoint()) + return 0; + + Value *Op1 = CI->getOperand(1), *Op2 = CI->getOperand(2); + if (ConstantFP *Op1C = dyn_cast<ConstantFP>(Op1)) { + if (Op1C->isExactlyValue(1.0)) // pow(1.0, x) -> 1.0 + return Op1C; + if (Op1C->isExactlyValue(2.0)) // pow(2.0, x) -> exp2(x) + return EmitUnaryFloatFnCall(Op2, "exp2", B); + } + + ConstantFP *Op2C = dyn_cast<ConstantFP>(Op2); + if (Op2C == 0) return 0; + + if (Op2C->getValueAPF().isZero()) // pow(x, 0.0) -> 1.0 + return ConstantFP::get(CI->getType(), 1.0); + + if (Op2C->isExactlyValue(0.5)) { + // FIXME: This is not safe for -0.0 and -inf. This can only be done when + // 'unsafe' math optimizations are allowed. + // x pow(x, 0.5) sqrt(x) + // --------------------------------------------- + // -0.0 +0.0 -0.0 + // -inf +inf NaN +#if 0 + // pow(x, 0.5) -> sqrt(x) + return B.CreateCall(get_sqrt(), Op1, "sqrt"); +#endif + } + + if (Op2C->isExactlyValue(1.0)) // pow(x, 1.0) -> x + return Op1; + if (Op2C->isExactlyValue(2.0)) // pow(x, 2.0) -> x*x + return B.CreateMul(Op1, Op1, "pow2"); + if (Op2C->isExactlyValue(-1.0)) // pow(x, -1.0) -> 1.0/x + return B.CreateFDiv(ConstantFP::get(CI->getType(), 1.0), Op1, "powrecip"); + return 0; + } +}; + +//===---------------------------------------===// +// 'exp2' Optimizations + +struct VISIBILITY_HIDDEN Exp2Opt : public LibCallOptimization { + virtual Value *CallOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) { + const FunctionType *FT = Callee->getFunctionType(); + // Just make sure this has 1 argument of FP type, which matches the + // result type. + if (FT->getNumParams() != 1 || FT->getReturnType() != FT->getParamType(0) || + !FT->getParamType(0)->isFloatingPoint()) + return 0; + + Value *Op = CI->getOperand(1); + // Turn exp2(sitofp(x)) -> ldexp(1.0, sext(x)) if sizeof(x) <= 32 + // Turn exp2(uitofp(x)) -> ldexp(1.0, zext(x)) if sizeof(x) < 32 + Value *LdExpArg = 0; + if (SIToFPInst *OpC = dyn_cast<SIToFPInst>(Op)) { + if (OpC->getOperand(0)->getType()->getPrimitiveSizeInBits() <= 32) + LdExpArg = B.CreateSExt(OpC->getOperand(0), Type::Int32Ty, "tmp"); + } else if (UIToFPInst *OpC = dyn_cast<UIToFPInst>(Op)) { + if (OpC->getOperand(0)->getType()->getPrimitiveSizeInBits() < 32) + LdExpArg = B.CreateZExt(OpC->getOperand(0), Type::Int32Ty, "tmp"); + } + + if (LdExpArg) { + const char *Name; + if (Op->getType() == Type::FloatTy) + Name = "ldexpf"; + else if (Op->getType() == Type::DoubleTy) + Name = "ldexp"; + else + Name = "ldexpl"; + + Constant *One = ConstantFP::get(APFloat(1.0f)); + if (Op->getType() != Type::FloatTy) + One = ConstantExpr::getFPExtend(One, Op->getType()); + + Module *M = Caller->getParent(); + Value *Callee = M->getOrInsertFunction(Name, Op->getType(), + Op->getType(), Type::Int32Ty,NULL); + return B.CreateCall2(Callee, One, LdExpArg); + } + return 0; + } +}; + + +//===---------------------------------------===// +// Double -> Float Shrinking Optimizations for Unary Functions like 'floor' + +struct VISIBILITY_HIDDEN UnaryDoubleFPOpt : public LibCallOptimization { + virtual Value *CallOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) { + const FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 1 || FT->getReturnType() != Type::DoubleTy || + FT->getParamType(0) != Type::DoubleTy) + return 0; + + // If this is something like 'floor((double)floatval)', convert to floorf. + FPExtInst *Cast = dyn_cast<FPExtInst>(CI->getOperand(1)); + if (Cast == 0 || Cast->getOperand(0)->getType() != Type::FloatTy) + return 0; + + // floor((double)floatval) -> (double)floorf(floatval) + Value *V = Cast->getOperand(0); + V = EmitUnaryFloatFnCall(V, Callee->getNameStart(), B); + return B.CreateFPExt(V, Type::DoubleTy); + } +}; + +//===----------------------------------------------------------------------===// +// Integer Optimizations +//===----------------------------------------------------------------------===// + +//===---------------------------------------===// +// 'ffs*' Optimizations + +struct VISIBILITY_HIDDEN FFSOpt : public LibCallOptimization { + virtual Value *CallOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) { + const FunctionType *FT = Callee->getFunctionType(); + // Just make sure this has 2 arguments of the same FP type, which match the + // result type. + if (FT->getNumParams() != 1 || FT->getReturnType() != Type::Int32Ty || + !isa<IntegerType>(FT->getParamType(0))) + return 0; + + Value *Op = CI->getOperand(1); + + // Constant fold. + if (ConstantInt *CI = dyn_cast<ConstantInt>(Op)) { + if (CI->getValue() == 0) // ffs(0) -> 0. + return Constant::getNullValue(CI->getType()); + return ConstantInt::get(Type::Int32Ty, // ffs(c) -> cttz(c)+1 + CI->getValue().countTrailingZeros()+1); + } + + // ffs(x) -> x != 0 ? (i32)llvm.cttz(x)+1 : 0 + const Type *ArgType = Op->getType(); + Value *F = Intrinsic::getDeclaration(Callee->getParent(), + Intrinsic::cttz, &ArgType, 1); + Value *V = B.CreateCall(F, Op, "cttz"); + V = B.CreateAdd(V, ConstantInt::get(V->getType(), 1), "tmp"); + V = B.CreateIntCast(V, Type::Int32Ty, false, "tmp"); + + Value *Cond = B.CreateICmpNE(Op, Constant::getNullValue(ArgType), "tmp"); + return B.CreateSelect(Cond, V, ConstantInt::get(Type::Int32Ty, 0)); + } +}; + +//===---------------------------------------===// +// 'isdigit' Optimizations + +struct VISIBILITY_HIDDEN IsDigitOpt : public LibCallOptimization { + virtual Value *CallOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) { + const FunctionType *FT = Callee->getFunctionType(); + // We require integer(i32) + if (FT->getNumParams() != 1 || !isa<IntegerType>(FT->getReturnType()) || + FT->getParamType(0) != Type::Int32Ty) + return 0; + + // isdigit(c) -> (c-'0') <u 10 + Value *Op = CI->getOperand(1); + Op = B.CreateSub(Op, ConstantInt::get(Type::Int32Ty, '0'), "isdigittmp"); + Op = B.CreateICmpULT(Op, ConstantInt::get(Type::Int32Ty, 10), "isdigit"); + return B.CreateZExt(Op, CI->getType()); + } +}; + +//===---------------------------------------===// +// 'isascii' Optimizations + +struct VISIBILITY_HIDDEN IsAsciiOpt : public LibCallOptimization { + virtual Value *CallOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) { + const FunctionType *FT = Callee->getFunctionType(); + // We require integer(i32) + if (FT->getNumParams() != 1 || !isa<IntegerType>(FT->getReturnType()) || + FT->getParamType(0) != Type::Int32Ty) + return 0; + + // isascii(c) -> c <u 128 + Value *Op = CI->getOperand(1); + Op = B.CreateICmpULT(Op, ConstantInt::get(Type::Int32Ty, 128), "isascii"); + return B.CreateZExt(Op, CI->getType()); + } +}; + +//===---------------------------------------===// +// 'abs', 'labs', 'llabs' Optimizations + +struct VISIBILITY_HIDDEN AbsOpt : public LibCallOptimization { + virtual Value *CallOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) { + const FunctionType *FT = Callee->getFunctionType(); + // We require integer(integer) where the types agree. + if (FT->getNumParams() != 1 || !isa<IntegerType>(FT->getReturnType()) || + FT->getParamType(0) != FT->getReturnType()) + return 0; + + // abs(x) -> x >s -1 ? x : -x + Value *Op = CI->getOperand(1); + Value *Pos = B.CreateICmpSGT(Op,ConstantInt::getAllOnesValue(Op->getType()), + "ispos"); + Value *Neg = B.CreateNeg(Op, "neg"); + return B.CreateSelect(Pos, Op, Neg); + } +}; + + +//===---------------------------------------===// +// 'toascii' Optimizations + +struct VISIBILITY_HIDDEN ToAsciiOpt : public LibCallOptimization { + virtual Value *CallOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) { + const FunctionType *FT = Callee->getFunctionType(); + // We require i32(i32) + if (FT->getNumParams() != 1 || FT->getReturnType() != FT->getParamType(0) || + FT->getParamType(0) != Type::Int32Ty) + return 0; + + // isascii(c) -> c & 0x7f + return B.CreateAnd(CI->getOperand(1), ConstantInt::get(CI->getType(),0x7F)); + } +}; + +//===----------------------------------------------------------------------===// +// Formatting and IO Optimizations +//===----------------------------------------------------------------------===// + +//===---------------------------------------===// +// 'printf' Optimizations + +struct VISIBILITY_HIDDEN PrintFOpt : public LibCallOptimization { + virtual Value *CallOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) { + // Require one fixed pointer argument and an integer/void result. + const FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() < 1 || !isa<PointerType>(FT->getParamType(0)) || + !(isa<IntegerType>(FT->getReturnType()) || + FT->getReturnType() == Type::VoidTy)) + return 0; + + // Check for a fixed format string. + std::string FormatStr; + if (!GetConstantStringInfo(CI->getOperand(1), FormatStr)) + return 0; + + // Empty format string -> noop. + if (FormatStr.empty()) // Tolerate printf's declared void. + return CI->use_empty() ? (Value*)CI : ConstantInt::get(CI->getType(), 0); + + // printf("x") -> putchar('x'), even for '%'. + if (FormatStr.size() == 1) { + EmitPutChar(ConstantInt::get(Type::Int32Ty, FormatStr[0]), B); + return CI->use_empty() ? (Value*)CI : ConstantInt::get(CI->getType(), 1); + } + + // printf("foo\n") --> puts("foo") + if (FormatStr[FormatStr.size()-1] == '\n' && + FormatStr.find('%') == std::string::npos) { // no format characters. + // Create a string literal with no \n on it. We expect the constant merge + // pass to be run after this pass, to merge duplicate strings. + FormatStr.erase(FormatStr.end()-1); + Constant *C = ConstantArray::get(FormatStr, true); + C = new GlobalVariable(C->getType(), true,GlobalVariable::InternalLinkage, + C, "str", Callee->getParent()); + EmitPutS(C, B); + return CI->use_empty() ? (Value*)CI : + ConstantInt::get(CI->getType(), FormatStr.size()+1); + } + + // Optimize specific format strings. + // printf("%c", chr) --> putchar(*(i8*)dst) + if (FormatStr == "%c" && CI->getNumOperands() > 2 && + isa<IntegerType>(CI->getOperand(2)->getType())) { + EmitPutChar(CI->getOperand(2), B); + return CI->use_empty() ? (Value*)CI : ConstantInt::get(CI->getType(), 1); + } + + // printf("%s\n", str) --> puts(str) + if (FormatStr == "%s\n" && CI->getNumOperands() > 2 && + isa<PointerType>(CI->getOperand(2)->getType()) && + CI->use_empty()) { + EmitPutS(CI->getOperand(2), B); + return CI; + } + return 0; + } +}; + +//===---------------------------------------===// +// 'sprintf' Optimizations + +struct VISIBILITY_HIDDEN SPrintFOpt : public LibCallOptimization { + virtual Value *CallOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) { + // Require two fixed pointer arguments and an integer result. + const FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2 || !isa<PointerType>(FT->getParamType(0)) || + !isa<PointerType>(FT->getParamType(1)) || + !isa<IntegerType>(FT->getReturnType())) + return 0; + + // Check for a fixed format string. + std::string FormatStr; + if (!GetConstantStringInfo(CI->getOperand(2), FormatStr)) + return 0; + + // If we just have a format string (nothing else crazy) transform it. + if (CI->getNumOperands() == 3) { + // Make sure there's no % in the constant array. We could try to handle + // %% -> % in the future if we cared. + for (unsigned i = 0, e = FormatStr.size(); i != e; ++i) + if (FormatStr[i] == '%') + return 0; // we found a format specifier, bail out. + + // sprintf(str, fmt) -> llvm.memcpy(str, fmt, strlen(fmt)+1, 1) + EmitMemCpy(CI->getOperand(1), CI->getOperand(2), // Copy the nul byte. + ConstantInt::get(TD->getIntPtrType(), FormatStr.size()+1),1,B); + return ConstantInt::get(CI->getType(), FormatStr.size()); + } + + // The remaining optimizations require the format string to be "%s" or "%c" + // and have an extra operand. + if (FormatStr.size() != 2 || FormatStr[0] != '%' || CI->getNumOperands() <4) + return 0; + + // Decode the second character of the format string. + if (FormatStr[1] == 'c') { + // sprintf(dst, "%c", chr) --> *(i8*)dst = chr; *((i8*)dst+1) = 0 + if (!isa<IntegerType>(CI->getOperand(3)->getType())) return 0; + Value *V = B.CreateTrunc(CI->getOperand(3), Type::Int8Ty, "char"); + Value *Ptr = CastToCStr(CI->getOperand(1), B); + B.CreateStore(V, Ptr); + Ptr = B.CreateGEP(Ptr, ConstantInt::get(Type::Int32Ty, 1), "nul"); + B.CreateStore(Constant::getNullValue(Type::Int8Ty), Ptr); + + return ConstantInt::get(CI->getType(), 1); + } + + if (FormatStr[1] == 's') { + // sprintf(dest, "%s", str) -> llvm.memcpy(dest, str, strlen(str)+1, 1) + if (!isa<PointerType>(CI->getOperand(3)->getType())) return 0; + + Value *Len = EmitStrLen(CI->getOperand(3), B); + Value *IncLen = B.CreateAdd(Len, ConstantInt::get(Len->getType(), 1), + "leninc"); + EmitMemCpy(CI->getOperand(1), CI->getOperand(3), IncLen, 1, B); + + // The sprintf result is the unincremented number of bytes in the string. + return B.CreateIntCast(Len, CI->getType(), false); + } + return 0; + } +}; + +//===---------------------------------------===// +// 'fwrite' Optimizations + +struct VISIBILITY_HIDDEN FWriteOpt : public LibCallOptimization { + virtual Value *CallOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) { + // Require a pointer, an integer, an integer, a pointer, returning integer. + const FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 4 || !isa<PointerType>(FT->getParamType(0)) || + !isa<IntegerType>(FT->getParamType(1)) || + !isa<IntegerType>(FT->getParamType(2)) || + !isa<PointerType>(FT->getParamType(3)) || + !isa<IntegerType>(FT->getReturnType())) + return 0; + + // Get the element size and count. + ConstantInt *SizeC = dyn_cast<ConstantInt>(CI->getOperand(2)); + ConstantInt *CountC = dyn_cast<ConstantInt>(CI->getOperand(3)); + if (!SizeC || !CountC) return 0; + uint64_t Bytes = SizeC->getZExtValue()*CountC->getZExtValue(); + + // If this is writing zero records, remove the call (it's a noop). + if (Bytes == 0) + return ConstantInt::get(CI->getType(), 0); + + // If this is writing one byte, turn it into fputc. + if (Bytes == 1) { // fwrite(S,1,1,F) -> fputc(S[0],F) + Value *Char = B.CreateLoad(CastToCStr(CI->getOperand(1), B), "char"); + EmitFPutC(Char, CI->getOperand(4), B); + return ConstantInt::get(CI->getType(), 1); + } + + return 0; + } +}; + +//===---------------------------------------===// +// 'fputs' Optimizations + +struct VISIBILITY_HIDDEN FPutsOpt : public LibCallOptimization { + virtual Value *CallOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) { + // Require two pointers. Also, we can't optimize if return value is used. + const FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2 || !isa<PointerType>(FT->getParamType(0)) || + !isa<PointerType>(FT->getParamType(1)) || + !CI->use_empty()) + return 0; + + // fputs(s,F) --> fwrite(s,1,strlen(s),F) + uint64_t Len = GetStringLength(CI->getOperand(1)); + if (!Len) return 0; + EmitFWrite(CI->getOperand(1), ConstantInt::get(TD->getIntPtrType(), Len-1), + CI->getOperand(2), B); + return CI; // Known to have no uses (see above). + } +}; + +//===---------------------------------------===// +// 'fprintf' Optimizations + +struct VISIBILITY_HIDDEN FPrintFOpt : public LibCallOptimization { + virtual Value *CallOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) { + // Require two fixed paramters as pointers and integer result. + const FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2 || !isa<PointerType>(FT->getParamType(0)) || + !isa<PointerType>(FT->getParamType(1)) || + !isa<IntegerType>(FT->getReturnType())) + return 0; + + // All the optimizations depend on the format string. + std::string FormatStr; + if (!GetConstantStringInfo(CI->getOperand(2), FormatStr)) + return 0; + + // fprintf(F, "foo") --> fwrite("foo", 3, 1, F) + if (CI->getNumOperands() == 3) { + for (unsigned i = 0, e = FormatStr.size(); i != e; ++i) + if (FormatStr[i] == '%') // Could handle %% -> % if we cared. + return 0; // We found a format specifier. + + EmitFWrite(CI->getOperand(2), ConstantInt::get(TD->getIntPtrType(), + FormatStr.size()), + CI->getOperand(1), B); + return ConstantInt::get(CI->getType(), FormatStr.size()); + } + + // The remaining optimizations require the format string to be "%s" or "%c" + // and have an extra operand. + if (FormatStr.size() != 2 || FormatStr[0] != '%' || CI->getNumOperands() <4) + return 0; + + // Decode the second character of the format string. + if (FormatStr[1] == 'c') { + // fprintf(F, "%c", chr) --> *(i8*)dst = chr + if (!isa<IntegerType>(CI->getOperand(3)->getType())) return 0; + EmitFPutC(CI->getOperand(3), CI->getOperand(1), B); + return ConstantInt::get(CI->getType(), 1); + } + + if (FormatStr[1] == 's') { + // fprintf(F, "%s", str) -> fputs(str, F) + if (!isa<PointerType>(CI->getOperand(3)->getType()) || !CI->use_empty()) + return 0; + EmitFPutS(CI->getOperand(3), CI->getOperand(1), B); + return CI; + } + return 0; + } +}; + +} // end anonymous namespace. + +//===----------------------------------------------------------------------===// +// SimplifyLibCalls Pass Implementation +//===----------------------------------------------------------------------===// + +namespace { + /// This pass optimizes well known library functions from libc and libm. + /// + class VISIBILITY_HIDDEN SimplifyLibCalls : public FunctionPass { + StringMap<LibCallOptimization*> Optimizations; + // Miscellaneous LibCall Optimizations + ExitOpt Exit; + // String and Memory LibCall Optimizations + StrCatOpt StrCat; StrNCatOpt StrNCat; StrChrOpt StrChr; StrCmpOpt StrCmp; + StrNCmpOpt StrNCmp; StrCpyOpt StrCpy; StrNCpyOpt StrNCpy; StrLenOpt StrLen; + StrToOpt StrTo; MemCmpOpt MemCmp; MemCpyOpt MemCpy; MemMoveOpt MemMove; + MemSetOpt MemSet; + // Math Library Optimizations + PowOpt Pow; Exp2Opt Exp2; UnaryDoubleFPOpt UnaryDoubleFP; + // Integer Optimizations + FFSOpt FFS; AbsOpt Abs; IsDigitOpt IsDigit; IsAsciiOpt IsAscii; + ToAsciiOpt ToAscii; + // Formatting and IO Optimizations + SPrintFOpt SPrintF; PrintFOpt PrintF; + FWriteOpt FWrite; FPutsOpt FPuts; FPrintFOpt FPrintF; + + bool Modified; // This is only used by doInitialization. + public: + static char ID; // Pass identification + SimplifyLibCalls() : FunctionPass(&ID) {} + + void InitOptimizations(); + bool runOnFunction(Function &F); + + void setDoesNotAccessMemory(Function &F); + void setOnlyReadsMemory(Function &F); + void setDoesNotThrow(Function &F); + void setDoesNotCapture(Function &F, unsigned n); + void setDoesNotAlias(Function &F, unsigned n); + bool doInitialization(Module &M); + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired<TargetData>(); + } + }; + char SimplifyLibCalls::ID = 0; +} // end anonymous namespace. + +static RegisterPass<SimplifyLibCalls> +X("simplify-libcalls", "Simplify well-known library calls"); + +// Public interface to the Simplify LibCalls pass. +FunctionPass *llvm::createSimplifyLibCallsPass() { + return new SimplifyLibCalls(); +} + +/// Optimizations - Populate the Optimizations map with all the optimizations +/// we know. +void SimplifyLibCalls::InitOptimizations() { + // Miscellaneous LibCall Optimizations + Optimizations["exit"] = &Exit; + + // String and Memory LibCall Optimizations + Optimizations["strcat"] = &StrCat; + Optimizations["strncat"] = &StrNCat; + Optimizations["strchr"] = &StrChr; + Optimizations["strcmp"] = &StrCmp; + Optimizations["strncmp"] = &StrNCmp; + Optimizations["strcpy"] = &StrCpy; + Optimizations["strncpy"] = &StrNCpy; + Optimizations["strlen"] = &StrLen; + Optimizations["strtol"] = &StrTo; + Optimizations["strtod"] = &StrTo; + Optimizations["strtof"] = &StrTo; + Optimizations["strtoul"] = &StrTo; + Optimizations["strtoll"] = &StrTo; + Optimizations["strtold"] = &StrTo; + Optimizations["strtoull"] = &StrTo; + Optimizations["memcmp"] = &MemCmp; + Optimizations["memcpy"] = &MemCpy; + Optimizations["memmove"] = &MemMove; + Optimizations["memset"] = &MemSet; + + // Math Library Optimizations + Optimizations["powf"] = &Pow; + Optimizations["pow"] = &Pow; + Optimizations["powl"] = &Pow; + Optimizations["llvm.pow.f32"] = &Pow; + Optimizations["llvm.pow.f64"] = &Pow; + Optimizations["llvm.pow.f80"] = &Pow; + Optimizations["llvm.pow.f128"] = &Pow; + Optimizations["llvm.pow.ppcf128"] = &Pow; + Optimizations["exp2l"] = &Exp2; + Optimizations["exp2"] = &Exp2; + Optimizations["exp2f"] = &Exp2; + Optimizations["llvm.exp2.ppcf128"] = &Exp2; + Optimizations["llvm.exp2.f128"] = &Exp2; + Optimizations["llvm.exp2.f80"] = &Exp2; + Optimizations["llvm.exp2.f64"] = &Exp2; + Optimizations["llvm.exp2.f32"] = &Exp2; + +#ifdef HAVE_FLOORF + Optimizations["floor"] = &UnaryDoubleFP; +#endif +#ifdef HAVE_CEILF + Optimizations["ceil"] = &UnaryDoubleFP; +#endif +#ifdef HAVE_ROUNDF + Optimizations["round"] = &UnaryDoubleFP; +#endif +#ifdef HAVE_RINTF + Optimizations["rint"] = &UnaryDoubleFP; +#endif +#ifdef HAVE_NEARBYINTF + Optimizations["nearbyint"] = &UnaryDoubleFP; +#endif + + // Integer Optimizations + Optimizations["ffs"] = &FFS; + Optimizations["ffsl"] = &FFS; + Optimizations["ffsll"] = &FFS; + Optimizations["abs"] = &Abs; + Optimizations["labs"] = &Abs; + Optimizations["llabs"] = &Abs; + Optimizations["isdigit"] = &IsDigit; + Optimizations["isascii"] = &IsAscii; + Optimizations["toascii"] = &ToAscii; + + // Formatting and IO Optimizations + Optimizations["sprintf"] = &SPrintF; + Optimizations["printf"] = &PrintF; + Optimizations["fwrite"] = &FWrite; + Optimizations["fputs"] = &FPuts; + Optimizations["fprintf"] = &FPrintF; +} + + +/// runOnFunction - Top level algorithm. +/// +bool SimplifyLibCalls::runOnFunction(Function &F) { + if (Optimizations.empty()) + InitOptimizations(); + + const TargetData &TD = getAnalysis<TargetData>(); + + IRBuilder<> Builder; + + bool Changed = false; + for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) { + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ) { + // Ignore non-calls. + CallInst *CI = dyn_cast<CallInst>(I++); + if (!CI) continue; + + // Ignore indirect calls and calls to non-external functions. + Function *Callee = CI->getCalledFunction(); + if (Callee == 0 || !Callee->isDeclaration() || + !(Callee->hasExternalLinkage() || Callee->hasDLLImportLinkage())) + continue; + + // Ignore unknown calls. + const char *CalleeName = Callee->getNameStart(); + StringMap<LibCallOptimization*>::iterator OMI = + Optimizations.find(CalleeName, CalleeName+Callee->getNameLen()); + if (OMI == Optimizations.end()) continue; + + // Set the builder to the instruction after the call. + Builder.SetInsertPoint(BB, I); + + // Try to optimize this call. + Value *Result = OMI->second->OptimizeCall(CI, TD, Builder); + if (Result == 0) continue; + + DEBUG(DOUT << "SimplifyLibCalls simplified: " << *CI; + DOUT << " into: " << *Result << "\n"); + + // Something changed! + Changed = true; + ++NumSimplified; + + // Inspect the instruction after the call (which was potentially just + // added) next. + I = CI; ++I; + + if (CI != Result && !CI->use_empty()) { + CI->replaceAllUsesWith(Result); + if (!Result->hasName()) + Result->takeName(CI); + } + CI->eraseFromParent(); + } + } + return Changed; +} + +// Utility methods for doInitialization. + +void SimplifyLibCalls::setDoesNotAccessMemory(Function &F) { + if (!F.doesNotAccessMemory()) { + F.setDoesNotAccessMemory(); + ++NumAnnotated; + Modified = true; + } +} +void SimplifyLibCalls::setOnlyReadsMemory(Function &F) { + if (!F.onlyReadsMemory()) { + F.setOnlyReadsMemory(); + ++NumAnnotated; + Modified = true; + } +} +void SimplifyLibCalls::setDoesNotThrow(Function &F) { + if (!F.doesNotThrow()) { + F.setDoesNotThrow(); + ++NumAnnotated; + Modified = true; + } +} +void SimplifyLibCalls::setDoesNotCapture(Function &F, unsigned n) { + if (!F.doesNotCapture(n)) { + F.setDoesNotCapture(n); + ++NumAnnotated; + Modified = true; + } +} +void SimplifyLibCalls::setDoesNotAlias(Function &F, unsigned n) { + if (!F.doesNotAlias(n)) { + F.setDoesNotAlias(n); + ++NumAnnotated; + Modified = true; + } +} + +/// doInitialization - Add attributes to well-known functions. +/// +bool SimplifyLibCalls::doInitialization(Module &M) { + Modified = false; + for (Module::iterator I = M.begin(), E = M.end(); I != E; ++I) { + Function &F = *I; + if (!F.isDeclaration()) + continue; + + unsigned NameLen = F.getNameLen(); + if (!NameLen) + continue; + + const FunctionType *FTy = F.getFunctionType(); + + const char *NameStr = F.getNameStart(); + switch (NameStr[0]) { + case 's': + if (NameLen == 6 && !strcmp(NameStr, "strlen")) { + if (FTy->getNumParams() != 1 || + !isa<PointerType>(FTy->getParamType(0))) + continue; + setOnlyReadsMemory(F); + setDoesNotThrow(F); + setDoesNotCapture(F, 1); + } else if ((NameLen == 6 && !strcmp(NameStr, "strcpy")) || + (NameLen == 6 && !strcmp(NameStr, "stpcpy")) || + (NameLen == 6 && !strcmp(NameStr, "strcat")) || + (NameLen == 6 && !strcmp(NameStr, "strtol")) || + (NameLen == 6 && !strcmp(NameStr, "strtod")) || + (NameLen == 6 && !strcmp(NameStr, "strtof")) || + (NameLen == 7 && !strcmp(NameStr, "strtoul")) || + (NameLen == 7 && !strcmp(NameStr, "strtoll")) || + (NameLen == 7 && !strcmp(NameStr, "strtold")) || + (NameLen == 7 && !strcmp(NameStr, "strncat")) || + (NameLen == 7 && !strcmp(NameStr, "strncpy")) || + (NameLen == 8 && !strcmp(NameStr, "strtoull"))) { + if (FTy->getNumParams() < 2 || + !isa<PointerType>(FTy->getParamType(1))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 2); + } else if (NameLen == 7 && !strcmp(NameStr, "strxfrm")) { + if (FTy->getNumParams() != 3 || + !isa<PointerType>(FTy->getParamType(0)) || + !isa<PointerType>(FTy->getParamType(1))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 1); + setDoesNotCapture(F, 2); + } else if ((NameLen == 6 && !strcmp(NameStr, "strcmp")) || + (NameLen == 6 && !strcmp(NameStr, "strspn")) || + (NameLen == 7 && !strcmp(NameStr, "strncmp")) || + (NameLen == 7 && !strcmp(NameStr, "strcspn")) || + (NameLen == 7 && !strcmp(NameStr, "strcoll")) || + (NameLen == 10 && !strcmp(NameStr, "strcasecmp")) || + (NameLen == 11 && !strcmp(NameStr, "strncasecmp"))) { + if (FTy->getNumParams() < 2 || + !isa<PointerType>(FTy->getParamType(0)) || + !isa<PointerType>(FTy->getParamType(1))) + continue; + setOnlyReadsMemory(F); + setDoesNotThrow(F); + setDoesNotCapture(F, 1); + setDoesNotCapture(F, 2); + } else if ((NameLen == 6 && !strcmp(NameStr, "strstr")) || + (NameLen == 7 && !strcmp(NameStr, "strpbrk"))) { + if (FTy->getNumParams() != 2 || + !isa<PointerType>(FTy->getParamType(1))) + continue; + setOnlyReadsMemory(F); + setDoesNotThrow(F); + setDoesNotCapture(F, 2); + } else if ((NameLen == 6 && !strcmp(NameStr, "strtok")) || + (NameLen == 8 && !strcmp(NameStr, "strtok_r"))) { + if (FTy->getNumParams() < 2 || + !isa<PointerType>(FTy->getParamType(1))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 2); + } else if ((NameLen == 5 && !strcmp(NameStr, "scanf")) || + (NameLen == 6 && !strcmp(NameStr, "setbuf")) || + (NameLen == 7 && !strcmp(NameStr, "setvbuf"))) { + if (FTy->getNumParams() < 1 || + !isa<PointerType>(FTy->getParamType(0))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 1); + } else if ((NameLen == 6 && !strcmp(NameStr, "strdup")) || + (NameLen == 7 && !strcmp(NameStr, "strndup"))) { + if (FTy->getNumParams() < 1 || + !isa<PointerType>(FTy->getReturnType()) || + !isa<PointerType>(FTy->getParamType(0))) + continue; + setDoesNotThrow(F); + setDoesNotAlias(F, 0); + setDoesNotCapture(F, 1); + } else if ((NameLen == 4 && !strcmp(NameStr, "stat")) || + (NameLen == 6 && !strcmp(NameStr, "sscanf")) || + (NameLen == 7 && !strcmp(NameStr, "sprintf")) || + (NameLen == 7 && !strcmp(NameStr, "statvfs"))) { + if (FTy->getNumParams() < 2 || + !isa<PointerType>(FTy->getParamType(0)) || + !isa<PointerType>(FTy->getParamType(1))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 1); + setDoesNotCapture(F, 2); + } else if (NameLen == 8 && !strcmp(NameStr, "snprintf")) { + if (FTy->getNumParams() != 3 || + !isa<PointerType>(FTy->getParamType(0)) || + !isa<PointerType>(FTy->getParamType(2))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 1); + setDoesNotCapture(F, 3); + } else if (NameLen == 9 && !strcmp(NameStr, "setitimer")) { + if (FTy->getNumParams() != 3 || + !isa<PointerType>(FTy->getParamType(1)) || + !isa<PointerType>(FTy->getParamType(2))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 2); + setDoesNotCapture(F, 3); + } else if (NameLen == 6 && !strcmp(NameStr, "system")) { + if (FTy->getNumParams() != 1 || + !isa<PointerType>(FTy->getParamType(0))) + continue; + // May throw; "system" is a valid pthread cancellation point. + setDoesNotCapture(F, 1); + } + break; + case 'm': + if (NameLen == 6 && !strcmp(NameStr, "memcmp")) { + if (FTy->getNumParams() != 3 || + !isa<PointerType>(FTy->getParamType(0)) || + !isa<PointerType>(FTy->getParamType(1))) + continue; + setOnlyReadsMemory(F); + setDoesNotThrow(F); + setDoesNotCapture(F, 1); + setDoesNotCapture(F, 2); + } else if ((NameLen == 6 && !strcmp(NameStr, "memchr")) || + (NameLen == 7 && !strcmp(NameStr, "memrchr"))) { + if (FTy->getNumParams() != 3) + continue; + setOnlyReadsMemory(F); + setDoesNotThrow(F); + } else if ((NameLen == 4 && !strcmp(NameStr, "modf")) || + (NameLen == 5 && !strcmp(NameStr, "modff")) || + (NameLen == 5 && !strcmp(NameStr, "modfl")) || + (NameLen == 6 && !strcmp(NameStr, "memcpy")) || + (NameLen == 7 && !strcmp(NameStr, "memccpy")) || + (NameLen == 7 && !strcmp(NameStr, "memmove"))) { + if (FTy->getNumParams() < 2 || + !isa<PointerType>(FTy->getParamType(1))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 2); + } else if (NameLen == 8 && !strcmp(NameStr, "memalign")) { + if (!isa<PointerType>(FTy->getReturnType())) + continue; + setDoesNotAlias(F, 0); + } else if ((NameLen == 5 && !strcmp(NameStr, "mkdir")) || + (NameLen == 6 && !strcmp(NameStr, "mktime"))) { + if (FTy->getNumParams() == 0 || + !isa<PointerType>(FTy->getParamType(0))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 1); + } + break; + case 'r': + if (NameLen == 7 && !strcmp(NameStr, "realloc")) { + if (FTy->getNumParams() != 2 || + !isa<PointerType>(FTy->getParamType(0)) || + !isa<PointerType>(FTy->getReturnType())) + continue; + setDoesNotThrow(F); + setDoesNotAlias(F, 0); + setDoesNotCapture(F, 1); + } else if (NameLen == 4 && !strcmp(NameStr, "read")) { + if (FTy->getNumParams() != 3 || + !isa<PointerType>(FTy->getParamType(1))) + continue; + // May throw; "read" is a valid pthread cancellation point. + setDoesNotCapture(F, 2); + } else if ((NameLen == 5 && !strcmp(NameStr, "rmdir")) || + (NameLen == 6 && !strcmp(NameStr, "rewind")) || + (NameLen == 6 && !strcmp(NameStr, "remove")) || + (NameLen == 8 && !strcmp(NameStr, "realpath"))) { + if (FTy->getNumParams() < 1 || + !isa<PointerType>(FTy->getParamType(0))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 1); + } else if ((NameLen == 6 && !strcmp(NameStr, "rename")) || + (NameLen == 8 && !strcmp(NameStr, "readlink"))) { + if (FTy->getNumParams() < 2 || + !isa<PointerType>(FTy->getParamType(0)) || + !isa<PointerType>(FTy->getParamType(1))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 1); + setDoesNotCapture(F, 2); + } + break; + case 'w': + if (NameLen == 5 && !strcmp(NameStr, "write")) { + if (FTy->getNumParams() != 3 || + !isa<PointerType>(FTy->getParamType(1))) + continue; + // May throw; "write" is a valid pthread cancellation point. + setDoesNotCapture(F, 2); + } + break; + case 'b': + if (NameLen == 5 && !strcmp(NameStr, "bcopy")) { + if (FTy->getNumParams() != 3 || + !isa<PointerType>(FTy->getParamType(0)) || + !isa<PointerType>(FTy->getParamType(1))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 1); + setDoesNotCapture(F, 2); + } else if (NameLen == 4 && !strcmp(NameStr, "bcmp")) { + if (FTy->getNumParams() != 3 || + !isa<PointerType>(FTy->getParamType(0)) || + !isa<PointerType>(FTy->getParamType(1))) + continue; + setDoesNotThrow(F); + setOnlyReadsMemory(F); + setDoesNotCapture(F, 1); + setDoesNotCapture(F, 2); + } else if (NameLen == 5 && !strcmp(NameStr, "bzero")) { + if (FTy->getNumParams() != 2 || + !isa<PointerType>(FTy->getParamType(0))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 1); + } + break; + case 'c': + if (NameLen == 6 && !strcmp(NameStr, "calloc")) { + if (FTy->getNumParams() != 2 || + !isa<PointerType>(FTy->getReturnType())) + continue; + setDoesNotThrow(F); + setDoesNotAlias(F, 0); + } else if ((NameLen == 5 && !strcmp(NameStr, "chmod")) || + (NameLen == 5 && !strcmp(NameStr, "chown")) || + (NameLen == 7 && !strcmp(NameStr, "ctermid")) || + (NameLen == 8 && !strcmp(NameStr, "clearerr")) || + (NameLen == 8 && !strcmp(NameStr, "closedir"))) { + if (FTy->getNumParams() == 0 || + !isa<PointerType>(FTy->getParamType(0))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 1); + } + break; + case 'a': + if ((NameLen == 4 && !strcmp(NameStr, "atoi")) || + (NameLen == 4 && !strcmp(NameStr, "atol")) || + (NameLen == 4 && !strcmp(NameStr, "atof")) || + (NameLen == 5 && !strcmp(NameStr, "atoll"))) { + if (FTy->getNumParams() != 1 || + !isa<PointerType>(FTy->getParamType(0))) + continue; + setDoesNotThrow(F); + setOnlyReadsMemory(F); + setDoesNotCapture(F, 1); + } else if (NameLen == 6 && !strcmp(NameStr, "access")) { + if (FTy->getNumParams() != 2 || + !isa<PointerType>(FTy->getParamType(0))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 1); + } + break; + case 'f': + if (NameLen == 5 && !strcmp(NameStr, "fopen")) { + if (FTy->getNumParams() != 2 || + !isa<PointerType>(FTy->getReturnType()) || + !isa<PointerType>(FTy->getParamType(0)) || + !isa<PointerType>(FTy->getParamType(1))) + continue; + setDoesNotThrow(F); + setDoesNotAlias(F, 0); + setDoesNotCapture(F, 1); + setDoesNotCapture(F, 2); + } else if (NameLen == 6 && !strcmp(NameStr, "fdopen")) { + if (FTy->getNumParams() != 2 || + !isa<PointerType>(FTy->getReturnType()) || + !isa<PointerType>(FTy->getParamType(1))) + continue; + setDoesNotThrow(F); + setDoesNotAlias(F, 0); + setDoesNotCapture(F, 2); + } else if ((NameLen == 4 && !strcmp(NameStr, "feof")) || + (NameLen == 4 && !strcmp(NameStr, "free")) || + (NameLen == 5 && !strcmp(NameStr, "fseek")) || + (NameLen == 5 && !strcmp(NameStr, "ftell")) || + (NameLen == 5 && !strcmp(NameStr, "fgetc")) || + (NameLen == 6 && !strcmp(NameStr, "fseeko")) || + (NameLen == 6 && !strcmp(NameStr, "ftello")) || + (NameLen == 6 && !strcmp(NameStr, "fileno")) || + (NameLen == 6 && !strcmp(NameStr, "fflush")) || + (NameLen == 6 && !strcmp(NameStr, "fclose")) || + (NameLen == 7 && !strcmp(NameStr, "fsetpos")) || + (NameLen == 9 && !strcmp(NameStr, "flockfile")) || + (NameLen == 11 && !strcmp(NameStr, "funlockfile")) || + (NameLen == 12 && !strcmp(NameStr, "ftrylockfile"))) { + if (FTy->getNumParams() == 0 || + !isa<PointerType>(FTy->getParamType(0))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 1); + } else if (NameLen == 6 && !strcmp(NameStr, "ferror")) { + if (FTy->getNumParams() != 1 || + !isa<PointerType>(FTy->getParamType(0))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 1); + setOnlyReadsMemory(F); + } else if ((NameLen == 5 && !strcmp(NameStr, "fputc")) || + (NameLen == 5 && !strcmp(NameStr, "fstat")) || + (NameLen == 5 && !strcmp(NameStr, "frexp")) || + (NameLen == 6 && !strcmp(NameStr, "frexpf")) || + (NameLen == 6 && !strcmp(NameStr, "frexpl")) || + (NameLen == 8 && !strcmp(NameStr, "fstatvfs"))) { + if (FTy->getNumParams() != 2 || + !isa<PointerType>(FTy->getParamType(1))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 2); + } else if (NameLen == 5 && !strcmp(NameStr, "fgets")) { + if (FTy->getNumParams() != 3 || + !isa<PointerType>(FTy->getParamType(0)) || + !isa<PointerType>(FTy->getParamType(2))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 3); + } else if ((NameLen == 5 && !strcmp(NameStr, "fread")) || + (NameLen == 6 && !strcmp(NameStr, "fwrite"))) { + if (FTy->getNumParams() != 4 || + !isa<PointerType>(FTy->getParamType(0)) || + !isa<PointerType>(FTy->getParamType(3))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 1); + setDoesNotCapture(F, 4); + } else if ((NameLen == 5 && !strcmp(NameStr, "fputs")) || + (NameLen == 6 && !strcmp(NameStr, "fscanf")) || + (NameLen == 7 && !strcmp(NameStr, "fprintf")) || + (NameLen == 7 && !strcmp(NameStr, "fgetpos"))) { + if (FTy->getNumParams() < 2 || + !isa<PointerType>(FTy->getParamType(0)) || + !isa<PointerType>(FTy->getParamType(1))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 1); + setDoesNotCapture(F, 2); + } + break; + case 'g': + if ((NameLen == 4 && !strcmp(NameStr, "getc")) || + (NameLen == 10 && !strcmp(NameStr, "getlogin_r")) || + (NameLen == 13 && !strcmp(NameStr, "getc_unlocked"))) { + if (FTy->getNumParams() == 0 || + !isa<PointerType>(FTy->getParamType(0))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 1); + } else if (NameLen == 6 && !strcmp(NameStr, "getenv")) { + if (FTy->getNumParams() != 1 || + !isa<PointerType>(FTy->getParamType(0))) + continue; + setDoesNotThrow(F); + setOnlyReadsMemory(F); + setDoesNotCapture(F, 1); + } else if ((NameLen == 4 && !strcmp(NameStr, "gets")) || + (NameLen == 7 && !strcmp(NameStr, "getchar"))) { + setDoesNotThrow(F); + } else if (NameLen == 9 && !strcmp(NameStr, "getitimer")) { + if (FTy->getNumParams() != 2 || + !isa<PointerType>(FTy->getParamType(1))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 2); + } else if (NameLen == 8 && !strcmp(NameStr, "getpwnam")) { + if (FTy->getNumParams() != 1 || + !isa<PointerType>(FTy->getParamType(0))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 1); + } + break; + case 'u': + if (NameLen == 6 && !strcmp(NameStr, "ungetc")) { + if (FTy->getNumParams() != 2 || + !isa<PointerType>(FTy->getParamType(1))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 2); + } else if ((NameLen == 5 && !strcmp(NameStr, "uname")) || + (NameLen == 6 && !strcmp(NameStr, "unlink")) || + (NameLen == 8 && !strcmp(NameStr, "unsetenv"))) { + if (FTy->getNumParams() != 1 || + !isa<PointerType>(FTy->getParamType(0))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 1); + } else if ((NameLen == 5 && !strcmp(NameStr, "utime")) || + (NameLen == 6 && !strcmp(NameStr, "utimes"))) { + if (FTy->getNumParams() != 2 || + !isa<PointerType>(FTy->getParamType(0)) || + !isa<PointerType>(FTy->getParamType(1))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 1); + setDoesNotCapture(F, 2); + } + break; + case 'p': + if (NameLen == 4 && !strcmp(NameStr, "putc")) { + if (FTy->getNumParams() != 2 || + !isa<PointerType>(FTy->getParamType(1))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 2); + } else if ((NameLen == 4 && !strcmp(NameStr, "puts")) || + (NameLen == 6 && !strcmp(NameStr, "printf")) || + (NameLen == 6 && !strcmp(NameStr, "perror"))) { + if (FTy->getNumParams() != 1 || + !isa<PointerType>(FTy->getParamType(0))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 1); + } else if ((NameLen == 5 && !strcmp(NameStr, "pread")) || + (NameLen == 6 && !strcmp(NameStr, "pwrite"))) { + if (FTy->getNumParams() != 4 || + !isa<PointerType>(FTy->getParamType(1))) + continue; + // May throw; these are valid pthread cancellation points. + setDoesNotCapture(F, 2); + } else if (NameLen == 7 && !strcmp(NameStr, "putchar")) { + setDoesNotThrow(F); + } else if (NameLen == 5 && !strcmp(NameStr, "popen")) { + if (FTy->getNumParams() != 2 || + !isa<PointerType>(FTy->getReturnType()) || + !isa<PointerType>(FTy->getParamType(0)) || + !isa<PointerType>(FTy->getParamType(1))) + continue; + setDoesNotThrow(F); + setDoesNotAlias(F, 0); + setDoesNotCapture(F, 1); + setDoesNotCapture(F, 2); + } else if (NameLen == 6 && !strcmp(NameStr, "pclose")) { + if (FTy->getNumParams() != 1 || + !isa<PointerType>(FTy->getParamType(0))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 1); + } + break; + case 'v': + if (NameLen == 6 && !strcmp(NameStr, "vscanf")) { + if (FTy->getNumParams() != 2 || + !isa<PointerType>(FTy->getParamType(1))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 1); + } else if ((NameLen == 7 && !strcmp(NameStr, "vsscanf")) || + (NameLen == 7 && !strcmp(NameStr, "vfscanf"))) { + if (FTy->getNumParams() != 3 || + !isa<PointerType>(FTy->getParamType(1)) || + !isa<PointerType>(FTy->getParamType(2))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 1); + setDoesNotCapture(F, 2); + } else if (NameLen == 6 && !strcmp(NameStr, "valloc")) { + if (!isa<PointerType>(FTy->getReturnType())) + continue; + setDoesNotThrow(F); + setDoesNotAlias(F, 0); + } else if (NameLen == 7 && !strcmp(NameStr, "vprintf")) { + if (FTy->getNumParams() != 2 || + !isa<PointerType>(FTy->getParamType(0))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 1); + } else if ((NameLen == 8 && !strcmp(NameStr, "vfprintf")) || + (NameLen == 8 && !strcmp(NameStr, "vsprintf"))) { + if (FTy->getNumParams() != 3 || + !isa<PointerType>(FTy->getParamType(0)) || + !isa<PointerType>(FTy->getParamType(1))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 1); + setDoesNotCapture(F, 2); + } else if (NameLen == 9 && !strcmp(NameStr, "vsnprintf")) { + if (FTy->getNumParams() != 4 || + !isa<PointerType>(FTy->getParamType(0)) || + !isa<PointerType>(FTy->getParamType(2))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 1); + setDoesNotCapture(F, 3); + } + break; + case 'o': + if (NameLen == 4 && !strcmp(NameStr, "open")) { + if (FTy->getNumParams() < 2 || + !isa<PointerType>(FTy->getParamType(0))) + continue; + // May throw; "open" is a valid pthread cancellation point. + setDoesNotCapture(F, 1); + } else if (NameLen == 7 && !strcmp(NameStr, "opendir")) { + if (FTy->getNumParams() != 1 || + !isa<PointerType>(FTy->getReturnType()) || + !isa<PointerType>(FTy->getParamType(0))) + continue; + setDoesNotThrow(F); + setDoesNotAlias(F, 0); + setDoesNotCapture(F, 1); + } + break; + case 't': + if (NameLen == 7 && !strcmp(NameStr, "tmpfile")) { + if (!isa<PointerType>(FTy->getReturnType())) + continue; + setDoesNotThrow(F); + setDoesNotAlias(F, 0); + } else if (NameLen == 5 && !strcmp(NameStr, "times")) { + if (FTy->getNumParams() != 1 || + !isa<PointerType>(FTy->getParamType(0))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 1); + } + break; + case 'h': + if ((NameLen == 5 && !strcmp(NameStr, "htonl")) || + (NameLen == 5 && !strcmp(NameStr, "htons"))) { + setDoesNotThrow(F); + setDoesNotAccessMemory(F); + } + break; + case 'n': + if ((NameLen == 5 && !strcmp(NameStr, "ntohl")) || + (NameLen == 5 && !strcmp(NameStr, "ntohs"))) { + setDoesNotThrow(F); + setDoesNotAccessMemory(F); + } + break; + case 'l': + if (NameLen == 5 && !strcmp(NameStr, "lstat")) { + if (FTy->getNumParams() != 2 || + !isa<PointerType>(FTy->getParamType(0)) || + !isa<PointerType>(FTy->getParamType(1))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 1); + setDoesNotCapture(F, 2); + } else if (NameLen == 6 && !strcmp(NameStr, "lchown")) { + if (FTy->getNumParams() != 3 || + !isa<PointerType>(FTy->getParamType(0))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 1); + } + break; + case 'q': + if (NameLen == 5 && !strcmp(NameStr, "qsort")) { + if (FTy->getNumParams() != 4 || + !isa<PointerType>(FTy->getParamType(3))) + continue; + // May throw; places call through function pointer. + setDoesNotCapture(F, 4); + } + break; + case '_': + if ((NameLen == 8 && !strcmp(NameStr, "__strdup")) || + (NameLen == 9 && !strcmp(NameStr, "__strndup"))) { + if (FTy->getNumParams() < 1 || + !isa<PointerType>(FTy->getReturnType()) || + !isa<PointerType>(FTy->getParamType(0))) + continue; + setDoesNotThrow(F); + setDoesNotAlias(F, 0); + setDoesNotCapture(F, 1); + } else if (NameLen == 10 && !strcmp(NameStr, "__strtok_r")) { + if (FTy->getNumParams() != 3 || + !isa<PointerType>(FTy->getParamType(1))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 2); + } else if (NameLen == 8 && !strcmp(NameStr, "_IO_getc")) { + if (FTy->getNumParams() != 1 || + !isa<PointerType>(FTy->getParamType(0))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 1); + } else if (NameLen == 8 && !strcmp(NameStr, "_IO_putc")) { + if (FTy->getNumParams() != 2 || + !isa<PointerType>(FTy->getParamType(1))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 2); + } + break; + case 1: + if (NameLen == 15 && !strcmp(NameStr, "\1__isoc99_scanf")) { + if (FTy->getNumParams() < 1 || + !isa<PointerType>(FTy->getParamType(0))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 1); + } else if ((NameLen == 7 && !strcmp(NameStr, "\1stat64")) || + (NameLen == 8 && !strcmp(NameStr, "\1lstat64")) || + (NameLen == 10 && !strcmp(NameStr, "\1statvfs64")) || + (NameLen == 16 && !strcmp(NameStr, "\1__isoc99_sscanf"))) { + if (FTy->getNumParams() < 1 || + !isa<PointerType>(FTy->getParamType(0)) || + !isa<PointerType>(FTy->getParamType(1))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 1); + setDoesNotCapture(F, 2); + } else if (NameLen == 8 && !strcmp(NameStr, "\1fopen64")) { + if (FTy->getNumParams() != 2 || + !isa<PointerType>(FTy->getReturnType()) || + !isa<PointerType>(FTy->getParamType(0)) || + !isa<PointerType>(FTy->getParamType(1))) + continue; + setDoesNotThrow(F); + setDoesNotAlias(F, 0); + setDoesNotCapture(F, 1); + setDoesNotCapture(F, 2); + } else if ((NameLen == 9 && !strcmp(NameStr, "\1fseeko64")) || + (NameLen == 9 && !strcmp(NameStr, "\1ftello64"))) { + if (FTy->getNumParams() == 0 || + !isa<PointerType>(FTy->getParamType(0))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 1); + } else if (NameLen == 10 && !strcmp(NameStr, "\1tmpfile64")) { + if (!isa<PointerType>(FTy->getReturnType())) + continue; + setDoesNotThrow(F); + setDoesNotAlias(F, 0); + } else if ((NameLen == 8 && !strcmp(NameStr, "\1fstat64")) || + (NameLen == 11 && !strcmp(NameStr, "\1fstatvfs64"))) { + if (FTy->getNumParams() != 2 || + !isa<PointerType>(FTy->getParamType(1))) + continue; + setDoesNotThrow(F); + setDoesNotCapture(F, 2); + } else if (NameLen == 7 && !strcmp(NameStr, "\1open64")) { + if (FTy->getNumParams() < 2 || + !isa<PointerType>(FTy->getParamType(0))) + continue; + // May throw; "open" is a valid pthread cancellation point. + setDoesNotCapture(F, 1); + } + break; + } + } + return Modified; +} + +// TODO: +// Additional cases that we need to add to this file: +// +// cbrt: +// * cbrt(expN(X)) -> expN(x/3) +// * cbrt(sqrt(x)) -> pow(x,1/6) +// * cbrt(sqrt(x)) -> pow(x,1/9) +// +// cos, cosf, cosl: +// * cos(-x) -> cos(x) +// +// exp, expf, expl: +// * exp(log(x)) -> x +// +// log, logf, logl: +// * log(exp(x)) -> x +// * log(x**y) -> y*log(x) +// * log(exp(y)) -> y*log(e) +// * log(exp2(y)) -> y*log(2) +// * log(exp10(y)) -> y*log(10) +// * log(sqrt(x)) -> 0.5*log(x) +// * log(pow(x,y)) -> y*log(x) +// +// lround, lroundf, lroundl: +// * lround(cnst) -> cnst' +// +// memcmp: +// * memcmp(x,y,l) -> cnst +// (if all arguments are constant and strlen(x) <= l and strlen(y) <= l) +// +// pow, powf, powl: +// * pow(exp(x),y) -> exp(x*y) +// * pow(sqrt(x),y) -> pow(x,y*0.5) +// * pow(pow(x,y),z)-> pow(x,y*z) +// +// puts: +// * puts("") -> putchar("\n") +// +// round, roundf, roundl: +// * round(cnst) -> cnst' +// +// signbit: +// * signbit(cnst) -> cnst' +// * signbit(nncst) -> 0 (if pstv is a non-negative constant) +// +// sqrt, sqrtf, sqrtl: +// * sqrt(expN(x)) -> expN(x*0.5) +// * sqrt(Nroot(x)) -> pow(x,1/(2*N)) +// * sqrt(pow(x,y)) -> pow(|x|,y*0.5) +// +// stpcpy: +// * stpcpy(str, "literal") -> +// llvm.memcpy(str,"literal",strlen("literal")+1,1) +// strrchr: +// * strrchr(s,c) -> reverse_offset_of_in(c,s) +// (if c is a constant integer and s is a constant string) +// * strrchr(s1,0) -> strchr(s1,0) +// +// strpbrk: +// * strpbrk(s,a) -> offset_in_for(s,a) +// (if s and a are both constant strings) +// * strpbrk(s,"") -> 0 +// * strpbrk(s,a) -> strchr(s,a[0]) (if a is constant string of length 1) +// +// strspn, strcspn: +// * strspn(s,a) -> const_int (if both args are constant) +// * strspn("",a) -> 0 +// * strspn(s,"") -> 0 +// * strcspn(s,a) -> const_int (if both args are constant) +// * strcspn("",a) -> 0 +// * strcspn(s,"") -> strlen(a) +// +// strstr: +// * strstr(x,x) -> x +// * strstr(s1,s2) -> offset_of_s2_in(s1) +// (if s1 and s2 are constant strings) +// +// tan, tanf, tanl: +// * tan(atan(x)) -> x +// +// trunc, truncf, truncl: +// * trunc(cnst) -> cnst' +// +// diff --git a/lib/Transforms/Scalar/TailDuplication.cpp b/lib/Transforms/Scalar/TailDuplication.cpp new file mode 100644 index 0000000..99a7dee --- /dev/null +++ b/lib/Transforms/Scalar/TailDuplication.cpp @@ -0,0 +1,365 @@ +//===- TailDuplication.cpp - Simplify CFG through tail duplication --------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass performs a limited form of tail duplication, intended to simplify +// CFGs by removing some unconditional branches. This pass is necessary to +// straighten out loops created by the C front-end, but also is capable of +// making other code nicer. After this pass is run, the CFG simplify pass +// should be run to clean up the mess. +// +// This pass could be enhanced in the future to use profile information to be +// more aggressive. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "tailduplicate" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Constant.h" +#include "llvm/Function.h" +#include "llvm/Instructions.h" +#include "llvm/IntrinsicInst.h" +#include "llvm/Pass.h" +#include "llvm/Type.h" +#include "llvm/Support/CFG.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/SmallPtrSet.h" +#include <map> +using namespace llvm; + +STATISTIC(NumEliminated, "Number of unconditional branches eliminated"); + +static cl::opt<unsigned> +TailDupThreshold("taildup-threshold", + cl::desc("Max block size to tail duplicate"), + cl::init(1), cl::Hidden); + +namespace { + class VISIBILITY_HIDDEN TailDup : public FunctionPass { + bool runOnFunction(Function &F); + public: + static char ID; // Pass identification, replacement for typeid + TailDup() : FunctionPass(&ID) {} + + private: + inline bool shouldEliminateUnconditionalBranch(TerminatorInst *, unsigned); + inline void eliminateUnconditionalBranch(BranchInst *BI); + SmallPtrSet<BasicBlock*, 4> CycleDetector; + }; +} + +char TailDup::ID = 0; +static RegisterPass<TailDup> X("tailduplicate", "Tail Duplication"); + +// Public interface to the Tail Duplication pass +FunctionPass *llvm::createTailDuplicationPass() { return new TailDup(); } + +/// runOnFunction - Top level algorithm - Loop over each unconditional branch in +/// the function, eliminating it if it looks attractive enough. CycleDetector +/// prevents infinite loops by checking that we aren't redirecting a branch to +/// a place it already pointed to earlier; see PR 2323. +bool TailDup::runOnFunction(Function &F) { + bool Changed = false; + CycleDetector.clear(); + for (Function::iterator I = F.begin(), E = F.end(); I != E; ) { + if (shouldEliminateUnconditionalBranch(I->getTerminator(), + TailDupThreshold)) { + eliminateUnconditionalBranch(cast<BranchInst>(I->getTerminator())); + Changed = true; + } else { + ++I; + CycleDetector.clear(); + } + } + return Changed; +} + +/// shouldEliminateUnconditionalBranch - Return true if this branch looks +/// attractive to eliminate. We eliminate the branch if the destination basic +/// block has <= 5 instructions in it, not counting PHI nodes. In practice, +/// since one of these is a terminator instruction, this means that we will add +/// up to 4 instructions to the new block. +/// +/// We don't count PHI nodes in the count since they will be removed when the +/// contents of the block are copied over. +/// +bool TailDup::shouldEliminateUnconditionalBranch(TerminatorInst *TI, + unsigned Threshold) { + BranchInst *BI = dyn_cast<BranchInst>(TI); + if (!BI || !BI->isUnconditional()) return false; // Not an uncond branch! + + BasicBlock *Dest = BI->getSuccessor(0); + if (Dest == BI->getParent()) return false; // Do not loop infinitely! + + // Do not inline a block if we will just get another branch to the same block! + TerminatorInst *DTI = Dest->getTerminator(); + if (BranchInst *DBI = dyn_cast<BranchInst>(DTI)) + if (DBI->isUnconditional() && DBI->getSuccessor(0) == Dest) + return false; // Do not loop infinitely! + + // FIXME: DemoteRegToStack cannot yet demote invoke instructions to the stack, + // because doing so would require breaking critical edges. This should be + // fixed eventually. + if (!DTI->use_empty()) + return false; + + // Do not bother with blocks with only a single predecessor: simplify + // CFG will fold these two blocks together! + pred_iterator PI = pred_begin(Dest), PE = pred_end(Dest); + ++PI; + if (PI == PE) return false; // Exactly one predecessor! + + BasicBlock::iterator I = Dest->getFirstNonPHI(); + + for (unsigned Size = 0; I != Dest->end(); ++I) { + if (Size == Threshold) return false; // The block is too large. + + // Don't tail duplicate call instructions. They are very large compared to + // other instructions. + if (isa<CallInst>(I) || isa<InvokeInst>(I)) return false; + + // Allso alloca and malloc. + if (isa<AllocationInst>(I)) return false; + + // Some vector instructions can expand into a number of instructions. + if (isa<ShuffleVectorInst>(I) || isa<ExtractElementInst>(I) || + isa<InsertElementInst>(I)) return false; + + // Only count instructions that are not debugger intrinsics. + if (!isa<DbgInfoIntrinsic>(I)) ++Size; + } + + // Do not tail duplicate a block that has thousands of successors into a block + // with a single successor if the block has many other predecessors. This can + // cause an N^2 explosion in CFG edges (and PHI node entries), as seen in + // cases that have a large number of indirect gotos. + unsigned NumSuccs = DTI->getNumSuccessors(); + if (NumSuccs > 8) { + unsigned TooMany = 128; + if (NumSuccs >= TooMany) return false; + TooMany = TooMany/NumSuccs; + for (; PI != PE; ++PI) + if (TooMany-- == 0) return false; + } + + // If this unconditional branch is a fall-through, be careful about + // tail duplicating it. In particular, we don't want to taildup it if the + // original block will still be there after taildup is completed: doing so + // would eliminate the fall-through, requiring unconditional branches. + Function::iterator DestI = Dest; + if (&*--DestI == BI->getParent()) { + // The uncond branch is a fall-through. Tail duplication of the block is + // will eliminate the fall-through-ness and end up cloning the terminator + // at the end of the Dest block. Since the original Dest block will + // continue to exist, this means that one or the other will not be able to + // fall through. One typical example that this helps with is code like: + // if (a) + // foo(); + // if (b) + // foo(); + // Cloning the 'if b' block into the end of the first foo block is messy. + + // The messy case is when the fall-through block falls through to other + // blocks. This is what we would be preventing if we cloned the block. + DestI = Dest; + if (++DestI != Dest->getParent()->end()) { + BasicBlock *DestSucc = DestI; + // If any of Dest's successors are fall-throughs, don't do this xform. + for (succ_iterator SI = succ_begin(Dest), SE = succ_end(Dest); + SI != SE; ++SI) + if (*SI == DestSucc) + return false; + } + } + + // Finally, check that we haven't redirected to this target block earlier; + // there are cases where we loop forever if we don't check this (PR 2323). + if (!CycleDetector.insert(Dest)) + return false; + + return true; +} + +/// FindObviousSharedDomOf - We know there is a branch from SrcBlock to +/// DestBlock, and that SrcBlock is not the only predecessor of DstBlock. If we +/// can find a predecessor of SrcBlock that is a dominator of both SrcBlock and +/// DstBlock, return it. +static BasicBlock *FindObviousSharedDomOf(BasicBlock *SrcBlock, + BasicBlock *DstBlock) { + // SrcBlock must have a single predecessor. + pred_iterator PI = pred_begin(SrcBlock), PE = pred_end(SrcBlock); + if (PI == PE || ++PI != PE) return 0; + + BasicBlock *SrcPred = *pred_begin(SrcBlock); + + // Look at the predecessors of DstBlock. One of them will be SrcBlock. If + // there is only one other pred, get it, otherwise we can't handle it. + PI = pred_begin(DstBlock); PE = pred_end(DstBlock); + BasicBlock *DstOtherPred = 0; + if (*PI == SrcBlock) { + if (++PI == PE) return 0; + DstOtherPred = *PI; + if (++PI != PE) return 0; + } else { + DstOtherPred = *PI; + if (++PI == PE || *PI != SrcBlock || ++PI != PE) return 0; + } + + // We can handle two situations here: "if then" and "if then else" blocks. An + // 'if then' situation is just where DstOtherPred == SrcPred. + if (DstOtherPred == SrcPred) + return SrcPred; + + // Check to see if we have an "if then else" situation, which means that + // DstOtherPred will have a single predecessor and it will be SrcPred. + PI = pred_begin(DstOtherPred); PE = pred_end(DstOtherPred); + if (PI != PE && *PI == SrcPred) { + if (++PI != PE) return 0; // Not a single pred. + return SrcPred; // Otherwise, it's an "if then" situation. Return the if. + } + + // Otherwise, this is something we can't handle. + return 0; +} + + +/// eliminateUnconditionalBranch - Clone the instructions from the destination +/// block into the source block, eliminating the specified unconditional branch. +/// If the destination block defines values used by successors of the dest +/// block, we may need to insert PHI nodes. +/// +void TailDup::eliminateUnconditionalBranch(BranchInst *Branch) { + BasicBlock *SourceBlock = Branch->getParent(); + BasicBlock *DestBlock = Branch->getSuccessor(0); + assert(SourceBlock != DestBlock && "Our predicate is broken!"); + + DOUT << "TailDuplication[" << SourceBlock->getParent()->getName() + << "]: Eliminating branch: " << *Branch; + + // See if we can avoid duplicating code by moving it up to a dominator of both + // blocks. + if (BasicBlock *DomBlock = FindObviousSharedDomOf(SourceBlock, DestBlock)) { + DOUT << "Found shared dominator: " << DomBlock->getName() << "\n"; + + // If there are non-phi instructions in DestBlock that have no operands + // defined in DestBlock, and if the instruction has no side effects, we can + // move the instruction to DomBlock instead of duplicating it. + BasicBlock::iterator BBI = DestBlock->getFirstNonPHI(); + while (!isa<TerminatorInst>(BBI)) { + Instruction *I = BBI++; + + bool CanHoist = !I->isTrapping() && !I->mayHaveSideEffects(); + if (CanHoist) { + for (unsigned op = 0, e = I->getNumOperands(); op != e; ++op) + if (Instruction *OpI = dyn_cast<Instruction>(I->getOperand(op))) + if (OpI->getParent() == DestBlock || + (isa<InvokeInst>(OpI) && OpI->getParent() == DomBlock)) { + CanHoist = false; + break; + } + if (CanHoist) { + // Remove from DestBlock, move right before the term in DomBlock. + DestBlock->getInstList().remove(I); + DomBlock->getInstList().insert(DomBlock->getTerminator(), I); + DOUT << "Hoisted: " << *I; + } + } + } + } + + // Tail duplication can not update SSA properties correctly if the values + // defined in the duplicated tail are used outside of the tail itself. For + // this reason, we spill all values that are used outside of the tail to the + // stack. + for (BasicBlock::iterator I = DestBlock->begin(); I != DestBlock->end(); ++I) + if (I->isUsedOutsideOfBlock(DestBlock)) { + // We found a use outside of the tail. Create a new stack slot to + // break this inter-block usage pattern. + DemoteRegToStack(*I); + } + + // We are going to have to map operands from the original block B to the new + // copy of the block B'. If there are PHI nodes in the DestBlock, these PHI + // nodes also define part of this mapping. Loop over these PHI nodes, adding + // them to our mapping. + // + std::map<Value*, Value*> ValueMapping; + + BasicBlock::iterator BI = DestBlock->begin(); + bool HadPHINodes = isa<PHINode>(BI); + for (; PHINode *PN = dyn_cast<PHINode>(BI); ++BI) + ValueMapping[PN] = PN->getIncomingValueForBlock(SourceBlock); + + // Clone the non-phi instructions of the dest block into the source block, + // keeping track of the mapping... + // + for (; BI != DestBlock->end(); ++BI) { + Instruction *New = BI->clone(); + New->setName(BI->getName()); + SourceBlock->getInstList().push_back(New); + ValueMapping[BI] = New; + } + + // Now that we have built the mapping information and cloned all of the + // instructions (giving us a new terminator, among other things), walk the new + // instructions, rewriting references of old instructions to use new + // instructions. + // + BI = Branch; ++BI; // Get an iterator to the first new instruction + for (; BI != SourceBlock->end(); ++BI) + for (unsigned i = 0, e = BI->getNumOperands(); i != e; ++i) + if (Value *Remapped = ValueMapping[BI->getOperand(i)]) + BI->setOperand(i, Remapped); + + // Next we check to see if any of the successors of DestBlock had PHI nodes. + // If so, we need to add entries to the PHI nodes for SourceBlock now. + for (succ_iterator SI = succ_begin(DestBlock), SE = succ_end(DestBlock); + SI != SE; ++SI) { + BasicBlock *Succ = *SI; + for (BasicBlock::iterator PNI = Succ->begin(); isa<PHINode>(PNI); ++PNI) { + PHINode *PN = cast<PHINode>(PNI); + // Ok, we have a PHI node. Figure out what the incoming value was for the + // DestBlock. + Value *IV = PN->getIncomingValueForBlock(DestBlock); + + // Remap the value if necessary... + if (Value *MappedIV = ValueMapping[IV]) + IV = MappedIV; + PN->addIncoming(IV, SourceBlock); + } + } + + // Next, remove the old branch instruction, and any PHI node entries that we + // had. + BI = Branch; ++BI; // Get an iterator to the first new instruction + DestBlock->removePredecessor(SourceBlock); // Remove entries in PHI nodes... + SourceBlock->getInstList().erase(Branch); // Destroy the uncond branch... + + // Final step: now that we have finished everything up, walk the cloned + // instructions one last time, constant propagating and DCE'ing them, because + // they may not be needed anymore. + // + if (HadPHINodes) { + while (BI != SourceBlock->end()) { + Instruction *Inst = BI++; + if (isInstructionTriviallyDead(Inst)) + Inst->eraseFromParent(); + else if (Constant *C = ConstantFoldInstruction(Inst)) { + Inst->replaceAllUsesWith(C); + Inst->eraseFromParent(); + } + } + } + + ++NumEliminated; // We just killed a branch! +} diff --git a/lib/Transforms/Scalar/TailRecursionElimination.cpp b/lib/Transforms/Scalar/TailRecursionElimination.cpp new file mode 100644 index 0000000..682d069 --- /dev/null +++ b/lib/Transforms/Scalar/TailRecursionElimination.cpp @@ -0,0 +1,479 @@ +//===- TailRecursionElimination.cpp - Eliminate Tail Calls ----------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file transforms calls of the current function (self recursion) followed +// by a return instruction with a branch to the entry of the function, creating +// a loop. This pass also implements the following extensions to the basic +// algorithm: +// +// 1. Trivial instructions between the call and return do not prevent the +// transformation from taking place, though currently the analysis cannot +// support moving any really useful instructions (only dead ones). +// 2. This pass transforms functions that are prevented from being tail +// recursive by an associative expression to use an accumulator variable, +// thus compiling the typical naive factorial or 'fib' implementation into +// efficient code. +// 3. TRE is performed if the function returns void, if the return +// returns the result returned by the call, or if the function returns a +// run-time constant on all exits from the function. It is possible, though +// unlikely, that the return returns something else (like constant 0), and +// can still be TRE'd. It can be TRE'd if ALL OTHER return instructions in +// the function return the exact same value. +// 4. If it can prove that callees do not access theier caller stack frame, +// they are marked as eligible for tail call elimination (by the code +// generator). +// +// There are several improvements that could be made: +// +// 1. If the function has any alloca instructions, these instructions will be +// moved out of the entry block of the function, causing them to be +// evaluated each time through the tail recursion. Safely keeping allocas +// in the entry block requires analysis to proves that the tail-called +// function does not read or write the stack object. +// 2. Tail recursion is only performed if the call immediately preceeds the +// return instruction. It's possible that there could be a jump between +// the call and the return. +// 3. There can be intervening operations between the call and the return that +// prevent the TRE from occurring. For example, there could be GEP's and +// stores to memory that will not be read or written by the call. This +// requires some substantial analysis (such as with DSA) to prove safe to +// move ahead of the call, but doing so could allow many more TREs to be +// performed, for example in TreeAdd/TreeAlloc from the treeadd benchmark. +// 4. The algorithm we use to detect if callees access their caller stack +// frames is very primitive. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "tailcallelim" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Function.h" +#include "llvm/Instructions.h" +#include "llvm/Pass.h" +#include "llvm/Support/CFG.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Support/Compiler.h" +using namespace llvm; + +STATISTIC(NumEliminated, "Number of tail calls removed"); +STATISTIC(NumAccumAdded, "Number of accumulators introduced"); + +namespace { + struct VISIBILITY_HIDDEN TailCallElim : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + TailCallElim() : FunctionPass(&ID) {} + + virtual bool runOnFunction(Function &F); + + private: + bool ProcessReturningBlock(ReturnInst *RI, BasicBlock *&OldEntry, + bool &TailCallsAreMarkedTail, + std::vector<PHINode*> &ArgumentPHIs, + bool CannotTailCallElimCallsMarkedTail); + bool CanMoveAboveCall(Instruction *I, CallInst *CI); + Value *CanTransformAccumulatorRecursion(Instruction *I, CallInst *CI); + }; +} + +char TailCallElim::ID = 0; +static RegisterPass<TailCallElim> X("tailcallelim", "Tail Call Elimination"); + +// Public interface to the TailCallElimination pass +FunctionPass *llvm::createTailCallEliminationPass() { + return new TailCallElim(); +} + + +/// AllocaMightEscapeToCalls - Return true if this alloca may be accessed by +/// callees of this function. We only do very simple analysis right now, this +/// could be expanded in the future to use mod/ref information for particular +/// call sites if desired. +static bool AllocaMightEscapeToCalls(AllocaInst *AI) { + // FIXME: do simple 'address taken' analysis. + return true; +} + +/// FunctionContainsAllocas - Scan the specified basic block for alloca +/// instructions. If it contains any that might be accessed by calls, return +/// true. +static bool CheckForEscapingAllocas(BasicBlock *BB, + bool &CannotTCETailMarkedCall) { + bool RetVal = false; + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) + if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) { + RetVal |= AllocaMightEscapeToCalls(AI); + + // If this alloca is in the body of the function, or if it is a variable + // sized allocation, we cannot tail call eliminate calls marked 'tail' + // with this mechanism. + if (BB != &BB->getParent()->getEntryBlock() || + !isa<ConstantInt>(AI->getArraySize())) + CannotTCETailMarkedCall = true; + } + return RetVal; +} + +bool TailCallElim::runOnFunction(Function &F) { + // If this function is a varargs function, we won't be able to PHI the args + // right, so don't even try to convert it... + if (F.getFunctionType()->isVarArg()) return false; + + BasicBlock *OldEntry = 0; + bool TailCallsAreMarkedTail = false; + std::vector<PHINode*> ArgumentPHIs; + bool MadeChange = false; + + bool FunctionContainsEscapingAllocas = false; + + // CannotTCETailMarkedCall - If true, we cannot perform TCE on tail calls + // marked with the 'tail' attribute, because doing so would cause the stack + // size to increase (real TCE would deallocate variable sized allocas, TCE + // doesn't). + bool CannotTCETailMarkedCall = false; + + // Loop over the function, looking for any returning blocks, and keeping track + // of whether this function has any non-trivially used allocas. + for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) { + if (FunctionContainsEscapingAllocas && CannotTCETailMarkedCall) + break; + + FunctionContainsEscapingAllocas |= + CheckForEscapingAllocas(BB, CannotTCETailMarkedCall); + } + + /// FIXME: The code generator produces really bad code when an 'escaping + /// alloca' is changed from being a static alloca to being a dynamic alloca. + /// Until this is resolved, disable this transformation if that would ever + /// happen. This bug is PR962. + if (FunctionContainsEscapingAllocas) + return false; + + + // Second pass, change any tail calls to loops. + for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) + if (ReturnInst *Ret = dyn_cast<ReturnInst>(BB->getTerminator())) + MadeChange |= ProcessReturningBlock(Ret, OldEntry, TailCallsAreMarkedTail, + ArgumentPHIs,CannotTCETailMarkedCall); + + // If we eliminated any tail recursions, it's possible that we inserted some + // silly PHI nodes which just merge an initial value (the incoming operand) + // with themselves. Check to see if we did and clean up our mess if so. This + // occurs when a function passes an argument straight through to its tail + // call. + if (!ArgumentPHIs.empty()) { + for (unsigned i = 0, e = ArgumentPHIs.size(); i != e; ++i) { + PHINode *PN = ArgumentPHIs[i]; + + // If the PHI Node is a dynamic constant, replace it with the value it is. + if (Value *PNV = PN->hasConstantValue()) { + PN->replaceAllUsesWith(PNV); + PN->eraseFromParent(); + } + } + } + + // Finally, if this function contains no non-escaping allocas, mark all calls + // in the function as eligible for tail calls (there is no stack memory for + // them to access). + if (!FunctionContainsEscapingAllocas) + for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) + if (CallInst *CI = dyn_cast<CallInst>(I)) { + CI->setTailCall(); + MadeChange = true; + } + + return MadeChange; +} + + +/// CanMoveAboveCall - Return true if it is safe to move the specified +/// instruction from after the call to before the call, assuming that all +/// instructions between the call and this instruction are movable. +/// +bool TailCallElim::CanMoveAboveCall(Instruction *I, CallInst *CI) { + // FIXME: We can move load/store/call/free instructions above the call if the + // call does not mod/ref the memory location being processed. + if (I->mayHaveSideEffects() || isa<LoadInst>(I)) + return false; + + // Otherwise, if this is a side-effect free instruction, check to make sure + // that it does not use the return value of the call. If it doesn't use the + // return value of the call, it must only use things that are defined before + // the call, or movable instructions between the call and the instruction + // itself. + for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) + if (I->getOperand(i) == CI) + return false; + return true; +} + +// isDynamicConstant - Return true if the specified value is the same when the +// return would exit as it was when the initial iteration of the recursive +// function was executed. +// +// We currently handle static constants and arguments that are not modified as +// part of the recursion. +// +static bool isDynamicConstant(Value *V, CallInst *CI) { + if (isa<Constant>(V)) return true; // Static constants are always dyn consts + + // Check to see if this is an immutable argument, if so, the value + // will be available to initialize the accumulator. + if (Argument *Arg = dyn_cast<Argument>(V)) { + // Figure out which argument number this is... + unsigned ArgNo = 0; + Function *F = CI->getParent()->getParent(); + for (Function::arg_iterator AI = F->arg_begin(); &*AI != Arg; ++AI) + ++ArgNo; + + // If we are passing this argument into call as the corresponding + // argument operand, then the argument is dynamically constant. + // Otherwise, we cannot transform this function safely. + if (CI->getOperand(ArgNo+1) == Arg) + return true; + } + // Not a constant or immutable argument, we can't safely transform. + return false; +} + +// getCommonReturnValue - Check to see if the function containing the specified +// return instruction and tail call consistently returns the same +// runtime-constant value at all exit points. If so, return the returned value. +// +static Value *getCommonReturnValue(ReturnInst *TheRI, CallInst *CI) { + Function *F = TheRI->getParent()->getParent(); + Value *ReturnedValue = 0; + + // TODO: Handle multiple value ret instructions; + if (isa<StructType>(F->getReturnType())) + return 0; + + for (Function::iterator BBI = F->begin(), E = F->end(); BBI != E; ++BBI) + if (ReturnInst *RI = dyn_cast<ReturnInst>(BBI->getTerminator())) + if (RI != TheRI) { + Value *RetOp = RI->getOperand(0); + + // We can only perform this transformation if the value returned is + // evaluatable at the start of the initial invocation of the function, + // instead of at the end of the evaluation. + // + if (!isDynamicConstant(RetOp, CI)) + return 0; + + if (ReturnedValue && RetOp != ReturnedValue) + return 0; // Cannot transform if differing values are returned. + ReturnedValue = RetOp; + } + return ReturnedValue; +} + +/// CanTransformAccumulatorRecursion - If the specified instruction can be +/// transformed using accumulator recursion elimination, return the constant +/// which is the start of the accumulator value. Otherwise return null. +/// +Value *TailCallElim::CanTransformAccumulatorRecursion(Instruction *I, + CallInst *CI) { + if (!I->isAssociative()) return 0; + assert(I->getNumOperands() == 2 && + "Associative operations should have 2 args!"); + + // Exactly one operand should be the result of the call instruction... + if ((I->getOperand(0) == CI && I->getOperand(1) == CI) || + (I->getOperand(0) != CI && I->getOperand(1) != CI)) + return 0; + + // The only user of this instruction we allow is a single return instruction. + if (!I->hasOneUse() || !isa<ReturnInst>(I->use_back())) + return 0; + + // Ok, now we have to check all of the other return instructions in this + // function. If they return non-constants or differing values, then we cannot + // transform the function safely. + return getCommonReturnValue(cast<ReturnInst>(I->use_back()), CI); +} + +bool TailCallElim::ProcessReturningBlock(ReturnInst *Ret, BasicBlock *&OldEntry, + bool &TailCallsAreMarkedTail, + std::vector<PHINode*> &ArgumentPHIs, + bool CannotTailCallElimCallsMarkedTail) { + BasicBlock *BB = Ret->getParent(); + Function *F = BB->getParent(); + + if (&BB->front() == Ret) // Make sure there is something before the ret... + return false; + + // If the return is in the entry block, then making this transformation would + // turn infinite recursion into an infinite loop. This transformation is ok + // in theory, but breaks some code like: + // double fabs(double f) { return __builtin_fabs(f); } // a 'fabs' call + // disable this xform in this case, because the code generator will lower the + // call to fabs into inline code. + if (BB == &F->getEntryBlock()) + return false; + + // Scan backwards from the return, checking to see if there is a tail call in + // this block. If so, set CI to it. + CallInst *CI; + BasicBlock::iterator BBI = Ret; + while (1) { + CI = dyn_cast<CallInst>(BBI); + if (CI && CI->getCalledFunction() == F) + break; + + if (BBI == BB->begin()) + return false; // Didn't find a potential tail call. + --BBI; + } + + // If this call is marked as a tail call, and if there are dynamic allocas in + // the function, we cannot perform this optimization. + if (CI->isTailCall() && CannotTailCallElimCallsMarkedTail) + return false; + + // If we are introducing accumulator recursion to eliminate associative + // operations after the call instruction, this variable contains the initial + // value for the accumulator. If this value is set, we actually perform + // accumulator recursion elimination instead of simple tail recursion + // elimination. + Value *AccumulatorRecursionEliminationInitVal = 0; + Instruction *AccumulatorRecursionInstr = 0; + + // Ok, we found a potential tail call. We can currently only transform the + // tail call if all of the instructions between the call and the return are + // movable to above the call itself, leaving the call next to the return. + // Check that this is the case now. + for (BBI = CI, ++BBI; &*BBI != Ret; ++BBI) + if (!CanMoveAboveCall(BBI, CI)) { + // If we can't move the instruction above the call, it might be because it + // is an associative operation that could be tranformed using accumulator + // recursion elimination. Check to see if this is the case, and if so, + // remember the initial accumulator value for later. + if ((AccumulatorRecursionEliminationInitVal = + CanTransformAccumulatorRecursion(BBI, CI))) { + // Yes, this is accumulator recursion. Remember which instruction + // accumulates. + AccumulatorRecursionInstr = BBI; + } else { + return false; // Otherwise, we cannot eliminate the tail recursion! + } + } + + // We can only transform call/return pairs that either ignore the return value + // of the call and return void, ignore the value of the call and return a + // constant, return the value returned by the tail call, or that are being + // accumulator recursion variable eliminated. + if (Ret->getNumOperands() == 1 && Ret->getReturnValue() != CI && + !isa<UndefValue>(Ret->getReturnValue()) && + AccumulatorRecursionEliminationInitVal == 0 && + !getCommonReturnValue(Ret, CI)) + return false; + + // OK! We can transform this tail call. If this is the first one found, + // create the new entry block, allowing us to branch back to the old entry. + if (OldEntry == 0) { + OldEntry = &F->getEntryBlock(); + BasicBlock *NewEntry = BasicBlock::Create("", F, OldEntry); + NewEntry->takeName(OldEntry); + OldEntry->setName("tailrecurse"); + BranchInst::Create(OldEntry, NewEntry); + + // If this tail call is marked 'tail' and if there are any allocas in the + // entry block, move them up to the new entry block. + TailCallsAreMarkedTail = CI->isTailCall(); + if (TailCallsAreMarkedTail) + // Move all fixed sized allocas from OldEntry to NewEntry. + for (BasicBlock::iterator OEBI = OldEntry->begin(), E = OldEntry->end(), + NEBI = NewEntry->begin(); OEBI != E; ) + if (AllocaInst *AI = dyn_cast<AllocaInst>(OEBI++)) + if (isa<ConstantInt>(AI->getArraySize())) + AI->moveBefore(NEBI); + + // Now that we have created a new block, which jumps to the entry + // block, insert a PHI node for each argument of the function. + // For now, we initialize each PHI to only have the real arguments + // which are passed in. + Instruction *InsertPos = OldEntry->begin(); + for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); + I != E; ++I) { + PHINode *PN = PHINode::Create(I->getType(), + I->getName() + ".tr", InsertPos); + I->replaceAllUsesWith(PN); // Everyone use the PHI node now! + PN->addIncoming(I, NewEntry); + ArgumentPHIs.push_back(PN); + } + } + + // If this function has self recursive calls in the tail position where some + // are marked tail and some are not, only transform one flavor or another. We + // have to choose whether we move allocas in the entry block to the new entry + // block or not, so we can't make a good choice for both. NOTE: We could do + // slightly better here in the case that the function has no entry block + // allocas. + if (TailCallsAreMarkedTail && !CI->isTailCall()) + return false; + + // Ok, now that we know we have a pseudo-entry block WITH all of the + // required PHI nodes, add entries into the PHI node for the actual + // parameters passed into the tail-recursive call. + for (unsigned i = 0, e = CI->getNumOperands()-1; i != e; ++i) + ArgumentPHIs[i]->addIncoming(CI->getOperand(i+1), BB); + + // If we are introducing an accumulator variable to eliminate the recursion, + // do so now. Note that we _know_ that no subsequent tail recursion + // eliminations will happen on this function because of the way the + // accumulator recursion predicate is set up. + // + if (AccumulatorRecursionEliminationInitVal) { + Instruction *AccRecInstr = AccumulatorRecursionInstr; + // Start by inserting a new PHI node for the accumulator. + PHINode *AccPN = PHINode::Create(AccRecInstr->getType(), "accumulator.tr", + OldEntry->begin()); + + // Loop over all of the predecessors of the tail recursion block. For the + // real entry into the function we seed the PHI with the initial value, + // computed earlier. For any other existing branches to this block (due to + // other tail recursions eliminated) the accumulator is not modified. + // Because we haven't added the branch in the current block to OldEntry yet, + // it will not show up as a predecessor. + for (pred_iterator PI = pred_begin(OldEntry), PE = pred_end(OldEntry); + PI != PE; ++PI) { + if (*PI == &F->getEntryBlock()) + AccPN->addIncoming(AccumulatorRecursionEliminationInitVal, *PI); + else + AccPN->addIncoming(AccPN, *PI); + } + + // Add an incoming argument for the current block, which is computed by our + // associative accumulator instruction. + AccPN->addIncoming(AccRecInstr, BB); + + // Next, rewrite the accumulator recursion instruction so that it does not + // use the result of the call anymore, instead, use the PHI node we just + // inserted. + AccRecInstr->setOperand(AccRecInstr->getOperand(0) != CI, AccPN); + + // Finally, rewrite any return instructions in the program to return the PHI + // node instead of the "initval" that they do currently. This loop will + // actually rewrite the return value we are destroying, but that's ok. + for (Function::iterator BBI = F->begin(), E = F->end(); BBI != E; ++BBI) + if (ReturnInst *RI = dyn_cast<ReturnInst>(BBI->getTerminator())) + RI->setOperand(0, AccPN); + ++NumAccumAdded; + } + + // Now that all of the PHI nodes are in place, remove the call and + // ret instructions, replacing them with an unconditional branch. + BranchInst::Create(OldEntry, Ret); + BB->getInstList().erase(Ret); // Remove return. + BB->getInstList().erase(CI); // Remove call. + ++NumEliminated; + return true; +} |