summaryrefslogtreecommitdiffstats
path: root/contrib/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp')
-rw-r--r--contrib/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp4421
1 files changed, 2964 insertions, 1457 deletions
diff --git a/contrib/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/contrib/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 2c7bffe..432c86d 100644
--- a/contrib/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/contrib/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -25,6 +25,7 @@
#include "llvm/CodeGen/MachineFrameInfo.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/SelectionDAG.h"
+#include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
#include "llvm/CodeGen/SelectionDAGTargetInfo.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DerivedTypes.h"
@@ -33,6 +34,7 @@
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/KnownBits.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Target/TargetLowering.h"
@@ -53,10 +55,6 @@ STATISTIC(SlicedLoads, "Number of load sliced");
namespace {
static cl::opt<bool>
- CombinerAA("combiner-alias-analysis", cl::Hidden,
- cl::desc("Enable DAG combiner alias-analysis heuristics"));
-
- static cl::opt<bool>
CombinerGlobalAA("combiner-global-alias-analysis", cl::Hidden,
cl::desc("Enable DAG combiner's use of IR alias analysis"));
@@ -117,7 +115,7 @@ namespace {
SmallPtrSet<SDNode *, 32> CombinedNodes;
// AA - Used for DAG load/store alias analysis.
- AliasAnalysis &AA;
+ AliasAnalysis *AA;
/// When an instruction is simplified, add all users of the instruction to
/// the work lists because they might get more simplified now.
@@ -133,6 +131,9 @@ namespace {
/// Add to the worklist making sure its instance is at the back (next to be
/// processed.)
void AddToWorklist(SDNode *N) {
+ assert(N->getOpcode() != ISD::DELETED_NODE &&
+ "Deleted Node added to Worklist");
+
// Skip handle nodes as they can't usefully be combined and confuse the
// zero-use deletion strategy.
if (N->getOpcode() == ISD::HANDLENODE)
@@ -177,6 +178,7 @@ namespace {
void CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO);
private:
+ unsigned MaximumLegalStoreInBits;
/// Check the specified integer node value to see if it can be simplified or
/// if things it uses can be simplified by bit propagation.
@@ -232,11 +234,18 @@ namespace {
SDValue visitTokenFactor(SDNode *N);
SDValue visitMERGE_VALUES(SDNode *N);
SDValue visitADD(SDNode *N);
+ SDValue visitADDLike(SDValue N0, SDValue N1, SDNode *LocReference);
SDValue visitSUB(SDNode *N);
SDValue visitADDC(SDNode *N);
+ SDValue visitUADDO(SDNode *N);
+ SDValue visitUADDOLike(SDValue N0, SDValue N1, SDNode *N);
SDValue visitSUBC(SDNode *N);
+ SDValue visitUSUBO(SDNode *N);
SDValue visitADDE(SDNode *N);
+ SDValue visitADDCARRY(SDNode *N);
+ SDValue visitADDCARRYLike(SDValue N0, SDValue N1, SDValue CarryIn, SDNode *N);
SDValue visitSUBE(SDNode *N);
+ SDValue visitSUBCARRY(SDNode *N);
SDValue visitMUL(SDNode *N);
SDValue useDivRem(SDNode *N);
SDValue visitSDIV(SDNode *N);
@@ -259,6 +268,7 @@ namespace {
SDValue visitSRA(SDNode *N);
SDValue visitSRL(SDNode *N);
SDValue visitRotate(SDNode *N);
+ SDValue visitABS(SDNode *N);
SDValue visitBSWAP(SDNode *N);
SDValue visitBITREVERSE(SDNode *N);
SDValue visitCTLZ(SDNode *N);
@@ -271,9 +281,11 @@ namespace {
SDValue visitSELECT_CC(SDNode *N);
SDValue visitSETCC(SDNode *N);
SDValue visitSETCCE(SDNode *N);
+ SDValue visitSETCCCARRY(SDNode *N);
SDValue visitSIGN_EXTEND(SDNode *N);
SDValue visitZERO_EXTEND(SDNode *N);
SDValue visitANY_EXTEND(SDNode *N);
+ SDValue visitAssertZext(SDNode *N);
SDValue visitSIGN_EXTEND_INREG(SDNode *N);
SDValue visitSIGN_EXTEND_VECTOR_INREG(SDNode *N);
SDValue visitZERO_EXTEND_VECTOR_INREG(SDNode *N);
@@ -336,6 +348,7 @@ namespace {
SDValue visitShiftByConstant(SDNode *N, ConstantSDNode *Amt);
SDValue foldSelectOfConstants(SDNode *N);
+ SDValue foldBinOpIntoSelect(SDNode *BO);
bool SimplifySelectOps(SDNode *SELECT, SDValue LHS, SDValue RHS);
SDValue SimplifyBinOpWithSameOpcodeHands(SDNode *N);
SDValue SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2);
@@ -344,6 +357,8 @@ namespace {
bool NotExtCompare = false);
SDValue foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0, SDValue N1,
SDValue N2, SDValue N3, ISD::CondCode CC);
+ SDValue foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
+ const SDLoc &DL);
SDValue SimplifySetCC(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond,
const SDLoc &DL, bool foldBooleans = true);
@@ -361,14 +376,14 @@ namespace {
SDValue BuildSDIVPow2(SDNode *N);
SDValue BuildUDIV(SDNode *N);
SDValue BuildLogBase2(SDValue Op, const SDLoc &DL);
- SDValue BuildReciprocalEstimate(SDValue Op, SDNodeFlags *Flags);
- SDValue buildRsqrtEstimate(SDValue Op, SDNodeFlags *Flags);
- SDValue buildSqrtEstimate(SDValue Op, SDNodeFlags *Flags);
- SDValue buildSqrtEstimateImpl(SDValue Op, SDNodeFlags *Flags, bool Recip);
+ SDValue BuildReciprocalEstimate(SDValue Op, SDNodeFlags Flags);
+ SDValue buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags);
+ SDValue buildSqrtEstimate(SDValue Op, SDNodeFlags Flags);
+ SDValue buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags, bool Recip);
SDValue buildSqrtNROneConst(SDValue Op, SDValue Est, unsigned Iterations,
- SDNodeFlags *Flags, bool Reciprocal);
+ SDNodeFlags Flags, bool Reciprocal);
SDValue buildSqrtNRTwoConst(SDValue Op, SDValue Est, unsigned Iterations,
- SDNodeFlags *Flags, bool Reciprocal);
+ SDNodeFlags Flags, bool Reciprocal);
SDValue MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
bool DemandHighBits = true);
SDValue MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1);
@@ -377,6 +392,7 @@ namespace {
unsigned PosOpcode, unsigned NegOpcode,
const SDLoc &DL);
SDNode *MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL);
+ SDValue MatchLoadCombine(SDNode *N);
SDValue ReduceLoadWidth(SDNode *N);
SDValue ReduceLoadOpStoreWidth(SDNode *N);
SDValue splitMergedValStore(StoreSDNode *ST);
@@ -384,9 +400,11 @@ namespace {
SDValue reduceBuildVecExtToExtBuildVec(SDNode *N);
SDValue reduceBuildVecConvertToConvertBuildVec(SDNode *N);
SDValue reduceBuildVecToShuffle(SDNode *N);
- SDValue createBuildVecShuffle(SDLoc DL, SDNode *N, ArrayRef<int> VectorMask,
- SDValue VecIn1, SDValue VecIn2,
- unsigned LeftIdx);
+ SDValue reduceBuildVecToTrunc(SDNode *N);
+ SDValue createBuildVecShuffle(const SDLoc &DL, SDNode *N,
+ ArrayRef<int> VectorMask, SDValue VecIn1,
+ SDValue VecIn2, unsigned LeftIdx);
+ SDValue matchVSelectOpSizesWithSetCC(SDNode *N);
SDValue GetDemandedBits(SDValue V, const APInt &Mask);
@@ -416,15 +434,12 @@ namespace {
/// Holds a pointer to an LSBaseSDNode as well as information on where it
/// is located in a sequence of memory operations connected by a chain.
struct MemOpLink {
- MemOpLink (LSBaseSDNode *N, int64_t Offset, unsigned Seq):
- MemNode(N), OffsetFromBase(Offset), SequenceNum(Seq) { }
+ MemOpLink(LSBaseSDNode *N, int64_t Offset)
+ : MemNode(N), OffsetFromBase(Offset) {}
// Ptr to the mem node.
LSBaseSDNode *MemNode;
// Offset from the base ptr.
int64_t OffsetFromBase;
- // What is the sequence number of this mem node.
- // Lowest mem operand in the DAG starts at zero.
- unsigned SequenceNum;
};
/// This is a helper function for visitMUL to check the profitability
@@ -435,12 +450,6 @@ namespace {
SDValue &AddNode,
SDValue &ConstNode);
- /// This is a helper function for MergeStoresOfConstantsOrVecElts. Returns a
- /// constant build_vector of the stored constant values in Stores.
- SDValue getMergedConstantVectorStore(SelectionDAG &DAG, const SDLoc &SL,
- ArrayRef<MemOpLink> Stores,
- SmallVectorImpl<SDValue> &Chains,
- EVT Ty) const;
/// This is a helper function for visitAND and visitZERO_EXTEND. Returns
/// true if the (and (load x) c) pattern matches an extload. ExtVT returns
@@ -451,34 +460,36 @@ namespace {
EVT LoadResultTy, EVT &ExtVT, EVT &LoadedVT,
bool &NarrowLoad);
+ /// Helper function for MergeConsecutiveStores which merges the
+ /// component store chains.
+ SDValue getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes,
+ unsigned NumStores);
+
/// This is a helper function for MergeConsecutiveStores. When the source
/// elements of the consecutive stores are all constants or all extracted
/// vector elements, try to merge them into one larger store.
- /// \return number of stores that were merged into a merged store (always
- /// a prefix of \p StoreNode).
- bool MergeStoresOfConstantsOrVecElts(
- SmallVectorImpl<MemOpLink> &StoreNodes, EVT MemVT, unsigned NumStores,
- bool IsConstantSrc, bool UseVector);
+ /// \return True if a merged store was created.
+ bool MergeStoresOfConstantsOrVecElts(SmallVectorImpl<MemOpLink> &StoreNodes,
+ EVT MemVT, unsigned NumStores,
+ bool IsConstantSrc, bool UseVector,
+ bool UseTrunc);
/// This is a helper function for MergeConsecutiveStores.
/// Stores that may be merged are placed in StoreNodes.
- /// Loads that may alias with those stores are placed in AliasLoadNodes.
- void getStoreMergeAndAliasCandidates(
- StoreSDNode* St, SmallVectorImpl<MemOpLink> &StoreNodes,
- SmallVectorImpl<LSBaseSDNode*> &AliasLoadNodes);
+ void getStoreMergeCandidates(StoreSDNode *St,
+ SmallVectorImpl<MemOpLink> &StoreNodes);
/// Helper function for MergeConsecutiveStores. Checks if
/// Candidate stores have indirect dependency through their
/// operands. \return True if safe to merge
bool checkMergeStoreCandidatesForDependencies(
- SmallVectorImpl<MemOpLink> &StoreNodes);
+ SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores);
/// Merge consecutive store operations into a wide store.
/// This optimization uses wide integers or vectors when possible.
/// \return number of stores that were merged into a merged store (the
/// affected nodes are stored as a prefix in \p StoreNodes).
- bool MergeConsecutiveStores(StoreSDNode *N,
- SmallVectorImpl<MemOpLink> &StoreNodes);
+ bool MergeConsecutiveStores(StoreSDNode *N);
/// \brief Try to transform a truncation where C is a constant:
/// (trunc (and X, C)) -> (and (trunc X), (trunc C))
@@ -489,10 +500,17 @@ namespace {
SDValue distributeTruncateThroughAnd(SDNode *N);
public:
- DAGCombiner(SelectionDAG &D, AliasAnalysis &A, CodeGenOpt::Level OL)
+ DAGCombiner(SelectionDAG &D, AliasAnalysis *AA, CodeGenOpt::Level OL)
: DAG(D), TLI(D.getTargetLoweringInfo()), Level(BeforeLegalizeTypes),
- OptLevel(OL), LegalOperations(false), LegalTypes(false), AA(A) {
+ OptLevel(OL), LegalOperations(false), LegalTypes(false), AA(AA) {
ForCodeSize = DAG.getMachineFunction().getFunction()->optForSize();
+
+ MaximumLegalStoreInBits = 0;
+ for (MVT VT : MVT::all_valuetypes())
+ if (EVT(VT).isSimple() && VT != MVT::Other &&
+ TLI.isTypeLegal(EVT(VT)) &&
+ VT.getSizeInBits() >= MaximumLegalStoreInBits)
+ MaximumLegalStoreInBits = VT.getSizeInBits();
}
/// Runs the dag combiner on all nodes in the work list
@@ -607,10 +625,16 @@ static char isNegatibleForFree(SDValue Op, bool LegalOperations,
switch (Op.getOpcode()) {
default: return false;
- case ISD::ConstantFP:
- // Don't invert constant FP values after legalize. The negated constant
- // isn't necessarily legal.
- return LegalOperations ? 0 : 1;
+ case ISD::ConstantFP: {
+ if (!LegalOperations)
+ return 1;
+
+ // Don't invert constant FP values after legalization unless the target says
+ // the negated constant is legal.
+ EVT VT = Op.getValueType();
+ return TLI.isOperationLegal(ISD::ConstantFP, VT) ||
+ TLI.isFPImmLegal(neg(cast<ConstantFPSDNode>(Op)->getValueAPF()), VT);
+ }
case ISD::FADD:
// FIXME: determine better conditions for this xform.
if (!Options->UnsafeFPMath) return 0;
@@ -629,7 +653,8 @@ static char isNegatibleForFree(SDValue Op, bool LegalOperations,
Depth + 1);
case ISD::FSUB:
// We can't turn -(A-B) into B-A when we honor signed zeros.
- if (!Options->UnsafeFPMath && !Op.getNode()->getFlags()->hasNoSignedZeros())
+ if (!Options->NoSignedZerosFPMath &&
+ !Op.getNode()->getFlags().hasNoSignedZeros())
return 0;
// fold (fneg (fsub A, B)) -> (fsub B, A)
@@ -667,7 +692,7 @@ static SDValue GetNegatedExpression(SDValue Op, SelectionDAG &DAG,
assert(Depth <= 6 && "GetNegatedExpression doesn't match isNegatibleForFree");
- const SDNodeFlags *Flags = Op.getNode()->getFlags();
+ const SDNodeFlags Flags = Op.getNode()->getFlags();
switch (Op.getOpcode()) {
default: llvm_unreachable("Unknown code");
@@ -950,8 +975,8 @@ CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) {
/// things it uses can be simplified by bit propagation. If so, return true.
bool DAGCombiner::SimplifyDemandedBits(SDValue Op, const APInt &Demanded) {
TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
- APInt KnownZero, KnownOne;
- if (!TLI.SimplifyDemandedBits(Op, Demanded, KnownZero, KnownOne, TLO))
+ KnownBits Known;
+ if (!TLI.SimplifyDemandedBits(Op, Demanded, Known, TLO))
return false;
// Revisit the node.
@@ -1006,13 +1031,13 @@ SDValue DAGCombiner::PromoteOperand(SDValue Op, EVT PVT, bool &Replace) {
switch (Opc) {
default: break;
case ISD::AssertSext:
- return DAG.getNode(ISD::AssertSext, DL, PVT,
- SExtPromoteOperand(Op.getOperand(0), PVT),
- Op.getOperand(1));
+ if (SDValue Op0 = SExtPromoteOperand(Op.getOperand(0), PVT))
+ return DAG.getNode(ISD::AssertSext, DL, PVT, Op0, Op.getOperand(1));
+ break;
case ISD::AssertZext:
- return DAG.getNode(ISD::AssertZext, DL, PVT,
- ZExtPromoteOperand(Op.getOperand(0), PVT),
- Op.getOperand(1));
+ if (SDValue Op0 = ZExtPromoteOperand(Op.getOperand(0), PVT))
+ return DAG.getNode(ISD::AssertZext, DL, PVT, Op0, Op.getOperand(1));
+ break;
case ISD::Constant: {
unsigned ExtOpc =
Op.getValueType().isByteSized() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
@@ -1079,37 +1104,44 @@ SDValue DAGCombiner::PromoteIntBinOp(SDValue Op) {
if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
assert(PVT != VT && "Don't know what type to promote to!");
+ DEBUG(dbgs() << "\nPromoting "; Op.getNode()->dump(&DAG));
+
bool Replace0 = false;
SDValue N0 = Op.getOperand(0);
SDValue NN0 = PromoteOperand(N0, PVT, Replace0);
- if (!NN0.getNode())
- return SDValue();
bool Replace1 = false;
SDValue N1 = Op.getOperand(1);
- SDValue NN1;
- if (N0 == N1)
- NN1 = NN0;
- else {
- NN1 = PromoteOperand(N1, PVT, Replace1);
- if (!NN1.getNode())
- return SDValue();
- }
+ SDValue NN1 = PromoteOperand(N1, PVT, Replace1);
+ SDLoc DL(Op);
- AddToWorklist(NN0.getNode());
- if (NN1.getNode())
- AddToWorklist(NN1.getNode());
+ SDValue RV =
+ DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getNode(Opc, DL, PVT, NN0, NN1));
+
+ // We are always replacing N0/N1's use in N and only need
+ // additional replacements if there are additional uses.
+ Replace0 &= !N0->hasOneUse();
+ Replace1 &= (N0 != N1) && !N1->hasOneUse();
+
+ // Combine Op here so it is presreved past replacements.
+ CombineTo(Op.getNode(), RV);
- if (Replace0)
+ // If operands have a use ordering, make sur we deal with
+ // predecessor first.
+ if (Replace0 && Replace1 && N0.getNode()->isPredecessorOf(N1.getNode())) {
+ std::swap(N0, N1);
+ std::swap(NN0, NN1);
+ }
+
+ if (Replace0) {
+ AddToWorklist(NN0.getNode());
ReplaceLoadWithPromotedLoad(N0.getNode(), NN0.getNode());
- if (Replace1)
+ }
+ if (Replace1) {
+ AddToWorklist(NN1.getNode());
ReplaceLoadWithPromotedLoad(N1.getNode(), NN1.getNode());
-
- DEBUG(dbgs() << "\nPromoting ";
- Op.getNode()->dump(&DAG));
- SDLoc DL(Op);
- return DAG.getNode(ISD::TRUNCATE, DL, VT,
- DAG.getNode(Opc, DL, PVT, NN0, NN1));
+ }
+ return Op;
}
return SDValue();
}
@@ -1137,26 +1169,32 @@ SDValue DAGCombiner::PromoteIntShiftOp(SDValue Op) {
if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
assert(PVT != VT && "Don't know what type to promote to!");
+ DEBUG(dbgs() << "\nPromoting "; Op.getNode()->dump(&DAG));
+
bool Replace = false;
SDValue N0 = Op.getOperand(0);
+ SDValue N1 = Op.getOperand(1);
if (Opc == ISD::SRA)
- N0 = SExtPromoteOperand(Op.getOperand(0), PVT);
+ N0 = SExtPromoteOperand(N0, PVT);
else if (Opc == ISD::SRL)
- N0 = ZExtPromoteOperand(Op.getOperand(0), PVT);
+ N0 = ZExtPromoteOperand(N0, PVT);
else
N0 = PromoteOperand(N0, PVT, Replace);
+
if (!N0.getNode())
return SDValue();
+ SDLoc DL(Op);
+ SDValue RV =
+ DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getNode(Opc, DL, PVT, N0, N1));
+
AddToWorklist(N0.getNode());
if (Replace)
ReplaceLoadWithPromotedLoad(Op.getOperand(0).getNode(), N0.getNode());
- DEBUG(dbgs() << "\nPromoting ";
- Op.getNode()->dump(&DAG));
- SDLoc DL(Op);
- return DAG.getNode(ISD::TRUNCATE, DL, VT,
- DAG.getNode(Opc, DL, PVT, N0, Op.getOperand(1)));
+ // Deal with Op being deleted.
+ if (Op && Op.getOpcode() != ISD::DELETED_NODE)
+ return RV;
}
return SDValue();
}
@@ -1361,8 +1399,7 @@ void DAGCombiner::Run(CombineLevel AtLevel) {
else {
assert(N->getValueType(0) == RV.getValueType() &&
N->getNumValues() == 1 && "Type mismatch");
- SDValue OpV = RV;
- DAG.ReplaceAllUsesWith(N, &OpV);
+ DAG.ReplaceAllUsesWith(N, &RV);
}
// Push the new node and any users onto the worklist
@@ -1389,9 +1426,13 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::ADD: return visitADD(N);
case ISD::SUB: return visitSUB(N);
case ISD::ADDC: return visitADDC(N);
+ case ISD::UADDO: return visitUADDO(N);
case ISD::SUBC: return visitSUBC(N);
+ case ISD::USUBO: return visitUSUBO(N);
case ISD::ADDE: return visitADDE(N);
+ case ISD::ADDCARRY: return visitADDCARRY(N);
case ISD::SUBE: return visitSUBE(N);
+ case ISD::SUBCARRY: return visitSUBCARRY(N);
case ISD::MUL: return visitMUL(N);
case ISD::SDIV: return visitSDIV(N);
case ISD::UDIV: return visitUDIV(N);
@@ -1415,6 +1456,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::SRL: return visitSRL(N);
case ISD::ROTR:
case ISD::ROTL: return visitRotate(N);
+ case ISD::ABS: return visitABS(N);
case ISD::BSWAP: return visitBSWAP(N);
case ISD::BITREVERSE: return visitBITREVERSE(N);
case ISD::CTLZ: return visitCTLZ(N);
@@ -1427,9 +1469,11 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::SELECT_CC: return visitSELECT_CC(N);
case ISD::SETCC: return visitSETCC(N);
case ISD::SETCCE: return visitSETCCE(N);
+ case ISD::SETCCCARRY: return visitSETCCCARRY(N);
case ISD::SIGN_EXTEND: return visitSIGN_EXTEND(N);
case ISD::ZERO_EXTEND: return visitZERO_EXTEND(N);
case ISD::ANY_EXTEND: return visitANY_EXTEND(N);
+ case ISD::AssertZext: return visitAssertZext(N);
case ISD::SIGN_EXTEND_INREG: return visitSIGN_EXTEND_INREG(N);
case ISD::SIGN_EXTEND_VECTOR_INREG: return visitSIGN_EXTEND_VECTOR_INREG(N);
case ISD::ZERO_EXTEND_VECTOR_INREG: return visitZERO_EXTEND_VECTOR_INREG(N);
@@ -1530,7 +1574,7 @@ SDValue DAGCombiner::combine(SDNode *N) {
// If N is a commutative binary node, try commuting it to enable more
// sdisel CSE.
- if (!RV.getNode() && SelectionDAG::isCommutativeBinOp(N->getOpcode()) &&
+ if (!RV.getNode() && TLI.isCommutativeBinOp(N->getOpcode()) &&
N->getNumValues() == 1) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
@@ -1574,7 +1618,7 @@ SDValue DAGCombiner::visitTokenFactor(SDNode *N) {
}
SmallVector<SDNode *, 8> TFs; // List of token factors to visit.
- SmallVector<SDValue, 8> Ops; // Ops for replacing token factor.
+ SmallVector<SDValue, 8> Ops; // Ops for replacing token factor.
SmallPtrSet<SDNode*, 16> SeenOps;
bool Changed = false; // If we should replace this token factor.
@@ -1618,26 +1662,108 @@ SDValue DAGCombiner::visitTokenFactor(SDNode *N) {
}
}
- SDValue Result;
+ // Remove Nodes that are chained to another node in the list. Do so
+ // by walking up chains breath-first stopping when we've seen
+ // another operand. In general we must climb to the EntryNode, but we can exit
+ // early if we find all remaining work is associated with just one operand as
+ // no further pruning is possible.
+
+ // List of nodes to search through and original Ops from which they originate.
+ SmallVector<std::pair<SDNode *, unsigned>, 8> Worklist;
+ SmallVector<unsigned, 8> OpWorkCount; // Count of work for each Op.
+ SmallPtrSet<SDNode *, 16> SeenChains;
+ bool DidPruneOps = false;
+
+ unsigned NumLeftToConsider = 0;
+ for (const SDValue &Op : Ops) {
+ Worklist.push_back(std::make_pair(Op.getNode(), NumLeftToConsider++));
+ OpWorkCount.push_back(1);
+ }
+
+ auto AddToWorklist = [&](unsigned CurIdx, SDNode *Op, unsigned OpNumber) {
+ // If this is an Op, we can remove the op from the list. Remark any
+ // search associated with it as from the current OpNumber.
+ if (SeenOps.count(Op) != 0) {
+ Changed = true;
+ DidPruneOps = true;
+ unsigned OrigOpNumber = 0;
+ while (OrigOpNumber < Ops.size() && Ops[OrigOpNumber].getNode() != Op)
+ OrigOpNumber++;
+ assert((OrigOpNumber != Ops.size()) &&
+ "expected to find TokenFactor Operand");
+ // Re-mark worklist from OrigOpNumber to OpNumber
+ for (unsigned i = CurIdx + 1; i < Worklist.size(); ++i) {
+ if (Worklist[i].second == OrigOpNumber) {
+ Worklist[i].second = OpNumber;
+ }
+ }
+ OpWorkCount[OpNumber] += OpWorkCount[OrigOpNumber];
+ OpWorkCount[OrigOpNumber] = 0;
+ NumLeftToConsider--;
+ }
+ // Add if it's a new chain
+ if (SeenChains.insert(Op).second) {
+ OpWorkCount[OpNumber]++;
+ Worklist.push_back(std::make_pair(Op, OpNumber));
+ }
+ };
+
+ for (unsigned i = 0; i < Worklist.size() && i < 1024; ++i) {
+ // We need at least be consider at least 2 Ops to prune.
+ if (NumLeftToConsider <= 1)
+ break;
+ auto CurNode = Worklist[i].first;
+ auto CurOpNumber = Worklist[i].second;
+ assert((OpWorkCount[CurOpNumber] > 0) &&
+ "Node should not appear in worklist");
+ switch (CurNode->getOpcode()) {
+ case ISD::EntryToken:
+ // Hitting EntryToken is the only way for the search to terminate without
+ // hitting
+ // another operand's search. Prevent us from marking this operand
+ // considered.
+ NumLeftToConsider++;
+ break;
+ case ISD::TokenFactor:
+ for (const SDValue &Op : CurNode->op_values())
+ AddToWorklist(i, Op.getNode(), CurOpNumber);
+ break;
+ case ISD::CopyFromReg:
+ case ISD::CopyToReg:
+ AddToWorklist(i, CurNode->getOperand(0).getNode(), CurOpNumber);
+ break;
+ default:
+ if (auto *MemNode = dyn_cast<MemSDNode>(CurNode))
+ AddToWorklist(i, MemNode->getChain().getNode(), CurOpNumber);
+ break;
+ }
+ OpWorkCount[CurOpNumber]--;
+ if (OpWorkCount[CurOpNumber] == 0)
+ NumLeftToConsider--;
+ }
// If we've changed things around then replace token factor.
if (Changed) {
+ SDValue Result;
if (Ops.empty()) {
// The entry token is the only possible outcome.
Result = DAG.getEntryNode();
} else {
- // New and improved token factor.
- Result = DAG.getNode(ISD::TokenFactor, SDLoc(N), MVT::Other, Ops);
+ if (DidPruneOps) {
+ SmallVector<SDValue, 8> PrunedOps;
+ //
+ for (const SDValue &Op : Ops) {
+ if (SeenChains.count(Op.getNode()) == 0)
+ PrunedOps.push_back(Op);
+ }
+ Result = DAG.getNode(ISD::TokenFactor, SDLoc(N), MVT::Other, PrunedOps);
+ } else {
+ Result = DAG.getNode(ISD::TokenFactor, SDLoc(N), MVT::Other, Ops);
+ }
}
-
- // Add users to worklist if AA is enabled, since it may introduce
- // a lot of new chained token factors while removing memory deps.
- bool UseAA = CombinerAA.getNumOccurrences() > 0 ? CombinerAA
- : DAG.getSubtarget().useAA();
- return CombineTo(N, Result, UseAA /*add to worklist*/);
+ return Result;
}
-
- return Result;
+ return SDValue();
}
/// MERGE_VALUES can always be eliminated.
@@ -1664,6 +1790,60 @@ static ConstantSDNode *getAsNonOpaqueConstant(SDValue N) {
return Const != nullptr && !Const->isOpaque() ? Const : nullptr;
}
+SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) {
+ auto BinOpcode = BO->getOpcode();
+ assert((BinOpcode == ISD::ADD || BinOpcode == ISD::SUB ||
+ BinOpcode == ISD::MUL || BinOpcode == ISD::SDIV ||
+ BinOpcode == ISD::UDIV || BinOpcode == ISD::SREM ||
+ BinOpcode == ISD::UREM || BinOpcode == ISD::AND ||
+ BinOpcode == ISD::OR || BinOpcode == ISD::XOR ||
+ BinOpcode == ISD::SHL || BinOpcode == ISD::SRL ||
+ BinOpcode == ISD::SRA || BinOpcode == ISD::FADD ||
+ BinOpcode == ISD::FSUB || BinOpcode == ISD::FMUL ||
+ BinOpcode == ISD::FDIV || BinOpcode == ISD::FREM) &&
+ "Unexpected binary operator");
+
+ // Bail out if any constants are opaque because we can't constant fold those.
+ SDValue C1 = BO->getOperand(1);
+ if (!isConstantOrConstantVector(C1, true) &&
+ !isConstantFPBuildVectorOrConstantFP(C1))
+ return SDValue();
+
+ // Don't do this unless the old select is going away. We want to eliminate the
+ // binary operator, not replace a binop with a select.
+ // TODO: Handle ISD::SELECT_CC.
+ SDValue Sel = BO->getOperand(0);
+ if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse())
+ return SDValue();
+
+ SDValue CT = Sel.getOperand(1);
+ if (!isConstantOrConstantVector(CT, true) &&
+ !isConstantFPBuildVectorOrConstantFP(CT))
+ return SDValue();
+
+ SDValue CF = Sel.getOperand(2);
+ if (!isConstantOrConstantVector(CF, true) &&
+ !isConstantFPBuildVectorOrConstantFP(CF))
+ return SDValue();
+
+ // We have a select-of-constants followed by a binary operator with a
+ // constant. Eliminate the binop by pulling the constant math into the select.
+ // Example: add (select Cond, CT, CF), C1 --> select Cond, CT + C1, CF + C1
+ EVT VT = Sel.getValueType();
+ SDLoc DL(Sel);
+ SDValue NewCT = DAG.getNode(BinOpcode, DL, VT, CT, C1);
+ assert((NewCT.isUndef() || isConstantOrConstantVector(NewCT) ||
+ isConstantFPBuildVectorOrConstantFP(NewCT)) &&
+ "Failed to constant fold a binop with constant operands");
+
+ SDValue NewCF = DAG.getNode(BinOpcode, DL, VT, CF, C1);
+ assert((NewCF.isUndef() || isConstantOrConstantVector(NewCF) ||
+ isConstantFPBuildVectorOrConstantFP(NewCF)) &&
+ "Failed to constant fold a binop with constant operands");
+
+ return DAG.getSelect(DL, VT, Sel.getOperand(0), NewCT, NewCF);
+}
+
SDValue DAGCombiner::visitADD(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
@@ -1702,16 +1882,36 @@ SDValue DAGCombiner::visitADD(SDNode *N) {
if (isNullConstant(N1))
return N0;
- // fold ((c1-A)+c2) -> (c1+c2)-A
if (isConstantOrConstantVector(N1, /* NoOpaque */ true)) {
- if (N0.getOpcode() == ISD::SUB)
- if (isConstantOrConstantVector(N0.getOperand(0), /* NoOpaque */ true)) {
- return DAG.getNode(ISD::SUB, DL, VT,
- DAG.getNode(ISD::ADD, DL, VT, N1, N0.getOperand(0)),
- N0.getOperand(1));
+ // fold ((c1-A)+c2) -> (c1+c2)-A
+ if (N0.getOpcode() == ISD::SUB &&
+ isConstantOrConstantVector(N0.getOperand(0), /* NoOpaque */ true)) {
+ // FIXME: Adding 2 constants should be handled by FoldConstantArithmetic.
+ return DAG.getNode(ISD::SUB, DL, VT,
+ DAG.getNode(ISD::ADD, DL, VT, N1, N0.getOperand(0)),
+ N0.getOperand(1));
+ }
+
+ // add (sext i1 X), 1 -> zext (not i1 X)
+ // We don't transform this pattern:
+ // add (zext i1 X), -1 -> sext (not i1 X)
+ // because most (?) targets generate better code for the zext form.
+ if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() &&
+ isOneConstantOrOneSplatConstant(N1)) {
+ SDValue X = N0.getOperand(0);
+ if ((!LegalOperations ||
+ (TLI.isOperationLegal(ISD::XOR, X.getValueType()) &&
+ TLI.isOperationLegal(ISD::ZERO_EXTEND, VT))) &&
+ X.getScalarValueSizeInBits() == 1) {
+ SDValue Not = DAG.getNOT(DL, X, X.getValueType());
+ return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Not);
}
+ }
}
+ if (SDValue NewSel = foldBinOpIntoSelect(N))
+ return NewSel;
+
// reassociate add
if (SDValue RADD = ReassociateOps(ISD::ADD, DL, N0, N1))
return RADD;
@@ -1771,9 +1971,60 @@ SDValue DAGCombiner::visitADD(SDNode *N) {
// fold (a+b) -> (a|b) iff a and b share no bits.
if ((!LegalOperations || TLI.isOperationLegal(ISD::OR, VT)) &&
- VT.isInteger() && DAG.haveNoCommonBitsSet(N0, N1))
+ DAG.haveNoCommonBitsSet(N0, N1))
return DAG.getNode(ISD::OR, DL, VT, N0, N1);
+ if (SDValue Combined = visitADDLike(N0, N1, N))
+ return Combined;
+
+ if (SDValue Combined = visitADDLike(N1, N0, N))
+ return Combined;
+
+ return SDValue();
+}
+
+static SDValue getAsCarry(const TargetLowering &TLI, SDValue V) {
+ bool Masked = false;
+
+ // First, peel away TRUNCATE/ZERO_EXTEND/AND nodes due to legalization.
+ while (true) {
+ if (V.getOpcode() == ISD::TRUNCATE || V.getOpcode() == ISD::ZERO_EXTEND) {
+ V = V.getOperand(0);
+ continue;
+ }
+
+ if (V.getOpcode() == ISD::AND && isOneConstant(V.getOperand(1))) {
+ Masked = true;
+ V = V.getOperand(0);
+ continue;
+ }
+
+ break;
+ }
+
+ // If this is not a carry, return.
+ if (V.getResNo() != 1)
+ return SDValue();
+
+ if (V.getOpcode() != ISD::ADDCARRY && V.getOpcode() != ISD::SUBCARRY &&
+ V.getOpcode() != ISD::UADDO && V.getOpcode() != ISD::USUBO)
+ return SDValue();
+
+ // If the result is masked, then no matter what kind of bool it is we can
+ // return. If it isn't, then we need to make sure the bool type is either 0 or
+ // 1 and not other values.
+ if (Masked ||
+ TLI.getBooleanContents(V.getValueType()) ==
+ TargetLoweringBase::ZeroOrOneBooleanContent)
+ return V;
+
+ return SDValue();
+}
+
+SDValue DAGCombiner::visitADDLike(SDValue N0, SDValue N1, SDNode *LocReference) {
+ EVT VT = N0.getValueType();
+ SDLoc DL(LocReference);
+
// fold (add x, shl(0 - y, n)) -> sub(x, shl(y, n))
if (N1.getOpcode() == ISD::SHL && N1.getOperand(0).getOpcode() == ISD::SUB &&
isNullConstantOrNullSplatConstant(N1.getOperand(0).getOperand(0)))
@@ -1781,12 +2032,6 @@ SDValue DAGCombiner::visitADD(SDNode *N) {
DAG.getNode(ISD::SHL, DL, VT,
N1.getOperand(0).getOperand(1),
N1.getOperand(1)));
- if (N0.getOpcode() == ISD::SHL && N0.getOperand(0).getOpcode() == ISD::SUB &&
- isNullConstantOrNullSplatConstant(N0.getOperand(0).getOperand(0)))
- return DAG.getNode(ISD::SUB, DL, VT, N1,
- DAG.getNode(ISD::SHL, DL, VT,
- N0.getOperand(0).getOperand(1),
- N0.getOperand(1)));
if (N1.getOpcode() == ISD::AND) {
SDValue AndOp0 = N1.getOperand(0);
@@ -1797,7 +2042,7 @@ SDValue DAGCombiner::visitADD(SDNode *N) {
// and similar xforms where the inner op is either ~0 or 0.
if (NumSignBits == DestBits &&
isOneConstantOrOneSplatConstant(N1->getOperand(1)))
- return DAG.getNode(ISD::SUB, DL, VT, N->getOperand(0), AndOp0);
+ return DAG.getNode(ISD::SUB, DL, VT, N0, AndOp0);
}
// add (sext i1), X -> sub X, (zext i1)
@@ -1818,6 +2063,18 @@ SDValue DAGCombiner::visitADD(SDNode *N) {
}
}
+ // (add X, (addcarry Y, 0, Carry)) -> (addcarry X, Y, Carry)
+ if (N1.getOpcode() == ISD::ADDCARRY && isNullConstant(N1.getOperand(1)))
+ return DAG.getNode(ISD::ADDCARRY, DL, N1->getVTList(),
+ N0, N1.getOperand(0), N1.getOperand(2));
+
+ // (add X, Carry) -> (addcarry X, 0, Carry)
+ if (TLI.isOperationLegalOrCustom(ISD::ADDCARRY, VT))
+ if (SDValue Carry = getAsCarry(TLI, N1))
+ return DAG.getNode(ISD::ADDCARRY, DL,
+ DAG.getVTList(VT, Carry.getValueType()), N0,
+ DAG.getConstant(0, DL, VT), Carry);
+
return SDValue();
}
@@ -1825,40 +2082,90 @@ SDValue DAGCombiner::visitADDC(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
EVT VT = N0.getValueType();
+ SDLoc DL(N);
// If the flag result is dead, turn this into an ADD.
if (!N->hasAnyUseOfValue(1))
- return CombineTo(N, DAG.getNode(ISD::ADD, SDLoc(N), VT, N0, N1),
- DAG.getNode(ISD::CARRY_FALSE,
- SDLoc(N), MVT::Glue));
+ return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
+ DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
// canonicalize constant to RHS.
ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
if (N0C && !N1C)
- return DAG.getNode(ISD::ADDC, SDLoc(N), N->getVTList(), N1, N0);
+ return DAG.getNode(ISD::ADDC, DL, N->getVTList(), N1, N0);
// fold (addc x, 0) -> x + no carry out
if (isNullConstant(N1))
return CombineTo(N, N0, DAG.getNode(ISD::CARRY_FALSE,
- SDLoc(N), MVT::Glue));
+ DL, MVT::Glue));
+
+ // If it cannot overflow, transform into an add.
+ if (DAG.computeOverflowKind(N0, N1) == SelectionDAG::OFK_Never)
+ return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
+ DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
+
+ return SDValue();
+}
+
+SDValue DAGCombiner::visitUADDO(SDNode *N) {
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+ EVT VT = N0.getValueType();
+ if (VT.isVector())
+ return SDValue();
+
+ EVT CarryVT = N->getValueType(1);
+ SDLoc DL(N);
+
+ // If the flag result is dead, turn this into an ADD.
+ if (!N->hasAnyUseOfValue(1))
+ return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
+ DAG.getUNDEF(CarryVT));
+
+ // canonicalize constant to RHS.
+ ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
+ ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
+ if (N0C && !N1C)
+ return DAG.getNode(ISD::UADDO, DL, N->getVTList(), N1, N0);
- // fold (addc a, b) -> (or a, b), CARRY_FALSE iff a and b share no bits.
- APInt LHSZero, LHSOne;
- APInt RHSZero, RHSOne;
- DAG.computeKnownBits(N0, LHSZero, LHSOne);
+ // fold (uaddo x, 0) -> x + no carry out
+ if (isNullConstant(N1))
+ return CombineTo(N, N0, DAG.getConstant(0, DL, CarryVT));
- if (LHSZero.getBoolValue()) {
- DAG.computeKnownBits(N1, RHSZero, RHSOne);
+ // If it cannot overflow, transform into an add.
+ if (DAG.computeOverflowKind(N0, N1) == SelectionDAG::OFK_Never)
+ return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
+ DAG.getConstant(0, DL, CarryVT));
+
+ if (SDValue Combined = visitUADDOLike(N0, N1, N))
+ return Combined;
- // If all possibly-set bits on the LHS are clear on the RHS, return an OR.
- // If all possibly-set bits on the RHS are clear on the LHS, return an OR.
- if ((RHSZero & ~LHSZero) == ~LHSZero || (LHSZero & ~RHSZero) == ~RHSZero)
- return CombineTo(N, DAG.getNode(ISD::OR, SDLoc(N), VT, N0, N1),
- DAG.getNode(ISD::CARRY_FALSE,
- SDLoc(N), MVT::Glue));
+ if (SDValue Combined = visitUADDOLike(N1, N0, N))
+ return Combined;
+
+ return SDValue();
+}
+
+SDValue DAGCombiner::visitUADDOLike(SDValue N0, SDValue N1, SDNode *N) {
+ auto VT = N0.getValueType();
+
+ // (uaddo X, (addcarry Y, 0, Carry)) -> (addcarry X, Y, Carry)
+ // If Y + 1 cannot overflow.
+ if (N1.getOpcode() == ISD::ADDCARRY && isNullConstant(N1.getOperand(1))) {
+ SDValue Y = N1.getOperand(0);
+ SDValue One = DAG.getConstant(1, SDLoc(N), Y.getValueType());
+ if (DAG.computeOverflowKind(Y, One) == SelectionDAG::OFK_Never)
+ return DAG.getNode(ISD::ADDCARRY, SDLoc(N), N->getVTList(), N0, Y,
+ N1.getOperand(2));
}
+ // (uaddo X, Carry) -> (addcarry X, 0, Carry)
+ if (TLI.isOperationLegalOrCustom(ISD::ADDCARRY, VT))
+ if (SDValue Carry = getAsCarry(TLI, N1))
+ return DAG.getNode(ISD::ADDCARRY, SDLoc(N), N->getVTList(), N0,
+ DAG.getConstant(0, SDLoc(N), VT), Carry);
+
return SDValue();
}
@@ -1881,6 +2188,90 @@ SDValue DAGCombiner::visitADDE(SDNode *N) {
return SDValue();
}
+SDValue DAGCombiner::visitADDCARRY(SDNode *N) {
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+ SDValue CarryIn = N->getOperand(2);
+ SDLoc DL(N);
+
+ // canonicalize constant to RHS
+ ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
+ ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
+ if (N0C && !N1C)
+ return DAG.getNode(ISD::ADDCARRY, DL, N->getVTList(), N1, N0, CarryIn);
+
+ // fold (addcarry x, y, false) -> (uaddo x, y)
+ if (isNullConstant(CarryIn))
+ return DAG.getNode(ISD::UADDO, DL, N->getVTList(), N0, N1);
+
+ // fold (addcarry 0, 0, X) -> (and (ext/trunc X), 1) and no carry.
+ if (isNullConstant(N0) && isNullConstant(N1)) {
+ EVT VT = N0.getValueType();
+ EVT CarryVT = CarryIn.getValueType();
+ SDValue CarryExt = DAG.getBoolExtOrTrunc(CarryIn, DL, VT, CarryVT);
+ AddToWorklist(CarryExt.getNode());
+ return CombineTo(N, DAG.getNode(ISD::AND, DL, VT, CarryExt,
+ DAG.getConstant(1, DL, VT)),
+ DAG.getConstant(0, DL, CarryVT));
+ }
+
+ if (SDValue Combined = visitADDCARRYLike(N0, N1, CarryIn, N))
+ return Combined;
+
+ if (SDValue Combined = visitADDCARRYLike(N1, N0, CarryIn, N))
+ return Combined;
+
+ return SDValue();
+}
+
+SDValue DAGCombiner::visitADDCARRYLike(SDValue N0, SDValue N1, SDValue CarryIn,
+ SDNode *N) {
+ // Iff the flag result is dead:
+ // (addcarry (add|uaddo X, Y), 0, Carry) -> (addcarry X, Y, Carry)
+ if ((N0.getOpcode() == ISD::ADD ||
+ (N0.getOpcode() == ISD::UADDO && N0.getResNo() == 0)) &&
+ isNullConstant(N1) && !N->hasAnyUseOfValue(1))
+ return DAG.getNode(ISD::ADDCARRY, SDLoc(N), N->getVTList(),
+ N0.getOperand(0), N0.getOperand(1), CarryIn);
+
+ /**
+ * When one of the addcarry argument is itself a carry, we may be facing
+ * a diamond carry propagation. In which case we try to transform the DAG
+ * to ensure linear carry propagation if that is possible.
+ *
+ * We are trying to get:
+ * (addcarry X, 0, (addcarry A, B, Z):Carry)
+ */
+ if (auto Y = getAsCarry(TLI, N1)) {
+ /**
+ * (uaddo A, B)
+ * / \
+ * Carry Sum
+ * | \
+ * | (addcarry *, 0, Z)
+ * | /
+ * \ Carry
+ * | /
+ * (addcarry X, *, *)
+ */
+ if (Y.getOpcode() == ISD::UADDO &&
+ CarryIn.getResNo() == 1 &&
+ CarryIn.getOpcode() == ISD::ADDCARRY &&
+ isNullConstant(CarryIn.getOperand(1)) &&
+ CarryIn.getOperand(0) == Y.getValue(0)) {
+ auto NewY = DAG.getNode(ISD::ADDCARRY, SDLoc(N), Y->getVTList(),
+ Y.getOperand(0), Y.getOperand(1),
+ CarryIn.getOperand(2));
+ AddToWorklist(NewY.getNode());
+ return DAG.getNode(ISD::ADDCARRY, SDLoc(N), N->getVTList(), N0,
+ DAG.getConstant(0, SDLoc(N), N0.getValueType()),
+ NewY.getValue(1));
+ }
+ }
+
+ return SDValue();
+}
+
// Since it may not be valid to emit a fold to zero for vector initializers
// check if we can before folding.
static SDValue tryFoldToZero(const SDLoc &DL, const TargetLowering &TLI, EVT VT,
@@ -1920,6 +2311,9 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {
N1.getNode());
}
+ if (SDValue NewSel = foldBinOpIntoSelect(N))
+ return NewSel;
+
ConstantSDNode *N1C = getAsNonOpaqueConstant(N1);
// fold (sub x, c) -> (add x, -c)
@@ -1944,13 +2338,13 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {
}
// 0 - X --> 0 if the sub is NUW.
- if (N->getFlags()->hasNoUnsignedWrap())
+ if (N->getFlags().hasNoUnsignedWrap())
return N0;
- if (DAG.MaskedValueIsZero(N1, ~APInt::getSignBit(BitWidth))) {
+ if (DAG.MaskedValueIsZero(N1, ~APInt::getSignMask(BitWidth))) {
// N1 is either 0 or the minimum signed value. If the sub is NSW, then
// N1 must be 0 because negating the minimum signed value is undefined.
- if (N->getFlags()->hasNoSignedWrap())
+ if (N->getFlags().hasNoSignedWrap())
return N0;
// 0 - X --> X if X is 0 or the minimum signed value.
@@ -2066,6 +2460,38 @@ SDValue DAGCombiner::visitSUBC(SDNode *N) {
return SDValue();
}
+SDValue DAGCombiner::visitUSUBO(SDNode *N) {
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+ EVT VT = N0.getValueType();
+ if (VT.isVector())
+ return SDValue();
+
+ EVT CarryVT = N->getValueType(1);
+ SDLoc DL(N);
+
+ // If the flag result is dead, turn this into an SUB.
+ if (!N->hasAnyUseOfValue(1))
+ return CombineTo(N, DAG.getNode(ISD::SUB, DL, VT, N0, N1),
+ DAG.getUNDEF(CarryVT));
+
+ // fold (usubo x, x) -> 0 + no borrow
+ if (N0 == N1)
+ return CombineTo(N, DAG.getConstant(0, DL, VT),
+ DAG.getConstant(0, DL, CarryVT));
+
+ // fold (usubo x, 0) -> x + no borrow
+ if (isNullConstant(N1))
+ return CombineTo(N, N0, DAG.getConstant(0, DL, CarryVT));
+
+ // Canonicalize (usubo -1, x) -> ~x, i.e. (xor x, -1) + no borrow
+ if (isAllOnesConstant(N0))
+ return CombineTo(N, DAG.getNode(ISD::XOR, DL, VT, N1, N0),
+ DAG.getConstant(0, DL, CarryVT));
+
+ return SDValue();
+}
+
SDValue DAGCombiner::visitSUBE(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
@@ -2078,6 +2504,18 @@ SDValue DAGCombiner::visitSUBE(SDNode *N) {
return SDValue();
}
+SDValue DAGCombiner::visitSUBCARRY(SDNode *N) {
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+ SDValue CarryIn = N->getOperand(2);
+
+ // fold (subcarry x, y, false) -> (usubo x, y)
+ if (isNullConstant(CarryIn))
+ return DAG.getNode(ISD::USUBO, SDLoc(N), N->getVTList(), N0, N1);
+
+ return SDValue();
+}
+
SDValue DAGCombiner::visitMUL(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
@@ -2122,15 +2560,19 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
!DAG.isConstantIntBuildVectorOrConstantInt(N1))
return DAG.getNode(ISD::MUL, SDLoc(N), VT, N1, N0);
// fold (mul x, 0) -> 0
- if (N1IsConst && ConstValue1 == 0)
+ if (N1IsConst && ConstValue1.isNullValue())
return N1;
// We require a splat of the entire scalar bit width for non-contiguous
// bit patterns.
bool IsFullSplat =
ConstValue1.getBitWidth() == VT.getScalarSizeInBits();
// fold (mul x, 1) -> x
- if (N1IsConst && ConstValue1 == 1 && IsFullSplat)
+ if (N1IsConst && ConstValue1.isOneValue() && IsFullSplat)
return N0;
+
+ if (SDValue NewSel = foldBinOpIntoSelect(N))
+ return NewSel;
+
// fold (mul x, -1) -> 0-x
if (N1IsConst && ConstValue1.isAllOnesValue()) {
SDLoc DL(N);
@@ -2297,6 +2739,23 @@ SDValue DAGCombiner::useDivRem(SDNode *Node) {
return combined;
}
+static SDValue simplifyDivRem(SDNode *N, SelectionDAG &DAG) {
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+ EVT VT = N->getValueType(0);
+ SDLoc DL(N);
+
+ if (DAG.isUndef(N->getOpcode(), {N0, N1}))
+ return DAG.getUNDEF(VT);
+
+ // undef / X -> 0
+ // undef % X -> 0
+ if (N0.isUndef())
+ return DAG.getConstant(0, DL, VT);
+
+ return SDValue();
+}
+
SDValue DAGCombiner::visitSDIV(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
@@ -2319,8 +2778,13 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) {
return N0;
// fold (sdiv X, -1) -> 0-X
if (N1C && N1C->isAllOnesValue())
- return DAG.getNode(ISD::SUB, DL, VT,
- DAG.getConstant(0, DL, VT), N0);
+ return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), N0);
+
+ if (SDValue V = simplifyDivRem(N, DAG))
+ return V;
+
+ if (SDValue NewSel = foldBinOpIntoSelect(N))
+ return NewSel;
// If we know the sign bits of both operands are zero, strength reduce to a
// udiv instead. Handles (X&15) /s 4 -> X&15 >> 2
@@ -2332,9 +2796,8 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) {
// better results in that case. The target-specific lowering should learn how
// to handle exact sdivs efficiently.
if (N1C && !N1C->isNullValue() && !N1C->isOpaque() &&
- !cast<BinaryWithFlagsSDNode>(N)->Flags.hasExact() &&
- (N1C->getAPIntValue().isPowerOf2() ||
- (-N1C->getAPIntValue()).isPowerOf2())) {
+ !N->getFlags().hasExact() && (N1C->getAPIntValue().isPowerOf2() ||
+ (-N1C->getAPIntValue()).isPowerOf2())) {
// Target-specific implementation of sdiv x, pow2.
if (SDValue Res = BuildSDIVPow2(N))
return Res;
@@ -2372,7 +2835,7 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) {
// If integer divide is expensive and we satisfy the requirements, emit an
// alternate sequence. Targets may check function attributes for size/speed
// trade-offs.
- AttributeSet Attr = DAG.getMachineFunction().getFunction()->getAttributes();
+ AttributeList Attr = DAG.getMachineFunction().getFunction()->getAttributes();
if (N1C && !TLI.isIntDivCheap(N->getValueType(0), Attr))
if (SDValue Op = BuildSDIV(N))
return Op;
@@ -2384,13 +2847,6 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) {
if (SDValue DivRem = useDivRem(N))
return DivRem;
- // undef / X -> 0
- if (N0.isUndef())
- return DAG.getConstant(0, DL, VT);
- // X / undef -> undef
- if (N1.isUndef())
- return N1;
-
return SDValue();
}
@@ -2414,6 +2870,12 @@ SDValue DAGCombiner::visitUDIV(SDNode *N) {
N0C, N1C))
return Folded;
+ if (SDValue V = simplifyDivRem(N, DAG))
+ return V;
+
+ if (SDValue NewSel = foldBinOpIntoSelect(N))
+ return NewSel;
+
// fold (udiv x, (1 << c)) -> x >>u c
if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
DAG.isKnownToBeAPowerOfTwo(N1)) {
@@ -2444,7 +2906,7 @@ SDValue DAGCombiner::visitUDIV(SDNode *N) {
}
// fold (udiv x, c) -> alternate
- AttributeSet Attr = DAG.getMachineFunction().getFunction()->getAttributes();
+ AttributeList Attr = DAG.getMachineFunction().getFunction()->getAttributes();
if (N1C && !TLI.isIntDivCheap(N->getValueType(0), Attr))
if (SDValue Op = BuildUDIV(N))
return Op;
@@ -2456,13 +2918,6 @@ SDValue DAGCombiner::visitUDIV(SDNode *N) {
if (SDValue DivRem = useDivRem(N))
return DivRem;
- // undef / X -> 0
- if (N0.isUndef())
- return DAG.getConstant(0, DL, VT);
- // X / undef -> undef
- if (N1.isUndef())
- return N1;
-
return SDValue();
}
@@ -2482,32 +2937,35 @@ SDValue DAGCombiner::visitREM(SDNode *N) {
if (SDValue Folded = DAG.FoldConstantArithmetic(Opcode, DL, VT, N0C, N1C))
return Folded;
+ if (SDValue V = simplifyDivRem(N, DAG))
+ return V;
+
+ if (SDValue NewSel = foldBinOpIntoSelect(N))
+ return NewSel;
+
if (isSigned) {
// If we know the sign bits of both operands are zero, strength reduce to a
// urem instead. Handles (X & 0x0FFFFFFF) %s 16 -> X&15
if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0))
return DAG.getNode(ISD::UREM, DL, VT, N0, N1);
} else {
- // fold (urem x, pow2) -> (and x, pow2-1)
+ SDValue NegOne = DAG.getAllOnesConstant(DL, VT);
if (DAG.isKnownToBeAPowerOfTwo(N1)) {
- APInt NegOne = APInt::getAllOnesValue(VT.getScalarSizeInBits());
- SDValue Add =
- DAG.getNode(ISD::ADD, DL, VT, N1, DAG.getConstant(NegOne, DL, VT));
+ // fold (urem x, pow2) -> (and x, pow2-1)
+ SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N1, NegOne);
AddToWorklist(Add.getNode());
return DAG.getNode(ISD::AND, DL, VT, N0, Add);
}
- // fold (urem x, (shl pow2, y)) -> (and x, (add (shl pow2, y), -1))
if (N1.getOpcode() == ISD::SHL &&
DAG.isKnownToBeAPowerOfTwo(N1.getOperand(0))) {
- APInt NegOne = APInt::getAllOnesValue(VT.getScalarSizeInBits());
- SDValue Add =
- DAG.getNode(ISD::ADD, DL, VT, N1, DAG.getConstant(NegOne, DL, VT));
+ // fold (urem x, (shl pow2, y)) -> (and x, (add (shl pow2, y), -1))
+ SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N1, NegOne);
AddToWorklist(Add.getNode());
return DAG.getNode(ISD::AND, DL, VT, N0, Add);
}
}
- AttributeSet Attr = DAG.getMachineFunction().getFunction()->getAttributes();
+ AttributeList Attr = DAG.getMachineFunction().getFunction()->getAttributes();
// If X/C can be simplified by the division-by-constant logic, lower
// X%C to the equivalent of X-X/C*C.
@@ -2536,13 +2994,6 @@ SDValue DAGCombiner::visitREM(SDNode *N) {
if (SDValue DivRem = useDivRem(N))
return DivRem.getValue(1);
- // undef % X -> 0
- if (N0.isUndef())
- return DAG.getConstant(0, DL, VT);
- // X % undef -> undef
- if (N1.isUndef())
- return N1;
-
return SDValue();
}
@@ -2932,95 +3383,139 @@ SDValue DAGCombiner::SimplifyBinOpWithSameOpcodeHands(SDNode *N) {
return SDValue();
}
+/// Try to make (and/or setcc (LL, LR), setcc (RL, RR)) more efficient.
+SDValue DAGCombiner::foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
+ const SDLoc &DL) {
+ SDValue LL, LR, RL, RR, N0CC, N1CC;
+ if (!isSetCCEquivalent(N0, LL, LR, N0CC) ||
+ !isSetCCEquivalent(N1, RL, RR, N1CC))
+ return SDValue();
+
+ assert(N0.getValueType() == N1.getValueType() &&
+ "Unexpected operand types for bitwise logic op");
+ assert(LL.getValueType() == LR.getValueType() &&
+ RL.getValueType() == RR.getValueType() &&
+ "Unexpected operand types for setcc");
+
+ // If we're here post-legalization or the logic op type is not i1, the logic
+ // op type must match a setcc result type. Also, all folds require new
+ // operations on the left and right operands, so those types must match.
+ EVT VT = N0.getValueType();
+ EVT OpVT = LL.getValueType();
+ if (LegalOperations || VT != MVT::i1)
+ if (VT != getSetCCResultType(OpVT))
+ return SDValue();
+ if (OpVT != RL.getValueType())
+ return SDValue();
+
+ ISD::CondCode CC0 = cast<CondCodeSDNode>(N0CC)->get();
+ ISD::CondCode CC1 = cast<CondCodeSDNode>(N1CC)->get();
+ bool IsInteger = OpVT.isInteger();
+ if (LR == RR && CC0 == CC1 && IsInteger) {
+ bool IsZero = isNullConstantOrNullSplatConstant(LR);
+ bool IsNeg1 = isAllOnesConstantOrAllOnesSplatConstant(LR);
+
+ // All bits clear?
+ bool AndEqZero = IsAnd && CC1 == ISD::SETEQ && IsZero;
+ // All sign bits clear?
+ bool AndGtNeg1 = IsAnd && CC1 == ISD::SETGT && IsNeg1;
+ // Any bits set?
+ bool OrNeZero = !IsAnd && CC1 == ISD::SETNE && IsZero;
+ // Any sign bits set?
+ bool OrLtZero = !IsAnd && CC1 == ISD::SETLT && IsZero;
+
+ // (and (seteq X, 0), (seteq Y, 0)) --> (seteq (or X, Y), 0)
+ // (and (setgt X, -1), (setgt Y, -1)) --> (setgt (or X, Y), -1)
+ // (or (setne X, 0), (setne Y, 0)) --> (setne (or X, Y), 0)
+ // (or (setlt X, 0), (setlt Y, 0)) --> (setlt (or X, Y), 0)
+ if (AndEqZero || AndGtNeg1 || OrNeZero || OrLtZero) {
+ SDValue Or = DAG.getNode(ISD::OR, SDLoc(N0), OpVT, LL, RL);
+ AddToWorklist(Or.getNode());
+ return DAG.getSetCC(DL, VT, Or, LR, CC1);
+ }
+
+ // All bits set?
+ bool AndEqNeg1 = IsAnd && CC1 == ISD::SETEQ && IsNeg1;
+ // All sign bits set?
+ bool AndLtZero = IsAnd && CC1 == ISD::SETLT && IsZero;
+ // Any bits clear?
+ bool OrNeNeg1 = !IsAnd && CC1 == ISD::SETNE && IsNeg1;
+ // Any sign bits clear?
+ bool OrGtNeg1 = !IsAnd && CC1 == ISD::SETGT && IsNeg1;
+
+ // (and (seteq X, -1), (seteq Y, -1)) --> (seteq (and X, Y), -1)
+ // (and (setlt X, 0), (setlt Y, 0)) --> (setlt (and X, Y), 0)
+ // (or (setne X, -1), (setne Y, -1)) --> (setne (and X, Y), -1)
+ // (or (setgt X, -1), (setgt Y -1)) --> (setgt (and X, Y), -1)
+ if (AndEqNeg1 || AndLtZero || OrNeNeg1 || OrGtNeg1) {
+ SDValue And = DAG.getNode(ISD::AND, SDLoc(N0), OpVT, LL, RL);
+ AddToWorklist(And.getNode());
+ return DAG.getSetCC(DL, VT, And, LR, CC1);
+ }
+ }
+
+ // TODO: What is the 'or' equivalent of this fold?
+ // (and (setne X, 0), (setne X, -1)) --> (setuge (add X, 1), 2)
+ if (IsAnd && LL == RL && CC0 == CC1 && IsInteger && CC0 == ISD::SETNE &&
+ ((isNullConstant(LR) && isAllOnesConstant(RR)) ||
+ (isAllOnesConstant(LR) && isNullConstant(RR)))) {
+ SDValue One = DAG.getConstant(1, DL, OpVT);
+ SDValue Two = DAG.getConstant(2, DL, OpVT);
+ SDValue Add = DAG.getNode(ISD::ADD, SDLoc(N0), OpVT, LL, One);
+ AddToWorklist(Add.getNode());
+ return DAG.getSetCC(DL, VT, Add, Two, ISD::SETUGE);
+ }
+
+ // Try more general transforms if the predicates match and the only user of
+ // the compares is the 'and' or 'or'.
+ if (IsInteger && TLI.convertSetCCLogicToBitwiseLogic(OpVT) && CC0 == CC1 &&
+ N0.hasOneUse() && N1.hasOneUse()) {
+ // and (seteq A, B), (seteq C, D) --> seteq (or (xor A, B), (xor C, D)), 0
+ // or (setne A, B), (setne C, D) --> setne (or (xor A, B), (xor C, D)), 0
+ if ((IsAnd && CC1 == ISD::SETEQ) || (!IsAnd && CC1 == ISD::SETNE)) {
+ SDValue XorL = DAG.getNode(ISD::XOR, SDLoc(N0), OpVT, LL, LR);
+ SDValue XorR = DAG.getNode(ISD::XOR, SDLoc(N1), OpVT, RL, RR);
+ SDValue Or = DAG.getNode(ISD::OR, DL, OpVT, XorL, XorR);
+ SDValue Zero = DAG.getConstant(0, DL, OpVT);
+ return DAG.getSetCC(DL, VT, Or, Zero, CC1);
+ }
+ }
+
+ // Canonicalize equivalent operands to LL == RL.
+ if (LL == RR && LR == RL) {
+ CC1 = ISD::getSetCCSwappedOperands(CC1);
+ std::swap(RL, RR);
+ }
+
+ // (and (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
+ // (or (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
+ if (LL == RL && LR == RR) {
+ ISD::CondCode NewCC = IsAnd ? ISD::getSetCCAndOperation(CC0, CC1, IsInteger)
+ : ISD::getSetCCOrOperation(CC0, CC1, IsInteger);
+ if (NewCC != ISD::SETCC_INVALID &&
+ (!LegalOperations ||
+ (TLI.isCondCodeLegal(NewCC, LL.getSimpleValueType()) &&
+ TLI.isOperationLegal(ISD::SETCC, OpVT))))
+ return DAG.getSetCC(DL, VT, LL, LR, NewCC);
+ }
+
+ return SDValue();
+}
+
/// This contains all DAGCombine rules which reduce two values combined by
/// an And operation to a single value. This makes them reusable in the context
/// of visitSELECT(). Rules involving constants are not included as
/// visitSELECT() already handles those cases.
-SDValue DAGCombiner::visitANDLike(SDValue N0, SDValue N1,
- SDNode *LocReference) {
+SDValue DAGCombiner::visitANDLike(SDValue N0, SDValue N1, SDNode *N) {
EVT VT = N1.getValueType();
+ SDLoc DL(N);
// fold (and x, undef) -> 0
if (N0.isUndef() || N1.isUndef())
- return DAG.getConstant(0, SDLoc(LocReference), VT);
- // fold (and (setcc x), (setcc y)) -> (setcc (and x, y))
- SDValue LL, LR, RL, RR, CC0, CC1;
- if (isSetCCEquivalent(N0, LL, LR, CC0) && isSetCCEquivalent(N1, RL, RR, CC1)){
- ISD::CondCode Op0 = cast<CondCodeSDNode>(CC0)->get();
- ISD::CondCode Op1 = cast<CondCodeSDNode>(CC1)->get();
-
- if (LR == RR && isa<ConstantSDNode>(LR) && Op0 == Op1 &&
- LL.getValueType().isInteger()) {
- // fold (and (seteq X, 0), (seteq Y, 0)) -> (seteq (or X, Y), 0)
- if (isNullConstant(LR) && Op1 == ISD::SETEQ) {
- EVT CCVT = getSetCCResultType(LR.getValueType());
- if (VT == CCVT || (!LegalOperations && VT == MVT::i1)) {
- SDValue ORNode = DAG.getNode(ISD::OR, SDLoc(N0),
- LR.getValueType(), LL, RL);
- AddToWorklist(ORNode.getNode());
- return DAG.getSetCC(SDLoc(LocReference), VT, ORNode, LR, Op1);
- }
- }
- if (isAllOnesConstant(LR)) {
- // fold (and (seteq X, -1), (seteq Y, -1)) -> (seteq (and X, Y), -1)
- if (Op1 == ISD::SETEQ) {
- EVT CCVT = getSetCCResultType(LR.getValueType());
- if (VT == CCVT || (!LegalOperations && VT == MVT::i1)) {
- SDValue ANDNode = DAG.getNode(ISD::AND, SDLoc(N0),
- LR.getValueType(), LL, RL);
- AddToWorklist(ANDNode.getNode());
- return DAG.getSetCC(SDLoc(LocReference), VT, ANDNode, LR, Op1);
- }
- }
- // fold (and (setgt X, -1), (setgt Y, -1)) -> (setgt (or X, Y), -1)
- if (Op1 == ISD::SETGT) {
- EVT CCVT = getSetCCResultType(LR.getValueType());
- if (VT == CCVT || (!LegalOperations && VT == MVT::i1)) {
- SDValue ORNode = DAG.getNode(ISD::OR, SDLoc(N0),
- LR.getValueType(), LL, RL);
- AddToWorklist(ORNode.getNode());
- return DAG.getSetCC(SDLoc(LocReference), VT, ORNode, LR, Op1);
- }
- }
- }
- }
- // Simplify (and (setne X, 0), (setne X, -1)) -> (setuge (add X, 1), 2)
- if (LL == RL && isa<ConstantSDNode>(LR) && isa<ConstantSDNode>(RR) &&
- Op0 == Op1 && LL.getValueType().isInteger() &&
- Op0 == ISD::SETNE && ((isNullConstant(LR) && isAllOnesConstant(RR)) ||
- (isAllOnesConstant(LR) && isNullConstant(RR)))) {
- EVT CCVT = getSetCCResultType(LL.getValueType());
- if (VT == CCVT || (!LegalOperations && VT == MVT::i1)) {
- SDLoc DL(N0);
- SDValue ADDNode = DAG.getNode(ISD::ADD, DL, LL.getValueType(),
- LL, DAG.getConstant(1, DL,
- LL.getValueType()));
- AddToWorklist(ADDNode.getNode());
- return DAG.getSetCC(SDLoc(LocReference), VT, ADDNode,
- DAG.getConstant(2, DL, LL.getValueType()),
- ISD::SETUGE);
- }
- }
- // canonicalize equivalent to ll == rl
- if (LL == RR && LR == RL) {
- Op1 = ISD::getSetCCSwappedOperands(Op1);
- std::swap(RL, RR);
- }
- if (LL == RL && LR == RR) {
- bool isInteger = LL.getValueType().isInteger();
- ISD::CondCode Result = ISD::getSetCCAndOperation(Op0, Op1, isInteger);
- if (Result != ISD::SETCC_INVALID &&
- (!LegalOperations ||
- (TLI.isCondCodeLegal(Result, LL.getSimpleValueType()) &&
- TLI.isOperationLegal(ISD::SETCC, LL.getValueType())))) {
- EVT CCVT = getSetCCResultType(LL.getValueType());
- if (N0.getValueType() == CCVT ||
- (!LegalOperations && N0.getValueType() == MVT::i1))
- return DAG.getSetCC(SDLoc(LocReference), N0.getValueType(),
- LL, LR, Result);
- }
- }
- }
+ return DAG.getConstant(0, DL, VT);
+
+ if (SDValue V = foldLogicOfSetCCs(true, N0, N1, DL))
+ return V;
if (N0.getOpcode() == ISD::ADD && N1.getOpcode() == ISD::SRL &&
VT.getSizeInBits() <= 64) {
@@ -3037,13 +3532,13 @@ SDValue DAGCombiner::visitANDLike(SDValue N0, SDValue N1,
if (DAG.MaskedValueIsZero(N0.getOperand(1), Mask)) {
ADDC |= Mask;
if (TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
- SDLoc DL(N0);
+ SDLoc DL0(N0);
SDValue NewAdd =
- DAG.getNode(ISD::ADD, DL, VT,
+ DAG.getNode(ISD::ADD, DL0, VT,
N0.getOperand(0), DAG.getConstant(ADDC, DL, VT));
CombineTo(N0.getNode(), NewAdd);
// Return N so it doesn't get rechecked!
- return SDValue(LocReference, 0);
+ return SDValue(N, 0);
}
}
}
@@ -3068,7 +3563,7 @@ SDValue DAGCombiner::visitANDLike(SDValue N0, SDValue N1,
unsigned MaskBits = AndMask.countTrailingOnes();
EVT HalfVT = EVT::getIntegerVT(*DAG.getContext(), Size / 2);
- if (APIntOps::isMask(AndMask) &&
+ if (AndMask.isMask() &&
// Required bits must not span the two halves of the integer and
// must fit in the half size type.
(ShiftBits + MaskBits <= Size / 2) &&
@@ -3108,7 +3603,7 @@ bool DAGCombiner::isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
bool &NarrowLoad) {
uint32_t ActiveBits = AndC->getAPIntValue().getActiveBits();
- if (ActiveBits == 0 || !APIntOps::isMask(ActiveBits, AndC->getAPIntValue()))
+ if (ActiveBits == 0 || !AndC->getAPIntValue().isMask(ActiveBits))
return false;
ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
@@ -3191,13 +3686,17 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
if (N1C && DAG.MaskedValueIsZero(SDValue(N, 0),
APInt::getAllOnesValue(BitWidth)))
return DAG.getConstant(0, SDLoc(N), VT);
+
+ if (SDValue NewSel = foldBinOpIntoSelect(N))
+ return NewSel;
+
// reassociate and
if (SDValue RAND = ReassociateOps(ISD::AND, SDLoc(N), N0, N1))
return RAND;
// fold (and (or x, C), D) -> D if (C & D) == D
if (N1C && N0.getOpcode() == ISD::OR)
if (ConstantSDNode *ORI = isConstOrConstSplat(N0.getOperand(1)))
- if ((ORI->getAPIntValue() & N1C->getAPIntValue()) == N1C->getAPIntValue())
+ if (N1C->getAPIntValue().isSubsetOf(ORI->getAPIntValue()))
return N1;
// fold (and (any_ext V), c) -> (zero_ext V) if 'and' only clears top bits.
if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
@@ -3299,6 +3798,10 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
// If the load type was an EXTLOAD, convert to ZEXTLOAD in order to
// preserve semantics once we get rid of the AND.
SDValue NewLoad(Load, 0);
+
+ // Fold the AND away. NewLoad may get replaced immediately.
+ CombineTo(N, (N0.getNode() == Load) ? NewLoad : N0);
+
if (Load->getExtensionType() == ISD::EXTLOAD) {
NewLoad = DAG.getLoad(Load->getAddressingMode(), ISD::ZEXTLOAD,
Load->getValueType(0), SDLoc(Load),
@@ -3316,10 +3819,6 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
}
}
- // Fold the AND away, taking care not to fold to the old load node if we
- // replaced it.
- CombineTo(N, (N0.getNode() == Load) ? NewLoad : N0);
-
return SDValue(N, 0); // Return N so it doesn't get rechecked!
}
}
@@ -3398,9 +3897,8 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
// Note: the SimplifyDemandedBits fold below can make an information-losing
// transform, and then we have no way to find this better fold.
if (N1C && N1C->isOne() && N0.getOpcode() == ISD::SUB) {
- ConstantSDNode *SubLHS = isConstOrConstSplat(N0.getOperand(0));
- SDValue SubRHS = N0.getOperand(1);
- if (SubLHS && SubLHS->isNullValue()) {
+ if (isNullConstantOrNullSplatConstant(N0.getOperand(0))) {
+ SDValue SubRHS = N0.getOperand(1);
if (SubRHS.getOpcode() == ISD::ZERO_EXTEND &&
SubRHS.getOperand(0).getScalarValueSizeInBits() == 1)
return SubRHS;
@@ -3412,7 +3910,7 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
// fold (and (sign_extend_inreg x, i16 to i32), 1) -> (and x, 1)
// fold (and (sra)) -> (and (srl)) when possible.
- if (!VT.isVector() && SimplifyDemandedBits(SDValue(N, 0)))
+ if (SimplifyDemandedBits(SDValue(N, 0)))
return SDValue(N, 0);
// fold (zext_inreg (extload x)) -> (zextload x)
@@ -3473,7 +3971,7 @@ SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
EVT VT = N->getValueType(0);
if (VT != MVT::i64 && VT != MVT::i32 && VT != MVT::i16)
return SDValue();
- if (!TLI.isOperationLegal(ISD::BSWAP, VT))
+ if (!TLI.isOperationLegalOrCustom(ISD::BSWAP, VT))
return SDValue();
// Recognize (and (shl a, 8), 0xff), (and (srl a, 8), 0xff00)
@@ -3585,27 +4083,36 @@ static bool isBSwapHWordElement(SDValue N, MutableArrayRef<SDNode *> Parts) {
if (Opc != ISD::AND && Opc != ISD::SHL && Opc != ISD::SRL)
return false;
- ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N.getOperand(1));
+ SDValue N0 = N.getOperand(0);
+ unsigned Opc0 = N0.getOpcode();
+ if (Opc0 != ISD::AND && Opc0 != ISD::SHL && Opc0 != ISD::SRL)
+ return false;
+
+ ConstantSDNode *N1C = nullptr;
+ // SHL or SRL: look upstream for AND mask operand
+ if (Opc == ISD::AND)
+ N1C = dyn_cast<ConstantSDNode>(N.getOperand(1));
+ else if (Opc0 == ISD::AND)
+ N1C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
if (!N1C)
return false;
- unsigned Num;
+ unsigned MaskByteOffset;
switch (N1C->getZExtValue()) {
default:
return false;
- case 0xFF: Num = 0; break;
- case 0xFF00: Num = 1; break;
- case 0xFF0000: Num = 2; break;
- case 0xFF000000: Num = 3; break;
+ case 0xFF: MaskByteOffset = 0; break;
+ case 0xFF00: MaskByteOffset = 1; break;
+ case 0xFF0000: MaskByteOffset = 2; break;
+ case 0xFF000000: MaskByteOffset = 3; break;
}
// Look for (x & 0xff) << 8 as well as ((x << 8) & 0xff00).
- SDValue N0 = N.getOperand(0);
if (Opc == ISD::AND) {
- if (Num == 0 || Num == 2) {
+ if (MaskByteOffset == 0 || MaskByteOffset == 2) {
// (x >> 8) & 0xff
// (x >> 8) & 0xff0000
- if (N0.getOpcode() != ISD::SRL)
+ if (Opc0 != ISD::SRL)
return false;
ConstantSDNode *C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
if (!C || C->getZExtValue() != 8)
@@ -3613,7 +4120,7 @@ static bool isBSwapHWordElement(SDValue N, MutableArrayRef<SDNode *> Parts) {
} else {
// (x << 8) & 0xff00
// (x << 8) & 0xff000000
- if (N0.getOpcode() != ISD::SHL)
+ if (Opc0 != ISD::SHL)
return false;
ConstantSDNode *C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
if (!C || C->getZExtValue() != 8)
@@ -3622,7 +4129,7 @@ static bool isBSwapHWordElement(SDValue N, MutableArrayRef<SDNode *> Parts) {
} else if (Opc == ISD::SHL) {
// (x & 0xff) << 8
// (x & 0xff0000) << 8
- if (Num != 0 && Num != 2)
+ if (MaskByteOffset != 0 && MaskByteOffset != 2)
return false;
ConstantSDNode *C = dyn_cast<ConstantSDNode>(N.getOperand(1));
if (!C || C->getZExtValue() != 8)
@@ -3630,17 +4137,17 @@ static bool isBSwapHWordElement(SDValue N, MutableArrayRef<SDNode *> Parts) {
} else { // Opc == ISD::SRL
// (x & 0xff00) >> 8
// (x & 0xff000000) >> 8
- if (Num != 1 && Num != 3)
+ if (MaskByteOffset != 1 && MaskByteOffset != 3)
return false;
ConstantSDNode *C = dyn_cast<ConstantSDNode>(N.getOperand(1));
if (!C || C->getZExtValue() != 8)
return false;
}
- if (Parts[Num])
+ if (Parts[MaskByteOffset])
return false;
- Parts[Num] = N0.getOperand(0).getNode();
+ Parts[MaskByteOffset] = N0.getOperand(0).getNode();
return true;
}
@@ -3657,7 +4164,7 @@ SDValue DAGCombiner::MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1) {
EVT VT = N->getValueType(0);
if (VT != MVT::i32)
return SDValue();
- if (!TLI.isOperationLegal(ISD::BSWAP, VT))
+ if (!TLI.isOperationLegalOrCustom(ISD::BSWAP, VT))
return SDValue();
// Look for either
@@ -3672,18 +4179,16 @@ SDValue DAGCombiner::MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1) {
if (N1.getOpcode() == ISD::OR &&
N00.getNumOperands() == 2 && N01.getNumOperands() == 2) {
// (or (or (and), (and)), (or (and), (and)))
- SDValue N000 = N00.getOperand(0);
- if (!isBSwapHWordElement(N000, Parts))
+ if (!isBSwapHWordElement(N00, Parts))
return SDValue();
- SDValue N001 = N00.getOperand(1);
- if (!isBSwapHWordElement(N001, Parts))
+ if (!isBSwapHWordElement(N01, Parts))
return SDValue();
- SDValue N010 = N01.getOperand(0);
- if (!isBSwapHWordElement(N010, Parts))
+ SDValue N10 = N1.getOperand(0);
+ if (!isBSwapHWordElement(N10, Parts))
return SDValue();
- SDValue N011 = N01.getOperand(1);
- if (!isBSwapHWordElement(N011, Parts))
+ SDValue N11 = N1.getOperand(1);
+ if (!isBSwapHWordElement(N11, Parts))
return SDValue();
} else {
// (or (or (or (and), (and)), (and)), (and))
@@ -3723,65 +4228,16 @@ SDValue DAGCombiner::MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1) {
/// This contains all DAGCombine rules which reduce two values combined by
/// an Or operation to a single value \see visitANDLike().
-SDValue DAGCombiner::visitORLike(SDValue N0, SDValue N1, SDNode *LocReference) {
+SDValue DAGCombiner::visitORLike(SDValue N0, SDValue N1, SDNode *N) {
EVT VT = N1.getValueType();
+ SDLoc DL(N);
+
// fold (or x, undef) -> -1
- if (!LegalOperations &&
- (N0.isUndef() || N1.isUndef())) {
- EVT EltVT = VT.isVector() ? VT.getVectorElementType() : VT;
- return DAG.getConstant(APInt::getAllOnesValue(EltVT.getSizeInBits()),
- SDLoc(LocReference), VT);
- }
- // fold (or (setcc x), (setcc y)) -> (setcc (or x, y))
- SDValue LL, LR, RL, RR, CC0, CC1;
- if (isSetCCEquivalent(N0, LL, LR, CC0) && isSetCCEquivalent(N1, RL, RR, CC1)){
- ISD::CondCode Op0 = cast<CondCodeSDNode>(CC0)->get();
- ISD::CondCode Op1 = cast<CondCodeSDNode>(CC1)->get();
-
- if (LR == RR && Op0 == Op1 && LL.getValueType().isInteger()) {
- // fold (or (setne X, 0), (setne Y, 0)) -> (setne (or X, Y), 0)
- // fold (or (setlt X, 0), (setlt Y, 0)) -> (setne (or X, Y), 0)
- if (isNullConstant(LR) && (Op1 == ISD::SETNE || Op1 == ISD::SETLT)) {
- EVT CCVT = getSetCCResultType(LR.getValueType());
- if (VT == CCVT || (!LegalOperations && VT == MVT::i1)) {
- SDValue ORNode = DAG.getNode(ISD::OR, SDLoc(LR),
- LR.getValueType(), LL, RL);
- AddToWorklist(ORNode.getNode());
- return DAG.getSetCC(SDLoc(LocReference), VT, ORNode, LR, Op1);
- }
- }
- // fold (or (setne X, -1), (setne Y, -1)) -> (setne (and X, Y), -1)
- // fold (or (setgt X, -1), (setgt Y -1)) -> (setgt (and X, Y), -1)
- if (isAllOnesConstant(LR) && (Op1 == ISD::SETNE || Op1 == ISD::SETGT)) {
- EVT CCVT = getSetCCResultType(LR.getValueType());
- if (VT == CCVT || (!LegalOperations && VT == MVT::i1)) {
- SDValue ANDNode = DAG.getNode(ISD::AND, SDLoc(LR),
- LR.getValueType(), LL, RL);
- AddToWorklist(ANDNode.getNode());
- return DAG.getSetCC(SDLoc(LocReference), VT, ANDNode, LR, Op1);
- }
- }
- }
- // canonicalize equivalent to ll == rl
- if (LL == RR && LR == RL) {
- Op1 = ISD::getSetCCSwappedOperands(Op1);
- std::swap(RL, RR);
- }
- if (LL == RL && LR == RR) {
- bool isInteger = LL.getValueType().isInteger();
- ISD::CondCode Result = ISD::getSetCCOrOperation(Op0, Op1, isInteger);
- if (Result != ISD::SETCC_INVALID &&
- (!LegalOperations ||
- (TLI.isCondCodeLegal(Result, LL.getSimpleValueType()) &&
- TLI.isOperationLegal(ISD::SETCC, LL.getValueType())))) {
- EVT CCVT = getSetCCResultType(LL.getValueType());
- if (N0.getValueType() == CCVT ||
- (!LegalOperations && N0.getValueType() == MVT::i1))
- return DAG.getSetCC(SDLoc(LocReference), N0.getValueType(),
- LL, LR, Result);
- }
- }
- }
+ if (!LegalOperations && (N0.isUndef() || N1.isUndef()))
+ return DAG.getAllOnesConstant(DL, VT);
+
+ if (SDValue V = foldLogicOfSetCCs(false, N0, N1, DL))
+ return V;
// (or (and X, C1), (and Y, C2)) -> (and (or X, Y), C3) if possible.
if (N0.getOpcode() == ISD::AND && N1.getOpcode() == ISD::AND &&
@@ -3802,7 +4258,6 @@ SDValue DAGCombiner::visitORLike(SDValue N0, SDValue N1, SDNode *LocReference) {
DAG.MaskedValueIsZero(N1.getOperand(0), LHSMask&~RHSMask)) {
SDValue X = DAG.getNode(ISD::OR, SDLoc(N0), VT,
N0.getOperand(0), N1.getOperand(0));
- SDLoc DL(LocReference);
return DAG.getNode(ISD::AND, DL, VT, X,
DAG.getConstant(LHSMask | RHSMask, DL, VT));
}
@@ -3818,7 +4273,7 @@ SDValue DAGCombiner::visitORLike(SDValue N0, SDValue N1, SDNode *LocReference) {
(N0.getNode()->hasOneUse() || N1.getNode()->hasOneUse())) {
SDValue X = DAG.getNode(ISD::OR, SDLoc(N0), VT,
N0.getOperand(1), N1.getOperand(1));
- return DAG.getNode(ISD::AND, SDLoc(LocReference), VT, N0.getOperand(0), X);
+ return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), X);
}
return SDValue();
@@ -3847,14 +4302,10 @@ SDValue DAGCombiner::visitOR(SDNode *N) {
// fold (or x, -1) -> -1, vector edition
if (ISD::isBuildVectorAllOnes(N0.getNode()))
// do not return N0, because undef node may exist in N0
- return DAG.getConstant(
- APInt::getAllOnesValue(N0.getScalarValueSizeInBits()), SDLoc(N),
- N0.getValueType());
+ return DAG.getAllOnesConstant(SDLoc(N), N0.getValueType());
if (ISD::isBuildVectorAllOnes(N1.getNode()))
// do not return N1, because undef node may exist in N1
- return DAG.getConstant(
- APInt::getAllOnesValue(N1.getScalarValueSizeInBits()), SDLoc(N),
- N1.getValueType());
+ return DAG.getAllOnesConstant(SDLoc(N), N1.getValueType());
// fold (or (shuf A, V_0, MA), (shuf B, V_0, MB)) -> (shuf A, B, Mask)
// Do this only if the resulting shuffle is legal.
@@ -3867,7 +4318,7 @@ SDValue DAGCombiner::visitOR(SDNode *N) {
bool ZeroN10 = ISD::isBuildVectorAllZeros(N1.getOperand(0).getNode());
bool ZeroN11 = ISD::isBuildVectorAllZeros(N1.getOperand(1).getNode());
// Ensure both shuffles have a zero input.
- if ((ZeroN00 || ZeroN01) && (ZeroN10 || ZeroN11)) {
+ if ((ZeroN00 != ZeroN01) && (ZeroN10 != ZeroN11)) {
assert((!ZeroN00 || !ZeroN01) && "Both inputs zero!");
assert((!ZeroN10 || !ZeroN11) && "Both inputs zero!");
const ShuffleVectorSDNode *SV0 = cast<ShuffleVectorSDNode>(N0);
@@ -3939,6 +4390,10 @@ SDValue DAGCombiner::visitOR(SDNode *N) {
// fold (or x, -1) -> -1
if (isAllOnesConstant(N1))
return N1;
+
+ if (SDValue NewSel = foldBinOpIntoSelect(N))
+ return NewSel;
+
// fold (or x, c) -> c iff (x & ~c) == 0
if (N1C && DAG.MaskedValueIsZero(N0, ~N1C->getAPIntValue()))
return N1;
@@ -3955,20 +4410,22 @@ SDValue DAGCombiner::visitOR(SDNode *N) {
// reassociate or
if (SDValue ROR = ReassociateOps(ISD::OR, SDLoc(N), N0, N1))
return ROR;
+
// Canonicalize (or (and X, c1), c2) -> (and (or X, c2), c1|c2)
- // iff (c1 & c2) == 0.
- if (N1C && N0.getOpcode() == ISD::AND && N0.getNode()->hasOneUse() &&
- isa<ConstantSDNode>(N0.getOperand(1))) {
- ConstantSDNode *C1 = cast<ConstantSDNode>(N0.getOperand(1));
- if ((C1->getAPIntValue() & N1C->getAPIntValue()) != 0) {
- if (SDValue COR = DAG.FoldConstantArithmetic(ISD::OR, SDLoc(N1), VT,
- N1C, C1))
- return DAG.getNode(
- ISD::AND, SDLoc(N), VT,
- DAG.getNode(ISD::OR, SDLoc(N0), VT, N0.getOperand(0), N1), COR);
- return SDValue();
+ // iff (c1 & c2) != 0.
+ if (N1C && N0.getOpcode() == ISD::AND && N0.getNode()->hasOneUse()) {
+ if (ConstantSDNode *C1 = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
+ if (C1->getAPIntValue().intersects(N1C->getAPIntValue())) {
+ if (SDValue COR =
+ DAG.FoldConstantArithmetic(ISD::OR, SDLoc(N1), VT, N1C, C1))
+ return DAG.getNode(
+ ISD::AND, SDLoc(N), VT,
+ DAG.getNode(ISD::OR, SDLoc(N0), VT, N0.getOperand(0), N1), COR);
+ return SDValue();
+ }
}
}
+
// Simplify: (or (op x...), (op y...)) -> (op (or x, y))
if (N0.getOpcode() == N1.getOpcode())
if (SDValue Tmp = SimplifyBinOpWithSameOpcodeHands(N))
@@ -3978,9 +4435,11 @@ SDValue DAGCombiner::visitOR(SDNode *N) {
if (SDNode *Rot = MatchRotate(N0, N1, SDLoc(N)))
return SDValue(Rot, 0);
+ if (SDValue Load = MatchLoadCombine(N))
+ return Load;
+
// Simplify the operands using demanded-bits information.
- if (!VT.isVector() &&
- SimplifyDemandedBits(SDValue(N, 0)))
+ if (SimplifyDemandedBits(SDValue(N, 0)))
return SDValue(N, 0);
return SDValue();
@@ -4134,6 +4593,20 @@ SDNode *DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos,
return nullptr;
}
+// if Left + Right == Sum (constant or constant splat vector)
+static bool sumMatchConstant(SDValue Left, SDValue Right, unsigned Sum,
+ SelectionDAG &DAG, const SDLoc &DL) {
+ EVT ShiftVT = Left.getValueType();
+ if (ShiftVT != Right.getValueType()) return false;
+
+ SDValue ShiftSum = DAG.FoldConstantArithmetic(ISD::ADD, DL, ShiftVT,
+ Left.getNode(), Right.getNode());
+ if (!ShiftSum) return false;
+
+ ConstantSDNode *CSum = isConstOrConstSplat(ShiftSum);
+ return CSum && CSum->getZExtValue() == Sum;
+}
+
// MatchRotate - Handle an 'or' of two operands. If this is one of the many
// idioms for rotate, and if the target supports rotation instructions, generate
// a rot[lr].
@@ -4179,31 +4652,24 @@ SDNode *DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) {
// fold (or (shl x, C1), (srl x, C2)) -> (rotl x, C1)
// fold (or (shl x, C1), (srl x, C2)) -> (rotr x, C2)
- if (isConstOrConstSplat(LHSShiftAmt) && isConstOrConstSplat(RHSShiftAmt)) {
- uint64_t LShVal = isConstOrConstSplat(LHSShiftAmt)->getZExtValue();
- uint64_t RShVal = isConstOrConstSplat(RHSShiftAmt)->getZExtValue();
- if ((LShVal + RShVal) != EltSizeInBits)
- return nullptr;
-
+ if (sumMatchConstant(LHSShiftAmt, RHSShiftAmt, EltSizeInBits, DAG, DL)) {
SDValue Rot = DAG.getNode(HasROTL ? ISD::ROTL : ISD::ROTR, DL, VT,
LHSShiftArg, HasROTL ? LHSShiftAmt : RHSShiftAmt);
// If there is an AND of either shifted operand, apply it to the result.
if (LHSMask.getNode() || RHSMask.getNode()) {
- APInt AllBits = APInt::getAllOnesValue(EltSizeInBits);
- SDValue Mask = DAG.getConstant(AllBits, DL, VT);
+ SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
+ SDValue Mask = AllOnes;
if (LHSMask.getNode()) {
- APInt RHSBits = APInt::getLowBitsSet(EltSizeInBits, LShVal);
+ SDValue RHSBits = DAG.getNode(ISD::SRL, DL, VT, AllOnes, RHSShiftAmt);
Mask = DAG.getNode(ISD::AND, DL, VT, Mask,
- DAG.getNode(ISD::OR, DL, VT, LHSMask,
- DAG.getConstant(RHSBits, DL, VT)));
+ DAG.getNode(ISD::OR, DL, VT, LHSMask, RHSBits));
}
if (RHSMask.getNode()) {
- APInt LHSBits = APInt::getHighBitsSet(EltSizeInBits, RShVal);
+ SDValue LHSBits = DAG.getNode(ISD::SHL, DL, VT, AllOnes, LHSShiftAmt);
Mask = DAG.getNode(ISD::AND, DL, VT, Mask,
- DAG.getNode(ISD::OR, DL, VT, RHSMask,
- DAG.getConstant(LHSBits, DL, VT)));
+ DAG.getNode(ISD::OR, DL, VT, RHSMask, LHSBits));
}
Rot = DAG.getNode(ISD::AND, DL, VT, Rot, Mask);
@@ -4246,109 +4712,299 @@ SDNode *DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) {
}
namespace {
-/// Helper struct to parse and store a memory address as base + index + offset.
-/// We ignore sign extensions when it is safe to do so.
-/// The following two expressions are not equivalent. To differentiate we need
-/// to store whether there was a sign extension involved in the index
-/// computation.
-/// (load (i64 add (i64 copyfromreg %c)
-/// (i64 signextend (add (i8 load %index)
-/// (i8 1))))
-/// vs
-///
-/// (load (i64 add (i64 copyfromreg %c)
-/// (i64 signextend (i32 add (i32 signextend (i8 load %index))
-/// (i32 1)))))
-struct BaseIndexOffset {
- SDValue Base;
- SDValue Index;
- int64_t Offset;
- bool IsIndexSignExt;
-
- BaseIndexOffset() : Offset(0), IsIndexSignExt(false) {}
-
- BaseIndexOffset(SDValue Base, SDValue Index, int64_t Offset,
- bool IsIndexSignExt) :
- Base(Base), Index(Index), Offset(Offset), IsIndexSignExt(IsIndexSignExt) {}
-
- bool equalBaseIndex(const BaseIndexOffset &Other) {
- return Other.Base == Base && Other.Index == Index &&
- Other.IsIndexSignExt == IsIndexSignExt;
- }
-
- /// Parses tree in Ptr for base, index, offset addresses.
- static BaseIndexOffset match(SDValue Ptr, SelectionDAG &DAG,
- int64_t PartialOffset = 0) {
- bool IsIndexSignExt = false;
-
- // Split up a folded GlobalAddress+Offset into its component parts.
- if (GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(Ptr))
- if (GA->getOpcode() == ISD::GlobalAddress && GA->getOffset() != 0) {
- return BaseIndexOffset(DAG.getGlobalAddress(GA->getGlobal(),
- SDLoc(GA),
- GA->getValueType(0),
- /*Offset=*/PartialOffset,
- /*isTargetGA=*/false,
- GA->getTargetFlags()),
- SDValue(),
- GA->getOffset(),
- IsIndexSignExt);
- }
-
- // We only can pattern match BASE + INDEX + OFFSET. If Ptr is not an ADD
- // instruction, then it could be just the BASE or everything else we don't
- // know how to handle. Just use Ptr as BASE and give up.
- if (Ptr->getOpcode() != ISD::ADD)
- return BaseIndexOffset(Ptr, SDValue(), PartialOffset, IsIndexSignExt);
-
- // We know that we have at least an ADD instruction. Try to pattern match
- // the simple case of BASE + OFFSET.
- if (isa<ConstantSDNode>(Ptr->getOperand(1))) {
- int64_t Offset = cast<ConstantSDNode>(Ptr->getOperand(1))->getSExtValue();
- return match(Ptr->getOperand(0), DAG, Offset + PartialOffset);
- }
-
- // Inside a loop the current BASE pointer is calculated using an ADD and a
- // MUL instruction. In this case Ptr is the actual BASE pointer.
- // (i64 add (i64 %array_ptr)
- // (i64 mul (i64 %induction_var)
- // (i64 %element_size)))
- if (Ptr->getOperand(1)->getOpcode() == ISD::MUL)
- return BaseIndexOffset(Ptr, SDValue(), PartialOffset, IsIndexSignExt);
-
- // Look at Base + Index + Offset cases.
- SDValue Base = Ptr->getOperand(0);
- SDValue IndexOffset = Ptr->getOperand(1);
-
- // Skip signextends.
- if (IndexOffset->getOpcode() == ISD::SIGN_EXTEND) {
- IndexOffset = IndexOffset->getOperand(0);
- IsIndexSignExt = true;
- }
-
- // Either the case of Base + Index (no offset) or something else.
- if (IndexOffset->getOpcode() != ISD::ADD)
- return BaseIndexOffset(Base, IndexOffset, PartialOffset, IsIndexSignExt);
-
- // Now we have the case of Base + Index + offset.
- SDValue Index = IndexOffset->getOperand(0);
- SDValue Offset = IndexOffset->getOperand(1);
-
- if (!isa<ConstantSDNode>(Offset))
- return BaseIndexOffset(Ptr, SDValue(), PartialOffset, IsIndexSignExt);
-
- // Ignore signextends.
- if (Index->getOpcode() == ISD::SIGN_EXTEND) {
- Index = Index->getOperand(0);
- IsIndexSignExt = true;
- } else IsIndexSignExt = false;
-
- int64_t Off = cast<ConstantSDNode>(Offset)->getSExtValue();
- return BaseIndexOffset(Base, Index, Off + PartialOffset, IsIndexSignExt);
+/// Represents known origin of an individual byte in load combine pattern. The
+/// value of the byte is either constant zero or comes from memory.
+struct ByteProvider {
+ // For constant zero providers Load is set to nullptr. For memory providers
+ // Load represents the node which loads the byte from memory.
+ // ByteOffset is the offset of the byte in the value produced by the load.
+ LoadSDNode *Load;
+ unsigned ByteOffset;
+
+ ByteProvider() : Load(nullptr), ByteOffset(0) {}
+
+ static ByteProvider getMemory(LoadSDNode *Load, unsigned ByteOffset) {
+ return ByteProvider(Load, ByteOffset);
}
+ static ByteProvider getConstantZero() { return ByteProvider(nullptr, 0); }
+
+ bool isConstantZero() const { return !Load; }
+ bool isMemory() const { return Load; }
+
+ bool operator==(const ByteProvider &Other) const {
+ return Other.Load == Load && Other.ByteOffset == ByteOffset;
+ }
+
+private:
+ ByteProvider(LoadSDNode *Load, unsigned ByteOffset)
+ : Load(Load), ByteOffset(ByteOffset) {}
};
+
+/// Recursively traverses the expression calculating the origin of the requested
+/// byte of the given value. Returns None if the provider can't be calculated.
+///
+/// For all the values except the root of the expression verifies that the value
+/// has exactly one use and if it's not true return None. This way if the origin
+/// of the byte is returned it's guaranteed that the values which contribute to
+/// the byte are not used outside of this expression.
+///
+/// Because the parts of the expression are not allowed to have more than one
+/// use this function iterates over trees, not DAGs. So it never visits the same
+/// node more than once.
+const Optional<ByteProvider> calculateByteProvider(SDValue Op, unsigned Index,
+ unsigned Depth,
+ bool Root = false) {
+ // Typical i64 by i8 pattern requires recursion up to 8 calls depth
+ if (Depth == 10)
+ return None;
+
+ if (!Root && !Op.hasOneUse())
+ return None;
+
+ assert(Op.getValueType().isScalarInteger() && "can't handle other types");
+ unsigned BitWidth = Op.getValueSizeInBits();
+ if (BitWidth % 8 != 0)
+ return None;
+ unsigned ByteWidth = BitWidth / 8;
+ assert(Index < ByteWidth && "invalid index requested");
+ (void) ByteWidth;
+
+ switch (Op.getOpcode()) {
+ case ISD::OR: {
+ auto LHS = calculateByteProvider(Op->getOperand(0), Index, Depth + 1);
+ if (!LHS)
+ return None;
+ auto RHS = calculateByteProvider(Op->getOperand(1), Index, Depth + 1);
+ if (!RHS)
+ return None;
+
+ if (LHS->isConstantZero())
+ return RHS;
+ if (RHS->isConstantZero())
+ return LHS;
+ return None;
+ }
+ case ISD::SHL: {
+ auto ShiftOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
+ if (!ShiftOp)
+ return None;
+
+ uint64_t BitShift = ShiftOp->getZExtValue();
+ if (BitShift % 8 != 0)
+ return None;
+ uint64_t ByteShift = BitShift / 8;
+
+ return Index < ByteShift
+ ? ByteProvider::getConstantZero()
+ : calculateByteProvider(Op->getOperand(0), Index - ByteShift,
+ Depth + 1);
+ }
+ case ISD::ANY_EXTEND:
+ case ISD::SIGN_EXTEND:
+ case ISD::ZERO_EXTEND: {
+ SDValue NarrowOp = Op->getOperand(0);
+ unsigned NarrowBitWidth = NarrowOp.getScalarValueSizeInBits();
+ if (NarrowBitWidth % 8 != 0)
+ return None;
+ uint64_t NarrowByteWidth = NarrowBitWidth / 8;
+
+ if (Index >= NarrowByteWidth)
+ return Op.getOpcode() == ISD::ZERO_EXTEND
+ ? Optional<ByteProvider>(ByteProvider::getConstantZero())
+ : None;
+ return calculateByteProvider(NarrowOp, Index, Depth + 1);
+ }
+ case ISD::BSWAP:
+ return calculateByteProvider(Op->getOperand(0), ByteWidth - Index - 1,
+ Depth + 1);
+ case ISD::LOAD: {
+ auto L = cast<LoadSDNode>(Op.getNode());
+ if (L->isVolatile() || L->isIndexed())
+ return None;
+
+ unsigned NarrowBitWidth = L->getMemoryVT().getSizeInBits();
+ if (NarrowBitWidth % 8 != 0)
+ return None;
+ uint64_t NarrowByteWidth = NarrowBitWidth / 8;
+
+ if (Index >= NarrowByteWidth)
+ return L->getExtensionType() == ISD::ZEXTLOAD
+ ? Optional<ByteProvider>(ByteProvider::getConstantZero())
+ : None;
+ return ByteProvider::getMemory(L, Index);
+ }
+ }
+
+ return None;
+}
} // namespace
+/// Match a pattern where a wide type scalar value is loaded by several narrow
+/// loads and combined by shifts and ors. Fold it into a single load or a load
+/// and a BSWAP if the targets supports it.
+///
+/// Assuming little endian target:
+/// i8 *a = ...
+/// i32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24)
+/// =>
+/// i32 val = *((i32)a)
+///
+/// i8 *a = ...
+/// i32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3]
+/// =>
+/// i32 val = BSWAP(*((i32)a))
+///
+/// TODO: This rule matches complex patterns with OR node roots and doesn't
+/// interact well with the worklist mechanism. When a part of the pattern is
+/// updated (e.g. one of the loads) its direct users are put into the worklist,
+/// but the root node of the pattern which triggers the load combine is not
+/// necessarily a direct user of the changed node. For example, once the address
+/// of t28 load is reassociated load combine won't be triggered:
+/// t25: i32 = add t4, Constant:i32<2>
+/// t26: i64 = sign_extend t25
+/// t27: i64 = add t2, t26
+/// t28: i8,ch = load<LD1[%tmp9]> t0, t27, undef:i64
+/// t29: i32 = zero_extend t28
+/// t32: i32 = shl t29, Constant:i8<8>
+/// t33: i32 = or t23, t32
+/// As a possible fix visitLoad can check if the load can be a part of a load
+/// combine pattern and add corresponding OR roots to the worklist.
+SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
+ assert(N->getOpcode() == ISD::OR &&
+ "Can only match load combining against OR nodes");
+
+ // Handles simple types only
+ EVT VT = N->getValueType(0);
+ if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
+ return SDValue();
+ unsigned ByteWidth = VT.getSizeInBits() / 8;
+
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+ // Before legalize we can introduce too wide illegal loads which will be later
+ // split into legal sized loads. This enables us to combine i64 load by i8
+ // patterns to a couple of i32 loads on 32 bit targets.
+ if (LegalOperations && !TLI.isOperationLegal(ISD::LOAD, VT))
+ return SDValue();
+
+ std::function<unsigned(unsigned, unsigned)> LittleEndianByteAt = [](
+ unsigned BW, unsigned i) { return i; };
+ std::function<unsigned(unsigned, unsigned)> BigEndianByteAt = [](
+ unsigned BW, unsigned i) { return BW - i - 1; };
+
+ bool IsBigEndianTarget = DAG.getDataLayout().isBigEndian();
+ auto MemoryByteOffset = [&] (ByteProvider P) {
+ assert(P.isMemory() && "Must be a memory byte provider");
+ unsigned LoadBitWidth = P.Load->getMemoryVT().getSizeInBits();
+ assert(LoadBitWidth % 8 == 0 &&
+ "can only analyze providers for individual bytes not bit");
+ unsigned LoadByteWidth = LoadBitWidth / 8;
+ return IsBigEndianTarget
+ ? BigEndianByteAt(LoadByteWidth, P.ByteOffset)
+ : LittleEndianByteAt(LoadByteWidth, P.ByteOffset);
+ };
+
+ Optional<BaseIndexOffset> Base;
+ SDValue Chain;
+
+ SmallSet<LoadSDNode *, 8> Loads;
+ Optional<ByteProvider> FirstByteProvider;
+ int64_t FirstOffset = INT64_MAX;
+
+ // Check if all the bytes of the OR we are looking at are loaded from the same
+ // base address. Collect bytes offsets from Base address in ByteOffsets.
+ SmallVector<int64_t, 4> ByteOffsets(ByteWidth);
+ for (unsigned i = 0; i < ByteWidth; i++) {
+ auto P = calculateByteProvider(SDValue(N, 0), i, 0, /*Root=*/true);
+ if (!P || !P->isMemory()) // All the bytes must be loaded from memory
+ return SDValue();
+
+ LoadSDNode *L = P->Load;
+ assert(L->hasNUsesOfValue(1, 0) && !L->isVolatile() && !L->isIndexed() &&
+ "Must be enforced by calculateByteProvider");
+ assert(L->getOffset().isUndef() && "Unindexed load must have undef offset");
+
+ // All loads must share the same chain
+ SDValue LChain = L->getChain();
+ if (!Chain)
+ Chain = LChain;
+ else if (Chain != LChain)
+ return SDValue();
+
+ // Loads must share the same base address
+ BaseIndexOffset Ptr = BaseIndexOffset::match(L->getBasePtr(), DAG);
+ int64_t ByteOffsetFromBase = 0;
+ if (!Base)
+ Base = Ptr;
+ else if (!Base->equalBaseIndex(Ptr, DAG, ByteOffsetFromBase))
+ return SDValue();
+
+ // Calculate the offset of the current byte from the base address
+ ByteOffsetFromBase += MemoryByteOffset(*P);
+ ByteOffsets[i] = ByteOffsetFromBase;
+
+ // Remember the first byte load
+ if (ByteOffsetFromBase < FirstOffset) {
+ FirstByteProvider = P;
+ FirstOffset = ByteOffsetFromBase;
+ }
+
+ Loads.insert(L);
+ }
+ assert(Loads.size() > 0 && "All the bytes of the value must be loaded from "
+ "memory, so there must be at least one load which produces the value");
+ assert(Base && "Base address of the accessed memory location must be set");
+ assert(FirstOffset != INT64_MAX && "First byte offset must be set");
+
+ // Check if the bytes of the OR we are looking at match with either big or
+ // little endian value load
+ bool BigEndian = true, LittleEndian = true;
+ for (unsigned i = 0; i < ByteWidth; i++) {
+ int64_t CurrentByteOffset = ByteOffsets[i] - FirstOffset;
+ LittleEndian &= CurrentByteOffset == LittleEndianByteAt(ByteWidth, i);
+ BigEndian &= CurrentByteOffset == BigEndianByteAt(ByteWidth, i);
+ if (!BigEndian && !LittleEndian)
+ return SDValue();
+ }
+ assert((BigEndian != LittleEndian) && "should be either or");
+ assert(FirstByteProvider && "must be set");
+
+ // Ensure that the first byte is loaded from zero offset of the first load.
+ // So the combined value can be loaded from the first load address.
+ if (MemoryByteOffset(*FirstByteProvider) != 0)
+ return SDValue();
+ LoadSDNode *FirstLoad = FirstByteProvider->Load;
+
+ // The node we are looking at matches with the pattern, check if we can
+ // replace it with a single load and bswap if needed.
+
+ // If the load needs byte swap check if the target supports it
+ bool NeedsBswap = IsBigEndianTarget != BigEndian;
+
+ // Before legalize we can introduce illegal bswaps which will be later
+ // converted to an explicit bswap sequence. This way we end up with a single
+ // load and byte shuffling instead of several loads and byte shuffling.
+ if (NeedsBswap && LegalOperations && !TLI.isOperationLegal(ISD::BSWAP, VT))
+ return SDValue();
+
+ // Check that a load of the wide type is both allowed and fast on the target
+ bool Fast = false;
+ bool Allowed = TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(),
+ VT, FirstLoad->getAddressSpace(),
+ FirstLoad->getAlignment(), &Fast);
+ if (!Allowed || !Fast)
+ return SDValue();
+
+ SDValue NewLoad =
+ DAG.getLoad(VT, SDLoc(N), Chain, FirstLoad->getBasePtr(),
+ FirstLoad->getPointerInfo(), FirstLoad->getAlignment());
+
+ // Transfer chain users from old loads to the new load.
+ for (LoadSDNode *L : Loads)
+ DAG.ReplaceAllUsesOfValueWith(SDValue(L, 1), SDValue(NewLoad.getNode(), 1));
+
+ return NeedsBswap ? DAG.getNode(ISD::BSWAP, SDLoc(N), VT, NewLoad) : NewLoad;
+}
+
SDValue DAGCombiner::visitXOR(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
@@ -4386,6 +5042,10 @@ SDValue DAGCombiner::visitXOR(SDNode *N) {
// fold (xor x, 0) -> x
if (isNullConstant(N1))
return N0;
+
+ if (SDValue NewSel = foldBinOpIntoSelect(N))
+ return NewSel;
+
// reassociate xor
if (SDValue RXOR = ReassociateOps(ISD::XOR, SDLoc(N), N0, N1))
return RXOR;
@@ -4403,9 +5063,9 @@ SDValue DAGCombiner::visitXOR(SDNode *N) {
default:
llvm_unreachable("Unhandled SetCC Equivalent!");
case ISD::SETCC:
- return DAG.getSetCC(SDLoc(N), VT, LHS, RHS, NotCC);
+ return DAG.getSetCC(SDLoc(N0), VT, LHS, RHS, NotCC);
case ISD::SELECT_CC:
- return DAG.getSelectCC(SDLoc(N), LHS, RHS, N0.getOperand(2),
+ return DAG.getSelectCC(SDLoc(N0), LHS, RHS, N0.getOperand(2),
N0.getOperand(3), NotCC);
}
}
@@ -4470,6 +5130,17 @@ SDValue DAGCombiner::visitXOR(SDNode *N) {
N01C->getAPIntValue(), DL, VT));
}
}
+
+ // fold Y = sra (X, size(X)-1); xor (add (X, Y), Y) -> (abs X)
+ unsigned OpSizeInBits = VT.getScalarSizeInBits();
+ if (N0.getOpcode() == ISD::ADD && N0.getOperand(1) == N1 &&
+ N1.getOpcode() == ISD::SRA && N1.getOperand(0) == N0.getOperand(0) &&
+ TLI.isOperationLegalOrCustom(ISD::ABS, VT)) {
+ if (ConstantSDNode *C = isConstOrConstSplat(N1.getOperand(1)))
+ if (C->getAPIntValue() == (OpSizeInBits - 1))
+ return DAG.getNode(ISD::ABS, SDLoc(N), VT, N0.getOperand(0));
+ }
+
// fold (xor x, x) -> 0
if (N0 == N1)
return tryFoldToZero(SDLoc(N), TLI, VT, DAG, LegalOperations, LegalTypes);
@@ -4505,8 +5176,7 @@ SDValue DAGCombiner::visitXOR(SDNode *N) {
return Tmp;
// Simplify the expression using non-local knowledge.
- if (!VT.isVector() &&
- SimplifyDemandedBits(SDValue(N, 0)))
+ if (SimplifyDemandedBits(SDValue(N, 0)))
return SDValue(N, 0);
return SDValue();
@@ -4613,13 +5283,51 @@ SDValue DAGCombiner::distributeTruncateThroughAnd(SDNode *N) {
}
SDValue DAGCombiner::visitRotate(SDNode *N) {
+ SDLoc dl(N);
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+ EVT VT = N->getValueType(0);
+ unsigned Bitsize = VT.getScalarSizeInBits();
+
+ // fold (rot x, 0) -> x
+ if (isNullConstantOrNullSplatConstant(N1))
+ return N0;
+
+ // fold (rot x, c) -> (rot x, c % BitSize)
+ if (ConstantSDNode *Cst = isConstOrConstSplat(N1)) {
+ if (Cst->getAPIntValue().uge(Bitsize)) {
+ uint64_t RotAmt = Cst->getAPIntValue().urem(Bitsize);
+ return DAG.getNode(N->getOpcode(), dl, VT, N0,
+ DAG.getConstant(RotAmt, dl, N1.getValueType()));
+ }
+ }
+
// fold (rot* x, (trunc (and y, c))) -> (rot* x, (and (trunc y), (trunc c))).
- if (N->getOperand(1).getOpcode() == ISD::TRUNCATE &&
- N->getOperand(1).getOperand(0).getOpcode() == ISD::AND) {
- if (SDValue NewOp1 =
- distributeTruncateThroughAnd(N->getOperand(1).getNode()))
- return DAG.getNode(N->getOpcode(), SDLoc(N), N->getValueType(0),
- N->getOperand(0), NewOp1);
+ if (N1.getOpcode() == ISD::TRUNCATE &&
+ N1.getOperand(0).getOpcode() == ISD::AND) {
+ if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
+ return DAG.getNode(N->getOpcode(), dl, VT, N0, NewOp1);
+ }
+
+ unsigned NextOp = N0.getOpcode();
+ // fold (rot* (rot* x, c2), c1) -> (rot* x, c1 +- c2 % bitsize)
+ if (NextOp == ISD::ROTL || NextOp == ISD::ROTR) {
+ SDNode *C1 = DAG.isConstantIntBuildVectorOrConstantInt(N1);
+ SDNode *C2 = DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1));
+ if (C1 && C2 && C1->getValueType(0) == C2->getValueType(0)) {
+ EVT ShiftVT = C1->getValueType(0);
+ bool SameSide = (N->getOpcode() == NextOp);
+ unsigned CombineOp = SameSide ? ISD::ADD : ISD::SUB;
+ if (SDValue CombinedShift =
+ DAG.FoldConstantArithmetic(CombineOp, dl, ShiftVT, C1, C2)) {
+ SDValue BitsizeC = DAG.getConstant(Bitsize, dl, ShiftVT);
+ SDValue CombinedShiftNorm = DAG.FoldConstantArithmetic(
+ ISD::SREM, dl, ShiftVT, CombinedShift.getNode(),
+ BitsizeC.getNode());
+ return DAG.getNode(N->getOpcode(), dl, VT, N0->getOperand(0),
+ CombinedShiftNorm);
+ }
+ }
}
return SDValue();
}
@@ -4662,7 +5370,7 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {
if (N0C && N1C && !N1C->isOpaque())
return DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N), VT, N0C, N1C);
// fold (shl 0, x) -> 0
- if (isNullConstant(N0))
+ if (isNullConstantOrNullSplatConstant(N0))
return N0;
// fold (shl x, c >= size(x)) -> undef
if (N1C && N1C->getAPIntValue().uge(OpSizeInBits))
@@ -4673,6 +5381,10 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {
// fold (shl undef, x) -> 0
if (N0.isUndef())
return DAG.getConstant(0, SDLoc(N), VT);
+
+ if (SDValue NewSel = foldBinOpIntoSelect(N))
+ return NewSel;
+
// if (shl x, c) is known to be zero, return 0
if (DAG.MaskedValueIsZero(SDValue(N, 0),
APInt::getAllOnesValue(OpSizeInBits)))
@@ -4763,7 +5475,7 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {
// fold (shl (sr[la] exact X, C1), C2) -> (shl X, (C2-C1)) if C1 <= C2
// fold (shl (sr[la] exact X, C1), C2) -> (sr[la] X, (C2-C1)) if C1 > C2
if (N1C && (N0.getOpcode() == ISD::SRL || N0.getOpcode() == ISD::SRA) &&
- cast<BinaryWithFlagsSDNode>(N0)->Flags.hasExact()) {
+ N0->getFlags().hasExact()) {
if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) {
uint64_t C1 = N0C1->getZExtValue();
uint64_t C2 = N1C->getZExtValue();
@@ -4788,12 +5500,12 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {
APInt Mask = APInt::getHighBitsSet(OpSizeInBits, OpSizeInBits - c1);
SDValue Shift;
if (c2 > c1) {
- Mask = Mask.shl(c2 - c1);
+ Mask <<= c2 - c1;
SDLoc DL(N);
Shift = DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0),
DAG.getConstant(c2 - c1, DL, N1.getValueType()));
} else {
- Mask = Mask.lshr(c1 - c2);
+ Mask.lshrInPlace(c1 - c2);
SDLoc DL(N);
Shift = DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0),
DAG.getConstant(c1 - c2, DL, N1.getValueType()));
@@ -4808,9 +5520,8 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {
// fold (shl (sra x, c1), c1) -> (and x, (shl -1, c1))
if (N0.getOpcode() == ISD::SRA && N1 == N0.getOperand(1) &&
isConstantOrConstantVector(N1, /* No Opaques */ true)) {
- unsigned BitSize = VT.getScalarSizeInBits();
SDLoc DL(N);
- SDValue AllBits = DAG.getConstant(APInt::getAllOnesValue(BitSize), DL, VT);
+ SDValue AllBits = DAG.getAllOnesConstant(DL, VT);
SDValue HiBitsMask = DAG.getNode(ISD::SHL, DL, VT, AllBits, N1);
return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), HiBitsMask);
}
@@ -4851,6 +5562,8 @@ SDValue DAGCombiner::visitSRA(SDNode *N) {
unsigned OpSizeInBits = VT.getScalarSizeInBits();
// Arithmetic shifting an all-sign-bit value is a no-op.
+ // fold (sra 0, x) -> 0
+ // fold (sra -1, x) -> -1
if (DAG.ComputeNumSignBits(N0) == OpSizeInBits)
return N0;
@@ -4865,18 +5578,16 @@ SDValue DAGCombiner::visitSRA(SDNode *N) {
ConstantSDNode *N0C = getAsNonOpaqueConstant(N0);
if (N0C && N1C && !N1C->isOpaque())
return DAG.FoldConstantArithmetic(ISD::SRA, SDLoc(N), VT, N0C, N1C);
- // fold (sra 0, x) -> 0
- if (isNullConstant(N0))
- return N0;
- // fold (sra -1, x) -> -1
- if (isAllOnesConstant(N0))
- return N0;
// fold (sra x, c >= size(x)) -> undef
if (N1C && N1C->getAPIntValue().uge(OpSizeInBits))
return DAG.getUNDEF(VT);
// fold (sra x, 0) -> x
if (N1C && N1C->isNullValue())
return N0;
+
+ if (SDValue NewSel = foldBinOpIntoSelect(N))
+ return NewSel;
+
// fold (sra (shl x, c1), c1) -> sext_inreg for some c1 and target supports
// sext_inreg.
if (N1C && N0.getOpcode() == ISD::SHL && N1 == N0.getOperand(1)) {
@@ -5016,7 +5727,7 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
if (N0C && N1C && !N1C->isOpaque())
return DAG.FoldConstantArithmetic(ISD::SRL, SDLoc(N), VT, N0C, N1C);
// fold (srl 0, x) -> 0
- if (isNullConstant(N0))
+ if (isNullConstantOrNullSplatConstant(N0))
return N0;
// fold (srl x, c >= size(x)) -> undef
if (N1C && N1C->getAPIntValue().uge(OpSizeInBits))
@@ -5024,6 +5735,10 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
// fold (srl x, 0) -> x
if (N1C && N1C->isNullValue())
return N0;
+
+ if (SDValue NewSel = foldBinOpIntoSelect(N))
+ return NewSel;
+
// if (srl x, c) is known to be zero, return 0
if (N1C && DAG.MaskedValueIsZero(SDValue(N, 0),
APInt::getAllOnesValue(OpSizeInBits)))
@@ -5049,24 +5764,24 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
// fold (srl (trunc (srl x, c1)), c2) -> 0 or (trunc (srl x, (add c1, c2)))
if (N1C && N0.getOpcode() == ISD::TRUNCATE &&
- N0.getOperand(0).getOpcode() == ISD::SRL &&
- isa<ConstantSDNode>(N0.getOperand(0)->getOperand(1))) {
- uint64_t c1 =
- cast<ConstantSDNode>(N0.getOperand(0)->getOperand(1))->getZExtValue();
- uint64_t c2 = N1C->getZExtValue();
- EVT InnerShiftVT = N0.getOperand(0).getValueType();
- EVT ShiftCountVT = N0.getOperand(0)->getOperand(1).getValueType();
- uint64_t InnerShiftSize = InnerShiftVT.getScalarSizeInBits();
- // This is only valid if the OpSizeInBits + c1 = size of inner shift.
- if (c1 + OpSizeInBits == InnerShiftSize) {
- SDLoc DL(N0);
- if (c1 + c2 >= InnerShiftSize)
- return DAG.getConstant(0, DL, VT);
- return DAG.getNode(ISD::TRUNCATE, DL, VT,
- DAG.getNode(ISD::SRL, DL, InnerShiftVT,
- N0.getOperand(0)->getOperand(0),
- DAG.getConstant(c1 + c2, DL,
- ShiftCountVT)));
+ N0.getOperand(0).getOpcode() == ISD::SRL) {
+ if (auto N001C = isConstOrConstSplat(N0.getOperand(0).getOperand(1))) {
+ uint64_t c1 = N001C->getZExtValue();
+ uint64_t c2 = N1C->getZExtValue();
+ EVT InnerShiftVT = N0.getOperand(0).getValueType();
+ EVT ShiftCountVT = N0.getOperand(0).getOperand(1).getValueType();
+ uint64_t InnerShiftSize = InnerShiftVT.getScalarSizeInBits();
+ // This is only valid if the OpSizeInBits + c1 = size of inner shift.
+ if (c1 + OpSizeInBits == InnerShiftSize) {
+ SDLoc DL(N0);
+ if (c1 + c2 >= InnerShiftSize)
+ return DAG.getConstant(0, DL, VT);
+ return DAG.getNode(ISD::TRUNCATE, DL, VT,
+ DAG.getNode(ISD::SRL, DL, InnerShiftVT,
+ N0.getOperand(0).getOperand(0),
+ DAG.getConstant(c1 + c2, DL,
+ ShiftCountVT)));
+ }
}
}
@@ -5074,9 +5789,8 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
if (N0.getOpcode() == ISD::SHL && N0.getOperand(1) == N1 &&
isConstantOrConstantVector(N1, /* NoOpaques */ true)) {
SDLoc DL(N);
- APInt AllBits = APInt::getAllOnesValue(N0.getScalarValueSizeInBits());
SDValue Mask =
- DAG.getNode(ISD::SRL, DL, VT, DAG.getConstant(AllBits, DL, VT), N1);
+ DAG.getNode(ISD::SRL, DL, VT, DAG.getAllOnesConstant(DL, VT), N1);
AddToWorklist(Mask.getNode());
return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), Mask);
}
@@ -5097,7 +5811,7 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
DAG.getConstant(ShiftAmt, DL0,
getShiftAmountTy(SmallVT)));
AddToWorklist(SmallShift.getNode());
- APInt Mask = APInt::getAllOnesValue(OpSizeInBits).lshr(ShiftAmt);
+ APInt Mask = APInt::getLowBitsSet(OpSizeInBits, OpSizeInBits - ShiftAmt);
SDLoc DL(N);
return DAG.getNode(ISD::AND, DL, VT,
DAG.getNode(ISD::ANY_EXTEND, DL, VT, SmallShift),
@@ -5115,20 +5829,20 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
// fold (srl (ctlz x), "5") -> x iff x has one bit set (the low bit).
if (N1C && N0.getOpcode() == ISD::CTLZ &&
N1C->getAPIntValue() == Log2_32(OpSizeInBits)) {
- APInt KnownZero, KnownOne;
- DAG.computeKnownBits(N0.getOperand(0), KnownZero, KnownOne);
+ KnownBits Known;
+ DAG.computeKnownBits(N0.getOperand(0), Known);
// If any of the input bits are KnownOne, then the input couldn't be all
// zeros, thus the result of the srl will always be zero.
- if (KnownOne.getBoolValue()) return DAG.getConstant(0, SDLoc(N0), VT);
+ if (Known.One.getBoolValue()) return DAG.getConstant(0, SDLoc(N0), VT);
// If all of the bits input the to ctlz node are known to be zero, then
// the result of the ctlz is "32" and the result of the shift is one.
- APInt UnknownBits = ~KnownZero;
+ APInt UnknownBits = ~Known.Zero;
if (UnknownBits == 0) return DAG.getConstant(1, SDLoc(N0), VT);
// Otherwise, check to see if there is exactly one bit input to the ctlz.
- if ((UnknownBits & (UnknownBits - 1)) == 0) {
+ if (UnknownBits.isPowerOf2()) {
// Okay, we know that only that the single bit specified by UnknownBits
// could be set on input to the CTLZ node. If this bit is set, the SRL
// will return 0, if it is clear, it returns 1. Change the CTLZ/SRL pair
@@ -5202,6 +5916,22 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
return SDValue();
}
+SDValue DAGCombiner::visitABS(SDNode *N) {
+ SDValue N0 = N->getOperand(0);
+ EVT VT = N->getValueType(0);
+
+ // fold (abs c1) -> c2
+ if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
+ return DAG.getNode(ISD::ABS, SDLoc(N), VT, N0);
+ // fold (abs (abs x)) -> (abs x)
+ if (N0.getOpcode() == ISD::ABS)
+ return N0;
+ // fold (abs x) -> x iff not-negative
+ if (DAG.SignBitIsZero(N0))
+ return N0;
+ return SDValue();
+}
+
SDValue DAGCombiner::visitBSWAP(SDNode *N) {
SDValue N0 = N->getOperand(0);
EVT VT = N->getValueType(0);
@@ -5217,7 +5947,11 @@ SDValue DAGCombiner::visitBSWAP(SDNode *N) {
SDValue DAGCombiner::visitBITREVERSE(SDNode *N) {
SDValue N0 = N->getOperand(0);
+ EVT VT = N->getValueType(0);
+ // fold (bitreverse c1) -> c2
+ if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
+ return DAG.getNode(ISD::BITREVERSE, SDLoc(N), VT, N0);
// fold (bitreverse (bitreverse x)) -> x
if (N0.getOpcode() == ISD::BITREVERSE)
return N0.getOperand(0);
@@ -5311,7 +6045,6 @@ static SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
}
}
-// TODO: We should handle other cases of selecting between {-1,0,1} here.
SDValue DAGCombiner::foldSelectOfConstants(SDNode *N) {
SDValue Cond = N->getOperand(0);
SDValue N1 = N->getOperand(1);
@@ -5320,6 +6053,67 @@ SDValue DAGCombiner::foldSelectOfConstants(SDNode *N) {
EVT CondVT = Cond.getValueType();
SDLoc DL(N);
+ if (!VT.isInteger())
+ return SDValue();
+
+ auto *C1 = dyn_cast<ConstantSDNode>(N1);
+ auto *C2 = dyn_cast<ConstantSDNode>(N2);
+ if (!C1 || !C2)
+ return SDValue();
+
+ // Only do this before legalization to avoid conflicting with target-specific
+ // transforms in the other direction (create a select from a zext/sext). There
+ // is also a target-independent combine here in DAGCombiner in the other
+ // direction for (select Cond, -1, 0) when the condition is not i1.
+ if (CondVT == MVT::i1 && !LegalOperations) {
+ if (C1->isNullValue() && C2->isOne()) {
+ // select Cond, 0, 1 --> zext (!Cond)
+ SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
+ if (VT != MVT::i1)
+ NotCond = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, NotCond);
+ return NotCond;
+ }
+ if (C1->isNullValue() && C2->isAllOnesValue()) {
+ // select Cond, 0, -1 --> sext (!Cond)
+ SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
+ if (VT != MVT::i1)
+ NotCond = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, NotCond);
+ return NotCond;
+ }
+ if (C1->isOne() && C2->isNullValue()) {
+ // select Cond, 1, 0 --> zext (Cond)
+ if (VT != MVT::i1)
+ Cond = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Cond);
+ return Cond;
+ }
+ if (C1->isAllOnesValue() && C2->isNullValue()) {
+ // select Cond, -1, 0 --> sext (Cond)
+ if (VT != MVT::i1)
+ Cond = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Cond);
+ return Cond;
+ }
+
+ // For any constants that differ by 1, we can transform the select into an
+ // extend and add. Use a target hook because some targets may prefer to
+ // transform in the other direction.
+ if (TLI.convertSelectOfConstantsToMath()) {
+ if (C1->getAPIntValue() - 1 == C2->getAPIntValue()) {
+ // select Cond, C1, C1-1 --> add (zext Cond), C1-1
+ if (VT != MVT::i1)
+ Cond = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Cond);
+ return DAG.getNode(ISD::ADD, DL, VT, Cond, N2);
+ }
+ if (C1->getAPIntValue() + 1 == C2->getAPIntValue()) {
+ // select Cond, C1, C1+1 --> add (sext Cond), C1+1
+ if (VT != MVT::i1)
+ Cond = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Cond);
+ return DAG.getNode(ISD::ADD, DL, VT, Cond, N2);
+ }
+ }
+
+ return SDValue();
+ }
+
// fold (select Cond, 0, 1) -> (xor Cond, 1)
// We can't do this reliably if integer based booleans have different contents
// to floating point based booleans. This is because we can't tell whether we
@@ -5329,15 +6123,14 @@ SDValue DAGCombiner::foldSelectOfConstants(SDNode *N) {
// undiscoverable (or not reasonably discoverable). For example, it could be
// in another basic block or it could require searching a complicated
// expression.
- if (VT.isInteger() &&
- (CondVT == MVT::i1 || (CondVT.isInteger() &&
- TLI.getBooleanContents(false, true) ==
- TargetLowering::ZeroOrOneBooleanContent &&
- TLI.getBooleanContents(false, false) ==
- TargetLowering::ZeroOrOneBooleanContent)) &&
- isNullConstant(N1) && isOneConstant(N2)) {
- SDValue NotCond = DAG.getNode(ISD::XOR, DL, CondVT, Cond,
- DAG.getConstant(1, DL, CondVT));
+ if (CondVT.isInteger() &&
+ TLI.getBooleanContents(false, true) ==
+ TargetLowering::ZeroOrOneBooleanContent &&
+ TLI.getBooleanContents(false, false) ==
+ TargetLowering::ZeroOrOneBooleanContent &&
+ C1->isNullValue() && C2->isOne()) {
+ SDValue NotCond =
+ DAG.getNode(ISD::XOR, DL, CondVT, Cond, DAG.getConstant(1, DL, CondVT));
if (VT.bitsEq(CondVT))
return NotCond;
return DAG.getZExtOrTrunc(NotCond, DL, VT);
@@ -5352,19 +6145,22 @@ SDValue DAGCombiner::visitSELECT(SDNode *N) {
SDValue N2 = N->getOperand(2);
EVT VT = N->getValueType(0);
EVT VT0 = N0.getValueType();
+ SDLoc DL(N);
// fold (select C, X, X) -> X
if (N1 == N2)
return N1;
+
if (const ConstantSDNode *N0C = dyn_cast<const ConstantSDNode>(N0)) {
// fold (select true, X, Y) -> X
// fold (select false, X, Y) -> Y
return !N0C->isNullValue() ? N1 : N2;
}
+
// fold (select X, X, Y) -> (or X, Y)
// fold (select X, 1, Y) -> (or C, Y)
if (VT == VT0 && VT == MVT::i1 && (N0 == N1 || isOneConstant(N1)))
- return DAG.getNode(ISD::OR, SDLoc(N), VT, N0, N2);
+ return DAG.getNode(ISD::OR, DL, VT, N0, N2);
if (SDValue V = foldSelectOfConstants(N))
return V;
@@ -5373,22 +6169,22 @@ SDValue DAGCombiner::visitSELECT(SDNode *N) {
if (VT == VT0 && VT == MVT::i1 && isNullConstant(N1)) {
SDValue NOTNode = DAG.getNOT(SDLoc(N0), N0, VT);
AddToWorklist(NOTNode.getNode());
- return DAG.getNode(ISD::AND, SDLoc(N), VT, NOTNode, N2);
+ return DAG.getNode(ISD::AND, DL, VT, NOTNode, N2);
}
// fold (select C, X, 1) -> (or (not C), X)
if (VT == VT0 && VT == MVT::i1 && isOneConstant(N2)) {
SDValue NOTNode = DAG.getNOT(SDLoc(N0), N0, VT);
AddToWorklist(NOTNode.getNode());
- return DAG.getNode(ISD::OR, SDLoc(N), VT, NOTNode, N1);
+ return DAG.getNode(ISD::OR, DL, VT, NOTNode, N1);
}
// fold (select X, Y, X) -> (and X, Y)
// fold (select X, Y, 0) -> (and X, Y)
if (VT == VT0 && VT == MVT::i1 && (N0 == N2 || isNullConstant(N2)))
- return DAG.getNode(ISD::AND, SDLoc(N), VT, N0, N1);
+ return DAG.getNode(ISD::AND, DL, VT, N0, N1);
// If we can fold this based on the true/false value, do so.
if (SimplifySelectOps(N, N1, N2))
- return SDValue(N, 0); // Don't revisit N.
+ return SDValue(N, 0); // Don't revisit N.
if (VT0 == MVT::i1) {
// The code in this block deals with the following 2 equivalences:
@@ -5399,27 +6195,27 @@ SDValue DAGCombiner::visitSELECT(SDNode *N) {
// to the right anyway if we find the inner select exists in the DAG anyway
// and we always transform to the left side if we know that we can further
// optimize the combination of the conditions.
- bool normalizeToSequence
- = TLI.shouldNormalizeToSelectSequence(*DAG.getContext(), VT);
+ bool normalizeToSequence =
+ TLI.shouldNormalizeToSelectSequence(*DAG.getContext(), VT);
// select (and Cond0, Cond1), X, Y
// -> select Cond0, (select Cond1, X, Y), Y
if (N0->getOpcode() == ISD::AND && N0->hasOneUse()) {
SDValue Cond0 = N0->getOperand(0);
SDValue Cond1 = N0->getOperand(1);
- SDValue InnerSelect = DAG.getNode(ISD::SELECT, SDLoc(N),
- N1.getValueType(), Cond1, N1, N2);
+ SDValue InnerSelect =
+ DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond1, N1, N2);
if (normalizeToSequence || !InnerSelect.use_empty())
- return DAG.getNode(ISD::SELECT, SDLoc(N), N1.getValueType(), Cond0,
+ return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond0,
InnerSelect, N2);
}
// select (or Cond0, Cond1), X, Y -> select Cond0, X, (select Cond1, X, Y)
if (N0->getOpcode() == ISD::OR && N0->hasOneUse()) {
SDValue Cond0 = N0->getOperand(0);
SDValue Cond1 = N0->getOperand(1);
- SDValue InnerSelect = DAG.getNode(ISD::SELECT, SDLoc(N),
- N1.getValueType(), Cond1, N1, N2);
+ SDValue InnerSelect =
+ DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond1, N1, N2);
if (normalizeToSequence || !InnerSelect.use_empty())
- return DAG.getNode(ISD::SELECT, SDLoc(N), N1.getValueType(), Cond0, N1,
+ return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond0, N1,
InnerSelect);
}
@@ -5431,15 +6227,13 @@ SDValue DAGCombiner::visitSELECT(SDNode *N) {
if (N1_2 == N2 && N0.getValueType() == N1_0.getValueType()) {
// Create the actual and node if we can generate good code for it.
if (!normalizeToSequence) {
- SDValue And = DAG.getNode(ISD::AND, SDLoc(N), N0.getValueType(),
- N0, N1_0);
- return DAG.getNode(ISD::SELECT, SDLoc(N), N1.getValueType(), And,
- N1_1, N2);
+ SDValue And = DAG.getNode(ISD::AND, DL, N0.getValueType(), N0, N1_0);
+ return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), And, N1_1, N2);
}
// Otherwise see if we can optimize the "and" to a better pattern.
if (SDValue Combined = visitANDLike(N0, N1_0, N))
- return DAG.getNode(ISD::SELECT, SDLoc(N), N1.getValueType(), Combined,
- N1_1, N2);
+ return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Combined, N1_1,
+ N2);
}
}
// select Cond0, X, (select Cond1, X, Y) -> select (or Cond0, Cond1), X, Y
@@ -5450,15 +6244,13 @@ SDValue DAGCombiner::visitSELECT(SDNode *N) {
if (N2_1 == N1 && N0.getValueType() == N2_0.getValueType()) {
// Create the actual or node if we can generate good code for it.
if (!normalizeToSequence) {
- SDValue Or = DAG.getNode(ISD::OR, SDLoc(N), N0.getValueType(),
- N0, N2_0);
- return DAG.getNode(ISD::SELECT, SDLoc(N), N1.getValueType(), Or,
- N1, N2_2);
+ SDValue Or = DAG.getNode(ISD::OR, DL, N0.getValueType(), N0, N2_0);
+ return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Or, N1, N2_2);
}
// Otherwise see if we can optimize to a better pattern.
if (SDValue Combined = visitORLike(N0, N2_0, N))
- return DAG.getNode(ISD::SELECT, SDLoc(N), N1.getValueType(), Combined,
- N1, N2_2);
+ return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Combined, N1,
+ N2_2);
}
}
}
@@ -5469,8 +6261,7 @@ SDValue DAGCombiner::visitSELECT(SDNode *N) {
if (auto *C = dyn_cast<ConstantSDNode>(N0->getOperand(1))) {
SDValue Cond0 = N0->getOperand(0);
if (C->isOne())
- return DAG.getNode(ISD::SELECT, SDLoc(N), N1.getValueType(),
- Cond0, N2, N1);
+ return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond0, N2, N1);
}
}
}
@@ -5487,24 +6278,21 @@ SDValue DAGCombiner::visitSELECT(SDNode *N) {
// FIXME: Instead of testing for UnsafeFPMath, this should be checking for
// no signed zeros as well as no nans.
const TargetOptions &Options = DAG.getTarget().Options;
- if (Options.UnsafeFPMath &&
- VT.isFloatingPoint() && N0.hasOneUse() &&
+ if (Options.UnsafeFPMath && VT.isFloatingPoint() && N0.hasOneUse() &&
DAG.isKnownNeverNaN(N1) && DAG.isKnownNeverNaN(N2)) {
ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
- if (SDValue FMinMax = combineMinNumMaxNum(SDLoc(N), VT, N0.getOperand(0),
- N0.getOperand(1), N1, N2, CC,
- TLI, DAG))
+ if (SDValue FMinMax = combineMinNumMaxNum(
+ DL, VT, N0.getOperand(0), N0.getOperand(1), N1, N2, CC, TLI, DAG))
return FMinMax;
}
if ((!LegalOperations &&
TLI.isOperationLegalOrCustom(ISD::SELECT_CC, VT)) ||
TLI.isOperationLegal(ISD::SELECT_CC, VT))
- return DAG.getNode(ISD::SELECT_CC, SDLoc(N), VT,
- N0.getOperand(0), N0.getOperand(1),
- N1, N2, N0.getOperand(2));
- return SimplifySelect(SDLoc(N), N0, N1, N2);
+ return DAG.getNode(ISD::SELECT_CC, DL, VT, N0.getOperand(0),
+ N0.getOperand(1), N1, N2, N0.getOperand(2));
+ return SimplifySelect(DL, N0, N1, N2);
}
return SDValue();
@@ -5847,7 +6635,7 @@ SDValue DAGCombiner::visitMLOAD(SDNode *N) {
ISD::NON_EXTLOAD, MLD->isExpandingLoad());
Ptr = TLI.IncrementMemoryAddress(Ptr, MaskLo, DL, LoMemVT, DAG,
- MLD->isExpandingLoad());
+ MLD->isExpandingLoad());
MMO = DAG.getMachineFunction().
getMachineMemOperand(MLD->getPointerInfo(),
@@ -5908,6 +6696,9 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) {
if (isAbs) {
EVT VT = LHS.getValueType();
+ if (TLI.isOperationLegalOrCustom(ISD::ABS, VT))
+ return DAG.getNode(ISD::ABS, DL, VT, LHS);
+
SDValue Shift = DAG.getNode(
ISD::SRA, DL, VT, LHS,
DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT));
@@ -5921,34 +6712,6 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) {
if (SimplifySelectOps(N, N1, N2))
return SDValue(N, 0); // Don't revisit N.
- // If the VSELECT result requires splitting and the mask is provided by a
- // SETCC, then split both nodes and its operands before legalization. This
- // prevents the type legalizer from unrolling SETCC into scalar comparisons
- // and enables future optimizations (e.g. min/max pattern matching on X86).
- if (N0.getOpcode() == ISD::SETCC) {
- EVT VT = N->getValueType(0);
-
- // Check if any splitting is required.
- if (TLI.getTypeAction(*DAG.getContext(), VT) !=
- TargetLowering::TypeSplitVector)
- return SDValue();
-
- SDValue Lo, Hi, CCLo, CCHi, LL, LH, RL, RH;
- std::tie(CCLo, CCHi) = SplitVSETCC(N0.getNode(), DAG);
- std::tie(LL, LH) = DAG.SplitVectorOperand(N, 1);
- std::tie(RL, RH) = DAG.SplitVectorOperand(N, 2);
-
- Lo = DAG.getNode(N->getOpcode(), DL, LL.getValueType(), CCLo, LL, RL);
- Hi = DAG.getNode(N->getOpcode(), DL, LH.getValueType(), CCHi, LH, RH);
-
- // Add the new VSELECT nodes to the work list in case they need to be split
- // again.
- AddToWorklist(Lo.getNode());
- AddToWorklist(Hi.getNode());
-
- return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Lo, Hi);
- }
-
// Fold (vselect (build_vector all_ones), N1, N2) -> N1
if (ISD::isBuildVectorAllOnes(N0.getNode()))
return N1;
@@ -6030,6 +6793,19 @@ SDValue DAGCombiner::visitSETCCE(SDNode *N) {
return SDValue();
}
+SDValue DAGCombiner::visitSETCCCARRY(SDNode *N) {
+ SDValue LHS = N->getOperand(0);
+ SDValue RHS = N->getOperand(1);
+ SDValue Carry = N->getOperand(2);
+ SDValue Cond = N->getOperand(3);
+
+ // If Carry is false, fold to a regular SETCC.
+ if (isNullConstant(Carry))
+ return DAG.getNode(ISD::SETCC, SDLoc(N), N->getVTList(), LHS, RHS, Cond);
+
+ return SDValue();
+}
+
/// Try to fold a sext/zext/aext dag node into a ConstantSDNode or
/// a build_vector of constants.
/// This function is called by the DAGCombiner when visiting sext/zext/aext
@@ -6258,6 +7034,9 @@ SDValue DAGCombiner::CombineExtLoad(SDNode *N) {
SDValue NewChain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Chains);
SDValue NewValue = DAG.getNode(ISD::CONCAT_VECTORS, DL, DstVT, Loads);
+ // Simplify TF.
+ AddToWorklist(NewChain.getNode());
+
CombineTo(N, NewValue);
// Replace uses of the original load (before extension)
@@ -6270,9 +7049,55 @@ SDValue DAGCombiner::CombineExtLoad(SDNode *N) {
return SDValue(N, 0); // Return N so it doesn't get rechecked!
}
+/// If we're narrowing or widening the result of a vector select and the final
+/// size is the same size as a setcc (compare) feeding the select, then try to
+/// apply the cast operation to the select's operands because matching vector
+/// sizes for a select condition and other operands should be more efficient.
+SDValue DAGCombiner::matchVSelectOpSizesWithSetCC(SDNode *Cast) {
+ unsigned CastOpcode = Cast->getOpcode();
+ assert((CastOpcode == ISD::SIGN_EXTEND || CastOpcode == ISD::ZERO_EXTEND ||
+ CastOpcode == ISD::TRUNCATE || CastOpcode == ISD::FP_EXTEND ||
+ CastOpcode == ISD::FP_ROUND) &&
+ "Unexpected opcode for vector select narrowing/widening");
+
+ // We only do this transform before legal ops because the pattern may be
+ // obfuscated by target-specific operations after legalization. Do not create
+ // an illegal select op, however, because that may be difficult to lower.
+ EVT VT = Cast->getValueType(0);
+ if (LegalOperations || !TLI.isOperationLegalOrCustom(ISD::VSELECT, VT))
+ return SDValue();
+
+ SDValue VSel = Cast->getOperand(0);
+ if (VSel.getOpcode() != ISD::VSELECT || !VSel.hasOneUse() ||
+ VSel.getOperand(0).getOpcode() != ISD::SETCC)
+ return SDValue();
+
+ // Does the setcc have the same vector size as the casted select?
+ SDValue SetCC = VSel.getOperand(0);
+ EVT SetCCVT = getSetCCResultType(SetCC.getOperand(0).getValueType());
+ if (SetCCVT.getSizeInBits() != VT.getSizeInBits())
+ return SDValue();
+
+ // cast (vsel (setcc X), A, B) --> vsel (setcc X), (cast A), (cast B)
+ SDValue A = VSel.getOperand(1);
+ SDValue B = VSel.getOperand(2);
+ SDValue CastA, CastB;
+ SDLoc DL(Cast);
+ if (CastOpcode == ISD::FP_ROUND) {
+ // FP_ROUND (fptrunc) has an extra flag operand to pass along.
+ CastA = DAG.getNode(CastOpcode, DL, VT, A, Cast->getOperand(1));
+ CastB = DAG.getNode(CastOpcode, DL, VT, B, Cast->getOperand(1));
+ } else {
+ CastA = DAG.getNode(CastOpcode, DL, VT, A);
+ CastB = DAG.getNode(CastOpcode, DL, VT, B);
+ }
+ return DAG.getNode(ISD::VSELECT, DL, VT, SetCC, CastA, CastB);
+}
+
SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
SDValue N0 = N->getOperand(0);
EVT VT = N->getValueType(0);
+ SDLoc DL(N);
if (SDNode *Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes,
LegalOperations))
@@ -6281,8 +7106,7 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
// fold (sext (sext x)) -> (sext x)
// fold (sext (aext x)) -> (sext x)
if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND)
- return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT,
- N0.getOperand(0));
+ return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, N0.getOperand(0));
if (N0.getOpcode() == ISD::TRUNCATE) {
// fold (sext (truncate (load x))) -> (sext (smaller load x))
@@ -6314,12 +7138,12 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
// Op is i32, Mid is i8, and Dest is i64. If Op has more than 24 sign
// bits, just sext from i32.
if (NumSignBits > OpBits-MidBits)
- return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, Op);
+ return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Op);
} else {
// Op is i64, Mid is i8, and Dest is i32. If Op has more than 56 sign
// bits, just truncate to i32.
if (NumSignBits > OpBits-MidBits)
- return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, Op);
+ return DAG.getNode(ISD::TRUNCATE, DL, VT, Op);
}
// fold (sext (truncate x)) -> (sextinreg x).
@@ -6329,7 +7153,7 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
Op = DAG.getNode(ISD::ANY_EXTEND, SDLoc(N0), VT, Op);
else if (OpBits > DestBits)
Op = DAG.getNode(ISD::TRUNCATE, SDLoc(N0), VT, Op);
- return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, Op,
+ return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, Op,
DAG.getValueType(N0.getValueType()));
}
}
@@ -6349,17 +7173,20 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
DoXform &= TLI.isVectorLoadExtDesirable(SDValue(N, 0));
if (DoXform) {
LoadSDNode *LN0 = cast<LoadSDNode>(N0);
- SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(N), VT,
- LN0->getChain(),
+ SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, DL, VT, LN0->getChain(),
LN0->getBasePtr(), N0.getValueType(),
LN0->getMemOperand());
- CombineTo(N, ExtLoad);
SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(N0),
N0.getValueType(), ExtLoad);
- CombineTo(N0.getNode(), Trunc, ExtLoad.getValue(1));
- ExtendSetCCUses(SetCCs, Trunc, ExtLoad, SDLoc(N),
- ISD::SIGN_EXTEND);
- return SDValue(N, 0); // Return N so it doesn't get rechecked!
+ ExtendSetCCUses(SetCCs, Trunc, ExtLoad, DL, ISD::SIGN_EXTEND);
+ // If the load value is used only by N, replace it via CombineTo N.
+ bool NoReplaceTrunc = SDValue(LN0, 0).hasOneUse();
+ CombineTo(N, ExtLoad);
+ if (NoReplaceTrunc)
+ DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
+ else
+ CombineTo(LN0, Trunc, ExtLoad.getValue(1));
+ return SDValue(N, 0);
}
}
@@ -6376,8 +7203,7 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
EVT MemVT = LN0->getMemoryVT();
if ((!LegalOperations && !LN0->isVolatile()) ||
TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, MemVT)) {
- SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(N), VT,
- LN0->getChain(),
+ SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, DL, VT, LN0->getChain(),
LN0->getBasePtr(), MemVT,
LN0->getMemOperand());
CombineTo(N, ExtLoad);
@@ -6411,32 +7237,38 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
LN0->getMemOperand());
APInt Mask = cast<ConstantSDNode>(N0.getOperand(1))->getAPIntValue();
Mask = Mask.sext(VT.getSizeInBits());
- SDLoc DL(N);
SDValue And = DAG.getNode(N0.getOpcode(), DL, VT,
ExtLoad, DAG.getConstant(Mask, DL, VT));
SDValue Trunc = DAG.getNode(ISD::TRUNCATE,
SDLoc(N0.getOperand(0)),
N0.getOperand(0).getValueType(), ExtLoad);
+ ExtendSetCCUses(SetCCs, Trunc, ExtLoad, DL, ISD::SIGN_EXTEND);
+ bool NoReplaceTrunc = SDValue(LN0, 0).hasOneUse();
CombineTo(N, And);
- CombineTo(N0.getOperand(0).getNode(), Trunc, ExtLoad.getValue(1));
- ExtendSetCCUses(SetCCs, Trunc, ExtLoad, DL,
- ISD::SIGN_EXTEND);
- return SDValue(N, 0); // Return N so it doesn't get rechecked!
+ if (NoReplaceTrunc)
+ DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
+ else
+ CombineTo(LN0, Trunc, ExtLoad.getValue(1));
+ return SDValue(N,0); // Return N so it doesn't get rechecked!
}
}
}
if (N0.getOpcode() == ISD::SETCC) {
- EVT N0VT = N0.getOperand(0).getValueType();
+ SDValue N00 = N0.getOperand(0);
+ SDValue N01 = N0.getOperand(1);
+ ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
+ EVT N00VT = N0.getOperand(0).getValueType();
+
// sext(setcc) -> sext_in_reg(vsetcc) for vectors.
// Only do this before legalize for now.
if (VT.isVector() && !LegalOperations &&
- TLI.getBooleanContents(N0VT) ==
+ TLI.getBooleanContents(N00VT) ==
TargetLowering::ZeroOrNegativeOneBooleanContent) {
// On some architectures (such as SSE/NEON/etc) the SETCC result type is
// of the same size as the compared operands. Only optimize sext(setcc())
// if this is the case.
- EVT SVT = getSetCCResultType(N0VT);
+ EVT SVT = getSetCCResultType(N00VT);
// We know that the # elements of the results is the same as the
// # elements of the compare (and the # elements of the compare result
@@ -6444,19 +7276,15 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
// we know that the element size of the sext'd result matches the
// element size of the compare operands.
if (VT.getSizeInBits() == SVT.getSizeInBits())
- return DAG.getSetCC(SDLoc(N), VT, N0.getOperand(0),
- N0.getOperand(1),
- cast<CondCodeSDNode>(N0.getOperand(2))->get());
+ return DAG.getSetCC(DL, VT, N00, N01, CC);
// If the desired elements are smaller or larger than the source
- // elements we can use a matching integer vector type and then
- // truncate/sign extend
- EVT MatchingVectorType = N0VT.changeVectorElementTypeToInteger();
- if (SVT == MatchingVectorType) {
- SDValue VsetCC = DAG.getSetCC(SDLoc(N), MatchingVectorType,
- N0.getOperand(0), N0.getOperand(1),
- cast<CondCodeSDNode>(N0.getOperand(2))->get());
- return DAG.getSExtOrTrunc(VsetCC, SDLoc(N), VT);
+ // elements, we can use a matching integer vector type and then
+ // truncate/sign extend.
+ EVT MatchingVecType = N00VT.changeVectorElementTypeToInteger();
+ if (SVT == MatchingVecType) {
+ SDValue VsetCC = DAG.getSetCC(DL, MatchingVecType, N00, N01, CC);
+ return DAG.getSExtOrTrunc(VsetCC, DL, VT);
}
}
@@ -6465,36 +7293,30 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
// getBooleanContents().
unsigned SetCCWidth = N0.getScalarValueSizeInBits();
- SDLoc DL(N);
// To determine the "true" side of the select, we need to know the high bit
// of the value returned by the setcc if it evaluates to true.
// If the type of the setcc is i1, then the true case of the select is just
// sext(i1 1), that is, -1.
// If the type of the setcc is larger (say, i8) then the value of the high
- // bit depends on getBooleanContents(). So, ask TLI for a real "true" value
+ // bit depends on getBooleanContents(), so ask TLI for a real "true" value
// of the appropriate width.
- SDValue ExtTrueVal =
- (SetCCWidth == 1)
- ? DAG.getConstant(APInt::getAllOnesValue(VT.getScalarSizeInBits()),
- DL, VT)
- : TLI.getConstTrueVal(DAG, VT, DL);
-
- if (SDValue SCC = SimplifySelectCC(
- DL, N0.getOperand(0), N0.getOperand(1), ExtTrueVal,
- DAG.getConstant(0, DL, VT),
- cast<CondCodeSDNode>(N0.getOperand(2))->get(), true))
+ SDValue ExtTrueVal = (SetCCWidth == 1) ? DAG.getAllOnesConstant(DL, VT)
+ : TLI.getConstTrueVal(DAG, VT, DL);
+ SDValue Zero = DAG.getConstant(0, DL, VT);
+ if (SDValue SCC =
+ SimplifySelectCC(DL, N00, N01, ExtTrueVal, Zero, CC, true))
return SCC;
if (!VT.isVector()) {
- EVT SetCCVT = getSetCCResultType(N0.getOperand(0).getValueType());
- if (!LegalOperations ||
- TLI.isOperationLegal(ISD::SETCC, N0.getOperand(0).getValueType())) {
- SDLoc DL(N);
- ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
- SDValue SetCC =
- DAG.getSetCC(DL, SetCCVT, N0.getOperand(0), N0.getOperand(1), CC);
- return DAG.getSelect(DL, VT, SetCC, ExtTrueVal,
- DAG.getConstant(0, DL, VT));
+ EVT SetCCVT = getSetCCResultType(N00VT);
+ // Don't do this transform for i1 because there's a select transform
+ // that would reverse it.
+ // TODO: We should not do this transform at all without a target hook
+ // because a sext is likely cheaper than a select?
+ if (SetCCVT.getScalarSizeInBits() != 1 &&
+ (!LegalOperations || TLI.isOperationLegal(ISD::SETCC, N00VT))) {
+ SDValue SetCC = DAG.getSetCC(DL, SetCCVT, N00, N01, CC);
+ return DAG.getSelect(DL, VT, SetCC, ExtTrueVal, Zero);
}
}
}
@@ -6502,21 +7324,23 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
// fold (sext x) -> (zext x) if the sign bit is known zero.
if ((!LegalOperations || TLI.isOperationLegal(ISD::ZERO_EXTEND, VT)) &&
DAG.SignBitIsZero(N0))
- return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, N0);
+ return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0);
+
+ if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
+ return NewVSel;
return SDValue();
}
// isTruncateOf - If N is a truncate of some other value, return true, record
-// the value being truncated in Op and which of Op's bits are zero in KnownZero.
-// This function computes KnownZero to avoid a duplicated call to
+// the value being truncated in Op and which of Op's bits are zero/one in Known.
+// This function computes KnownBits to avoid a duplicated call to
// computeKnownBits in the caller.
static bool isTruncateOf(SelectionDAG &DAG, SDValue N, SDValue &Op,
- APInt &KnownZero) {
- APInt KnownOne;
+ KnownBits &Known) {
if (N->getOpcode() == ISD::TRUNCATE) {
Op = N->getOperand(0);
- DAG.computeKnownBits(Op, KnownZero, KnownOne);
+ DAG.computeKnownBits(Op, Known);
return true;
}
@@ -6535,9 +7359,9 @@ static bool isTruncateOf(SelectionDAG &DAG, SDValue N, SDValue &Op,
else
return false;
- DAG.computeKnownBits(Op, KnownZero, KnownOne);
+ DAG.computeKnownBits(Op, Known);
- if (!(KnownZero | APInt(Op.getValueSizeInBits(), 1)).isAllOnesValue())
+ if (!(Known.Zero | 1).isAllOnesValue())
return false;
return true;
@@ -6562,8 +7386,8 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
// This is valid when the truncated bits of x are already zero.
// FIXME: We should extend this to work for vectors too.
SDValue Op;
- APInt KnownZero;
- if (!VT.isVector() && isTruncateOf(DAG, N0, Op, KnownZero)) {
+ KnownBits Known;
+ if (!VT.isVector() && isTruncateOf(DAG, N0, Op, Known)) {
APInt TruncatedBits =
(Op.getValueSizeInBits() == N0.getValueSizeInBits()) ?
APInt(Op.getValueSizeInBits(), 0) :
@@ -6571,14 +7395,8 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
N0.getValueSizeInBits(),
std::min(Op.getValueSizeInBits(),
VT.getSizeInBits()));
- if (TruncatedBits == (KnownZero & TruncatedBits)) {
- if (VT.bitsGT(Op.getValueType()))
- return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, Op);
- if (VT.bitsLT(Op.getValueType()))
- return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, Op);
-
- return Op;
- }
+ if (TruncatedBits.isSubsetOf(Known.Zero))
+ return DAG.getZExtOrTrunc(Op, SDLoc(N), VT);
}
// fold (zext (truncate (load x))) -> (zext (smaller load x))
@@ -6625,14 +7443,8 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
}
if (!LegalOperations || TLI.isOperationLegal(ISD::AND, VT)) {
- SDValue Op = N0.getOperand(0);
- if (SrcVT.bitsLT(VT)) {
- Op = DAG.getNode(ISD::ANY_EXTEND, SDLoc(N), VT, Op);
- AddToWorklist(Op.getNode());
- } else if (SrcVT.bitsGT(VT)) {
- Op = DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, Op);
- AddToWorklist(Op.getNode());
- }
+ SDValue Op = DAG.getAnyExtOrTrunc(N0.getOperand(0), SDLoc(N), VT);
+ AddToWorklist(Op.getNode());
return DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT.getScalarType());
}
}
@@ -6646,11 +7458,7 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
N0.getValueType()) ||
!TLI.isZExtFree(N0.getValueType(), VT))) {
SDValue X = N0.getOperand(0).getOperand(0);
- if (X.getValueType().bitsLT(VT)) {
- X = DAG.getNode(ISD::ANY_EXTEND, SDLoc(X), VT, X);
- } else if (X.getValueType().bitsGT(VT)) {
- X = DAG.getNode(ISD::TRUNCATE, SDLoc(X), VT, X);
- }
+ X = DAG.getAnyExtOrTrunc(X, SDLoc(X), VT);
APInt Mask = cast<ConstantSDNode>(N0.getOperand(1))->getAPIntValue();
Mask = Mask.zext(VT.getSizeInBits());
SDLoc DL(N);
@@ -6677,14 +7485,18 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
LN0->getChain(),
LN0->getBasePtr(), N0.getValueType(),
LN0->getMemOperand());
- CombineTo(N, ExtLoad);
+
SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(N0),
N0.getValueType(), ExtLoad);
- CombineTo(N0.getNode(), Trunc, ExtLoad.getValue(1));
-
- ExtendSetCCUses(SetCCs, Trunc, ExtLoad, SDLoc(N),
- ISD::ZERO_EXTEND);
- return SDValue(N, 0); // Return N so it doesn't get rechecked!
+ ExtendSetCCUses(SetCCs, Trunc, ExtLoad, SDLoc(N), ISD::ZERO_EXTEND);
+ // If the load value is used only by N, replace it via CombineTo N.
+ bool NoReplaceTrunc = SDValue(LN0, 0).hasOneUse();
+ CombineTo(N, ExtLoad);
+ if (NoReplaceTrunc)
+ DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
+ else
+ CombineTo(LN0, Trunc, ExtLoad.getValue(1));
+ return SDValue(N, 0); // Return N so it doesn't get rechecked!
}
}
@@ -6734,11 +7546,14 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
SDValue Trunc = DAG.getNode(ISD::TRUNCATE,
SDLoc(N0.getOperand(0)),
N0.getOperand(0).getValueType(), ExtLoad);
+ ExtendSetCCUses(SetCCs, Trunc, ExtLoad, DL, ISD::ZERO_EXTEND);
+ bool NoReplaceTrunc = SDValue(LN0, 0).hasOneUse();
CombineTo(N, And);
- CombineTo(N0.getOperand(0).getNode(), Trunc, ExtLoad.getValue(1));
- ExtendSetCCUses(SetCCs, Trunc, ExtLoad, DL,
- ISD::ZERO_EXTEND);
- return SDValue(N, 0); // Return N so it doesn't get rechecked!
+ if (NoReplaceTrunc)
+ DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
+ else
+ CombineTo(LN0, Trunc, ExtLoad.getValue(1));
+ return SDValue(N,0); // Return N so it doesn't get rechecked!
}
}
}
@@ -6837,6 +7652,9 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
ShAmt);
}
+ if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
+ return NewVSel;
+
return SDValue();
}
@@ -6871,14 +7689,8 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) {
}
// fold (aext (truncate x))
- if (N0.getOpcode() == ISD::TRUNCATE) {
- SDValue TruncOp = N0.getOperand(0);
- if (TruncOp.getValueType() == VT)
- return TruncOp; // x iff x size == zext size.
- if (TruncOp.getValueType().bitsGT(VT))
- return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, TruncOp);
- return DAG.getNode(ISD::ANY_EXTEND, SDLoc(N), VT, TruncOp);
- }
+ if (N0.getOpcode() == ISD::TRUNCATE)
+ return DAG.getAnyExtOrTrunc(N0.getOperand(0), SDLoc(N), VT);
// Fold (aext (and (trunc x), cst)) -> (and x, cst)
// if the trunc is not free.
@@ -6889,11 +7701,7 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) {
N0.getValueType())) {
SDLoc DL(N);
SDValue X = N0.getOperand(0).getOperand(0);
- if (X.getValueType().bitsLT(VT)) {
- X = DAG.getNode(ISD::ANY_EXTEND, DL, VT, X);
- } else if (X.getValueType().bitsGT(VT)) {
- X = DAG.getNode(ISD::TRUNCATE, DL, VT, X);
- }
+ X = DAG.getAnyExtOrTrunc(X, DL, VT);
APInt Mask = cast<ConstantSDNode>(N0.getOperand(1))->getAPIntValue();
Mask = Mask.zext(VT.getSizeInBits());
return DAG.getNode(ISD::AND, DL, VT,
@@ -6917,13 +7725,18 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) {
LN0->getChain(),
LN0->getBasePtr(), N0.getValueType(),
LN0->getMemOperand());
- CombineTo(N, ExtLoad);
SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(N0),
N0.getValueType(), ExtLoad);
- CombineTo(N0.getNode(), Trunc, ExtLoad.getValue(1));
ExtendSetCCUses(SetCCs, Trunc, ExtLoad, SDLoc(N),
ISD::ANY_EXTEND);
- return SDValue(N, 0); // Return N so it doesn't get rechecked!
+ // If the load value is used only by N, replace it via CombineTo N.
+ bool NoReplaceTrunc = N0.hasOneUse();
+ CombineTo(N, ExtLoad);
+ if (NoReplaceTrunc)
+ DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
+ else
+ CombineTo(LN0, Trunc, ExtLoad.getValue(1));
+ return SDValue(N, 0); // Return N so it doesn't get rechecked!
}
}
@@ -6991,9 +7804,25 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) {
return SDValue();
}
+SDValue DAGCombiner::visitAssertZext(SDNode *N) {
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+ EVT EVT = cast<VTSDNode>(N1)->getVT();
+
+ // fold (assertzext (assertzext x, vt), vt) -> (assertzext x, vt)
+ if (N0.getOpcode() == ISD::AssertZext &&
+ EVT == cast<VTSDNode>(N0.getOperand(1))->getVT())
+ return N0;
+
+ return SDValue();
+}
+
/// See if the specified operand can be simplified with the knowledge that only
/// the bits specified by Mask are used. If so, return the simpler operand,
/// otherwise return a null SDValue.
+///
+/// (This exists alongside SimplifyDemandedBits because GetDemandedBits can
+/// simplify nodes with multiple uses more aggressively.)
SDValue DAGCombiner::GetDemandedBits(SDValue V, const APInt &Mask) {
switch (V.getOpcode()) {
default: break;
@@ -7029,6 +7858,14 @@ SDValue DAGCombiner::GetDemandedBits(SDValue V, const APInt &Mask) {
return DAG.getNode(ISD::SRL, SDLoc(V), V.getValueType(),
SimplifyLHS, V.getOperand(1));
}
+ break;
+ case ISD::AND: {
+ // X & -1 -> X (ignoring bits which aren't demanded).
+ ConstantSDNode *AndVal = isConstOrConstSplat(V.getOperand(1));
+ if (AndVal && (AndVal->getAPIntValue() & Mask) == Mask)
+ return V.getOperand(0);
+ break;
+ }
}
return SDValue();
}
@@ -7169,7 +8006,7 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) {
SDValue NewPtr = DAG.getNode(ISD::ADD, DL,
PtrType, LN0->getBasePtr(),
DAG.getConstant(PtrOff, DL, PtrType),
- &Flags);
+ Flags);
AddToWorklist(NewPtr.getNode());
SDValue Load;
@@ -7244,6 +8081,16 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) {
return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, N00, N1);
}
+ // fold (sext_in_reg (*_extend_vector_inreg x)) -> (sext_vector_in_reg x)
+ if ((N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG ||
+ N0.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG ||
+ N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG) &&
+ N0.getOperand(0).getScalarValueSizeInBits() == EVTBits) {
+ if (!LegalOperations ||
+ TLI.isOperationLegal(ISD::SIGN_EXTEND_VECTOR_INREG, VT))
+ return DAG.getSignExtendVectorInReg(N0.getOperand(0), SDLoc(N), VT);
+ }
+
// fold (sext_in_reg (zext x)) -> (sext x)
// iff we are extending the source sign bit.
if (N0.getOpcode() == ISD::ZERO_EXTEND) {
@@ -7254,7 +8101,7 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) {
}
// fold (sext_in_reg x) -> (zext_in_reg x) if the sign bit is known zero.
- if (DAG.MaskedValueIsZero(N0, APInt::getBitsSet(VTBits, EVTBits-1, EVTBits)))
+ if (DAG.MaskedValueIsZero(N0, APInt::getOneBitSet(VTBits, EVTBits - 1)))
return DAG.getZeroExtendInReg(N0, SDLoc(N), EVT.getScalarType());
// fold operands of sext_in_reg based on knowledge that the top bits are not
@@ -7439,18 +8286,20 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
if (N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
(!LegalOperations || TLI.isOperationLegalOrCustom(ISD::SHL, VT)) &&
TLI.isTypeDesirableForOp(ISD::SHL, VT)) {
- if (const ConstantSDNode *CAmt = isConstOrConstSplat(N0.getOperand(1))) {
- uint64_t Amt = CAmt->getZExtValue();
- unsigned Size = VT.getScalarSizeInBits();
-
- if (Amt < Size) {
- SDLoc SL(N);
- EVT AmtVT = TLI.getShiftAmountTy(VT, DAG.getDataLayout());
+ SDValue Amt = N0.getOperand(1);
+ KnownBits Known;
+ DAG.computeKnownBits(Amt, Known);
+ unsigned Size = VT.getScalarSizeInBits();
+ if (Known.getBitWidth() - Known.countMinLeadingZeros() <= Log2_32(Size)) {
+ SDLoc SL(N);
+ EVT AmtVT = TLI.getShiftAmountTy(VT, DAG.getDataLayout());
- SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(0));
- return DAG.getNode(ISD::SHL, SL, VT, Trunc,
- DAG.getConstant(Amt, SL, AmtVT));
+ SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(0));
+ if (AmtVT != Amt.getValueType()) {
+ Amt = DAG.getZExtOrTrunc(Amt, SL, AmtVT);
+ AddToWorklist(Amt.getNode());
}
+ return DAG.getNode(ISD::SHL, SL, VT, Trunc, Amt);
}
}
@@ -7496,6 +8345,7 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
VT.getSizeInBits())))
return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, Shorter);
}
+
// fold (truncate (load x)) -> (smaller load x)
// fold (truncate (srl (load x), c)) -> (smaller load (x+c/evtbits))
if (!LegalTypes || TLI.isTypeDesirableForOp(N0.getOpcode(), VT)) {
@@ -7517,6 +8367,7 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
}
}
}
+
// fold (trunc (concat ... x ...)) -> (concat ..., (trunc x), ...)),
// where ... are all 'undef'.
if (N0.getOpcode() == ISD::CONCAT_VECTORS && !LegalTypes) {
@@ -7582,6 +8433,22 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
SimplifyDemandedBits(SDValue(N, 0)))
return SDValue(N, 0);
+ // (trunc adde(X, Y, Carry)) -> (adde trunc(X), trunc(Y), Carry)
+ // (trunc addcarry(X, Y, Carry)) -> (addcarry trunc(X), trunc(Y), Carry)
+ // When the adde's carry is not used.
+ if ((N0.getOpcode() == ISD::ADDE || N0.getOpcode() == ISD::ADDCARRY) &&
+ N0.hasOneUse() && !N0.getNode()->hasAnyUseOfValue(1) &&
+ (!LegalOperations || TLI.isOperationLegal(N0.getOpcode(), VT))) {
+ SDLoc SL(N);
+ auto X = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(0));
+ auto Y = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(1));
+ auto VTs = DAG.getVTList(VT, N0->getValueType(1));
+ return DAG.getNode(N0.getOpcode(), SL, VTs, X, Y, N0.getOperand(2));
+ }
+
+ if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
+ return NewVSel;
+
return SDValue();
}
@@ -7645,11 +8512,11 @@ static SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
switch (N0.getOpcode()) {
case ISD::AND:
FPOpcode = ISD::FABS;
- SignMask = ~APInt::getSignBit(SourceVT.getSizeInBits());
+ SignMask = ~APInt::getSignMask(SourceVT.getSizeInBits());
break;
case ISD::XOR:
FPOpcode = ISD::FNEG;
- SignMask = APInt::getSignBit(SourceVT.getSizeInBits());
+ SignMask = APInt::getSignMask(SourceVT.getSizeInBits());
break;
// TODO: ISD::OR --> ISD::FNABS?
default:
@@ -7672,6 +8539,9 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) {
SDValue N0 = N->getOperand(0);
EVT VT = N->getValueType(0);
+ if (N0.isUndef())
+ return DAG.getUNDEF(VT);
+
// If the input is a BUILD_VECTOR with all constant elements, fold this now.
// Only do this before legalize, since afterward the target may be depending
// on the bitconvert.
@@ -7757,7 +8627,7 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) {
if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
assert(VT.getSizeInBits() == 128);
SDValue SignBit = DAG.getConstant(
- APInt::getSignBit(VT.getSizeInBits() / 2), SDLoc(N0), MVT::i64);
+ APInt::getSignMask(VT.getSizeInBits() / 2), SDLoc(N0), MVT::i64);
SDValue FlipBit;
if (N0.getOpcode() == ISD::FNEG) {
FlipBit = SignBit;
@@ -7777,7 +8647,7 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) {
AddToWorklist(FlipBits.getNode());
return DAG.getNode(ISD::XOR, DL, VT, NewConv, FlipBits);
}
- APInt SignBit = APInt::getSignBit(VT.getSizeInBits());
+ APInt SignBit = APInt::getSignMask(VT.getSizeInBits());
if (N0.getOpcode() == ISD::FNEG)
return DAG.getNode(ISD::XOR, DL, VT,
NewConv, DAG.getConstant(SignBit, DL, VT));
@@ -7825,7 +8695,7 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) {
}
if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
- APInt SignBit = APInt::getSignBit(VT.getSizeInBits() / 2);
+ APInt SignBit = APInt::getSignMask(VT.getSizeInBits() / 2);
SDValue Cst = DAG.getBitcast(VT, N0.getOperand(0));
AddToWorklist(Cst.getNode());
SDValue X = DAG.getBitcast(VT, N0.getOperand(1));
@@ -7846,7 +8716,7 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) {
AddToWorklist(FlipBits.getNode());
return DAG.getNode(ISD::XOR, SDLoc(N), VT, Cst, FlipBits);
}
- APInt SignBit = APInt::getSignBit(VT.getSizeInBits());
+ APInt SignBit = APInt::getSignMask(VT.getSizeInBits());
X = DAG.getNode(ISD::AND, SDLoc(X), VT,
X, DAG.getConstant(SignBit, SDLoc(X), VT));
AddToWorklist(X.getNode());
@@ -8029,7 +8899,7 @@ ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) {
for (unsigned j = 0; j != NumOutputsPerInput; ++j) {
APInt ThisVal = OpVal.trunc(DstBitSize);
Ops.push_back(DAG.getConstant(ThisVal, DL, DstEltVT));
- OpVal = OpVal.lshr(DstBitSize);
+ OpVal.lshrInPlace(DstBitSize);
}
// For big endian targets, swap the order of the pieces of each element.
@@ -8040,6 +8910,11 @@ ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) {
return DAG.getBuildVector(VT, DL, Ops);
}
+static bool isContractable(SDNode *N) {
+ SDNodeFlags F = N->getFlags();
+ return F.hasAllowContract() || F.hasUnsafeAlgebra();
+}
+
/// Try to perform FMA combining on a given FADD node.
SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
SDValue N0 = N->getOperand(0);
@@ -8048,24 +8923,27 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
SDLoc SL(N);
const TargetOptions &Options = DAG.getTarget().Options;
- bool AllowFusion =
- (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath);
// Floating-point multiply-add with intermediate rounding.
bool HasFMAD = (LegalOperations && TLI.isOperationLegal(ISD::FMAD, VT));
// Floating-point multiply-add without intermediate rounding.
bool HasFMA =
- AllowFusion && TLI.isFMAFasterThanFMulAndFAdd(VT) &&
+ TLI.isFMAFasterThanFMulAndFAdd(VT) &&
(!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
// No valid opcode, do not combine.
if (!HasFMAD && !HasFMA)
return SDValue();
+ bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast ||
+ Options.UnsafeFPMath || HasFMAD);
+ // If the addition is not contractable, do not combine.
+ if (!AllowFusionGlobally && !isContractable(N))
+ return SDValue();
+
const SelectionDAGTargetInfo *STI = DAG.getSubtarget().getSelectionDAGInfo();
- ;
- if (AllowFusion && STI && STI->generateFMAsInMachineCombiner(OptLevel))
+ if (STI && STI->generateFMAsInMachineCombiner(OptLevel))
return SDValue();
// Always prefer FMAD to FMA for precision.
@@ -8073,35 +8951,39 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
bool LookThroughFPExt = TLI.isFPExtFree(VT);
+ // Is the node an FMUL and contractable either due to global flags or
+ // SDNodeFlags.
+ auto isContractableFMUL = [AllowFusionGlobally](SDValue N) {
+ if (N.getOpcode() != ISD::FMUL)
+ return false;
+ return AllowFusionGlobally || isContractable(N.getNode());
+ };
// If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
// prefer to fold the multiply with fewer uses.
- if (Aggressive && N0.getOpcode() == ISD::FMUL &&
- N1.getOpcode() == ISD::FMUL) {
+ if (Aggressive && isContractableFMUL(N0) && isContractableFMUL(N1)) {
if (N0.getNode()->use_size() > N1.getNode()->use_size())
std::swap(N0, N1);
}
// fold (fadd (fmul x, y), z) -> (fma x, y, z)
- if (N0.getOpcode() == ISD::FMUL &&
- (Aggressive || N0->hasOneUse())) {
+ if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) {
return DAG.getNode(PreferredFusedOpcode, SL, VT,
N0.getOperand(0), N0.getOperand(1), N1);
}
// fold (fadd x, (fmul y, z)) -> (fma y, z, x)
// Note: Commutes FADD operands.
- if (N1.getOpcode() == ISD::FMUL &&
- (Aggressive || N1->hasOneUse())) {
+ if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) {
return DAG.getNode(PreferredFusedOpcode, SL, VT,
N1.getOperand(0), N1.getOperand(1), N0);
}
// Look through FP_EXTEND nodes to do more combining.
- if (AllowFusion && LookThroughFPExt) {
+ if (LookThroughFPExt) {
// fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z)
if (N0.getOpcode() == ISD::FP_EXTEND) {
SDValue N00 = N0.getOperand(0);
- if (N00.getOpcode() == ISD::FMUL)
+ if (isContractableFMUL(N00))
return DAG.getNode(PreferredFusedOpcode, SL, VT,
DAG.getNode(ISD::FP_EXTEND, SL, VT,
N00.getOperand(0)),
@@ -8113,7 +8995,7 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
// Note: Commutes FADD operands.
if (N1.getOpcode() == ISD::FP_EXTEND) {
SDValue N10 = N1.getOperand(0);
- if (N10.getOpcode() == ISD::FMUL)
+ if (isContractableFMUL(N10))
return DAG.getNode(PreferredFusedOpcode, SL, VT,
DAG.getNode(ISD::FP_EXTEND, SL, VT,
N10.getOperand(0)),
@@ -8154,7 +9036,7 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
N0));
}
- if (AllowFusion && LookThroughFPExt) {
+ if (LookThroughFPExt) {
// fold (fadd (fma x, y, (fpext (fmul u, v))), z)
// -> (fma x, y, (fma (fpext u), (fpext v), z))
auto FoldFAddFMAFPExtFMul = [&] (
@@ -8169,7 +9051,7 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
SDValue N02 = N0.getOperand(2);
if (N02.getOpcode() == ISD::FP_EXTEND) {
SDValue N020 = N02.getOperand(0);
- if (N020.getOpcode() == ISD::FMUL)
+ if (isContractableFMUL(N020))
return FoldFAddFMAFPExtFMul(N0.getOperand(0), N0.getOperand(1),
N020.getOperand(0), N020.getOperand(1),
N1);
@@ -8195,7 +9077,7 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
SDValue N00 = N0.getOperand(0);
if (N00.getOpcode() == PreferredFusedOpcode) {
SDValue N002 = N00.getOperand(2);
- if (N002.getOpcode() == ISD::FMUL)
+ if (isContractableFMUL(N002))
return FoldFAddFPExtFMAFMul(N00.getOperand(0), N00.getOperand(1),
N002.getOperand(0), N002.getOperand(1),
N1);
@@ -8208,7 +9090,7 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
SDValue N12 = N1.getOperand(2);
if (N12.getOpcode() == ISD::FP_EXTEND) {
SDValue N120 = N12.getOperand(0);
- if (N120.getOpcode() == ISD::FMUL)
+ if (isContractableFMUL(N120))
return FoldFAddFMAFPExtFMul(N1.getOperand(0), N1.getOperand(1),
N120.getOperand(0), N120.getOperand(1),
N0);
@@ -8224,7 +9106,7 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
SDValue N10 = N1.getOperand(0);
if (N10.getOpcode() == PreferredFusedOpcode) {
SDValue N102 = N10.getOperand(2);
- if (N102.getOpcode() == ISD::FMUL)
+ if (isContractableFMUL(N102))
return FoldFAddFPExtFMAFMul(N10.getOperand(0), N10.getOperand(1),
N102.getOperand(0), N102.getOperand(1),
N0);
@@ -8244,23 +9126,26 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
SDLoc SL(N);
const TargetOptions &Options = DAG.getTarget().Options;
- bool AllowFusion =
- (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath);
-
// Floating-point multiply-add with intermediate rounding.
bool HasFMAD = (LegalOperations && TLI.isOperationLegal(ISD::FMAD, VT));
// Floating-point multiply-add without intermediate rounding.
bool HasFMA =
- AllowFusion && TLI.isFMAFasterThanFMulAndFAdd(VT) &&
+ TLI.isFMAFasterThanFMulAndFAdd(VT) &&
(!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
// No valid opcode, do not combine.
if (!HasFMAD && !HasFMA)
return SDValue();
+ bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast ||
+ Options.UnsafeFPMath || HasFMAD);
+ // If the subtraction is not contractable, do not combine.
+ if (!AllowFusionGlobally && !isContractable(N))
+ return SDValue();
+
const SelectionDAGTargetInfo *STI = DAG.getSubtarget().getSelectionDAGInfo();
- if (AllowFusion && STI && STI->generateFMAsInMachineCombiner(OptLevel))
+ if (STI && STI->generateFMAsInMachineCombiner(OptLevel))
return SDValue();
// Always prefer FMAD to FMA for precision.
@@ -8268,9 +9153,16 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
bool LookThroughFPExt = TLI.isFPExtFree(VT);
+ // Is the node an FMUL and contractable either due to global flags or
+ // SDNodeFlags.
+ auto isContractableFMUL = [AllowFusionGlobally](SDValue N) {
+ if (N.getOpcode() != ISD::FMUL)
+ return false;
+ return AllowFusionGlobally || isContractable(N.getNode());
+ };
+
// fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
- if (N0.getOpcode() == ISD::FMUL &&
- (Aggressive || N0->hasOneUse())) {
+ if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) {
return DAG.getNode(PreferredFusedOpcode, SL, VT,
N0.getOperand(0), N0.getOperand(1),
DAG.getNode(ISD::FNEG, SL, VT, N1));
@@ -8278,16 +9170,14 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
// fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
// Note: Commutes FSUB operands.
- if (N1.getOpcode() == ISD::FMUL &&
- (Aggressive || N1->hasOneUse()))
+ if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse()))
return DAG.getNode(PreferredFusedOpcode, SL, VT,
DAG.getNode(ISD::FNEG, SL, VT,
N1.getOperand(0)),
N1.getOperand(1), N0);
// fold (fsub (fneg (fmul, x, y)), z) -> (fma (fneg x), y, (fneg z))
- if (N0.getOpcode() == ISD::FNEG &&
- N0.getOperand(0).getOpcode() == ISD::FMUL &&
+ if (N0.getOpcode() == ISD::FNEG && isContractableFMUL(N0.getOperand(0)) &&
(Aggressive || (N0->hasOneUse() && N0.getOperand(0).hasOneUse()))) {
SDValue N00 = N0.getOperand(0).getOperand(0);
SDValue N01 = N0.getOperand(0).getOperand(1);
@@ -8297,12 +9187,12 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
}
// Look through FP_EXTEND nodes to do more combining.
- if (AllowFusion && LookThroughFPExt) {
+ if (LookThroughFPExt) {
// fold (fsub (fpext (fmul x, y)), z)
// -> (fma (fpext x), (fpext y), (fneg z))
if (N0.getOpcode() == ISD::FP_EXTEND) {
SDValue N00 = N0.getOperand(0);
- if (N00.getOpcode() == ISD::FMUL)
+ if (isContractableFMUL(N00))
return DAG.getNode(PreferredFusedOpcode, SL, VT,
DAG.getNode(ISD::FP_EXTEND, SL, VT,
N00.getOperand(0)),
@@ -8316,7 +9206,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
// Note: Commutes FSUB operands.
if (N1.getOpcode() == ISD::FP_EXTEND) {
SDValue N10 = N1.getOperand(0);
- if (N10.getOpcode() == ISD::FMUL)
+ if (isContractableFMUL(N10))
return DAG.getNode(PreferredFusedOpcode, SL, VT,
DAG.getNode(ISD::FNEG, SL, VT,
DAG.getNode(ISD::FP_EXTEND, SL, VT,
@@ -8336,7 +9226,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
SDValue N00 = N0.getOperand(0);
if (N00.getOpcode() == ISD::FNEG) {
SDValue N000 = N00.getOperand(0);
- if (N000.getOpcode() == ISD::FMUL) {
+ if (isContractableFMUL(N000)) {
return DAG.getNode(ISD::FNEG, SL, VT,
DAG.getNode(PreferredFusedOpcode, SL, VT,
DAG.getNode(ISD::FP_EXTEND, SL, VT,
@@ -8358,7 +9248,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
SDValue N00 = N0.getOperand(0);
if (N00.getOpcode() == ISD::FP_EXTEND) {
SDValue N000 = N00.getOperand(0);
- if (N000.getOpcode() == ISD::FMUL) {
+ if (isContractableFMUL(N000)) {
return DAG.getNode(ISD::FNEG, SL, VT,
DAG.getNode(PreferredFusedOpcode, SL, VT,
DAG.getNode(ISD::FP_EXTEND, SL, VT,
@@ -8378,10 +9268,9 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
// -> (fma x, y (fma u, v, (fneg z)))
// FIXME: The UnsafeAlgebra flag should be propagated to FMA/FMAD, but FMF
// are currently only supported on binary nodes.
- if (Options.UnsafeFPMath &&
- N0.getOpcode() == PreferredFusedOpcode &&
- N0.getOperand(2).getOpcode() == ISD::FMUL &&
- N0->hasOneUse() && N0.getOperand(2)->hasOneUse()) {
+ if (Options.UnsafeFPMath && N0.getOpcode() == PreferredFusedOpcode &&
+ isContractableFMUL(N0.getOperand(2)) && N0->hasOneUse() &&
+ N0.getOperand(2)->hasOneUse()) {
return DAG.getNode(PreferredFusedOpcode, SL, VT,
N0.getOperand(0), N0.getOperand(1),
DAG.getNode(PreferredFusedOpcode, SL, VT,
@@ -8395,9 +9284,8 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
// -> (fma (fneg y), z, (fma (fneg u), v, x))
// FIXME: The UnsafeAlgebra flag should be propagated to FMA/FMAD, but FMF
// are currently only supported on binary nodes.
- if (Options.UnsafeFPMath &&
- N1.getOpcode() == PreferredFusedOpcode &&
- N1.getOperand(2).getOpcode() == ISD::FMUL) {
+ if (Options.UnsafeFPMath && N1.getOpcode() == PreferredFusedOpcode &&
+ isContractableFMUL(N1.getOperand(2))) {
SDValue N20 = N1.getOperand(2).getOperand(0);
SDValue N21 = N1.getOperand(2).getOperand(1);
return DAG.getNode(PreferredFusedOpcode, SL, VT,
@@ -8410,14 +9298,14 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
N21, N0));
}
- if (AllowFusion && LookThroughFPExt) {
+ if (LookThroughFPExt) {
// fold (fsub (fma x, y, (fpext (fmul u, v))), z)
// -> (fma x, y (fma (fpext u), (fpext v), (fneg z)))
if (N0.getOpcode() == PreferredFusedOpcode) {
SDValue N02 = N0.getOperand(2);
if (N02.getOpcode() == ISD::FP_EXTEND) {
SDValue N020 = N02.getOperand(0);
- if (N020.getOpcode() == ISD::FMUL)
+ if (isContractableFMUL(N020))
return DAG.getNode(PreferredFusedOpcode, SL, VT,
N0.getOperand(0), N0.getOperand(1),
DAG.getNode(PreferredFusedOpcode, SL, VT,
@@ -8440,7 +9328,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
SDValue N00 = N0.getOperand(0);
if (N00.getOpcode() == PreferredFusedOpcode) {
SDValue N002 = N00.getOperand(2);
- if (N002.getOpcode() == ISD::FMUL)
+ if (isContractableFMUL(N002))
return DAG.getNode(PreferredFusedOpcode, SL, VT,
DAG.getNode(ISD::FP_EXTEND, SL, VT,
N00.getOperand(0)),
@@ -8461,7 +9349,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
if (N1.getOpcode() == PreferredFusedOpcode &&
N1.getOperand(2).getOpcode() == ISD::FP_EXTEND) {
SDValue N120 = N1.getOperand(2).getOperand(0);
- if (N120.getOpcode() == ISD::FMUL) {
+ if (isContractableFMUL(N120)) {
SDValue N1200 = N120.getOperand(0);
SDValue N1201 = N120.getOperand(1);
return DAG.getNode(PreferredFusedOpcode, SL, VT,
@@ -8488,7 +9376,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
SDValue N100 = N1.getOperand(0).getOperand(0);
SDValue N101 = N1.getOperand(0).getOperand(1);
SDValue N102 = N1.getOperand(0).getOperand(2);
- if (N102.getOpcode() == ISD::FMUL) {
+ if (isContractableFMUL(N102)) {
SDValue N1020 = N102.getOperand(0);
SDValue N1021 = N102.getOperand(1);
return DAG.getNode(PreferredFusedOpcode, SL, VT,
@@ -8601,6 +9489,14 @@ SDValue DAGCombiner::visitFMULForFMADistributiveCombine(SDNode *N) {
return SDValue();
}
+static bool isFMulNegTwo(SDValue &N) {
+ if (N.getOpcode() != ISD::FMUL)
+ return false;
+ if (ConstantFPSDNode *CFP = isConstOrConstSplatFP(N.getOperand(1)))
+ return CFP->isExactlyValue(-2.0);
+ return false;
+}
+
SDValue DAGCombiner::visitFADD(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
@@ -8609,7 +9505,7 @@ SDValue DAGCombiner::visitFADD(SDNode *N) {
EVT VT = N->getValueType(0);
SDLoc DL(N);
const TargetOptions &Options = DAG.getTarget().Options;
- const SDNodeFlags *Flags = &cast<BinaryWithFlagsSDNode>(N)->Flags;
+ const SDNodeFlags Flags = N->getFlags();
// fold vector ops
if (VT.isVector())
@@ -8624,6 +9520,9 @@ SDValue DAGCombiner::visitFADD(SDNode *N) {
if (N0CFP && !N1CFP)
return DAG.getNode(ISD::FADD, DL, VT, N1, N0, Flags);
+ if (SDValue NewSel = foldBinOpIntoSelect(N))
+ return NewSel;
+
// fold (fadd A, (fneg B)) -> (fsub A, B)
if ((!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT)) &&
isNegatibleForFree(N1, LegalOperations, TLI, &Options) == 2)
@@ -8636,8 +9535,18 @@ SDValue DAGCombiner::visitFADD(SDNode *N) {
return DAG.getNode(ISD::FSUB, DL, VT, N1,
GetNegatedExpression(N0, DAG, LegalOperations), Flags);
+ // fold (fadd A, (fmul B, -2.0)) -> (fsub A, (fadd B, B))
+ // fold (fadd (fmul B, -2.0), A) -> (fsub A, (fadd B, B))
+ if ((isFMulNegTwo(N0) && N0.hasOneUse()) ||
+ (isFMulNegTwo(N1) && N1.hasOneUse())) {
+ bool N1IsFMul = isFMulNegTwo(N1);
+ SDValue AddOp = N1IsFMul ? N1.getOperand(0) : N0.getOperand(0);
+ SDValue Add = DAG.getNode(ISD::FADD, DL, VT, AddOp, AddOp, Flags);
+ return DAG.getNode(ISD::FSUB, DL, VT, N1IsFMul ? N0 : N1, Add, Flags);
+ }
+
// FIXME: Auto-upgrade the target/function-level option.
- if (Options.UnsafeFPMath || N->getFlags()->hasNoSignedZeros()) {
+ if (Options.NoSignedZerosFPMath || N->getFlags().hasNoSignedZeros()) {
// fold (fadd A, 0) -> A
if (ConstantFPSDNode *N1C = isConstOrConstSplatFP(N1))
if (N1C->isZero())
@@ -8760,7 +9669,7 @@ SDValue DAGCombiner::visitFSUB(SDNode *N) {
EVT VT = N->getValueType(0);
SDLoc DL(N);
const TargetOptions &Options = DAG.getTarget().Options;
- const SDNodeFlags *Flags = &cast<BinaryWithFlagsSDNode>(N)->Flags;
+ const SDNodeFlags Flags = N->getFlags();
// fold vector ops
if (VT.isVector())
@@ -8771,13 +9680,16 @@ SDValue DAGCombiner::visitFSUB(SDNode *N) {
if (N0CFP && N1CFP)
return DAG.getNode(ISD::FSUB, DL, VT, N0, N1, Flags);
+ if (SDValue NewSel = foldBinOpIntoSelect(N))
+ return NewSel;
+
// fold (fsub A, (fneg B)) -> (fadd A, B)
if (isNegatibleForFree(N1, LegalOperations, TLI, &Options))
return DAG.getNode(ISD::FADD, DL, VT, N0,
GetNegatedExpression(N1, DAG, LegalOperations), Flags);
// FIXME: Auto-upgrade the target/function-level option.
- if (Options.UnsafeFPMath || N->getFlags()->hasNoSignedZeros()) {
+ if (Options.NoSignedZerosFPMath || N->getFlags().hasNoSignedZeros()) {
// (fsub 0, B) -> -B
if (N0CFP && N0CFP->isZero()) {
if (isNegatibleForFree(N1, LegalOperations, TLI, &Options))
@@ -8828,7 +9740,7 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) {
EVT VT = N->getValueType(0);
SDLoc DL(N);
const TargetOptions &Options = DAG.getTarget().Options;
- const SDNodeFlags *Flags = &cast<BinaryWithFlagsSDNode>(N)->Flags;
+ const SDNodeFlags Flags = N->getFlags();
// fold vector ops
if (VT.isVector()) {
@@ -8850,6 +9762,9 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) {
if (N1CFP && N1CFP->isExactlyValue(1.0))
return N0;
+ if (SDValue NewSel = foldBinOpIntoSelect(N))
+ return NewSel;
+
if (Options.UnsafeFPMath) {
// fold (fmul A, 0) -> 0
if (N1CFP && N1CFP->isZero())
@@ -8914,6 +9829,52 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) {
}
}
+ // fold (fmul X, (select (fcmp X > 0.0), -1.0, 1.0)) -> (fneg (fabs X))
+ // fold (fmul X, (select (fcmp X > 0.0), 1.0, -1.0)) -> (fabs X)
+ if (Flags.hasNoNaNs() && Flags.hasNoSignedZeros() &&
+ (N0.getOpcode() == ISD::SELECT || N1.getOpcode() == ISD::SELECT) &&
+ TLI.isOperationLegal(ISD::FABS, VT)) {
+ SDValue Select = N0, X = N1;
+ if (Select.getOpcode() != ISD::SELECT)
+ std::swap(Select, X);
+
+ SDValue Cond = Select.getOperand(0);
+ auto TrueOpnd = dyn_cast<ConstantFPSDNode>(Select.getOperand(1));
+ auto FalseOpnd = dyn_cast<ConstantFPSDNode>(Select.getOperand(2));
+
+ if (TrueOpnd && FalseOpnd &&
+ Cond.getOpcode() == ISD::SETCC && Cond.getOperand(0) == X &&
+ isa<ConstantFPSDNode>(Cond.getOperand(1)) &&
+ cast<ConstantFPSDNode>(Cond.getOperand(1))->isExactlyValue(0.0)) {
+ ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
+ switch (CC) {
+ default: break;
+ case ISD::SETOLT:
+ case ISD::SETULT:
+ case ISD::SETOLE:
+ case ISD::SETULE:
+ case ISD::SETLT:
+ case ISD::SETLE:
+ std::swap(TrueOpnd, FalseOpnd);
+ // Fall through
+ case ISD::SETOGT:
+ case ISD::SETUGT:
+ case ISD::SETOGE:
+ case ISD::SETUGE:
+ case ISD::SETGT:
+ case ISD::SETGE:
+ if (TrueOpnd->isExactlyValue(-1.0) && FalseOpnd->isExactlyValue(1.0) &&
+ TLI.isOperationLegal(ISD::FNEG, VT))
+ return DAG.getNode(ISD::FNEG, DL, VT,
+ DAG.getNode(ISD::FABS, DL, VT, X));
+ if (TrueOpnd->isExactlyValue(1.0) && FalseOpnd->isExactlyValue(-1.0))
+ return DAG.getNode(ISD::FABS, DL, VT, X);
+
+ break;
+ }
+ }
+ }
+
// FMUL -> FMA combines:
if (SDValue Fused = visitFMULForFMADistributiveCombine(N)) {
AddToWorklist(Fused.getNode());
@@ -8969,7 +9930,7 @@ SDValue DAGCombiner::visitFMA(SDNode *N) {
isConstantFPBuildVectorOrConstantFP(N2.getOperand(1))) {
return DAG.getNode(ISD::FMUL, DL, VT, N0,
DAG.getNode(ISD::FADD, DL, VT, N1, N2.getOperand(1),
- &Flags), &Flags);
+ Flags), Flags);
}
// (fma (fmul x, c1), c2, y) -> (fma x, c1*c2, y)
@@ -8979,7 +9940,7 @@ SDValue DAGCombiner::visitFMA(SDNode *N) {
return DAG.getNode(ISD::FMA, DL, VT,
N0.getOperand(0),
DAG.getNode(ISD::FMUL, DL, VT, N1, N0.getOperand(1),
- &Flags),
+ Flags),
N2);
}
}
@@ -9005,16 +9966,16 @@ SDValue DAGCombiner::visitFMA(SDNode *N) {
if (N1CFP && N0 == N2) {
return DAG.getNode(ISD::FMUL, DL, VT, N0,
DAG.getNode(ISD::FADD, DL, VT, N1,
- DAG.getConstantFP(1.0, DL, VT), &Flags),
- &Flags);
+ DAG.getConstantFP(1.0, DL, VT), Flags),
+ Flags);
}
// (fma x, c, (fneg x)) -> (fmul x, (c-1))
if (N1CFP && N2.getOpcode() == ISD::FNEG && N2.getOperand(0) == N0) {
return DAG.getNode(ISD::FMUL, DL, VT, N0,
DAG.getNode(ISD::FADD, DL, VT, N1,
- DAG.getConstantFP(-1.0, DL, VT), &Flags),
- &Flags);
+ DAG.getConstantFP(-1.0, DL, VT), Flags),
+ Flags);
}
}
@@ -9030,8 +9991,8 @@ SDValue DAGCombiner::visitFMA(SDNode *N) {
// is the critical path is increased from "one FDIV" to "one FDIV + one FMUL".
SDValue DAGCombiner::combineRepeatedFPDivisors(SDNode *N) {
bool UnsafeMath = DAG.getTarget().Options.UnsafeFPMath;
- const SDNodeFlags *Flags = N->getFlags();
- if (!UnsafeMath && !Flags->hasAllowReciprocal())
+ const SDNodeFlags Flags = N->getFlags();
+ if (!UnsafeMath && !Flags.hasAllowReciprocal())
return SDValue();
// Skip if current node is a reciprocal.
@@ -9054,7 +10015,7 @@ SDValue DAGCombiner::combineRepeatedFPDivisors(SDNode *N) {
if (U->getOpcode() == ISD::FDIV && U->getOperand(1) == N1) {
// This division is eligible for optimization only if global unsafe math
// is enabled or if this division allows reciprocal formation.
- if (UnsafeMath || U->getFlags()->hasAllowReciprocal())
+ if (UnsafeMath || U->getFlags().hasAllowReciprocal())
Users.insert(U);
}
}
@@ -9093,7 +10054,7 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) {
EVT VT = N->getValueType(0);
SDLoc DL(N);
const TargetOptions &Options = DAG.getTarget().Options;
- SDNodeFlags *Flags = &cast<BinaryWithFlagsSDNode>(N)->Flags;
+ SDNodeFlags Flags = N->getFlags();
// fold vector ops
if (VT.isVector())
@@ -9104,6 +10065,9 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) {
if (N0CFP && N1CFP)
return DAG.getNode(ISD::FDIV, SDLoc(N), VT, N0, N1, Flags);
+ if (SDValue NewSel = foldBinOpIntoSelect(N))
+ return NewSel;
+
if (Options.UnsafeFPMath) {
// fold (fdiv X, c2) -> fmul X, 1/c2 if losing precision is acceptable.
if (N1CFP) {
@@ -9204,8 +10168,10 @@ SDValue DAGCombiner::visitFREM(SDNode *N) {
// fold (frem c1, c2) -> fmod(c1,c2)
if (N0CFP && N1CFP)
- return DAG.getNode(ISD::FREM, SDLoc(N), VT, N0, N1,
- &cast<BinaryWithFlagsSDNode>(N)->Flags);
+ return DAG.getNode(ISD::FREM, SDLoc(N), VT, N0, N1, N->getFlags());
+
+ if (SDValue NewSel = foldBinOpIntoSelect(N))
+ return NewSel;
return SDValue();
}
@@ -9222,7 +10188,7 @@ SDValue DAGCombiner::visitFSQRT(SDNode *N) {
// For now, create a Flags object for use with all unsafe math transforms.
SDNodeFlags Flags;
Flags.setUnsafeAlgebra(true);
- return buildSqrtEstimate(N0, &Flags);
+ return buildSqrtEstimate(N0, Flags);
}
/// copysign(x, fp_extend(y)) -> copysign(x, y)
@@ -9497,6 +10463,9 @@ SDValue DAGCombiner::visitFP_ROUND(SDNode *N) {
Tmp, N0.getOperand(1));
}
+ if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
+ return NewVSel;
+
return SDValue();
}
@@ -9563,6 +10532,9 @@ SDValue DAGCombiner::visitFP_EXTEND(SDNode *N) {
return SDValue(N, 0); // Return N so it doesn't get rechecked!
}
+ if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
+ return NewVSel;
+
return SDValue();
}
@@ -9624,11 +10596,11 @@ SDValue DAGCombiner::visitFNEG(SDNode *N) {
if (N0.getValueType().isVector()) {
// For a vector, get a mask such as 0x80... per scalar element
// and splat it.
- SignMask = APInt::getSignBit(N0.getScalarValueSizeInBits());
+ SignMask = APInt::getSignMask(N0.getScalarValueSizeInBits());
SignMask = APInt::getSplat(IntVT.getSizeInBits(), SignMask);
} else {
// For a scalar, just generate 0x80...
- SignMask = APInt::getSignBit(IntVT.getSizeInBits());
+ SignMask = APInt::getSignMask(IntVT.getSizeInBits());
}
SDLoc DL0(N0);
Int = DAG.getNode(ISD::XOR, DL0, IntVT, Int,
@@ -9648,10 +10620,10 @@ SDValue DAGCombiner::visitFNEG(SDNode *N) {
if (Level >= AfterLegalizeDAG &&
(TLI.isFPImmLegal(CVal, VT) ||
TLI.isOperationLegal(ISD::ConstantFP, VT)))
- return DAG.getNode(ISD::FMUL, SDLoc(N), VT, N0.getOperand(0),
- DAG.getNode(ISD::FNEG, SDLoc(N), VT,
- N0.getOperand(1)),
- &cast<BinaryWithFlagsSDNode>(N0)->Flags);
+ return DAG.getNode(
+ ISD::FMUL, SDLoc(N), VT, N0.getOperand(0),
+ DAG.getNode(ISD::FNEG, SDLoc(N), VT, N0.getOperand(1)),
+ N0->getFlags());
}
}
@@ -9729,11 +10701,11 @@ SDValue DAGCombiner::visitFABS(SDNode *N) {
if (N0.getValueType().isVector()) {
// For a vector, get a mask such as 0x7f... per scalar element
// and splat it.
- SignMask = ~APInt::getSignBit(N0.getScalarValueSizeInBits());
+ SignMask = ~APInt::getSignMask(N0.getScalarValueSizeInBits());
SignMask = APInt::getSplat(IntVT.getSizeInBits(), SignMask);
} else {
// For a scalar, just generate 0x7f...
- SignMask = ~APInt::getSignBit(IntVT.getSizeInBits());
+ SignMask = ~APInt::getSignMask(IntVT.getSizeInBits());
}
SDLoc DL(N0);
Int = DAG.getNode(ISD::AND, DL, IntVT, Int,
@@ -10149,7 +11121,7 @@ bool DAGCombiner::CombineToPreIndexedLoadStore(SDNode *N) {
// x1 * offset1 + y1 * ptr0 = t1 (the indexed load/store)
//
// where x0, x1, y0 and y1 in {-1, 1} are given by the types of the
- // indexed load/store and the expresion that needs to be re-written.
+ // indexed load/store and the expression that needs to be re-written.
//
// Therefore, we have:
// t0 = (x0 * offset0 - x1 * y0 * y1 *offset1) + (y0 * y1) * t1
@@ -10361,7 +11333,7 @@ SDValue DAGCombiner::visitLOAD(SDNode *N) {
dbgs() << "\n");
WorklistRemover DeadNodes(*this);
DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Chain);
-
+ AddUsersToWorklist(Chain.getNode());
if (N->use_empty())
deleteAndRecombine(N);
@@ -10414,7 +11386,7 @@ SDValue DAGCombiner::visitLOAD(SDNode *N) {
StoreSDNode *PrevST = cast<StoreSDNode>(Chain);
if (PrevST->getBasePtr() == Ptr &&
PrevST->getValue().getValueType() == N->getValueType(0))
- return CombineTo(N, Chain.getOperand(1), Chain);
+ return CombineTo(N, PrevST->getOperand(1), Chain);
}
}
@@ -10432,14 +11404,7 @@ SDValue DAGCombiner::visitLOAD(SDNode *N) {
}
}
- bool UseAA = CombinerAA.getNumOccurrences() > 0 ? CombinerAA
- : DAG.getSubtarget().useAA();
-#ifndef NDEBUG
- if (CombinerAAOnlyFunc.getNumOccurrences() &&
- CombinerAAOnlyFunc != DAG.getMachineFunction().getName())
- UseAA = false;
-#endif
- if (UseAA && LD->isUnindexed()) {
+ if (LD->isUnindexed()) {
// Walk up chain skipping non-aliasing memory nodes.
SDValue BetterChain = FindBetterChain(N, Chain);
@@ -10462,12 +11427,8 @@ SDValue DAGCombiner::visitLOAD(SDNode *N) {
SDValue Token = DAG.getNode(ISD::TokenFactor, SDLoc(N),
MVT::Other, Chain, ReplLoad.getValue(1));
- // Make sure the new and old chains are cleaned up.
- AddToWorklist(Token.getNode());
-
- // Replace uses with load result and token factor. Don't add users
- // to work list.
- return CombineTo(N, ReplLoad.getValue(0), Token, false);
+ // Replace uses with load result and token factor
+ return CombineTo(N, ReplLoad.getValue(0), Token);
}
}
@@ -10490,7 +11451,7 @@ namespace {
/// Shift = srl Ty1 Origin, CstTy Amount
/// Inst = trunc Shift to Ty2
///
-/// Then, it will be rewriten into:
+/// Then, it will be rewritten into:
/// Slice = load SliceTy, Base + SliceOffset
/// [Inst = zext Slice to Ty2], only if SliceTy <> Ty2
///
@@ -10959,7 +11920,7 @@ bool DAGCombiner::SliceUpLoad(SDNode *N) {
// Check if this is a trunc(lshr).
if (User->getOpcode() == ISD::SRL && User->hasOneUse() &&
isa<ConstantSDNode>(User->getOperand(1))) {
- Shift = cast<ConstantSDNode>(User->getOperand(1))->getZExtValue();
+ Shift = User->getConstantOperandVal(1);
User = *User->use_begin();
}
@@ -11021,6 +11982,7 @@ bool DAGCombiner::SliceUpLoad(SDNode *N) {
SDValue Chain = DAG.getNode(ISD::TokenFactor, SDLoc(LD), MVT::Other,
ArgChains);
DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Chain);
+ AddToWorklist(Chain.getNode());
return true;
}
@@ -11414,44 +12376,36 @@ bool DAGCombiner::isMulAddWithConstProfitable(SDNode *MulNode,
return false;
}
-SDValue DAGCombiner::getMergedConstantVectorStore(
- SelectionDAG &DAG, const SDLoc &SL, ArrayRef<MemOpLink> Stores,
- SmallVectorImpl<SDValue> &Chains, EVT Ty) const {
- SmallVector<SDValue, 8> BuildVector;
+SDValue DAGCombiner::getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes,
+ unsigned NumStores) {
+ SmallVector<SDValue, 8> Chains;
+ SmallPtrSet<const SDNode *, 8> Visited;
+ SDLoc StoreDL(StoreNodes[0].MemNode);
+
+ for (unsigned i = 0; i < NumStores; ++i) {
+ Visited.insert(StoreNodes[i].MemNode);
+ }
- for (unsigned I = 0, E = Ty.getVectorNumElements(); I != E; ++I) {
- StoreSDNode *St = cast<StoreSDNode>(Stores[I].MemNode);
- Chains.push_back(St->getChain());
- BuildVector.push_back(St->getValue());
+ // don't include nodes that are children
+ for (unsigned i = 0; i < NumStores; ++i) {
+ if (Visited.count(StoreNodes[i].MemNode->getChain().getNode()) == 0)
+ Chains.push_back(StoreNodes[i].MemNode->getChain());
}
- return DAG.getBuildVector(Ty, SL, BuildVector);
+ assert(Chains.size() > 0 && "Chain should have generated a chain");
+ return DAG.getNode(ISD::TokenFactor, StoreDL, MVT::Other, Chains);
}
bool DAGCombiner::MergeStoresOfConstantsOrVecElts(
- SmallVectorImpl<MemOpLink> &StoreNodes, EVT MemVT,
- unsigned NumStores, bool IsConstantSrc, bool UseVector) {
+ SmallVectorImpl<MemOpLink> &StoreNodes, EVT MemVT, unsigned NumStores,
+ bool IsConstantSrc, bool UseVector, bool UseTrunc) {
// Make sure we have something to merge.
if (NumStores < 2)
return false;
int64_t ElementSizeBytes = MemVT.getSizeInBits() / 8;
- LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
- unsigned LatestNodeUsed = 0;
-
- for (unsigned i=0; i < NumStores; ++i) {
- // Find a chain for the new wide-store operand. Notice that some
- // of the store nodes that we found may not be selected for inclusion
- // in the wide store. The chain we use needs to be the chain of the
- // latest store node which is *used* and replaced by the wide store.
- if (StoreNodes[i].SequenceNum < StoreNodes[LatestNodeUsed].SequenceNum)
- LatestNodeUsed = i;
- }
-
- SmallVector<SDValue, 8> Chains;
// The latest Node in the DAG.
- LSBaseSDNode *LatestOp = StoreNodes[LatestNodeUsed].MemNode;
SDLoc DL(StoreNodes[0].MemNode);
SDValue StoredVal;
@@ -11467,7 +12421,18 @@ bool DAGCombiner::MergeStoresOfConstantsOrVecElts(
assert(TLI.isTypeLegal(Ty) && "Illegal vector store");
if (IsConstantSrc) {
- StoredVal = getMergedConstantVectorStore(DAG, DL, StoreNodes, Chains, Ty);
+ SmallVector<SDValue, 8> BuildVector;
+ for (unsigned I = 0, E = Ty.getVectorNumElements(); I != E; ++I) {
+ StoreSDNode *St = cast<StoreSDNode>(StoreNodes[I].MemNode);
+ SDValue Val = St->getValue();
+ if (MemVT.getScalarType().isInteger())
+ if (auto *CFP = dyn_cast<ConstantFPSDNode>(St->getValue()))
+ Val = DAG.getConstant(
+ (uint32_t)CFP->getValueAPF().bitcastToAPInt().getZExtValue(),
+ SDLoc(CFP), MemVT);
+ BuildVector.push_back(Val);
+ }
+ StoredVal = DAG.getBuildVector(Ty, DL, BuildVector);
} else {
SmallVector<SDValue, 8> Ops;
for (unsigned i = 0; i < NumStores; ++i) {
@@ -11477,7 +12442,6 @@ bool DAGCombiner::MergeStoresOfConstantsOrVecElts(
if (Val.getValueType() != MemVT)
return false;
Ops.push_back(Val);
- Chains.push_back(St->getChain());
}
// Build the extracted vector elements back into a vector.
@@ -11497,14 +12461,13 @@ bool DAGCombiner::MergeStoresOfConstantsOrVecElts(
for (unsigned i = 0; i < NumStores; ++i) {
unsigned Idx = IsLE ? (NumStores - 1 - i) : i;
StoreSDNode *St = cast<StoreSDNode>(StoreNodes[Idx].MemNode);
- Chains.push_back(St->getChain());
SDValue Val = St->getValue();
StoreInt <<= ElementSizeBytes * 8;
if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val)) {
- StoreInt |= C->getAPIntValue().zext(SizeInBits);
+ StoreInt |= C->getAPIntValue().zextOrTrunc(SizeInBits);
} else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Val)) {
- StoreInt |= C->getValueAPF().bitcastToAPInt().zext(SizeInBits);
+ StoreInt |= C->getValueAPF().bitcastToAPInt().zextOrTrunc(SizeInBits);
} else {
llvm_unreachable("Invalid constant element type");
}
@@ -11515,194 +12478,181 @@ bool DAGCombiner::MergeStoresOfConstantsOrVecElts(
StoredVal = DAG.getConstant(StoreInt, DL, StoreTy);
}
- assert(!Chains.empty());
-
- SDValue NewChain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Chains);
- SDValue NewStore = DAG.getStore(NewChain, DL, StoredVal,
- FirstInChain->getBasePtr(),
- FirstInChain->getPointerInfo(),
- FirstInChain->getAlignment());
-
- bool UseAA = CombinerAA.getNumOccurrences() > 0 ? CombinerAA
- : DAG.getSubtarget().useAA();
- if (UseAA) {
- // Replace all merged stores with the new store.
- for (unsigned i = 0; i < NumStores; ++i)
- CombineTo(StoreNodes[i].MemNode, NewStore);
- } else {
- // Replace the last store with the new store.
- CombineTo(LatestOp, NewStore);
- // Erase all other stores.
- for (unsigned i = 0; i < NumStores; ++i) {
- if (StoreNodes[i].MemNode == LatestOp)
- continue;
- StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
- // ReplaceAllUsesWith will replace all uses that existed when it was
- // called, but graph optimizations may cause new ones to appear. For
- // example, the case in pr14333 looks like
- //
- // St's chain -> St -> another store -> X
- //
- // And the only difference from St to the other store is the chain.
- // When we change it's chain to be St's chain they become identical,
- // get CSEed and the net result is that X is now a use of St.
- // Since we know that St is redundant, just iterate.
- while (!St->use_empty())
- DAG.ReplaceAllUsesWith(SDValue(St, 0), St->getChain());
- deleteAndRecombine(St);
- }
- }
-
- StoreNodes.erase(StoreNodes.begin() + NumStores, StoreNodes.end());
+ LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
+ SDValue NewChain = getMergeStoreChains(StoreNodes, NumStores);
+
+ // make sure we use trunc store if it's necessary to be legal.
+ SDValue NewStore;
+ if (UseVector || !UseTrunc) {
+ NewStore = DAG.getStore(NewChain, DL, StoredVal, FirstInChain->getBasePtr(),
+ FirstInChain->getPointerInfo(),
+ FirstInChain->getAlignment());
+ } else { // Must be realized as a trunc store
+ EVT LegalizedStoredValueTy =
+ TLI.getTypeToTransformTo(*DAG.getContext(), StoredVal.getValueType());
+ unsigned LegalizedStoreSize = LegalizedStoredValueTy.getSizeInBits();
+ ConstantSDNode *C = cast<ConstantSDNode>(StoredVal);
+ SDValue ExtendedStoreVal =
+ DAG.getConstant(C->getAPIntValue().zextOrTrunc(LegalizedStoreSize), DL,
+ LegalizedStoredValueTy);
+ NewStore = DAG.getTruncStore(
+ NewChain, DL, ExtendedStoreVal, FirstInChain->getBasePtr(),
+ FirstInChain->getPointerInfo(), StoredVal.getValueType() /*TVT*/,
+ FirstInChain->getAlignment(),
+ FirstInChain->getMemOperand()->getFlags());
+ }
+
+ // Replace all merged stores with the new store.
+ for (unsigned i = 0; i < NumStores; ++i)
+ CombineTo(StoreNodes[i].MemNode, NewStore);
+
+ AddToWorklist(NewChain.getNode());
return true;
}
-void DAGCombiner::getStoreMergeAndAliasCandidates(
- StoreSDNode* St, SmallVectorImpl<MemOpLink> &StoreNodes,
- SmallVectorImpl<LSBaseSDNode*> &AliasLoadNodes) {
+void DAGCombiner::getStoreMergeCandidates(
+ StoreSDNode *St, SmallVectorImpl<MemOpLink> &StoreNodes) {
// This holds the base pointer, index, and the offset in bytes from the base
// pointer.
BaseIndexOffset BasePtr = BaseIndexOffset::match(St->getBasePtr(), DAG);
+ EVT MemVT = St->getMemoryVT();
// We must have a base and an offset.
- if (!BasePtr.Base.getNode())
+ if (!BasePtr.getBase().getNode())
return;
// Do not handle stores to undef base pointers.
- if (BasePtr.Base.isUndef())
+ if (BasePtr.getBase().isUndef())
return;
- // Walk up the chain and look for nodes with offsets from the same
- // base pointer. Stop when reaching an instruction with a different kind
- // or instruction which has a different base pointer.
- EVT MemVT = St->getMemoryVT();
- unsigned Seq = 0;
- StoreSDNode *Index = St;
-
-
- bool UseAA = CombinerAA.getNumOccurrences() > 0 ? CombinerAA
- : DAG.getSubtarget().useAA();
-
- if (UseAA) {
- // Look at other users of the same chain. Stores on the same chain do not
- // alias. If combiner-aa is enabled, non-aliasing stores are canonicalized
- // to be on the same chain, so don't bother looking at adjacent chains.
-
- SDValue Chain = St->getChain();
- for (auto I = Chain->use_begin(), E = Chain->use_end(); I != E; ++I) {
- if (StoreSDNode *OtherST = dyn_cast<StoreSDNode>(*I)) {
- if (I.getOperandNo() != 0)
- continue;
-
- if (OtherST->isVolatile() || OtherST->isIndexed())
- continue;
-
- if (OtherST->getMemoryVT() != MemVT)
- continue;
-
- BaseIndexOffset Ptr = BaseIndexOffset::match(OtherST->getBasePtr(), DAG);
-
- if (Ptr.equalBaseIndex(BasePtr))
- StoreNodes.push_back(MemOpLink(OtherST, Ptr.Offset, Seq++));
- }
+ bool IsConstantSrc = isa<ConstantSDNode>(St->getValue()) ||
+ isa<ConstantFPSDNode>(St->getValue());
+ bool IsExtractVecSrc =
+ (St->getValue().getOpcode() == ISD::EXTRACT_VECTOR_ELT ||
+ St->getValue().getOpcode() == ISD::EXTRACT_SUBVECTOR);
+ bool IsLoadSrc = isa<LoadSDNode>(St->getValue());
+ BaseIndexOffset LBasePtr;
+ // Match on loadbaseptr if relevant.
+ if (IsLoadSrc)
+ LBasePtr = BaseIndexOffset::match(
+ cast<LoadSDNode>(St->getValue())->getBasePtr(), DAG);
+
+ auto CandidateMatch = [&](StoreSDNode *Other, BaseIndexOffset &Ptr,
+ int64_t &Offset) -> bool {
+ if (Other->isVolatile() || Other->isIndexed())
+ return false;
+ // We can merge constant floats to equivalent integers
+ if (Other->getMemoryVT() != MemVT)
+ if (!(MemVT.isInteger() && MemVT.bitsEq(Other->getMemoryVT()) &&
+ isa<ConstantFPSDNode>(Other->getValue())))
+ return false;
+ if (IsLoadSrc) {
+ // The Load's Base Ptr must also match
+ if (LoadSDNode *OtherLd = dyn_cast<LoadSDNode>(Other->getValue())) {
+ auto LPtr = BaseIndexOffset::match(OtherLd->getBasePtr(), DAG);
+ if (!(LBasePtr.equalBaseIndex(LPtr, DAG)))
+ return false;
+ } else
+ return false;
}
-
- return;
- }
-
- while (Index) {
- // If the chain has more than one use, then we can't reorder the mem ops.
- if (Index != St && !SDValue(Index, 0)->hasOneUse())
- break;
-
- // Find the base pointer and offset for this memory node.
- BaseIndexOffset Ptr = BaseIndexOffset::match(Index->getBasePtr(), DAG);
-
- // Check that the base pointer is the same as the original one.
- if (!Ptr.equalBaseIndex(BasePtr))
- break;
-
- // The memory operands must not be volatile.
- if (Index->isVolatile() || Index->isIndexed())
- break;
-
- // No truncation.
- if (Index->isTruncatingStore())
- break;
-
- // The stored memory type must be the same.
- if (Index->getMemoryVT() != MemVT)
- break;
-
- // We do not allow under-aligned stores in order to prevent
- // overriding stores. NOTE: this is a bad hack. Alignment SHOULD
- // be irrelevant here; what MATTERS is that we not move memory
- // operations that potentially overlap past each-other.
- if (Index->getAlignment() < MemVT.getStoreSize())
- break;
-
- // We found a potential memory operand to merge.
- StoreNodes.push_back(MemOpLink(Index, Ptr.Offset, Seq++));
-
- // Find the next memory operand in the chain. If the next operand in the
- // chain is a store then move up and continue the scan with the next
- // memory operand. If the next operand is a load save it and use alias
- // information to check if it interferes with anything.
- SDNode *NextInChain = Index->getChain().getNode();
- while (1) {
- if (StoreSDNode *STn = dyn_cast<StoreSDNode>(NextInChain)) {
- // We found a store node. Use it for the next iteration.
- Index = STn;
- break;
- } else if (LoadSDNode *Ldn = dyn_cast<LoadSDNode>(NextInChain)) {
- if (Ldn->isVolatile()) {
- Index = nullptr;
- break;
+ if (IsConstantSrc)
+ if (!(isa<ConstantSDNode>(Other->getValue()) ||
+ isa<ConstantFPSDNode>(Other->getValue())))
+ return false;
+ if (IsExtractVecSrc)
+ if (!(Other->getValue().getOpcode() == ISD::EXTRACT_VECTOR_ELT ||
+ Other->getValue().getOpcode() == ISD::EXTRACT_SUBVECTOR))
+ return false;
+ Ptr = BaseIndexOffset::match(Other->getBasePtr(), DAG);
+ return (BasePtr.equalBaseIndex(Ptr, DAG, Offset));
+ };
+ // We looking for a root node which is an ancestor to all mergable
+ // stores. We search up through a load, to our root and then down
+ // through all children. For instance we will find Store{1,2,3} if
+ // St is Store1, Store2. or Store3 where the root is not a load
+ // which always true for nonvolatile ops. TODO: Expand
+ // the search to find all valid candidates through multiple layers of loads.
+ //
+ // Root
+ // |-------|-------|
+ // Load Load Store3
+ // | |
+ // Store1 Store2
+ //
+ // FIXME: We should be able to climb and
+ // descend TokenFactors to find candidates as well.
+
+ SDNode *RootNode = (St->getChain()).getNode();
+
+ if (LoadSDNode *Ldn = dyn_cast<LoadSDNode>(RootNode)) {
+ RootNode = Ldn->getChain().getNode();
+ for (auto I = RootNode->use_begin(), E = RootNode->use_end(); I != E; ++I)
+ if (I.getOperandNo() == 0 && isa<LoadSDNode>(*I)) // walk down chain
+ for (auto I2 = (*I)->use_begin(), E2 = (*I)->use_end(); I2 != E2; ++I2)
+ if (I2.getOperandNo() == 0)
+ if (StoreSDNode *OtherST = dyn_cast<StoreSDNode>(*I2)) {
+ BaseIndexOffset Ptr;
+ int64_t PtrDiff;
+ if (CandidateMatch(OtherST, Ptr, PtrDiff))
+ StoreNodes.push_back(MemOpLink(OtherST, PtrDiff));
+ }
+ } else
+ for (auto I = RootNode->use_begin(), E = RootNode->use_end(); I != E; ++I)
+ if (I.getOperandNo() == 0)
+ if (StoreSDNode *OtherST = dyn_cast<StoreSDNode>(*I)) {
+ BaseIndexOffset Ptr;
+ int64_t PtrDiff;
+ if (CandidateMatch(OtherST, Ptr, PtrDiff))
+ StoreNodes.push_back(MemOpLink(OtherST, PtrDiff));
}
-
- // Save the load node for later. Continue the scan.
- AliasLoadNodes.push_back(Ldn);
- NextInChain = Ldn->getChain().getNode();
- continue;
- } else {
- Index = nullptr;
- break;
- }
- }
- }
}
-// We need to check that merging these stores does not cause a loop
-// in the DAG. Any store candidate may depend on another candidate
+// We need to check that merging these stores does not cause a loop in
+// the DAG. Any store candidate may depend on another candidate
// indirectly through its operand (we already consider dependencies
// through the chain). Check in parallel by searching up from
// non-chain operands of candidates.
+
bool DAGCombiner::checkMergeStoreCandidatesForDependencies(
- SmallVectorImpl<MemOpLink> &StoreNodes) {
+ SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores) {
+
+ // FIXME: We should be able to truncate a full search of
+ // predecessors by doing a BFS and keeping tabs the originating
+ // stores from which worklist nodes come from in a similar way to
+ // TokenFactor simplfication.
+
SmallPtrSet<const SDNode *, 16> Visited;
SmallVector<const SDNode *, 8> Worklist;
- // search ops of store candidates
- for (unsigned i = 0; i < StoreNodes.size(); ++i) {
+ unsigned int Max = 8192;
+ // Search Ops of store candidates.
+ for (unsigned i = 0; i < NumStores; ++i) {
SDNode *n = StoreNodes[i].MemNode;
// Potential loops may happen only through non-chain operands
for (unsigned j = 1; j < n->getNumOperands(); ++j)
Worklist.push_back(n->getOperand(j).getNode());
}
- // search through DAG. We can stop early if we find a storenode
- for (unsigned i = 0; i < StoreNodes.size(); ++i) {
- if (SDNode::hasPredecessorHelper(StoreNodes[i].MemNode, Visited, Worklist))
+ // Search through DAG. We can stop early if we find a store node.
+ for (unsigned i = 0; i < NumStores; ++i) {
+ if (SDNode::hasPredecessorHelper(StoreNodes[i].MemNode, Visited, Worklist,
+ Max))
+ return false;
+ // Check if we ended early, failing conservatively if so.
+ if (Visited.size() >= Max)
return false;
}
return true;
}
-bool DAGCombiner::MergeConsecutiveStores(
- StoreSDNode* St, SmallVectorImpl<MemOpLink> &StoreNodes) {
+bool DAGCombiner::MergeConsecutiveStores(StoreSDNode *St) {
if (OptLevel == CodeGenOpt::None)
return false;
EVT MemVT = St->getMemoryVT();
int64_t ElementSizeBytes = MemVT.getSizeInBits() / 8;
+
+ if (MemVT.getSizeInBits() * 2 > MaximumLegalStoreInBits)
+ return false;
+
bool NoVectors = DAG.getMachineFunction().getFunction()->hasFnAttribute(
Attribute::NoImplicitFloat);
@@ -11731,376 +12681,437 @@ bool DAGCombiner::MergeConsecutiveStores(
if (MemVT.isVector() && IsLoadSrc)
return false;
- // Only look at ends of store sequences.
- SDValue Chain = SDValue(St, 0);
- if (Chain->hasOneUse() && Chain->use_begin()->getOpcode() == ISD::STORE)
- return false;
-
- // Save the LoadSDNodes that we find in the chain.
- // We need to make sure that these nodes do not interfere with
- // any of the store nodes.
- SmallVector<LSBaseSDNode*, 8> AliasLoadNodes;
-
- getStoreMergeAndAliasCandidates(St, StoreNodes, AliasLoadNodes);
+ SmallVector<MemOpLink, 8> StoreNodes;
+ // Find potential store merge candidates by searching through chain sub-DAG
+ getStoreMergeCandidates(St, StoreNodes);
// Check if there is anything to merge.
if (StoreNodes.size() < 2)
return false;
- // only do dependence check in AA case
- bool UseAA = CombinerAA.getNumOccurrences() > 0 ? CombinerAA
- : DAG.getSubtarget().useAA();
- if (UseAA && !checkMergeStoreCandidatesForDependencies(StoreNodes))
- return false;
-
// Sort the memory operands according to their distance from the
- // base pointer. As a secondary criteria: make sure stores coming
- // later in the code come first in the list. This is important for
- // the non-UseAA case, because we're merging stores into the FINAL
- // store along a chain which potentially contains aliasing stores.
- // Thus, if there are multiple stores to the same address, the last
- // one can be considered for merging but not the others.
+ // base pointer.
std::sort(StoreNodes.begin(), StoreNodes.end(),
[](MemOpLink LHS, MemOpLink RHS) {
- return LHS.OffsetFromBase < RHS.OffsetFromBase ||
- (LHS.OffsetFromBase == RHS.OffsetFromBase &&
- LHS.SequenceNum < RHS.SequenceNum);
- });
-
- // Scan the memory operations on the chain and find the first non-consecutive
- // store memory address.
- unsigned LastConsecutiveStore = 0;
- int64_t StartAddress = StoreNodes[0].OffsetFromBase;
- for (unsigned i = 0, e = StoreNodes.size(); i < e; ++i) {
-
+ return LHS.OffsetFromBase < RHS.OffsetFromBase;
+ });
+
+ // Store Merge attempts to merge the lowest stores. This generally
+ // works out as if successful, as the remaining stores are checked
+ // after the first collection of stores is merged. However, in the
+ // case that a non-mergeable store is found first, e.g., {p[-2],
+ // p[0], p[1], p[2], p[3]}, we would fail and miss the subsequent
+ // mergeable cases. To prevent this, we prune such stores from the
+ // front of StoreNodes here.
+
+ bool RV = false;
+ while (StoreNodes.size() > 1) {
+ unsigned StartIdx = 0;
+ while ((StartIdx + 1 < StoreNodes.size()) &&
+ StoreNodes[StartIdx].OffsetFromBase + ElementSizeBytes !=
+ StoreNodes[StartIdx + 1].OffsetFromBase)
+ ++StartIdx;
+
+ // Bail if we don't have enough candidates to merge.
+ if (StartIdx + 1 >= StoreNodes.size())
+ return RV;
+
+ if (StartIdx)
+ StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + StartIdx);
+
+ // Scan the memory operations on the chain and find the first
+ // non-consecutive store memory address.
+ unsigned NumConsecutiveStores = 1;
+ int64_t StartAddress = StoreNodes[0].OffsetFromBase;
// Check that the addresses are consecutive starting from the second
// element in the list of stores.
- if (i > 0) {
+ for (unsigned i = 1, e = StoreNodes.size(); i < e; ++i) {
int64_t CurrAddress = StoreNodes[i].OffsetFromBase;
if (CurrAddress - StartAddress != (ElementSizeBytes * i))
break;
+ NumConsecutiveStores = i + 1;
}
- // Check if this store interferes with any of the loads that we found.
- // If we find a load that alias with this store. Stop the sequence.
- if (any_of(AliasLoadNodes, [&](LSBaseSDNode *Ldn) {
- return isAlias(Ldn, StoreNodes[i].MemNode);
- }))
- break;
+ if (NumConsecutiveStores < 2) {
+ StoreNodes.erase(StoreNodes.begin(),
+ StoreNodes.begin() + NumConsecutiveStores);
+ continue;
+ }
- // Mark this node as useful.
- LastConsecutiveStore = i;
- }
+ // Check that we can merge these candidates without causing a cycle
+ if (!checkMergeStoreCandidatesForDependencies(StoreNodes,
+ NumConsecutiveStores)) {
+ StoreNodes.erase(StoreNodes.begin(),
+ StoreNodes.begin() + NumConsecutiveStores);
+ continue;
+ }
- // The node with the lowest store address.
- LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
- unsigned FirstStoreAS = FirstInChain->getAddressSpace();
- unsigned FirstStoreAlign = FirstInChain->getAlignment();
- LLVMContext &Context = *DAG.getContext();
- const DataLayout &DL = DAG.getDataLayout();
-
- // Store the constants into memory as one consecutive store.
- if (IsConstantSrc) {
- unsigned LastLegalType = 0;
- unsigned LastLegalVectorType = 0;
- bool NonZero = false;
- for (unsigned i=0; i<LastConsecutiveStore+1; ++i) {
- StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
- SDValue StoredVal = St->getValue();
-
- if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(StoredVal)) {
- NonZero |= !C->isNullValue();
- } else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(StoredVal)) {
- NonZero |= !C->getConstantFPValue()->isNullValue();
- } else {
- // Non-constant.
- break;
- }
+ // The node with the lowest store address.
+ LLVMContext &Context = *DAG.getContext();
+ const DataLayout &DL = DAG.getDataLayout();
- // Find a legal type for the constant store.
- unsigned SizeInBits = (i+1) * ElementSizeBytes * 8;
- EVT StoreTy = EVT::getIntegerVT(Context, SizeInBits);
- bool IsFast;
- if (TLI.isTypeLegal(StoreTy) &&
- TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstStoreAS,
- FirstStoreAlign, &IsFast) && IsFast) {
- LastLegalType = i+1;
- // Or check whether a truncstore is legal.
- } else if (TLI.getTypeAction(Context, StoreTy) ==
- TargetLowering::TypePromoteInteger) {
- EVT LegalizedStoredValueTy =
- TLI.getTypeToTransformTo(Context, StoredVal.getValueType());
- if (TLI.isTruncStoreLegal(LegalizedStoredValueTy, StoreTy) &&
- TLI.allowsMemoryAccess(Context, DL, LegalizedStoredValueTy,
- FirstStoreAS, FirstStoreAlign, &IsFast) &&
+ // Store the constants into memory as one consecutive store.
+ if (IsConstantSrc) {
+ LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
+ unsigned FirstStoreAS = FirstInChain->getAddressSpace();
+ unsigned FirstStoreAlign = FirstInChain->getAlignment();
+ unsigned LastLegalType = 1;
+ unsigned LastLegalVectorType = 1;
+ bool LastIntegerTrunc = false;
+ bool NonZero = false;
+ for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
+ StoreSDNode *ST = cast<StoreSDNode>(StoreNodes[i].MemNode);
+ SDValue StoredVal = ST->getValue();
+
+ if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(StoredVal)) {
+ NonZero |= !C->isNullValue();
+ } else if (ConstantFPSDNode *C =
+ dyn_cast<ConstantFPSDNode>(StoredVal)) {
+ NonZero |= !C->getConstantFPValue()->isNullValue();
+ } else {
+ // Non-constant.
+ break;
+ }
+
+ // Find a legal type for the constant store.
+ unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
+ EVT StoreTy = EVT::getIntegerVT(Context, SizeInBits);
+ bool IsFast = false;
+ if (TLI.isTypeLegal(StoreTy) &&
+ TLI.canMergeStoresTo(FirstStoreAS, StoreTy, DAG) &&
+ TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstStoreAS,
+ FirstStoreAlign, &IsFast) &&
IsFast) {
+ LastIntegerTrunc = false;
LastLegalType = i + 1;
+ // Or check whether a truncstore is legal.
+ } else if (TLI.getTypeAction(Context, StoreTy) ==
+ TargetLowering::TypePromoteInteger) {
+ EVT LegalizedStoredValueTy =
+ TLI.getTypeToTransformTo(Context, StoredVal.getValueType());
+ if (TLI.isTruncStoreLegal(LegalizedStoredValueTy, StoreTy) &&
+ TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValueTy, DAG) &&
+ TLI.allowsMemoryAccess(Context, DL, LegalizedStoredValueTy,
+ FirstStoreAS, FirstStoreAlign, &IsFast) &&
+ IsFast) {
+ LastIntegerTrunc = true;
+ LastLegalType = i + 1;
+ }
}
- }
- // We only use vectors if the constant is known to be zero or the target
- // allows it and the function is not marked with the noimplicitfloat
- // attribute.
- if ((!NonZero || TLI.storeOfVectorConstantIsCheap(MemVT, i+1,
- FirstStoreAS)) &&
- !NoVectors) {
- // Find a legal type for the vector store.
- EVT Ty = EVT::getVectorVT(Context, MemVT, i+1);
- if (TLI.isTypeLegal(Ty) &&
- TLI.allowsMemoryAccess(Context, DL, Ty, FirstStoreAS,
- FirstStoreAlign, &IsFast) && IsFast)
- LastLegalVectorType = i + 1;
+ // We only use vectors if the constant is known to be zero or the target
+ // allows it and the function is not marked with the noimplicitfloat
+ // attribute.
+ if ((!NonZero ||
+ TLI.storeOfVectorConstantIsCheap(MemVT, i + 1, FirstStoreAS)) &&
+ !NoVectors) {
+ // Find a legal type for the vector store.
+ unsigned Elts = i + 1;
+ if (MemVT.isVector()) {
+ // When merging vector stores, get the total number of elements.
+ Elts *= MemVT.getVectorNumElements();
+ }
+ EVT Ty = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
+ if (TLI.isTypeLegal(Ty) &&
+ TLI.canMergeStoresTo(FirstStoreAS, Ty, DAG) &&
+ TLI.allowsMemoryAccess(Context, DL, Ty, FirstStoreAS,
+ FirstStoreAlign, &IsFast) &&
+ IsFast)
+ LastLegalVectorType = i + 1;
+ }
}
- }
- // Check if we found a legal integer type to store.
- if (LastLegalType == 0 && LastLegalVectorType == 0)
- return false;
-
- bool UseVector = (LastLegalVectorType > LastLegalType) && !NoVectors;
- unsigned NumElem = UseVector ? LastLegalVectorType : LastLegalType;
-
- return MergeStoresOfConstantsOrVecElts(StoreNodes, MemVT, NumElem,
- true, UseVector);
- }
+ // Check if we found a legal integer type that creates a meaningful merge.
+ if (LastLegalType < 2 && LastLegalVectorType < 2) {
+ StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + 1);
+ continue;
+ }
- // When extracting multiple vector elements, try to store them
- // in one vector store rather than a sequence of scalar stores.
- if (IsExtractVecSrc) {
- unsigned NumStoresToMerge = 0;
- bool IsVec = MemVT.isVector();
- for (unsigned i = 0; i < LastConsecutiveStore + 1; ++i) {
- StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
- unsigned StoreValOpcode = St->getValue().getOpcode();
- // This restriction could be loosened.
- // Bail out if any stored values are not elements extracted from a vector.
- // It should be possible to handle mixed sources, but load sources need
- // more careful handling (see the block of code below that handles
- // consecutive loads).
- if (StoreValOpcode != ISD::EXTRACT_VECTOR_ELT &&
- StoreValOpcode != ISD::EXTRACT_SUBVECTOR)
- return false;
+ bool UseVector = (LastLegalVectorType > LastLegalType) && !NoVectors;
+ unsigned NumElem = (UseVector) ? LastLegalVectorType : LastLegalType;
- // Find a legal type for the vector store.
- unsigned Elts = i + 1;
- if (IsVec) {
- // When merging vector stores, get the total number of elements.
- Elts *= MemVT.getVectorNumElements();
+ bool Merged = MergeStoresOfConstantsOrVecElts(
+ StoreNodes, MemVT, NumElem, true, UseVector, LastIntegerTrunc);
+ if (!Merged) {
+ StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
+ continue;
}
- EVT Ty = EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts);
- bool IsFast;
- if (TLI.isTypeLegal(Ty) &&
- TLI.allowsMemoryAccess(Context, DL, Ty, FirstStoreAS,
- FirstStoreAlign, &IsFast) && IsFast)
- NumStoresToMerge = i + 1;
+ // Remove merged stores for next iteration.
+ RV = true;
+ StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
+ continue;
}
- return MergeStoresOfConstantsOrVecElts(StoreNodes, MemVT, NumStoresToMerge,
- false, true);
- }
-
- // Below we handle the case of multiple consecutive stores that
- // come from multiple consecutive loads. We merge them into a single
- // wide load and a single wide store.
-
- // Look for load nodes which are used by the stored values.
- SmallVector<MemOpLink, 8> LoadNodes;
+ // When extracting multiple vector elements, try to store them
+ // in one vector store rather than a sequence of scalar stores.
+ if (IsExtractVecSrc) {
+ LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
+ unsigned FirstStoreAS = FirstInChain->getAddressSpace();
+ unsigned FirstStoreAlign = FirstInChain->getAlignment();
+ unsigned NumStoresToMerge = 1;
+ bool IsVec = MemVT.isVector();
+ for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
+ StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
+ unsigned StoreValOpcode = St->getValue().getOpcode();
+ // This restriction could be loosened.
+ // Bail out if any stored values are not elements extracted from a
+ // vector. It should be possible to handle mixed sources, but load
+ // sources need more careful handling (see the block of code below that
+ // handles consecutive loads).
+ if (StoreValOpcode != ISD::EXTRACT_VECTOR_ELT &&
+ StoreValOpcode != ISD::EXTRACT_SUBVECTOR)
+ return RV;
- // Find acceptable loads. Loads need to have the same chain (token factor),
- // must not be zext, volatile, indexed, and they must be consecutive.
- BaseIndexOffset LdBasePtr;
- for (unsigned i=0; i<LastConsecutiveStore+1; ++i) {
- StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
- LoadSDNode *Ld = dyn_cast<LoadSDNode>(St->getValue());
- if (!Ld) break;
+ // Find a legal type for the vector store.
+ unsigned Elts = i + 1;
+ if (IsVec) {
+ // When merging vector stores, get the total number of elements.
+ Elts *= MemVT.getVectorNumElements();
+ }
+ EVT Ty =
+ EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts);
+ bool IsFast;
+ if (TLI.isTypeLegal(Ty) &&
+ TLI.canMergeStoresTo(FirstStoreAS, Ty, DAG) &&
+ TLI.allowsMemoryAccess(Context, DL, Ty, FirstStoreAS,
+ FirstStoreAlign, &IsFast) &&
+ IsFast)
+ NumStoresToMerge = i + 1;
+ }
- // Loads must only have one use.
- if (!Ld->hasNUsesOfValue(1, 0))
- break;
+ bool Merged = MergeStoresOfConstantsOrVecElts(
+ StoreNodes, MemVT, NumStoresToMerge, false, true, false);
+ if (!Merged) {
+ StoreNodes.erase(StoreNodes.begin(),
+ StoreNodes.begin() + NumStoresToMerge);
+ continue;
+ }
+ // Remove merged stores for next iteration.
+ StoreNodes.erase(StoreNodes.begin(),
+ StoreNodes.begin() + NumStoresToMerge);
+ RV = true;
+ continue;
+ }
- // The memory operands must not be volatile.
- if (Ld->isVolatile() || Ld->isIndexed())
- break;
+ // Below we handle the case of multiple consecutive stores that
+ // come from multiple consecutive loads. We merge them into a single
+ // wide load and a single wide store.
- // We do not accept ext loads.
- if (Ld->getExtensionType() != ISD::NON_EXTLOAD)
- break;
+ // Look for load nodes which are used by the stored values.
+ SmallVector<MemOpLink, 8> LoadNodes;
- // The stored memory type must be the same.
- if (Ld->getMemoryVT() != MemVT)
- break;
+ // Find acceptable loads. Loads need to have the same chain (token factor),
+ // must not be zext, volatile, indexed, and they must be consecutive.
+ BaseIndexOffset LdBasePtr;
+ for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
+ StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
+ LoadSDNode *Ld = dyn_cast<LoadSDNode>(St->getValue());
+ if (!Ld)
+ break;
- BaseIndexOffset LdPtr = BaseIndexOffset::match(Ld->getBasePtr(), DAG);
- // If this is not the first ptr that we check.
- if (LdBasePtr.Base.getNode()) {
- // The base ptr must be the same.
- if (!LdPtr.equalBaseIndex(LdBasePtr))
+ // Loads must only have one use.
+ if (!Ld->hasNUsesOfValue(1, 0))
break;
- } else {
- // Check that all other base pointers are the same as this one.
- LdBasePtr = LdPtr;
- }
- // We found a potential memory operand to merge.
- LoadNodes.push_back(MemOpLink(Ld, LdPtr.Offset, 0));
- }
+ // The memory operands must not be volatile.
+ if (Ld->isVolatile() || Ld->isIndexed())
+ break;
- if (LoadNodes.size() < 2)
- return false;
+ // We do not accept ext loads.
+ if (Ld->getExtensionType() != ISD::NON_EXTLOAD)
+ break;
- // If we have load/store pair instructions and we only have two values,
- // don't bother.
- unsigned RequiredAlignment;
- if (LoadNodes.size() == 2 && TLI.hasPairedLoad(MemVT, RequiredAlignment) &&
- St->getAlignment() >= RequiredAlignment)
- return false;
+ // The stored memory type must be the same.
+ if (Ld->getMemoryVT() != MemVT)
+ break;
- LoadSDNode *FirstLoad = cast<LoadSDNode>(LoadNodes[0].MemNode);
- unsigned FirstLoadAS = FirstLoad->getAddressSpace();
- unsigned FirstLoadAlign = FirstLoad->getAlignment();
-
- // Scan the memory operations on the chain and find the first non-consecutive
- // load memory address. These variables hold the index in the store node
- // array.
- unsigned LastConsecutiveLoad = 0;
- // This variable refers to the size and not index in the array.
- unsigned LastLegalVectorType = 0;
- unsigned LastLegalIntegerType = 0;
- StartAddress = LoadNodes[0].OffsetFromBase;
- SDValue FirstChain = FirstLoad->getChain();
- for (unsigned i = 1; i < LoadNodes.size(); ++i) {
- // All loads must share the same chain.
- if (LoadNodes[i].MemNode->getChain() != FirstChain)
- break;
+ BaseIndexOffset LdPtr = BaseIndexOffset::match(Ld->getBasePtr(), DAG);
+ // If this is not the first ptr that we check.
+ int64_t LdOffset = 0;
+ if (LdBasePtr.getBase().getNode()) {
+ // The base ptr must be the same.
+ if (!LdBasePtr.equalBaseIndex(LdPtr, DAG, LdOffset))
+ break;
+ } else {
+ // Check that all other base pointers are the same as this one.
+ LdBasePtr = LdPtr;
+ }
- int64_t CurrAddress = LoadNodes[i].OffsetFromBase;
- if (CurrAddress - StartAddress != (ElementSizeBytes * i))
- break;
- LastConsecutiveLoad = i;
- // Find a legal type for the vector store.
- EVT StoreTy = EVT::getVectorVT(Context, MemVT, i+1);
- bool IsFastSt, IsFastLd;
- if (TLI.isTypeLegal(StoreTy) &&
- TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstStoreAS,
- FirstStoreAlign, &IsFastSt) && IsFastSt &&
- TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstLoadAS,
- FirstLoadAlign, &IsFastLd) && IsFastLd) {
- LastLegalVectorType = i + 1;
- }
-
- // Find a legal type for the integer store.
- unsigned SizeInBits = (i+1) * ElementSizeBytes * 8;
- StoreTy = EVT::getIntegerVT(Context, SizeInBits);
- if (TLI.isTypeLegal(StoreTy) &&
- TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstStoreAS,
- FirstStoreAlign, &IsFastSt) && IsFastSt &&
- TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstLoadAS,
- FirstLoadAlign, &IsFastLd) && IsFastLd)
- LastLegalIntegerType = i + 1;
- // Or check whether a truncstore and extload is legal.
- else if (TLI.getTypeAction(Context, StoreTy) ==
- TargetLowering::TypePromoteInteger) {
- EVT LegalizedStoredValueTy =
- TLI.getTypeToTransformTo(Context, StoreTy);
- if (TLI.isTruncStoreLegal(LegalizedStoredValueTy, StoreTy) &&
- TLI.isLoadExtLegal(ISD::ZEXTLOAD, LegalizedStoredValueTy, StoreTy) &&
- TLI.isLoadExtLegal(ISD::SEXTLOAD, LegalizedStoredValueTy, StoreTy) &&
- TLI.isLoadExtLegal(ISD::EXTLOAD, LegalizedStoredValueTy, StoreTy) &&
- TLI.allowsMemoryAccess(Context, DL, LegalizedStoredValueTy,
- FirstStoreAS, FirstStoreAlign, &IsFastSt) &&
- IsFastSt &&
- TLI.allowsMemoryAccess(Context, DL, LegalizedStoredValueTy,
- FirstLoadAS, FirstLoadAlign, &IsFastLd) &&
- IsFastLd)
- LastLegalIntegerType = i+1;
+ // We found a potential memory operand to merge.
+ LoadNodes.push_back(MemOpLink(Ld, LdOffset));
}
- }
-
- // Only use vector types if the vector type is larger than the integer type.
- // If they are the same, use integers.
- bool UseVectorTy = LastLegalVectorType > LastLegalIntegerType && !NoVectors;
- unsigned LastLegalType = std::max(LastLegalVectorType, LastLegalIntegerType);
- // We add +1 here because the LastXXX variables refer to location while
- // the NumElem refers to array/index size.
- unsigned NumElem = std::min(LastConsecutiveStore, LastConsecutiveLoad) + 1;
- NumElem = std::min(LastLegalType, NumElem);
-
- if (NumElem < 2)
- return false;
-
- // Collect the chains from all merged stores.
- SmallVector<SDValue, 8> MergeStoreChains;
- MergeStoreChains.push_back(StoreNodes[0].MemNode->getChain());
+ if (LoadNodes.size() < 2) {
+ StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + 1);
+ continue;
+ }
- // The latest Node in the DAG.
- unsigned LatestNodeUsed = 0;
- for (unsigned i=1; i<NumElem; ++i) {
- // Find a chain for the new wide-store operand. Notice that some
- // of the store nodes that we found may not be selected for inclusion
- // in the wide store. The chain we use needs to be the chain of the
- // latest store node which is *used* and replaced by the wide store.
- if (StoreNodes[i].SequenceNum < StoreNodes[LatestNodeUsed].SequenceNum)
- LatestNodeUsed = i;
+ // If we have load/store pair instructions and we only have two values,
+ // don't bother merging.
+ unsigned RequiredAlignment;
+ if (LoadNodes.size() == 2 && TLI.hasPairedLoad(MemVT, RequiredAlignment) &&
+ StoreNodes[0].MemNode->getAlignment() >= RequiredAlignment) {
+ StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + 2);
+ continue;
+ }
+ LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
+ unsigned FirstStoreAS = FirstInChain->getAddressSpace();
+ unsigned FirstStoreAlign = FirstInChain->getAlignment();
+ LoadSDNode *FirstLoad = cast<LoadSDNode>(LoadNodes[0].MemNode);
+ unsigned FirstLoadAS = FirstLoad->getAddressSpace();
+ unsigned FirstLoadAlign = FirstLoad->getAlignment();
+
+ // Scan the memory operations on the chain and find the first
+ // non-consecutive load memory address. These variables hold the index in
+ // the store node array.
+ unsigned LastConsecutiveLoad = 1;
+ // This variable refers to the size and not index in the array.
+ unsigned LastLegalVectorType = 1;
+ unsigned LastLegalIntegerType = 1;
+ bool isDereferenceable = true;
+ bool DoIntegerTruncate = false;
+ StartAddress = LoadNodes[0].OffsetFromBase;
+ SDValue FirstChain = FirstLoad->getChain();
+ for (unsigned i = 1; i < LoadNodes.size(); ++i) {
+ // All loads must share the same chain.
+ if (LoadNodes[i].MemNode->getChain() != FirstChain)
+ break;
- MergeStoreChains.push_back(StoreNodes[i].MemNode->getChain());
- }
+ int64_t CurrAddress = LoadNodes[i].OffsetFromBase;
+ if (CurrAddress - StartAddress != (ElementSizeBytes * i))
+ break;
+ LastConsecutiveLoad = i;
- LSBaseSDNode *LatestOp = StoreNodes[LatestNodeUsed].MemNode;
+ if (isDereferenceable && !LoadNodes[i].MemNode->isDereferenceable())
+ isDereferenceable = false;
- // Find if it is better to use vectors or integers to load and store
- // to memory.
- EVT JointMemOpVT;
- if (UseVectorTy) {
- JointMemOpVT = EVT::getVectorVT(Context, MemVT, NumElem);
- } else {
- unsigned SizeInBits = NumElem * ElementSizeBytes * 8;
- JointMemOpVT = EVT::getIntegerVT(Context, SizeInBits);
- }
+ // Find a legal type for the vector store.
+ EVT StoreTy = EVT::getVectorVT(Context, MemVT, i + 1);
+ bool IsFastSt, IsFastLd;
+ if (TLI.isTypeLegal(StoreTy) &&
+ TLI.canMergeStoresTo(FirstStoreAS, StoreTy, DAG) &&
+ TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstStoreAS,
+ FirstStoreAlign, &IsFastSt) &&
+ IsFastSt &&
+ TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstLoadAS,
+ FirstLoadAlign, &IsFastLd) &&
+ IsFastLd) {
+ LastLegalVectorType = i + 1;
+ }
- SDLoc LoadDL(LoadNodes[0].MemNode);
- SDLoc StoreDL(StoreNodes[0].MemNode);
+ // Find a legal type for the integer store.
+ unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
+ StoreTy = EVT::getIntegerVT(Context, SizeInBits);
+ if (TLI.isTypeLegal(StoreTy) &&
+ TLI.canMergeStoresTo(FirstStoreAS, StoreTy, DAG) &&
+ TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstStoreAS,
+ FirstStoreAlign, &IsFastSt) &&
+ IsFastSt &&
+ TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstLoadAS,
+ FirstLoadAlign, &IsFastLd) &&
+ IsFastLd) {
+ LastLegalIntegerType = i + 1;
+ DoIntegerTruncate = false;
+ // Or check whether a truncstore and extload is legal.
+ } else if (TLI.getTypeAction(Context, StoreTy) ==
+ TargetLowering::TypePromoteInteger) {
+ EVT LegalizedStoredValueTy = TLI.getTypeToTransformTo(Context, StoreTy);
+ if (TLI.isTruncStoreLegal(LegalizedStoredValueTy, StoreTy) &&
+ TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValueTy, DAG) &&
+ TLI.isLoadExtLegal(ISD::ZEXTLOAD, LegalizedStoredValueTy,
+ StoreTy) &&
+ TLI.isLoadExtLegal(ISD::SEXTLOAD, LegalizedStoredValueTy,
+ StoreTy) &&
+ TLI.isLoadExtLegal(ISD::EXTLOAD, LegalizedStoredValueTy, StoreTy) &&
+ TLI.allowsMemoryAccess(Context, DL, LegalizedStoredValueTy,
+ FirstStoreAS, FirstStoreAlign, &IsFastSt) &&
+ IsFastSt &&
+ TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstLoadAS,
+ FirstLoadAlign, &IsFastLd) &&
+ IsFastLd) {
+ LastLegalIntegerType = i + 1;
+ DoIntegerTruncate = true;
+ }
+ }
+ }
- // The merged loads are required to have the same incoming chain, so
- // using the first's chain is acceptable.
- SDValue NewLoad = DAG.getLoad(JointMemOpVT, LoadDL, FirstLoad->getChain(),
- FirstLoad->getBasePtr(),
- FirstLoad->getPointerInfo(), FirstLoadAlign);
+ // Only use vector types if the vector type is larger than the integer type.
+ // If they are the same, use integers.
+ bool UseVectorTy = LastLegalVectorType > LastLegalIntegerType && !NoVectors;
+ unsigned LastLegalType =
+ std::max(LastLegalVectorType, LastLegalIntegerType);
- SDValue NewStoreChain =
- DAG.getNode(ISD::TokenFactor, StoreDL, MVT::Other, MergeStoreChains);
+ // We add +1 here because the LastXXX variables refer to location while
+ // the NumElem refers to array/index size.
+ unsigned NumElem = std::min(NumConsecutiveStores, LastConsecutiveLoad + 1);
+ NumElem = std::min(LastLegalType, NumElem);
- SDValue NewStore =
- DAG.getStore(NewStoreChain, StoreDL, NewLoad, FirstInChain->getBasePtr(),
- FirstInChain->getPointerInfo(), FirstStoreAlign);
+ if (NumElem < 2) {
+ StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + 1);
+ continue;
+ }
- // Transfer chain users from old loads to the new load.
- for (unsigned i = 0; i < NumElem; ++i) {
- LoadSDNode *Ld = cast<LoadSDNode>(LoadNodes[i].MemNode);
- DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1),
- SDValue(NewLoad.getNode(), 1));
- }
+ // Find if it is better to use vectors or integers to load and store
+ // to memory.
+ EVT JointMemOpVT;
+ if (UseVectorTy) {
+ JointMemOpVT = EVT::getVectorVT(Context, MemVT, NumElem);
+ } else {
+ unsigned SizeInBits = NumElem * ElementSizeBytes * 8;
+ JointMemOpVT = EVT::getIntegerVT(Context, SizeInBits);
+ }
+
+ SDLoc LoadDL(LoadNodes[0].MemNode);
+ SDLoc StoreDL(StoreNodes[0].MemNode);
+
+ // The merged loads are required to have the same incoming chain, so
+ // using the first's chain is acceptable.
+
+ SDValue NewStoreChain = getMergeStoreChains(StoreNodes, NumElem);
+ AddToWorklist(NewStoreChain.getNode());
+
+ MachineMemOperand::Flags MMOFlags = isDereferenceable ?
+ MachineMemOperand::MODereferenceable:
+ MachineMemOperand::MONone;
+
+ SDValue NewLoad, NewStore;
+ if (UseVectorTy || !DoIntegerTruncate) {
+ NewLoad = DAG.getLoad(JointMemOpVT, LoadDL, FirstLoad->getChain(),
+ FirstLoad->getBasePtr(),
+ FirstLoad->getPointerInfo(), FirstLoadAlign,
+ MMOFlags);
+ NewStore = DAG.getStore(NewStoreChain, StoreDL, NewLoad,
+ FirstInChain->getBasePtr(),
+ FirstInChain->getPointerInfo(), FirstStoreAlign);
+ } else { // This must be the truncstore/extload case
+ EVT ExtendedTy =
+ TLI.getTypeToTransformTo(*DAG.getContext(), JointMemOpVT);
+ NewLoad =
+ DAG.getExtLoad(ISD::EXTLOAD, LoadDL, ExtendedTy, FirstLoad->getChain(),
+ FirstLoad->getBasePtr(), FirstLoad->getPointerInfo(),
+ JointMemOpVT, FirstLoadAlign, MMOFlags);
+ NewStore = DAG.getTruncStore(NewStoreChain, StoreDL, NewLoad,
+ FirstInChain->getBasePtr(),
+ FirstInChain->getPointerInfo(), JointMemOpVT,
+ FirstInChain->getAlignment(),
+ FirstInChain->getMemOperand()->getFlags());
+ }
+
+ // Transfer chain users from old loads to the new load.
+ for (unsigned i = 0; i < NumElem; ++i) {
+ LoadSDNode *Ld = cast<LoadSDNode>(LoadNodes[i].MemNode);
+ DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1),
+ SDValue(NewLoad.getNode(), 1));
+ }
- if (UseAA) {
// Replace the all stores with the new store.
for (unsigned i = 0; i < NumElem; ++i)
CombineTo(StoreNodes[i].MemNode, NewStore);
- } else {
- // Replace the last store with the new store.
- CombineTo(LatestOp, NewStore);
- // Erase all other stores.
- for (unsigned i = 0; i < NumElem; ++i) {
- // Remove all Store nodes.
- if (StoreNodes[i].MemNode == LatestOp)
- continue;
- StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
- DAG.ReplaceAllUsesOfValueWith(SDValue(St, 0), St->getChain());
- deleteAndRecombine(St);
- }
+ RV = true;
+ StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
+ continue;
}
-
- StoreNodes.erase(StoreNodes.begin() + NumElem, StoreNodes.end());
- return true;
+ return RV;
}
SDValue DAGCombiner::replaceStoreChain(StoreSDNode *ST, SDValue BetterChain) {
@@ -12256,19 +13267,7 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) {
if (SDValue NewST = TransformFPLoadStorePair(N))
return NewST;
- bool UseAA = CombinerAA.getNumOccurrences() > 0 ? CombinerAA
- : DAG.getSubtarget().useAA();
-#ifndef NDEBUG
- if (CombinerAAOnlyFunc.getNumOccurrences() &&
- CombinerAAOnlyFunc != DAG.getMachineFunction().getName())
- UseAA = false;
-#endif
- if (UseAA && ST->isUnindexed()) {
- // FIXME: We should do this even without AA enabled. AA will just allow
- // FindBetterChain to work in more situations. The problem with this is that
- // any combine that expects memory operations to be on consecutive chains
- // first needs to be updated to look for users of the same chain.
-
+ if (ST->isUnindexed()) {
// Walk up chain skipping non-aliasing memory nodes, on this store and any
// adjacent stores.
if (findBetterNeighborChains(ST)) {
@@ -12279,10 +13278,6 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) {
Chain = ST->getChain();
}
- // Try transforming N to an indexed store.
- if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
- return SDValue(N, 0);
-
// FIXME: is there such a thing as a truncating indexed store?
if (ST->isTruncatingStore() && ST->isUnindexed() &&
Value.getValueType().isInteger()) {
@@ -12302,8 +13297,15 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) {
if (SimplifyDemandedBits(
Value,
APInt::getLowBitsSet(Value.getScalarValueSizeInBits(),
- ST->getMemoryVT().getScalarSizeInBits())))
+ ST->getMemoryVT().getScalarSizeInBits()))) {
+ // Re-visit the store if anything changed and the store hasn't been merged
+ // with another node (N is deleted) SimplifyDemandedBits will add Value's
+ // node back to the worklist if necessary, but we also need to re-visit
+ // the Store node itself.
+ if (N->getOpcode() != ISD::DELETED_NODE)
+ AddToWorklist(N);
return SDValue(N, 0);
+ }
}
// If this is a load followed by a store to the same location, then the store
@@ -12319,14 +13321,28 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) {
}
}
- // If this is a store followed by a store with the same value to the same
- // location, then the store is dead/noop.
if (StoreSDNode *ST1 = dyn_cast<StoreSDNode>(Chain)) {
- if (ST1->getBasePtr() == Ptr && ST->getMemoryVT() == ST1->getMemoryVT() &&
- ST1->getValue() == Value && ST->isUnindexed() && !ST->isVolatile() &&
- ST1->isUnindexed() && !ST1->isVolatile()) {
- // The store is dead, remove it.
- return Chain;
+ if (ST->isUnindexed() && !ST->isVolatile() && ST1->isUnindexed() &&
+ !ST1->isVolatile() && ST1->getBasePtr() == Ptr &&
+ ST->getMemoryVT() == ST1->getMemoryVT()) {
+ // If this is a store followed by a store with the same value to the same
+ // location, then the store is dead/noop.
+ if (ST1->getValue() == Value) {
+ // The store is dead, remove it.
+ return Chain;
+ }
+
+ // If this is a store who's preceeding store to the same location
+ // and no one other node is chained to that store we can effectively
+ // drop the store. Do not remove stores to undef as they may be used as
+ // data sinks.
+ if (OptLevel != CodeGenOpt::None && ST1->hasOneUse() &&
+ !ST1->getBasePtr().isUndef()) {
+ // ST1 is fully overwritten and can be elided. Combine with it's chain
+ // value.
+ CombineTo(ST1, ST1->getChain());
+ return SDValue();
+ }
}
}
@@ -12342,29 +13358,31 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) {
// Only perform this optimization before the types are legal, because we
// don't want to perform this optimization on every DAGCombine invocation.
- if (!LegalTypes) {
+ if ((TLI.mergeStoresAfterLegalization()) ? Level == AfterLegalizeDAG
+ : !LegalTypes) {
for (;;) {
// There can be multiple store sequences on the same chain.
// Keep trying to merge store sequences until we are unable to do so
// or until we merge the last store on the chain.
- SmallVector<MemOpLink, 8> StoreNodes;
- bool Changed = MergeConsecutiveStores(ST, StoreNodes);
+ bool Changed = MergeConsecutiveStores(ST);
if (!Changed) break;
-
- if (any_of(StoreNodes,
- [ST](const MemOpLink &Link) { return Link.MemNode == ST; })) {
- // ST has been merged and no longer exists.
+ // Return N as merge only uses CombineTo and no worklist clean
+ // up is necessary.
+ if (N->getOpcode() == ISD::DELETED_NODE || !isa<StoreSDNode>(N))
return SDValue(N, 0);
- }
}
}
+ // Try transforming N to an indexed store.
+ if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
+ return SDValue(N, 0);
+
// Turn 'store float 1.0, Ptr' -> 'store int 0x12345678, Ptr'
//
// Make sure to do this only after attempting to merge stores in order to
// avoid changing the types of some subset of stores due to visit order,
// preventing their merging.
- if (isa<ConstantFPSDNode>(Value)) {
+ if (isa<ConstantFPSDNode>(ST->getValue())) {
if (SDValue NewSt = replaceStoreOfFPConstant(ST))
return NewSt;
}
@@ -12493,10 +13511,6 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
EVT VT = InVec.getValueType();
- // If we can't generate a legal BUILD_VECTOR, exit
- if (LegalOperations && !TLI.isOperationLegal(ISD::BUILD_VECTOR, VT))
- return SDValue();
-
// Check that we know which element is being inserted
if (!isa<ConstantSDNode>(EltNo))
return SDValue();
@@ -12511,8 +13525,7 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
// do this only if indices are both constants and Idx1 < Idx0.
if (InVec.getOpcode() == ISD::INSERT_VECTOR_ELT && InVec.hasOneUse()
&& isa<ConstantSDNode>(InVec.getOperand(2))) {
- unsigned OtherElt =
- cast<ConstantSDNode>(InVec.getOperand(2))->getZExtValue();
+ unsigned OtherElt = InVec.getConstantOperandVal(2);
if (Elt < OtherElt) {
// Swap nodes.
SDValue NewOp = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VT,
@@ -12523,6 +13536,10 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
}
}
+ // If we can't generate a legal BUILD_VECTOR, exit
+ if (LegalOperations && !TLI.isOperationLegal(ISD::BUILD_VECTOR, VT))
+ return SDValue();
+
// Check that the operand is a BUILD_VECTOR (or UNDEF, which can essentially
// be converted to a BUILD_VECTOR). Fill in the Ops vector with the
// vector elements.
@@ -12544,11 +13561,7 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
// All the operands of BUILD_VECTOR must have the same type;
// we enforce that here.
EVT OpVT = Ops[0].getValueType();
- if (InVal.getValueType() != OpVT)
- InVal = OpVT.bitsGT(InVal.getValueType()) ?
- DAG.getNode(ISD::ANY_EXTEND, DL, OpVT, InVal) :
- DAG.getNode(ISD::TRUNCATE, DL, OpVT, InVal);
- Ops[Elt] = InVal;
+ Ops[Elt] = OpVT.isInteger() ? DAG.getAnyExtOrTrunc(InVal, DL, OpVT) : InVal;
}
// Return the new vector
@@ -12568,6 +13581,11 @@ SDValue DAGCombiner::ReplaceExtractVectorEltOfLoadWithNarrowedLoad(
if (NewAlign > Align || !TLI.isOperationLegalOrCustom(ISD::LOAD, VecEltVT))
return SDValue();
+ ISD::LoadExtType ExtTy = ResultVT.bitsGT(VecEltVT) ?
+ ISD::NON_EXTLOAD : ISD::EXTLOAD;
+ if (!TLI.shouldReduceLoadWidth(OriginalLoad, ExtTy, VecEltVT))
+ return SDValue();
+
Align = NewAlign;
SDValue NewPtr = OriginalLoad->getBasePtr();
@@ -12639,6 +13657,9 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
EVT VT = InVec.getValueType();
EVT NVT = N->getValueType(0);
+ if (InVec.isUndef())
+ return DAG.getUNDEF(NVT);
+
if (InVec.getOpcode() == ISD::SCALAR_TO_VECTOR) {
// Check if the result type doesn't match the inserted element type. A
// SCALAR_TO_VECTOR may truncate the inserted element and the
@@ -13022,7 +14043,7 @@ SDValue DAGCombiner::reduceBuildVecConvertToConvertBuildVec(SDNode *N) {
return DAG.getNode(Opcode, DL, VT, BV);
}
-SDValue DAGCombiner::createBuildVecShuffle(SDLoc DL, SDNode *N,
+SDValue DAGCombiner::createBuildVecShuffle(const SDLoc &DL, SDNode *N,
ArrayRef<int> VectorMask,
SDValue VecIn1, SDValue VecIn2,
unsigned LeftIdx) {
@@ -13088,6 +14109,11 @@ SDValue DAGCombiner::createBuildVecShuffle(SDLoc DL, SDNode *N,
// when we start sorting the vectors by type.
return SDValue();
}
+ } else if (InVT2.getSizeInBits() * 2 == VT.getSizeInBits() &&
+ InVT1.getSizeInBits() == VT.getSizeInBits()) {
+ SmallVector<SDValue, 2> ConcatOps(2, DAG.getUNDEF(InVT2));
+ ConcatOps[0] = VecIn2;
+ VecIn2 = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
} else {
// TODO: Support cases where the length mismatch isn't exactly by a
// factor of 2.
@@ -13293,6 +14319,73 @@ SDValue DAGCombiner::reduceBuildVecToShuffle(SDNode *N) {
return Shuffles[0];
}
+// Check to see if this is a BUILD_VECTOR of a bunch of EXTRACT_VECTOR_ELT
+// operations which can be matched to a truncate.
+SDValue DAGCombiner::reduceBuildVecToTrunc(SDNode *N) {
+ // TODO: Add support for big-endian.
+ if (DAG.getDataLayout().isBigEndian())
+ return SDValue();
+ if (N->getNumOperands() < 2)
+ return SDValue();
+ SDLoc DL(N);
+ EVT VT = N->getValueType(0);
+ unsigned NumElems = N->getNumOperands();
+
+ if (!isTypeLegal(VT))
+ return SDValue();
+
+ // If the input is something other than an EXTRACT_VECTOR_ELT with a constant
+ // index, bail out.
+ // TODO: Allow undef elements in some cases?
+ if (any_of(N->ops(), [VT](SDValue Op) {
+ return Op.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
+ !isa<ConstantSDNode>(Op.getOperand(1)) ||
+ Op.getValueType() != VT.getVectorElementType();
+ }))
+ return SDValue();
+
+ // Helper for obtaining an EXTRACT_VECTOR_ELT's constant index
+ auto GetExtractIdx = [](SDValue Extract) {
+ return cast<ConstantSDNode>(Extract.getOperand(1))->getSExtValue();
+ };
+
+ // The first BUILD_VECTOR operand must be an an extract from index zero
+ // (assuming no undef and little-endian).
+ if (GetExtractIdx(N->getOperand(0)) != 0)
+ return SDValue();
+
+ // Compute the stride from the first index.
+ int Stride = GetExtractIdx(N->getOperand(1));
+ SDValue ExtractedFromVec = N->getOperand(0).getOperand(0);
+
+ // Proceed only if the stride and the types can be matched to a truncate.
+ if ((Stride == 1 || !isPowerOf2_32(Stride)) ||
+ (ExtractedFromVec.getValueType().getVectorNumElements() !=
+ Stride * NumElems) ||
+ (VT.getScalarSizeInBits() * Stride > 64))
+ return SDValue();
+
+ // Check remaining operands are consistent with the computed stride.
+ for (unsigned i = 1; i != NumElems; ++i) {
+ SDValue Op = N->getOperand(i);
+
+ if ((Op.getOperand(0) != ExtractedFromVec) ||
+ (GetExtractIdx(Op) != Stride * i))
+ return SDValue();
+ }
+
+ // All checks were ok, construct the truncate.
+ LLVMContext &Ctx = *DAG.getContext();
+ EVT NewVT = VT.getVectorVT(
+ Ctx, EVT::getIntegerVT(Ctx, VT.getScalarSizeInBits() * Stride), NumElems);
+ EVT TruncVT =
+ VT.isFloatingPoint() ? VT.changeVectorElementTypeToInteger() : VT;
+
+ SDValue Res = DAG.getBitcast(NewVT, ExtractedFromVec);
+ Res = DAG.getNode(ISD::TRUNCATE, SDLoc(N), TruncVT, Res);
+ return DAG.getBitcast(VT, Res);
+}
+
SDValue DAGCombiner::visitBUILD_VECTOR(SDNode *N) {
EVT VT = N->getValueType(0);
@@ -13300,12 +14393,45 @@ SDValue DAGCombiner::visitBUILD_VECTOR(SDNode *N) {
if (ISD::allOperandsUndef(N))
return DAG.getUNDEF(VT);
+ // Check if we can express BUILD VECTOR via subvector extract.
+ if (!LegalTypes && (N->getNumOperands() > 1)) {
+ SDValue Op0 = N->getOperand(0);
+ auto checkElem = [&](SDValue Op) -> uint64_t {
+ if ((Op.getOpcode() == ISD::EXTRACT_VECTOR_ELT) &&
+ (Op0.getOperand(0) == Op.getOperand(0)))
+ if (auto CNode = dyn_cast<ConstantSDNode>(Op.getOperand(1)))
+ return CNode->getZExtValue();
+ return -1;
+ };
+
+ int Offset = checkElem(Op0);
+ for (unsigned i = 0; i < N->getNumOperands(); ++i) {
+ if (Offset + i != checkElem(N->getOperand(i))) {
+ Offset = -1;
+ break;
+ }
+ }
+
+ if ((Offset == 0) &&
+ (Op0.getOperand(0).getValueType() == N->getValueType(0)))
+ return Op0.getOperand(0);
+ if ((Offset != -1) &&
+ ((Offset % N->getValueType(0).getVectorNumElements()) ==
+ 0)) // IDX must be multiple of output size.
+ return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), N->getValueType(0),
+ Op0.getOperand(0), Op0.getOperand(1));
+ }
+
if (SDValue V = reduceBuildVecExtToExtBuildVec(N))
return V;
if (SDValue V = reduceBuildVecConvertToConvertBuildVec(N))
return V;
+ if (TLI.isDesirableToCombineBuildVectorToTruncate())
+ if (SDValue V = reduceBuildVecToTrunc(N))
+ return V;
+
if (SDValue V = reduceBuildVecToShuffle(N))
return V;
@@ -13419,7 +14545,7 @@ static SDValue combineConcatVectorOfExtracts(SDNode *N, SelectionDAG &DAG) {
if (!isa<ConstantSDNode>(Op.getOperand(1)))
return SDValue();
- int ExtIdx = cast<ConstantSDNode>(Op.getOperand(1))->getZExtValue();
+ int ExtIdx = Op.getConstantOperandVal(1);
// Ensure that we are extracting a subvector from a vector the same
// size as the result.
@@ -13491,8 +14617,11 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {
if (!SclTy.isFloatingPoint() && !SclTy.isInteger())
return SDValue();
- EVT NVT = EVT::getVectorVT(*DAG.getContext(), SclTy,
- VT.getSizeInBits() / SclTy.getSizeInBits());
+ unsigned VNTNumElms = VT.getSizeInBits() / SclTy.getSizeInBits();
+ if (VNTNumElms < 2)
+ return SDValue();
+
+ EVT NVT = EVT::getVectorVT(*DAG.getContext(), SclTy, VNTNumElms);
if (!TLI.isTypeLegal(NVT) || !TLI.isTypeLegal(Scalar.getValueType()))
return SDValue();
@@ -13607,19 +14736,153 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {
return SDValue();
}
+/// If we are extracting a subvector produced by a wide binary operator with at
+/// at least one operand that was the result of a vector concatenation, then try
+/// to use the narrow vector operands directly to avoid the concatenation and
+/// extraction.
+static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG) {
+ // TODO: Refactor with the caller (visitEXTRACT_SUBVECTOR), so we can share
+ // some of these bailouts with other transforms.
+
+ // The extract index must be a constant, so we can map it to a concat operand.
+ auto *ExtractIndex = dyn_cast<ConstantSDNode>(Extract->getOperand(1));
+ if (!ExtractIndex)
+ return SDValue();
+
+ // Only handle the case where we are doubling and then halving. A larger ratio
+ // may require more than two narrow binops to replace the wide binop.
+ EVT VT = Extract->getValueType(0);
+ unsigned NumElems = VT.getVectorNumElements();
+ assert((ExtractIndex->getZExtValue() % NumElems) == 0 &&
+ "Extract index is not a multiple of the vector length.");
+ if (Extract->getOperand(0).getValueSizeInBits() != VT.getSizeInBits() * 2)
+ return SDValue();
+
+ // We are looking for an optionally bitcasted wide vector binary operator
+ // feeding an extract subvector.
+ SDValue BinOp = Extract->getOperand(0);
+ if (BinOp.getOpcode() == ISD::BITCAST)
+ BinOp = BinOp.getOperand(0);
+
+ // TODO: The motivating case for this transform is an x86 AVX1 target. That
+ // target has temptingly almost legal versions of bitwise logic ops in 256-bit
+ // flavors, but no other 256-bit integer support. This could be extended to
+ // handle any binop, but that may require fixing/adding other folds to avoid
+ // codegen regressions.
+ unsigned BOpcode = BinOp.getOpcode();
+ if (BOpcode != ISD::AND && BOpcode != ISD::OR && BOpcode != ISD::XOR)
+ return SDValue();
+
+ // The binop must be a vector type, so we can chop it in half.
+ EVT WideBVT = BinOp.getValueType();
+ if (!WideBVT.isVector())
+ return SDValue();
+
+ // Bail out if the target does not support a narrower version of the binop.
+ EVT NarrowBVT = EVT::getVectorVT(*DAG.getContext(), WideBVT.getScalarType(),
+ WideBVT.getVectorNumElements() / 2);
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+ if (!TLI.isOperationLegalOrCustomOrPromote(BOpcode, NarrowBVT))
+ return SDValue();
+
+ // Peek through bitcasts of the binary operator operands if needed.
+ SDValue LHS = BinOp.getOperand(0);
+ if (LHS.getOpcode() == ISD::BITCAST)
+ LHS = LHS.getOperand(0);
+
+ SDValue RHS = BinOp.getOperand(1);
+ if (RHS.getOpcode() == ISD::BITCAST)
+ RHS = RHS.getOperand(0);
+
+ // We need at least one concatenation operation of a binop operand to make
+ // this transform worthwhile. The concat must double the input vector sizes.
+ // TODO: Should we also handle INSERT_SUBVECTOR patterns?
+ bool ConcatL =
+ LHS.getOpcode() == ISD::CONCAT_VECTORS && LHS.getNumOperands() == 2;
+ bool ConcatR =
+ RHS.getOpcode() == ISD::CONCAT_VECTORS && RHS.getNumOperands() == 2;
+ if (!ConcatL && !ConcatR)
+ return SDValue();
+
+ // If one of the binop operands was not the result of a concat, we must
+ // extract a half-sized operand for our new narrow binop. We can't just reuse
+ // the original extract index operand because we may have bitcasted.
+ unsigned ConcatOpNum = ExtractIndex->getZExtValue() / NumElems;
+ unsigned ExtBOIdx = ConcatOpNum * NarrowBVT.getVectorNumElements();
+ EVT ExtBOIdxVT = Extract->getOperand(1).getValueType();
+ SDLoc DL(Extract);
+
+ // extract (binop (concat X1, X2), (concat Y1, Y2)), N --> binop XN, YN
+ // extract (binop (concat X1, X2), Y), N --> binop XN, (extract Y, N)
+ // extract (binop X, (concat Y1, Y2)), N --> binop (extract X, N), YN
+ SDValue X = ConcatL ? DAG.getBitcast(NarrowBVT, LHS.getOperand(ConcatOpNum))
+ : DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
+ BinOp.getOperand(0),
+ DAG.getConstant(ExtBOIdx, DL, ExtBOIdxVT));
+
+ SDValue Y = ConcatR ? DAG.getBitcast(NarrowBVT, RHS.getOperand(ConcatOpNum))
+ : DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
+ BinOp.getOperand(1),
+ DAG.getConstant(ExtBOIdx, DL, ExtBOIdxVT));
+
+ SDValue NarrowBinOp = DAG.getNode(BOpcode, DL, NarrowBVT, X, Y);
+ return DAG.getBitcast(VT, NarrowBinOp);
+}
+
+/// If we are extracting a subvector from a wide vector load, convert to a
+/// narrow load to eliminate the extraction:
+/// (extract_subvector (load wide vector)) --> (load narrow vector)
+static SDValue narrowExtractedVectorLoad(SDNode *Extract, SelectionDAG &DAG) {
+ // TODO: Add support for big-endian. The offset calculation must be adjusted.
+ if (DAG.getDataLayout().isBigEndian())
+ return SDValue();
+
+ // TODO: The one-use check is overly conservative. Check the cost of the
+ // extract instead or remove that condition entirely.
+ auto *Ld = dyn_cast<LoadSDNode>(Extract->getOperand(0));
+ auto *ExtIdx = dyn_cast<ConstantSDNode>(Extract->getOperand(1));
+ if (!Ld || !Ld->hasOneUse() || Ld->getExtensionType() || Ld->isVolatile() ||
+ !ExtIdx)
+ return SDValue();
+
+ // The narrow load will be offset from the base address of the old load if
+ // we are extracting from something besides index 0 (little-endian).
+ EVT VT = Extract->getValueType(0);
+ SDLoc DL(Extract);
+ SDValue BaseAddr = Ld->getOperand(1);
+ unsigned Offset = ExtIdx->getZExtValue() * VT.getScalarType().getStoreSize();
+
+ // TODO: Use "BaseIndexOffset" to make this more effective.
+ SDValue NewAddr = DAG.getMemBasePlusOffset(BaseAddr, Offset, DL);
+ MachineFunction &MF = DAG.getMachineFunction();
+ MachineMemOperand *MMO = MF.getMachineMemOperand(Ld->getMemOperand(), Offset,
+ VT.getStoreSize());
+ SDValue NewLd = DAG.getLoad(VT, DL, Ld->getChain(), NewAddr, MMO);
+ DAG.makeEquivalentMemoryOrdering(Ld, NewLd);
+ return NewLd;
+}
+
SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode* N) {
EVT NVT = N->getValueType(0);
SDValue V = N->getOperand(0);
- if (V->getOpcode() == ISD::CONCAT_VECTORS) {
- // Combine:
- // (extract_subvec (concat V1, V2, ...), i)
- // Into:
- // Vi if possible
- // Only operand 0 is checked as 'concat' assumes all inputs of the same
- // type.
- if (V->getOperand(0).getValueType() != NVT)
- return SDValue();
+ // Extract from UNDEF is UNDEF.
+ if (V.isUndef())
+ return DAG.getUNDEF(NVT);
+
+ if (TLI.isOperationLegalOrCustomOrPromote(ISD::LOAD, NVT))
+ if (SDValue NarrowLoad = narrowExtractedVectorLoad(N, DAG))
+ return NarrowLoad;
+
+ // Combine:
+ // (extract_subvec (concat V1, V2, ...), i)
+ // Into:
+ // Vi if possible
+ // Only operand 0 is checked as 'concat' assumes all inputs of the same
+ // type.
+ if (V->getOpcode() == ISD::CONCAT_VECTORS &&
+ isa<ConstantSDNode>(N->getOperand(1)) &&
+ V->getOperand(0).getValueType() == NVT) {
unsigned Idx = N->getConstantOperandVal(1);
unsigned NumElems = NVT.getVectorNumElements();
assert((Idx % NumElems) == 0 &&
@@ -13633,19 +14896,16 @@ SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode* N) {
if (V->getOpcode() == ISD::INSERT_SUBVECTOR) {
// Handle only simple case where vector being inserted and vector
- // being extracted are of same type, and are half size of larger vectors.
- EVT BigVT = V->getOperand(0).getValueType();
+ // being extracted are of same size.
EVT SmallVT = V->getOperand(1).getValueType();
- if (!NVT.bitsEq(SmallVT) || NVT.getSizeInBits()*2 != BigVT.getSizeInBits())
+ if (!NVT.bitsEq(SmallVT))
return SDValue();
- // Only handle cases where both indexes are constants with the same type.
+ // Only handle cases where both indexes are constants.
ConstantSDNode *ExtIdx = dyn_cast<ConstantSDNode>(N->getOperand(1));
ConstantSDNode *InsIdx = dyn_cast<ConstantSDNode>(V->getOperand(2));
- if (InsIdx && ExtIdx &&
- InsIdx->getValueType(0).getSizeInBits() <= 64 &&
- ExtIdx->getValueType(0).getSizeInBits() <= 64) {
+ if (InsIdx && ExtIdx) {
// Combine:
// (extract_subvec (insert_subvec V1, V2, InsIdx), ExtIdx)
// Into:
@@ -13661,6 +14921,9 @@ SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode* N) {
}
}
+ if (SDValue NarrowBOp = narrowExtractedVectorBinOp(N, DAG))
+ return NarrowBOp;
+
return SDValue();
}
@@ -13892,6 +15155,167 @@ static SDValue combineShuffleOfScalars(ShuffleVectorSDNode *SVN,
return DAG.getBuildVector(VT, SDLoc(SVN), Ops);
}
+// Match shuffles that can be converted to any_vector_extend_in_reg.
+// This is often generated during legalization.
+// e.g. v4i32 <0,u,1,u> -> (v2i64 any_vector_extend_in_reg(v4i32 src))
+// TODO Add support for ZERO_EXTEND_VECTOR_INREG when we have a test case.
+static SDValue combineShuffleToVectorExtend(ShuffleVectorSDNode *SVN,
+ SelectionDAG &DAG,
+ const TargetLowering &TLI,
+ bool LegalOperations) {
+ EVT VT = SVN->getValueType(0);
+ bool IsBigEndian = DAG.getDataLayout().isBigEndian();
+
+ // TODO Add support for big-endian when we have a test case.
+ if (!VT.isInteger() || IsBigEndian)
+ return SDValue();
+
+ unsigned NumElts = VT.getVectorNumElements();
+ unsigned EltSizeInBits = VT.getScalarSizeInBits();
+ ArrayRef<int> Mask = SVN->getMask();
+ SDValue N0 = SVN->getOperand(0);
+
+ // shuffle<0,-1,1,-1> == (v2i64 anyextend_vector_inreg(v4i32))
+ auto isAnyExtend = [&Mask, &NumElts](unsigned Scale) {
+ for (unsigned i = 0; i != NumElts; ++i) {
+ if (Mask[i] < 0)
+ continue;
+ if ((i % Scale) == 0 && Mask[i] == (int)(i / Scale))
+ continue;
+ return false;
+ }
+ return true;
+ };
+
+ // Attempt to match a '*_extend_vector_inreg' shuffle, we just search for
+ // power-of-2 extensions as they are the most likely.
+ for (unsigned Scale = 2; Scale < NumElts; Scale *= 2) {
+ if (!isAnyExtend(Scale))
+ continue;
+
+ EVT OutSVT = EVT::getIntegerVT(*DAG.getContext(), EltSizeInBits * Scale);
+ EVT OutVT = EVT::getVectorVT(*DAG.getContext(), OutSVT, NumElts / Scale);
+ if (!LegalOperations ||
+ TLI.isOperationLegalOrCustom(ISD::ANY_EXTEND_VECTOR_INREG, OutVT))
+ return DAG.getBitcast(VT,
+ DAG.getAnyExtendVectorInReg(N0, SDLoc(SVN), OutVT));
+ }
+
+ return SDValue();
+}
+
+// Detect 'truncate_vector_inreg' style shuffles that pack the lower parts of
+// each source element of a large type into the lowest elements of a smaller
+// destination type. This is often generated during legalization.
+// If the source node itself was a '*_extend_vector_inreg' node then we should
+// then be able to remove it.
+static SDValue combineTruncationShuffle(ShuffleVectorSDNode *SVN,
+ SelectionDAG &DAG) {
+ EVT VT = SVN->getValueType(0);
+ bool IsBigEndian = DAG.getDataLayout().isBigEndian();
+
+ // TODO Add support for big-endian when we have a test case.
+ if (!VT.isInteger() || IsBigEndian)
+ return SDValue();
+
+ SDValue N0 = SVN->getOperand(0);
+ while (N0.getOpcode() == ISD::BITCAST)
+ N0 = N0.getOperand(0);
+
+ unsigned Opcode = N0.getOpcode();
+ if (Opcode != ISD::ANY_EXTEND_VECTOR_INREG &&
+ Opcode != ISD::SIGN_EXTEND_VECTOR_INREG &&
+ Opcode != ISD::ZERO_EXTEND_VECTOR_INREG)
+ return SDValue();
+
+ SDValue N00 = N0.getOperand(0);
+ ArrayRef<int> Mask = SVN->getMask();
+ unsigned NumElts = VT.getVectorNumElements();
+ unsigned EltSizeInBits = VT.getScalarSizeInBits();
+ unsigned ExtSrcSizeInBits = N00.getScalarValueSizeInBits();
+ unsigned ExtDstSizeInBits = N0.getScalarValueSizeInBits();
+
+ if (ExtDstSizeInBits % ExtSrcSizeInBits != 0)
+ return SDValue();
+ unsigned ExtScale = ExtDstSizeInBits / ExtSrcSizeInBits;
+
+ // (v4i32 truncate_vector_inreg(v2i64)) == shuffle<0,2-1,-1>
+ // (v8i16 truncate_vector_inreg(v4i32)) == shuffle<0,2,4,6,-1,-1,-1,-1>
+ // (v8i16 truncate_vector_inreg(v2i64)) == shuffle<0,4,-1,-1,-1,-1,-1,-1>
+ auto isTruncate = [&Mask, &NumElts](unsigned Scale) {
+ for (unsigned i = 0; i != NumElts; ++i) {
+ if (Mask[i] < 0)
+ continue;
+ if ((i * Scale) < NumElts && Mask[i] == (int)(i * Scale))
+ continue;
+ return false;
+ }
+ return true;
+ };
+
+ // At the moment we just handle the case where we've truncated back to the
+ // same size as before the extension.
+ // TODO: handle more extension/truncation cases as cases arise.
+ if (EltSizeInBits != ExtSrcSizeInBits)
+ return SDValue();
+
+ // We can remove *extend_vector_inreg only if the truncation happens at
+ // the same scale as the extension.
+ if (isTruncate(ExtScale))
+ return DAG.getBitcast(VT, N00);
+
+ return SDValue();
+}
+
+// Combine shuffles of splat-shuffles of the form:
+// shuffle (shuffle V, undef, splat-mask), undef, M
+// If splat-mask contains undef elements, we need to be careful about
+// introducing undef's in the folded mask which are not the result of composing
+// the masks of the shuffles.
+static SDValue combineShuffleOfSplat(ArrayRef<int> UserMask,
+ ShuffleVectorSDNode *Splat,
+ SelectionDAG &DAG) {
+ ArrayRef<int> SplatMask = Splat->getMask();
+ assert(UserMask.size() == SplatMask.size() && "Mask length mismatch");
+
+ // Prefer simplifying to the splat-shuffle, if possible. This is legal if
+ // every undef mask element in the splat-shuffle has a corresponding undef
+ // element in the user-shuffle's mask or if the composition of mask elements
+ // would result in undef.
+ // Examples for (shuffle (shuffle v, undef, SplatMask), undef, UserMask):
+ // * UserMask=[0,2,u,u], SplatMask=[2,u,2,u] -> [2,2,u,u]
+ // In this case it is not legal to simplify to the splat-shuffle because we
+ // may be exposing the users of the shuffle an undef element at index 1
+ // which was not there before the combine.
+ // * UserMask=[0,u,2,u], SplatMask=[2,u,2,u] -> [2,u,2,u]
+ // In this case the composition of masks yields SplatMask, so it's ok to
+ // simplify to the splat-shuffle.
+ // * UserMask=[3,u,2,u], SplatMask=[2,u,2,u] -> [u,u,2,u]
+ // In this case the composed mask includes all undef elements of SplatMask
+ // and in addition sets element zero to undef. It is safe to simplify to
+ // the splat-shuffle.
+ auto CanSimplifyToExistingSplat = [](ArrayRef<int> UserMask,
+ ArrayRef<int> SplatMask) {
+ for (unsigned i = 0, e = UserMask.size(); i != e; ++i)
+ if (UserMask[i] != -1 && SplatMask[i] == -1 &&
+ SplatMask[UserMask[i]] != -1)
+ return false;
+ return true;
+ };
+ if (CanSimplifyToExistingSplat(UserMask, SplatMask))
+ return SDValue(Splat, 0);
+
+ // Create a new shuffle with a mask that is composed of the two shuffles'
+ // masks.
+ SmallVector<int, 32> NewMask;
+ for (int Idx : UserMask)
+ NewMask.push_back(Idx == -1 ? -1 : SplatMask[Idx]);
+
+ return DAG.getVectorShuffle(Splat->getValueType(0), SDLoc(Splat),
+ Splat->getOperand(0), Splat->getOperand(1),
+ NewMask);
+}
+
SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
EVT VT = N->getValueType(0);
unsigned NumElts = VT.getVectorNumElements();
@@ -13938,6 +15362,11 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
return DAG.getVectorShuffle(VT, SDLoc(N), N0, N1, NewMask);
}
+ // A shuffle of a single vector that is a splat can always be folded.
+ if (auto *N0Shuf = dyn_cast<ShuffleVectorSDNode>(N0))
+ if (N1->isUndef() && N0Shuf->isSplat())
+ return combineShuffleOfSplat(SVN->getMask(), N0Shuf, DAG);
+
// If it is a splat, check if the argument vector is another splat or a
// build_vector.
if (SVN->isSplat() && SVN->getSplatIndex() < (int)NumElts) {
@@ -13996,6 +15425,14 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
if (SDValue S = simplifyShuffleOperands(SVN, N0, N1, DAG))
return S;
+ // Match shuffles that can be converted to any_vector_extend_in_reg.
+ if (SDValue V = combineShuffleToVectorExtend(SVN, DAG, TLI, LegalOperations))
+ return V;
+
+ // Combine "truncate_vector_in_reg" style shuffles.
+ if (SDValue V = combineTruncationShuffle(SVN, DAG))
+ return V;
+
if (N0.getOpcode() == ISD::CONCAT_VECTORS &&
Level < AfterLegalizeVectorOps &&
(N1.isUndef() ||
@@ -14253,6 +15690,16 @@ SDValue DAGCombiner::visitINSERT_SUBVECTOR(SDNode *N) {
SDValue N1 = N->getOperand(1);
SDValue N2 = N->getOperand(2);
+ // If inserting an UNDEF, just return the original vector.
+ if (N1.isUndef())
+ return N0;
+
+ // If this is an insert of an extracted vector into an undef vector, we can
+ // just use the input to the extract.
+ if (N0.isUndef() && N1.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
+ N1.getOperand(1) == N2 && N1.getOperand(0).getValueType() == VT)
+ return N1.getOperand(0);
+
// Combine INSERT_SUBVECTORs where we are inserting to the same index.
// INSERT_SUBVECTOR( INSERT_SUBVECTOR( Vec, SubOld, Idx ), SubNew, Idx )
// --> INSERT_SUBVECTOR( Vec, SubNew, Idx )
@@ -14262,26 +15709,39 @@ SDValue DAGCombiner::visitINSERT_SUBVECTOR(SDNode *N) {
return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, N0.getOperand(0),
N1, N2);
- if (N0.getValueType() != N1.getValueType())
+ if (!isa<ConstantSDNode>(N2))
return SDValue();
+ unsigned InsIdx = cast<ConstantSDNode>(N2)->getZExtValue();
+
+ // Canonicalize insert_subvector dag nodes.
+ // Example:
+ // (insert_subvector (insert_subvector A, Idx0), Idx1)
+ // -> (insert_subvector (insert_subvector A, Idx1), Idx0)
+ if (N0.getOpcode() == ISD::INSERT_SUBVECTOR && N0.hasOneUse() &&
+ N1.getValueType() == N0.getOperand(1).getValueType() &&
+ isa<ConstantSDNode>(N0.getOperand(2))) {
+ unsigned OtherIdx = N0.getConstantOperandVal(2);
+ if (InsIdx < OtherIdx) {
+ // Swap nodes.
+ SDValue NewOp = DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT,
+ N0.getOperand(0), N1, N2);
+ AddToWorklist(NewOp.getNode());
+ return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N0.getNode()),
+ VT, NewOp, N0.getOperand(1), N0.getOperand(2));
+ }
+ }
+
// If the input vector is a concatenation, and the insert replaces
- // one of the halves, we can optimize into a single concat_vectors.
- if (N0.getOpcode() == ISD::CONCAT_VECTORS && N0->getNumOperands() == 2 &&
- N2.getOpcode() == ISD::Constant) {
- APInt InsIdx = cast<ConstantSDNode>(N2)->getAPIntValue();
-
- // Lower half: fold (insert_subvector (concat_vectors X, Y), Z) ->
- // (concat_vectors Z, Y)
- if (InsIdx == 0)
- return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, N1,
- N0.getOperand(1));
+ // one of the pieces, we can optimize into a single concat_vectors.
+ if (N0.getOpcode() == ISD::CONCAT_VECTORS && N0.hasOneUse() &&
+ N0.getOperand(0).getValueType() == N1.getValueType()) {
+ unsigned Factor = N1.getValueType().getVectorNumElements();
+
+ SmallVector<SDValue, 8> Ops(N0->op_begin(), N0->op_end());
+ Ops[cast<ConstantSDNode>(N2)->getZExtValue() / Factor] = N1;
- // Upper half: fold (insert_subvector (concat_vectors X, Y), Z) ->
- // (concat_vectors X, Z)
- if (InsIdx == VT.getVectorNumElements() / 2)
- return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, N0.getOperand(0),
- N1);
+ return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
}
return SDValue();
@@ -14366,9 +15826,9 @@ SDValue DAGCombiner::XformToShuffleWithZero(SDNode *N) {
// Extract the sub element from the constant bit mask.
if (DAG.getDataLayout().isBigEndian()) {
- Bits = Bits.lshr((Split - SubIdx - 1) * NumSubBits);
+ Bits.lshrInPlace((Split - SubIdx - 1) * NumSubBits);
} else {
- Bits = Bits.lshr(SubIdx * NumSubBits);
+ Bits.lshrInPlace(SubIdx * NumSubBits);
}
if (Split > 1)
@@ -15041,7 +16501,7 @@ SDValue DAGCombiner::BuildLogBase2(SDValue V, const SDLoc &DL) {
/// =>
/// X_{i+1} = X_i (2 - A X_i) = X_i + X_i (1 - A X_i) [this second form
/// does not require additional intermediate precision]
-SDValue DAGCombiner::BuildReciprocalEstimate(SDValue Op, SDNodeFlags *Flags) {
+SDValue DAGCombiner::BuildReciprocalEstimate(SDValue Op, SDNodeFlags Flags) {
if (Level >= AfterLegalizeDAG)
return SDValue();
@@ -15096,7 +16556,7 @@ SDValue DAGCombiner::BuildReciprocalEstimate(SDValue Op, SDNodeFlags *Flags) {
/// As a result, we precompute A/2 prior to the iteration loop.
SDValue DAGCombiner::buildSqrtNROneConst(SDValue Arg, SDValue Est,
unsigned Iterations,
- SDNodeFlags *Flags, bool Reciprocal) {
+ SDNodeFlags Flags, bool Reciprocal) {
EVT VT = Arg.getValueType();
SDLoc DL(Arg);
SDValue ThreeHalves = DAG.getConstantFP(1.5, DL, VT);
@@ -15140,7 +16600,7 @@ SDValue DAGCombiner::buildSqrtNROneConst(SDValue Arg, SDValue Est,
/// X_{i+1} = (-0.5 * X_i) * (A * X_i * X_i + (-3.0))
SDValue DAGCombiner::buildSqrtNRTwoConst(SDValue Arg, SDValue Est,
unsigned Iterations,
- SDNodeFlags *Flags, bool Reciprocal) {
+ SDNodeFlags Flags, bool Reciprocal) {
EVT VT = Arg.getValueType();
SDLoc DL(Arg);
SDValue MinusThree = DAG.getConstantFP(-3.0, DL, VT);
@@ -15185,7 +16645,7 @@ SDValue DAGCombiner::buildSqrtNRTwoConst(SDValue Arg, SDValue Est,
/// Build code to calculate either rsqrt(Op) or sqrt(Op). In the latter case
/// Op*rsqrt(Op) is actually computed, so additional postprocessing is needed if
/// Op can be zero.
-SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags *Flags,
+SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags,
bool Reciprocal) {
if (Level >= AfterLegalizeDAG)
return SDValue();
@@ -15238,17 +16698,17 @@ SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags *Flags,
return SDValue();
}
-SDValue DAGCombiner::buildRsqrtEstimate(SDValue Op, SDNodeFlags *Flags) {
+SDValue DAGCombiner::buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags) {
return buildSqrtEstimateImpl(Op, Flags, true);
}
-SDValue DAGCombiner::buildSqrtEstimate(SDValue Op, SDNodeFlags *Flags) {
+SDValue DAGCombiner::buildSqrtEstimate(SDValue Op, SDNodeFlags Flags) {
return buildSqrtEstimateImpl(Op, Flags, false);
}
/// Return true if base is a frame index, which is known not to alias with
/// anything but itself. Provides base object and offset as results.
-static bool FindBaseOffset(SDValue Ptr, SDValue &Base, int64_t &Offset,
+static bool findBaseOffset(SDValue Ptr, SDValue &Base, int64_t &Offset,
const GlobalValue *&GV, const void *&CV) {
// Assume it is a primitive operation.
Base = Ptr; Offset = 0; GV = nullptr; CV = nullptr;
@@ -15257,7 +16717,7 @@ static bool FindBaseOffset(SDValue Ptr, SDValue &Base, int64_t &Offset,
if (Base.getOpcode() == ISD::ADD) {
if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Base.getOperand(1))) {
Base = Base.getOperand(0);
- Offset += C->getZExtValue();
+ Offset += C->getSExtValue();
}
}
@@ -15300,54 +16760,82 @@ bool DAGCombiner::isAlias(LSBaseSDNode *Op0, LSBaseSDNode *Op1) const {
if (Op1->isInvariant() && Op0->writeMem())
return false;
+ unsigned NumBytes0 = Op0->getMemoryVT().getSizeInBits() >> 3;
+ unsigned NumBytes1 = Op1->getMemoryVT().getSizeInBits() >> 3;
+
+ // Check for BaseIndexOffset matching.
+ BaseIndexOffset BasePtr0 = BaseIndexOffset::match(Op0->getBasePtr(), DAG);
+ BaseIndexOffset BasePtr1 = BaseIndexOffset::match(Op1->getBasePtr(), DAG);
+ int64_t PtrDiff;
+ if (BasePtr0.equalBaseIndex(BasePtr1, DAG, PtrDiff))
+ return !((NumBytes0 <= PtrDiff) || (PtrDiff + NumBytes1 <= 0));
+
+ // If both BasePtr0 and BasePtr1 are FrameIndexes, we will not be
+ // able to calculate their relative offset if at least one arises
+ // from an alloca. However, these allocas cannot overlap and we
+ // can infer there is no alias.
+ if (auto *A = dyn_cast<FrameIndexSDNode>(BasePtr0.getBase()))
+ if (auto *B = dyn_cast<FrameIndexSDNode>(BasePtr1.getBase())) {
+ MachineFrameInfo &MFI = DAG.getMachineFunction().getFrameInfo();
+ // If the base are the same frame index but the we couldn't find a
+ // constant offset, (indices are different) be conservative.
+ if (A != B && (!MFI.isFixedObjectIndex(A->getIndex()) ||
+ !MFI.isFixedObjectIndex(B->getIndex())))
+ return false;
+ }
+
+ // FIXME: findBaseOffset and ConstantValue/GlobalValue/FrameIndex analysis
+ // modified to use BaseIndexOffset.
+
// Gather base node and offset information.
- SDValue Base1, Base2;
- int64_t Offset1, Offset2;
- const GlobalValue *GV1, *GV2;
- const void *CV1, *CV2;
- bool isFrameIndex1 = FindBaseOffset(Op0->getBasePtr(),
+ SDValue Base0, Base1;
+ int64_t Offset0, Offset1;
+ const GlobalValue *GV0, *GV1;
+ const void *CV0, *CV1;
+ bool IsFrameIndex0 = findBaseOffset(Op0->getBasePtr(),
+ Base0, Offset0, GV0, CV0);
+ bool IsFrameIndex1 = findBaseOffset(Op1->getBasePtr(),
Base1, Offset1, GV1, CV1);
- bool isFrameIndex2 = FindBaseOffset(Op1->getBasePtr(),
- Base2, Offset2, GV2, CV2);
- // If they have a same base address then check to see if they overlap.
- if (Base1 == Base2 || (GV1 && (GV1 == GV2)) || (CV1 && (CV1 == CV2)))
- return !((Offset1 + (Op0->getMemoryVT().getSizeInBits() >> 3)) <= Offset2 ||
- (Offset2 + (Op1->getMemoryVT().getSizeInBits() >> 3)) <= Offset1);
+ // If they have the same base address, then check to see if they overlap.
+ if (Base0 == Base1 || (GV0 && (GV0 == GV1)) || (CV0 && (CV0 == CV1)))
+ return !((Offset0 + NumBytes0) <= Offset1 ||
+ (Offset1 + NumBytes1) <= Offset0);
// It is possible for different frame indices to alias each other, mostly
// when tail call optimization reuses return address slots for arguments.
// To catch this case, look up the actual index of frame indices to compute
// the real alias relationship.
- if (isFrameIndex1 && isFrameIndex2) {
+ if (IsFrameIndex0 && IsFrameIndex1) {
MachineFrameInfo &MFI = DAG.getMachineFunction().getFrameInfo();
+ Offset0 += MFI.getObjectOffset(cast<FrameIndexSDNode>(Base0)->getIndex());
Offset1 += MFI.getObjectOffset(cast<FrameIndexSDNode>(Base1)->getIndex());
- Offset2 += MFI.getObjectOffset(cast<FrameIndexSDNode>(Base2)->getIndex());
- return !((Offset1 + (Op0->getMemoryVT().getSizeInBits() >> 3)) <= Offset2 ||
- (Offset2 + (Op1->getMemoryVT().getSizeInBits() >> 3)) <= Offset1);
+ return !((Offset0 + NumBytes0) <= Offset1 ||
+ (Offset1 + NumBytes1) <= Offset0);
}
// Otherwise, if we know what the bases are, and they aren't identical, then
// we know they cannot alias.
- if ((isFrameIndex1 || CV1 || GV1) && (isFrameIndex2 || CV2 || GV2))
+ if ((IsFrameIndex0 || CV0 || GV0) && (IsFrameIndex1 || CV1 || GV1))
return false;
// If we know required SrcValue1 and SrcValue2 have relatively large alignment
// compared to the size and offset of the access, we may be able to prove they
- // do not alias. This check is conservative for now to catch cases created by
+ // do not alias. This check is conservative for now to catch cases created by
// splitting vector types.
- if ((Op0->getOriginalAlignment() == Op1->getOriginalAlignment()) &&
- (Op0->getSrcValueOffset() != Op1->getSrcValueOffset()) &&
- (Op0->getMemoryVT().getSizeInBits() >> 3 ==
- Op1->getMemoryVT().getSizeInBits() >> 3) &&
- (Op0->getOriginalAlignment() > (Op0->getMemoryVT().getSizeInBits() >> 3))) {
- int64_t OffAlign1 = Op0->getSrcValueOffset() % Op0->getOriginalAlignment();
- int64_t OffAlign2 = Op1->getSrcValueOffset() % Op1->getOriginalAlignment();
+ int64_t SrcValOffset0 = Op0->getSrcValueOffset();
+ int64_t SrcValOffset1 = Op1->getSrcValueOffset();
+ unsigned OrigAlignment0 = Op0->getOriginalAlignment();
+ unsigned OrigAlignment1 = Op1->getOriginalAlignment();
+ if (OrigAlignment0 == OrigAlignment1 && SrcValOffset0 != SrcValOffset1 &&
+ NumBytes0 == NumBytes1 && OrigAlignment0 > NumBytes0) {
+ int64_t OffAlign0 = SrcValOffset0 % OrigAlignment0;
+ int64_t OffAlign1 = SrcValOffset1 % OrigAlignment1;
// There is no overlap between these relatively aligned accesses of similar
- // size, return no alias.
- if ((OffAlign1 + (Op0->getMemoryVT().getSizeInBits() >> 3)) <= OffAlign2 ||
- (OffAlign2 + (Op1->getMemoryVT().getSizeInBits() >> 3)) <= OffAlign1)
+ // size. Return no alias.
+ if ((OffAlign0 + NumBytes0) <= OffAlign1 ||
+ (OffAlign1 + NumBytes1) <= OffAlign0)
return false;
}
@@ -15359,20 +16847,18 @@ bool DAGCombiner::isAlias(LSBaseSDNode *Op0, LSBaseSDNode *Op1) const {
CombinerAAOnlyFunc != DAG.getMachineFunction().getName())
UseAA = false;
#endif
- if (UseAA &&
+
+ if (UseAA && AA &&
Op0->getMemOperand()->getValue() && Op1->getMemOperand()->getValue()) {
// Use alias analysis information.
- int64_t MinOffset = std::min(Op0->getSrcValueOffset(),
- Op1->getSrcValueOffset());
- int64_t Overlap1 = (Op0->getMemoryVT().getSizeInBits() >> 3) +
- Op0->getSrcValueOffset() - MinOffset;
- int64_t Overlap2 = (Op1->getMemoryVT().getSizeInBits() >> 3) +
- Op1->getSrcValueOffset() - MinOffset;
+ int64_t MinOffset = std::min(SrcValOffset0, SrcValOffset1);
+ int64_t Overlap0 = NumBytes0 + SrcValOffset0 - MinOffset;
+ int64_t Overlap1 = NumBytes1 + SrcValOffset1 - MinOffset;
AliasResult AAResult =
- AA.alias(MemoryLocation(Op0->getMemOperand()->getValue(), Overlap1,
- UseTBAA ? Op0->getAAInfo() : AAMDNodes()),
- MemoryLocation(Op1->getMemOperand()->getValue(), Overlap2,
- UseTBAA ? Op1->getAAInfo() : AAMDNodes()));
+ AA->alias(MemoryLocation(Op0->getMemOperand()->getValue(), Overlap0,
+ UseTBAA ? Op0->getAAInfo() : AAMDNodes()),
+ MemoryLocation(Op1->getMemOperand()->getValue(), Overlap1,
+ UseTBAA ? Op1->getAAInfo() : AAMDNodes()) );
if (AAResult == NoAlias)
return false;
}
@@ -15454,6 +16940,12 @@ void DAGCombiner::GatherAllAliases(SDNode *N, SDValue OriginalChain,
++Depth;
break;
+ case ISD::CopyFromReg:
+ // Forward past CopyFromReg.
+ Chains.push_back(Chain.getOperand(0));
+ ++Depth;
+ break;
+
default:
// For all other instructions we will just have to take what we can get.
Aliases.push_back(Chain);
@@ -15482,17 +16974,29 @@ SDValue DAGCombiner::FindBetterChain(SDNode *N, SDValue OldChain) {
return DAG.getNode(ISD::TokenFactor, SDLoc(N), MVT::Other, Aliases);
}
+// This function tries to collect a bunch of potentially interesting
+// nodes to improve the chains of, all at once. This might seem
+// redundant, as this function gets called when visiting every store
+// node, so why not let the work be done on each store as it's visited?
+//
+// I believe this is mainly important because MergeConsecutiveStores
+// is unable to deal with merging stores of different sizes, so unless
+// we improve the chains of all the potential candidates up-front
+// before running MergeConsecutiveStores, it might only see some of
+// the nodes that will eventually be candidates, and then not be able
+// to go from a partially-merged state to the desired final
+// fully-merged state.
bool DAGCombiner::findBetterNeighborChains(StoreSDNode *St) {
// This holds the base pointer, index, and the offset in bytes from the base
// pointer.
BaseIndexOffset BasePtr = BaseIndexOffset::match(St->getBasePtr(), DAG);
// We must have a base and an offset.
- if (!BasePtr.Base.getNode())
+ if (!BasePtr.getBase().getNode())
return false;
// Do not handle stores to undef base pointers.
- if (BasePtr.Base.isUndef())
+ if (BasePtr.getBase().isUndef())
return false;
SmallVector<StoreSDNode *, 8> ChainedStores;
@@ -15514,13 +17018,11 @@ bool DAGCombiner::findBetterNeighborChains(StoreSDNode *St) {
BaseIndexOffset Ptr = BaseIndexOffset::match(Index->getBasePtr(), DAG);
// Check that the base pointer is the same as the original one.
- if (!Ptr.equalBaseIndex(BasePtr))
+ if (!BasePtr.equalBaseIndex(Ptr, DAG))
break;
- // Find the next memory operand in the chain. If the next operand in the
- // chain is a store then move up and continue the scan with the next
- // memory operand. If the next operand is a load save it and use alias
- // information to check if it interferes with anything.
+ // Walk up the chain to find the next store node, ignoring any
+ // intermediate loads. Any other kind of node will halt the loop.
SDNode *NextInChain = Index->getChain().getNode();
while (true) {
if (StoreSDNode *STn = dyn_cast<StoreSDNode>(NextInChain)) {
@@ -15539,9 +17041,14 @@ bool DAGCombiner::findBetterNeighborChains(StoreSDNode *St) {
Index = nullptr;
break;
}
- }
+ } // end while
}
+ // At this point, ChainedStores lists all of the Store nodes
+ // reachable by iterating up through chain nodes matching the above
+ // conditions. For each such store identified, try to find an
+ // earlier chain to attach the store to which won't violate the
+ // required ordering.
bool MadeChangeToSt = false;
SmallVector<std::pair<StoreSDNode *, SDValue>, 8> BetterChains;
@@ -15565,7 +17072,7 @@ bool DAGCombiner::findBetterNeighborChains(StoreSDNode *St) {
}
/// This is the entry point for the file.
-void SelectionDAG::Combine(CombineLevel Level, AliasAnalysis &AA,
+void SelectionDAG::Combine(CombineLevel Level, AliasAnalysis *AA,
CodeGenOpt::Level OptLevel) {
/// This is the main entry point to this class.
DAGCombiner(*this, AA, OptLevel).Run(Level);
OpenPOWER on IntegriCloud