diff options
Diffstat (limited to 'contrib/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp')
-rw-r--r-- | contrib/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp | 686 |
1 files changed, 467 insertions, 219 deletions
diff --git a/contrib/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/contrib/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 9bb65ef..5e71c5c 100644 --- a/contrib/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/contrib/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -15,28 +15,21 @@ #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/IR/ConstantRange.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/Support/ConstantRange.h" -#include "llvm/Support/GetElementPtrTypeIterator.h" -#include "llvm/Support/PatternMatch.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/Target/TargetLibraryInfo.h" using namespace llvm; using namespace PatternMatch; +#define DEBUG_TYPE "instcombine" + static ConstantInt *getOne(Constant *C) { return ConstantInt::get(cast<IntegerType>(C->getType()), 1); } -/// AddOne - Add one to a ConstantInt -static Constant *AddOne(Constant *C) { - return ConstantExpr::getAdd(C, ConstantInt::get(C->getType(), 1)); -} -/// SubOne - Subtract one from a ConstantInt -static Constant *SubOne(Constant *C) { - return ConstantExpr::getSub(C, ConstantInt::get(C->getType(), 1)); -} - static ConstantInt *ExtractElement(Constant *V, Constant *Idx) { return cast<ConstantInt>(ConstantExpr::getExtractElement(V, Idx)); } @@ -227,15 +220,15 @@ Instruction *InstCombiner:: FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, CmpInst &ICI, ConstantInt *AndCst) { // We need TD information to know the pointer size unless this is inbounds. - if (!GEP->isInBounds() && TD == 0) - return 0; + if (!GEP->isInBounds() && !DL) + return nullptr; Constant *Init = GV->getInitializer(); if (!isa<ConstantArray>(Init) && !isa<ConstantDataArray>(Init)) - return 0; + return nullptr; uint64_t ArrayElementCount = Init->getType()->getArrayNumElements(); - if (ArrayElementCount > 1024) return 0; // Don't blow up on huge arrays. + if (ArrayElementCount > 1024) return nullptr; // Don't blow up on huge arrays. // There are many forms of this optimization we can handle, for now, just do // the simple index into a single-dimensional array. @@ -245,7 +238,7 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, !isa<ConstantInt>(GEP->getOperand(1)) || !cast<ConstantInt>(GEP->getOperand(1))->isZero() || isa<Constant>(GEP->getOperand(2))) - return 0; + return nullptr; // Check that indices after the variable are constants and in-range for the // type they index. Collect the indices. This is typically for arrays of @@ -255,18 +248,18 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, Type *EltTy = Init->getType()->getArrayElementType(); for (unsigned i = 3, e = GEP->getNumOperands(); i != e; ++i) { ConstantInt *Idx = dyn_cast<ConstantInt>(GEP->getOperand(i)); - if (Idx == 0) return 0; // Variable index. + if (!Idx) return nullptr; // Variable index. uint64_t IdxVal = Idx->getZExtValue(); - if ((unsigned)IdxVal != IdxVal) return 0; // Too large array index. + if ((unsigned)IdxVal != IdxVal) return nullptr; // Too large array index. if (StructType *STy = dyn_cast<StructType>(EltTy)) EltTy = STy->getElementType(IdxVal); else if (ArrayType *ATy = dyn_cast<ArrayType>(EltTy)) { - if (IdxVal >= ATy->getNumElements()) return 0; + if (IdxVal >= ATy->getNumElements()) return nullptr; EltTy = ATy->getElementType(); } else { - return 0; // Unknown type. + return nullptr; // Unknown type. } LaterIndices.push_back(IdxVal); @@ -305,7 +298,7 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, Constant *CompareRHS = cast<Constant>(ICI.getOperand(1)); for (unsigned i = 0, e = ArrayElementCount; i != e; ++i) { Constant *Elt = Init->getAggregateElement(i); - if (Elt == 0) return 0; + if (!Elt) return nullptr; // If this is indexing an array of structures, get the structure element. if (!LaterIndices.empty()) @@ -316,7 +309,7 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, // Find out if the comparison would be true or false for the i'th element. Constant *C = ConstantFoldCompareInstOperands(ICI.getPredicate(), Elt, - CompareRHS, TD, TLI); + CompareRHS, DL, TLI); // If the result is undef for this element, ignore it. if (isa<UndefValue>(C)) { // Extend range state machines to cover this element in case there is an @@ -330,7 +323,7 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, // If we can't compute the result for any of the elements, we have to give // up evaluating the entire conditional. - if (!isa<ConstantInt>(C)) return 0; + if (!isa<ConstantInt>(C)) return nullptr; // Otherwise, we know if the comparison is true or false for this element, // update our state machines. @@ -384,7 +377,7 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, if ((i & 8) == 0 && i >= 64 && SecondTrueElement == Overdefined && SecondFalseElement == Overdefined && TrueRangeEnd == Overdefined && FalseRangeEnd == Overdefined) - return 0; + return nullptr; } // Now that we've scanned the entire array, emit our new comparison(s). We @@ -395,7 +388,7 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, // index down like the GEP would do implicitly. We don't have to do this for // an inbounds GEP because the index can't be out of range. if (!GEP->isInBounds()) { - Type *IntPtrTy = TD->getIntPtrType(GEP->getType()); + Type *IntPtrTy = DL->getIntPtrType(GEP->getType()); unsigned PtrSize = IntPtrTy->getIntegerBitWidth(); if (Idx->getType()->getPrimitiveSizeInBits() > PtrSize) Idx = Builder->CreateTrunc(Idx, IntPtrTy); @@ -476,7 +469,7 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, // of this load, replace it with computation that does: // ((magic_cst >> i) & 1) != 0 { - Type *Ty = 0; + Type *Ty = nullptr; // Look for an appropriate type: // - The type of Idx if the magic fits @@ -484,12 +477,12 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, // - Default to i32 if (ArrayElementCount <= Idx->getType()->getIntegerBitWidth()) Ty = Idx->getType(); - else if (TD) - Ty = TD->getSmallestLegalIntType(Init->getContext(), ArrayElementCount); + else if (DL) + Ty = DL->getSmallestLegalIntType(Init->getContext(), ArrayElementCount); else if (ArrayElementCount <= 32) Ty = Type::getInt32Ty(Init->getContext()); - if (Ty != 0) { + if (Ty) { Value *V = Builder->CreateIntCast(Idx, Ty, false); V = Builder->CreateLShr(ConstantInt::get(Ty, MagicBitvector), V); V = Builder->CreateAnd(ConstantInt::get(Ty, 1), V); @@ -497,7 +490,7 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, } } - return 0; + return nullptr; } @@ -512,7 +505,7 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, /// If we can't emit an optimized form for this expression, this returns null. /// static Value *EvaluateGEPOffsetExpression(User *GEP, InstCombiner &IC) { - DataLayout &TD = *IC.getDataLayout(); + const DataLayout &DL = *IC.getDataLayout(); gep_type_iterator GTI = gep_type_begin(GEP); // Check to see if this gep only has a single variable index. If so, and if @@ -529,9 +522,9 @@ static Value *EvaluateGEPOffsetExpression(User *GEP, InstCombiner &IC) { // Handle a struct index, which adds its field offset to the pointer. if (StructType *STy = dyn_cast<StructType>(*GTI)) { - Offset += TD.getStructLayout(STy)->getElementOffset(CI->getZExtValue()); + Offset += DL.getStructLayout(STy)->getElementOffset(CI->getZExtValue()); } else { - uint64_t Size = TD.getTypeAllocSize(GTI.getIndexedType()); + uint64_t Size = DL.getTypeAllocSize(GTI.getIndexedType()); Offset += Size*CI->getSExtValue(); } } else { @@ -542,26 +535,26 @@ static Value *EvaluateGEPOffsetExpression(User *GEP, InstCombiner &IC) { // If there are no variable indices, we must have a constant offset, just // evaluate it the general way. - if (i == e) return 0; + if (i == e) return nullptr; Value *VariableIdx = GEP->getOperand(i); // Determine the scale factor of the variable element. For example, this is // 4 if the variable index is into an array of i32. - uint64_t VariableScale = TD.getTypeAllocSize(GTI.getIndexedType()); + uint64_t VariableScale = DL.getTypeAllocSize(GTI.getIndexedType()); // Verify that there are no other variable indices. If so, emit the hard way. for (++i, ++GTI; i != e; ++i, ++GTI) { ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(i)); - if (!CI) return 0; + if (!CI) return nullptr; // Compute the aggregate offset of constant indices. if (CI->isZero()) continue; // Handle a struct index, which adds its field offset to the pointer. if (StructType *STy = dyn_cast<StructType>(*GTI)) { - Offset += TD.getStructLayout(STy)->getElementOffset(CI->getZExtValue()); + Offset += DL.getStructLayout(STy)->getElementOffset(CI->getZExtValue()); } else { - uint64_t Size = TD.getTypeAllocSize(GTI.getIndexedType()); + uint64_t Size = DL.getTypeAllocSize(GTI.getIndexedType()); Offset += Size*CI->getSExtValue(); } } @@ -571,7 +564,7 @@ static Value *EvaluateGEPOffsetExpression(User *GEP, InstCombiner &IC) { // Okay, we know we have a single variable index, which must be a // pointer/array/vector index. If there is no offset, life is simple, return // the index. - Type *IntPtrTy = TD.getIntPtrType(GEP->getOperand(0)->getType()); + Type *IntPtrTy = DL.getIntPtrType(GEP->getOperand(0)->getType()); unsigned IntPtrWidth = IntPtrTy->getIntegerBitWidth(); if (Offset == 0) { // Cast to intptrty in case a truncation occurs. If an extension is needed, @@ -596,7 +589,7 @@ static Value *EvaluateGEPOffsetExpression(User *GEP, InstCombiner &IC) { // multiple of the variable scale. int64_t NewOffs = Offset / (int64_t)VariableScale; if (Offset != NewOffs*(int64_t)VariableScale) - return 0; + return nullptr; // Okay, we can do this evaluation. Start by converting the index to intptr. if (VariableIdx->getType() != IntPtrTy) @@ -617,14 +610,15 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, // e.g. "&foo[0] <s &foo[1]" can't be folded to "true" because "foo" could be // the maximum signed value for the pointer type. if (ICmpInst::isSigned(Cond)) - return 0; + return nullptr; - // Look through bitcasts. - if (BitCastInst *BCI = dyn_cast<BitCastInst>(RHS)) - RHS = BCI->getOperand(0); + // Look through bitcasts and addrspacecasts. We do not however want to remove + // 0 GEPs. + if (!isa<GetElementPtrInst>(RHS)) + RHS = RHS->stripPointerCasts(); Value *PtrBase = GEPLHS->getOperand(0); - if (TD && PtrBase == RHS && GEPLHS->isInBounds()) { + if (DL && PtrBase == RHS && GEPLHS->isInBounds()) { // ((gep Ptr, OFFSET) cmp Ptr) ---> (OFFSET cmp 0). // This transformation (ignoring the base and scales) is valid because we // know pointers can't overflow since the gep is inbounds. See if we can @@ -632,7 +626,7 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, Value *Offset = EvaluateGEPOffsetExpression(GEPLHS, *this); // If not, synthesize the offset the hard way. - if (Offset == 0) + if (!Offset) Offset = EmitGEPOffset(GEPLHS); return new ICmpInst(ICmpInst::getSignedPredicate(Cond), Offset, Constant::getNullValue(Offset->getType())); @@ -657,43 +651,44 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, // If we're comparing GEPs with two base pointers that only differ in type // and both GEPs have only constant indices or just one use, then fold // the compare with the adjusted indices. - if (TD && GEPLHS->isInBounds() && GEPRHS->isInBounds() && + if (DL && GEPLHS->isInBounds() && GEPRHS->isInBounds() && (GEPLHS->hasAllConstantIndices() || GEPLHS->hasOneUse()) && (GEPRHS->hasAllConstantIndices() || GEPRHS->hasOneUse()) && PtrBase->stripPointerCasts() == GEPRHS->getOperand(0)->stripPointerCasts()) { + Value *LOffset = EmitGEPOffset(GEPLHS); + Value *ROffset = EmitGEPOffset(GEPRHS); + + // If we looked through an addrspacecast between different sized address + // spaces, the LHS and RHS pointers are different sized + // integers. Truncate to the smaller one. + Type *LHSIndexTy = LOffset->getType(); + Type *RHSIndexTy = ROffset->getType(); + if (LHSIndexTy != RHSIndexTy) { + if (LHSIndexTy->getPrimitiveSizeInBits() < + RHSIndexTy->getPrimitiveSizeInBits()) { + ROffset = Builder->CreateTrunc(ROffset, LHSIndexTy); + } else + LOffset = Builder->CreateTrunc(LOffset, RHSIndexTy); + } + Value *Cmp = Builder->CreateICmp(ICmpInst::getSignedPredicate(Cond), - EmitGEPOffset(GEPLHS), - EmitGEPOffset(GEPRHS)); + LOffset, ROffset); return ReplaceInstUsesWith(I, Cmp); } // Otherwise, the base pointers are different and the indices are // different, bail out. - return 0; + return nullptr; } // If one of the GEPs has all zero indices, recurse. - bool AllZeros = true; - for (unsigned i = 1, e = GEPLHS->getNumOperands(); i != e; ++i) - if (!isa<Constant>(GEPLHS->getOperand(i)) || - !cast<Constant>(GEPLHS->getOperand(i))->isNullValue()) { - AllZeros = false; - break; - } - if (AllZeros) + if (GEPLHS->hasAllZeroIndices()) return FoldGEPICmp(GEPRHS, GEPLHS->getOperand(0), ICmpInst::getSwappedPredicate(Cond), I); // If the other GEP has all zero indices, recurse. - AllZeros = true; - for (unsigned i = 1, e = GEPRHS->getNumOperands(); i != e; ++i) - if (!isa<Constant>(GEPRHS->getOperand(i)) || - !cast<Constant>(GEPRHS->getOperand(i))->isNullValue()) { - AllZeros = false; - break; - } - if (AllZeros) + if (GEPRHS->hasAllZeroIndices()) return FoldGEPICmp(GEPLHS, GEPRHS->getOperand(0), Cond, I); bool GEPsInBounds = GEPLHS->isInBounds() && GEPRHS->isInBounds(); @@ -728,7 +723,7 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, // Only lower this if the icmp is the only user of the GEP or if we expect // the result to fold to a constant! - if (TD && + if (DL && GEPsInBounds && (isa<ConstantExpr>(GEPLHS) || GEPLHS->hasOneUse()) && (isa<ConstantExpr>(GEPRHS) || GEPRHS->hasOneUse())) { @@ -738,7 +733,7 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, return new ICmpInst(ICmpInst::getSignedPredicate(Cond), L, R); } } - return 0; + return nullptr; } /// FoldICmpAddOpCst - Fold "icmp pred (X+CI), X". @@ -821,11 +816,11 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, // if it finds it. bool DivIsSigned = DivI->getOpcode() == Instruction::SDiv; if (!ICI.isEquality() && DivIsSigned != ICI.isSigned()) - return 0; + return nullptr; if (DivRHS->isZero()) - return 0; // The ProdOV computation fails on divide by zero. + return nullptr; // The ProdOV computation fails on divide by zero. if (DivIsSigned && DivRHS->isAllOnesValue()) - return 0; // The overflow computation also screws up here + return nullptr; // The overflow computation also screws up here if (DivRHS->isOne()) { // This eliminates some funny cases with INT_MIN. ICI.setOperand(0, DivI->getOperand(0)); // X/1 == X. @@ -859,7 +854,7 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, // overflow variable is set to 0 if it's corresponding bound variable is valid // -1 if overflowed off the bottom end, or +1 if overflowed off the top end. int LoOverflow = 0, HiOverflow = 0; - Constant *LoBound = 0, *HiBound = 0; + Constant *LoBound = nullptr, *HiBound = nullptr; if (!DivIsSigned) { // udiv // e.g. X/5 op 3 --> [15, 20) @@ -899,7 +894,7 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, HiBound = cast<ConstantInt>(ConstantExpr::getNeg(RangeSize)); if (HiBound == DivRHS) { // -INTMIN = INTMIN HiOverflow = 1; // [INTMIN+1, overflow) - HiBound = 0; // e.g. X/INTMIN = 0 --> X > INTMIN + HiBound = nullptr; // e.g. X/INTMIN = 0 --> X > INTMIN } } else if (CmpRHSV.isStrictlyPositive()) { // (X / neg) op pos // e.g. X/-5 op 3 --> [-19, -14) @@ -973,20 +968,20 @@ Instruction *InstCombiner::FoldICmpShrCst(ICmpInst &ICI, BinaryOperator *Shr, uint32_t TypeBits = CmpRHSV.getBitWidth(); uint32_t ShAmtVal = (uint32_t)ShAmt->getLimitedValue(TypeBits); if (ShAmtVal >= TypeBits || ShAmtVal == 0) - return 0; + return nullptr; if (!ICI.isEquality()) { // If we have an unsigned comparison and an ashr, we can't simplify this. // Similarly for signed comparisons with lshr. if (ICI.isSigned() != (Shr->getOpcode() == Instruction::AShr)) - return 0; + return nullptr; // Otherwise, all lshr and most exact ashr's are equivalent to a udiv/sdiv // by a power of 2. Since we already have logic to simplify these, // transform to div and then simplify the resultant comparison. if (Shr->getOpcode() == Instruction::AShr && (!Shr->isExact() || ShAmtVal == TypeBits - 1)) - return 0; + return nullptr; // Revisit the shift (to delete it). Worklist.Add(Shr); @@ -1003,7 +998,7 @@ Instruction *InstCombiner::FoldICmpShrCst(ICmpInst &ICI, BinaryOperator *Shr, // If the builder folded the binop, just return it. BinaryOperator *TheDiv = dyn_cast<BinaryOperator>(Tmp); - if (TheDiv == 0) + if (!TheDiv) return &ICI; // Otherwise, fold this div/compare. @@ -1046,7 +1041,7 @@ Instruction *InstCombiner::FoldICmpShrCst(ICmpInst &ICI, BinaryOperator *Shr, Mask, Shr->getName()+".mask"); return new ICmpInst(ICI.getPredicate(), And, ShiftedCmpRHS); } - return 0; + return nullptr; } @@ -1065,7 +1060,7 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, unsigned DstBits = LHSI->getType()->getPrimitiveSizeInBits(), SrcBits = LHSI->getOperand(0)->getType()->getPrimitiveSizeInBits(); APInt KnownZero(SrcBits, 0), KnownOne(SrcBits, 0); - ComputeMaskedBits(LHSI->getOperand(0), KnownZero, KnownOne); + computeKnownBits(LHSI->getOperand(0), KnownZero, KnownOne); // If all the high bits are known, we can do this xform. if ((KnownZero|KnownOne).countLeadingOnes() >= SrcBits-DstBits) { @@ -1078,17 +1073,17 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, } break; - case Instruction::Xor: // (icmp pred (xor X, XorCST), CI) - if (ConstantInt *XorCST = dyn_cast<ConstantInt>(LHSI->getOperand(1))) { + case Instruction::Xor: // (icmp pred (xor X, XorCst), CI) + if (ConstantInt *XorCst = dyn_cast<ConstantInt>(LHSI->getOperand(1))) { // If this is a comparison that tests the signbit (X < 0) or (x > -1), // fold the xor. if ((ICI.getPredicate() == ICmpInst::ICMP_SLT && RHSV == 0) || (ICI.getPredicate() == ICmpInst::ICMP_SGT && RHSV.isAllOnesValue())) { Value *CompareVal = LHSI->getOperand(0); - // If the sign bit of the XorCST is not set, there is no change to + // If the sign bit of the XorCst is not set, there is no change to // the operation, just stop using the Xor. - if (!XorCST->isNegative()) { + if (!XorCst->isNegative()) { ICI.setOperand(0, CompareVal); Worklist.Add(LHSI); return &ICI; @@ -1110,8 +1105,8 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, if (LHSI->hasOneUse()) { // (icmp u/s (xor A SignBit), C) -> (icmp s/u A, (xor C SignBit)) - if (!ICI.isEquality() && XorCST->getValue().isSignBit()) { - const APInt &SignBit = XorCST->getValue(); + if (!ICI.isEquality() && XorCst->getValue().isSignBit()) { + const APInt &SignBit = XorCst->getValue(); ICmpInst::Predicate Pred = ICI.isSigned() ? ICI.getUnsignedPredicate() : ICI.getSignedPredicate(); @@ -1120,8 +1115,8 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, } // (icmp u/s (xor A ~SignBit), C) -> (icmp s/u (xor C ~SignBit), A) - if (!ICI.isEquality() && XorCST->isMaxValue(true)) { - const APInt &NotSignBit = XorCST->getValue(); + if (!ICI.isEquality() && XorCst->isMaxValue(true)) { + const APInt &NotSignBit = XorCst->getValue(); ICmpInst::Predicate Pred = ICI.isSigned() ? ICI.getUnsignedPredicate() : ICI.getSignedPredicate(); @@ -1134,20 +1129,20 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, // (icmp ugt (xor X, C), ~C) -> (icmp ult X, C) // iff -C is a power of 2 if (ICI.getPredicate() == ICmpInst::ICMP_UGT && - XorCST->getValue() == ~RHSV && (RHSV + 1).isPowerOf2()) - return new ICmpInst(ICmpInst::ICMP_ULT, LHSI->getOperand(0), XorCST); + XorCst->getValue() == ~RHSV && (RHSV + 1).isPowerOf2()) + return new ICmpInst(ICmpInst::ICMP_ULT, LHSI->getOperand(0), XorCst); // (icmp ult (xor X, C), -C) -> (icmp uge X, C) // iff -C is a power of 2 if (ICI.getPredicate() == ICmpInst::ICMP_ULT && - XorCST->getValue() == -RHSV && RHSV.isPowerOf2()) - return new ICmpInst(ICmpInst::ICMP_UGE, LHSI->getOperand(0), XorCST); + XorCst->getValue() == -RHSV && RHSV.isPowerOf2()) + return new ICmpInst(ICmpInst::ICMP_UGE, LHSI->getOperand(0), XorCst); } break; - case Instruction::And: // (icmp pred (and X, AndCST), RHS) + case Instruction::And: // (icmp pred (and X, AndCst), RHS) if (LHSI->hasOneUse() && isa<ConstantInt>(LHSI->getOperand(1)) && LHSI->getOperand(0)->hasOneUse()) { - ConstantInt *AndCST = cast<ConstantInt>(LHSI->getOperand(1)); + ConstantInt *AndCst = cast<ConstantInt>(LHSI->getOperand(1)); // If the LHS is an AND of a truncating cast, we can widen the // and/compare to be the input width without changing the value @@ -1158,10 +1153,10 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, // Extending a relational comparison when we're checking the sign // bit would not work. if (ICI.isEquality() || - (!AndCST->isNegative() && RHSV.isNonNegative())) { + (!AndCst->isNegative() && RHSV.isNonNegative())) { Value *NewAnd = Builder->CreateAnd(Cast->getOperand(0), - ConstantExpr::getZExt(AndCST, Cast->getSrcTy())); + ConstantExpr::getZExt(AndCst, Cast->getSrcTy())); NewAnd->takeName(LHSI); return new ICmpInst(ICI.getPredicate(), NewAnd, ConstantExpr::getZExt(RHS, Cast->getSrcTy())); @@ -1177,7 +1172,7 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, if (ICI.isEquality() && RHSV.getActiveBits() <= Ty->getBitWidth()) { Value *NewAnd = Builder->CreateAnd(Cast->getOperand(0), - ConstantExpr::getTrunc(AndCST, Ty)); + ConstantExpr::getTrunc(AndCst, Ty)); NewAnd->takeName(LHSI); return new ICmpInst(ICI.getPredicate(), NewAnd, ConstantExpr::getTrunc(RHS, Ty)); @@ -1190,49 +1185,58 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, // access. BinaryOperator *Shift = dyn_cast<BinaryOperator>(LHSI->getOperand(0)); if (Shift && !Shift->isShift()) - Shift = 0; + Shift = nullptr; ConstantInt *ShAmt; - ShAmt = Shift ? dyn_cast<ConstantInt>(Shift->getOperand(1)) : 0; - Type *Ty = Shift ? Shift->getType() : 0; // Type of the shift. - Type *AndTy = AndCST->getType(); // Type of the and. - - // We can fold this as long as we can't shift unknown bits - // into the mask. This can happen with signed shift - // rights, as they sign-extend. With logical shifts, - // we must still make sure the comparison is not signed - // because we are effectively changing the - // position of the sign bit (PR17827). - // TODO: We can relax these constraints a bit more. + ShAmt = Shift ? dyn_cast<ConstantInt>(Shift->getOperand(1)) : nullptr; + + // This seemingly simple opportunity to fold away a shift turns out to + // be rather complicated. See PR17827 + // ( http://llvm.org/bugs/show_bug.cgi?id=17827 ) for details. if (ShAmt) { bool CanFold = false; unsigned ShiftOpcode = Shift->getOpcode(); if (ShiftOpcode == Instruction::AShr) { - // To test for the bad case of the signed shr, see if any - // of the bits shifted in could be tested after the mask. - uint32_t TyBits = Ty->getPrimitiveSizeInBits(); - int ShAmtVal = TyBits - ShAmt->getLimitedValue(TyBits); - - uint32_t BitWidth = AndTy->getPrimitiveSizeInBits(); - if ((APInt::getHighBitsSet(BitWidth, BitWidth-ShAmtVal) & - AndCST->getValue()) == 0) + // There may be some constraints that make this possible, + // but nothing simple has been discovered yet. + CanFold = false; + } else if (ShiftOpcode == Instruction::Shl) { + // For a left shift, we can fold if the comparison is not signed. + // We can also fold a signed comparison if the mask value and + // comparison value are not negative. These constraints may not be + // obvious, but we can prove that they are correct using an SMT + // solver. + if (!ICI.isSigned() || (!AndCst->isNegative() && !RHS->isNegative())) + CanFold = true; + } else if (ShiftOpcode == Instruction::LShr) { + // For a logical right shift, we can fold if the comparison is not + // signed. We can also fold a signed comparison if the shifted mask + // value and the shifted comparison value are not negative. + // These constraints may not be obvious, but we can prove that they + // are correct using an SMT solver. + if (!ICI.isSigned()) CanFold = true; - } else if (ShiftOpcode == Instruction::Shl || - ShiftOpcode == Instruction::LShr) { - CanFold = !ICI.isSigned(); + else { + ConstantInt *ShiftedAndCst = + cast<ConstantInt>(ConstantExpr::getShl(AndCst, ShAmt)); + ConstantInt *ShiftedRHSCst = + cast<ConstantInt>(ConstantExpr::getShl(RHS, ShAmt)); + + if (!ShiftedAndCst->isNegative() && !ShiftedRHSCst->isNegative()) + CanFold = true; + } } if (CanFold) { Constant *NewCst; - if (Shift->getOpcode() == Instruction::Shl) + if (ShiftOpcode == Instruction::Shl) NewCst = ConstantExpr::getLShr(RHS, ShAmt); else NewCst = ConstantExpr::getShl(RHS, ShAmt); // Check to see if we are shifting out any of the bits being // compared. - if (ConstantExpr::get(Shift->getOpcode(), - NewCst, ShAmt) != RHS) { + if (ConstantExpr::get(ShiftOpcode, NewCst, ShAmt) != RHS) { // If we shifted bits out, the fold is not going to work out. // As a special case, check to see if this means that the // result is always true or false now. @@ -1242,12 +1246,12 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, return ReplaceInstUsesWith(ICI, Builder->getTrue()); } else { ICI.setOperand(1, NewCst); - Constant *NewAndCST; - if (Shift->getOpcode() == Instruction::Shl) - NewAndCST = ConstantExpr::getLShr(AndCST, ShAmt); + Constant *NewAndCst; + if (ShiftOpcode == Instruction::Shl) + NewAndCst = ConstantExpr::getLShr(AndCst, ShAmt); else - NewAndCST = ConstantExpr::getShl(AndCST, ShAmt); - LHSI->setOperand(1, NewAndCST); + NewAndCst = ConstantExpr::getShl(AndCst, ShAmt); + LHSI->setOperand(1, NewAndCst); LHSI->setOperand(0, Shift->getOperand(0)); Worklist.Add(Shift); // Shift is dead. return &ICI; @@ -1264,10 +1268,10 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, // Compute C << Y. Value *NS; if (Shift->getOpcode() == Instruction::LShr) { - NS = Builder->CreateShl(AndCST, Shift->getOperand(1)); + NS = Builder->CreateShl(AndCst, Shift->getOperand(1)); } else { // Insert a logical shift. - NS = Builder->CreateLShr(AndCST, Shift->getOperand(1)); + NS = Builder->CreateLShr(AndCst, Shift->getOperand(1)); } // Compute X & (C << Y). @@ -1278,12 +1282,12 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, return &ICI; } - // Replace ((X & AndCST) > RHSV) with ((X & AndCST) != 0), if any - // bit set in (X & AndCST) will produce a result greater than RHSV. + // Replace ((X & AndCst) > RHSV) with ((X & AndCst) != 0), if any + // bit set in (X & AndCst) will produce a result greater than RHSV. if (ICI.getPredicate() == ICmpInst::ICMP_UGT) { - unsigned NTZ = AndCST->getValue().countTrailingZeros(); - if ((NTZ < AndCST->getBitWidth()) && - APInt::getOneBitSet(AndCST->getBitWidth(), NTZ).ugt(RHSV)) + unsigned NTZ = AndCst->getValue().countTrailingZeros(); + if ((NTZ < AndCst->getBitWidth()) && + APInt::getOneBitSet(AndCst->getBitWidth(), NTZ).ugt(RHSV)) return new ICmpInst(ICmpInst::ICMP_NE, LHSI, Constant::getNullValue(RHS->getType())); } @@ -1777,7 +1781,7 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, } } } - return 0; + return nullptr; } /// visitICmpInstWithCastAndCast - Handle icmp (cast x to y), (cast/cst). @@ -1792,9 +1796,9 @@ Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICI) { // Turn icmp (ptrtoint x), (ptrtoint/c) into a compare of the input if the // integer type is the same size as the pointer type. - if (TD && LHSCI->getOpcode() == Instruction::PtrToInt && - TD->getPointerTypeSizeInBits(SrcTy) == DestTy->getIntegerBitWidth()) { - Value *RHSOp = 0; + if (DL && LHSCI->getOpcode() == Instruction::PtrToInt && + DL->getPointerTypeSizeInBits(SrcTy) == DestTy->getIntegerBitWidth()) { + Value *RHSOp = nullptr; if (Constant *RHSC = dyn_cast<Constant>(ICI.getOperand(1))) { RHSOp = ConstantExpr::getIntToPtr(RHSC, SrcTy); } else if (PtrToIntInst *RHSC = dyn_cast<PtrToIntInst>(ICI.getOperand(1))) { @@ -1812,7 +1816,7 @@ Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICI) { // Enforce this. if (LHSCI->getOpcode() != Instruction::ZExt && LHSCI->getOpcode() != Instruction::SExt) - return 0; + return nullptr; bool isSignedExt = LHSCI->getOpcode() == Instruction::SExt; bool isSignedCmp = ICI.isSigned(); @@ -1821,12 +1825,12 @@ Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICI) { // Not an extension from the same type? RHSCIOp = CI->getOperand(0); if (RHSCIOp->getType() != LHSCIOp->getType()) - return 0; + return nullptr; // If the signedness of the two casts doesn't agree (i.e. one is a sext // and the other is a zext), then we can't handle this. if (CI->getOpcode() != LHSCI->getOpcode()) - return 0; + return nullptr; // Deal with equality cases early. if (ICI.isEquality()) @@ -1844,7 +1848,7 @@ Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICI) { // If we aren't dealing with a constant on the RHS, exit early ConstantInt *CI = dyn_cast<ConstantInt>(ICI.getOperand(1)); if (!CI) - return 0; + return nullptr; // Compute the constant that would happen if we truncated to SrcTy then // reextended to DestTy. @@ -1873,7 +1877,7 @@ Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICI) { // by SimplifyICmpInst, so only deal with the tricky case. if (isSignedCmp || !isSignedExt) - return 0; + return nullptr; // Evaluate the comparison for LT (we invert for GT below). LE and GE cases // should have been folded away previously and not enter in here. @@ -1909,12 +1913,12 @@ static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, // In order to eliminate the add-with-constant, the compare can be its only // use. Instruction *AddWithCst = cast<Instruction>(I.getOperand(0)); - if (!AddWithCst->hasOneUse()) return 0; + if (!AddWithCst->hasOneUse()) return nullptr; // If CI2 is 2^7, 2^15, 2^31, then it might be an sadd.with.overflow. - if (!CI2->getValue().isPowerOf2()) return 0; + if (!CI2->getValue().isPowerOf2()) return nullptr; unsigned NewWidth = CI2->getValue().countTrailingZeros(); - if (NewWidth != 7 && NewWidth != 15 && NewWidth != 31) return 0; + if (NewWidth != 7 && NewWidth != 15 && NewWidth != 31) return nullptr; // The width of the new add formed is 1 more than the bias. ++NewWidth; @@ -1922,7 +1926,7 @@ static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, // Check to see that CI1 is an all-ones value with NewWidth bits. if (CI1->getBitWidth() == NewWidth || CI1->getValue() != APInt::getLowBitsSet(CI1->getBitWidth(), NewWidth)) - return 0; + return nullptr; // This is only really a signed overflow check if the inputs have been // sign-extended; check for that condition. For example, if CI2 is 2^31 and @@ -1930,25 +1934,24 @@ static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, unsigned NeededSignBits = CI1->getBitWidth() - NewWidth + 1; if (IC.ComputeNumSignBits(A) < NeededSignBits || IC.ComputeNumSignBits(B) < NeededSignBits) - return 0; + return nullptr; // In order to replace the original add with a narrower // llvm.sadd.with.overflow, the only uses allowed are the add-with-constant // and truncates that discard the high bits of the add. Verify that this is // the case. Instruction *OrigAdd = cast<Instruction>(AddWithCst->getOperand(0)); - for (Value::use_iterator UI = OrigAdd->use_begin(), E = OrigAdd->use_end(); - UI != E; ++UI) { - if (*UI == AddWithCst) continue; + for (User *U : OrigAdd->users()) { + if (U == AddWithCst) continue; // Only accept truncates for now. We would really like a nice recursive // predicate like SimplifyDemandedBits, but which goes downwards the use-def // chain to see which bits of a value are actually demanded. If the // original add had another add which was then immediately truncated, we // could still do the transformation. - TruncInst *TI = dyn_cast<TruncInst>(*UI); - if (TI == 0 || - TI->getType()->getPrimitiveSizeInBits() > NewWidth) return 0; + TruncInst *TI = dyn_cast<TruncInst>(U); + if (!TI || TI->getType()->getPrimitiveSizeInBits() > NewWidth) + return nullptr; } // If the pattern matches, truncate the inputs to the narrower type and @@ -1984,11 +1987,11 @@ static Instruction *ProcessUAddIdiom(Instruction &I, Value *OrigAddV, InstCombiner &IC) { // Don't bother doing this transformation for pointers, don't do it for // vectors. - if (!isa<IntegerType>(OrigAddV->getType())) return 0; + if (!isa<IntegerType>(OrigAddV->getType())) return nullptr; // If the add is a constant expr, then we don't bother transforming it. Instruction *OrigAdd = dyn_cast<Instruction>(OrigAddV); - if (OrigAdd == 0) return 0; + if (!OrigAdd) return nullptr; Value *LHS = OrigAdd->getOperand(0), *RHS = OrigAdd->getOperand(1); @@ -2009,6 +2012,240 @@ static Instruction *ProcessUAddIdiom(Instruction &I, Value *OrigAddV, return ExtractValueInst::Create(Call, 1, "uadd.overflow"); } +/// \brief Recognize and process idiom involving test for multiplication +/// overflow. +/// +/// The caller has matched a pattern of the form: +/// I = cmp u (mul(zext A, zext B), V +/// The function checks if this is a test for overflow and if so replaces +/// multiplication with call to 'mul.with.overflow' intrinsic. +/// +/// \param I Compare instruction. +/// \param MulVal Result of 'mult' instruction. It is one of the arguments of +/// the compare instruction. Must be of integer type. +/// \param OtherVal The other argument of compare instruction. +/// \returns Instruction which must replace the compare instruction, NULL if no +/// replacement required. +static Instruction *ProcessUMulZExtIdiom(ICmpInst &I, Value *MulVal, + Value *OtherVal, InstCombiner &IC) { + // Don't bother doing this transformation for pointers, don't do it for + // vectors. + if (!isa<IntegerType>(MulVal->getType())) + return nullptr; + + assert(I.getOperand(0) == MulVal || I.getOperand(1) == MulVal); + assert(I.getOperand(0) == OtherVal || I.getOperand(1) == OtherVal); + Instruction *MulInstr = cast<Instruction>(MulVal); + assert(MulInstr->getOpcode() == Instruction::Mul); + + Instruction *LHS = cast<Instruction>(MulInstr->getOperand(0)), + *RHS = cast<Instruction>(MulInstr->getOperand(1)); + assert(LHS->getOpcode() == Instruction::ZExt); + assert(RHS->getOpcode() == Instruction::ZExt); + Value *A = LHS->getOperand(0), *B = RHS->getOperand(0); + + // Calculate type and width of the result produced by mul.with.overflow. + Type *TyA = A->getType(), *TyB = B->getType(); + unsigned WidthA = TyA->getPrimitiveSizeInBits(), + WidthB = TyB->getPrimitiveSizeInBits(); + unsigned MulWidth; + Type *MulType; + if (WidthB > WidthA) { + MulWidth = WidthB; + MulType = TyB; + } else { + MulWidth = WidthA; + MulType = TyA; + } + + // In order to replace the original mul with a narrower mul.with.overflow, + // all uses must ignore upper bits of the product. The number of used low + // bits must be not greater than the width of mul.with.overflow. + if (MulVal->hasNUsesOrMore(2)) + for (User *U : MulVal->users()) { + if (U == &I) + continue; + if (TruncInst *TI = dyn_cast<TruncInst>(U)) { + // Check if truncation ignores bits above MulWidth. + unsigned TruncWidth = TI->getType()->getPrimitiveSizeInBits(); + if (TruncWidth > MulWidth) + return nullptr; + } else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) { + // Check if AND ignores bits above MulWidth. + if (BO->getOpcode() != Instruction::And) + return nullptr; + if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand(1))) { + const APInt &CVal = CI->getValue(); + if (CVal.getBitWidth() - CVal.countLeadingZeros() > MulWidth) + return nullptr; + } + } else { + // Other uses prohibit this transformation. + return nullptr; + } + } + + // Recognize patterns + switch (I.getPredicate()) { + case ICmpInst::ICMP_EQ: + case ICmpInst::ICMP_NE: + // Recognize pattern: + // mulval = mul(zext A, zext B) + // cmp eq/neq mulval, zext trunc mulval + if (ZExtInst *Zext = dyn_cast<ZExtInst>(OtherVal)) + if (Zext->hasOneUse()) { + Value *ZextArg = Zext->getOperand(0); + if (TruncInst *Trunc = dyn_cast<TruncInst>(ZextArg)) + if (Trunc->getType()->getPrimitiveSizeInBits() == MulWidth) + break; //Recognized + } + + // Recognize pattern: + // mulval = mul(zext A, zext B) + // cmp eq/neq mulval, and(mulval, mask), mask selects low MulWidth bits. + ConstantInt *CI; + Value *ValToMask; + if (match(OtherVal, m_And(m_Value(ValToMask), m_ConstantInt(CI)))) { + if (ValToMask != MulVal) + return nullptr; + const APInt &CVal = CI->getValue() + 1; + if (CVal.isPowerOf2()) { + unsigned MaskWidth = CVal.logBase2(); + if (MaskWidth == MulWidth) + break; // Recognized + } + } + return nullptr; + + case ICmpInst::ICMP_UGT: + // Recognize pattern: + // mulval = mul(zext A, zext B) + // cmp ugt mulval, max + if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) { + APInt MaxVal = APInt::getMaxValue(MulWidth); + MaxVal = MaxVal.zext(CI->getBitWidth()); + if (MaxVal.eq(CI->getValue())) + break; // Recognized + } + return nullptr; + + case ICmpInst::ICMP_UGE: + // Recognize pattern: + // mulval = mul(zext A, zext B) + // cmp uge mulval, max+1 + if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) { + APInt MaxVal = APInt::getOneBitSet(CI->getBitWidth(), MulWidth); + if (MaxVal.eq(CI->getValue())) + break; // Recognized + } + return nullptr; + + case ICmpInst::ICMP_ULE: + // Recognize pattern: + // mulval = mul(zext A, zext B) + // cmp ule mulval, max + if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) { + APInt MaxVal = APInt::getMaxValue(MulWidth); + MaxVal = MaxVal.zext(CI->getBitWidth()); + if (MaxVal.eq(CI->getValue())) + break; // Recognized + } + return nullptr; + + case ICmpInst::ICMP_ULT: + // Recognize pattern: + // mulval = mul(zext A, zext B) + // cmp ule mulval, max + 1 + if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) { + APInt MaxVal = APInt::getOneBitSet(CI->getBitWidth(), MulWidth); + if (MaxVal.eq(CI->getValue())) + break; // Recognized + } + return nullptr; + + default: + return nullptr; + } + + InstCombiner::BuilderTy *Builder = IC.Builder; + Builder->SetInsertPoint(MulInstr); + Module *M = I.getParent()->getParent()->getParent(); + + // Replace: mul(zext A, zext B) --> mul.with.overflow(A, B) + Value *MulA = A, *MulB = B; + if (WidthA < MulWidth) + MulA = Builder->CreateZExt(A, MulType); + if (WidthB < MulWidth) + MulB = Builder->CreateZExt(B, MulType); + Value *F = + Intrinsic::getDeclaration(M, Intrinsic::umul_with_overflow, MulType); + CallInst *Call = Builder->CreateCall2(F, MulA, MulB, "umul"); + IC.Worklist.Add(MulInstr); + + // If there are uses of mul result other than the comparison, we know that + // they are truncation or binary AND. Change them to use result of + // mul.with.overflow and adjust properly mask/size. + if (MulVal->hasNUsesOrMore(2)) { + Value *Mul = Builder->CreateExtractValue(Call, 0, "umul.value"); + for (User *U : MulVal->users()) { + if (U == &I || U == OtherVal) + continue; + if (TruncInst *TI = dyn_cast<TruncInst>(U)) { + if (TI->getType()->getPrimitiveSizeInBits() == MulWidth) + IC.ReplaceInstUsesWith(*TI, Mul); + else + TI->setOperand(0, Mul); + } else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) { + assert(BO->getOpcode() == Instruction::And); + // Replace (mul & mask) --> zext (mul.with.overflow & short_mask) + ConstantInt *CI = cast<ConstantInt>(BO->getOperand(1)); + APInt ShortMask = CI->getValue().trunc(MulWidth); + Value *ShortAnd = Builder->CreateAnd(Mul, ShortMask); + Instruction *Zext = + cast<Instruction>(Builder->CreateZExt(ShortAnd, BO->getType())); + IC.Worklist.Add(Zext); + IC.ReplaceInstUsesWith(*BO, Zext); + } else { + llvm_unreachable("Unexpected Binary operation"); + } + IC.Worklist.Add(cast<Instruction>(U)); + } + } + if (isa<Instruction>(OtherVal)) + IC.Worklist.Add(cast<Instruction>(OtherVal)); + + // The original icmp gets replaced with the overflow value, maybe inverted + // depending on predicate. + bool Inverse = false; + switch (I.getPredicate()) { + case ICmpInst::ICMP_NE: + break; + case ICmpInst::ICMP_EQ: + Inverse = true; + break; + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_UGE: + if (I.getOperand(0) == MulVal) + break; + Inverse = true; + break; + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_ULE: + if (I.getOperand(1) == MulVal) + break; + Inverse = true; + break; + default: + llvm_unreachable("Unexpected predicate"); + } + if (Inverse) { + Value *Res = Builder->CreateExtractValue(Call, 1); + return BinaryOperator::CreateNot(Res); + } + + return ExtractValueInst::Create(Call, 1); +} + // DemandedBitsLHSMask - When performing a comparison against a constant, // it is possible that not all the bits in the LHS are demanded. This helper // method computes the mask that IS demanded. @@ -2048,7 +2285,7 @@ static APInt DemandedBitsLHSMask(ICmpInst &I, /// \brief Check if the order of \p Op0 and \p Op1 as operand in an ICmpInst /// should be swapped. -/// The descision is based on how many times these two operands are reused +/// The decision is based on how many times these two operands are reused /// as subtract operands and their positions in those instructions. /// The rational is that several architectures use the same instruction for /// both subtract and cmp, thus it is better if the order of those operands @@ -2064,12 +2301,12 @@ static bool swapMayExposeCSEOpportunities(const Value * Op0, // Each time Op0 is the first operand, count -1: swapping is bad, the // subtract has already the same layout as the compare. // Each time Op0 is the second operand, count +1: swapping is good, the - // subtract has a diffrent layout as the compare. + // subtract has a different layout as the compare. // At the end, if the benefit is greater than 0, Op0 should come second to // expose more CSE opportunities. int GlobalSwapBenefits = 0; - for (Value::const_use_iterator UI = Op0->use_begin(), UIEnd = Op0->use_end(); UI != UIEnd; ++UI) { - const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(*UI); + for (const User *U : Op0->users()) { + const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(U); if (!BinOp || BinOp->getOpcode() != Instruction::Sub) continue; // If Op0 is the first argument, this is not beneficial to swap the @@ -2104,7 +2341,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { Changed = true; } - if (Value *V = SimplifyICmpInst(I.getPredicate(), Op0, Op1, TD)) + if (Value *V = SimplifyICmpInst(I.getPredicate(), Op0, Op1, DL)) return ReplaceInstUsesWith(I, V); // comparing -val or val with non-zero is the same as just comparing val @@ -2172,14 +2409,14 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { unsigned BitWidth = 0; if (Ty->isIntOrIntVectorTy()) BitWidth = Ty->getScalarSizeInBits(); - else if (TD) // Pointers require TD info to get their size. - BitWidth = TD->getTypeSizeInBits(Ty->getScalarType()); + else if (DL) // Pointers require DL info to get their size. + BitWidth = DL->getTypeSizeInBits(Ty->getScalarType()); bool isSignBit = false; // See if we are doing a comparison with a constant. if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { - Value *A = 0, *B = 0; + Value *A = nullptr, *B = nullptr; // Match the following pattern, which is a common idiom when writing // overflow-safe integer arithmetic function. The source performs an @@ -2292,21 +2529,29 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // bit is set. If the comparison is against zero, then this is a check // to see if *that* bit is set. APInt Op0KnownZeroInverted = ~Op0KnownZero; - if (~Op1KnownZero == 0 && Op0KnownZeroInverted.isPowerOf2()) { + if (~Op1KnownZero == 0) { // If the LHS is an AND with the same constant, look through it. - Value *LHS = 0; - ConstantInt *LHSC = 0; + Value *LHS = nullptr; + ConstantInt *LHSC = nullptr; if (!match(Op0, m_And(m_Value(LHS), m_ConstantInt(LHSC))) || LHSC->getValue() != Op0KnownZeroInverted) LHS = Op0; // If the LHS is 1 << x, and we know the result is a power of 2 like 8, // then turn "((1 << x)&8) == 0" into "x != 3". - Value *X = 0; + // or turn "((1 << x)&7) == 0" into "x > 2". + Value *X = nullptr; if (match(LHS, m_Shl(m_One(), m_Value(X)))) { - unsigned CmpVal = Op0KnownZeroInverted.countTrailingZeros(); - return new ICmpInst(ICmpInst::ICMP_NE, X, - ConstantInt::get(X->getType(), CmpVal)); + APInt ValToCheck = Op0KnownZeroInverted; + if (ValToCheck.isPowerOf2()) { + unsigned CmpVal = ValToCheck.countTrailingZeros(); + return new ICmpInst(ICmpInst::ICMP_NE, X, + ConstantInt::get(X->getType(), CmpVal)); + } else if ((++ValToCheck).isPowerOf2()) { + unsigned CmpVal = ValToCheck.countTrailingZeros() - 1; + return new ICmpInst(ICmpInst::ICMP_UGT, X, + ConstantInt::get(X->getType(), CmpVal)); + } } // If the LHS is 8 >>u x, and we know the result is a power of 2 like 1, @@ -2329,21 +2574,29 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // bit is set. If the comparison is against zero, then this is a check // to see if *that* bit is set. APInt Op0KnownZeroInverted = ~Op0KnownZero; - if (~Op1KnownZero == 0 && Op0KnownZeroInverted.isPowerOf2()) { + if (~Op1KnownZero == 0) { // If the LHS is an AND with the same constant, look through it. - Value *LHS = 0; - ConstantInt *LHSC = 0; + Value *LHS = nullptr; + ConstantInt *LHSC = nullptr; if (!match(Op0, m_And(m_Value(LHS), m_ConstantInt(LHSC))) || LHSC->getValue() != Op0KnownZeroInverted) LHS = Op0; // If the LHS is 1 << x, and we know the result is a power of 2 like 8, // then turn "((1 << x)&8) != 0" into "x == 3". - Value *X = 0; + // or turn "((1 << x)&7) != 0" into "x < 3". + Value *X = nullptr; if (match(LHS, m_Shl(m_One(), m_Value(X)))) { - unsigned CmpVal = Op0KnownZeroInverted.countTrailingZeros(); - return new ICmpInst(ICmpInst::ICMP_EQ, X, - ConstantInt::get(X->getType(), CmpVal)); + APInt ValToCheck = Op0KnownZeroInverted; + if (ValToCheck.isPowerOf2()) { + unsigned CmpVal = ValToCheck.countTrailingZeros(); + return new ICmpInst(ICmpInst::ICMP_EQ, X, + ConstantInt::get(X->getType(), CmpVal)); + } else if ((++ValToCheck).isPowerOf2()) { + unsigned CmpVal = ValToCheck.countTrailingZeros(); + return new ICmpInst(ICmpInst::ICMP_ULT, X, + ConstantInt::get(X->getType(), CmpVal)); + } } // If the LHS is 8 >>u x, and we know the result is a power of 2 like 1, @@ -2468,10 +2721,10 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // operands has at least one user besides the compare (the select), // which would often largely negate the benefit of folding anyway. if (I.hasOneUse()) - if (SelectInst *SI = dyn_cast<SelectInst>(*I.use_begin())) + if (SelectInst *SI = dyn_cast<SelectInst>(*I.user_begin())) if ((SI->getOperand(1) == Op0 && SI->getOperand(2) == Op1) || (SI->getOperand(2) == Op0 && SI->getOperand(1) == Op1)) - return 0; + return nullptr; // See if we are doing a comparison between a constant and an instruction that // can be folded into the comparison. @@ -2507,7 +2760,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // If either operand of the select is a constant, we can fold the // comparison into the select arms, which will cause one to be // constant folded and the select turned into a bitwise or. - Value *Op1 = 0, *Op2 = 0; + Value *Op1 = nullptr, *Op2 = nullptr; if (Constant *C = dyn_cast<Constant>(LHSI->getOperand(1))) Op1 = ConstantExpr::getICmp(I.getPredicate(), C, RHSC); if (Constant *C = dyn_cast<Constant>(LHSI->getOperand(2))) @@ -2532,8 +2785,8 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { } case Instruction::IntToPtr: // icmp pred inttoptr(X), null -> icmp pred X, 0 - if (RHSC->isNullValue() && TD && - TD->getIntPtrType(RHSC->getType()) == + if (RHSC->isNullValue() && DL && + DL->getIntPtrType(RHSC->getType()) == LHSI->getOperand(0)->getType()) return new ICmpInst(I.getPredicate(), LHSI->getOperand(0), Constant::getNullValue(LHSI->getOperand(0)->getType())); @@ -2619,7 +2872,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // Analyze the case when either Op0 or Op1 is an add instruction. // Op0 = A + B (or A and B are null); Op1 = C + D (or C and D are null). - Value *A = 0, *B = 0, *C = 0, *D = 0; + Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr; if (BO0 && BO0->getOpcode() == Instruction::Add) A = BO0->getOperand(0), B = BO0->getOperand(1); if (BO1 && BO1->getOpcode() == Instruction::Add) @@ -2714,7 +2967,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // Analyze the case when either Op0 or Op1 is a sub instruction. // Op0 = A - B (or A and B are null); Op1 = C - D (or C and D are null). - A = 0; B = 0; C = 0; D = 0; + A = nullptr; B = nullptr; C = nullptr; D = nullptr; if (BO0 && BO0->getOpcode() == Instruction::Sub) A = BO0->getOperand(0), B = BO0->getOperand(1); if (BO1 && BO1->getOpcode() == Instruction::Sub) @@ -2740,7 +2993,17 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { BO0->hasOneUse() && BO1->hasOneUse()) return new ICmpInst(Pred, D, B); - BinaryOperator *SRem = NULL; + // icmp (0-X) < cst --> x > -cst + if (NoOp0WrapProblem && ICmpInst::isSigned(Pred)) { + Value *X; + if (match(BO0, m_Neg(m_Value(X)))) + if (ConstantInt *RHSC = dyn_cast<ConstantInt>(Op1)) + if (!RHSC->isMinValue(/*isSigned=*/true)) + return new ICmpInst(I.getSwappedPredicate(), X, + ConstantExpr::getNeg(RHSC)); + } + + BinaryOperator *SRem = nullptr; // icmp (srem X, Y), Y if (BO0 && BO0->getOpcode() == Instruction::SRem && Op1 == BO0->getOperand(1)) @@ -2878,6 +3141,16 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { (Op0 == A || Op0 == B)) if (Instruction *R = ProcessUAddIdiom(I, Op1, *this)) return R; + + // (zext a) * (zext b) --> llvm.umul.with.overflow. + if (match(Op0, m_Mul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) { + if (Instruction *R = ProcessUMulZExtIdiom(I, Op0, Op1, *this)) + return R; + } + if (match(Op1, m_Mul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) { + if (Instruction *R = ProcessUMulZExtIdiom(I, Op1, Op0, *this)) + return R; + } } if (I.isEquality()) { @@ -2919,7 +3192,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // (X&Z) == (Y&Z) -> (X^Y) & Z == 0 if (match(Op0, m_OneUse(m_And(m_Value(A), m_Value(B)))) && match(Op1, m_OneUse(m_And(m_Value(C), m_Value(D))))) { - Value *X = 0, *Y = 0, *Z = 0; + Value *X = nullptr, *Y = nullptr, *Z = nullptr; if (A == C) { X = B; Y = D; Z = A; @@ -3010,7 +3283,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { if (match(Op1, m_Add(m_Value(X), m_ConstantInt(Cst))) && Op0 == X) return FoldICmpAddOpCst(I, X, Cst, I.getSwappedPredicate()); } - return Changed ? &I : 0; + return Changed ? &I : nullptr; } /// FoldFCmp_IntToFP_Cst - Fold fcmp ([us]itofp x, cst) if possible. @@ -3018,13 +3291,13 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I, Instruction *LHSI, Constant *RHSC) { - if (!isa<ConstantFP>(RHSC)) return 0; + if (!isa<ConstantFP>(RHSC)) return nullptr; const APFloat &RHS = cast<ConstantFP>(RHSC)->getValueAPF(); // Get the width of the mantissa. We don't want to hack on conversions that // might lose information from the integer, e.g. "i64 -> float" int MantissaWidth = LHSI->getType()->getFPMantissaWidth(); - if (MantissaWidth == -1) return 0; // Unknown. + if (MantissaWidth == -1) return nullptr; // Unknown. // Check to see that the input is converted from an integer type that is small // enough that preserves all bits. TODO: check here for "known" sign bits. @@ -3038,7 +3311,7 @@ Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I, // If the conversion would lose info, don't hack on this. if ((int)InputSize > MantissaWidth) - return 0; + return nullptr; // Otherwise, we can potentially simplify the comparison. We know that it // will always come through as an integer value and we know the constant is @@ -3229,7 +3502,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - if (Value *V = SimplifyFCmpInst(I.getPredicate(), Op0, Op1, TD)) + if (Value *V = SimplifyFCmpInst(I.getPredicate(), Op0, Op1, DL)) return ReplaceInstUsesWith(I, V); // Simplify 'fcmp pred X, X' @@ -3313,31 +3586,6 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { if (Instruction *NV = FoldFCmp_IntToFP_Cst(I, LHSI, RHSC)) return NV; break; - case Instruction::Select: { - // If either operand of the select is a constant, we can fold the - // comparison into the select arms, which will cause one to be - // constant folded and the select turned into a bitwise or. - Value *Op1 = 0, *Op2 = 0; - if (LHSI->hasOneUse()) { - if (Constant *C = dyn_cast<Constant>(LHSI->getOperand(1))) { - // Fold the known value into the constant operand. - Op1 = ConstantExpr::getCompare(I.getPredicate(), C, RHSC); - // Insert a new FCmp of the other select operand. - Op2 = Builder->CreateFCmp(I.getPredicate(), - LHSI->getOperand(2), RHSC, I.getName()); - } else if (Constant *C = dyn_cast<Constant>(LHSI->getOperand(2))) { - // Fold the known value into the constant operand. - Op2 = ConstantExpr::getCompare(I.getPredicate(), C, RHSC); - // Insert a new FCmp of the other select operand. - Op1 = Builder->CreateFCmp(I.getPredicate(), LHSI->getOperand(1), - RHSC, I.getName()); - } - } - - if (Op1) - return SelectInst::Create(LHSI->getOperand(0), Op1, Op2); - break; - } case Instruction::FSub: { // fcmp pred (fneg x), C -> fcmp swap(pred) x, -C Value *Op; @@ -3409,5 +3657,5 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { return new FCmpInst(I.getPredicate(), LHSExt->getOperand(0), RHSExt->getOperand(0)); - return Changed ? &I : 0; + return Changed ? &I : nullptr; } |