diff options
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineAddSub.cpp')
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineAddSub.cpp | 350 |
1 files changed, 158 insertions, 192 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 4d2c89e..c36a955 100644 --- a/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -84,43 +84,37 @@ bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS) { } Instruction *InstCombiner::visitAdd(BinaryOperator &I) { - bool Changed = SimplifyCommutative(I); + bool Changed = SimplifyAssociativeOrCommutative(I); Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); if (Value *V = SimplifyAddInst(LHS, RHS, I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), TD)) return ReplaceInstUsesWith(I, V); - - if (Constant *RHSC = dyn_cast<Constant>(RHS)) { - if (ConstantInt *CI = dyn_cast<ConstantInt>(RHSC)) { - // X + (signbit) --> X ^ signbit - const APInt& Val = CI->getValue(); - uint32_t BitWidth = Val.getBitWidth(); - if (Val == APInt::getSignBit(BitWidth)) - return BinaryOperator::CreateXor(LHS, RHS); - - // See if SimplifyDemandedBits can simplify this. This handles stuff like - // (X & 254)+1 -> (X&254)|1 - if (SimplifyDemandedInstructionBits(I)) - return &I; - - // zext(bool) + C -> bool ? C + 1 : C - if (ZExtInst *ZI = dyn_cast<ZExtInst>(LHS)) - if (ZI->getSrcTy() == Type::getInt1Ty(I.getContext())) - return SelectInst::Create(ZI->getOperand(0), AddOne(CI), CI); - } + // (A*B)+(A*C) -> A*(B+C) etc + if (Value *V = SimplifyUsingDistributiveLaws(I)) + return ReplaceInstUsesWith(I, V); - if (isa<PHINode>(LHS)) - if (Instruction *NV = FoldOpIntoPhi(I)) - return NV; + if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) { + // X + (signbit) --> X ^ signbit + const APInt &Val = CI->getValue(); + if (Val.isSignBit()) + return BinaryOperator::CreateXor(LHS, RHS); + + // See if SimplifyDemandedBits can simplify this. This handles stuff like + // (X & 254)+1 -> (X&254)|1 + if (SimplifyDemandedInstructionBits(I)) + return &I; + + // zext(bool) + C -> bool ? C + 1 : C + if (ZExtInst *ZI = dyn_cast<ZExtInst>(LHS)) + if (ZI->getSrcTy()->isIntegerTy(1)) + return SelectInst::Create(ZI->getOperand(0), AddOne(CI), CI); - ConstantInt *XorRHS = 0; - Value *XorLHS = 0; - if (isa<ConstantInt>(RHSC) && - match(LHS, m_Xor(m_Value(XorLHS), m_ConstantInt(XorRHS)))) { + Value *XorLHS = 0; ConstantInt *XorRHS = 0; + if (match(LHS, m_Xor(m_Value(XorLHS), m_ConstantInt(XorRHS)))) { uint32_t TySizeBits = I.getType()->getScalarSizeInBits(); - const APInt& RHSVal = cast<ConstantInt>(RHSC)->getValue(); + const APInt &RHSVal = CI->getValue(); unsigned ExtendAmt = 0; // If we have ADD(XOR(AND(X, 0xFF), 0x80), 0xF..F80), it's a sext. // If we have ADD(XOR(AND(X, 0xFF), 0xF..F80), 0x80), it's a sext. @@ -130,13 +124,13 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { else if (XorRHS->getValue().isPowerOf2()) ExtendAmt = TySizeBits - XorRHS->getValue().logBase2() - 1; } - + if (ExtendAmt) { APInt Mask = APInt::getHighBitsSet(TySizeBits, ExtendAmt); if (!MaskedValueIsZero(XorLHS, Mask)) ExtendAmt = 0; } - + if (ExtendAmt) { Constant *ShAmt = ConstantInt::get(I.getType(), ExtendAmt); Value *NewShl = Builder->CreateShl(XorLHS, ShAmt, "sext"); @@ -145,34 +139,28 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { } } + if (isa<Constant>(RHS) && isa<PHINode>(LHS)) + if (Instruction *NV = FoldOpIntoPhi(I)) + return NV; + if (I.getType()->isIntegerTy(1)) return BinaryOperator::CreateXor(LHS, RHS); - if (I.getType()->isIntegerTy()) { - // X + X --> X << 1 - if (LHS == RHS) - return BinaryOperator::CreateShl(LHS, ConstantInt::get(I.getType(), 1)); - - if (Instruction *RHSI = dyn_cast<Instruction>(RHS)) { - if (RHSI->getOpcode() == Instruction::Sub) - if (LHS == RHSI->getOperand(1)) // A + (B - A) --> B - return ReplaceInstUsesWith(I, RHSI->getOperand(0)); - } - if (Instruction *LHSI = dyn_cast<Instruction>(LHS)) { - if (LHSI->getOpcode() == Instruction::Sub) - if (RHS == LHSI->getOperand(1)) // (B - A) + A --> B - return ReplaceInstUsesWith(I, LHSI->getOperand(0)); - } + // X + X --> X << 1 + if (LHS == RHS) { + BinaryOperator *New = + BinaryOperator::CreateShl(LHS, ConstantInt::get(I.getType(), 1)); + New->setHasNoSignedWrap(I.hasNoSignedWrap()); + New->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); + return New; } // -A + B --> B - A // -A + -B --> -(A + B) if (Value *LHSV = dyn_castNegVal(LHS)) { - if (LHS->getType()->isIntOrIntVectorTy()) { - if (Value *RHSV = dyn_castNegVal(RHS)) { - Value *NewAdd = Builder->CreateAdd(LHSV, RHSV, "sum"); - return BinaryOperator::CreateNeg(NewAdd); - } + if (Value *RHSV = dyn_castNegVal(RHS)) { + Value *NewAdd = Builder->CreateAdd(LHSV, RHSV, "sum"); + return BinaryOperator::CreateNeg(NewAdd); } return BinaryOperator::CreateSub(RHS, LHSV); @@ -199,11 +187,6 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (dyn_castFoldableMul(RHS, C2) == LHS) return BinaryOperator::CreateMul(LHS, AddOne(C2)); - // X + ~X --> -1 since ~X = -X-1 - if (match(LHS, m_Not(m_Specific(RHS))) || - match(RHS, m_Not(m_Specific(LHS)))) - return ReplaceInstUsesWith(I, Constant::getAllOnesValue(I.getType())); - // A+B --> A|B iff A and B have no bits set in common. if (const IntegerType *IT = dyn_cast<IntegerType>(I.getType())) { APInt Mask = APInt::getAllOnesValue(IT->getBitWidth()); @@ -222,7 +205,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { } // W*X + Y*Z --> W * (X+Z) iff W == Y - if (I.getType()->isIntOrIntVectorTy()) { + { Value *W, *X, *Y, *Z; if (match(LHS, m_Mul(m_Value(W), m_Value(X))) && match(RHS, m_Mul(m_Value(Y), m_Value(Z)))) { @@ -251,24 +234,22 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { // (X & FF00) + xx00 -> (X+xx00) & FF00 if (LHS->hasOneUse() && - match(LHS, m_And(m_Value(X), m_ConstantInt(C2)))) { - Constant *Anded = ConstantExpr::getAnd(CRHS, C2); - if (Anded == CRHS) { - // See if all bits from the first bit set in the Add RHS up are included - // in the mask. First, get the rightmost bit. - const APInt &AddRHSV = CRHS->getValue(); - - // Form a mask of all bits from the lowest bit added through the top. - APInt AddRHSHighBits(~((AddRHSV & -AddRHSV)-1)); - - // See if the and mask includes all of these bits. - APInt AddRHSHighBitsAnd(AddRHSHighBits & C2->getValue()); - - if (AddRHSHighBits == AddRHSHighBitsAnd) { - // Okay, the xform is safe. Insert the new add pronto. - Value *NewAdd = Builder->CreateAdd(X, CRHS, LHS->getName()); - return BinaryOperator::CreateAnd(NewAdd, C2); - } + match(LHS, m_And(m_Value(X), m_ConstantInt(C2))) && + CRHS->getValue() == (CRHS->getValue() & C2->getValue())) { + // See if all bits from the first bit set in the Add RHS up are included + // in the mask. First, get the rightmost bit. + const APInt &AddRHSV = CRHS->getValue(); + + // Form a mask of all bits from the lowest bit added through the top. + APInt AddRHSHighBits(~((AddRHSV & -AddRHSV)-1)); + + // See if the and mask includes all of these bits. + APInt AddRHSHighBitsAnd(AddRHSHighBits & C2->getValue()); + + if (AddRHSHighBits == AddRHSHighBitsAnd) { + // Okay, the xform is safe. Insert the new add pronto. + Value *NewAdd = Builder->CreateAdd(X, CRHS, LHS->getName()); + return BinaryOperator::CreateAnd(NewAdd, C2); } } @@ -293,12 +274,11 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { // Can we fold the add into the argument of the select? // We check both true and false select arguments for a matching subtract. - if (match(FV, m_Zero()) && - match(TV, m_Sub(m_Value(N), m_Specific(A)))) + if (match(FV, m_Zero()) && match(TV, m_Sub(m_Value(N), m_Specific(A)))) // Fold the add into the true select value. return SelectInst::Create(SI->getCondition(), N, A); - if (match(TV, m_Zero()) && - match(FV, m_Sub(m_Value(N), m_Specific(A)))) + + if (match(TV, m_Zero()) && match(FV, m_Sub(m_Value(N), m_Specific(A)))) // Fold the add into the false select value. return SelectInst::Create(SI->getCondition(), A, N); } @@ -342,7 +322,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { } Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { - bool Changed = SimplifyCommutative(I); + bool Changed = SimplifyAssociativeOrCommutative(I); Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); if (Constant *RHSC = dyn_cast<Constant>(RHS)) { @@ -424,6 +404,10 @@ Value *InstCombiner::EmitGEPOffset(User *GEP) { const Type *IntPtrTy = TD.getIntPtrType(GEP->getContext()); Value *Result = Constant::getNullValue(IntPtrTy); + // If the GEP is inbounds, we know that none of the addressing operations will + // overflow in an unsigned sense. + bool isInBounds = cast<GEPOperator>(GEP)->isInBounds(); + // Build a mask for high order bits. unsigned IntPtrWidth = TD.getPointerSizeInBits(); uint64_t PtrSizeMask = ~0ULL >> (64-IntPtrWidth); @@ -439,16 +423,16 @@ Value *InstCombiner::EmitGEPOffset(User *GEP) { if (const StructType *STy = dyn_cast<StructType>(*GTI)) { Size = TD.getStructLayout(STy)->getElementOffset(OpC->getZExtValue()); - Result = Builder->CreateAdd(Result, - ConstantInt::get(IntPtrTy, Size), - GEP->getName()+".offs"); + if (Size) + Result = Builder->CreateAdd(Result, ConstantInt::get(IntPtrTy, Size), + GEP->getName()+".offs"); continue; } Constant *Scale = ConstantInt::get(IntPtrTy, Size); Constant *OC = ConstantExpr::getIntegerCast(OpC, IntPtrTy, true /*SExt*/); - Scale = ConstantExpr::getMul(OC, Scale); + Scale = ConstantExpr::getMul(OC, Scale, isInBounds/*NUW*/); // Emit an add instruction. Result = Builder->CreateAdd(Result, Scale, GEP->getName()+".offs"); continue; @@ -457,9 +441,9 @@ Value *InstCombiner::EmitGEPOffset(User *GEP) { if (Op->getType() != IntPtrTy) Op = Builder->CreateIntCast(Op, IntPtrTy, true, Op->getName()+".c"); if (Size != 1) { - Constant *Scale = ConstantInt::get(IntPtrTy, Size); // We'll let instcombine(mul) convert this to a shl if possible. - Op = Builder->CreateMul(Op, Scale, GEP->getName()+".idx"); + Op = Builder->CreateMul(Op, ConstantInt::get(IntPtrTy, Size), + GEP->getName()+".idx", isInBounds /*NUW*/); } // Emit an add instruction. @@ -545,8 +529,13 @@ Value *InstCombiner::OptimizePointerDifference(Value *LHS, Value *RHS, Instruction *InstCombiner::visitSub(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - if (Op0 == Op1) // sub X, X -> 0 - return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); + if (Value *V = SimplifySubInst(Op0, Op1, I.hasNoSignedWrap(), + I.hasNoUnsignedWrap(), TD)) + return ReplaceInstUsesWith(I, V); + + // (A*B)-(A*C) -> A*(B-C) etc + if (Value *V = SimplifyUsingDistributiveLaws(I)) + return ReplaceInstUsesWith(I, V); // If this is a 'B = x-(-A)', change to B = x+A. This preserves NSW/NUW. if (Value *V = dyn_castNegVal(Op1)) { @@ -556,18 +545,14 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return Res; } - if (isa<UndefValue>(Op0)) - return ReplaceInstUsesWith(I, Op0); // undef - X -> undef - if (isa<UndefValue>(Op1)) - return ReplaceInstUsesWith(I, Op1); // X - undef -> undef if (I.getType()->isIntegerTy(1)) return BinaryOperator::CreateXor(Op0, Op1); + + // Replace (-1 - A) with (~A). + if (match(Op0, m_AllOnes())) + return BinaryOperator::CreateNot(Op1); if (ConstantInt *C = dyn_cast<ConstantInt>(Op0)) { - // Replace (-1 - A) with (~A). - if (C->isAllOnesValue()) - return BinaryOperator::CreateNot(Op1); - // C - ~X == X + (1+C) Value *X = 0; if (match(Op1, m_Not(m_Value(X)))) @@ -576,29 +561,16 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { // -(X >>u 31) -> (X >>s 31) // -(X >>s 31) -> (X >>u 31) if (C->isZero()) { - if (BinaryOperator *SI = dyn_cast<BinaryOperator>(Op1)) { - if (SI->getOpcode() == Instruction::LShr) { - if (ConstantInt *CU = dyn_cast<ConstantInt>(SI->getOperand(1))) { - // Check to see if we are shifting out everything but the sign bit. - if (CU->getLimitedValue(SI->getType()->getPrimitiveSizeInBits()) == - SI->getType()->getPrimitiveSizeInBits()-1) { - // Ok, the transformation is safe. Insert AShr. - return BinaryOperator::Create(Instruction::AShr, - SI->getOperand(0), CU, SI->getName()); - } - } - } else if (SI->getOpcode() == Instruction::AShr) { - if (ConstantInt *CU = dyn_cast<ConstantInt>(SI->getOperand(1))) { - // Check to see if we are shifting out everything but the sign bit. - if (CU->getLimitedValue(SI->getType()->getPrimitiveSizeInBits()) == - SI->getType()->getPrimitiveSizeInBits()-1) { - // Ok, the transformation is safe. Insert LShr. - return BinaryOperator::CreateLShr( - SI->getOperand(0), CU, SI->getName()); - } - } - } - } + Value *X; ConstantInt *CI; + if (match(Op1, m_LShr(m_Value(X), m_ConstantInt(CI))) && + // Verify we are shifting out everything but the sign bit. + CI->getValue() == I.getType()->getPrimitiveSizeInBits()-1) + return BinaryOperator::CreateAShr(X, CI); + + if (match(Op1, m_AShr(m_Value(X), m_ConstantInt(CI))) && + // Verify we are shifting out everything but the sign bit. + CI->getValue() == I.getType()->getPrimitiveSizeInBits()-1) + return BinaryOperator::CreateLShr(X, CI); } // Try to fold constant sub into select arguments. @@ -608,86 +580,80 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { // C - zext(bool) -> bool ? C - 1 : C if (ZExtInst *ZI = dyn_cast<ZExtInst>(Op1)) - if (ZI->getSrcTy() == Type::getInt1Ty(I.getContext())) + if (ZI->getSrcTy()->isIntegerTy(1)) return SelectInst::Create(ZI->getOperand(0), SubOne(C), C); + + // C-(X+C2) --> (C-C2)-X + ConstantInt *C2; + if (match(Op1, m_Add(m_Value(X), m_ConstantInt(C2)))) + return BinaryOperator::CreateSub(ConstantExpr::getSub(C, C2), X); } - if (BinaryOperator *Op1I = dyn_cast<BinaryOperator>(Op1)) { - if (Op1I->getOpcode() == Instruction::Add) { - if (Op1I->getOperand(0) == Op0) // X-(X+Y) == -Y - return BinaryOperator::CreateNeg(Op1I->getOperand(1), - I.getName()); - else if (Op1I->getOperand(1) == Op0) // X-(Y+X) == -Y - return BinaryOperator::CreateNeg(Op1I->getOperand(0), - I.getName()); - else if (ConstantInt *CI1 = dyn_cast<ConstantInt>(I.getOperand(0))) { - if (ConstantInt *CI2 = dyn_cast<ConstantInt>(Op1I->getOperand(1))) - // C1-(X+C2) --> (C1-C2)-X - return BinaryOperator::CreateSub( - ConstantExpr::getSub(CI1, CI2), Op1I->getOperand(0)); - } + + { Value *Y; + // X-(X+Y) == -Y X-(Y+X) == -Y + if (match(Op1, m_Add(m_Specific(Op0), m_Value(Y))) || + match(Op1, m_Add(m_Value(Y), m_Specific(Op0)))) + return BinaryOperator::CreateNeg(Y); + + // (X-Y)-X == -Y + if (match(Op0, m_Sub(m_Specific(Op1), m_Value(Y)))) + return BinaryOperator::CreateNeg(Y); + } + + if (Op1->hasOneUse()) { + Value *X = 0, *Y = 0, *Z = 0; + Constant *C = 0; + ConstantInt *CI = 0; + + // (X - (Y - Z)) --> (X + (Z - Y)). + if (match(Op1, m_Sub(m_Value(Y), m_Value(Z)))) + return BinaryOperator::CreateAdd(Op0, + Builder->CreateSub(Z, Y, Op1->getName())); + + // (X - (X & Y)) --> (X & ~Y) + // + if (match(Op1, m_And(m_Value(Y), m_Specific(Op0))) || + match(Op1, m_And(m_Specific(Op0), m_Value(Y)))) + return BinaryOperator::CreateAnd(Op0, + Builder->CreateNot(Y, Y->getName() + ".not")); + + // 0 - (X sdiv C) -> (X sdiv -C) + if (match(Op1, m_SDiv(m_Value(X), m_Constant(C))) && + match(Op0, m_Zero())) + return BinaryOperator::CreateSDiv(X, ConstantExpr::getNeg(C)); + + // 0 - (X << Y) -> (-X << Y) when X is freely negatable. + if (match(Op1, m_Shl(m_Value(X), m_Value(Y))) && match(Op0, m_Zero())) + if (Value *XNeg = dyn_castNegVal(X)) + return BinaryOperator::CreateShl(XNeg, Y); + + // X - X*C --> X * (1-C) + if (match(Op1, m_Mul(m_Specific(Op0), m_ConstantInt(CI)))) { + Constant *CP1 = ConstantExpr::getSub(ConstantInt::get(I.getType(),1), CI); + return BinaryOperator::CreateMul(Op0, CP1); } - if (Op1I->hasOneUse()) { - // Replace (x - (y - z)) with (x + (z - y)) if the (y - z) subexpression - // is not used by anyone else... - // - if (Op1I->getOpcode() == Instruction::Sub) { - // Swap the two operands of the subexpr... - Value *IIOp0 = Op1I->getOperand(0), *IIOp1 = Op1I->getOperand(1); - Op1I->setOperand(0, IIOp1); - Op1I->setOperand(1, IIOp0); - - // Create the new top level add instruction... - return BinaryOperator::CreateAdd(Op0, Op1); - } - - // Replace (A - (A & B)) with (A & ~B) if this is the only use of (A&B)... - // - if (Op1I->getOpcode() == Instruction::And && - (Op1I->getOperand(0) == Op0 || Op1I->getOperand(1) == Op0)) { - Value *OtherOp = Op1I->getOperand(Op1I->getOperand(0) == Op0); - - Value *NewNot = Builder->CreateNot(OtherOp, "B.not"); - return BinaryOperator::CreateAnd(Op0, NewNot); - } - - // 0 - (X sdiv C) -> (X sdiv -C) - if (Op1I->getOpcode() == Instruction::SDiv) - if (ConstantInt *CSI = dyn_cast<ConstantInt>(Op0)) - if (CSI->isZero()) - if (Constant *DivRHS = dyn_cast<Constant>(Op1I->getOperand(1))) - return BinaryOperator::CreateSDiv(Op1I->getOperand(0), - ConstantExpr::getNeg(DivRHS)); - - // 0 - (C << X) -> (-C << X) - if (Op1I->getOpcode() == Instruction::Shl) - if (ConstantInt *CSI = dyn_cast<ConstantInt>(Op0)) - if (CSI->isZero()) - if (Value *ShlLHSNeg = dyn_castNegVal(Op1I->getOperand(0))) - return BinaryOperator::CreateShl(ShlLHSNeg, Op1I->getOperand(1)); - - // X - X*C --> X * (1-C) - ConstantInt *C2 = 0; - if (dyn_castFoldableMul(Op1I, C2) == Op0) { - Constant *CP1 = - ConstantExpr::getSub(ConstantInt::get(I.getType(), 1), - C2); - return BinaryOperator::CreateMul(Op0, CP1); - } + // X - X<<C --> X * (1-(1<<C)) + if (match(Op1, m_Shl(m_Specific(Op0), m_ConstantInt(CI)))) { + Constant *One = ConstantInt::get(I.getType(), 1); + C = ConstantExpr::getSub(One, ConstantExpr::getShl(One, CI)); + return BinaryOperator::CreateMul(Op0, C); } - } - - if (BinaryOperator *Op0I = dyn_cast<BinaryOperator>(Op0)) { - if (Op0I->getOpcode() == Instruction::Add) { - if (Op0I->getOperand(0) == Op1) // (Y+X)-Y == X - return ReplaceInstUsesWith(I, Op0I->getOperand(1)); - else if (Op0I->getOperand(1) == Op1) // (X+Y)-Y == X - return ReplaceInstUsesWith(I, Op0I->getOperand(0)); - } else if (Op0I->getOpcode() == Instruction::Sub) { - if (Op0I->getOperand(0) == Op1) // (X-Y)-X == -Y - return BinaryOperator::CreateNeg(Op0I->getOperand(1), - I.getName()); + + // X - A*-B -> X + A*B + // X - -A*B -> X + A*B + Value *A, *B; + if (match(Op1, m_Mul(m_Value(A), m_Neg(m_Value(B)))) || + match(Op1, m_Mul(m_Neg(m_Value(A)), m_Value(B)))) + return BinaryOperator::CreateAdd(Op0, Builder->CreateMul(A, B)); + + // X - A*CI -> X + A*-CI + // X - CI*A -> X + A*-CI + if (match(Op1, m_Mul(m_Value(A), m_ConstantInt(CI))) || + match(Op1, m_Mul(m_ConstantInt(CI), m_Value(A)))) { + Value *NewMul = Builder->CreateMul(A, ConstantExpr::getNeg(CI)); + return BinaryOperator::CreateAdd(Op0, NewMul); } } |