diff options
author | dim <dim@FreeBSD.org> | 2017-04-02 17:24:58 +0000 |
---|---|---|
committer | dim <dim@FreeBSD.org> | 2017-04-02 17:24:58 +0000 |
commit | 60b571e49a90d38697b3aca23020d9da42fc7d7f (patch) | |
tree | 99351324c24d6cb146b6285b6caffa4d26fce188 /contrib/llvm/lib/Analysis/InstructionSimplify.cpp | |
parent | bea1b22c7a9bce1dfdd73e6e5b65bc4752215180 (diff) | |
download | FreeBSD-src-60b571e49a90d38697b3aca23020d9da42fc7d7f.zip FreeBSD-src-60b571e49a90d38697b3aca23020d9da42fc7d7f.tar.gz |
Update clang, llvm, lld, lldb, compiler-rt and libc++ to 4.0.0 release:
MFC r309142 (by emaste):
Add WITH_LLD_AS_LD build knob
If set it installs LLD as /usr/bin/ld. LLD (as of version 3.9) is not
capable of linking the world and kernel, but can self-host and link many
substantial applications. GNU ld continues to be used for the world and
kernel build, regardless of how this knob is set.
It is on by default for arm64, and off for all other CPU architectures.
Sponsored by: The FreeBSD Foundation
MFC r310840:
Reapply 310775, now it also builds correctly if lldb is disabled:
Move llvm-objdump from CLANG_EXTRAS to installed by default
We currently install three tools from binutils 2.17.50: as, ld, and
objdump. Work is underway to migrate to a permissively-licensed
tool-chain, with one goal being the retirement of binutils 2.17.50.
LLVM's llvm-objdump is intended to be compatible with GNU objdump
although it is currently missing some options and may have formatting
differences. Enable it by default for testing and further investigation.
It may later be changed to install as /usr/bin/objdump, it becomes a
fully viable replacement.
Reviewed by: emaste
Differential Revision: https://reviews.freebsd.org/D8879
MFC r312855 (by emaste):
Rename LLD_AS_LD to LLD_IS_LD, for consistency with CLANG_IS_CC
Reported by: Dan McGregor <dan.mcgregor usask.ca>
MFC r313559 | glebius | 2017-02-10 18:34:48 +0100 (Fri, 10 Feb 2017) | 5 lines
Don't check struct rtentry on FreeBSD, it is an internal kernel structure.
On other systems it may be API structure for SIOCADDRT/SIOCDELRT.
Reviewed by: emaste, dim
MFC r314152 (by jkim):
Remove an assembler flag, which is redundant since r309124. The upstream
took care of it by introducing a macro NO_EXEC_STACK_DIRECTIVE.
http://llvm.org/viewvc/llvm-project?rev=273500&view=rev
Reviewed by: dim
MFC r314564:
Upgrade our copies of clang, llvm, lld, lldb, compiler-rt and libc++ to
4.0.0 (branches/release_40 296509). The release will follow soon.
Please note that from 3.5.0 onwards, clang, llvm and lldb require C++11
support to build; see UPDATING for more information.
Also note that as of 4.0.0, lld should be able to link the base system
on amd64 and aarch64. See the WITH_LLD_IS_LLD setting in src.conf(5).
Though please be aware that this is work in progress.
Release notes for llvm, clang and lld will be available here:
<http://releases.llvm.org/4.0.0/docs/ReleaseNotes.html>
<http://releases.llvm.org/4.0.0/tools/clang/docs/ReleaseNotes.html>
<http://releases.llvm.org/4.0.0/tools/lld/docs/ReleaseNotes.html>
Thanks to Ed Maste, Jan Beich, Antoine Brodin and Eric Fiselier for
their help.
Relnotes: yes
Exp-run: antoine
PR: 215969, 216008
MFC r314708:
For now, revert r287232 from upstream llvm trunk (by Daniil Fukalov):
[SCEV] limit recursion depth of CompareSCEVComplexity
Summary:
CompareSCEVComplexity goes too deep (50+ on a quite a big unrolled
loop) and runs almost infinite time.
Added cache of "equal" SCEV pairs to earlier cutoff of further
estimation. Recursion depth limit was also introduced as a parameter.
Reviewers: sanjoy
Subscribers: mzolotukhin, tstellarAMD, llvm-commits
Differential Revision: https://reviews.llvm.org/D26389
This commit is the cause of excessive compile times on skein_block.c
(and possibly other files) during kernel builds on amd64.
We never saw the problematic behavior described in this upstream commit,
so for now it is better to revert it. An upstream bug has been filed
here: https://bugs.llvm.org/show_bug.cgi?id=32142
Reported by: mjg
MFC r314795:
Reapply r287232 from upstream llvm trunk (by Daniil Fukalov):
[SCEV] limit recursion depth of CompareSCEVComplexity
Summary:
CompareSCEVComplexity goes too deep (50+ on a quite a big unrolled
loop) and runs almost infinite time.
Added cache of "equal" SCEV pairs to earlier cutoff of further
estimation. Recursion depth limit was also introduced as a parameter.
Reviewers: sanjoy
Subscribers: mzolotukhin, tstellarAMD, llvm-commits
Differential Revision: https://reviews.llvm.org/D26389
Pull in r296992 from upstream llvm trunk (by Sanjoy Das):
[SCEV] Decrease the recursion threshold for CompareValueComplexity
Fixes PR32142.
r287232 accidentally increased the recursion threshold for
CompareValueComplexity from 2 to 32. This change reverses that
change by introducing a separate flag for CompareValueComplexity's
threshold.
The latter revision fixes the excessive compile times for skein_block.c.
MFC r314907 | mmel | 2017-03-08 12:40:27 +0100 (Wed, 08 Mar 2017) | 7 lines
Unbreak ARMv6 world.
The new compiler_rt library imported with clang 4.0.0 have several fatal
issues (non-functional __udivsi3 for example) with ARM specific instrict
functions. As temporary workaround, until upstream solve these problems,
disable all thumb[1][2] related feature.
MFC r315016:
Update clang, llvm, lld, lldb, compiler-rt and libc++ to 4.0.0 release.
We were already very close to the last release candidate, so this is a
pretty minor update.
Relnotes: yes
MFC r316005:
Revert r314907, and pull in r298713 from upstream compiler-rt trunk (by
Weiming Zhao):
builtins: Select correct code fragments when compiling for Thumb1/Thum2/ARM ISA.
Summary:
Value of __ARM_ARCH_ISA_THUMB isn't based on the actual compilation
mode (-mthumb, -marm), it reflect's capability of given CPU.
Due to this:
- use __tbumb__ and __thumb2__ insteand of __ARM_ARCH_ISA_THUMB
- use '.thumb' directive consistently in all affected files
- decorate all thumb functions using
DEFINE_COMPILERRT_THUMB_FUNCTION()
---------
Note: This patch doesn't fix broken Thumb1 variant of __udivsi3 !
Reviewers: weimingz, rengolin, compnerd
Subscribers: aemerson, dim
Differential Revision: https://reviews.llvm.org/D30938
Discussed with: mmel
Diffstat (limited to 'contrib/llvm/lib/Analysis/InstructionSimplify.cpp')
-rw-r--r-- | contrib/llvm/lib/Analysis/InstructionSimplify.cpp | 1666 |
1 files changed, 988 insertions, 678 deletions
diff --git a/contrib/llvm/lib/Analysis/InstructionSimplify.cpp b/contrib/llvm/lib/Analysis/InstructionSimplify.cpp index aeaf938..796e6e4 100644 --- a/contrib/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/contrib/llvm/lib/Analysis/InstructionSimplify.cpp @@ -67,9 +67,12 @@ static Value *SimplifyFPBinOp(unsigned, Value *, Value *, const FastMathFlags &, const Query &, unsigned); static Value *SimplifyCmpInst(unsigned, Value *, Value *, const Query &, unsigned); +static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, + const Query &Q, unsigned MaxRecurse); static Value *SimplifyOrInst(Value *, Value *, const Query &, unsigned); static Value *SimplifyXorInst(Value *, Value *, const Query &, unsigned); -static Value *SimplifyTruncInst(Value *, Type *, const Query &, unsigned); +static Value *SimplifyCastInst(unsigned, Value *, Type *, + const Query &, unsigned); /// For a boolean type, or a vector of boolean type, return false, or /// a vector with every element false, as appropriate for the type. @@ -679,9 +682,26 @@ static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, if (Op0 == Op1) return Constant::getNullValue(Op0->getType()); - // 0 - X -> 0 if the sub is NUW. - if (isNUW && match(Op0, m_Zero())) - return Op0; + // Is this a negation? + if (match(Op0, m_Zero())) { + // 0 - X -> 0 if the sub is NUW. + if (isNUW) + return Op0; + + unsigned BitWidth = Op1->getType()->getScalarSizeInBits(); + APInt KnownZero(BitWidth, 0); + APInt KnownOne(BitWidth, 0); + computeKnownBits(Op1, KnownZero, KnownOne, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (KnownZero == ~APInt::getSignBit(BitWidth)) { + // Op1 is either 0 or the minimum signed value. If the sub is NSW, then + // Op1 must be 0 because negating the minimum signed value is undefined. + if (isNSW) + return Op0; + + // 0 - X -> X if X is 0 or the minimum signed value. + return Op1; + } + } // (X + Y) - Z -> X + (Y - Z) or Y + (X - Z) if everything simplifies. // For example, (X + Y) - Y -> X; (Y + X) - Y -> X @@ -747,7 +767,8 @@ static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, // See if "V === X - Y" simplifies. if (Value *V = SimplifyBinOp(Instruction::Sub, X, Y, Q, MaxRecurse-1)) // It does! Now see if "trunc V" simplifies. - if (Value *W = SimplifyTruncInst(V, Op0->getType(), Q, MaxRecurse-1)) + if (Value *W = SimplifyCastInst(Instruction::Trunc, V, Op0->getType(), + Q, MaxRecurse - 1)) // It does, return the simplified "trunc V". return W; @@ -1085,6 +1106,16 @@ static Value *SimplifyUDivInst(Value *Op0, Value *Op1, const Query &Q, if (Value *V = SimplifyDiv(Instruction::UDiv, Op0, Op1, Q, MaxRecurse)) return V; + // udiv %V, C -> 0 if %V < C + if (MaxRecurse) { + if (Constant *C = dyn_cast_or_null<Constant>(SimplifyICmpInst( + ICmpInst::ICMP_ULT, Op0, Op1, Q, MaxRecurse - 1))) { + if (C->isAllOnesValue()) { + return Constant::getNullValue(Op0->getType()); + } + } + } + return nullptr; } @@ -1106,6 +1137,10 @@ static Value *SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF, if (match(Op1, m_Undef())) return Op1; + // X / 1.0 -> X + if (match(Op1, m_FPOne())) + return Op0; + // 0 / X -> 0 // Requires that NaNs are off (X could be zero) and signed zeroes are // ignored (X could be positive or negative, so the output sign is unknown). @@ -1222,6 +1257,16 @@ static Value *SimplifyURemInst(Value *Op0, Value *Op1, const Query &Q, if (Value *V = SimplifyRem(Instruction::URem, Op0, Op1, Q, MaxRecurse)) return V; + // urem %V, C -> %V if %V < C + if (MaxRecurse) { + if (Constant *C = dyn_cast_or_null<Constant>(SimplifyICmpInst( + ICmpInst::ICMP_ULT, Op0, Op1, Q, MaxRecurse - 1))) { + if (C->isAllOnesValue()) { + return Op0; + } + } + } + return nullptr; } @@ -1497,17 +1542,45 @@ static Value *simplifyUnsignedRangeCheck(ICmpInst *ZeroICmp, return nullptr; } -static Value *SimplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1) { - Type *ITy = Op0->getType(); +/// Commuted variants are assumed to be handled by calling this function again +/// with the parameters swapped. +static Value *simplifyAndOfICmpsWithSameOperands(ICmpInst *Op0, ICmpInst *Op1) { ICmpInst::Predicate Pred0, Pred1; - ConstantInt *CI1, *CI2; - Value *V; + Value *A ,*B; + if (!match(Op0, m_ICmp(Pred0, m_Value(A), m_Value(B))) || + !match(Op1, m_ICmp(Pred1, m_Specific(A), m_Specific(B)))) + return nullptr; + // We have (icmp Pred0, A, B) & (icmp Pred1, A, B). + // If Op1 is always implied true by Op0, then Op0 is a subset of Op1, and we + // can eliminate Op1 from this 'and'. + if (ICmpInst::isImpliedTrueByMatchingCmp(Pred0, Pred1)) + return Op0; + + // Check for any combination of predicates that are guaranteed to be disjoint. + if ((Pred0 == ICmpInst::getInversePredicate(Pred1)) || + (Pred0 == ICmpInst::ICMP_EQ && ICmpInst::isFalseWhenEqual(Pred1)) || + (Pred0 == ICmpInst::ICMP_SLT && Pred1 == ICmpInst::ICMP_SGT) || + (Pred0 == ICmpInst::ICMP_ULT && Pred1 == ICmpInst::ICMP_UGT)) + return getFalse(Op0->getType()); + + return nullptr; +} + +/// Commuted variants are assumed to be handled by calling this function again +/// with the parameters swapped. +static Value *SimplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1) { if (Value *X = simplifyUnsignedRangeCheck(Op0, Op1, /*IsAnd=*/true)) return X; + if (Value *X = simplifyAndOfICmpsWithSameOperands(Op0, Op1)) + return X; + // Look for this pattern: (icmp V, C0) & (icmp V, C1)). + Type *ITy = Op0->getType(); + ICmpInst::Predicate Pred0, Pred1; const APInt *C0, *C1; + Value *V; if (match(Op0, m_ICmp(Pred0, m_Value(V), m_APInt(C0))) && match(Op1, m_ICmp(Pred1, m_Specific(V), m_APInt(C1)))) { // Make a constant range that's the intersection of the two icmp ranges. @@ -1518,21 +1591,22 @@ static Value *SimplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1) { return getFalse(ITy); } - if (!match(Op0, m_ICmp(Pred0, m_Add(m_Value(V), m_ConstantInt(CI1)), - m_ConstantInt(CI2)))) + // (icmp (add V, C0), C1) & (icmp V, C0) + if (!match(Op0, m_ICmp(Pred0, m_Add(m_Value(V), m_APInt(C0)), m_APInt(C1)))) return nullptr; - if (!match(Op1, m_ICmp(Pred1, m_Specific(V), m_Specific(CI1)))) + if (!match(Op1, m_ICmp(Pred1, m_Specific(V), m_Value()))) return nullptr; auto *AddInst = cast<BinaryOperator>(Op0->getOperand(0)); + if (AddInst->getOperand(1) != Op1->getOperand(1)) + return nullptr; + bool isNSW = AddInst->hasNoSignedWrap(); bool isNUW = AddInst->hasNoUnsignedWrap(); - const APInt &CI1V = CI1->getValue(); - const APInt &CI2V = CI2->getValue(); - const APInt Delta = CI2V - CI1V; - if (CI1V.isStrictlyPositive()) { + const APInt Delta = *C1 - *C0; + if (C0->isStrictlyPositive()) { if (Delta == 2) { if (Pred0 == ICmpInst::ICMP_ULT && Pred1 == ICmpInst::ICMP_SGT) return getFalse(ITy); @@ -1546,7 +1620,7 @@ static Value *SimplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1) { return getFalse(ITy); } } - if (CI1V.getBoolValue() && isNUW) { + if (C0->getBoolValue() && isNUW) { if (Delta == 2) if (Pred0 == ICmpInst::ICMP_ULT && Pred1 == ICmpInst::ICMP_UGT) return getFalse(ITy); @@ -1680,33 +1754,61 @@ Value *llvm::SimplifyAndInst(Value *Op0, Value *Op1, const DataLayout &DL, RecursionLimit); } -/// Simplify (or (icmp ...) (icmp ...)) to true when we can tell that the union -/// contains all possible values. -static Value *SimplifyOrOfICmps(ICmpInst *Op0, ICmpInst *Op1) { +/// Commuted variants are assumed to be handled by calling this function again +/// with the parameters swapped. +static Value *simplifyOrOfICmpsWithSameOperands(ICmpInst *Op0, ICmpInst *Op1) { ICmpInst::Predicate Pred0, Pred1; - ConstantInt *CI1, *CI2; - Value *V; + Value *A ,*B; + if (!match(Op0, m_ICmp(Pred0, m_Value(A), m_Value(B))) || + !match(Op1, m_ICmp(Pred1, m_Specific(A), m_Specific(B)))) + return nullptr; + + // We have (icmp Pred0, A, B) | (icmp Pred1, A, B). + // If Op1 is always implied true by Op0, then Op0 is a subset of Op1, and we + // can eliminate Op0 from this 'or'. + if (ICmpInst::isImpliedTrueByMatchingCmp(Pred0, Pred1)) + return Op1; + + // Check for any combination of predicates that cover the entire range of + // possibilities. + if ((Pred0 == ICmpInst::getInversePredicate(Pred1)) || + (Pred0 == ICmpInst::ICMP_NE && ICmpInst::isTrueWhenEqual(Pred1)) || + (Pred0 == ICmpInst::ICMP_SLE && Pred1 == ICmpInst::ICMP_SGE) || + (Pred0 == ICmpInst::ICMP_ULE && Pred1 == ICmpInst::ICMP_UGE)) + return getTrue(Op0->getType()); + + return nullptr; +} +/// Commuted variants are assumed to be handled by calling this function again +/// with the parameters swapped. +static Value *SimplifyOrOfICmps(ICmpInst *Op0, ICmpInst *Op1) { if (Value *X = simplifyUnsignedRangeCheck(Op0, Op1, /*IsAnd=*/false)) return X; - if (!match(Op0, m_ICmp(Pred0, m_Add(m_Value(V), m_ConstantInt(CI1)), - m_ConstantInt(CI2)))) - return nullptr; + if (Value *X = simplifyOrOfICmpsWithSameOperands(Op0, Op1)) + return X; - if (!match(Op1, m_ICmp(Pred1, m_Specific(V), m_Specific(CI1)))) + // (icmp (add V, C0), C1) | (icmp V, C0) + ICmpInst::Predicate Pred0, Pred1; + const APInt *C0, *C1; + Value *V; + if (!match(Op0, m_ICmp(Pred0, m_Add(m_Value(V), m_APInt(C0)), m_APInt(C1)))) return nullptr; - Type *ITy = Op0->getType(); + if (!match(Op1, m_ICmp(Pred1, m_Specific(V), m_Value()))) + return nullptr; auto *AddInst = cast<BinaryOperator>(Op0->getOperand(0)); + if (AddInst->getOperand(1) != Op1->getOperand(1)) + return nullptr; + + Type *ITy = Op0->getType(); bool isNSW = AddInst->hasNoSignedWrap(); bool isNUW = AddInst->hasNoUnsignedWrap(); - const APInt &CI1V = CI1->getValue(); - const APInt &CI2V = CI2->getValue(); - const APInt Delta = CI2V - CI1V; - if (CI1V.isStrictlyPositive()) { + const APInt Delta = *C1 - *C0; + if (C0->isStrictlyPositive()) { if (Delta == 2) { if (Pred0 == ICmpInst::ICMP_UGE && Pred1 == ICmpInst::ICMP_SLE) return getTrue(ITy); @@ -1720,7 +1822,7 @@ static Value *SimplifyOrOfICmps(ICmpInst *Op0, ICmpInst *Op1) { return getTrue(ITy); } } - if (CI1V.getBoolValue() && isNUW) { + if (C0->getBoolValue() && isNUW) { if (Delta == 2) if (Pred0 == ICmpInst::ICMP_UGE && Pred1 == ICmpInst::ICMP_ULE) return getTrue(ITy); @@ -2102,8 +2204,8 @@ computePointerICmp(const DataLayout &DL, const TargetLibraryInfo *TLI, GetUnderlyingObjects(RHS, RHSUObjs, DL); // Is the set of underlying objects all noalias calls? - auto IsNAC = [](SmallVectorImpl<Value *> &Objects) { - return std::all_of(Objects.begin(), Objects.end(), isNoAliasCall); + auto IsNAC = [](ArrayRef<Value *> Objects) { + return all_of(Objects, isNoAliasCall); }; // Is the set of underlying objects all things which must be disjoint from @@ -2112,8 +2214,8 @@ computePointerICmp(const DataLayout &DL, const TargetLibraryInfo *TLI, // live with the compared-to allocation). For globals, we exclude symbols // that might be resolve lazily to symbols in another dynamically-loaded // library (and, thus, could be malloc'ed by the implementation). - auto IsAllocDisjoint = [](SmallVectorImpl<Value *> &Objects) { - return std::all_of(Objects.begin(), Objects.end(), [](Value *V) { + auto IsAllocDisjoint = [](ArrayRef<Value *> Objects) { + return all_of(Objects, [](Value *V) { if (const AllocaInst *AI = dyn_cast<AllocaInst>(V)) return AI->getParent() && AI->getFunction() && AI->isStaticAlloca(); if (const GlobalValue *GV = dyn_cast<GlobalValue>(V)) @@ -2150,470 +2252,275 @@ computePointerICmp(const DataLayout &DL, const TargetLibraryInfo *TLI, return nullptr; } -/// Given operands for an ICmpInst, see if we can fold the result. -/// If not, this returns null. -static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, - const Query &Q, unsigned MaxRecurse) { - CmpInst::Predicate Pred = (CmpInst::Predicate)Predicate; - assert(CmpInst::isIntPredicate(Pred) && "Not an integer compare!"); - - if (Constant *CLHS = dyn_cast<Constant>(LHS)) { - if (Constant *CRHS = dyn_cast<Constant>(RHS)) - return ConstantFoldCompareInstOperands(Pred, CLHS, CRHS, Q.DL, Q.TLI); - - // If we have a constant, make sure it is on the RHS. - std::swap(LHS, RHS); - Pred = CmpInst::getSwappedPredicate(Pred); - } - +/// Fold an icmp when its operands have i1 scalar type. +static Value *simplifyICmpOfBools(CmpInst::Predicate Pred, Value *LHS, + Value *RHS, const Query &Q) { Type *ITy = GetCompareTy(LHS); // The return type. Type *OpTy = LHS->getType(); // The operand type. + if (!OpTy->getScalarType()->isIntegerTy(1)) + return nullptr; - // icmp X, X -> true/false - // X icmp undef -> true/false. For example, icmp ugt %X, undef -> false - // because X could be 0. - if (LHS == RHS || isa<UndefValue>(RHS)) - return ConstantInt::get(ITy, CmpInst::isTrueWhenEqual(Pred)); - - // Special case logic when the operands have i1 type. - if (OpTy->getScalarType()->isIntegerTy(1)) { - switch (Pred) { - default: break; - case ICmpInst::ICMP_EQ: - // X == 1 -> X - if (match(RHS, m_One())) - return LHS; - break; - case ICmpInst::ICMP_NE: - // X != 0 -> X - if (match(RHS, m_Zero())) - return LHS; - break; - case ICmpInst::ICMP_UGT: - // X >u 0 -> X - if (match(RHS, m_Zero())) - return LHS; - break; - case ICmpInst::ICMP_UGE: { - // X >=u 1 -> X - if (match(RHS, m_One())) - return LHS; - if (isImpliedCondition(RHS, LHS, Q.DL).getValueOr(false)) - return getTrue(ITy); - break; - } - case ICmpInst::ICMP_SGE: { - /// For signed comparison, the values for an i1 are 0 and -1 - /// respectively. This maps into a truth table of: - /// LHS | RHS | LHS >=s RHS | LHS implies RHS - /// 0 | 0 | 1 (0 >= 0) | 1 - /// 0 | 1 | 1 (0 >= -1) | 1 - /// 1 | 0 | 0 (-1 >= 0) | 0 - /// 1 | 1 | 1 (-1 >= -1) | 1 - if (isImpliedCondition(LHS, RHS, Q.DL).getValueOr(false)) - return getTrue(ITy); - break; - } - case ICmpInst::ICMP_SLT: - // X <s 0 -> X - if (match(RHS, m_Zero())) - return LHS; - break; - case ICmpInst::ICMP_SLE: - // X <=s -1 -> X - if (match(RHS, m_One())) - return LHS; - break; - case ICmpInst::ICMP_ULE: { - if (isImpliedCondition(LHS, RHS, Q.DL).getValueOr(false)) - return getTrue(ITy); - break; - } - } - } - - // If we are comparing with zero then try hard since this is a common case. - if (match(RHS, m_Zero())) { - bool LHSKnownNonNegative, LHSKnownNegative; - switch (Pred) { - default: llvm_unreachable("Unknown ICmp predicate!"); - case ICmpInst::ICMP_ULT: - return getFalse(ITy); - case ICmpInst::ICMP_UGE: + switch (Pred) { + default: + break; + case ICmpInst::ICMP_EQ: + // X == 1 -> X + if (match(RHS, m_One())) + return LHS; + break; + case ICmpInst::ICMP_NE: + // X != 0 -> X + if (match(RHS, m_Zero())) + return LHS; + break; + case ICmpInst::ICMP_UGT: + // X >u 0 -> X + if (match(RHS, m_Zero())) + return LHS; + break; + case ICmpInst::ICMP_UGE: + // X >=u 1 -> X + if (match(RHS, m_One())) + return LHS; + if (isImpliedCondition(RHS, LHS, Q.DL).getValueOr(false)) return getTrue(ITy); - case ICmpInst::ICMP_EQ: - case ICmpInst::ICMP_ULE: - if (isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) - return getFalse(ITy); - break; - case ICmpInst::ICMP_NE: - case ICmpInst::ICMP_UGT: - if (isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) - return getTrue(ITy); - break; - case ICmpInst::ICMP_SLT: - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, 0, Q.AC, - Q.CxtI, Q.DT); - if (LHSKnownNegative) - return getTrue(ITy); - if (LHSKnownNonNegative) - return getFalse(ITy); - break; - case ICmpInst::ICMP_SLE: - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, 0, Q.AC, - Q.CxtI, Q.DT); - if (LHSKnownNegative) - return getTrue(ITy); - if (LHSKnownNonNegative && - isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) - return getFalse(ITy); - break; - case ICmpInst::ICMP_SGE: - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, 0, Q.AC, - Q.CxtI, Q.DT); - if (LHSKnownNegative) - return getFalse(ITy); - if (LHSKnownNonNegative) - return getTrue(ITy); - break; - case ICmpInst::ICMP_SGT: - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, 0, Q.AC, - Q.CxtI, Q.DT); - if (LHSKnownNegative) - return getFalse(ITy); - if (LHSKnownNonNegative && - isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) - return getTrue(ITy); - break; - } + break; + case ICmpInst::ICMP_SGE: + /// For signed comparison, the values for an i1 are 0 and -1 + /// respectively. This maps into a truth table of: + /// LHS | RHS | LHS >=s RHS | LHS implies RHS + /// 0 | 0 | 1 (0 >= 0) | 1 + /// 0 | 1 | 1 (0 >= -1) | 1 + /// 1 | 0 | 0 (-1 >= 0) | 0 + /// 1 | 1 | 1 (-1 >= -1) | 1 + if (isImpliedCondition(LHS, RHS, Q.DL).getValueOr(false)) + return getTrue(ITy); + break; + case ICmpInst::ICMP_SLT: + // X <s 0 -> X + if (match(RHS, m_Zero())) + return LHS; + break; + case ICmpInst::ICMP_SLE: + // X <=s -1 -> X + if (match(RHS, m_One())) + return LHS; + break; + case ICmpInst::ICMP_ULE: + if (isImpliedCondition(LHS, RHS, Q.DL).getValueOr(false)) + return getTrue(ITy); + break; } - // See if we are doing a comparison with a constant integer. - if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) { - // Rule out tautological comparisons (eg., ult 0 or uge 0). - ConstantRange RHS_CR = ICmpInst::makeConstantRange(Pred, CI->getValue()); - if (RHS_CR.isEmptySet()) - return ConstantInt::getFalse(CI->getContext()); - if (RHS_CR.isFullSet()) - return ConstantInt::getTrue(CI->getContext()); - - // Many binary operators with constant RHS have easy to compute constant - // range. Use them to check whether the comparison is a tautology. - unsigned Width = CI->getBitWidth(); - APInt Lower = APInt(Width, 0); - APInt Upper = APInt(Width, 0); - ConstantInt *CI2; - if (match(LHS, m_URem(m_Value(), m_ConstantInt(CI2)))) { - // 'urem x, CI2' produces [0, CI2). - Upper = CI2->getValue(); - } else if (match(LHS, m_SRem(m_Value(), m_ConstantInt(CI2)))) { - // 'srem x, CI2' produces (-|CI2|, |CI2|). - Upper = CI2->getValue().abs(); - Lower = (-Upper) + 1; - } else if (match(LHS, m_UDiv(m_ConstantInt(CI2), m_Value()))) { - // 'udiv CI2, x' produces [0, CI2]. - Upper = CI2->getValue() + 1; - } else if (match(LHS, m_UDiv(m_Value(), m_ConstantInt(CI2)))) { - // 'udiv x, CI2' produces [0, UINT_MAX / CI2]. - APInt NegOne = APInt::getAllOnesValue(Width); - if (!CI2->isZero()) - Upper = NegOne.udiv(CI2->getValue()) + 1; - } else if (match(LHS, m_SDiv(m_ConstantInt(CI2), m_Value()))) { - if (CI2->isMinSignedValue()) { - // 'sdiv INT_MIN, x' produces [INT_MIN, INT_MIN / -2]. - Lower = CI2->getValue(); - Upper = Lower.lshr(1) + 1; - } else { - // 'sdiv CI2, x' produces [-|CI2|, |CI2|]. - Upper = CI2->getValue().abs() + 1; - Lower = (-Upper) + 1; - } - } else if (match(LHS, m_SDiv(m_Value(), m_ConstantInt(CI2)))) { - APInt IntMin = APInt::getSignedMinValue(Width); - APInt IntMax = APInt::getSignedMaxValue(Width); - const APInt &Val = CI2->getValue(); - if (Val.isAllOnesValue()) { - // 'sdiv x, -1' produces [INT_MIN + 1, INT_MAX] - // where CI2 != -1 and CI2 != 0 and CI2 != 1 - Lower = IntMin + 1; - Upper = IntMax + 1; - } else if (Val.countLeadingZeros() < Width - 1) { - // 'sdiv x, CI2' produces [INT_MIN / CI2, INT_MAX / CI2] - // where CI2 != -1 and CI2 != 0 and CI2 != 1 - Lower = IntMin.sdiv(Val); - Upper = IntMax.sdiv(Val); - if (Lower.sgt(Upper)) - std::swap(Lower, Upper); - Upper = Upper + 1; - assert(Upper != Lower && "Upper part of range has wrapped!"); - } - } else if (match(LHS, m_NUWShl(m_ConstantInt(CI2), m_Value()))) { - // 'shl nuw CI2, x' produces [CI2, CI2 << CLZ(CI2)] - Lower = CI2->getValue(); - Upper = Lower.shl(Lower.countLeadingZeros()) + 1; - } else if (match(LHS, m_NSWShl(m_ConstantInt(CI2), m_Value()))) { - if (CI2->isNegative()) { - // 'shl nsw CI2, x' produces [CI2 << CLO(CI2)-1, CI2] - unsigned ShiftAmount = CI2->getValue().countLeadingOnes() - 1; - Lower = CI2->getValue().shl(ShiftAmount); - Upper = CI2->getValue() + 1; - } else { - // 'shl nsw CI2, x' produces [CI2, CI2 << CLZ(CI2)-1] - unsigned ShiftAmount = CI2->getValue().countLeadingZeros() - 1; - Lower = CI2->getValue(); - Upper = CI2->getValue().shl(ShiftAmount) + 1; - } - } else if (match(LHS, m_LShr(m_Value(), m_ConstantInt(CI2)))) { - // 'lshr x, CI2' produces [0, UINT_MAX >> CI2]. - APInt NegOne = APInt::getAllOnesValue(Width); - if (CI2->getValue().ult(Width)) - Upper = NegOne.lshr(CI2->getValue()) + 1; - } else if (match(LHS, m_LShr(m_ConstantInt(CI2), m_Value()))) { - // 'lshr CI2, x' produces [CI2 >> (Width-1), CI2]. - unsigned ShiftAmount = Width - 1; - if (!CI2->isZero() && cast<BinaryOperator>(LHS)->isExact()) - ShiftAmount = CI2->getValue().countTrailingZeros(); - Lower = CI2->getValue().lshr(ShiftAmount); - Upper = CI2->getValue() + 1; - } else if (match(LHS, m_AShr(m_Value(), m_ConstantInt(CI2)))) { - // 'ashr x, CI2' produces [INT_MIN >> CI2, INT_MAX >> CI2]. - APInt IntMin = APInt::getSignedMinValue(Width); - APInt IntMax = APInt::getSignedMaxValue(Width); - if (CI2->getValue().ult(Width)) { - Lower = IntMin.ashr(CI2->getValue()); - Upper = IntMax.ashr(CI2->getValue()) + 1; - } - } else if (match(LHS, m_AShr(m_ConstantInt(CI2), m_Value()))) { - unsigned ShiftAmount = Width - 1; - if (!CI2->isZero() && cast<BinaryOperator>(LHS)->isExact()) - ShiftAmount = CI2->getValue().countTrailingZeros(); - if (CI2->isNegative()) { - // 'ashr CI2, x' produces [CI2, CI2 >> (Width-1)] - Lower = CI2->getValue(); - Upper = CI2->getValue().ashr(ShiftAmount) + 1; - } else { - // 'ashr CI2, x' produces [CI2 >> (Width-1), CI2] - Lower = CI2->getValue().ashr(ShiftAmount); - Upper = CI2->getValue() + 1; - } - } else if (match(LHS, m_Or(m_Value(), m_ConstantInt(CI2)))) { - // 'or x, CI2' produces [CI2, UINT_MAX]. - Lower = CI2->getValue(); - } else if (match(LHS, m_And(m_Value(), m_ConstantInt(CI2)))) { - // 'and x, CI2' produces [0, CI2]. - Upper = CI2->getValue() + 1; - } else if (match(LHS, m_NUWAdd(m_Value(), m_ConstantInt(CI2)))) { - // 'add nuw x, CI2' produces [CI2, UINT_MAX]. - Lower = CI2->getValue(); - } - - ConstantRange LHS_CR = Lower != Upper ? ConstantRange(Lower, Upper) - : ConstantRange(Width, true); + return nullptr; +} - if (auto *I = dyn_cast<Instruction>(LHS)) - if (auto *Ranges = I->getMetadata(LLVMContext::MD_range)) - LHS_CR = LHS_CR.intersectWith(getConstantRangeFromMetadata(*Ranges)); +/// Try hard to fold icmp with zero RHS because this is a common case. +static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS, + Value *RHS, const Query &Q) { + if (!match(RHS, m_Zero())) + return nullptr; - if (!LHS_CR.isFullSet()) { - if (RHS_CR.contains(LHS_CR)) - return ConstantInt::getTrue(RHS->getContext()); - if (RHS_CR.inverse().contains(LHS_CR)) - return ConstantInt::getFalse(RHS->getContext()); - } + Type *ITy = GetCompareTy(LHS); // The return type. + bool LHSKnownNonNegative, LHSKnownNegative; + switch (Pred) { + default: + llvm_unreachable("Unknown ICmp predicate!"); + case ICmpInst::ICMP_ULT: + return getFalse(ITy); + case ICmpInst::ICMP_UGE: + return getTrue(ITy); + case ICmpInst::ICMP_EQ: + case ICmpInst::ICMP_ULE: + if (isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) + return getFalse(ITy); + break; + case ICmpInst::ICMP_NE: + case ICmpInst::ICMP_UGT: + if (isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) + return getTrue(ITy); + break; + case ICmpInst::ICMP_SLT: + ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, 0, Q.AC, + Q.CxtI, Q.DT); + if (LHSKnownNegative) + return getTrue(ITy); + if (LHSKnownNonNegative) + return getFalse(ITy); + break; + case ICmpInst::ICMP_SLE: + ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, 0, Q.AC, + Q.CxtI, Q.DT); + if (LHSKnownNegative) + return getTrue(ITy); + if (LHSKnownNonNegative && isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) + return getFalse(ITy); + break; + case ICmpInst::ICMP_SGE: + ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, 0, Q.AC, + Q.CxtI, Q.DT); + if (LHSKnownNegative) + return getFalse(ITy); + if (LHSKnownNonNegative) + return getTrue(ITy); + break; + case ICmpInst::ICMP_SGT: + ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, 0, Q.AC, + Q.CxtI, Q.DT); + if (LHSKnownNegative) + return getFalse(ITy); + if (LHSKnownNonNegative && isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) + return getTrue(ITy); + break; } - // If both operands have range metadata, use the metadata - // to simplify the comparison. - if (isa<Instruction>(RHS) && isa<Instruction>(LHS)) { - auto RHS_Instr = dyn_cast<Instruction>(RHS); - auto LHS_Instr = dyn_cast<Instruction>(LHS); - - if (RHS_Instr->getMetadata(LLVMContext::MD_range) && - LHS_Instr->getMetadata(LLVMContext::MD_range)) { - auto RHS_CR = getConstantRangeFromMetadata( - *RHS_Instr->getMetadata(LLVMContext::MD_range)); - auto LHS_CR = getConstantRangeFromMetadata( - *LHS_Instr->getMetadata(LLVMContext::MD_range)); + return nullptr; +} - auto Satisfied_CR = ConstantRange::makeSatisfyingICmpRegion(Pred, RHS_CR); - if (Satisfied_CR.contains(LHS_CR)) - return ConstantInt::getTrue(RHS->getContext()); +static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS, + Value *RHS) { + const APInt *C; + if (!match(RHS, m_APInt(C))) + return nullptr; - auto InversedSatisfied_CR = ConstantRange::makeSatisfyingICmpRegion( - CmpInst::getInversePredicate(Pred), RHS_CR); - if (InversedSatisfied_CR.contains(LHS_CR)) - return ConstantInt::getFalse(RHS->getContext()); + // Rule out tautological comparisons (eg., ult 0 or uge 0). + ConstantRange RHS_CR = ConstantRange::makeExactICmpRegion(Pred, *C); + if (RHS_CR.isEmptySet()) + return ConstantInt::getFalse(GetCompareTy(RHS)); + if (RHS_CR.isFullSet()) + return ConstantInt::getTrue(GetCompareTy(RHS)); + + // Many binary operators with constant RHS have easy to compute constant + // range. Use them to check whether the comparison is a tautology. + unsigned Width = C->getBitWidth(); + APInt Lower = APInt(Width, 0); + APInt Upper = APInt(Width, 0); + const APInt *C2; + if (match(LHS, m_URem(m_Value(), m_APInt(C2)))) { + // 'urem x, C2' produces [0, C2). + Upper = *C2; + } else if (match(LHS, m_SRem(m_Value(), m_APInt(C2)))) { + // 'srem x, C2' produces (-|C2|, |C2|). + Upper = C2->abs(); + Lower = (-Upper) + 1; + } else if (match(LHS, m_UDiv(m_APInt(C2), m_Value()))) { + // 'udiv C2, x' produces [0, C2]. + Upper = *C2 + 1; + } else if (match(LHS, m_UDiv(m_Value(), m_APInt(C2)))) { + // 'udiv x, C2' produces [0, UINT_MAX / C2]. + APInt NegOne = APInt::getAllOnesValue(Width); + if (*C2 != 0) + Upper = NegOne.udiv(*C2) + 1; + } else if (match(LHS, m_SDiv(m_APInt(C2), m_Value()))) { + if (C2->isMinSignedValue()) { + // 'sdiv INT_MIN, x' produces [INT_MIN, INT_MIN / -2]. + Lower = *C2; + Upper = Lower.lshr(1) + 1; + } else { + // 'sdiv C2, x' produces [-|C2|, |C2|]. + Upper = C2->abs() + 1; + Lower = (-Upper) + 1; } - } - - // Compare of cast, for example (zext X) != 0 -> X != 0 - if (isa<CastInst>(LHS) && (isa<Constant>(RHS) || isa<CastInst>(RHS))) { - Instruction *LI = cast<CastInst>(LHS); - Value *SrcOp = LI->getOperand(0); - Type *SrcTy = SrcOp->getType(); - Type *DstTy = LI->getType(); - - // Turn icmp (ptrtoint x), (ptrtoint/constant) into a compare of the input - // if the integer type is the same size as the pointer type. - if (MaxRecurse && isa<PtrToIntInst>(LI) && - Q.DL.getTypeSizeInBits(SrcTy) == DstTy->getPrimitiveSizeInBits()) { - if (Constant *RHSC = dyn_cast<Constant>(RHS)) { - // Transfer the cast to the constant. - if (Value *V = SimplifyICmpInst(Pred, SrcOp, - ConstantExpr::getIntToPtr(RHSC, SrcTy), - Q, MaxRecurse-1)) - return V; - } else if (PtrToIntInst *RI = dyn_cast<PtrToIntInst>(RHS)) { - if (RI->getOperand(0)->getType() == SrcTy) - // Compare without the cast. - if (Value *V = SimplifyICmpInst(Pred, SrcOp, RI->getOperand(0), - Q, MaxRecurse-1)) - return V; - } + } else if (match(LHS, m_SDiv(m_Value(), m_APInt(C2)))) { + APInt IntMin = APInt::getSignedMinValue(Width); + APInt IntMax = APInt::getSignedMaxValue(Width); + if (C2->isAllOnesValue()) { + // 'sdiv x, -1' produces [INT_MIN + 1, INT_MAX] + // where C2 != -1 and C2 != 0 and C2 != 1 + Lower = IntMin + 1; + Upper = IntMax + 1; + } else if (C2->countLeadingZeros() < Width - 1) { + // 'sdiv x, C2' produces [INT_MIN / C2, INT_MAX / C2] + // where C2 != -1 and C2 != 0 and C2 != 1 + Lower = IntMin.sdiv(*C2); + Upper = IntMax.sdiv(*C2); + if (Lower.sgt(Upper)) + std::swap(Lower, Upper); + Upper = Upper + 1; + assert(Upper != Lower && "Upper part of range has wrapped!"); } - - if (isa<ZExtInst>(LHS)) { - // Turn icmp (zext X), (zext Y) into a compare of X and Y if they have the - // same type. - if (ZExtInst *RI = dyn_cast<ZExtInst>(RHS)) { - if (MaxRecurse && SrcTy == RI->getOperand(0)->getType()) - // Compare X and Y. Note that signed predicates become unsigned. - if (Value *V = SimplifyICmpInst(ICmpInst::getUnsignedPredicate(Pred), - SrcOp, RI->getOperand(0), Q, - MaxRecurse-1)) - return V; - } - // Turn icmp (zext X), Cst into a compare of X and Cst if Cst is extended - // too. If not, then try to deduce the result of the comparison. - else if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) { - // Compute the constant that would happen if we truncated to SrcTy then - // reextended to DstTy. - Constant *Trunc = ConstantExpr::getTrunc(CI, SrcTy); - Constant *RExt = ConstantExpr::getCast(CastInst::ZExt, Trunc, DstTy); - - // If the re-extended constant didn't change then this is effectively - // also a case of comparing two zero-extended values. - if (RExt == CI && MaxRecurse) - if (Value *V = SimplifyICmpInst(ICmpInst::getUnsignedPredicate(Pred), - SrcOp, Trunc, Q, MaxRecurse-1)) - return V; - - // Otherwise the upper bits of LHS are zero while RHS has a non-zero bit - // there. Use this to work out the result of the comparison. - if (RExt != CI) { - switch (Pred) { - default: llvm_unreachable("Unknown ICmp predicate!"); - // LHS <u RHS. - case ICmpInst::ICMP_EQ: - case ICmpInst::ICMP_UGT: - case ICmpInst::ICMP_UGE: - return ConstantInt::getFalse(CI->getContext()); - - case ICmpInst::ICMP_NE: - case ICmpInst::ICMP_ULT: - case ICmpInst::ICMP_ULE: - return ConstantInt::getTrue(CI->getContext()); - - // LHS is non-negative. If RHS is negative then LHS >s LHS. If RHS - // is non-negative then LHS <s RHS. - case ICmpInst::ICMP_SGT: - case ICmpInst::ICMP_SGE: - return CI->getValue().isNegative() ? - ConstantInt::getTrue(CI->getContext()) : - ConstantInt::getFalse(CI->getContext()); - - case ICmpInst::ICMP_SLT: - case ICmpInst::ICMP_SLE: - return CI->getValue().isNegative() ? - ConstantInt::getFalse(CI->getContext()) : - ConstantInt::getTrue(CI->getContext()); - } - } - } + } else if (match(LHS, m_NUWShl(m_APInt(C2), m_Value()))) { + // 'shl nuw C2, x' produces [C2, C2 << CLZ(C2)] + Lower = *C2; + Upper = Lower.shl(Lower.countLeadingZeros()) + 1; + } else if (match(LHS, m_NSWShl(m_APInt(C2), m_Value()))) { + if (C2->isNegative()) { + // 'shl nsw C2, x' produces [C2 << CLO(C2)-1, C2] + unsigned ShiftAmount = C2->countLeadingOnes() - 1; + Lower = C2->shl(ShiftAmount); + Upper = *C2 + 1; + } else { + // 'shl nsw C2, x' produces [C2, C2 << CLZ(C2)-1] + unsigned ShiftAmount = C2->countLeadingZeros() - 1; + Lower = *C2; + Upper = C2->shl(ShiftAmount) + 1; } + } else if (match(LHS, m_LShr(m_Value(), m_APInt(C2)))) { + // 'lshr x, C2' produces [0, UINT_MAX >> C2]. + APInt NegOne = APInt::getAllOnesValue(Width); + if (C2->ult(Width)) + Upper = NegOne.lshr(*C2) + 1; + } else if (match(LHS, m_LShr(m_APInt(C2), m_Value()))) { + // 'lshr C2, x' produces [C2 >> (Width-1), C2]. + unsigned ShiftAmount = Width - 1; + if (*C2 != 0 && cast<BinaryOperator>(LHS)->isExact()) + ShiftAmount = C2->countTrailingZeros(); + Lower = C2->lshr(ShiftAmount); + Upper = *C2 + 1; + } else if (match(LHS, m_AShr(m_Value(), m_APInt(C2)))) { + // 'ashr x, C2' produces [INT_MIN >> C2, INT_MAX >> C2]. + APInt IntMin = APInt::getSignedMinValue(Width); + APInt IntMax = APInt::getSignedMaxValue(Width); + if (C2->ult(Width)) { + Lower = IntMin.ashr(*C2); + Upper = IntMax.ashr(*C2) + 1; + } + } else if (match(LHS, m_AShr(m_APInt(C2), m_Value()))) { + unsigned ShiftAmount = Width - 1; + if (*C2 != 0 && cast<BinaryOperator>(LHS)->isExact()) + ShiftAmount = C2->countTrailingZeros(); + if (C2->isNegative()) { + // 'ashr C2, x' produces [C2, C2 >> (Width-1)] + Lower = *C2; + Upper = C2->ashr(ShiftAmount) + 1; + } else { + // 'ashr C2, x' produces [C2 >> (Width-1), C2] + Lower = C2->ashr(ShiftAmount); + Upper = *C2 + 1; + } + } else if (match(LHS, m_Or(m_Value(), m_APInt(C2)))) { + // 'or x, C2' produces [C2, UINT_MAX]. + Lower = *C2; + } else if (match(LHS, m_And(m_Value(), m_APInt(C2)))) { + // 'and x, C2' produces [0, C2]. + Upper = *C2 + 1; + } else if (match(LHS, m_NUWAdd(m_Value(), m_APInt(C2)))) { + // 'add nuw x, C2' produces [C2, UINT_MAX]. + Lower = *C2; + } - if (isa<SExtInst>(LHS)) { - // Turn icmp (sext X), (sext Y) into a compare of X and Y if they have the - // same type. - if (SExtInst *RI = dyn_cast<SExtInst>(RHS)) { - if (MaxRecurse && SrcTy == RI->getOperand(0)->getType()) - // Compare X and Y. Note that the predicate does not change. - if (Value *V = SimplifyICmpInst(Pred, SrcOp, RI->getOperand(0), - Q, MaxRecurse-1)) - return V; - } - // Turn icmp (sext X), Cst into a compare of X and Cst if Cst is extended - // too. If not, then try to deduce the result of the comparison. - else if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) { - // Compute the constant that would happen if we truncated to SrcTy then - // reextended to DstTy. - Constant *Trunc = ConstantExpr::getTrunc(CI, SrcTy); - Constant *RExt = ConstantExpr::getCast(CastInst::SExt, Trunc, DstTy); - - // If the re-extended constant didn't change then this is effectively - // also a case of comparing two sign-extended values. - if (RExt == CI && MaxRecurse) - if (Value *V = SimplifyICmpInst(Pred, SrcOp, Trunc, Q, MaxRecurse-1)) - return V; - - // Otherwise the upper bits of LHS are all equal, while RHS has varying - // bits there. Use this to work out the result of the comparison. - if (RExt != CI) { - switch (Pred) { - default: llvm_unreachable("Unknown ICmp predicate!"); - case ICmpInst::ICMP_EQ: - return ConstantInt::getFalse(CI->getContext()); - case ICmpInst::ICMP_NE: - return ConstantInt::getTrue(CI->getContext()); + ConstantRange LHS_CR = + Lower != Upper ? ConstantRange(Lower, Upper) : ConstantRange(Width, true); - // If RHS is non-negative then LHS <s RHS. If RHS is negative then - // LHS >s RHS. - case ICmpInst::ICMP_SGT: - case ICmpInst::ICMP_SGE: - return CI->getValue().isNegative() ? - ConstantInt::getTrue(CI->getContext()) : - ConstantInt::getFalse(CI->getContext()); - case ICmpInst::ICMP_SLT: - case ICmpInst::ICMP_SLE: - return CI->getValue().isNegative() ? - ConstantInt::getFalse(CI->getContext()) : - ConstantInt::getTrue(CI->getContext()); + if (auto *I = dyn_cast<Instruction>(LHS)) + if (auto *Ranges = I->getMetadata(LLVMContext::MD_range)) + LHS_CR = LHS_CR.intersectWith(getConstantRangeFromMetadata(*Ranges)); - // If LHS is non-negative then LHS <u RHS. If LHS is negative then - // LHS >u RHS. - case ICmpInst::ICMP_UGT: - case ICmpInst::ICMP_UGE: - // Comparison is true iff the LHS <s 0. - if (MaxRecurse) - if (Value *V = SimplifyICmpInst(ICmpInst::ICMP_SLT, SrcOp, - Constant::getNullValue(SrcTy), - Q, MaxRecurse-1)) - return V; - break; - case ICmpInst::ICMP_ULT: - case ICmpInst::ICMP_ULE: - // Comparison is true iff the LHS >=s 0. - if (MaxRecurse) - if (Value *V = SimplifyICmpInst(ICmpInst::ICMP_SGE, SrcOp, - Constant::getNullValue(SrcTy), - Q, MaxRecurse-1)) - return V; - break; - } - } - } - } + if (!LHS_CR.isFullSet()) { + if (RHS_CR.contains(LHS_CR)) + return ConstantInt::getTrue(GetCompareTy(RHS)); + if (RHS_CR.inverse().contains(LHS_CR)) + return ConstantInt::getFalse(GetCompareTy(RHS)); } - // icmp eq|ne X, Y -> false|true if X != Y - if ((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) && - isKnownNonEqual(LHS, RHS, Q.DL, Q.AC, Q.CxtI, Q.DT)) { - LLVMContext &Ctx = LHS->getType()->getContext(); - return Pred == ICmpInst::ICMP_NE ? - ConstantInt::getTrue(Ctx) : ConstantInt::getFalse(Ctx); - } + return nullptr; +} + +static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, + Value *RHS, const Query &Q, + unsigned MaxRecurse) { + Type *ITy = GetCompareTy(LHS); // The return type. - // Special logic for binary operators. BinaryOperator *LBO = dyn_cast<BinaryOperator>(LHS); BinaryOperator *RBO = dyn_cast<BinaryOperator>(RHS); if (MaxRecurse && (LBO || RBO)) { @@ -2622,35 +2529,39 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, // LHS = A + B (or A and B are null); RHS = C + D (or C and D are null). bool NoLHSWrapProblem = false, NoRHSWrapProblem = false; if (LBO && LBO->getOpcode() == Instruction::Add) { - A = LBO->getOperand(0); B = LBO->getOperand(1); - NoLHSWrapProblem = ICmpInst::isEquality(Pred) || - (CmpInst::isUnsigned(Pred) && LBO->hasNoUnsignedWrap()) || - (CmpInst::isSigned(Pred) && LBO->hasNoSignedWrap()); + A = LBO->getOperand(0); + B = LBO->getOperand(1); + NoLHSWrapProblem = + ICmpInst::isEquality(Pred) || + (CmpInst::isUnsigned(Pred) && LBO->hasNoUnsignedWrap()) || + (CmpInst::isSigned(Pred) && LBO->hasNoSignedWrap()); } if (RBO && RBO->getOpcode() == Instruction::Add) { - C = RBO->getOperand(0); D = RBO->getOperand(1); - NoRHSWrapProblem = ICmpInst::isEquality(Pred) || - (CmpInst::isUnsigned(Pred) && RBO->hasNoUnsignedWrap()) || - (CmpInst::isSigned(Pred) && RBO->hasNoSignedWrap()); + C = RBO->getOperand(0); + D = RBO->getOperand(1); + NoRHSWrapProblem = + ICmpInst::isEquality(Pred) || + (CmpInst::isUnsigned(Pred) && RBO->hasNoUnsignedWrap()) || + (CmpInst::isSigned(Pred) && RBO->hasNoSignedWrap()); } // icmp (X+Y), X -> icmp Y, 0 for equalities or if there is no overflow. if ((A == RHS || B == RHS) && NoLHSWrapProblem) if (Value *V = SimplifyICmpInst(Pred, A == RHS ? B : A, - Constant::getNullValue(RHS->getType()), - Q, MaxRecurse-1)) + Constant::getNullValue(RHS->getType()), Q, + MaxRecurse - 1)) return V; // icmp X, (X+Y) -> icmp 0, Y for equalities or if there is no overflow. if ((C == LHS || D == LHS) && NoRHSWrapProblem) - if (Value *V = SimplifyICmpInst(Pred, - Constant::getNullValue(LHS->getType()), - C == LHS ? D : C, Q, MaxRecurse-1)) + if (Value *V = + SimplifyICmpInst(Pred, Constant::getNullValue(LHS->getType()), + C == LHS ? D : C, Q, MaxRecurse - 1)) return V; // icmp (X+Y), (X+Z) -> icmp Y,Z for equalities or if there is no overflow. - if (A && C && (A == C || A == D || B == C || B == D) && - NoLHSWrapProblem && NoRHSWrapProblem) { + if (A && C && (A == C || A == D || B == C || B == D) && NoLHSWrapProblem && + NoRHSWrapProblem) { // Determine Y and Z in the form icmp (X+Y), (X+Z). Value *Y, *Z; if (A == C) { @@ -2671,7 +2582,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, Y = A; Z = C; } - if (Value *V = SimplifyICmpInst(Pred, Y, Z, Q, MaxRecurse-1)) + if (Value *V = SimplifyICmpInst(Pred, Y, Z, Q, MaxRecurse - 1)) return V; } } @@ -2771,7 +2682,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, Q.CxtI, Q.DT); if (!KnownNonNegative) break; - // fall-through + LLVM_FALLTHROUGH; case ICmpInst::ICMP_EQ: case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_UGE: @@ -2782,7 +2693,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, Q.CxtI, Q.DT); if (!KnownNonNegative) break; - // fall-through + LLVM_FALLTHROUGH; case ICmpInst::ICMP_NE: case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_ULE: @@ -2802,7 +2713,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, Q.CxtI, Q.DT); if (!KnownNonNegative) break; - // fall-through + LLVM_FALLTHROUGH; case ICmpInst::ICMP_NE: case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_UGE: @@ -2813,7 +2724,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, Q.CxtI, Q.DT); if (!KnownNonNegative) break; - // fall-through + LLVM_FALLTHROUGH; case ICmpInst::ICMP_EQ: case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_ULE: @@ -2832,6 +2743,17 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, return getTrue(ITy); } + // x >=u x >> y + // x >=u x udiv y. + if (RBO && (match(RBO, m_LShr(m_Specific(LHS), m_Value())) || + match(RBO, m_UDiv(m_Specific(LHS), m_Value())))) { + // icmp pred X, (X op Y) + if (Pred == ICmpInst::ICMP_ULT) + return getFalse(ITy); + if (Pred == ICmpInst::ICMP_UGE) + return getTrue(ITy); + } + // handle: // CI2 << X == CI // CI2 << X != CI @@ -2870,18 +2792,19 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, if (MaxRecurse && LBO && RBO && LBO->getOpcode() == RBO->getOpcode() && LBO->getOperand(1) == RBO->getOperand(1)) { switch (LBO->getOpcode()) { - default: break; + default: + break; case Instruction::UDiv: case Instruction::LShr: if (ICmpInst::isSigned(Pred)) break; - // fall-through + LLVM_FALLTHROUGH; case Instruction::SDiv: case Instruction::AShr: if (!LBO->isExact() || !RBO->isExact()) break; if (Value *V = SimplifyICmpInst(Pred, LBO->getOperand(0), - RBO->getOperand(0), Q, MaxRecurse-1)) + RBO->getOperand(0), Q, MaxRecurse - 1)) return V; break; case Instruction::Shl: { @@ -2892,40 +2815,51 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, if (!NSW && ICmpInst::isSigned(Pred)) break; if (Value *V = SimplifyICmpInst(Pred, LBO->getOperand(0), - RBO->getOperand(0), Q, MaxRecurse-1)) + RBO->getOperand(0), Q, MaxRecurse - 1)) return V; break; } } } + return nullptr; +} - // Simplify comparisons involving max/min. +/// Simplify integer comparisons where at least one operand of the compare +/// matches an integer min/max idiom. +static Value *simplifyICmpWithMinMax(CmpInst::Predicate Pred, Value *LHS, + Value *RHS, const Query &Q, + unsigned MaxRecurse) { + Type *ITy = GetCompareTy(LHS); // The return type. Value *A, *B; CmpInst::Predicate P = CmpInst::BAD_ICMP_PREDICATE; CmpInst::Predicate EqP; // Chosen so that "A == max/min(A,B)" iff "A EqP B". // Signed variants on "max(a,b)>=a -> true". if (match(LHS, m_SMax(m_Value(A), m_Value(B))) && (A == RHS || B == RHS)) { - if (A != RHS) std::swap(A, B); // smax(A, B) pred A. + if (A != RHS) + std::swap(A, B); // smax(A, B) pred A. EqP = CmpInst::ICMP_SGE; // "A == smax(A, B)" iff "A sge B". // We analyze this as smax(A, B) pred A. P = Pred; } else if (match(RHS, m_SMax(m_Value(A), m_Value(B))) && (A == LHS || B == LHS)) { - if (A != LHS) std::swap(A, B); // A pred smax(A, B). + if (A != LHS) + std::swap(A, B); // A pred smax(A, B). EqP = CmpInst::ICMP_SGE; // "A == smax(A, B)" iff "A sge B". // We analyze this as smax(A, B) swapped-pred A. P = CmpInst::getSwappedPredicate(Pred); } else if (match(LHS, m_SMin(m_Value(A), m_Value(B))) && (A == RHS || B == RHS)) { - if (A != RHS) std::swap(A, B); // smin(A, B) pred A. + if (A != RHS) + std::swap(A, B); // smin(A, B) pred A. EqP = CmpInst::ICMP_SLE; // "A == smin(A, B)" iff "A sle B". // We analyze this as smax(-A, -B) swapped-pred -A. // Note that we do not need to actually form -A or -B thanks to EqP. P = CmpInst::getSwappedPredicate(Pred); } else if (match(RHS, m_SMin(m_Value(A), m_Value(B))) && (A == LHS || B == LHS)) { - if (A != LHS) std::swap(A, B); // A pred smin(A, B). + if (A != LHS) + std::swap(A, B); // A pred smin(A, B). EqP = CmpInst::ICMP_SLE; // "A == smin(A, B)" iff "A sle B". // We analyze this as smax(-A, -B) pred -A. // Note that we do not need to actually form -A or -B thanks to EqP. @@ -2946,7 +2880,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, return V; // Otherwise, see if "A EqP B" simplifies. if (MaxRecurse) - if (Value *V = SimplifyICmpInst(EqP, A, B, Q, MaxRecurse-1)) + if (Value *V = SimplifyICmpInst(EqP, A, B, Q, MaxRecurse - 1)) return V; break; case CmpInst::ICMP_NE: @@ -2960,7 +2894,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, return V; // Otherwise, see if "A InvEqP B" simplifies. if (MaxRecurse) - if (Value *V = SimplifyICmpInst(InvEqP, A, B, Q, MaxRecurse-1)) + if (Value *V = SimplifyICmpInst(InvEqP, A, B, Q, MaxRecurse - 1)) return V; break; } @@ -2976,26 +2910,30 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, // Unsigned variants on "max(a,b)>=a -> true". P = CmpInst::BAD_ICMP_PREDICATE; if (match(LHS, m_UMax(m_Value(A), m_Value(B))) && (A == RHS || B == RHS)) { - if (A != RHS) std::swap(A, B); // umax(A, B) pred A. + if (A != RHS) + std::swap(A, B); // umax(A, B) pred A. EqP = CmpInst::ICMP_UGE; // "A == umax(A, B)" iff "A uge B". // We analyze this as umax(A, B) pred A. P = Pred; } else if (match(RHS, m_UMax(m_Value(A), m_Value(B))) && (A == LHS || B == LHS)) { - if (A != LHS) std::swap(A, B); // A pred umax(A, B). + if (A != LHS) + std::swap(A, B); // A pred umax(A, B). EqP = CmpInst::ICMP_UGE; // "A == umax(A, B)" iff "A uge B". // We analyze this as umax(A, B) swapped-pred A. P = CmpInst::getSwappedPredicate(Pred); } else if (match(LHS, m_UMin(m_Value(A), m_Value(B))) && (A == RHS || B == RHS)) { - if (A != RHS) std::swap(A, B); // umin(A, B) pred A. + if (A != RHS) + std::swap(A, B); // umin(A, B) pred A. EqP = CmpInst::ICMP_ULE; // "A == umin(A, B)" iff "A ule B". // We analyze this as umax(-A, -B) swapped-pred -A. // Note that we do not need to actually form -A or -B thanks to EqP. P = CmpInst::getSwappedPredicate(Pred); } else if (match(RHS, m_UMin(m_Value(A), m_Value(B))) && (A == LHS || B == LHS)) { - if (A != LHS) std::swap(A, B); // A pred umin(A, B). + if (A != LHS) + std::swap(A, B); // A pred umin(A, B). EqP = CmpInst::ICMP_ULE; // "A == umin(A, B)" iff "A ule B". // We analyze this as umax(-A, -B) pred -A. // Note that we do not need to actually form -A or -B thanks to EqP. @@ -3016,7 +2954,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, return V; // Otherwise, see if "A EqP B" simplifies. if (MaxRecurse) - if (Value *V = SimplifyICmpInst(EqP, A, B, Q, MaxRecurse-1)) + if (Value *V = SimplifyICmpInst(EqP, A, B, Q, MaxRecurse - 1)) return V; break; case CmpInst::ICMP_NE: @@ -3030,7 +2968,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, return V; // Otherwise, see if "A InvEqP B" simplifies. if (MaxRecurse) - if (Value *V = SimplifyICmpInst(InvEqP, A, B, Q, MaxRecurse-1)) + if (Value *V = SimplifyICmpInst(InvEqP, A, B, Q, MaxRecurse - 1)) return V; break; } @@ -3087,11 +3025,254 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, return getFalse(ITy); } + return nullptr; +} + +/// Given operands for an ICmpInst, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, + const Query &Q, unsigned MaxRecurse) { + CmpInst::Predicate Pred = (CmpInst::Predicate)Predicate; + assert(CmpInst::isIntPredicate(Pred) && "Not an integer compare!"); + + if (Constant *CLHS = dyn_cast<Constant>(LHS)) { + if (Constant *CRHS = dyn_cast<Constant>(RHS)) + return ConstantFoldCompareInstOperands(Pred, CLHS, CRHS, Q.DL, Q.TLI); + + // If we have a constant, make sure it is on the RHS. + std::swap(LHS, RHS); + Pred = CmpInst::getSwappedPredicate(Pred); + } + + Type *ITy = GetCompareTy(LHS); // The return type. + + // icmp X, X -> true/false + // X icmp undef -> true/false. For example, icmp ugt %X, undef -> false + // because X could be 0. + if (LHS == RHS || isa<UndefValue>(RHS)) + return ConstantInt::get(ITy, CmpInst::isTrueWhenEqual(Pred)); + + if (Value *V = simplifyICmpOfBools(Pred, LHS, RHS, Q)) + return V; + + if (Value *V = simplifyICmpWithZero(Pred, LHS, RHS, Q)) + return V; + + if (Value *V = simplifyICmpWithConstant(Pred, LHS, RHS)) + return V; + + // If both operands have range metadata, use the metadata + // to simplify the comparison. + if (isa<Instruction>(RHS) && isa<Instruction>(LHS)) { + auto RHS_Instr = dyn_cast<Instruction>(RHS); + auto LHS_Instr = dyn_cast<Instruction>(LHS); + + if (RHS_Instr->getMetadata(LLVMContext::MD_range) && + LHS_Instr->getMetadata(LLVMContext::MD_range)) { + auto RHS_CR = getConstantRangeFromMetadata( + *RHS_Instr->getMetadata(LLVMContext::MD_range)); + auto LHS_CR = getConstantRangeFromMetadata( + *LHS_Instr->getMetadata(LLVMContext::MD_range)); + + auto Satisfied_CR = ConstantRange::makeSatisfyingICmpRegion(Pred, RHS_CR); + if (Satisfied_CR.contains(LHS_CR)) + return ConstantInt::getTrue(RHS->getContext()); + + auto InversedSatisfied_CR = ConstantRange::makeSatisfyingICmpRegion( + CmpInst::getInversePredicate(Pred), RHS_CR); + if (InversedSatisfied_CR.contains(LHS_CR)) + return ConstantInt::getFalse(RHS->getContext()); + } + } + + // Compare of cast, for example (zext X) != 0 -> X != 0 + if (isa<CastInst>(LHS) && (isa<Constant>(RHS) || isa<CastInst>(RHS))) { + Instruction *LI = cast<CastInst>(LHS); + Value *SrcOp = LI->getOperand(0); + Type *SrcTy = SrcOp->getType(); + Type *DstTy = LI->getType(); + + // Turn icmp (ptrtoint x), (ptrtoint/constant) into a compare of the input + // if the integer type is the same size as the pointer type. + if (MaxRecurse && isa<PtrToIntInst>(LI) && + Q.DL.getTypeSizeInBits(SrcTy) == DstTy->getPrimitiveSizeInBits()) { + if (Constant *RHSC = dyn_cast<Constant>(RHS)) { + // Transfer the cast to the constant. + if (Value *V = SimplifyICmpInst(Pred, SrcOp, + ConstantExpr::getIntToPtr(RHSC, SrcTy), + Q, MaxRecurse-1)) + return V; + } else if (PtrToIntInst *RI = dyn_cast<PtrToIntInst>(RHS)) { + if (RI->getOperand(0)->getType() == SrcTy) + // Compare without the cast. + if (Value *V = SimplifyICmpInst(Pred, SrcOp, RI->getOperand(0), + Q, MaxRecurse-1)) + return V; + } + } + + if (isa<ZExtInst>(LHS)) { + // Turn icmp (zext X), (zext Y) into a compare of X and Y if they have the + // same type. + if (ZExtInst *RI = dyn_cast<ZExtInst>(RHS)) { + if (MaxRecurse && SrcTy == RI->getOperand(0)->getType()) + // Compare X and Y. Note that signed predicates become unsigned. + if (Value *V = SimplifyICmpInst(ICmpInst::getUnsignedPredicate(Pred), + SrcOp, RI->getOperand(0), Q, + MaxRecurse-1)) + return V; + } + // Turn icmp (zext X), Cst into a compare of X and Cst if Cst is extended + // too. If not, then try to deduce the result of the comparison. + else if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) { + // Compute the constant that would happen if we truncated to SrcTy then + // reextended to DstTy. + Constant *Trunc = ConstantExpr::getTrunc(CI, SrcTy); + Constant *RExt = ConstantExpr::getCast(CastInst::ZExt, Trunc, DstTy); + + // If the re-extended constant didn't change then this is effectively + // also a case of comparing two zero-extended values. + if (RExt == CI && MaxRecurse) + if (Value *V = SimplifyICmpInst(ICmpInst::getUnsignedPredicate(Pred), + SrcOp, Trunc, Q, MaxRecurse-1)) + return V; + + // Otherwise the upper bits of LHS are zero while RHS has a non-zero bit + // there. Use this to work out the result of the comparison. + if (RExt != CI) { + switch (Pred) { + default: llvm_unreachable("Unknown ICmp predicate!"); + // LHS <u RHS. + case ICmpInst::ICMP_EQ: + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_UGE: + return ConstantInt::getFalse(CI->getContext()); + + case ICmpInst::ICMP_NE: + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_ULE: + return ConstantInt::getTrue(CI->getContext()); + + // LHS is non-negative. If RHS is negative then LHS >s LHS. If RHS + // is non-negative then LHS <s RHS. + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_SGE: + return CI->getValue().isNegative() ? + ConstantInt::getTrue(CI->getContext()) : + ConstantInt::getFalse(CI->getContext()); + + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_SLE: + return CI->getValue().isNegative() ? + ConstantInt::getFalse(CI->getContext()) : + ConstantInt::getTrue(CI->getContext()); + } + } + } + } + + if (isa<SExtInst>(LHS)) { + // Turn icmp (sext X), (sext Y) into a compare of X and Y if they have the + // same type. + if (SExtInst *RI = dyn_cast<SExtInst>(RHS)) { + if (MaxRecurse && SrcTy == RI->getOperand(0)->getType()) + // Compare X and Y. Note that the predicate does not change. + if (Value *V = SimplifyICmpInst(Pred, SrcOp, RI->getOperand(0), + Q, MaxRecurse-1)) + return V; + } + // Turn icmp (sext X), Cst into a compare of X and Cst if Cst is extended + // too. If not, then try to deduce the result of the comparison. + else if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) { + // Compute the constant that would happen if we truncated to SrcTy then + // reextended to DstTy. + Constant *Trunc = ConstantExpr::getTrunc(CI, SrcTy); + Constant *RExt = ConstantExpr::getCast(CastInst::SExt, Trunc, DstTy); + + // If the re-extended constant didn't change then this is effectively + // also a case of comparing two sign-extended values. + if (RExt == CI && MaxRecurse) + if (Value *V = SimplifyICmpInst(Pred, SrcOp, Trunc, Q, MaxRecurse-1)) + return V; + + // Otherwise the upper bits of LHS are all equal, while RHS has varying + // bits there. Use this to work out the result of the comparison. + if (RExt != CI) { + switch (Pred) { + default: llvm_unreachable("Unknown ICmp predicate!"); + case ICmpInst::ICMP_EQ: + return ConstantInt::getFalse(CI->getContext()); + case ICmpInst::ICMP_NE: + return ConstantInt::getTrue(CI->getContext()); + + // If RHS is non-negative then LHS <s RHS. If RHS is negative then + // LHS >s RHS. + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_SGE: + return CI->getValue().isNegative() ? + ConstantInt::getTrue(CI->getContext()) : + ConstantInt::getFalse(CI->getContext()); + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_SLE: + return CI->getValue().isNegative() ? + ConstantInt::getFalse(CI->getContext()) : + ConstantInt::getTrue(CI->getContext()); + + // If LHS is non-negative then LHS <u RHS. If LHS is negative then + // LHS >u RHS. + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_UGE: + // Comparison is true iff the LHS <s 0. + if (MaxRecurse) + if (Value *V = SimplifyICmpInst(ICmpInst::ICMP_SLT, SrcOp, + Constant::getNullValue(SrcTy), + Q, MaxRecurse-1)) + return V; + break; + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_ULE: + // Comparison is true iff the LHS >=s 0. + if (MaxRecurse) + if (Value *V = SimplifyICmpInst(ICmpInst::ICMP_SGE, SrcOp, + Constant::getNullValue(SrcTy), + Q, MaxRecurse-1)) + return V; + break; + } + } + } + } + } + + // icmp eq|ne X, Y -> false|true if X != Y + if ((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) && + isKnownNonEqual(LHS, RHS, Q.DL, Q.AC, Q.CxtI, Q.DT)) { + LLVMContext &Ctx = LHS->getType()->getContext(); + return Pred == ICmpInst::ICMP_NE ? + ConstantInt::getTrue(Ctx) : ConstantInt::getFalse(Ctx); + } + + if (Value *V = simplifyICmpWithBinOp(Pred, LHS, RHS, Q, MaxRecurse)) + return V; + + if (Value *V = simplifyICmpWithMinMax(Pred, LHS, RHS, Q, MaxRecurse)) + return V; + // Simplify comparisons of related pointers using a powerful, recursive // GEP-walk when we have target data available.. if (LHS->getType()->isPointerTy()) if (auto *C = computePointerICmp(Q.DL, Q.TLI, Q.DT, Pred, Q.CxtI, LHS, RHS)) return C; + if (auto *CLHS = dyn_cast<PtrToIntOperator>(LHS)) + if (auto *CRHS = dyn_cast<PtrToIntOperator>(RHS)) + if (Q.DL.getTypeSizeInBits(CLHS->getPointerOperandType()) == + Q.DL.getTypeSizeInBits(CLHS->getType()) && + Q.DL.getTypeSizeInBits(CRHS->getPointerOperandType()) == + Q.DL.getTypeSizeInBits(CRHS->getType())) + if (auto *C = computePointerICmp(Q.DL, Q.TLI, Q.DT, Pred, Q.CxtI, + CLHS->getPointerOperand(), + CRHS->getPointerOperand())) + return C; if (GetElementPtrInst *GLHS = dyn_cast<GetElementPtrInst>(LHS)) { if (GEPOperator *GRHS = dyn_cast<GEPOperator>(RHS)) { @@ -3119,17 +3300,16 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, // If a bit is known to be zero for A and known to be one for B, // then A and B cannot be equal. if (ICmpInst::isEquality(Pred)) { - if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) { - uint32_t BitWidth = CI->getBitWidth(); + const APInt *RHSVal; + if (match(RHS, m_APInt(RHSVal))) { + unsigned BitWidth = RHSVal->getBitWidth(); APInt LHSKnownZero(BitWidth, 0); APInt LHSKnownOne(BitWidth, 0); computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT); - const APInt &RHSVal = CI->getValue(); - if (((LHSKnownZero & RHSVal) != 0) || ((LHSKnownOne & ~RHSVal) != 0)) - return Pred == ICmpInst::ICMP_EQ - ? ConstantInt::getFalse(CI->getContext()) - : ConstantInt::getTrue(CI->getContext()); + if (((LHSKnownZero & *RHSVal) != 0) || ((LHSKnownOne & ~(*RHSVal)) != 0)) + return Pred == ICmpInst::ICMP_EQ ? ConstantInt::getFalse(ITy) + : ConstantInt::getTrue(ITy); } } @@ -3175,17 +3355,18 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, } // Fold trivial predicates. + Type *RetTy = GetCompareTy(LHS); if (Pred == FCmpInst::FCMP_FALSE) - return ConstantInt::get(GetCompareTy(LHS), 0); + return getFalse(RetTy); if (Pred == FCmpInst::FCMP_TRUE) - return ConstantInt::get(GetCompareTy(LHS), 1); + return getTrue(RetTy); // UNO/ORD predicates can be trivially folded if NaNs are ignored. if (FMF.noNaNs()) { if (Pred == FCmpInst::FCMP_UNO) - return ConstantInt::get(GetCompareTy(LHS), 0); + return getFalse(RetTy); if (Pred == FCmpInst::FCMP_ORD) - return ConstantInt::get(GetCompareTy(LHS), 1); + return getTrue(RetTy); } // fcmp pred x, undef and fcmp pred undef, x @@ -3193,15 +3374,15 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, if (isa<UndefValue>(LHS) || isa<UndefValue>(RHS)) { // Choosing NaN for the undef will always make unordered comparison succeed // and ordered comparison fail. - return ConstantInt::get(GetCompareTy(LHS), CmpInst::isUnordered(Pred)); + return ConstantInt::get(RetTy, CmpInst::isUnordered(Pred)); } // fcmp x,x -> true/false. Not all compares are foldable. if (LHS == RHS) { if (CmpInst::isTrueWhenEqual(Pred)) - return ConstantInt::get(GetCompareTy(LHS), 1); + return getTrue(RetTy); if (CmpInst::isFalseWhenEqual(Pred)) - return ConstantInt::get(GetCompareTy(LHS), 0); + return getFalse(RetTy); } // Handle fcmp with constant RHS @@ -3216,11 +3397,11 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, // If the constant is a nan, see if we can fold the comparison based on it. if (CFP->getValueAPF().isNaN()) { if (FCmpInst::isOrdered(Pred)) // True "if ordered and foo" - return ConstantInt::getFalse(CFP->getContext()); + return getFalse(RetTy); assert(FCmpInst::isUnordered(Pred) && "Comparison must be either ordered or unordered!"); // True if unordered. - return ConstantInt::get(GetCompareTy(LHS), 1); + return getTrue(RetTy); } // Check whether the constant is an infinity. if (CFP->getValueAPF().isInfinity()) { @@ -3228,10 +3409,10 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, switch (Pred) { case FCmpInst::FCMP_OLT: // No value is ordered and less than negative infinity. - return ConstantInt::get(GetCompareTy(LHS), 0); + return getFalse(RetTy); case FCmpInst::FCMP_UGE: // All values are unordered with or at least negative infinity. - return ConstantInt::get(GetCompareTy(LHS), 1); + return getTrue(RetTy); default: break; } @@ -3239,10 +3420,10 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, switch (Pred) { case FCmpInst::FCMP_OGT: // No value is ordered and greater than infinity. - return ConstantInt::get(GetCompareTy(LHS), 0); + return getFalse(RetTy); case FCmpInst::FCMP_ULE: // All values are unordered with and at most infinity. - return ConstantInt::get(GetCompareTy(LHS), 1); + return getTrue(RetTy); default: break; } @@ -3252,12 +3433,12 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, switch (Pred) { case FCmpInst::FCMP_UGE: if (CannotBeOrderedLessThanZero(LHS, Q.TLI)) - return ConstantInt::get(GetCompareTy(LHS), 1); + return getTrue(RetTy); break; case FCmpInst::FCMP_OLT: // X < 0 if (CannotBeOrderedLessThanZero(LHS, Q.TLI)) - return ConstantInt::get(GetCompareTy(LHS), 0); + return getFalse(RetTy); break; default: break; @@ -3371,6 +3552,150 @@ static const Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, return nullptr; } +/// Try to simplify a select instruction when its condition operand is an +/// integer comparison where one operand of the compare is a constant. +static Value *simplifySelectBitTest(Value *TrueVal, Value *FalseVal, Value *X, + const APInt *Y, bool TrueWhenUnset) { + const APInt *C; + + // (X & Y) == 0 ? X & ~Y : X --> X + // (X & Y) != 0 ? X & ~Y : X --> X & ~Y + if (FalseVal == X && match(TrueVal, m_And(m_Specific(X), m_APInt(C))) && + *Y == ~*C) + return TrueWhenUnset ? FalseVal : TrueVal; + + // (X & Y) == 0 ? X : X & ~Y --> X & ~Y + // (X & Y) != 0 ? X : X & ~Y --> X + if (TrueVal == X && match(FalseVal, m_And(m_Specific(X), m_APInt(C))) && + *Y == ~*C) + return TrueWhenUnset ? FalseVal : TrueVal; + + if (Y->isPowerOf2()) { + // (X & Y) == 0 ? X | Y : X --> X | Y + // (X & Y) != 0 ? X | Y : X --> X + if (FalseVal == X && match(TrueVal, m_Or(m_Specific(X), m_APInt(C))) && + *Y == *C) + return TrueWhenUnset ? TrueVal : FalseVal; + + // (X & Y) == 0 ? X : X | Y --> X + // (X & Y) != 0 ? X : X | Y --> X | Y + if (TrueVal == X && match(FalseVal, m_Or(m_Specific(X), m_APInt(C))) && + *Y == *C) + return TrueWhenUnset ? TrueVal : FalseVal; + } + + return nullptr; +} + +/// An alternative way to test if a bit is set or not uses sgt/slt instead of +/// eq/ne. +static Value *simplifySelectWithFakeICmpEq(Value *CmpLHS, Value *TrueVal, + Value *FalseVal, + bool TrueWhenUnset) { + unsigned BitWidth = TrueVal->getType()->getScalarSizeInBits(); + if (!BitWidth) + return nullptr; + + APInt MinSignedValue; + Value *X; + if (match(CmpLHS, m_Trunc(m_Value(X))) && (X == TrueVal || X == FalseVal)) { + // icmp slt (trunc X), 0 <--> icmp ne (and X, C), 0 + // icmp sgt (trunc X), -1 <--> icmp eq (and X, C), 0 + unsigned DestSize = CmpLHS->getType()->getScalarSizeInBits(); + MinSignedValue = APInt::getSignedMinValue(DestSize).zext(BitWidth); + } else { + // icmp slt X, 0 <--> icmp ne (and X, C), 0 + // icmp sgt X, -1 <--> icmp eq (and X, C), 0 + X = CmpLHS; + MinSignedValue = APInt::getSignedMinValue(BitWidth); + } + + if (Value *V = simplifySelectBitTest(TrueVal, FalseVal, X, &MinSignedValue, + TrueWhenUnset)) + return V; + + return nullptr; +} + +/// Try to simplify a select instruction when its condition operand is an +/// integer comparison. +static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal, + Value *FalseVal, const Query &Q, + unsigned MaxRecurse) { + ICmpInst::Predicate Pred; + Value *CmpLHS, *CmpRHS; + if (!match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) + return nullptr; + + // FIXME: This code is nearly duplicated in InstCombine. Using/refactoring + // decomposeBitTestICmp() might help. + if (ICmpInst::isEquality(Pred) && match(CmpRHS, m_Zero())) { + Value *X; + const APInt *Y; + if (match(CmpLHS, m_And(m_Value(X), m_APInt(Y)))) + if (Value *V = simplifySelectBitTest(TrueVal, FalseVal, X, Y, + Pred == ICmpInst::ICMP_EQ)) + return V; + } else if (Pred == ICmpInst::ICMP_SLT && match(CmpRHS, m_Zero())) { + // Comparing signed-less-than 0 checks if the sign bit is set. + if (Value *V = simplifySelectWithFakeICmpEq(CmpLHS, TrueVal, FalseVal, + false)) + return V; + } else if (Pred == ICmpInst::ICMP_SGT && match(CmpRHS, m_AllOnes())) { + // Comparing signed-greater-than -1 checks if the sign bit is not set. + if (Value *V = simplifySelectWithFakeICmpEq(CmpLHS, TrueVal, FalseVal, + true)) + return V; + } + + if (CondVal->hasOneUse()) { + const APInt *C; + if (match(CmpRHS, m_APInt(C))) { + // X < MIN ? T : F --> F + if (Pred == ICmpInst::ICMP_SLT && C->isMinSignedValue()) + return FalseVal; + // X < MIN ? T : F --> F + if (Pred == ICmpInst::ICMP_ULT && C->isMinValue()) + return FalseVal; + // X > MAX ? T : F --> F + if (Pred == ICmpInst::ICMP_SGT && C->isMaxSignedValue()) + return FalseVal; + // X > MAX ? T : F --> F + if (Pred == ICmpInst::ICMP_UGT && C->isMaxValue()) + return FalseVal; + } + } + + // If we have an equality comparison, then we know the value in one of the + // arms of the select. See if substituting this value into the arm and + // simplifying the result yields the same value as the other arm. + if (Pred == ICmpInst::ICMP_EQ) { + if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, MaxRecurse) == + TrueVal || + SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q, MaxRecurse) == + TrueVal) + return FalseVal; + if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, MaxRecurse) == + FalseVal || + SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q, MaxRecurse) == + FalseVal) + return FalseVal; + } else if (Pred == ICmpInst::ICMP_NE) { + if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, MaxRecurse) == + FalseVal || + SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q, MaxRecurse) == + FalseVal) + return TrueVal; + if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, MaxRecurse) == + TrueVal || + SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q, MaxRecurse) == + TrueVal) + return TrueVal; + } + + return nullptr; +} + /// Given operands for a SelectInst, see if we can fold the result. /// If not, this returns null. static Value *SimplifySelectInst(Value *CondVal, Value *TrueVal, @@ -3399,106 +3724,9 @@ static Value *SimplifySelectInst(Value *CondVal, Value *TrueVal, if (isa<UndefValue>(FalseVal)) // select C, X, undef -> X return TrueVal; - if (const auto *ICI = dyn_cast<ICmpInst>(CondVal)) { - // FIXME: This code is nearly duplicated in InstCombine. Using/refactoring - // decomposeBitTestICmp() might help. - unsigned BitWidth = - Q.DL.getTypeSizeInBits(TrueVal->getType()->getScalarType()); - ICmpInst::Predicate Pred = ICI->getPredicate(); - Value *CmpLHS = ICI->getOperand(0); - Value *CmpRHS = ICI->getOperand(1); - APInt MinSignedValue = APInt::getSignBit(BitWidth); - Value *X; - const APInt *Y; - bool TrueWhenUnset; - bool IsBitTest = false; - if (ICmpInst::isEquality(Pred) && - match(CmpLHS, m_And(m_Value(X), m_APInt(Y))) && - match(CmpRHS, m_Zero())) { - IsBitTest = true; - TrueWhenUnset = Pred == ICmpInst::ICMP_EQ; - } else if (Pred == ICmpInst::ICMP_SLT && match(CmpRHS, m_Zero())) { - X = CmpLHS; - Y = &MinSignedValue; - IsBitTest = true; - TrueWhenUnset = false; - } else if (Pred == ICmpInst::ICMP_SGT && match(CmpRHS, m_AllOnes())) { - X = CmpLHS; - Y = &MinSignedValue; - IsBitTest = true; - TrueWhenUnset = true; - } - if (IsBitTest) { - const APInt *C; - // (X & Y) == 0 ? X & ~Y : X --> X - // (X & Y) != 0 ? X & ~Y : X --> X & ~Y - if (FalseVal == X && match(TrueVal, m_And(m_Specific(X), m_APInt(C))) && - *Y == ~*C) - return TrueWhenUnset ? FalseVal : TrueVal; - // (X & Y) == 0 ? X : X & ~Y --> X & ~Y - // (X & Y) != 0 ? X : X & ~Y --> X - if (TrueVal == X && match(FalseVal, m_And(m_Specific(X), m_APInt(C))) && - *Y == ~*C) - return TrueWhenUnset ? FalseVal : TrueVal; - - if (Y->isPowerOf2()) { - // (X & Y) == 0 ? X | Y : X --> X | Y - // (X & Y) != 0 ? X | Y : X --> X - if (FalseVal == X && match(TrueVal, m_Or(m_Specific(X), m_APInt(C))) && - *Y == *C) - return TrueWhenUnset ? TrueVal : FalseVal; - // (X & Y) == 0 ? X : X | Y --> X - // (X & Y) != 0 ? X : X | Y --> X | Y - if (TrueVal == X && match(FalseVal, m_Or(m_Specific(X), m_APInt(C))) && - *Y == *C) - return TrueWhenUnset ? TrueVal : FalseVal; - } - } - if (ICI->hasOneUse()) { - const APInt *C; - if (match(CmpRHS, m_APInt(C))) { - // X < MIN ? T : F --> F - if (Pred == ICmpInst::ICMP_SLT && C->isMinSignedValue()) - return FalseVal; - // X < MIN ? T : F --> F - if (Pred == ICmpInst::ICMP_ULT && C->isMinValue()) - return FalseVal; - // X > MAX ? T : F --> F - if (Pred == ICmpInst::ICMP_SGT && C->isMaxSignedValue()) - return FalseVal; - // X > MAX ? T : F --> F - if (Pred == ICmpInst::ICMP_UGT && C->isMaxValue()) - return FalseVal; - } - } - - // If we have an equality comparison then we know the value in one of the - // arms of the select. See if substituting this value into the arm and - // simplifying the result yields the same value as the other arm. - if (Pred == ICmpInst::ICMP_EQ) { - if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, MaxRecurse) == - TrueVal || - SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q, MaxRecurse) == - TrueVal) - return FalseVal; - if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, MaxRecurse) == - FalseVal || - SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q, MaxRecurse) == - FalseVal) - return FalseVal; - } else if (Pred == ICmpInst::ICMP_NE) { - if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, MaxRecurse) == - FalseVal || - SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q, MaxRecurse) == - FalseVal) - return TrueVal; - if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, MaxRecurse) == - TrueVal || - SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q, MaxRecurse) == - TrueVal) - return TrueVal; - } - } + if (Value *V = + simplifySelectWithICmpCond(CondVal, TrueVal, FalseVal, Q, MaxRecurse)) + return V; return nullptr; } @@ -3587,6 +3815,32 @@ static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops, } } + if (Q.DL.getTypeAllocSize(LastType) == 1 && + all_of(Ops.slice(1).drop_back(1), + [](Value *Idx) { return match(Idx, m_Zero()); })) { + unsigned PtrWidth = + Q.DL.getPointerSizeInBits(Ops[0]->getType()->getPointerAddressSpace()); + if (Q.DL.getTypeSizeInBits(Ops.back()->getType()) == PtrWidth) { + APInt BasePtrOffset(PtrWidth, 0); + Value *StrippedBasePtr = + Ops[0]->stripAndAccumulateInBoundsConstantOffsets(Q.DL, + BasePtrOffset); + + // gep (gep V, C), (sub 0, V) -> C + if (match(Ops.back(), + m_Sub(m_Zero(), m_PtrToInt(m_Specific(StrippedBasePtr))))) { + auto *CI = ConstantInt::get(GEPTy->getContext(), BasePtrOffset); + return ConstantExpr::getIntToPtr(CI, GEPTy); + } + // gep (gep V, C), (xor V, -1) -> C-1 + if (match(Ops.back(), + m_Xor(m_PtrToInt(m_Specific(StrippedBasePtr)), m_AllOnes()))) { + auto *CI = ConstantInt::get(GEPTy->getContext(), BasePtrOffset - 1); + return ConstantExpr::getIntToPtr(CI, GEPTy); + } + } + } + // Check to see if this is constant foldable. for (unsigned i = 0, e = Ops.size(); i != e; ++i) if (!isa<Constant>(Ops[i])) @@ -3742,19 +3996,47 @@ static Value *SimplifyPHINode(PHINode *PN, const Query &Q) { return CommonValue; } -static Value *SimplifyTruncInst(Value *Op, Type *Ty, const Query &Q, unsigned) { - if (Constant *C = dyn_cast<Constant>(Op)) - return ConstantFoldCastOperand(Instruction::Trunc, C, Ty, Q.DL); +static Value *SimplifyCastInst(unsigned CastOpc, Value *Op, + Type *Ty, const Query &Q, unsigned MaxRecurse) { + if (auto *C = dyn_cast<Constant>(Op)) + return ConstantFoldCastOperand(CastOpc, C, Ty, Q.DL); + + if (auto *CI = dyn_cast<CastInst>(Op)) { + auto *Src = CI->getOperand(0); + Type *SrcTy = Src->getType(); + Type *MidTy = CI->getType(); + Type *DstTy = Ty; + if (Src->getType() == Ty) { + auto FirstOp = static_cast<Instruction::CastOps>(CI->getOpcode()); + auto SecondOp = static_cast<Instruction::CastOps>(CastOpc); + Type *SrcIntPtrTy = + SrcTy->isPtrOrPtrVectorTy() ? Q.DL.getIntPtrType(SrcTy) : nullptr; + Type *MidIntPtrTy = + MidTy->isPtrOrPtrVectorTy() ? Q.DL.getIntPtrType(MidTy) : nullptr; + Type *DstIntPtrTy = + DstTy->isPtrOrPtrVectorTy() ? Q.DL.getIntPtrType(DstTy) : nullptr; + if (CastInst::isEliminableCastPair(FirstOp, SecondOp, SrcTy, MidTy, DstTy, + SrcIntPtrTy, MidIntPtrTy, + DstIntPtrTy) == Instruction::BitCast) + return Src; + } + } + + // bitcast x -> x + if (CastOpc == Instruction::BitCast) + if (Op->getType() == Ty) + return Op; return nullptr; } -Value *llvm::SimplifyTruncInst(Value *Op, Type *Ty, const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyTruncInst(Op, Ty, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); +Value *llvm::SimplifyCastInst(unsigned CastOpc, Value *Op, Type *Ty, + const DataLayout &DL, + const TargetLibraryInfo *TLI, + const DominatorTree *DT, AssumptionCache *AC, + const Instruction *CxtI) { + return ::SimplifyCastInst(CastOpc, Op, Ty, Query(DL, TLI, DT, AC, CxtI), + RecursionLimit); } //=== Helper functions for higher up the class hierarchy. @@ -3837,6 +4119,8 @@ static Value *SimplifyFPBinOp(unsigned Opcode, Value *LHS, Value *RHS, return SimplifyFSubInst(LHS, RHS, FMF, Q, MaxRecurse); case Instruction::FMul: return SimplifyFMulInst(LHS, RHS, FMF, Q, MaxRecurse); + case Instruction::FDiv: + return SimplifyFDivInst(LHS, RHS, FMF, Q, MaxRecurse); default: return SimplifyBinOp(Opcode, LHS, RHS, Q, MaxRecurse); } @@ -3968,14 +4252,36 @@ static Value *SimplifyIntrinsic(Function *F, IterTy ArgBegin, IterTy ArgEnd, const Query &Q, unsigned MaxRecurse) { Intrinsic::ID IID = F->getIntrinsicID(); unsigned NumOperands = std::distance(ArgBegin, ArgEnd); - Type *ReturnType = F->getReturnType(); + + // Unary Ops + if (NumOperands == 1) { + // Perform idempotent optimizations + if (IsIdempotent(IID)) { + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(*ArgBegin)) { + if (II->getIntrinsicID() == IID) + return II; + } + } + + switch (IID) { + case Intrinsic::fabs: { + if (SignBitMustBeZero(*ArgBegin, Q.TLI)) + return *ArgBegin; + } + default: + return nullptr; + } + } // Binary Ops if (NumOperands == 2) { Value *LHS = *ArgBegin; Value *RHS = *(ArgBegin + 1); - if (IID == Intrinsic::usub_with_overflow || - IID == Intrinsic::ssub_with_overflow) { + Type *ReturnType = F->getReturnType(); + + switch (IID) { + case Intrinsic::usub_with_overflow: + case Intrinsic::ssub_with_overflow: { // X - X -> { 0, false } if (LHS == RHS) return Constant::getNullValue(ReturnType); @@ -3984,17 +4290,19 @@ static Value *SimplifyIntrinsic(Function *F, IterTy ArgBegin, IterTy ArgEnd, // undef - X -> undef if (isa<UndefValue>(LHS) || isa<UndefValue>(RHS)) return UndefValue::get(ReturnType); - } - if (IID == Intrinsic::uadd_with_overflow || - IID == Intrinsic::sadd_with_overflow) { + return nullptr; + } + case Intrinsic::uadd_with_overflow: + case Intrinsic::sadd_with_overflow: { // X + undef -> undef if (isa<UndefValue>(RHS)) return UndefValue::get(ReturnType); - } - if (IID == Intrinsic::umul_with_overflow || - IID == Intrinsic::smul_with_overflow) { + return nullptr; + } + case Intrinsic::umul_with_overflow: + case Intrinsic::smul_with_overflow: { // X * 0 -> { 0, false } if (match(RHS, m_Zero())) return Constant::getNullValue(ReturnType); @@ -4002,34 +4310,34 @@ static Value *SimplifyIntrinsic(Function *F, IterTy ArgBegin, IterTy ArgEnd, // X * undef -> { 0, false } if (match(RHS, m_Undef())) return Constant::getNullValue(ReturnType); - } - if (IID == Intrinsic::load_relative && isa<Constant>(LHS) && - isa<Constant>(RHS)) - return SimplifyRelativeLoad(cast<Constant>(LHS), cast<Constant>(RHS), - Q.DL); + return nullptr; + } + case Intrinsic::load_relative: { + Constant *C0 = dyn_cast<Constant>(LHS); + Constant *C1 = dyn_cast<Constant>(RHS); + if (C0 && C1) + return SimplifyRelativeLoad(C0, C1, Q.DL); + return nullptr; + } + default: + return nullptr; + } } // Simplify calls to llvm.masked.load.* - if (IID == Intrinsic::masked_load) { + switch (IID) { + case Intrinsic::masked_load: { Value *MaskArg = ArgBegin[2]; Value *PassthruArg = ArgBegin[3]; // If the mask is all zeros or undef, the "passthru" argument is the result. if (maskIsAllZeroOrUndef(MaskArg)) return PassthruArg; + return nullptr; } - - // Perform idempotent optimizations - if (!IsIdempotent(IID)) + default: return nullptr; - - // Unary Ops - if (NumOperands == 1) - if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(*ArgBegin)) - if (II->getIntrinsicID() == IID) - return II; - - return nullptr; + } } template <typename IterTy> @@ -4223,21 +4531,23 @@ Value *llvm::SimplifyInstruction(Instruction *I, const DataLayout &DL, TLI, DT, AC, I); break; } - case Instruction::Trunc: - Result = - SimplifyTruncInst(I->getOperand(0), I->getType(), DL, TLI, DT, AC, I); +#define HANDLE_CAST_INST(num, opc, clas) case Instruction::opc: +#include "llvm/IR/Instruction.def" +#undef HANDLE_CAST_INST + Result = SimplifyCastInst(I->getOpcode(), I->getOperand(0), I->getType(), + DL, TLI, DT, AC, I); break; } // In general, it is possible for computeKnownBits to determine all bits in a // value even when the operands are not all constants. - if (!Result && I->getType()->isIntegerTy()) { + if (!Result && I->getType()->isIntOrIntVectorTy()) { unsigned BitWidth = I->getType()->getScalarSizeInBits(); APInt KnownZero(BitWidth, 0); APInt KnownOne(BitWidth, 0); computeKnownBits(I, KnownZero, KnownOne, DL, /*Depth*/0, AC, I, DT); if ((KnownZero | KnownOne).isAllOnesValue()) - Result = ConstantInt::get(I->getContext(), KnownOne); + Result = ConstantInt::get(I->getType(), KnownOne); } /// If called on unreachable code, the above logic may report that the |