summaryrefslogtreecommitdiffstats
path: root/contrib/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp')
-rw-r--r--contrib/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp345
1 files changed, 240 insertions, 105 deletions
diff --git a/contrib/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/contrib/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index bf3c33e..d2fbcdd 100644
--- a/contrib/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/contrib/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -11,88 +11,55 @@
//
//===----------------------------------------------------------------------===//
-#include "InstCombine.h"
+#include "InstCombineInternal.h"
#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/PatternMatch.h"
using namespace llvm;
using namespace PatternMatch;
#define DEBUG_TYPE "instcombine"
-/// MatchSelectPattern - Pattern match integer [SU]MIN, [SU]MAX, and ABS idioms,
-/// returning the kind and providing the out parameter results if we
-/// successfully match.
static SelectPatternFlavor
-MatchSelectPattern(Value *V, Value *&LHS, Value *&RHS) {
- SelectInst *SI = dyn_cast<SelectInst>(V);
- if (!SI) return SPF_UNKNOWN;
-
- ICmpInst *ICI = dyn_cast<ICmpInst>(SI->getCondition());
- if (!ICI) return SPF_UNKNOWN;
-
- ICmpInst::Predicate Pred = ICI->getPredicate();
- Value *CmpLHS = ICI->getOperand(0);
- Value *CmpRHS = ICI->getOperand(1);
- Value *TrueVal = SI->getTrueValue();
- Value *FalseVal = SI->getFalseValue();
-
- LHS = CmpLHS;
- RHS = CmpRHS;
-
- // (icmp X, Y) ? X : Y
- if (TrueVal == CmpLHS && FalseVal == CmpRHS) {
- switch (Pred) {
- default: return SPF_UNKNOWN; // Equality.
- case ICmpInst::ICMP_UGT:
- case ICmpInst::ICMP_UGE: return SPF_UMAX;
- case ICmpInst::ICMP_SGT:
- case ICmpInst::ICMP_SGE: return SPF_SMAX;
- case ICmpInst::ICMP_ULT:
- case ICmpInst::ICMP_ULE: return SPF_UMIN;
- case ICmpInst::ICMP_SLT:
- case ICmpInst::ICMP_SLE: return SPF_SMIN;
- }
- }
-
- // (icmp X, Y) ? Y : X
- if (TrueVal == CmpRHS && FalseVal == CmpLHS) {
- switch (Pred) {
- default: return SPF_UNKNOWN; // Equality.
- case ICmpInst::ICMP_UGT:
- case ICmpInst::ICMP_UGE: return SPF_UMIN;
- case ICmpInst::ICMP_SGT:
- case ICmpInst::ICMP_SGE: return SPF_SMIN;
- case ICmpInst::ICMP_ULT:
- case ICmpInst::ICMP_ULE: return SPF_UMAX;
- case ICmpInst::ICMP_SLT:
- case ICmpInst::ICMP_SLE: return SPF_SMAX;
- }
+getInverseMinMaxSelectPattern(SelectPatternFlavor SPF) {
+ switch (SPF) {
+ default:
+ llvm_unreachable("unhandled!");
+
+ case SPF_SMIN:
+ return SPF_SMAX;
+ case SPF_UMIN:
+ return SPF_UMAX;
+ case SPF_SMAX:
+ return SPF_SMIN;
+ case SPF_UMAX:
+ return SPF_UMIN;
}
+}
- if (ConstantInt *C1 = dyn_cast<ConstantInt>(CmpRHS)) {
- if ((CmpLHS == TrueVal && match(FalseVal, m_Neg(m_Specific(CmpLHS)))) ||
- (CmpLHS == FalseVal && match(TrueVal, m_Neg(m_Specific(CmpLHS))))) {
-
- // ABS(X) ==> (X >s 0) ? X : -X and (X >s -1) ? X : -X
- // NABS(X) ==> (X >s 0) ? -X : X and (X >s -1) ? -X : X
- if (Pred == ICmpInst::ICMP_SGT && (C1->isZero() || C1->isMinusOne())) {
- return (CmpLHS == TrueVal) ? SPF_ABS : SPF_NABS;
- }
-
- // ABS(X) ==> (X <s 0) ? -X : X and (X <s 1) ? -X : X
- // NABS(X) ==> (X <s 0) ? X : -X and (X <s 1) ? X : -X
- if (Pred == ICmpInst::ICMP_SLT && (C1->isZero() || C1->isOne())) {
- return (CmpLHS == FalseVal) ? SPF_ABS : SPF_NABS;
- }
- }
+static CmpInst::Predicate getICmpPredicateForMinMax(SelectPatternFlavor SPF) {
+ switch (SPF) {
+ default:
+ llvm_unreachable("unhandled!");
+
+ case SPF_SMIN:
+ return ICmpInst::ICMP_SLT;
+ case SPF_UMIN:
+ return ICmpInst::ICMP_ULT;
+ case SPF_SMAX:
+ return ICmpInst::ICMP_SGT;
+ case SPF_UMAX:
+ return ICmpInst::ICMP_UGT;
}
-
- // TODO: (X > 4) ? X : 5 --> (X >= 5) ? X : 5 --> MAX(X, 5)
-
- return SPF_UNKNOWN;
}
+static Value *generateMinMaxSelectPattern(InstCombiner::BuilderTy *Builder,
+ SelectPatternFlavor SPF, Value *A,
+ Value *B) {
+ CmpInst::Predicate Pred = getICmpPredicateForMinMax(SPF);
+ return Builder->CreateSelect(Builder->CreateICmp(Pred, A, B), A, B);
+}
/// GetSelectFoldableOperands - We want to turn code that looks like this:
/// %C = or %A, %B
@@ -312,9 +279,9 @@ Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal,
/// SimplifyWithOpReplaced - See if V simplifies when its operand Op is
/// replaced with RepOp.
static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
- const DataLayout *TD,
const TargetLibraryInfo *TLI,
- DominatorTree *DT, AssumptionCache *AC) {
+ const DataLayout &DL, DominatorTree *DT,
+ AssumptionCache *AC) {
// Trivial replacement.
if (V == Op)
return RepOp;
@@ -326,18 +293,18 @@ static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
// If this is a binary operator, try to simplify it with the replaced op.
if (BinaryOperator *B = dyn_cast<BinaryOperator>(I)) {
if (B->getOperand(0) == Op)
- return SimplifyBinOp(B->getOpcode(), RepOp, B->getOperand(1), TD, TLI);
+ return SimplifyBinOp(B->getOpcode(), RepOp, B->getOperand(1), DL, TLI);
if (B->getOperand(1) == Op)
- return SimplifyBinOp(B->getOpcode(), B->getOperand(0), RepOp, TD, TLI);
+ return SimplifyBinOp(B->getOpcode(), B->getOperand(0), RepOp, DL, TLI);
}
// Same for CmpInsts.
if (CmpInst *C = dyn_cast<CmpInst>(I)) {
if (C->getOperand(0) == Op)
- return SimplifyCmpInst(C->getPredicate(), RepOp, C->getOperand(1), TD,
+ return SimplifyCmpInst(C->getPredicate(), RepOp, C->getOperand(1), DL,
TLI, DT, AC);
if (C->getOperand(1) == Op)
- return SimplifyCmpInst(C->getPredicate(), C->getOperand(0), RepOp, TD,
+ return SimplifyCmpInst(C->getPredicate(), C->getOperand(0), RepOp, DL,
TLI, DT, AC);
}
@@ -361,14 +328,14 @@ static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
if (ConstOps.size() == I->getNumOperands()) {
if (CmpInst *C = dyn_cast<CmpInst>(I))
return ConstantFoldCompareInstOperands(C->getPredicate(), ConstOps[0],
- ConstOps[1], TD, TLI);
+ ConstOps[1], DL, TLI);
if (LoadInst *LI = dyn_cast<LoadInst>(I))
if (!LI->isVolatile())
- return ConstantFoldLoadFromConstPtr(ConstOps[0], TD);
+ return ConstantFoldLoadFromConstPtr(ConstOps[0], DL);
- return ConstantFoldInstOperands(I->getOpcode(), I->getType(),
- ConstOps, TD, TLI);
+ return ConstantFoldInstOperands(I->getOpcode(), I->getType(), ConstOps,
+ DL, TLI);
}
}
@@ -437,6 +404,62 @@ static Value *foldSelectICmpAndOr(const SelectInst &SI, Value *TrueVal,
return Builder->CreateOr(V, Y);
}
+/// Attempt to fold a cttz/ctlz followed by a icmp plus select into a single
+/// call to cttz/ctlz with flag 'is_zero_undef' cleared.
+///
+/// For example, we can fold the following code sequence:
+/// \code
+/// %0 = tail call i32 @llvm.cttz.i32(i32 %x, i1 true)
+/// %1 = icmp ne i32 %x, 0
+/// %2 = select i1 %1, i32 %0, i32 32
+/// \code
+///
+/// into:
+/// %0 = tail call i32 @llvm.cttz.i32(i32 %x, i1 false)
+static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal,
+ InstCombiner::BuilderTy *Builder) {
+ ICmpInst::Predicate Pred = ICI->getPredicate();
+ Value *CmpLHS = ICI->getOperand(0);
+ Value *CmpRHS = ICI->getOperand(1);
+
+ // Check if the condition value compares a value for equality against zero.
+ if (!ICI->isEquality() || !match(CmpRHS, m_Zero()))
+ return nullptr;
+
+ Value *Count = FalseVal;
+ Value *ValueOnZero = TrueVal;
+ if (Pred == ICmpInst::ICMP_NE)
+ std::swap(Count, ValueOnZero);
+
+ // Skip zero extend/truncate.
+ Value *V = nullptr;
+ if (match(Count, m_ZExt(m_Value(V))) ||
+ match(Count, m_Trunc(m_Value(V))))
+ Count = V;
+
+ // Check if the value propagated on zero is a constant number equal to the
+ // sizeof in bits of 'Count'.
+ unsigned SizeOfInBits = Count->getType()->getScalarSizeInBits();
+ if (!match(ValueOnZero, m_SpecificInt(SizeOfInBits)))
+ return nullptr;
+
+ // Check that 'Count' is a call to intrinsic cttz/ctlz. Also check that the
+ // input to the cttz/ctlz is used as LHS for the compare instruction.
+ if (match(Count, m_Intrinsic<Intrinsic::cttz>(m_Specific(CmpLHS))) ||
+ match(Count, m_Intrinsic<Intrinsic::ctlz>(m_Specific(CmpLHS)))) {
+ IntrinsicInst *II = cast<IntrinsicInst>(Count);
+ IRBuilder<> Builder(II);
+ // Explicitly clear the 'undef_on_zero' flag.
+ IntrinsicInst *NewI = cast<IntrinsicInst>(II->clone());
+ Type *Ty = NewI->getArgOperand(1)->getType();
+ NewI->setArgOperand(1, Constant::getNullValue(Ty));
+ Builder.Insert(NewI);
+ return Builder.CreateZExtOrTrunc(NewI, ValueOnZero->getType());
+ }
+
+ return nullptr;
+}
+
/// visitSelectInstWithICmp - Visit a SelectInst that has an
/// ICmpInst as its first operand.
///
@@ -579,25 +602,25 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI,
// arms of the select. See if substituting this value into the arm and
// simplifying the result yields the same value as the other arm.
if (Pred == ICmpInst::ICMP_EQ) {
- if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, DL, TLI, DT, AC) ==
+ if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, TLI, DL, DT, AC) ==
TrueVal ||
- SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, DL, TLI, DT, AC) ==
+ SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, TLI, DL, DT, AC) ==
TrueVal)
return ReplaceInstUsesWith(SI, FalseVal);
- if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, DL, TLI, DT, AC) ==
+ if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, TLI, DL, DT, AC) ==
FalseVal ||
- SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, DL, TLI, DT, AC) ==
+ SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, TLI, DL, DT, AC) ==
FalseVal)
return ReplaceInstUsesWith(SI, FalseVal);
} else if (Pred == ICmpInst::ICMP_NE) {
- if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, DL, TLI, DT, AC) ==
+ if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, TLI, DL, DT, AC) ==
FalseVal ||
- SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, DL, TLI, DT, AC) ==
+ SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, TLI, DL, DT, AC) ==
FalseVal)
return ReplaceInstUsesWith(SI, TrueVal);
- if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, DL, TLI, DT, AC) ==
+ if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, TLI, DL, DT, AC) ==
TrueVal ||
- SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, DL, TLI, DT, AC) ==
+ SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, TLI, DL, DT, AC) ==
TrueVal)
return ReplaceInstUsesWith(SI, TrueVal);
}
@@ -665,6 +688,9 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI,
if (Value *V = foldSelectICmpAndOr(SI, TrueVal, FalseVal, Builder))
return ReplaceInstUsesWith(SI, V);
+ if (Value *V = foldSelectCttzCtlz(ICI, TrueVal, FalseVal, Builder))
+ return ReplaceInstUsesWith(SI, V);
+
return Changed ? &SI : nullptr;
}
@@ -770,6 +796,52 @@ Instruction *InstCombiner::FoldSPFofSPF(Instruction *Inner,
SI->getCondition(), SI->getFalseValue(), SI->getTrueValue());
return ReplaceInstUsesWith(Outer, NewSI);
}
+
+ auto IsFreeOrProfitableToInvert =
+ [&](Value *V, Value *&NotV, bool &ElidesXor) {
+ if (match(V, m_Not(m_Value(NotV)))) {
+ // If V has at most 2 uses then we can get rid of the xor operation
+ // entirely.
+ ElidesXor |= !V->hasNUsesOrMore(3);
+ return true;
+ }
+
+ if (IsFreeToInvert(V, !V->hasNUsesOrMore(3))) {
+ NotV = nullptr;
+ return true;
+ }
+
+ return false;
+ };
+
+ Value *NotA, *NotB, *NotC;
+ bool ElidesXor = false;
+
+ // MIN(MIN(~A, ~B), ~C) == ~MAX(MAX(A, B), C)
+ // MIN(MAX(~A, ~B), ~C) == ~MAX(MIN(A, B), C)
+ // MAX(MIN(~A, ~B), ~C) == ~MIN(MAX(A, B), C)
+ // MAX(MAX(~A, ~B), ~C) == ~MIN(MIN(A, B), C)
+ //
+ // This transform is performance neutral if we can elide at least one xor from
+ // the set of three operands, since we'll be tacking on an xor at the very
+ // end.
+ if (IsFreeOrProfitableToInvert(A, NotA, ElidesXor) &&
+ IsFreeOrProfitableToInvert(B, NotB, ElidesXor) &&
+ IsFreeOrProfitableToInvert(C, NotC, ElidesXor) && ElidesXor) {
+ if (!NotA)
+ NotA = Builder->CreateNot(A);
+ if (!NotB)
+ NotB = Builder->CreateNot(B);
+ if (!NotC)
+ NotC = Builder->CreateNot(C);
+
+ Value *NewInner = generateMinMaxSelectPattern(
+ Builder, getInverseMinMaxSelectPattern(SPF1), NotA, NotB);
+ Value *NewOuter = Builder->CreateNot(generateMinMaxSelectPattern(
+ Builder, getInverseMinMaxSelectPattern(SPF2), NewInner, NotC));
+ return ReplaceInstUsesWith(Outer, NewOuter);
+ }
+
return nullptr;
}
@@ -868,7 +940,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
return BinaryOperator::CreateAnd(NotCond, FalseVal);
}
if (ConstantInt *C = dyn_cast<ConstantInt>(FalseVal)) {
- if (C->getZExtValue() == false) {
+ if (!C->getZExtValue()) {
// Change: A = select B, C, false --> A = and B, C
return BinaryOperator::CreateAnd(CondVal, TrueVal);
}
@@ -1082,26 +1154,67 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
}
// See if we can fold the select into one of our operands.
- if (SI.getType()->isIntegerTy()) {
+ if (SI.getType()->isIntOrIntVectorTy()) {
if (Instruction *FoldI = FoldSelectIntoOp(SI, TrueVal, FalseVal))
return FoldI;
- // MAX(MAX(a, b), a) -> MAX(a, b)
- // MIN(MIN(a, b), a) -> MIN(a, b)
- // MAX(MIN(a, b), a) -> a
- // MIN(MAX(a, b), a) -> a
Value *LHS, *RHS, *LHS2, *RHS2;
- if (SelectPatternFlavor SPF = MatchSelectPattern(&SI, LHS, RHS)) {
- if (SelectPatternFlavor SPF2 = MatchSelectPattern(LHS, LHS2, RHS2))
+ Instruction::CastOps CastOp;
+ SelectPatternFlavor SPF = matchSelectPattern(&SI, LHS, RHS, &CastOp);
+
+ if (SPF) {
+ // Canonicalize so that type casts are outside select patterns.
+ if (LHS->getType()->getPrimitiveSizeInBits() !=
+ SI.getType()->getPrimitiveSizeInBits()) {
+ CmpInst::Predicate Pred = getICmpPredicateForMinMax(SPF);
+ Value *Cmp = Builder->CreateICmp(Pred, LHS, RHS);
+ Value *NewSI = Builder->CreateCast(CastOp,
+ Builder->CreateSelect(Cmp, LHS, RHS),
+ SI.getType());
+ return ReplaceInstUsesWith(SI, NewSI);
+ }
+
+ // MAX(MAX(a, b), a) -> MAX(a, b)
+ // MIN(MIN(a, b), a) -> MIN(a, b)
+ // MAX(MIN(a, b), a) -> a
+ // MIN(MAX(a, b), a) -> a
+ if (SelectPatternFlavor SPF2 = matchSelectPattern(LHS, LHS2, RHS2))
if (Instruction *R = FoldSPFofSPF(cast<Instruction>(LHS),SPF2,LHS2,RHS2,
SI, SPF, RHS))
return R;
- if (SelectPatternFlavor SPF2 = MatchSelectPattern(RHS, LHS2, RHS2))
+ if (SelectPatternFlavor SPF2 = matchSelectPattern(RHS, LHS2, RHS2))
if (Instruction *R = FoldSPFofSPF(cast<Instruction>(RHS),SPF2,LHS2,RHS2,
SI, SPF, LHS))
return R;
}
+ // MAX(~a, ~b) -> ~MIN(a, b)
+ if (SPF == SPF_SMAX || SPF == SPF_UMAX) {
+ if (IsFreeToInvert(LHS, LHS->hasNUses(2)) &&
+ IsFreeToInvert(RHS, RHS->hasNUses(2))) {
+
+ // This transform adds a xor operation and that extra cost needs to be
+ // justified. We look for simplifications that will result from
+ // applying this rule:
+
+ bool Profitable =
+ (LHS->hasNUses(2) && match(LHS, m_Not(m_Value()))) ||
+ (RHS->hasNUses(2) && match(RHS, m_Not(m_Value()))) ||
+ (SI.hasOneUse() && match(*SI.user_begin(), m_Not(m_Value())));
+
+ if (Profitable) {
+ Value *NewLHS = Builder->CreateNot(LHS);
+ Value *NewRHS = Builder->CreateNot(RHS);
+ Value *NewCmp = SPF == SPF_SMAX
+ ? Builder->CreateICmpSLT(NewLHS, NewRHS)
+ : Builder->CreateICmpULT(NewLHS, NewRHS);
+ Value *NewSI =
+ Builder->CreateNot(Builder->CreateSelect(NewCmp, NewLHS, NewRHS));
+ return ReplaceInstUsesWith(SI, NewSI);
+ }
+ }
+ }
+
// TODO.
// ABS(-X) -> ABS(X)
}
@@ -1115,19 +1228,41 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
return NV;
if (SelectInst *TrueSI = dyn_cast<SelectInst>(TrueVal)) {
- if (TrueSI->getCondition() == CondVal) {
- if (SI.getTrueValue() == TrueSI->getTrueValue())
- return nullptr;
- SI.setOperand(1, TrueSI->getTrueValue());
- return &SI;
+ if (TrueSI->getCondition()->getType() == CondVal->getType()) {
+ // select(C, select(C, a, b), c) -> select(C, a, c)
+ if (TrueSI->getCondition() == CondVal) {
+ if (SI.getTrueValue() == TrueSI->getTrueValue())
+ return nullptr;
+ SI.setOperand(1, TrueSI->getTrueValue());
+ return &SI;
+ }
+ // select(C0, select(C1, a, b), b) -> select(C0&C1, a, b)
+ // We choose this as normal form to enable folding on the And and shortening
+ // paths for the values (this helps GetUnderlyingObjects() for example).
+ if (TrueSI->getFalseValue() == FalseVal && TrueSI->hasOneUse()) {
+ Value *And = Builder->CreateAnd(CondVal, TrueSI->getCondition());
+ SI.setOperand(0, And);
+ SI.setOperand(1, TrueSI->getTrueValue());
+ return &SI;
+ }
}
}
if (SelectInst *FalseSI = dyn_cast<SelectInst>(FalseVal)) {
- if (FalseSI->getCondition() == CondVal) {
- if (SI.getFalseValue() == FalseSI->getFalseValue())
- return nullptr;
- SI.setOperand(2, FalseSI->getFalseValue());
- return &SI;
+ if (FalseSI->getCondition()->getType() == CondVal->getType()) {
+ // select(C, a, select(C, b, c)) -> select(C, a, c)
+ if (FalseSI->getCondition() == CondVal) {
+ if (SI.getFalseValue() == FalseSI->getFalseValue())
+ return nullptr;
+ SI.setOperand(2, FalseSI->getFalseValue());
+ return &SI;
+ }
+ // select(C0, a, select(C1, a, b)) -> select(C0|C1, a, b)
+ if (FalseSI->getTrueValue() == TrueVal && FalseSI->hasOneUse()) {
+ Value *Or = Builder->CreateOr(CondVal, FalseSI->getCondition());
+ SI.setOperand(0, Or);
+ SI.setOperand(2, FalseSI->getFalseValue());
+ return &SI;
+ }
}
}
OpenPOWER on IntegriCloud