diff options
Diffstat (limited to 'contrib/llvm/lib/Analysis/VectorUtils.cpp')
-rw-r--r-- | contrib/llvm/lib/Analysis/VectorUtils.cpp | 94 |
1 files changed, 90 insertions, 4 deletions
diff --git a/contrib/llvm/lib/Analysis/VectorUtils.cpp b/contrib/llvm/lib/Analysis/VectorUtils.cpp index 7e598f4..554d132 100644 --- a/contrib/llvm/lib/Analysis/VectorUtils.cpp +++ b/contrib/llvm/lib/Analysis/VectorUtils.cpp @@ -11,18 +11,19 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Analysis/VectorUtils.h" #include "llvm/ADT/EquivalenceClasses.h" #include "llvm/Analysis/DemandedBits.h" #include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" -#include "llvm/Analysis/VectorUtils.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/GetElementPtrTypeIterator.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/Value.h" -#include "llvm/IR/Constants.h" using namespace llvm; using namespace llvm::PatternMatch; @@ -300,7 +301,7 @@ const llvm::Value *llvm::getSplatValue(const Value *V) { auto *InsertEltInst = dyn_cast<InsertElementInst>(ShuffleInst->getOperand(0)); if (!InsertEltInst || !isa<ConstantInt>(InsertEltInst->getOperand(2)) || - !cast<ConstantInt>(InsertEltInst->getOperand(2))->isNullValue()) + !cast<ConstantInt>(InsertEltInst->getOperand(2))->isZero()) return nullptr; return InsertEltInst->getOperand(1); @@ -488,3 +489,88 @@ Instruction *llvm::propagateMetadata(Instruction *Inst, ArrayRef<Value *> VL) { return Inst; } + +Constant *llvm::createInterleaveMask(IRBuilder<> &Builder, unsigned VF, + unsigned NumVecs) { + SmallVector<Constant *, 16> Mask; + for (unsigned i = 0; i < VF; i++) + for (unsigned j = 0; j < NumVecs; j++) + Mask.push_back(Builder.getInt32(j * VF + i)); + + return ConstantVector::get(Mask); +} + +Constant *llvm::createStrideMask(IRBuilder<> &Builder, unsigned Start, + unsigned Stride, unsigned VF) { + SmallVector<Constant *, 16> Mask; + for (unsigned i = 0; i < VF; i++) + Mask.push_back(Builder.getInt32(Start + i * Stride)); + + return ConstantVector::get(Mask); +} + +Constant *llvm::createSequentialMask(IRBuilder<> &Builder, unsigned Start, + unsigned NumInts, unsigned NumUndefs) { + SmallVector<Constant *, 16> Mask; + for (unsigned i = 0; i < NumInts; i++) + Mask.push_back(Builder.getInt32(Start + i)); + + Constant *Undef = UndefValue::get(Builder.getInt32Ty()); + for (unsigned i = 0; i < NumUndefs; i++) + Mask.push_back(Undef); + + return ConstantVector::get(Mask); +} + +/// A helper function for concatenating vectors. This function concatenates two +/// vectors having the same element type. If the second vector has fewer +/// elements than the first, it is padded with undefs. +static Value *concatenateTwoVectors(IRBuilder<> &Builder, Value *V1, + Value *V2) { + VectorType *VecTy1 = dyn_cast<VectorType>(V1->getType()); + VectorType *VecTy2 = dyn_cast<VectorType>(V2->getType()); + assert(VecTy1 && VecTy2 && + VecTy1->getScalarType() == VecTy2->getScalarType() && + "Expect two vectors with the same element type"); + + unsigned NumElts1 = VecTy1->getNumElements(); + unsigned NumElts2 = VecTy2->getNumElements(); + assert(NumElts1 >= NumElts2 && "Unexpect the first vector has less elements"); + + if (NumElts1 > NumElts2) { + // Extend with UNDEFs. + Constant *ExtMask = + createSequentialMask(Builder, 0, NumElts2, NumElts1 - NumElts2); + V2 = Builder.CreateShuffleVector(V2, UndefValue::get(VecTy2), ExtMask); + } + + Constant *Mask = createSequentialMask(Builder, 0, NumElts1 + NumElts2, 0); + return Builder.CreateShuffleVector(V1, V2, Mask); +} + +Value *llvm::concatenateVectors(IRBuilder<> &Builder, ArrayRef<Value *> Vecs) { + unsigned NumVecs = Vecs.size(); + assert(NumVecs > 1 && "Should be at least two vectors"); + + SmallVector<Value *, 8> ResList; + ResList.append(Vecs.begin(), Vecs.end()); + do { + SmallVector<Value *, 8> TmpList; + for (unsigned i = 0; i < NumVecs - 1; i += 2) { + Value *V0 = ResList[i], *V1 = ResList[i + 1]; + assert((V0->getType() == V1->getType() || i == NumVecs - 2) && + "Only the last vector may have a different type"); + + TmpList.push_back(concatenateTwoVectors(Builder, V0, V1)); + } + + // Push the last vector if the total number of vectors is odd. + if (NumVecs % 2 != 0) + TmpList.push_back(ResList[NumVecs - 1]); + + ResList = TmpList; + NumVecs = ResList.size(); + } while (NumVecs > 1); + + return ResList[0]; +} |