diff options
Diffstat (limited to 'contrib/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp')
-rw-r--r-- | contrib/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp | 413 |
1 files changed, 224 insertions, 189 deletions
diff --git a/contrib/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/contrib/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index bf3c33e..f51442a 100644 --- a/contrib/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/contrib/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -11,88 +11,55 @@ // //===----------------------------------------------------------------------===// -#include "InstCombine.h" +#include "InstCombineInternal.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/PatternMatch.h" using namespace llvm; using namespace PatternMatch; #define DEBUG_TYPE "instcombine" -/// MatchSelectPattern - Pattern match integer [SU]MIN, [SU]MAX, and ABS idioms, -/// returning the kind and providing the out parameter results if we -/// successfully match. static SelectPatternFlavor -MatchSelectPattern(Value *V, Value *&LHS, Value *&RHS) { - SelectInst *SI = dyn_cast<SelectInst>(V); - if (!SI) return SPF_UNKNOWN; - - ICmpInst *ICI = dyn_cast<ICmpInst>(SI->getCondition()); - if (!ICI) return SPF_UNKNOWN; - - ICmpInst::Predicate Pred = ICI->getPredicate(); - Value *CmpLHS = ICI->getOperand(0); - Value *CmpRHS = ICI->getOperand(1); - Value *TrueVal = SI->getTrueValue(); - Value *FalseVal = SI->getFalseValue(); - - LHS = CmpLHS; - RHS = CmpRHS; - - // (icmp X, Y) ? X : Y - if (TrueVal == CmpLHS && FalseVal == CmpRHS) { - switch (Pred) { - default: return SPF_UNKNOWN; // Equality. - case ICmpInst::ICMP_UGT: - case ICmpInst::ICMP_UGE: return SPF_UMAX; - case ICmpInst::ICMP_SGT: - case ICmpInst::ICMP_SGE: return SPF_SMAX; - case ICmpInst::ICMP_ULT: - case ICmpInst::ICMP_ULE: return SPF_UMIN; - case ICmpInst::ICMP_SLT: - case ICmpInst::ICMP_SLE: return SPF_SMIN; - } - } - - // (icmp X, Y) ? Y : X - if (TrueVal == CmpRHS && FalseVal == CmpLHS) { - switch (Pred) { - default: return SPF_UNKNOWN; // Equality. - case ICmpInst::ICMP_UGT: - case ICmpInst::ICMP_UGE: return SPF_UMIN; - case ICmpInst::ICMP_SGT: - case ICmpInst::ICMP_SGE: return SPF_SMIN; - case ICmpInst::ICMP_ULT: - case ICmpInst::ICMP_ULE: return SPF_UMAX; - case ICmpInst::ICMP_SLT: - case ICmpInst::ICMP_SLE: return SPF_SMAX; - } +getInverseMinMaxSelectPattern(SelectPatternFlavor SPF) { + switch (SPF) { + default: + llvm_unreachable("unhandled!"); + + case SPF_SMIN: + return SPF_SMAX; + case SPF_UMIN: + return SPF_UMAX; + case SPF_SMAX: + return SPF_SMIN; + case SPF_UMAX: + return SPF_UMIN; } +} - if (ConstantInt *C1 = dyn_cast<ConstantInt>(CmpRHS)) { - if ((CmpLHS == TrueVal && match(FalseVal, m_Neg(m_Specific(CmpLHS)))) || - (CmpLHS == FalseVal && match(TrueVal, m_Neg(m_Specific(CmpLHS))))) { - - // ABS(X) ==> (X >s 0) ? X : -X and (X >s -1) ? X : -X - // NABS(X) ==> (X >s 0) ? -X : X and (X >s -1) ? -X : X - if (Pred == ICmpInst::ICMP_SGT && (C1->isZero() || C1->isMinusOne())) { - return (CmpLHS == TrueVal) ? SPF_ABS : SPF_NABS; - } - - // ABS(X) ==> (X <s 0) ? -X : X and (X <s 1) ? -X : X - // NABS(X) ==> (X <s 0) ? X : -X and (X <s 1) ? X : -X - if (Pred == ICmpInst::ICMP_SLT && (C1->isZero() || C1->isOne())) { - return (CmpLHS == FalseVal) ? SPF_ABS : SPF_NABS; - } - } +static CmpInst::Predicate getICmpPredicateForMinMax(SelectPatternFlavor SPF) { + switch (SPF) { + default: + llvm_unreachable("unhandled!"); + + case SPF_SMIN: + return ICmpInst::ICMP_SLT; + case SPF_UMIN: + return ICmpInst::ICMP_ULT; + case SPF_SMAX: + return ICmpInst::ICMP_SGT; + case SPF_UMAX: + return ICmpInst::ICMP_UGT; } - - // TODO: (X > 4) ? X : 5 --> (X >= 5) ? X : 5 --> MAX(X, 5) - - return SPF_UNKNOWN; } +static Value *generateMinMaxSelectPattern(InstCombiner::BuilderTy *Builder, + SelectPatternFlavor SPF, Value *A, + Value *B) { + CmpInst::Predicate Pred = getICmpPredicateForMinMax(SPF); + return Builder->CreateSelect(Builder->CreateICmp(Pred, A, B), A, B); +} /// GetSelectFoldableOperands - We want to turn code that looks like this: /// %C = or %A, %B @@ -309,72 +276,6 @@ Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal, return nullptr; } -/// SimplifyWithOpReplaced - See if V simplifies when its operand Op is -/// replaced with RepOp. -static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, - const DataLayout *TD, - const TargetLibraryInfo *TLI, - DominatorTree *DT, AssumptionCache *AC) { - // Trivial replacement. - if (V == Op) - return RepOp; - - Instruction *I = dyn_cast<Instruction>(V); - if (!I) - return nullptr; - - // If this is a binary operator, try to simplify it with the replaced op. - if (BinaryOperator *B = dyn_cast<BinaryOperator>(I)) { - if (B->getOperand(0) == Op) - return SimplifyBinOp(B->getOpcode(), RepOp, B->getOperand(1), TD, TLI); - if (B->getOperand(1) == Op) - return SimplifyBinOp(B->getOpcode(), B->getOperand(0), RepOp, TD, TLI); - } - - // Same for CmpInsts. - if (CmpInst *C = dyn_cast<CmpInst>(I)) { - if (C->getOperand(0) == Op) - return SimplifyCmpInst(C->getPredicate(), RepOp, C->getOperand(1), TD, - TLI, DT, AC); - if (C->getOperand(1) == Op) - return SimplifyCmpInst(C->getPredicate(), C->getOperand(0), RepOp, TD, - TLI, DT, AC); - } - - // TODO: We could hand off more cases to instsimplify here. - - // If all operands are constant after substituting Op for RepOp then we can - // constant fold the instruction. - if (Constant *CRepOp = dyn_cast<Constant>(RepOp)) { - // Build a list of all constant operands. - SmallVector<Constant*, 8> ConstOps; - for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { - if (I->getOperand(i) == Op) - ConstOps.push_back(CRepOp); - else if (Constant *COp = dyn_cast<Constant>(I->getOperand(i))) - ConstOps.push_back(COp); - else - break; - } - - // All operands were constants, fold it. - if (ConstOps.size() == I->getNumOperands()) { - if (CmpInst *C = dyn_cast<CmpInst>(I)) - return ConstantFoldCompareInstOperands(C->getPredicate(), ConstOps[0], - ConstOps[1], TD, TLI); - - if (LoadInst *LI = dyn_cast<LoadInst>(I)) - if (!LI->isVolatile()) - return ConstantFoldLoadFromConstPtr(ConstOps[0], TD); - - return ConstantFoldInstOperands(I->getOpcode(), I->getType(), - ConstOps, TD, TLI); - } - } - - return nullptr; -} - /// foldSelectICmpAndOr - We want to turn: /// (select (icmp eq (and X, C1), 0), Y, (or Y, C2)) /// into: @@ -437,6 +338,62 @@ static Value *foldSelectICmpAndOr(const SelectInst &SI, Value *TrueVal, return Builder->CreateOr(V, Y); } +/// Attempt to fold a cttz/ctlz followed by a icmp plus select into a single +/// call to cttz/ctlz with flag 'is_zero_undef' cleared. +/// +/// For example, we can fold the following code sequence: +/// \code +/// %0 = tail call i32 @llvm.cttz.i32(i32 %x, i1 true) +/// %1 = icmp ne i32 %x, 0 +/// %2 = select i1 %1, i32 %0, i32 32 +/// \code +/// +/// into: +/// %0 = tail call i32 @llvm.cttz.i32(i32 %x, i1 false) +static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal, + InstCombiner::BuilderTy *Builder) { + ICmpInst::Predicate Pred = ICI->getPredicate(); + Value *CmpLHS = ICI->getOperand(0); + Value *CmpRHS = ICI->getOperand(1); + + // Check if the condition value compares a value for equality against zero. + if (!ICI->isEquality() || !match(CmpRHS, m_Zero())) + return nullptr; + + Value *Count = FalseVal; + Value *ValueOnZero = TrueVal; + if (Pred == ICmpInst::ICMP_NE) + std::swap(Count, ValueOnZero); + + // Skip zero extend/truncate. + Value *V = nullptr; + if (match(Count, m_ZExt(m_Value(V))) || + match(Count, m_Trunc(m_Value(V)))) + Count = V; + + // Check if the value propagated on zero is a constant number equal to the + // sizeof in bits of 'Count'. + unsigned SizeOfInBits = Count->getType()->getScalarSizeInBits(); + if (!match(ValueOnZero, m_SpecificInt(SizeOfInBits))) + return nullptr; + + // Check that 'Count' is a call to intrinsic cttz/ctlz. Also check that the + // input to the cttz/ctlz is used as LHS for the compare instruction. + if (match(Count, m_Intrinsic<Intrinsic::cttz>(m_Specific(CmpLHS))) || + match(Count, m_Intrinsic<Intrinsic::ctlz>(m_Specific(CmpLHS)))) { + IntrinsicInst *II = cast<IntrinsicInst>(Count); + IRBuilder<> Builder(II); + // Explicitly clear the 'undef_on_zero' flag. + IntrinsicInst *NewI = cast<IntrinsicInst>(II->clone()); + Type *Ty = NewI->getArgOperand(1)->getType(); + NewI->setArgOperand(1, Constant::getNullValue(Ty)); + Builder.Insert(NewI); + return Builder.CreateZExtOrTrunc(NewI, ValueOnZero->getType()); + } + + return nullptr; +} + /// visitSelectInstWithICmp - Visit a SelectInst that has an /// ICmpInst as its first operand. /// @@ -454,14 +411,6 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI, // here, so make sure the select is the only user. if (ICI->hasOneUse()) if (ConstantInt *CI = dyn_cast<ConstantInt>(CmpRHS)) { - // X < MIN ? T : F --> F - if ((Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT) - && CI->isMinValue(Pred == ICmpInst::ICMP_SLT)) - return ReplaceInstUsesWith(SI, FalseVal); - // X > MAX ? T : F --> F - else if ((Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT) - && CI->isMaxValue(Pred == ICmpInst::ICMP_SGT)) - return ReplaceInstUsesWith(SI, FalseVal); switch (Pred) { default: break; case ICmpInst::ICMP_ULT: @@ -575,33 +524,6 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI, } } - // If we have an equality comparison then we know the value in one of the - // arms of the select. See if substituting this value into the arm and - // simplifying the result yields the same value as the other arm. - if (Pred == ICmpInst::ICMP_EQ) { - if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, DL, TLI, DT, AC) == - TrueVal || - SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, DL, TLI, DT, AC) == - TrueVal) - return ReplaceInstUsesWith(SI, FalseVal); - if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, DL, TLI, DT, AC) == - FalseVal || - SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, DL, TLI, DT, AC) == - FalseVal) - return ReplaceInstUsesWith(SI, FalseVal); - } else if (Pred == ICmpInst::ICMP_NE) { - if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, DL, TLI, DT, AC) == - FalseVal || - SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, DL, TLI, DT, AC) == - FalseVal) - return ReplaceInstUsesWith(SI, TrueVal); - if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, DL, TLI, DT, AC) == - TrueVal || - SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, DL, TLI, DT, AC) == - TrueVal) - return ReplaceInstUsesWith(SI, TrueVal); - } - // NOTE: if we wanted to, this is where to detect integer MIN/MAX if (CmpRHS != CmpLHS && isa<Constant>(CmpRHS)) { @@ -616,7 +538,8 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI, } } - if (unsigned BitWidth = TrueVal->getType()->getScalarSizeInBits()) { + { + unsigned BitWidth = DL.getTypeSizeInBits(TrueVal->getType()); APInt MinSignedValue = APInt::getSignBit(BitWidth); Value *X; const APInt *Y, *C; @@ -665,6 +588,9 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI, if (Value *V = foldSelectICmpAndOr(SI, TrueVal, FalseVal, Builder)) return ReplaceInstUsesWith(SI, V); + if (Value *V = foldSelectCttzCtlz(ICI, TrueVal, FalseVal, Builder)) + return ReplaceInstUsesWith(SI, V); + return Changed ? &SI : nullptr; } @@ -770,6 +696,52 @@ Instruction *InstCombiner::FoldSPFofSPF(Instruction *Inner, SI->getCondition(), SI->getFalseValue(), SI->getTrueValue()); return ReplaceInstUsesWith(Outer, NewSI); } + + auto IsFreeOrProfitableToInvert = + [&](Value *V, Value *&NotV, bool &ElidesXor) { + if (match(V, m_Not(m_Value(NotV)))) { + // If V has at most 2 uses then we can get rid of the xor operation + // entirely. + ElidesXor |= !V->hasNUsesOrMore(3); + return true; + } + + if (IsFreeToInvert(V, !V->hasNUsesOrMore(3))) { + NotV = nullptr; + return true; + } + + return false; + }; + + Value *NotA, *NotB, *NotC; + bool ElidesXor = false; + + // MIN(MIN(~A, ~B), ~C) == ~MAX(MAX(A, B), C) + // MIN(MAX(~A, ~B), ~C) == ~MAX(MIN(A, B), C) + // MAX(MIN(~A, ~B), ~C) == ~MIN(MAX(A, B), C) + // MAX(MAX(~A, ~B), ~C) == ~MIN(MIN(A, B), C) + // + // This transform is performance neutral if we can elide at least one xor from + // the set of three operands, since we'll be tacking on an xor at the very + // end. + if (IsFreeOrProfitableToInvert(A, NotA, ElidesXor) && + IsFreeOrProfitableToInvert(B, NotB, ElidesXor) && + IsFreeOrProfitableToInvert(C, NotC, ElidesXor) && ElidesXor) { + if (!NotA) + NotA = Builder->CreateNot(A); + if (!NotB) + NotB = Builder->CreateNot(B); + if (!NotC) + NotC = Builder->CreateNot(C); + + Value *NewInner = generateMinMaxSelectPattern( + Builder, getInverseMinMaxSelectPattern(SPF1), NotA, NotB); + Value *NewOuter = Builder->CreateNot(generateMinMaxSelectPattern( + Builder, getInverseMinMaxSelectPattern(SPF2), NewInner, NotC)); + return ReplaceInstUsesWith(Outer, NewOuter); + } + return nullptr; } @@ -868,7 +840,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { return BinaryOperator::CreateAnd(NotCond, FalseVal); } if (ConstantInt *C = dyn_cast<ConstantInt>(FalseVal)) { - if (C->getZExtValue() == false) { + if (!C->getZExtValue()) { // Change: A = select B, C, false --> A = and B, C return BinaryOperator::CreateAnd(CondVal, TrueVal); } @@ -1082,26 +1054,67 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { } // See if we can fold the select into one of our operands. - if (SI.getType()->isIntegerTy()) { + if (SI.getType()->isIntOrIntVectorTy()) { if (Instruction *FoldI = FoldSelectIntoOp(SI, TrueVal, FalseVal)) return FoldI; - // MAX(MAX(a, b), a) -> MAX(a, b) - // MIN(MIN(a, b), a) -> MIN(a, b) - // MAX(MIN(a, b), a) -> a - // MIN(MAX(a, b), a) -> a Value *LHS, *RHS, *LHS2, *RHS2; - if (SelectPatternFlavor SPF = MatchSelectPattern(&SI, LHS, RHS)) { - if (SelectPatternFlavor SPF2 = MatchSelectPattern(LHS, LHS2, RHS2)) + Instruction::CastOps CastOp; + SelectPatternFlavor SPF = matchSelectPattern(&SI, LHS, RHS, &CastOp); + + if (SPF) { + // Canonicalize so that type casts are outside select patterns. + if (LHS->getType()->getPrimitiveSizeInBits() != + SI.getType()->getPrimitiveSizeInBits()) { + CmpInst::Predicate Pred = getICmpPredicateForMinMax(SPF); + Value *Cmp = Builder->CreateICmp(Pred, LHS, RHS); + Value *NewSI = Builder->CreateCast(CastOp, + Builder->CreateSelect(Cmp, LHS, RHS), + SI.getType()); + return ReplaceInstUsesWith(SI, NewSI); + } + + // MAX(MAX(a, b), a) -> MAX(a, b) + // MIN(MIN(a, b), a) -> MIN(a, b) + // MAX(MIN(a, b), a) -> a + // MIN(MAX(a, b), a) -> a + if (SelectPatternFlavor SPF2 = matchSelectPattern(LHS, LHS2, RHS2)) if (Instruction *R = FoldSPFofSPF(cast<Instruction>(LHS),SPF2,LHS2,RHS2, SI, SPF, RHS)) return R; - if (SelectPatternFlavor SPF2 = MatchSelectPattern(RHS, LHS2, RHS2)) + if (SelectPatternFlavor SPF2 = matchSelectPattern(RHS, LHS2, RHS2)) if (Instruction *R = FoldSPFofSPF(cast<Instruction>(RHS),SPF2,LHS2,RHS2, SI, SPF, LHS)) return R; } + // MAX(~a, ~b) -> ~MIN(a, b) + if (SPF == SPF_SMAX || SPF == SPF_UMAX) { + if (IsFreeToInvert(LHS, LHS->hasNUses(2)) && + IsFreeToInvert(RHS, RHS->hasNUses(2))) { + + // This transform adds a xor operation and that extra cost needs to be + // justified. We look for simplifications that will result from + // applying this rule: + + bool Profitable = + (LHS->hasNUses(2) && match(LHS, m_Not(m_Value()))) || + (RHS->hasNUses(2) && match(RHS, m_Not(m_Value()))) || + (SI.hasOneUse() && match(*SI.user_begin(), m_Not(m_Value()))); + + if (Profitable) { + Value *NewLHS = Builder->CreateNot(LHS); + Value *NewRHS = Builder->CreateNot(RHS); + Value *NewCmp = SPF == SPF_SMAX + ? Builder->CreateICmpSLT(NewLHS, NewRHS) + : Builder->CreateICmpULT(NewLHS, NewRHS); + Value *NewSI = + Builder->CreateNot(Builder->CreateSelect(NewCmp, NewLHS, NewRHS)); + return ReplaceInstUsesWith(SI, NewSI); + } + } + } + // TODO. // ABS(-X) -> ABS(X) } @@ -1115,19 +1128,41 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { return NV; if (SelectInst *TrueSI = dyn_cast<SelectInst>(TrueVal)) { - if (TrueSI->getCondition() == CondVal) { - if (SI.getTrueValue() == TrueSI->getTrueValue()) - return nullptr; - SI.setOperand(1, TrueSI->getTrueValue()); - return &SI; + if (TrueSI->getCondition()->getType() == CondVal->getType()) { + // select(C, select(C, a, b), c) -> select(C, a, c) + if (TrueSI->getCondition() == CondVal) { + if (SI.getTrueValue() == TrueSI->getTrueValue()) + return nullptr; + SI.setOperand(1, TrueSI->getTrueValue()); + return &SI; + } + // select(C0, select(C1, a, b), b) -> select(C0&C1, a, b) + // We choose this as normal form to enable folding on the And and shortening + // paths for the values (this helps GetUnderlyingObjects() for example). + if (TrueSI->getFalseValue() == FalseVal && TrueSI->hasOneUse()) { + Value *And = Builder->CreateAnd(CondVal, TrueSI->getCondition()); + SI.setOperand(0, And); + SI.setOperand(1, TrueSI->getTrueValue()); + return &SI; + } } } if (SelectInst *FalseSI = dyn_cast<SelectInst>(FalseVal)) { - if (FalseSI->getCondition() == CondVal) { - if (SI.getFalseValue() == FalseSI->getFalseValue()) - return nullptr; - SI.setOperand(2, FalseSI->getFalseValue()); - return &SI; + if (FalseSI->getCondition()->getType() == CondVal->getType()) { + // select(C, a, select(C, b, c)) -> select(C, a, c) + if (FalseSI->getCondition() == CondVal) { + if (SI.getFalseValue() == FalseSI->getFalseValue()) + return nullptr; + SI.setOperand(2, FalseSI->getFalseValue()); + return &SI; + } + // select(C0, a, select(C1, a, b)) -> select(C0|C1, a, b) + if (FalseSI->getTrueValue() == TrueVal && FalseSI->hasOneUse()) { + Value *Or = Builder->CreateOr(CondVal, FalseSI->getCondition()); + SI.setOperand(0, Or); + SI.setOperand(2, FalseSI->getFalseValue()); + return &SI; + } } } |