diff options
Diffstat (limited to 'lib/VMCore/ConstantFold.cpp')
-rw-r--r-- | lib/VMCore/ConstantFold.cpp | 159 |
1 files changed, 111 insertions, 48 deletions
diff --git a/lib/VMCore/ConstantFold.cpp b/lib/VMCore/ConstantFold.cpp index 9a91daf..573efb7 100644 --- a/lib/VMCore/ConstantFold.cpp +++ b/lib/VMCore/ConstantFold.cpp @@ -42,6 +42,10 @@ using namespace llvm; /// input vector constant are all simple integer or FP values. static Constant *BitCastConstantVector(ConstantVector *CV, const VectorType *DstTy) { + + if (CV->isAllOnesValue()) return Constant::getAllOnesValue(DstTy); + if (CV->isNullValue()) return Constant::getNullValue(DstTy); + // If this cast changes element count then we can't handle it here: // doing so requires endianness information. This should be handled by // Analysis/ConstantFolding.cpp @@ -145,7 +149,7 @@ static Constant *FoldBitCast(Constant *V, const Type *DestTy) { // This allows for other simplifications (although some of them // can only be handled by Analysis/ConstantFolding.cpp). if (isa<ConstantInt>(V) || isa<ConstantFP>(V)) - return ConstantExpr::getBitCast(ConstantVector::get(&V, 1), DestPTy); + return ConstantExpr::getBitCast(ConstantVector::get(V), DestPTy); } // Finally, implement bitcast folding now. The code below doesn't handle @@ -202,7 +206,7 @@ static Constant *ExtractConstantBytes(Constant *C, unsigned ByteStart, APInt V = CI->getValue(); if (ByteStart) V = V.lshr(ByteStart*8); - V.trunc(ByteSize*8); + V = V.trunc(ByteSize*8); return ConstantInt::get(CI->getContext(), V); } @@ -511,10 +515,14 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V, return Constant::getNullValue(DestTy); return UndefValue::get(DestTy); } + // No compile-time operations on this type yet. if (V->getType()->isPPC_FP128Ty() || DestTy->isPPC_FP128Ty()) return 0; + if (V->isNullValue() && !DestTy->isX86_MMXTy()) + return Constant::getNullValue(DestTy); + // If the cast operand is a constant expression, there's a few things we can // do to try to simplify it. if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) { @@ -637,9 +645,7 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V, case Instruction::SIToFP: if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) { APInt api = CI->getValue(); - const uint64_t zero[] = {0, 0}; - APFloat apf = APFloat(APInt(DestTy->getPrimitiveSizeInBits(), - 2, zero)); + APFloat apf(APInt::getNullValue(DestTy->getPrimitiveSizeInBits()), true); (void)apf.convertFromAPInt(api, opc==Instruction::SIToFP, APFloat::rmNearestTiesToEven); @@ -649,25 +655,22 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V, case Instruction::ZExt: if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) { uint32_t BitWidth = cast<IntegerType>(DestTy)->getBitWidth(); - APInt Result(CI->getValue()); - Result.zext(BitWidth); - return ConstantInt::get(V->getContext(), Result); + return ConstantInt::get(V->getContext(), + CI->getValue().zext(BitWidth)); } return 0; case Instruction::SExt: if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) { uint32_t BitWidth = cast<IntegerType>(DestTy)->getBitWidth(); - APInt Result(CI->getValue()); - Result.sext(BitWidth); - return ConstantInt::get(V->getContext(), Result); + return ConstantInt::get(V->getContext(), + CI->getValue().sext(BitWidth)); } return 0; case Instruction::Trunc: { uint32_t DestBitWidth = cast<IntegerType>(DestTy)->getBitWidth(); if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) { - APInt Result(CI->getValue()); - Result.trunc(DestBitWidth); - return ConstantInt::get(V->getContext(), Result); + return ConstantInt::get(V->getContext(), + CI->getValue().trunc(DestBitWidth)); } // The input must be a constantexpr. See if we can simplify this based on @@ -690,10 +693,58 @@ Constant *llvm::ConstantFoldSelectInstruction(Constant *Cond, if (ConstantInt *CB = dyn_cast<ConstantInt>(Cond)) return CB->getZExtValue() ? V1 : V2; + // Check for zero aggregate and ConstantVector of zeros + if (Cond->isNullValue()) return V2; + + if (ConstantVector* CondV = dyn_cast<ConstantVector>(Cond)) { + + if (CondV->isAllOnesValue()) return V1; + + const VectorType *VTy = cast<VectorType>(V1->getType()); + ConstantVector *CP1 = dyn_cast<ConstantVector>(V1); + ConstantVector *CP2 = dyn_cast<ConstantVector>(V2); + + if ((CP1 || isa<ConstantAggregateZero>(V1)) && + (CP2 || isa<ConstantAggregateZero>(V2))) { + + // Find the element type of the returned vector + const Type *EltTy = VTy->getElementType(); + unsigned NumElem = VTy->getNumElements(); + std::vector<Constant*> Res(NumElem); + + bool Valid = true; + for (unsigned i = 0; i < NumElem; ++i) { + ConstantInt* c = dyn_cast<ConstantInt>(CondV->getOperand(i)); + if (!c) { + Valid = false; + break; + } + Constant *C1 = CP1 ? CP1->getOperand(i) : Constant::getNullValue(EltTy); + Constant *C2 = CP2 ? CP2->getOperand(i) : Constant::getNullValue(EltTy); + Res[i] = c->getZExtValue() ? C1 : C2; + } + // If we were able to build the vector, return it + if (Valid) return ConstantVector::get(Res); + } + } + + if (isa<UndefValue>(V1)) return V2; if (isa<UndefValue>(V2)) return V1; if (isa<UndefValue>(Cond)) return V1; if (V1 == V2) return V1; + + if (ConstantExpr *TrueVal = dyn_cast<ConstantExpr>(V1)) { + if (TrueVal->getOpcode() == Instruction::Select) + if (TrueVal->getOperand(0) == Cond) + return ConstantExpr::getSelect(Cond, TrueVal->getOperand(1), V2); + } + if (ConstantExpr *FalseVal = dyn_cast<ConstantExpr>(V2)) { + if (FalseVal->getOpcode() == Instruction::Select) + if (FalseVal->getOperand(0) == Cond) + return ConstantExpr::getSelect(Cond, V1, FalseVal->getOperand(2)); + } + return 0; } @@ -821,7 +872,7 @@ Constant *llvm::ConstantFoldShuffleVectorInstruction(Constant *V1, Result.push_back(InElt); } - return ConstantVector::get(&Result[0], Result.size()); + return ConstantVector::get(Result); } Constant *llvm::ConstantFoldExtractValueInstruction(Constant *Agg, @@ -982,8 +1033,8 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, return Constant::getNullValue(C1->getType()); // X lshr undef -> 0 // undef lshr X -> 0 case Instruction::AShr: - if (!isa<UndefValue>(C2)) - return C1; // undef ashr X --> undef + if (!isa<UndefValue>(C2)) // undef ashr X --> all ones + return Constant::getAllOnesValue(C1->getType()); else if (isa<UndefValue>(C1)) return C1; // undef ashr undef -> undef else @@ -1343,8 +1394,7 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, // Given ((a + b) + c), if (b + c) folds to something interesting, return // (a + (b + c)). - if (Instruction::isAssociative(Opcode, C1->getType()) && - CE1->getOpcode() == Opcode) { + if (Instruction::isAssociative(Opcode) && CE1->getOpcode() == Opcode) { Constant *T = ConstantExpr::get(Opcode, CE1->getOperand(1), C2); if (!isa<ConstantExpr>(T) || cast<ConstantExpr>(T)->getOpcode() != Opcode) return ConstantExpr::get(Opcode, CE1->getOperand(0), T); @@ -1413,7 +1463,7 @@ static bool isMaybeZeroSizedType(const Type *Ty) { /// first is less than the second, return -1, if the second is less than the /// first, return 1. If the constants are not integral, return -2. /// -static int IdxCompare(Constant *C1, Constant *C2, const Type *ElTy) { +static int IdxCompare(Constant *C1, Constant *C2, const Type *ElTy) { if (C1 == C2) return 0; // Ok, we found a different index. If they are not ConstantInt, we can't do @@ -1896,11 +1946,11 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred, // If we can constant fold the comparison of each element, constant fold // the whole vector comparison. SmallVector<Constant*, 4> ResElts; - for (unsigned i = 0, e = C1Elts.size(); i != e; ++i) { - // Compare the elements, producing an i1 result or constant expr. + // Compare the elements, producing an i1 result or constant expr. + for (unsigned i = 0, e = C1Elts.size(); i != e; ++i) ResElts.push_back(ConstantExpr::getCompare(pred, C1Elts[i], C2Elts[i])); - } - return ConstantVector::get(&ResElts[0], ResElts.size()); + + return ConstantVector::get(ResElts); } if (C1->getType()->isFloatingPointTy()) { @@ -1948,7 +1998,7 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred, else if (pred == FCmpInst::FCMP_UGT || pred == FCmpInst::FCMP_OGT) Result = 1; break; - case ICmpInst::ICMP_NE: // We know that C1 != C2 + case FCmpInst::FCMP_ONE: // We know that C1 != C2 // We can only partially decide this relation. if (pred == FCmpInst::FCMP_OEQ || pred == FCmpInst::FCMP_UEQ) Result = 0; @@ -2073,56 +2123,55 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred, /// isInBoundsIndices - Test whether the given sequence of *normalized* indices /// is "inbounds". -static bool isInBoundsIndices(Constant *const *Idxs, size_t NumIdx) { +template<typename IndexTy> +static bool isInBoundsIndices(IndexTy const *Idxs, size_t NumIdx) { // No indices means nothing that could be out of bounds. if (NumIdx == 0) return true; // If the first index is zero, it's in bounds. - if (Idxs[0]->isNullValue()) return true; + if (cast<Constant>(Idxs[0])->isNullValue()) return true; // If the first index is one and all the rest are zero, it's in bounds, // by the one-past-the-end rule. if (!cast<ConstantInt>(Idxs[0])->isOne()) return false; for (unsigned i = 1, e = NumIdx; i != e; ++i) - if (!Idxs[i]->isNullValue()) + if (!cast<Constant>(Idxs[i])->isNullValue()) return false; return true; } -Constant *llvm::ConstantFoldGetElementPtr(Constant *C, - bool inBounds, - Constant* const *Idxs, - unsigned NumIdx) { +template<typename IndexTy> +static Constant *ConstantFoldGetElementPtrImpl(Constant *C, + bool inBounds, + IndexTy const *Idxs, + unsigned NumIdx) { + Constant *Idx0 = cast<Constant>(Idxs[0]); if (NumIdx == 0 || - (NumIdx == 1 && Idxs[0]->isNullValue())) + (NumIdx == 1 && Idx0->isNullValue())) return C; if (isa<UndefValue>(C)) { const PointerType *Ptr = cast<PointerType>(C->getType()); - const Type *Ty = GetElementPtrInst::getIndexedType(Ptr, - (Value **)Idxs, - (Value **)Idxs+NumIdx); + const Type *Ty = GetElementPtrInst::getIndexedType(Ptr, Idxs, Idxs+NumIdx); assert(Ty != 0 && "Invalid indices for GEP!"); return UndefValue::get(PointerType::get(Ty, Ptr->getAddressSpace())); } - Constant *Idx0 = Idxs[0]; if (C->isNullValue()) { bool isNull = true; for (unsigned i = 0, e = NumIdx; i != e; ++i) - if (!Idxs[i]->isNullValue()) { + if (!cast<Constant>(Idxs[i])->isNullValue()) { isNull = false; break; } if (isNull) { const PointerType *Ptr = cast<PointerType>(C->getType()); - const Type *Ty = GetElementPtrInst::getIndexedType(Ptr, - (Value**)Idxs, - (Value**)Idxs+NumIdx); + const Type *Ty = GetElementPtrInst::getIndexedType(Ptr, Idxs, + Idxs+NumIdx); assert(Ty != 0 && "Invalid indices for GEP!"); - return ConstantPointerNull::get( - PointerType::get(Ty,Ptr->getAddressSpace())); + return ConstantPointerNull::get(PointerType::get(Ty, + Ptr->getAddressSpace())); } } @@ -2173,9 +2222,9 @@ Constant *llvm::ConstantFoldGetElementPtr(Constant *C, } // Implement folding of: - // int* getelementptr ([2 x int]* bitcast ([3 x int]* %X to [2 x int]*), - // long 0, long 0) - // To: int* getelementptr ([3 x int]* %X, long 0, long 0) + // i32* getelementptr ([2 x i32]* bitcast ([3 x i32]* %X to [2 x i32]*), + // i64 0, i64 0) + // To: i32* getelementptr ([3 x i32]* %X, i64 0, i64 0) // if (CE->isCast() && NumIdx > 1 && Idx0->isNullValue()) { if (const PointerType *SPT = @@ -2214,7 +2263,7 @@ Constant *llvm::ConstantFoldGetElementPtr(Constant *C, ATy->getNumElements()); NewIdxs[i] = ConstantExpr::getSRem(CI, Factor); - Constant *PrevIdx = Idxs[i-1]; + Constant *PrevIdx = cast<Constant>(Idxs[i-1]); Constant *Div = ConstantExpr::getSDiv(CI, Factor); // Before adding, extend both operands to i64 to avoid @@ -2242,7 +2291,7 @@ Constant *llvm::ConstantFoldGetElementPtr(Constant *C, // If we did any factoring, start over with the adjusted indices. if (!NewIdxs.empty()) { for (unsigned i = 0; i != NumIdx; ++i) - if (!NewIdxs[i]) NewIdxs[i] = Idxs[i]; + if (!NewIdxs[i]) NewIdxs[i] = cast<Constant>(Idxs[i]); return inBounds ? ConstantExpr::getInBoundsGetElementPtr(C, NewIdxs.data(), NewIdxs.size()) : @@ -2257,3 +2306,17 @@ Constant *llvm::ConstantFoldGetElementPtr(Constant *C, return 0; } + +Constant *llvm::ConstantFoldGetElementPtr(Constant *C, + bool inBounds, + Constant* const *Idxs, + unsigned NumIdx) { + return ConstantFoldGetElementPtrImpl(C, inBounds, Idxs, NumIdx); +} + +Constant *llvm::ConstantFoldGetElementPtr(Constant *C, + bool inBounds, + Value* const *Idxs, + unsigned NumIdx) { + return ConstantFoldGetElementPtrImpl(C, inBounds, Idxs, NumIdx); +} |