diff options
Diffstat (limited to 'contrib/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp')
-rw-r--r-- | contrib/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp | 173 |
1 files changed, 125 insertions, 48 deletions
diff --git a/contrib/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/contrib/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 788097f..45a19fb 100644 --- a/contrib/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/contrib/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -48,8 +48,8 @@ static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC, BinaryOperator *I = dyn_cast<BinaryOperator>(V); if (I && I->isLogicalShift() && isKnownToBeAPowerOfTwo(I->getOperand(0), IC.getDataLayout(), false, 0, - IC.getAssumptionCache(), &CxtI, - IC.getDominatorTree())) { + &IC.getAssumptionCache(), &CxtI, + &IC.getDominatorTree())) { // We know that this is an exact/nuw shift and that the input is a // non-zero context as well. if (Value *V2 = simplifyValueKnownNonZero(I->getOperand(0), IC, CxtI)) { @@ -179,7 +179,7 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); - if (Value *V = SimplifyMulInst(Op0, Op1, DL, TLI, DT, AC)) + if (Value *V = SimplifyMulInst(Op0, Op1, DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); if (Value *V = SimplifyUsingDistributiveLaws(I)) @@ -267,14 +267,8 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { // Simplify mul instructions with a constant RHS. if (isa<Constant>(Op1)) { - // Try to fold constant mul into select arguments. - if (SelectInst *SI = dyn_cast<SelectInst>(Op0)) - if (Instruction *R = FoldOpIntoSelect(I, SI)) - return R; - - if (isa<PHINode>(Op0)) - if (Instruction *NV = FoldOpIntoPhi(I)) - return NV; + if (Instruction *FoldedMul = foldOpWithConstantIntoOperand(I)) + return FoldedMul; // Canonicalize (X+C1)*CI -> X*CI+C1*CI. { @@ -389,6 +383,80 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { } } + // Check for (mul (sext x), y), see if we can merge this into an + // integer mul followed by a sext. + if (SExtInst *Op0Conv = dyn_cast<SExtInst>(Op0)) { + // (mul (sext x), cst) --> (sext (mul x, cst')) + if (ConstantInt *Op1C = dyn_cast<ConstantInt>(Op1)) { + if (Op0Conv->hasOneUse()) { + Constant *CI = + ConstantExpr::getTrunc(Op1C, Op0Conv->getOperand(0)->getType()); + if (ConstantExpr::getSExt(CI, I.getType()) == Op1C && + WillNotOverflowSignedMul(Op0Conv->getOperand(0), CI, I)) { + // Insert the new, smaller mul. + Value *NewMul = + Builder->CreateNSWMul(Op0Conv->getOperand(0), CI, "mulconv"); + return new SExtInst(NewMul, I.getType()); + } + } + } + + // (mul (sext x), (sext y)) --> (sext (mul int x, y)) + if (SExtInst *Op1Conv = dyn_cast<SExtInst>(Op1)) { + // Only do this if x/y have the same type, if at last one of them has a + // single use (so we don't increase the number of sexts), and if the + // integer mul will not overflow. + if (Op0Conv->getOperand(0)->getType() == + Op1Conv->getOperand(0)->getType() && + (Op0Conv->hasOneUse() || Op1Conv->hasOneUse()) && + WillNotOverflowSignedMul(Op0Conv->getOperand(0), + Op1Conv->getOperand(0), I)) { + // Insert the new integer mul. + Value *NewMul = Builder->CreateNSWMul( + Op0Conv->getOperand(0), Op1Conv->getOperand(0), "mulconv"); + return new SExtInst(NewMul, I.getType()); + } + } + } + + // Check for (mul (zext x), y), see if we can merge this into an + // integer mul followed by a zext. + if (auto *Op0Conv = dyn_cast<ZExtInst>(Op0)) { + // (mul (zext x), cst) --> (zext (mul x, cst')) + if (ConstantInt *Op1C = dyn_cast<ConstantInt>(Op1)) { + if (Op0Conv->hasOneUse()) { + Constant *CI = + ConstantExpr::getTrunc(Op1C, Op0Conv->getOperand(0)->getType()); + if (ConstantExpr::getZExt(CI, I.getType()) == Op1C && + computeOverflowForUnsignedMul(Op0Conv->getOperand(0), CI, &I) == + OverflowResult::NeverOverflows) { + // Insert the new, smaller mul. + Value *NewMul = + Builder->CreateNUWMul(Op0Conv->getOperand(0), CI, "mulconv"); + return new ZExtInst(NewMul, I.getType()); + } + } + } + + // (mul (zext x), (zext y)) --> (zext (mul int x, y)) + if (auto *Op1Conv = dyn_cast<ZExtInst>(Op1)) { + // Only do this if x/y have the same type, if at last one of them has a + // single use (so we don't increase the number of zexts), and if the + // integer mul will not overflow. + if (Op0Conv->getOperand(0)->getType() == + Op1Conv->getOperand(0)->getType() && + (Op0Conv->hasOneUse() || Op1Conv->hasOneUse()) && + computeOverflowForUnsignedMul(Op0Conv->getOperand(0), + Op1Conv->getOperand(0), + &I) == OverflowResult::NeverOverflows) { + // Insert the new integer mul. + Value *NewMul = Builder->CreateNUWMul( + Op0Conv->getOperand(0), Op1Conv->getOperand(0), "mulconv"); + return new ZExtInst(NewMul, I.getType()); + } + } + } + if (!I.hasNoSignedWrap() && WillNotOverflowSignedMul(Op0, Op1, I)) { Changed = true; I.setHasNoSignedWrap(true); @@ -545,21 +613,15 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { std::swap(Op0, Op1); if (Value *V = - SimplifyFMulInst(Op0, Op1, I.getFastMathFlags(), DL, TLI, DT, AC)) + SimplifyFMulInst(Op0, Op1, I.getFastMathFlags(), DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); bool AllowReassociate = I.hasUnsafeAlgebra(); // Simplify mul instructions with a constant RHS. if (isa<Constant>(Op1)) { - // Try to fold constant mul into select arguments. - if (SelectInst *SI = dyn_cast<SelectInst>(Op0)) - if (Instruction *R = FoldOpIntoSelect(I, SI)) - return R; - - if (isa<PHINode>(Op0)) - if (Instruction *NV = FoldOpIntoPhi(I)) - return NV; + if (Instruction *FoldedMul = foldOpWithConstantIntoOperand(I)) + return FoldedMul; // (fmul X, -1.0) --> (fsub -0.0, X) if (match(Op1, m_SpecificFP(-1.0))) { @@ -709,7 +771,6 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { BuilderTy::FastMathFlagGuard Guard(*Builder); Builder->setFastMathFlags(I.getFastMathFlags()); Value *T = Builder->CreateFMul(Opnd1, Opnd1); - Value *R = Builder->CreateFMul(T, Y); R->takeName(&I); return replaceInstUsesWith(I, R); @@ -883,14 +944,9 @@ Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) { } } - if (*C2 != 0) { // avoid X udiv 0 - if (SelectInst *SI = dyn_cast<SelectInst>(Op0)) - if (Instruction *R = FoldOpIntoSelect(I, SI)) - return R; - if (isa<PHINode>(Op0)) - if (Instruction *NV = FoldOpIntoPhi(I)) - return NV; - } + if (*C2 != 0) // avoid X udiv 0 + if (Instruction *FoldedDiv = foldOpWithConstantIntoOperand(I)) + return FoldedDiv; } } @@ -991,19 +1047,22 @@ static Instruction *foldUDivNegCst(Value *Op0, Value *Op1, } // X udiv (C1 << N), where C1 is "1<<C2" --> X >> (N+C2) +// X udiv (zext (C1 << N)), where C1 is "1<<C2" --> X >> (N+C2) static Instruction *foldUDivShl(Value *Op0, Value *Op1, const BinaryOperator &I, InstCombiner &IC) { - Instruction *ShiftLeft = cast<Instruction>(Op1); - if (isa<ZExtInst>(ShiftLeft)) - ShiftLeft = cast<Instruction>(ShiftLeft->getOperand(0)); - - const APInt &CI = - cast<Constant>(ShiftLeft->getOperand(0))->getUniqueInteger(); - Value *N = ShiftLeft->getOperand(1); - if (CI != 1) - N = IC.Builder->CreateAdd(N, ConstantInt::get(N->getType(), CI.logBase2())); - if (ZExtInst *Z = dyn_cast<ZExtInst>(Op1)) - N = IC.Builder->CreateZExt(N, Z->getDestTy()); + Value *ShiftLeft; + if (!match(Op1, m_ZExt(m_Value(ShiftLeft)))) + ShiftLeft = Op1; + + const APInt *CI; + Value *N; + if (!match(ShiftLeft, m_Shl(m_APInt(CI), m_Value(N)))) + llvm_unreachable("match should never fail here!"); + if (*CI != 1) + N = IC.Builder->CreateAdd(N, + ConstantInt::get(N->getType(), CI->logBase2())); + if (Op1 != ShiftLeft) + N = IC.Builder->CreateZExt(N, Op1->getType()); BinaryOperator *LShr = BinaryOperator::CreateLShr(Op0, N); if (I.isExact()) LShr->setIsExact(); @@ -1059,7 +1118,7 @@ Instruction *InstCombiner::visitUDiv(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); - if (Value *V = SimplifyUDivInst(Op0, Op1, DL, TLI, DT, AC)) + if (Value *V = SimplifyUDivInst(Op0, Op1, DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); // Handle the integer div common cases @@ -1132,7 +1191,7 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); - if (Value *V = SimplifySDivInst(Op0, Op1, DL, TLI, DT, AC)) + if (Value *V = SimplifySDivInst(Op0, Op1, DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); // Handle the integer div common cases @@ -1195,7 +1254,7 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { return BO; } - if (isKnownToBeAPowerOfTwo(Op1, DL, /*OrZero*/ true, 0, AC, &I, DT)) { + if (isKnownToBeAPowerOfTwo(Op1, DL, /*OrZero*/ true, 0, &AC, &I, &DT)) { // X sdiv (1 << Y) -> X udiv (1 << Y) ( -> X u>> Y) // Safe because the only negative value (1 << Y) can take on is // INT_MIN, and X sdiv INT_MIN == X udiv INT_MIN == 0 if X doesn't have @@ -1247,7 +1306,7 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { return replaceInstUsesWith(I, V); if (Value *V = SimplifyFDivInst(Op0, Op1, I.getFastMathFlags(), - DL, TLI, DT, AC)) + DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); if (isa<Constant>(Op0)) @@ -1367,6 +1426,16 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { } } + Value *LHS; + Value *RHS; + + // -x / -y -> x / y + if (match(Op0, m_FNeg(m_Value(LHS))) && match(Op1, m_FNeg(m_Value(RHS)))) { + I.setOperand(0, LHS); + I.setOperand(1, RHS); + return &I; + } + return nullptr; } @@ -1421,7 +1490,7 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); - if (Value *V = SimplifyURemInst(Op0, Op1, DL, TLI, DT, AC)) + if (Value *V = SimplifyURemInst(Op0, Op1, DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); if (Instruction *common = commonIRemTransforms(I)) @@ -1434,7 +1503,7 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) { I.getType()); // X urem Y -> X and Y-1, where Y is a power of 2, - if (isKnownToBeAPowerOfTwo(Op1, DL, /*OrZero*/ true, 0, AC, &I, DT)) { + if (isKnownToBeAPowerOfTwo(Op1, DL, /*OrZero*/ true, 0, &AC, &I, &DT)) { Constant *N1 = Constant::getAllOnesValue(I.getType()); Value *Add = Builder->CreateAdd(Op1, N1); return BinaryOperator::CreateAnd(Op0, Add); @@ -1447,6 +1516,14 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) { return replaceInstUsesWith(I, Ext); } + // X urem C -> X < C ? X : X - C, where C >= signbit. + const APInt *DivisorC; + if (match(Op1, m_APInt(DivisorC)) && DivisorC->isNegative()) { + Value *Cmp = Builder->CreateICmpULT(Op0, Op1); + Value *Sub = Builder->CreateSub(Op0, Op1); + return SelectInst::Create(Cmp, Op0, Sub); + } + return nullptr; } @@ -1456,7 +1533,7 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); - if (Value *V = SimplifySRemInst(Op0, Op1, DL, TLI, DT, AC)) + if (Value *V = SimplifySRemInst(Op0, Op1, DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); // Handle the integer rem common cases @@ -1532,7 +1609,7 @@ Instruction *InstCombiner::visitFRem(BinaryOperator &I) { return replaceInstUsesWith(I, V); if (Value *V = SimplifyFRemInst(Op0, Op1, I.getFastMathFlags(), - DL, TLI, DT, AC)) + DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); // Handle cases involving: rem X, (select Cond, Y, Z) |