diff options
Diffstat (limited to 'contrib/llvm/lib/Target/X86/X86ISelLowering.cpp')
-rw-r--r-- | contrib/llvm/lib/Target/X86/X86ISelLowering.cpp | 8595 |
1 files changed, 5759 insertions, 2836 deletions
diff --git a/contrib/llvm/lib/Target/X86/X86ISelLowering.cpp b/contrib/llvm/lib/Target/X86/X86ISelLowering.cpp index f499e56..08fe2ba 100644 --- a/contrib/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/contrib/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -17,6 +17,7 @@ #include "X86CallingConv.h" #include "X86FrameLowering.h" #include "X86InstrBuilder.h" +#include "X86IntrinsicsInfo.h" #include "X86MachineFunctionInfo.h" #include "X86ShuffleDecodeConstantPool.h" #include "X86TargetMachine.h" @@ -53,10 +54,10 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MathExtras.h" #include "llvm/Target/TargetOptions.h" -#include "X86IntrinsicsInfo.h" +#include <algorithm> #include <bitset> -#include <numeric> #include <cctype> +#include <numeric> using namespace llvm; #define DEBUG_TYPE "x86-isel" @@ -96,15 +97,16 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, const X86RegisterInfo *RegInfo = Subtarget.getRegisterInfo(); setStackPointerRegisterToSaveRestore(RegInfo->getStackRegister()); - // Bypass expensive divides on Atom when compiling with O2. + // Bypass expensive divides and use cheaper ones. if (TM.getOptLevel() >= CodeGenOpt::Default) { if (Subtarget.hasSlowDivide32()) addBypassSlowDiv(32, 8); if (Subtarget.hasSlowDivide64() && Subtarget.is64Bit()) - addBypassSlowDiv(64, 16); + addBypassSlowDiv(64, 32); } - if (Subtarget.isTargetKnownWindowsMSVC()) { + if (Subtarget.isTargetKnownWindowsMSVC() || + Subtarget.isTargetWindowsItanium()) { // Setup Windows compiler runtime calls. setLibcallName(RTLIB::SDIV_I64, "_alldiv"); setLibcallName(RTLIB::UDIV_I64, "_aulldiv"); @@ -286,7 +288,11 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::UDIV, VT, Expand); setOperationAction(ISD::SREM, VT, Expand); setOperationAction(ISD::UREM, VT, Expand); + } + for (auto VT : { MVT::i8, MVT::i16, MVT::i32, MVT::i64 }) { + if (VT == MVT::i64 && !Subtarget.is64Bit()) + continue; // Add/Sub overflow ops with MVT::Glues are lowered to EFLAGS dependences. setOperationAction(ISD::ADDC, VT, Custom); setOperationAction(ISD::ADDE, VT, Custom); @@ -349,7 +355,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, // Special handling for half-precision floating point conversions. // If we don't have F16C support, then lower half float conversions // into library calls. - if (Subtarget.useSoftFloat() || !Subtarget.hasF16C()) { + if (Subtarget.useSoftFloat() || + (!Subtarget.hasF16C() && !Subtarget.hasAVX512())) { setOperationAction(ISD::FP16_TO_FP, MVT::f32, Expand); setOperationAction(ISD::FP_TO_FP16, MVT::f32, Expand); } @@ -484,8 +491,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, if (!Subtarget.useSoftFloat() && X86ScalarSSEf64) { // f32 and f64 use SSE. // Set up the FP register classes. - addRegisterClass(MVT::f32, &X86::FR32RegClass); - addRegisterClass(MVT::f64, &X86::FR64RegClass); + addRegisterClass(MVT::f32, Subtarget.hasAVX512() ? &X86::FR32XRegClass + : &X86::FR32RegClass); + addRegisterClass(MVT::f64, Subtarget.hasAVX512() ? &X86::FR64XRegClass + : &X86::FR64RegClass); for (auto VT : { MVT::f32, MVT::f64 }) { // Use ANDPD to simulate FABS. @@ -514,7 +523,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, } else if (UseX87 && X86ScalarSSEf32) { // Use SSE for f32, x87 for f64. // Set up the FP register classes. - addRegisterClass(MVT::f32, &X86::FR32RegClass); + addRegisterClass(MVT::f32, Subtarget.hasAVX512() ? &X86::FR32XRegClass + : &X86::FR32RegClass); addRegisterClass(MVT::f64, &X86::RFP64RegClass); // Use ANDPS to simulate FABS. @@ -590,14 +600,14 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::UNDEF, MVT::f80, Expand); setOperationAction(ISD::FCOPYSIGN, MVT::f80, Expand); { - APFloat TmpFlt = APFloat::getZero(APFloat::x87DoubleExtended); + APFloat TmpFlt = APFloat::getZero(APFloat::x87DoubleExtended()); addLegalFPImmediate(TmpFlt); // FLD0 TmpFlt.changeSign(); addLegalFPImmediate(TmpFlt); // FLD0/FCHS bool ignored; APFloat TmpFlt2(+1.0); - TmpFlt2.convert(APFloat::x87DoubleExtended, APFloat::rmNearestTiesToEven, + TmpFlt2.convert(APFloat::x87DoubleExtended(), APFloat::rmNearestTiesToEven, &ignored); addLegalFPImmediate(TmpFlt2); // FLD1 TmpFlt2.changeSign(); @@ -717,10 +727,12 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, } if (!Subtarget.useSoftFloat() && Subtarget.hasSSE1()) { - addRegisterClass(MVT::v4f32, &X86::VR128RegClass); + addRegisterClass(MVT::v4f32, Subtarget.hasVLX() ? &X86::VR128XRegClass + : &X86::VR128RegClass); setOperationAction(ISD::FNEG, MVT::v4f32, Custom); setOperationAction(ISD::FABS, MVT::v4f32, Custom); + setOperationAction(ISD::FCOPYSIGN, MVT::v4f32, Custom); setOperationAction(ISD::BUILD_VECTOR, MVT::v4f32, Custom); setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v4f32, Custom); setOperationAction(ISD::VSELECT, MVT::v4f32, Custom); @@ -730,14 +742,19 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, } if (!Subtarget.useSoftFloat() && Subtarget.hasSSE2()) { - addRegisterClass(MVT::v2f64, &X86::VR128RegClass); + addRegisterClass(MVT::v2f64, Subtarget.hasVLX() ? &X86::VR128XRegClass + : &X86::VR128RegClass); // FIXME: Unfortunately, -soft-float and -no-implicit-float mean XMM // registers cannot be used even for integer operations. - addRegisterClass(MVT::v16i8, &X86::VR128RegClass); - addRegisterClass(MVT::v8i16, &X86::VR128RegClass); - addRegisterClass(MVT::v4i32, &X86::VR128RegClass); - addRegisterClass(MVT::v2i64, &X86::VR128RegClass); + addRegisterClass(MVT::v16i8, Subtarget.hasVLX() ? &X86::VR128XRegClass + : &X86::VR128RegClass); + addRegisterClass(MVT::v8i16, Subtarget.hasVLX() ? &X86::VR128XRegClass + : &X86::VR128RegClass); + addRegisterClass(MVT::v4i32, Subtarget.hasVLX() ? &X86::VR128XRegClass + : &X86::VR128RegClass); + addRegisterClass(MVT::v2i64, Subtarget.hasVLX() ? &X86::VR128XRegClass + : &X86::VR128RegClass); setOperationAction(ISD::MUL, MVT::v16i8, Custom); setOperationAction(ISD::MUL, MVT::v4i32, Custom); @@ -751,6 +768,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::MUL, MVT::v8i16, Legal); setOperationAction(ISD::FNEG, MVT::v2f64, Custom); setOperationAction(ISD::FABS, MVT::v2f64, Custom); + setOperationAction(ISD::FCOPYSIGN, MVT::v2f64, Custom); setOperationAction(ISD::SMAX, MVT::v8i16, Legal); setOperationAction(ISD::UMAX, MVT::v16i8, Legal); @@ -776,7 +794,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::CTTZ, MVT::v16i8, Custom); setOperationAction(ISD::CTTZ, MVT::v8i16, Custom); setOperationAction(ISD::CTTZ, MVT::v4i32, Custom); - // ISD::CTTZ v2i64 - scalarization is faster. + setOperationAction(ISD::CTTZ, MVT::v2i64, Custom); // Custom lower build_vector, vector_shuffle, and extract_vector_elt. for (auto VT : { MVT::v16i8, MVT::v8i16, MVT::v4i32 }) { @@ -828,16 +846,17 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::SELECT, MVT::v2i64, Custom); setOperationAction(ISD::FP_TO_SINT, MVT::v4i32, Legal); - setOperationAction(ISD::SINT_TO_FP, MVT::v4i32, Legal); + setOperationAction(ISD::FP_TO_SINT, MVT::v2i32, Custom); + setOperationAction(ISD::SINT_TO_FP, MVT::v4i32, Legal); setOperationAction(ISD::SINT_TO_FP, MVT::v2i32, Custom); setOperationAction(ISD::UINT_TO_FP, MVT::v4i8, Custom); setOperationAction(ISD::UINT_TO_FP, MVT::v4i16, Custom); - // As there is no 64-bit GPR available, we need build a special custom - // sequence to convert from v2i32 to v2f32. - if (!Subtarget.is64Bit()) - setOperationAction(ISD::UINT_TO_FP, MVT::v2f32, Custom); + setOperationAction(ISD::UINT_TO_FP, MVT::v2i32, Custom); + + // Fast v2f32 UINT_TO_FP( v2i32 ) custom conversion. + setOperationAction(ISD::UINT_TO_FP, MVT::v2f32, Custom); setOperationAction(ISD::FP_EXTEND, MVT::v2f32, Custom); setOperationAction(ISD::FP_ROUND, MVT::v2f32, Custom); @@ -872,8 +891,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::BITREVERSE, MVT::v16i8, Custom); setOperationAction(ISD::CTLZ, MVT::v16i8, Custom); setOperationAction(ISD::CTLZ, MVT::v8i16, Custom); - // ISD::CTLZ v4i32 - scalarization is faster. - // ISD::CTLZ v2i64 - scalarization is faster. + setOperationAction(ISD::CTLZ, MVT::v4i32, Custom); + setOperationAction(ISD::CTLZ, MVT::v2i64, Custom); } if (!Subtarget.useSoftFloat() && Subtarget.hasSSE41()) { @@ -946,12 +965,18 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, if (!Subtarget.useSoftFloat() && Subtarget.hasFp256()) { bool HasInt256 = Subtarget.hasInt256(); - addRegisterClass(MVT::v32i8, &X86::VR256RegClass); - addRegisterClass(MVT::v16i16, &X86::VR256RegClass); - addRegisterClass(MVT::v8i32, &X86::VR256RegClass); - addRegisterClass(MVT::v8f32, &X86::VR256RegClass); - addRegisterClass(MVT::v4i64, &X86::VR256RegClass); - addRegisterClass(MVT::v4f64, &X86::VR256RegClass); + addRegisterClass(MVT::v32i8, Subtarget.hasVLX() ? &X86::VR256XRegClass + : &X86::VR256RegClass); + addRegisterClass(MVT::v16i16, Subtarget.hasVLX() ? &X86::VR256XRegClass + : &X86::VR256RegClass); + addRegisterClass(MVT::v8i32, Subtarget.hasVLX() ? &X86::VR256XRegClass + : &X86::VR256RegClass); + addRegisterClass(MVT::v8f32, Subtarget.hasVLX() ? &X86::VR256XRegClass + : &X86::VR256RegClass); + addRegisterClass(MVT::v4i64, Subtarget.hasVLX() ? &X86::VR256XRegClass + : &X86::VR256RegClass); + addRegisterClass(MVT::v4f64, Subtarget.hasVLX() ? &X86::VR256XRegClass + : &X86::VR256RegClass); for (auto VT : { MVT::v8f32, MVT::v4f64 }) { setOperationAction(ISD::FFLOOR, VT, Legal); @@ -961,6 +986,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::FNEARBYINT, VT, Legal); setOperationAction(ISD::FNEG, VT, Custom); setOperationAction(ISD::FABS, VT, Custom); + setOperationAction(ISD::FCOPYSIGN, VT, Custom); } // (fp_to_int:v8i16 (v8f32 ..)) requires the result type to be promoted @@ -1011,16 +1037,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, for (auto VT : { MVT::v32i8, MVT::v16i16, MVT::v8i32, MVT::v4i64 }) { setOperationAction(ISD::CTPOP, VT, Custom); setOperationAction(ISD::CTTZ, VT, Custom); - } - - // ISD::CTLZ v8i32/v4i64 - scalarization is faster without AVX2 - // as we end up splitting the 256-bit vectors. - for (auto VT : { MVT::v32i8, MVT::v16i16 }) setOperationAction(ISD::CTLZ, VT, Custom); - - if (HasInt256) - for (auto VT : { MVT::v8i32, MVT::v4i64 }) - setOperationAction(ISD::CTLZ, VT, Custom); + } if (Subtarget.hasAnyFMA()) { for (auto VT : { MVT::f32, MVT::f64, MVT::v4f32, MVT::v8f32, @@ -1171,12 +1189,14 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::FNEG, VT, Custom); setOperationAction(ISD::FABS, VT, Custom); setOperationAction(ISD::FMA, VT, Legal); + setOperationAction(ISD::FCOPYSIGN, VT, Custom); } setOperationAction(ISD::FP_TO_SINT, MVT::v16i32, Legal); setOperationAction(ISD::FP_TO_UINT, MVT::v16i32, Legal); setOperationAction(ISD::FP_TO_UINT, MVT::v8i32, Legal); setOperationAction(ISD::FP_TO_UINT, MVT::v4i32, Legal); + setOperationAction(ISD::FP_TO_UINT, MVT::v2i32, Custom); setOperationAction(ISD::SINT_TO_FP, MVT::v16i32, Legal); setOperationAction(ISD::SINT_TO_FP, MVT::v8i1, Custom); setOperationAction(ISD::SINT_TO_FP, MVT::v16i1, Custom); @@ -1216,10 +1236,11 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setTruncStoreAction(MVT::v4i32, MVT::v4i8, Legal); setTruncStoreAction(MVT::v4i32, MVT::v4i16, Legal); } else { - setOperationAction(ISD::MLOAD, MVT::v8i32, Custom); - setOperationAction(ISD::MLOAD, MVT::v8f32, Custom); - setOperationAction(ISD::MSTORE, MVT::v8i32, Custom); - setOperationAction(ISD::MSTORE, MVT::v8f32, Custom); + for (auto VT : {MVT::v4i32, MVT::v8i32, MVT::v2i64, MVT::v4i64, + MVT::v4f32, MVT::v8f32, MVT::v2f64, MVT::v4f64}) { + setOperationAction(ISD::MLOAD, VT, Custom); + setOperationAction(ISD::MSTORE, VT, Custom); + } } setOperationAction(ISD::TRUNCATE, MVT::i1, Custom); setOperationAction(ISD::TRUNCATE, MVT::v16i8, Custom); @@ -1230,18 +1251,23 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::VSELECT, MVT::v16i1, Expand); if (Subtarget.hasDQI()) { setOperationAction(ISD::SINT_TO_FP, MVT::v8i64, Legal); + setOperationAction(ISD::SINT_TO_FP, MVT::v4i64, Legal); + setOperationAction(ISD::SINT_TO_FP, MVT::v2i64, Legal); setOperationAction(ISD::UINT_TO_FP, MVT::v8i64, Legal); + setOperationAction(ISD::UINT_TO_FP, MVT::v4i64, Legal); + setOperationAction(ISD::UINT_TO_FP, MVT::v2i64, Legal); setOperationAction(ISD::FP_TO_SINT, MVT::v8i64, Legal); + setOperationAction(ISD::FP_TO_SINT, MVT::v4i64, Legal); + setOperationAction(ISD::FP_TO_SINT, MVT::v2i64, Legal); setOperationAction(ISD::FP_TO_UINT, MVT::v8i64, Legal); + setOperationAction(ISD::FP_TO_UINT, MVT::v4i64, Legal); + setOperationAction(ISD::FP_TO_UINT, MVT::v2i64, Legal); + if (Subtarget.hasVLX()) { - setOperationAction(ISD::SINT_TO_FP, MVT::v4i64, Legal); - setOperationAction(ISD::SINT_TO_FP, MVT::v2i64, Legal); - setOperationAction(ISD::UINT_TO_FP, MVT::v4i64, Legal); - setOperationAction(ISD::UINT_TO_FP, MVT::v2i64, Legal); - setOperationAction(ISD::FP_TO_SINT, MVT::v4i64, Legal); - setOperationAction(ISD::FP_TO_SINT, MVT::v2i64, Legal); - setOperationAction(ISD::FP_TO_UINT, MVT::v4i64, Legal); - setOperationAction(ISD::FP_TO_UINT, MVT::v2i64, Legal); + // Fast v2f32 SINT_TO_FP( v2i32 ) custom conversion. + setOperationAction(ISD::SINT_TO_FP, MVT::v2f32, Custom); + setOperationAction(ISD::FP_TO_SINT, MVT::v2f32, Custom); + setOperationAction(ISD::FP_TO_UINT, MVT::v2f32, Custom); } } if (Subtarget.hasVLX()) { @@ -1250,11 +1276,12 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::FP_TO_SINT, MVT::v8i32, Legal); setOperationAction(ISD::FP_TO_UINT, MVT::v8i32, Legal); setOperationAction(ISD::SINT_TO_FP, MVT::v4i32, Legal); - setOperationAction(ISD::UINT_TO_FP, MVT::v4i32, Legal); setOperationAction(ISD::FP_TO_SINT, MVT::v4i32, Legal); setOperationAction(ISD::FP_TO_UINT, MVT::v4i32, Legal); setOperationAction(ISD::ZERO_EXTEND, MVT::v4i32, Custom); setOperationAction(ISD::ZERO_EXTEND, MVT::v2i64, Custom); + setOperationAction(ISD::SIGN_EXTEND, MVT::v4i32, Custom); + setOperationAction(ISD::SIGN_EXTEND, MVT::v2i64, Custom); // FIXME. This commands are available on SSE/AVX2, add relevant patterns. setLoadExtAction(ISD::EXTLOAD, MVT::v8i32, MVT::v8i8, Legal); @@ -1281,10 +1308,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::SIGN_EXTEND, MVT::v16i8, Custom); setOperationAction(ISD::SIGN_EXTEND, MVT::v8i16, Custom); setOperationAction(ISD::SIGN_EXTEND, MVT::v16i16, Custom); - if (Subtarget.hasDQI()) { - setOperationAction(ISD::SIGN_EXTEND, MVT::v4i32, Custom); - setOperationAction(ISD::SIGN_EXTEND, MVT::v2i64, Custom); - } + for (auto VT : { MVT::v16f32, MVT::v8f64 }) { setOperationAction(ISD::FFLOOR, VT, Legal); setOperationAction(ISD::FCEIL, VT, Legal); @@ -1293,6 +1317,13 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::FNEARBYINT, VT, Legal); } + setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, MVT::v8i64, Custom); + setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, MVT::v16i32, Custom); + + // Without BWI we need to use custom lowering to handle MVT::v64i8 input. + setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, MVT::v64i8, Custom); + setOperationAction(ISD::ZERO_EXTEND_VECTOR_INREG, MVT::v64i8, Custom); + setOperationAction(ISD::CONCAT_VECTORS, MVT::v8f64, Custom); setOperationAction(ISD::CONCAT_VECTORS, MVT::v8i64, Custom); setOperationAction(ISD::CONCAT_VECTORS, MVT::v16f32, Custom); @@ -1339,13 +1370,17 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::SRL, VT, Custom); setOperationAction(ISD::SHL, VT, Custom); setOperationAction(ISD::SRA, VT, Custom); - setOperationAction(ISD::AND, VT, Legal); - setOperationAction(ISD::OR, VT, Legal); - setOperationAction(ISD::XOR, VT, Legal); setOperationAction(ISD::CTPOP, VT, Custom); setOperationAction(ISD::CTTZ, VT, Custom); } + // Need to promote to 64-bit even though we have 32-bit masked instructions + // because the IR optimizers rearrange bitcasts around logic ops leaving + // too many variations to handle if we don't promote them. + setOperationPromotedToType(ISD::AND, MVT::v16i32, MVT::v8i64); + setOperationPromotedToType(ISD::OR, MVT::v16i32, MVT::v8i64); + setOperationPromotedToType(ISD::XOR, MVT::v16i32, MVT::v8i64); + if (Subtarget.hasCDI()) { setOperationAction(ISD::CTLZ, MVT::v8i64, Legal); setOperationAction(ISD::CTLZ, MVT::v16i32, Legal); @@ -1377,12 +1412,12 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, } // Subtarget.hasCDI() if (Subtarget.hasDQI()) { - if (Subtarget.hasVLX()) { - setOperationAction(ISD::MUL, MVT::v2i64, Legal); - setOperationAction(ISD::MUL, MVT::v4i64, Legal); - } + // NonVLX sub-targets extend 128/256 vectors to use the 512 version. + setOperationAction(ISD::MUL, MVT::v2i64, Legal); + setOperationAction(ISD::MUL, MVT::v4i64, Legal); setOperationAction(ISD::MUL, MVT::v8i64, Legal); } + // Custom lower several nodes. for (auto VT : { MVT::v4i32, MVT::v8i32, MVT::v2i64, MVT::v4i64, MVT::v4f32, MVT::v8f32, MVT::v2f64, MVT::v4f64 }) { @@ -1413,6 +1448,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::MSCATTER, VT, Custom); } for (auto VT : { MVT::v64i8, MVT::v32i16, MVT::v16i32 }) { + setOperationPromotedToType(ISD::LOAD, VT, MVT::v8i64); setOperationPromotedToType(ISD::SELECT, VT, MVT::v8i64); } }// has AVX-512 @@ -1447,6 +1483,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v64i8, Custom); setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v32i16, Custom); setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v64i8, Custom); + setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v32i1, Custom); + setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v64i1, Custom); setOperationAction(ISD::SCALAR_TO_VECTOR, MVT::v32i16, Custom); setOperationAction(ISD::SCALAR_TO_VECTOR, MVT::v64i8, Custom); setOperationAction(ISD::SELECT, MVT::v32i1, Custom); @@ -1486,10 +1524,13 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::UMIN, MVT::v64i8, Legal); setOperationAction(ISD::UMIN, MVT::v32i16, Legal); + setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, MVT::v32i16, Custom); + setTruncStoreAction(MVT::v32i16, MVT::v32i8, Legal); - setTruncStoreAction(MVT::v16i16, MVT::v16i8, Legal); - if (Subtarget.hasVLX()) + if (Subtarget.hasVLX()) { + setTruncStoreAction(MVT::v16i16, MVT::v16i8, Legal); setTruncStoreAction(MVT::v8i16, MVT::v8i8, Legal); + } LegalizeAction Action = Subtarget.hasVLX() ? Legal : Custom; for (auto VT : { MVT::v32i8, MVT::v16i8, MVT::v16i16, MVT::v8i16 }) { @@ -1532,35 +1573,25 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, addRegisterClass(MVT::v4i1, &X86::VK4RegClass); addRegisterClass(MVT::v2i1, &X86::VK2RegClass); - setOperationAction(ISD::ADD, MVT::v2i1, Expand); - setOperationAction(ISD::ADD, MVT::v4i1, Expand); - setOperationAction(ISD::SUB, MVT::v2i1, Expand); - setOperationAction(ISD::SUB, MVT::v4i1, Expand); - setOperationAction(ISD::MUL, MVT::v2i1, Expand); - setOperationAction(ISD::MUL, MVT::v4i1, Expand); - - setOperationAction(ISD::TRUNCATE, MVT::v2i1, Custom); - setOperationAction(ISD::TRUNCATE, MVT::v4i1, Custom); - setOperationAction(ISD::SETCC, MVT::v4i1, Custom); - setOperationAction(ISD::SETCC, MVT::v2i1, Custom); - setOperationAction(ISD::CONCAT_VECTORS, MVT::v4i1, Custom); + for (auto VT : { MVT::v2i1, MVT::v4i1 }) { + setOperationAction(ISD::ADD, VT, Expand); + setOperationAction(ISD::SUB, VT, Expand); + setOperationAction(ISD::MUL, VT, Expand); + setOperationAction(ISD::VSELECT, VT, Expand); + + setOperationAction(ISD::TRUNCATE, VT, Custom); + setOperationAction(ISD::SETCC, VT, Custom); + setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom); + setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom); + setOperationAction(ISD::SELECT, VT, Custom); + setOperationAction(ISD::BUILD_VECTOR, VT, Custom); + setOperationAction(ISD::VECTOR_SHUFFLE, VT, Custom); + } + setOperationAction(ISD::CONCAT_VECTORS, MVT::v8i1, Custom); + setOperationAction(ISD::CONCAT_VECTORS, MVT::v4i1, Custom); setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v8i1, Custom); setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v4i1, Custom); - setOperationAction(ISD::SELECT, MVT::v4i1, Custom); - setOperationAction(ISD::SELECT, MVT::v2i1, Custom); - setOperationAction(ISD::BUILD_VECTOR, MVT::v4i1, Custom); - setOperationAction(ISD::BUILD_VECTOR, MVT::v2i1, Custom); - setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v2i1, Custom); - setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v4i1, Custom); - setOperationAction(ISD::VSELECT, MVT::v2i1, Expand); - setOperationAction(ISD::VSELECT, MVT::v4i1, Expand); - - for (auto VT : { MVT::v4i32, MVT::v8i32 }) { - setOperationAction(ISD::AND, VT, Legal); - setOperationAction(ISD::OR, VT, Legal); - setOperationAction(ISD::XOR, VT, Legal); - } for (auto VT : { MVT::v2i64, MVT::v4i64 }) { setOperationAction(ISD::SMAX, VT, Legal); @@ -1629,7 +1660,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, // is. We should promote the value to 64-bits to solve this. // This is what the CRT headers do - `fmodf` is an inline header // function casting to f64 and calling `fmod`. - if (Subtarget.is32Bit() && Subtarget.isTargetKnownWindowsMSVC()) + if (Subtarget.is32Bit() && (Subtarget.isTargetKnownWindowsMSVC() || + Subtarget.isTargetWindowsItanium())) for (ISD::NodeType Op : {ISD::FCEIL, ISD::FCOS, ISD::FEXP, ISD::FFLOOR, ISD::FREM, ISD::FLOG, ISD::FLOG10, ISD::FPOW, ISD::FSIN}) @@ -1953,9 +1985,11 @@ X86TargetLowering::findRepresentativeClass(const TargetRegisterInfo *TRI, case MVT::f32: case MVT::f64: case MVT::v16i8: case MVT::v8i16: case MVT::v4i32: case MVT::v2i64: case MVT::v4f32: case MVT::v2f64: - case MVT::v32i8: case MVT::v8i32: case MVT::v4i64: case MVT::v8f32: - case MVT::v4f64: - RRC = &X86::VR128RegClass; + case MVT::v32i8: case MVT::v16i16: case MVT::v8i32: case MVT::v4i64: + case MVT::v8f32: case MVT::v4f64: + case MVT::v64i8: case MVT::v32i16: case MVT::v16i32: case MVT::v8i64: + case MVT::v16f32: case MVT::v8f64: + RRC = &X86::VR128XRegClass; break; } return std::make_pair(RRC, Cost); @@ -2019,6 +2053,9 @@ Value *X86TargetLowering::getSSPStackGuardCheck(const Module &M) const { } Value *X86TargetLowering::getSafeStackPointerLocation(IRBuilder<> &IRB) const { + if (Subtarget.getTargetTriple().isOSContiki()) + return getDefaultSafeStackPointerLocation(IRB, false); + if (!Subtarget.isTargetAndroid()) return TargetLowering::getSafeStackPointerLocation(IRB); @@ -2062,6 +2099,58 @@ const MCPhysReg *X86TargetLowering::getScratchRegisters(CallingConv::ID) const { return ScratchRegs; } +/// Lowers masks values (v*i1) to the local register values +/// \returns DAG node after lowering to register type +static SDValue lowerMasksToReg(const SDValue &ValArg, const EVT &ValLoc, + const SDLoc &Dl, SelectionDAG &DAG) { + EVT ValVT = ValArg.getValueType(); + + if ((ValVT == MVT::v8i1 && (ValLoc == MVT::i8 || ValLoc == MVT::i32)) || + (ValVT == MVT::v16i1 && (ValLoc == MVT::i16 || ValLoc == MVT::i32))) { + // Two stage lowering might be required + // bitcast: v8i1 -> i8 / v16i1 -> i16 + // anyextend: i8 -> i32 / i16 -> i32 + EVT TempValLoc = ValVT == MVT::v8i1 ? MVT::i8 : MVT::i16; + SDValue ValToCopy = DAG.getBitcast(TempValLoc, ValArg); + if (ValLoc == MVT::i32) + ValToCopy = DAG.getNode(ISD::ANY_EXTEND, Dl, ValLoc, ValToCopy); + return ValToCopy; + } else if ((ValVT == MVT::v32i1 && ValLoc == MVT::i32) || + (ValVT == MVT::v64i1 && ValLoc == MVT::i64)) { + // One stage lowering is required + // bitcast: v32i1 -> i32 / v64i1 -> i64 + return DAG.getBitcast(ValLoc, ValArg); + } else + return DAG.getNode(ISD::SIGN_EXTEND, Dl, ValLoc, ValArg); +} + +/// Breaks v64i1 value into two registers and adds the new node to the DAG +static void Passv64i1ArgInRegs( + const SDLoc &Dl, SelectionDAG &DAG, SDValue Chain, SDValue &Arg, + SmallVector<std::pair<unsigned, SDValue>, 8> &RegsToPass, CCValAssign &VA, + CCValAssign &NextVA, const X86Subtarget &Subtarget) { + assert((Subtarget.hasBWI() || Subtarget.hasBMI()) && + "Expected AVX512BW or AVX512BMI target!"); + assert(Subtarget.is32Bit() && "Expecting 32 bit target"); + assert(Arg.getValueType() == MVT::i64 && "Expecting 64 bit value"); + assert(VA.isRegLoc() && NextVA.isRegLoc() && + "The value should reside in two registers"); + + // Before splitting the value we cast it to i64 + Arg = DAG.getBitcast(MVT::i64, Arg); + + // Splitting the value into two i32 types + SDValue Lo, Hi; + Lo = DAG.getNode(ISD::EXTRACT_ELEMENT, Dl, MVT::i32, Arg, + DAG.getConstant(0, Dl, MVT::i32)); + Hi = DAG.getNode(ISD::EXTRACT_ELEMENT, Dl, MVT::i32, Arg, + DAG.getConstant(1, Dl, MVT::i32)); + + // Attach the two i32 types into corresponding registers + RegsToPass.push_back(std::make_pair(VA.getLocReg(), Lo)); + RegsToPass.push_back(std::make_pair(NextVA.getLocReg(), Hi)); +} + SDValue X86TargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv, bool isVarArg, @@ -2086,10 +2175,11 @@ X86TargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv, MVT::i32)); // Copy the result values into the output registers. - for (unsigned i = 0, e = RVLocs.size(); i != e; ++i) { - CCValAssign &VA = RVLocs[i]; + for (unsigned I = 0, OutsIndex = 0, E = RVLocs.size(); I != E; + ++I, ++OutsIndex) { + CCValAssign &VA = RVLocs[I]; assert(VA.isRegLoc() && "Can only return in registers!"); - SDValue ValToCopy = OutVals[i]; + SDValue ValToCopy = OutVals[OutsIndex]; EVT ValVT = ValToCopy.getValueType(); // Promote values to the appropriate types. @@ -2099,7 +2189,7 @@ X86TargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv, ValToCopy = DAG.getNode(ISD::ZERO_EXTEND, dl, VA.getLocVT(), ValToCopy); else if (VA.getLocInfo() == CCValAssign::AExt) { if (ValVT.isVector() && ValVT.getVectorElementType() == MVT::i1) - ValToCopy = DAG.getNode(ISD::SIGN_EXTEND, dl, VA.getLocVT(), ValToCopy); + ValToCopy = lowerMasksToReg(ValToCopy, VA.getLocVT(), dl, DAG); else ValToCopy = DAG.getNode(ISD::ANY_EXTEND, dl, VA.getLocVT(), ValToCopy); } @@ -2152,9 +2242,27 @@ X86TargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv, } } - Chain = DAG.getCopyToReg(Chain, dl, VA.getLocReg(), ValToCopy, Flag); - Flag = Chain.getValue(1); - RetOps.push_back(DAG.getRegister(VA.getLocReg(), VA.getLocVT())); + SmallVector<std::pair<unsigned, SDValue>, 8> RegsToPass; + + if (VA.needsCustom()) { + assert(VA.getValVT() == MVT::v64i1 && + "Currently the only custom case is when we split v64i1 to 2 regs"); + + Passv64i1ArgInRegs(dl, DAG, Chain, ValToCopy, RegsToPass, VA, RVLocs[++I], + Subtarget); + + assert(2 == RegsToPass.size() && + "Expecting two registers after Pass64BitArgInRegs"); + } else { + RegsToPass.push_back(std::make_pair(VA.getLocReg(), ValToCopy)); + } + + // Add nodes to the DAG and add the values into the RetOps list + for (auto &Reg : RegsToPass) { + Chain = DAG.getCopyToReg(Chain, dl, Reg.first, Reg.second, Flag); + Flag = Chain.getValue(1); + RetOps.push_back(DAG.getRegister(Reg.first, Reg.second.getValueType())); + } } // Swift calling convention does not require we copy the sret argument @@ -2282,6 +2390,98 @@ EVT X86TargetLowering::getTypeForExtReturn(LLVMContext &Context, EVT VT, return VT.bitsLT(MinVT) ? MinVT : VT; } +/// Reads two 32 bit registers and creates a 64 bit mask value. +/// \param VA The current 32 bit value that need to be assigned. +/// \param NextVA The next 32 bit value that need to be assigned. +/// \param Root The parent DAG node. +/// \param [in,out] InFlag Represents SDvalue in the parent DAG node for +/// glue purposes. In the case the DAG is already using +/// physical register instead of virtual, we should glue +/// our new SDValue to InFlag SDvalue. +/// \return a new SDvalue of size 64bit. +static SDValue getv64i1Argument(CCValAssign &VA, CCValAssign &NextVA, + SDValue &Root, SelectionDAG &DAG, + const SDLoc &Dl, const X86Subtarget &Subtarget, + SDValue *InFlag = nullptr) { + assert((Subtarget.hasBWI()) && "Expected AVX512BW target!"); + assert(Subtarget.is32Bit() && "Expecting 32 bit target"); + assert(VA.getValVT() == MVT::v64i1 && + "Expecting first location of 64 bit width type"); + assert(NextVA.getValVT() == VA.getValVT() && + "The locations should have the same type"); + assert(VA.isRegLoc() && NextVA.isRegLoc() && + "The values should reside in two registers"); + + SDValue Lo, Hi; + unsigned Reg; + SDValue ArgValueLo, ArgValueHi; + + MachineFunction &MF = DAG.getMachineFunction(); + const TargetRegisterClass *RC = &X86::GR32RegClass; + + // Read a 32 bit value from the registers + if (nullptr == InFlag) { + // When no physical register is present, + // create an intermediate virtual register + Reg = MF.addLiveIn(VA.getLocReg(), RC); + ArgValueLo = DAG.getCopyFromReg(Root, Dl, Reg, MVT::i32); + Reg = MF.addLiveIn(NextVA.getLocReg(), RC); + ArgValueHi = DAG.getCopyFromReg(Root, Dl, Reg, MVT::i32); + } else { + // When a physical register is available read the value from it and glue + // the reads together. + ArgValueLo = + DAG.getCopyFromReg(Root, Dl, VA.getLocReg(), MVT::i32, *InFlag); + *InFlag = ArgValueLo.getValue(2); + ArgValueHi = + DAG.getCopyFromReg(Root, Dl, NextVA.getLocReg(), MVT::i32, *InFlag); + *InFlag = ArgValueHi.getValue(2); + } + + // Convert the i32 type into v32i1 type + Lo = DAG.getBitcast(MVT::v32i1, ArgValueLo); + + // Convert the i32 type into v32i1 type + Hi = DAG.getBitcast(MVT::v32i1, ArgValueHi); + + // Concantenate the two values together + return DAG.getNode(ISD::CONCAT_VECTORS, Dl, MVT::v64i1, Lo, Hi); +} + +/// The function will lower a register of various sizes (8/16/32/64) +/// to a mask value of the expected size (v8i1/v16i1/v32i1/v64i1) +/// \returns a DAG node contains the operand after lowering to mask type. +static SDValue lowerRegToMasks(const SDValue &ValArg, const EVT &ValVT, + const EVT &ValLoc, const SDLoc &Dl, + SelectionDAG &DAG) { + SDValue ValReturned = ValArg; + + if (ValVT == MVT::v64i1) { + // In 32 bit machine, this case is handled by getv64i1Argument + assert(ValLoc == MVT::i64 && "Expecting only i64 locations"); + // In 64 bit machine, There is no need to truncate the value only bitcast + } else { + MVT maskLen; + switch (ValVT.getSimpleVT().SimpleTy) { + case MVT::v8i1: + maskLen = MVT::i8; + break; + case MVT::v16i1: + maskLen = MVT::i16; + break; + case MVT::v32i1: + maskLen = MVT::i32; + break; + default: + llvm_unreachable("Expecting a vector of i1 types"); + } + + ValReturned = DAG.getNode(ISD::TRUNCATE, Dl, maskLen, ValReturned); + } + + return DAG.getBitcast(ValVT, ValReturned); +} + /// Lower the result values of a call into the /// appropriate copies out of appropriate physical registers. /// @@ -2298,13 +2498,14 @@ SDValue X86TargetLowering::LowerCallResult( CCInfo.AnalyzeCallResult(Ins, RetCC_X86); // Copy all of the result registers out of their specified physreg. - for (unsigned i = 0, e = RVLocs.size(); i != e; ++i) { - CCValAssign &VA = RVLocs[i]; + for (unsigned I = 0, InsIndex = 0, E = RVLocs.size(); I != E; + ++I, ++InsIndex) { + CCValAssign &VA = RVLocs[I]; EVT CopyVT = VA.getLocVT(); // If this is x86-64, and we disabled SSE, we can't return FP values if ((CopyVT == MVT::f32 || CopyVT == MVT::f64 || CopyVT == MVT::f128) && - ((Is64Bit || Ins[i].Flags.isInReg()) && !Subtarget.hasSSE1())) { + ((Is64Bit || Ins[InsIndex].Flags.isInReg()) && !Subtarget.hasSSE1())) { report_fatal_error("SSE register return with SSE disabled"); } @@ -2319,19 +2520,34 @@ SDValue X86TargetLowering::LowerCallResult( RoundAfterCopy = (CopyVT != VA.getLocVT()); } - Chain = DAG.getCopyFromReg(Chain, dl, VA.getLocReg(), - CopyVT, InFlag).getValue(1); - SDValue Val = Chain.getValue(0); + SDValue Val; + if (VA.needsCustom()) { + assert(VA.getValVT() == MVT::v64i1 && + "Currently the only custom case is when we split v64i1 to 2 regs"); + Val = + getv64i1Argument(VA, RVLocs[++I], Chain, DAG, dl, Subtarget, &InFlag); + } else { + Chain = DAG.getCopyFromReg(Chain, dl, VA.getLocReg(), CopyVT, InFlag) + .getValue(1); + Val = Chain.getValue(0); + InFlag = Chain.getValue(2); + } if (RoundAfterCopy) Val = DAG.getNode(ISD::FP_ROUND, dl, VA.getValVT(), Val, // This truncation won't change the value. DAG.getIntPtrConstant(1, dl)); - if (VA.isExtInLoc() && VA.getValVT().getScalarType() == MVT::i1) - Val = DAG.getNode(ISD::TRUNCATE, dl, VA.getValVT(), Val); + if (VA.isExtInLoc() && (VA.getValVT().getScalarType() == MVT::i1)) { + if (VA.getValVT().isVector() && + ((VA.getLocVT() == MVT::i64) || (VA.getLocVT() == MVT::i32) || + (VA.getLocVT() == MVT::i16) || (VA.getLocVT() == MVT::i8))) { + // promoting a mask type (v*i1) into a register of type i64/i32/i16/i8 + Val = lowerRegToMasks(Val, VA.getValVT(), VA.getLocVT(), dl, DAG); + } else + Val = DAG.getNode(ISD::TRUNCATE, dl, VA.getValVT(), Val); + } - InFlag = Chain.getValue(2); InVals.push_back(Val); } @@ -2399,7 +2615,8 @@ static SDValue CreateCopyOfByValArgument(SDValue Src, SDValue Dst, /// Return true if the calling convention is one that we can guarantee TCO for. static bool canGuaranteeTCO(CallingConv::ID CC) { return (CC == CallingConv::Fast || CC == CallingConv::GHC || - CC == CallingConv::HiPE || CC == CallingConv::HHVM); + CC == CallingConv::X86_RegCall || CC == CallingConv::HiPE || + CC == CallingConv::HHVM); } /// Return true if we might ever do TCO for calls with this calling convention. @@ -2445,7 +2662,7 @@ X86TargetLowering::LowerMemArgument(SDValue Chain, CallingConv::ID CallConv, const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &dl, SelectionDAG &DAG, const CCValAssign &VA, - MachineFrameInfo *MFI, unsigned i) const { + MachineFrameInfo &MFI, unsigned i) const { // Create the nodes corresponding to a load from this parameter slot. ISD::ArgFlagsTy Flags = Ins[i].Flags; bool AlwaysUseMutable = shouldGuaranteeTCO( @@ -2454,9 +2671,11 @@ X86TargetLowering::LowerMemArgument(SDValue Chain, CallingConv::ID CallConv, EVT ValVT; // If value is passed by pointer we have address passed instead of the value - // itself. - bool ExtendedInMem = VA.isExtInLoc() && - VA.getValVT().getScalarType() == MVT::i1; + // itself. No need to extend if the mask value and location share the same + // absolute size. + bool ExtendedInMem = + VA.isExtInLoc() && VA.getValVT().getScalarType() == MVT::i1 && + VA.getValVT().getSizeInBits() != VA.getLocVT().getSizeInBits(); if (VA.getLocInfo() == CCValAssign::Indirect || ExtendedInMem) ValVT = VA.getLocVT(); @@ -2483,26 +2702,26 @@ X86TargetLowering::LowerMemArgument(SDValue Chain, CallingConv::ID CallConv, if (Flags.isByVal()) { unsigned Bytes = Flags.getByValSize(); if (Bytes == 0) Bytes = 1; // Don't create zero-sized stack objects. - int FI = MFI->CreateFixedObject(Bytes, VA.getLocMemOffset(), isImmutable); + int FI = MFI.CreateFixedObject(Bytes, VA.getLocMemOffset(), isImmutable); // Adjust SP offset of interrupt parameter. if (CallConv == CallingConv::X86_INTR) { - MFI->setObjectOffset(FI, Offset); + MFI.setObjectOffset(FI, Offset); } return DAG.getFrameIndex(FI, getPointerTy(DAG.getDataLayout())); } else { - int FI = MFI->CreateFixedObject(ValVT.getSizeInBits()/8, - VA.getLocMemOffset(), isImmutable); + int FI = MFI.CreateFixedObject(ValVT.getSizeInBits()/8, + VA.getLocMemOffset(), isImmutable); // Set SExt or ZExt flag. if (VA.getLocInfo() == CCValAssign::ZExt) { - MFI->setObjectZExt(FI, true); + MFI.setObjectZExt(FI, true); } else if (VA.getLocInfo() == CCValAssign::SExt) { - MFI->setObjectSExt(FI, true); + MFI.setObjectSExt(FI, true); } // Adjust SP offset of interrupt parameter. if (CallConv == CallingConv::X86_INTR) { - MFI->setObjectOffset(FI, Offset); + MFI.setObjectOffset(FI, Offset); } SDValue FIN = DAG.getFrameIndex(FI, getPointerTy(DAG.getDataLayout())); @@ -2562,6 +2781,13 @@ static ArrayRef<MCPhysReg> get64BitArgumentXMMs(MachineFunction &MF, return makeArrayRef(std::begin(XMMArgRegs64Bit), std::end(XMMArgRegs64Bit)); } +static bool isSortedByValueNo(const SmallVectorImpl<CCValAssign> &ArgLocs) { + return std::is_sorted(ArgLocs.begin(), ArgLocs.end(), + [](const CCValAssign &A, const CCValAssign &B) -> bool { + return A.getValNo() < B.getValNo(); + }); +} + SDValue X86TargetLowering::LowerFormalArguments( SDValue Chain, CallingConv::ID CallConv, bool isVarArg, const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &dl, @@ -2576,12 +2802,13 @@ SDValue X86TargetLowering::LowerFormalArguments( Fn->getName() == "main") FuncInfo->setForceFramePointer(true); - MachineFrameInfo *MFI = MF.getFrameInfo(); + MachineFrameInfo &MFI = MF.getFrameInfo(); bool Is64Bit = Subtarget.is64Bit(); bool IsWin64 = Subtarget.isCallingConvWin64(CallConv); - assert(!(isVarArg && canGuaranteeTCO(CallConv)) && - "Var args not supported with calling convention fastcc, ghc or hipe"); + assert( + !(isVarArg && canGuaranteeTCO(CallConv)) && + "Var args not supported with calling conv' regcall, fastcc, ghc or hipe"); if (CallConv == CallingConv::X86_INTR) { bool isLegal = Ins.size() == 1 || @@ -2595,59 +2822,78 @@ SDValue X86TargetLowering::LowerFormalArguments( SmallVector<CCValAssign, 16> ArgLocs; CCState CCInfo(CallConv, isVarArg, MF, ArgLocs, *DAG.getContext()); - // Allocate shadow area for Win64 + // Allocate shadow area for Win64. if (IsWin64) CCInfo.AllocateStack(32, 8); - CCInfo.AnalyzeFormalArguments(Ins, CC_X86); + CCInfo.AnalyzeArguments(Ins, CC_X86); + + // In vectorcall calling convention a second pass is required for the HVA + // types. + if (CallingConv::X86_VectorCall == CallConv) { + CCInfo.AnalyzeArgumentsSecondPass(Ins, CC_X86); + } + + // The next loop assumes that the locations are in the same order of the + // input arguments. + if (!isSortedByValueNo(ArgLocs)) + llvm_unreachable("Argument Location list must be sorted before lowering"); - unsigned LastVal = ~0U; SDValue ArgValue; - for (unsigned i = 0, e = ArgLocs.size(); i != e; ++i) { - CCValAssign &VA = ArgLocs[i]; - // TODO: If an arg is passed in two places (e.g. reg and stack), skip later - // places. - assert(VA.getValNo() != LastVal && - "Don't support value assigned to multiple locs yet"); - (void)LastVal; - LastVal = VA.getValNo(); + for (unsigned I = 0, InsIndex = 0, E = ArgLocs.size(); I != E; + ++I, ++InsIndex) { + assert(InsIndex < Ins.size() && "Invalid Ins index"); + CCValAssign &VA = ArgLocs[I]; if (VA.isRegLoc()) { EVT RegVT = VA.getLocVT(); - const TargetRegisterClass *RC; - if (RegVT == MVT::i32) - RC = &X86::GR32RegClass; - else if (Is64Bit && RegVT == MVT::i64) - RC = &X86::GR64RegClass; - else if (RegVT == MVT::f32) - RC = &X86::FR32RegClass; - else if (RegVT == MVT::f64) - RC = &X86::FR64RegClass; - else if (RegVT == MVT::f128) - RC = &X86::FR128RegClass; - else if (RegVT.is512BitVector()) - RC = &X86::VR512RegClass; - else if (RegVT.is256BitVector()) - RC = &X86::VR256RegClass; - else if (RegVT.is128BitVector()) - RC = &X86::VR128RegClass; - else if (RegVT == MVT::x86mmx) - RC = &X86::VR64RegClass; - else if (RegVT == MVT::i1) - RC = &X86::VK1RegClass; - else if (RegVT == MVT::v8i1) - RC = &X86::VK8RegClass; - else if (RegVT == MVT::v16i1) - RC = &X86::VK16RegClass; - else if (RegVT == MVT::v32i1) - RC = &X86::VK32RegClass; - else if (RegVT == MVT::v64i1) - RC = &X86::VK64RegClass; - else - llvm_unreachable("Unknown argument type!"); + if (VA.needsCustom()) { + assert( + VA.getValVT() == MVT::v64i1 && + "Currently the only custom case is when we split v64i1 to 2 regs"); + + // v64i1 values, in regcall calling convention, that are + // compiled to 32 bit arch, are splited up into two registers. + ArgValue = + getv64i1Argument(VA, ArgLocs[++I], Chain, DAG, dl, Subtarget); + } else { + const TargetRegisterClass *RC; + if (RegVT == MVT::i32) + RC = &X86::GR32RegClass; + else if (Is64Bit && RegVT == MVT::i64) + RC = &X86::GR64RegClass; + else if (RegVT == MVT::f32) + RC = Subtarget.hasAVX512() ? &X86::FR32XRegClass : &X86::FR32RegClass; + else if (RegVT == MVT::f64) + RC = Subtarget.hasAVX512() ? &X86::FR64XRegClass : &X86::FR64RegClass; + else if (RegVT == MVT::f80) + RC = &X86::RFP80RegClass; + else if (RegVT == MVT::f128) + RC = &X86::FR128RegClass; + else if (RegVT.is512BitVector()) + RC = &X86::VR512RegClass; + else if (RegVT.is256BitVector()) + RC = Subtarget.hasVLX() ? &X86::VR256XRegClass : &X86::VR256RegClass; + else if (RegVT.is128BitVector()) + RC = Subtarget.hasVLX() ? &X86::VR128XRegClass : &X86::VR128RegClass; + else if (RegVT == MVT::x86mmx) + RC = &X86::VR64RegClass; + else if (RegVT == MVT::i1) + RC = &X86::VK1RegClass; + else if (RegVT == MVT::v8i1) + RC = &X86::VK8RegClass; + else if (RegVT == MVT::v16i1) + RC = &X86::VK16RegClass; + else if (RegVT == MVT::v32i1) + RC = &X86::VK32RegClass; + else if (RegVT == MVT::v64i1) + RC = &X86::VK64RegClass; + else + llvm_unreachable("Unknown argument type!"); - unsigned Reg = MF.addLiveIn(VA.getLocReg(), RC); - ArgValue = DAG.getCopyFromReg(Chain, dl, Reg, RegVT); + unsigned Reg = MF.addLiveIn(VA.getLocReg(), RC); + ArgValue = DAG.getCopyFromReg(Chain, dl, Reg, RegVT); + } // If this is an 8 or 16-bit value, it is really passed promoted to 32 // bits. Insert an assert[sz]ext to capture this, then truncate to the @@ -2665,12 +2911,19 @@ SDValue X86TargetLowering::LowerFormalArguments( // Handle MMX values passed in XMM regs. if (RegVT.isVector() && VA.getValVT().getScalarType() != MVT::i1) ArgValue = DAG.getNode(X86ISD::MOVDQ2Q, dl, VA.getValVT(), ArgValue); - else + else if (VA.getValVT().isVector() && + VA.getValVT().getScalarType() == MVT::i1 && + ((VA.getLocVT() == MVT::i64) || (VA.getLocVT() == MVT::i32) || + (VA.getLocVT() == MVT::i16) || (VA.getLocVT() == MVT::i8))) { + // Promoting a mask type (v*i1) into a register of type i64/i32/i16/i8 + ArgValue = lowerRegToMasks(ArgValue, VA.getValVT(), RegVT, dl, DAG); + } else ArgValue = DAG.getNode(ISD::TRUNCATE, dl, VA.getValVT(), ArgValue); } } else { assert(VA.isMemLoc()); - ArgValue = LowerMemArgument(Chain, CallConv, Ins, dl, DAG, VA, MFI, i); + ArgValue = + LowerMemArgument(Chain, CallConv, Ins, dl, DAG, VA, MFI, InsIndex); } // If value is passed via pointer - do a load. @@ -2681,7 +2934,7 @@ SDValue X86TargetLowering::LowerFormalArguments( InVals.push_back(ArgValue); } - for (unsigned i = 0, e = ArgLocs.size(); i != e; ++i) { + for (unsigned I = 0, E = Ins.size(); I != E; ++I) { // Swift calling convention does not require we copy the sret argument // into %rax/%eax for the return. We don't set SRetReturnReg for Swift. if (CallConv == CallingConv::Swift) @@ -2691,14 +2944,14 @@ SDValue X86TargetLowering::LowerFormalArguments( // sret argument into %rax/%eax (depending on ABI) for the return. Save // the argument into a virtual register so that we can access it from the // return points. - if (Ins[i].Flags.isSRet()) { + if (Ins[I].Flags.isSRet()) { unsigned Reg = FuncInfo->getSRetReturnReg(); if (!Reg) { MVT PtrTy = getPointerTy(DAG.getDataLayout()); Reg = MF.getRegInfo().createVirtualRegister(getRegClassFor(PtrTy)); FuncInfo->setSRetReturnReg(Reg); } - SDValue Copy = DAG.getCopyToReg(DAG.getEntryNode(), dl, Reg, InVals[i]); + SDValue Copy = DAG.getCopyToReg(DAG.getEntryNode(), dl, Reg, InVals[I]); Chain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Copy, Chain); break; } @@ -2713,11 +2966,10 @@ SDValue X86TargetLowering::LowerFormalArguments( // If the function takes variable number of arguments, make a frame index for // the start of the first vararg value... for expansion of llvm.va_start. We // can skip this if there are no va_start calls. - if (MFI->hasVAStart() && + if (MFI.hasVAStart() && (Is64Bit || (CallConv != CallingConv::X86_FastCall && CallConv != CallingConv::X86_ThisCall))) { - FuncInfo->setVarArgsFrameIndex( - MFI->CreateFixedObject(1, StackSize, true)); + FuncInfo->setVarArgsFrameIndex(MFI.CreateFixedObject(1, StackSize, true)); } // Figure out if XMM registers are in use. @@ -2727,7 +2979,7 @@ SDValue X86TargetLowering::LowerFormalArguments( // 64-bit calling conventions support varargs and register parameters, so we // have to do extra work to spill them in the prologue. - if (Is64Bit && isVarArg && MFI->hasVAStart()) { + if (Is64Bit && isVarArg && MFI.hasVAStart()) { // Find the first unallocated argument registers. ArrayRef<MCPhysReg> ArgGPRs = get64BitArgumentGPRs(CallConv, Subtarget); ArrayRef<MCPhysReg> ArgXMMs = get64BitArgumentXMMs(MF, CallConv, Subtarget); @@ -2760,7 +3012,7 @@ SDValue X86TargetLowering::LowerFormalArguments( // for the return address. int HomeOffset = TFI.getOffsetOfLocalArea() + 8; FuncInfo->setRegSaveFrameIndex( - MFI->CreateFixedObject(1, NumIntRegs * 8 + HomeOffset, false)); + MFI.CreateFixedObject(1, NumIntRegs * 8 + HomeOffset, false)); // Fixup to set vararg frame on shadow area (4 x i64). if (NumIntRegs < 4) FuncInfo->setVarArgsFrameIndex(FuncInfo->getRegSaveFrameIndex()); @@ -2770,7 +3022,7 @@ SDValue X86TargetLowering::LowerFormalArguments( // they may be loaded by dereferencing the result of va_next. FuncInfo->setVarArgsGPOffset(NumIntRegs * 8); FuncInfo->setVarArgsFPOffset(ArgGPRs.size() * 8 + NumXMMRegs * 16); - FuncInfo->setRegSaveFrameIndex(MFI->CreateStackObject( + FuncInfo->setRegSaveFrameIndex(MFI.CreateStackObject( ArgGPRs.size() * 8 + ArgXMMs.size() * 16, 16, false)); } @@ -2810,7 +3062,7 @@ SDValue X86TargetLowering::LowerFormalArguments( Chain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, MemOps); } - if (isVarArg && MFI->hasMustTailInVarArgFunc()) { + if (isVarArg && MFI.hasMustTailInVarArgFunc()) { // Find the largest legal vector type. MVT VecVT = MVT::Other; // FIXME: Only some x86_32 calling conventions support AVX512. @@ -2889,7 +3141,7 @@ SDValue X86TargetLowering::LowerFormalArguments( // same, so the size of funclets' (mostly empty) frames is dictated by // how far this slot is from the bottom (since they allocate just enough // space to accommodate holding this slot at the correct offset). - int PSPSymFI = MFI->CreateStackObject(8, 8, /*isSS=*/false); + int PSPSymFI = MFI.CreateStackObject(8, 8, /*isSS=*/false); EHInfo->PSPSymFrameIdx = PSPSymFI; } } @@ -2938,7 +3190,7 @@ static SDValue EmitTailCallStoreRetAddr(SelectionDAG &DAG, MachineFunction &MF, if (!FPDiff) return Chain; // Calculate the new stack slot for the return address. int NewReturnAddrFI = - MF.getFrameInfo()->CreateFixedObject(SlotSize, (int64_t)FPDiff - SlotSize, + MF.getFrameInfo().CreateFixedObject(SlotSize, (int64_t)FPDiff - SlotSize, false); SDValue NewRetAddrFrIdx = DAG.getFrameIndex(NewReturnAddrFI, PtrVT); Chain = DAG.getStore(Chain, dl, RetAddrFrIdx, NewRetAddrFrIdx, @@ -3029,11 +3281,17 @@ X86TargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, SmallVector<CCValAssign, 16> ArgLocs; CCState CCInfo(CallConv, isVarArg, MF, ArgLocs, *DAG.getContext()); - // Allocate shadow area for Win64 + // Allocate shadow area for Win64. if (IsWin64) CCInfo.AllocateStack(32, 8); - CCInfo.AnalyzeCallOperands(Outs, CC_X86); + CCInfo.AnalyzeArguments(Outs, CC_X86); + + // In vectorcall calling convention a second pass is required for the HVA + // types. + if (CallingConv::X86_VectorCall == CallConv) { + CCInfo.AnalyzeArgumentsSecondPass(Outs, CC_X86); + } // Get a count of how many bytes are to be pushed on the stack. unsigned NumBytes = CCInfo.getAlignedCallFrameSize(); @@ -3088,18 +3346,25 @@ X86TargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, SmallVector<SDValue, 8> MemOpChains; SDValue StackPtr; + // The next loop assumes that the locations are in the same order of the + // input arguments. + if (!isSortedByValueNo(ArgLocs)) + llvm_unreachable("Argument Location list must be sorted before lowering"); + // Walk the register/memloc assignments, inserting copies/loads. In the case // of tail call optimization arguments are handle later. const X86RegisterInfo *RegInfo = Subtarget.getRegisterInfo(); - for (unsigned i = 0, e = ArgLocs.size(); i != e; ++i) { + for (unsigned I = 0, OutIndex = 0, E = ArgLocs.size(); I != E; + ++I, ++OutIndex) { + assert(OutIndex < Outs.size() && "Invalid Out index"); // Skip inalloca arguments, they have already been written. - ISD::ArgFlagsTy Flags = Outs[i].Flags; + ISD::ArgFlagsTy Flags = Outs[OutIndex].Flags; if (Flags.isInAlloca()) continue; - CCValAssign &VA = ArgLocs[i]; + CCValAssign &VA = ArgLocs[I]; EVT RegVT = VA.getLocVT(); - SDValue Arg = OutVals[i]; + SDValue Arg = OutVals[OutIndex]; bool isByVal = Flags.isByVal(); // Promote the value if needed. @@ -3115,7 +3380,7 @@ X86TargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, case CCValAssign::AExt: if (Arg.getValueType().isVector() && Arg.getValueType().getVectorElementType() == MVT::i1) - Arg = DAG.getNode(ISD::SIGN_EXTEND, dl, RegVT, Arg); + Arg = lowerMasksToReg(Arg, RegVT, dl, DAG); else if (RegVT.is128BitVector()) { // Special case: passing MMX values in XMM registers. Arg = DAG.getBitcast(MVT::i64, Arg); @@ -3139,7 +3404,13 @@ X86TargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, } } - if (VA.isRegLoc()) { + if (VA.needsCustom()) { + assert(VA.getValVT() == MVT::v64i1 && + "Currently the only custom case is when we split v64i1 to 2 regs"); + // Split v64i1 value into two registers + Passv64i1ArgInRegs(dl, DAG, Chain, Arg, RegsToPass, VA, ArgLocs[++I], + Subtarget); + } else if (VA.isRegLoc()) { RegsToPass.push_back(std::make_pair(VA.getLocReg(), Arg)); if (isVarArg && IsWin64) { // Win64 ABI requires argument XMM reg to be copied to the corresponding @@ -3239,20 +3510,32 @@ X86TargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, SmallVector<SDValue, 8> MemOpChains2; SDValue FIN; int FI = 0; - for (unsigned i = 0, e = ArgLocs.size(); i != e; ++i) { - CCValAssign &VA = ArgLocs[i]; - if (VA.isRegLoc()) + for (unsigned I = 0, OutsIndex = 0, E = ArgLocs.size(); I != E; + ++I, ++OutsIndex) { + CCValAssign &VA = ArgLocs[I]; + + if (VA.isRegLoc()) { + if (VA.needsCustom()) { + assert((CallConv == CallingConv::X86_RegCall) && + "Expecting custome case only in regcall calling convention"); + // This means that we are in special case where one argument was + // passed through two register locations - Skip the next location + ++I; + } + continue; + } + assert(VA.isMemLoc()); - SDValue Arg = OutVals[i]; - ISD::ArgFlagsTy Flags = Outs[i].Flags; + SDValue Arg = OutVals[OutsIndex]; + ISD::ArgFlagsTy Flags = Outs[OutsIndex].Flags; // Skip inalloca arguments. They don't require any work. if (Flags.isInAlloca()) continue; // Create frame index. int32_t Offset = VA.getLocMemOffset()+FPDiff; uint32_t OpSize = (VA.getLocVT().getSizeInBits()+7)/8; - FI = MF.getFrameInfo()->CreateFixedObject(OpSize, Offset, true); + FI = MF.getFrameInfo().CreateFixedObject(OpSize, Offset, true); FIN = DAG.getFrameIndex(FI, getPointerTy(DAG.getDataLayout())); if (Flags.isByVal()) { @@ -3391,7 +3674,7 @@ X86TargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, // This isn't right, although it's probably harmless on x86; liveouts // should be computed from returns not tail calls. Consider a void // function making a tail call to a function returning int. - MF.getFrameInfo()->setHasTailCall(); + MF.getFrameInfo().setHasTailCall(); return DAG.getNode(X86ISD::TC_RETURN, dl, NodeTys, Ops); } @@ -3493,9 +3776,9 @@ X86TargetLowering::GetAlignedArgumentStackSize(unsigned StackSize, /// same position (relatively) of the caller's incoming argument stack. static bool MatchingStackOffset(SDValue Arg, unsigned Offset, ISD::ArgFlagsTy Flags, - MachineFrameInfo *MFI, const MachineRegisterInfo *MRI, + MachineFrameInfo &MFI, const MachineRegisterInfo *MRI, const X86InstrInfo *TII, const CCValAssign &VA) { - unsigned Bytes = Arg.getValueType().getSizeInBits() / 8; + unsigned Bytes = Arg.getValueSizeInBits() / 8; for (;;) { // Look through nodes that don't alter the bits of the incoming value. @@ -3558,22 +3841,22 @@ bool MatchingStackOffset(SDValue Arg, unsigned Offset, ISD::ArgFlagsTy Flags, return false; assert(FI != INT_MAX); - if (!MFI->isFixedObjectIndex(FI)) + if (!MFI.isFixedObjectIndex(FI)) return false; - if (Offset != MFI->getObjectOffset(FI)) + if (Offset != MFI.getObjectOffset(FI)) return false; - if (VA.getLocVT().getSizeInBits() > Arg.getValueType().getSizeInBits()) { + if (VA.getLocVT().getSizeInBits() > Arg.getValueSizeInBits()) { // If the argument location is wider than the argument type, check that any // extension flags match. - if (Flags.isZExt() != MFI->isObjectZExt(FI) || - Flags.isSExt() != MFI->isObjectSExt(FI)) { + if (Flags.isZExt() != MFI.isObjectZExt(FI) || + Flags.isSExt() != MFI.isObjectSExt(FI)) { return false; } } - return Bytes == MFI->getObjectSize(FI); + return Bytes == MFI.getObjectSize(FI); } /// Check whether the call is eligible for tail call optimization. Targets @@ -3700,7 +3983,7 @@ bool X86TargetLowering::IsEligibleForTailCallOptimization( if (CCInfo.getNextStackOffset()) { // Check if the arguments are already laid out in the right way as // the caller's fixed stack objects. - MachineFrameInfo *MFI = MF.getFrameInfo(); + MachineFrameInfo &MFI = MF.getFrameInfo(); const MachineRegisterInfo *MRI = &MF.getRegInfo(); const X86InstrInfo *TII = Subtarget.getInstrInfo(); for (unsigned i = 0, e = ArgLocs.size(); i != e; ++i) { @@ -3787,6 +4070,14 @@ static bool MayFoldIntoStore(SDValue Op) { return Op.hasOneUse() && ISD::isNormalStore(*Op.getNode()->use_begin()); } +static bool MayFoldIntoZeroExtend(SDValue Op) { + if (Op.hasOneUse()) { + unsigned Opcode = Op.getNode()->use_begin()->getOpcode(); + return (ISD::ZERO_EXTEND == Opcode); + } + return false; +} + static bool isTargetShuffle(unsigned Opcode) { switch(Opcode) { default: return false; @@ -3821,6 +4112,7 @@ static bool isTargetShuffle(unsigned Opcode) { case X86ISD::VPPERM: case X86ISD::VPERMV: case X86ISD::VPERMV3: + case X86ISD::VPERMIV3: case X86ISD::VZEXT_MOVL: return true; } @@ -3829,41 +4121,18 @@ static bool isTargetShuffle(unsigned Opcode) { static bool isTargetShuffleVariableMask(unsigned Opcode) { switch (Opcode) { default: return false; + // Target Shuffles. case X86ISD::PSHUFB: case X86ISD::VPERMILPV: + case X86ISD::VPERMIL2: + case X86ISD::VPPERM: + case X86ISD::VPERMV: + case X86ISD::VPERMV3: + case X86ISD::VPERMIV3: + return true; + // 'Faux' Target Shuffles. + case ISD::AND: return true; - } -} - -static SDValue getTargetShuffleNode(unsigned Opc, const SDLoc &dl, MVT VT, - SDValue V1, unsigned TargetMask, - SelectionDAG &DAG) { - switch(Opc) { - default: llvm_unreachable("Unknown x86 shuffle node"); - case X86ISD::PSHUFD: - case X86ISD::PSHUFHW: - case X86ISD::PSHUFLW: - case X86ISD::VPERMILPI: - case X86ISD::VPERMI: - return DAG.getNode(Opc, dl, VT, V1, - DAG.getConstant(TargetMask, dl, MVT::i8)); - } -} - -static SDValue getTargetShuffleNode(unsigned Opc, const SDLoc &dl, MVT VT, - SDValue V1, SDValue V2, SelectionDAG &DAG) { - switch(Opc) { - default: llvm_unreachable("Unknown x86 shuffle node"); - case X86ISD::MOVLHPS: - case X86ISD::MOVLHPD: - case X86ISD::MOVHLPS: - case X86ISD::MOVLPS: - case X86ISD::MOVLPD: - case X86ISD::MOVSS: - case X86ISD::MOVSD: - case X86ISD::UNPCKL: - case X86ISD::UNPCKH: - return DAG.getNode(Opc, dl, VT, V1, V2); } } @@ -3876,9 +4145,9 @@ SDValue X86TargetLowering::getReturnAddressFrameIndex(SelectionDAG &DAG) const { if (ReturnAddrIndex == 0) { // Set up a frame object for the return address. unsigned SlotSize = RegInfo->getSlotSize(); - ReturnAddrIndex = MF.getFrameInfo()->CreateFixedObject(SlotSize, - -(int64_t)SlotSize, - false); + ReturnAddrIndex = MF.getFrameInfo().CreateFixedObject(SlotSize, + -(int64_t)SlotSize, + false); FuncInfo->setRAIndex(ReturnAddrIndex); } @@ -3974,7 +4243,7 @@ static X86::CondCode TranslateIntegerX86CC(ISD::CondCode SetCCOpcode) { /// Do a one-to-one translation of a ISD::CondCode to the X86-specific /// condition code, returning the condition code and the LHS/RHS of the /// comparison to make. -static unsigned TranslateX86CC(ISD::CondCode SetCCOpcode, const SDLoc &DL, +static X86::CondCode TranslateX86CC(ISD::CondCode SetCCOpcode, const SDLoc &DL, bool isFP, SDValue &LHS, SDValue &RHS, SelectionDAG &DAG) { if (!isFP) { @@ -4175,6 +4444,10 @@ bool X86TargetLowering::isCheapToSpeculateCtlz() const { return Subtarget.hasLZCNT(); } +bool X86TargetLowering::isCtlzFast() const { + return Subtarget.hasFastLZCNT(); +} + bool X86TargetLowering::hasAndNotCompare(SDValue Y) const { if (!Subtarget.hasBMI()) return false; @@ -4187,11 +4460,21 @@ bool X86TargetLowering::hasAndNotCompare(SDValue Y) const { return true; } +/// Val is the undef sentinel value or equal to the specified value. +static bool isUndefOrEqual(int Val, int CmpVal) { + return ((Val == SM_SentinelUndef) || (Val == CmpVal)); +} + +/// Val is either the undef or zero sentinel value. +static bool isUndefOrZero(int Val) { + return ((Val == SM_SentinelUndef) || (Val == SM_SentinelZero)); +} + /// Return true if every element in Mask, beginning -/// from position Pos and ending in Pos+Size is undef. +/// from position Pos and ending in Pos+Size is the undef sentinel value. static bool isUndefInRange(ArrayRef<int> Mask, unsigned Pos, unsigned Size) { for (unsigned i = Pos, e = Pos + Size; i != e; ++i) - if (0 <= Mask[i]) + if (Mask[i] != SM_SentinelUndef) return false; return true; } @@ -4199,7 +4482,7 @@ static bool isUndefInRange(ArrayRef<int> Mask, unsigned Pos, unsigned Size) { /// Return true if Val is undef or if its value falls within the /// specified range (L, H]. static bool isUndefOrInRange(int Val, int Low, int Hi) { - return (Val < 0) || (Val >= Low && Val < Hi); + return (Val == SM_SentinelUndef) || (Val >= Low && Val < Hi); } /// Return true if every element in Mask is undef or if its value @@ -4212,14 +4495,19 @@ static bool isUndefOrInRange(ArrayRef<int> Mask, return true; } -/// Val is either less than zero (undef) or equal to the specified value. -static bool isUndefOrEqual(int Val, int CmpVal) { - return (Val < 0 || Val == CmpVal); +/// Return true if Val is undef, zero or if its value falls within the +/// specified range (L, H]. +static bool isUndefOrZeroOrInRange(int Val, int Low, int Hi) { + return isUndefOrZero(Val) || (Val >= Low && Val < Hi); } -/// Val is either the undef or zero sentinel value. -static bool isUndefOrZero(int Val) { - return (Val == SM_SentinelUndef || Val == SM_SentinelZero); +/// Return true if every element in Mask is undef, zero or if its value +/// falls within the specified range (L, H]. +static bool isUndefOrZeroOrInRange(ArrayRef<int> Mask, int Low, int Hi) { + for (int M : Mask) + if (!isUndefOrZeroOrInRange(M, Low, Hi)) + return false; + return true; } /// Return true if every element in Mask, beginning @@ -4244,6 +4532,100 @@ static bool isSequentialOrUndefOrZeroInRange(ArrayRef<int> Mask, unsigned Pos, return true; } +/// Return true if every element in Mask, beginning +/// from position Pos and ending in Pos+Size is undef or is zero. +static bool isUndefOrZeroInRange(ArrayRef<int> Mask, unsigned Pos, + unsigned Size) { + for (unsigned i = Pos, e = Pos + Size; i != e; ++i) + if (!isUndefOrZero(Mask[i])) + return false; + return true; +} + +/// \brief Helper function to test whether a shuffle mask could be +/// simplified by widening the elements being shuffled. +/// +/// Appends the mask for wider elements in WidenedMask if valid. Otherwise +/// leaves it in an unspecified state. +/// +/// NOTE: This must handle normal vector shuffle masks and *target* vector +/// shuffle masks. The latter have the special property of a '-2' representing +/// a zero-ed lane of a vector. +static bool canWidenShuffleElements(ArrayRef<int> Mask, + SmallVectorImpl<int> &WidenedMask) { + WidenedMask.assign(Mask.size() / 2, 0); + for (int i = 0, Size = Mask.size(); i < Size; i += 2) { + // If both elements are undef, its trivial. + if (Mask[i] == SM_SentinelUndef && Mask[i + 1] == SM_SentinelUndef) { + WidenedMask[i / 2] = SM_SentinelUndef; + continue; + } + + // Check for an undef mask and a mask value properly aligned to fit with + // a pair of values. If we find such a case, use the non-undef mask's value. + if (Mask[i] == SM_SentinelUndef && Mask[i + 1] >= 0 && + Mask[i + 1] % 2 == 1) { + WidenedMask[i / 2] = Mask[i + 1] / 2; + continue; + } + if (Mask[i + 1] == SM_SentinelUndef && Mask[i] >= 0 && Mask[i] % 2 == 0) { + WidenedMask[i / 2] = Mask[i] / 2; + continue; + } + + // When zeroing, we need to spread the zeroing across both lanes to widen. + if (Mask[i] == SM_SentinelZero || Mask[i + 1] == SM_SentinelZero) { + if ((Mask[i] == SM_SentinelZero || Mask[i] == SM_SentinelUndef) && + (Mask[i + 1] == SM_SentinelZero || Mask[i + 1] == SM_SentinelUndef)) { + WidenedMask[i / 2] = SM_SentinelZero; + continue; + } + return false; + } + + // Finally check if the two mask values are adjacent and aligned with + // a pair. + if (Mask[i] != SM_SentinelUndef && Mask[i] % 2 == 0 && + Mask[i] + 1 == Mask[i + 1]) { + WidenedMask[i / 2] = Mask[i] / 2; + continue; + } + + // Otherwise we can't safely widen the elements used in this shuffle. + return false; + } + assert(WidenedMask.size() == Mask.size() / 2 && + "Incorrect size of mask after widening the elements!"); + + return true; +} + +/// Helper function to scale a shuffle or target shuffle mask, replacing each +/// mask index with the scaled sequential indices for an equivalent narrowed +/// mask. This is the reverse process to canWidenShuffleElements, but can always +/// succeed. +static void scaleShuffleMask(int Scale, ArrayRef<int> Mask, + SmallVectorImpl<int> &ScaledMask) { + assert(0 < Scale && "Unexpected scaling factor"); + int NumElts = Mask.size(); + ScaledMask.assign(NumElts * Scale, -1); + + for (int i = 0; i != NumElts; ++i) { + int M = Mask[i]; + + // Repeat sentinel values in every mask element. + if (M < 0) { + for (int s = 0; s != Scale; ++s) + ScaledMask[(Scale * i) + s] = M; + continue; + } + + // Scale mask element and increment across each mask element. + for (int s = 0; s != Scale; ++s) + ScaledMask[(Scale * i) + s] = (Scale * M) + s; + } +} + /// Return true if the specified EXTRACT_SUBVECTOR operand specifies a vector /// extract that is suitable for instruction that extract 128 or 256 bit vectors static bool isVEXTRACTIndex(SDNode *N, unsigned vecWidth) { @@ -4256,7 +4638,7 @@ static bool isVEXTRACTIndex(SDNode *N, unsigned vecWidth) { cast<ConstantSDNode>(N->getOperand(1).getNode())->getZExtValue(); MVT VT = N->getSimpleValueType(0); - unsigned ElSize = VT.getVectorElementType().getSizeInBits(); + unsigned ElSize = VT.getScalarSizeInBits(); bool Result = (Index * ElSize) % vecWidth == 0; return Result; @@ -4274,7 +4656,7 @@ static bool isVINSERTIndex(SDNode *N, unsigned vecWidth) { cast<ConstantSDNode>(N->getOperand(2).getNode())->getZExtValue(); MVT VT = N->getSimpleValueType(0); - unsigned ElSize = VT.getVectorElementType().getSizeInBits(); + unsigned ElSize = VT.getScalarSizeInBits(); bool Result = (Index * ElSize) % vecWidth == 0; return Result; @@ -4388,6 +4770,46 @@ static SDValue getConstVector(ArrayRef<int> Values, MVT VT, SelectionDAG &DAG, return ConstsNode; } +static SDValue getConstVector(ArrayRef<APInt> Bits, SmallBitVector &Undefs, + MVT VT, SelectionDAG &DAG, const SDLoc &dl) { + assert(Bits.size() == Undefs.size() && "Unequal constant and undef arrays"); + SmallVector<SDValue, 32> Ops; + bool Split = false; + + MVT ConstVecVT = VT; + unsigned NumElts = VT.getVectorNumElements(); + bool In64BitMode = DAG.getTargetLoweringInfo().isTypeLegal(MVT::i64); + if (!In64BitMode && VT.getVectorElementType() == MVT::i64) { + ConstVecVT = MVT::getVectorVT(MVT::i32, NumElts * 2); + Split = true; + } + + MVT EltVT = ConstVecVT.getVectorElementType(); + for (unsigned i = 0, e = Bits.size(); i != e; ++i) { + if (Undefs[i]) { + Ops.append(Split ? 2 : 1, DAG.getUNDEF(EltVT)); + continue; + } + const APInt &V = Bits[i]; + assert(V.getBitWidth() == VT.getScalarSizeInBits() && "Unexpected sizes"); + if (Split) { + Ops.push_back(DAG.getConstant(V.trunc(32), dl, EltVT)); + Ops.push_back(DAG.getConstant(V.lshr(32).trunc(32), dl, EltVT)); + } else if (EltVT == MVT::f32) { + APFloat FV(APFloat::IEEEsingle(), V); + Ops.push_back(DAG.getConstantFP(FV, dl, EltVT)); + } else if (EltVT == MVT::f64) { + APFloat FV(APFloat::IEEEdouble(), V); + Ops.push_back(DAG.getConstantFP(FV, dl, EltVT)); + } else { + Ops.push_back(DAG.getConstant(V, dl, EltVT)); + } + } + + SDValue ConstsNode = DAG.getBuildVector(ConstVecVT, dl, Ops); + return DAG.getBitcast(VT, ConstsNode); +} + /// Returns a vector of specified type with all zero elements. static SDValue getZeroVector(MVT VT, const X86Subtarget &Subtarget, SelectionDAG &DAG, const SDLoc &dl) { @@ -4416,8 +4838,6 @@ static SDValue getZeroVector(MVT VT, const X86Subtarget &Subtarget, static SDValue extractSubVector(SDValue Vec, unsigned IdxVal, SelectionDAG &DAG, const SDLoc &dl, unsigned vectorWidth) { - assert((vectorWidth == 128 || vectorWidth == 256) && - "Unsupported vector width"); EVT VT = Vec.getValueType(); EVT ElVT = VT.getVectorElementType(); unsigned Factor = VT.getSizeInBits()/vectorWidth; @@ -4438,8 +4858,8 @@ static SDValue extractSubVector(SDValue Vec, unsigned IdxVal, SelectionDAG &DAG, // If the input is a buildvector just emit a smaller one. if (Vec.getOpcode() == ISD::BUILD_VECTOR) - return DAG.getNode(ISD::BUILD_VECTOR, - dl, ResultVT, makeArrayRef(Vec->op_begin() + IdxVal, ElemsPerChunk)); + return DAG.getNode(ISD::BUILD_VECTOR, dl, ResultVT, + makeArrayRef(Vec->op_begin() + IdxVal, ElemsPerChunk)); SDValue VecIdx = DAG.getIntPtrConstant(IdxVal, dl); return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, ResultVT, Vec, VecIdx); @@ -4694,29 +5114,35 @@ static SDValue getOnesVector(EVT VT, const X86Subtarget &Subtarget, return DAG.getBitcast(VT, Vec); } +/// Generate unpacklo/unpackhi shuffle mask. +static void createUnpackShuffleMask(MVT VT, SmallVectorImpl<int> &Mask, bool Lo, + bool Unary) { + assert(Mask.empty() && "Expected an empty shuffle mask vector"); + int NumElts = VT.getVectorNumElements(); + int NumEltsInLane = 128 / VT.getScalarSizeInBits(); + + for (int i = 0; i < NumElts; ++i) { + unsigned LaneStart = (i / NumEltsInLane) * NumEltsInLane; + int Pos = (i % NumEltsInLane) / 2 + LaneStart; + Pos += (Unary ? 0 : NumElts * (i % 2)); + Pos += (Lo ? 0 : NumEltsInLane / 2); + Mask.push_back(Pos); + } +} + /// Returns a vector_shuffle node for an unpackl operation. static SDValue getUnpackl(SelectionDAG &DAG, const SDLoc &dl, MVT VT, SDValue V1, SDValue V2) { - assert(VT.is128BitVector() && "Expected a 128-bit vector type"); - unsigned NumElems = VT.getVectorNumElements(); - SmallVector<int, 8> Mask(NumElems); - for (unsigned i = 0, e = NumElems/2; i != e; ++i) { - Mask[i * 2] = i; - Mask[i * 2 + 1] = i + NumElems; - } + SmallVector<int, 8> Mask; + createUnpackShuffleMask(VT, Mask, /* Lo = */ true, /* Unary = */ false); return DAG.getVectorShuffle(VT, dl, V1, V2, Mask); } /// Returns a vector_shuffle node for an unpackh operation. static SDValue getUnpackh(SelectionDAG &DAG, const SDLoc &dl, MVT VT, SDValue V1, SDValue V2) { - assert(VT.is128BitVector() && "Expected a 128-bit vector type"); - unsigned NumElems = VT.getVectorNumElements(); - SmallVector<int, 8> Mask(NumElems); - for (unsigned i = 0, Half = NumElems/2; i != Half; ++i) { - Mask[i * 2] = i + Half; - Mask[i * 2 + 1] = i + NumElems + Half; - } + SmallVector<int, 8> Mask; + createUnpackShuffleMask(VT, Mask, /* Lo = */ false, /* Unary = */ false); return DAG.getVectorShuffle(VT, dl, V1, V2, Mask); } @@ -4745,6 +5171,135 @@ static SDValue peekThroughBitcasts(SDValue V) { return V; } +static SDValue peekThroughOneUseBitcasts(SDValue V) { + while (V.getNode() && V.getOpcode() == ISD::BITCAST && + V.getOperand(0).hasOneUse()) + V = V.getOperand(0); + return V; +} + +static const Constant *getTargetConstantFromNode(SDValue Op) { + Op = peekThroughBitcasts(Op); + + auto *Load = dyn_cast<LoadSDNode>(Op); + if (!Load) + return nullptr; + + SDValue Ptr = Load->getBasePtr(); + if (Ptr->getOpcode() == X86ISD::Wrapper || + Ptr->getOpcode() == X86ISD::WrapperRIP) + Ptr = Ptr->getOperand(0); + + auto *CNode = dyn_cast<ConstantPoolSDNode>(Ptr); + if (!CNode || CNode->isMachineConstantPoolEntry()) + return nullptr; + + return dyn_cast<Constant>(CNode->getConstVal()); +} + +// Extract raw constant bits from constant pools. +static bool getTargetConstantBitsFromNode(SDValue Op, unsigned EltSizeInBits, + SmallBitVector &UndefElts, + SmallVectorImpl<APInt> &EltBits) { + assert(UndefElts.empty() && "Expected an empty UndefElts vector"); + assert(EltBits.empty() && "Expected an empty EltBits vector"); + + Op = peekThroughBitcasts(Op); + + EVT VT = Op.getValueType(); + unsigned SizeInBits = VT.getSizeInBits(); + assert((SizeInBits % EltSizeInBits) == 0 && "Can't split constant!"); + unsigned NumElts = SizeInBits / EltSizeInBits; + + // Extract all the undef/constant element data and pack into single bitsets. + APInt UndefBits(SizeInBits, 0); + APInt MaskBits(SizeInBits, 0); + + // Split the undef/constant single bitset data into the target elements. + auto SplitBitData = [&]() { + UndefElts = SmallBitVector(NumElts, false); + EltBits.resize(NumElts, APInt(EltSizeInBits, 0)); + + for (unsigned i = 0; i != NumElts; ++i) { + APInt UndefEltBits = UndefBits.lshr(i * EltSizeInBits); + UndefEltBits = UndefEltBits.zextOrTrunc(EltSizeInBits); + + // Only treat an element as UNDEF if all bits are UNDEF, otherwise + // treat it as zero. + if (UndefEltBits.isAllOnesValue()) { + UndefElts[i] = true; + continue; + } + + APInt Bits = MaskBits.lshr(i * EltSizeInBits); + Bits = Bits.zextOrTrunc(EltSizeInBits); + EltBits[i] = Bits.getZExtValue(); + } + return true; + }; + + auto ExtractConstantBits = [SizeInBits](const Constant *Cst, APInt &Mask, + APInt &Undefs) { + if (!Cst) + return false; + unsigned CstSizeInBits = Cst->getType()->getPrimitiveSizeInBits(); + if (isa<UndefValue>(Cst)) { + Mask = APInt::getNullValue(SizeInBits); + Undefs = APInt::getLowBitsSet(SizeInBits, CstSizeInBits); + return true; + } + if (auto *CInt = dyn_cast<ConstantInt>(Cst)) { + Mask = CInt->getValue().zextOrTrunc(SizeInBits); + Undefs = APInt::getNullValue(SizeInBits); + return true; + } + if (auto *CFP = dyn_cast<ConstantFP>(Cst)) { + Mask = CFP->getValueAPF().bitcastToAPInt().zextOrTrunc(SizeInBits); + Undefs = APInt::getNullValue(SizeInBits); + return true; + } + return false; + }; + + // Extract constant bits from constant pool vector. + if (auto *Cst = getTargetConstantFromNode(Op)) { + Type *CstTy = Cst->getType(); + if (!CstTy->isVectorTy() || (SizeInBits != CstTy->getPrimitiveSizeInBits())) + return false; + + unsigned CstEltSizeInBits = CstTy->getScalarSizeInBits(); + for (unsigned i = 0, e = CstTy->getVectorNumElements(); i != e; ++i) { + APInt Bits, Undefs; + if (!ExtractConstantBits(Cst->getAggregateElement(i), Bits, Undefs)) + return false; + MaskBits |= Bits.shl(i * CstEltSizeInBits); + UndefBits |= Undefs.shl(i * CstEltSizeInBits); + } + + return SplitBitData(); + } + + // Extract constant bits from a broadcasted constant pool scalar. + if (Op.getOpcode() == X86ISD::VBROADCAST && + EltSizeInBits <= Op.getScalarValueSizeInBits()) { + if (auto *Broadcast = getTargetConstantFromNode(Op.getOperand(0))) { + APInt Bits, Undefs; + if (ExtractConstantBits(Broadcast, Bits, Undefs)) { + unsigned NumBroadcastBits = Op.getScalarValueSizeInBits(); + unsigned NumBroadcastElts = SizeInBits / NumBroadcastBits; + for (unsigned i = 0; i != NumBroadcastElts; ++i) { + MaskBits |= Bits.shl(i * NumBroadcastBits); + UndefBits |= Undefs.shl(i * NumBroadcastBits); + } + return SplitBitData(); + } + } + } + + return false; +} + +// TODO: Merge more of this with getTargetConstantBitsFromNode. static bool getTargetShuffleMaskIndices(SDValue MaskNode, unsigned MaskEltSizeInBits, SmallVectorImpl<uint64_t> &RawMask) { @@ -4752,6 +5307,7 @@ static bool getTargetShuffleMaskIndices(SDValue MaskNode, MVT VT = MaskNode.getSimpleValueType(); assert(VT.isVector() && "Can't produce a non-vector with a build_vector!"); + unsigned NumMaskElts = VT.getSizeInBits() / MaskEltSizeInBits; // Split an APInt element into MaskEltSizeInBits sized pieces and // insert into the shuffle mask. @@ -4783,17 +5339,20 @@ static bool getTargetShuffleMaskIndices(SDValue MaskNode, if (MaskNode.getOpcode() == X86ISD::VZEXT_MOVL && MaskNode.getOperand(0).getOpcode() == ISD::SCALAR_TO_VECTOR) { - - // TODO: Handle (MaskEltSizeInBits % VT.getScalarSizeInBits()) == 0 - if ((VT.getScalarSizeInBits() % MaskEltSizeInBits) != 0) - return false; - unsigned ElementSplit = VT.getScalarSizeInBits() / MaskEltSizeInBits; - SDValue MaskOp = MaskNode.getOperand(0).getOperand(0); if (auto *CN = dyn_cast<ConstantSDNode>(MaskOp)) { - SplitElementToMask(CN->getAPIntValue()); - RawMask.append((VT.getVectorNumElements() - 1) * ElementSplit, 0); - return true; + if ((MaskEltSizeInBits % VT.getScalarSizeInBits()) == 0) { + RawMask.push_back(CN->getZExtValue()); + RawMask.append(NumMaskElts - 1, 0); + return true; + } + + if ((VT.getScalarSizeInBits() % MaskEltSizeInBits) == 0) { + unsigned ElementSplit = VT.getScalarSizeInBits() / MaskEltSizeInBits; + SplitElementToMask(CN->getAPIntValue()); + RawMask.append((VT.getVectorNumElements() - 1) * ElementSplit, 0); + return true; + } } return false; } @@ -4803,8 +5362,8 @@ static bool getTargetShuffleMaskIndices(SDValue MaskNode, // We can always decode if the buildvector is all zero constants, // but can't use isBuildVectorAllZeros as it might contain UNDEFs. - if (llvm::all_of(MaskNode->ops(), X86::isZeroNode)) { - RawMask.append(VT.getSizeInBits() / MaskEltSizeInBits, 0); + if (all_of(MaskNode->ops(), X86::isZeroNode)) { + RawMask.append(NumMaskElts, 0); return true; } @@ -4824,25 +5383,6 @@ static bool getTargetShuffleMaskIndices(SDValue MaskNode, return true; } -static const Constant *getTargetShuffleMaskConstant(SDValue MaskNode) { - MaskNode = peekThroughBitcasts(MaskNode); - - auto *MaskLoad = dyn_cast<LoadSDNode>(MaskNode); - if (!MaskLoad) - return nullptr; - - SDValue Ptr = MaskLoad->getBasePtr(); - if (Ptr->getOpcode() == X86ISD::Wrapper || - Ptr->getOpcode() == X86ISD::WrapperRIP) - Ptr = Ptr->getOperand(0); - - auto *MaskCP = dyn_cast<ConstantPoolSDNode>(Ptr); - if (!MaskCP || MaskCP->isMachineConstantPoolEntry()) - return nullptr; - - return dyn_cast<Constant>(MaskCP->getConstVal()); -} - /// Calculates the shuffle mask corresponding to the target-specific opcode. /// If the mask could be calculated, returns it in \p Mask, returns the shuffle /// operands in \p Ops, and returns true. @@ -4896,6 +5436,9 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, assert(VT.getScalarType() == MVT::i8 && "Byte vector expected"); ImmN = N->getOperand(N->getNumOperands()-1); DecodePALIGNRMask(VT, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask); + IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(1); + Ops.push_back(N->getOperand(1)); + Ops.push_back(N->getOperand(0)); break; case X86ISD::VSHLDQ: assert(VT.getScalarType() == MVT::i8 && "Byte vector expected"); @@ -4947,7 +5490,7 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, DecodeVPERMILPMask(VT, RawMask, Mask); break; } - if (auto *C = getTargetShuffleMaskConstant(MaskNode)) { + if (auto *C = getTargetConstantFromNode(MaskNode)) { DecodeVPERMILPMask(C, MaskEltSize, Mask); break; } @@ -4961,7 +5504,7 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, DecodePSHUFBMask(RawMask, Mask); break; } - if (auto *C = getTargetShuffleMaskConstant(MaskNode)) { + if (auto *C = getTargetConstantFromNode(MaskNode)) { DecodePSHUFBMask(C, Mask); break; } @@ -5010,7 +5553,7 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, DecodeVPERMIL2PMask(VT, CtrlImm, RawMask, Mask); break; } - if (auto *C = getTargetShuffleMaskConstant(MaskNode)) { + if (auto *C = getTargetConstantFromNode(MaskNode)) { DecodeVPERMIL2PMask(C, CtrlImm, MaskEltSize, Mask); break; } @@ -5025,7 +5568,7 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, DecodeVPPERMMask(RawMask, Mask); break; } - if (auto *C = getTargetShuffleMaskConstant(MaskNode)) { + if (auto *C = getTargetConstantFromNode(MaskNode)) { DecodeVPPERMMask(C, Mask); break; } @@ -5042,8 +5585,8 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, DecodeVPERMVMask(RawMask, Mask); break; } - if (auto *C = getTargetShuffleMaskConstant(MaskNode)) { - DecodeVPERMVMask(C, VT, Mask); + if (auto *C = getTargetConstantFromNode(MaskNode)) { + DecodeVPERMVMask(C, MaskEltSize, Mask); break; } return false; @@ -5054,8 +5597,22 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, Ops.push_back(N->getOperand(0)); Ops.push_back(N->getOperand(2)); SDValue MaskNode = N->getOperand(1); - if (auto *C = getTargetShuffleMaskConstant(MaskNode)) { - DecodeVPERMV3Mask(C, VT, Mask); + unsigned MaskEltSize = VT.getScalarSizeInBits(); + if (auto *C = getTargetConstantFromNode(MaskNode)) { + DecodeVPERMV3Mask(C, MaskEltSize, Mask); + break; + } + return false; + } + case X86ISD::VPERMIV3: { + IsUnary = IsFakeUnary = N->getOperand(1) == N->getOperand(2); + // Unlike most shuffle nodes, VPERMIV3's mask operand is the first one. + Ops.push_back(N->getOperand(1)); + Ops.push_back(N->getOperand(2)); + SDValue MaskNode = N->getOperand(0); + unsigned MaskEltSize = VT.getScalarSizeInBits(); + if (auto *C = getTargetConstantFromNode(MaskNode)) { + DecodeVPERMV3Mask(C, MaskEltSize, Mask); break; } return false; @@ -5069,7 +5626,7 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, // Check if we're getting a shuffle mask with zero'd elements. if (!AllowSentinelZero) - if (llvm::any_of(Mask, [](int M) { return M == SM_SentinelZero; })) + if (any_of(Mask, [](int M) { return M == SM_SentinelZero; })) return false; // If we have a fake unary shuffle, the shuffle mask is spread across two @@ -5101,8 +5658,9 @@ static bool setTargetShuffleZeroElements(SDValue N, bool IsUnary; if (!isTargetShuffle(N.getOpcode())) return false; - if (!getTargetShuffleMask(N.getNode(), N.getSimpleValueType(), true, Ops, - Mask, IsUnary)) + + MVT VT = N.getSimpleValueType(); + if (!getTargetShuffleMask(N.getNode(), VT, true, Ops, Mask, IsUnary)) return false; SDValue V1 = Ops[0]; @@ -5164,9 +5722,94 @@ static bool setTargetShuffleZeroElements(SDValue N, } } + assert(VT.getVectorNumElements() == Mask.size() && + "Different mask size from vector size!"); return true; } +// Attempt to decode ops that could be represented as a shuffle mask. +// The decoded shuffle mask may contain a different number of elements to the +// destination value type. +static bool getFauxShuffleMask(SDValue N, SmallVectorImpl<int> &Mask, + SmallVectorImpl<SDValue> &Ops) { + Mask.clear(); + Ops.clear(); + + MVT VT = N.getSimpleValueType(); + unsigned NumElts = VT.getVectorNumElements(); + unsigned NumSizeInBits = VT.getSizeInBits(); + unsigned NumBitsPerElt = VT.getScalarSizeInBits(); + assert((NumBitsPerElt % 8) == 0 && (NumSizeInBits % 8) == 0 && + "Expected byte aligned value types"); + + unsigned Opcode = N.getOpcode(); + switch (Opcode) { + case ISD::AND: { + // Attempt to decode as a per-byte mask. + SmallBitVector UndefElts; + SmallVector<APInt, 32> EltBits; + if (!getTargetConstantBitsFromNode(N.getOperand(1), 8, UndefElts, EltBits)) + return false; + for (int i = 0, e = (int)EltBits.size(); i != e; ++i) { + if (UndefElts[i]) { + Mask.push_back(SM_SentinelUndef); + continue; + } + uint64_t ByteBits = EltBits[i].getZExtValue(); + if (ByteBits != 0 && ByteBits != 255) + return false; + Mask.push_back(ByteBits == 0 ? SM_SentinelZero : i); + } + Ops.push_back(N.getOperand(0)); + return true; + } + case X86ISD::VSHLI: + case X86ISD::VSRLI: { + uint64_t ShiftVal = N.getConstantOperandVal(1); + // Out of range bit shifts are guaranteed to be zero. + if (NumBitsPerElt <= ShiftVal) { + Mask.append(NumElts, SM_SentinelZero); + return true; + } + + // We can only decode 'whole byte' bit shifts as shuffles. + if ((ShiftVal % 8) != 0) + break; + + uint64_t ByteShift = ShiftVal / 8; + unsigned NumBytes = NumSizeInBits / 8; + unsigned NumBytesPerElt = NumBitsPerElt / 8; + Ops.push_back(N.getOperand(0)); + + // Clear mask to all zeros and insert the shifted byte indices. + Mask.append(NumBytes, SM_SentinelZero); + + if (X86ISD::VSHLI == Opcode) { + for (unsigned i = 0; i != NumBytes; i += NumBytesPerElt) + for (unsigned j = ByteShift; j != NumBytesPerElt; ++j) + Mask[i + j] = i + j - ByteShift; + } else { + for (unsigned i = 0; i != NumBytes; i += NumBytesPerElt) + for (unsigned j = ByteShift; j != NumBytesPerElt; ++j) + Mask[i + j - ByteShift] = i + j; + } + return true; + } + case X86ISD::VZEXT: { + // TODO - add support for VPMOVZX with smaller input vector types. + SDValue Src = N.getOperand(0); + MVT SrcVT = Src.getSimpleValueType(); + if (NumSizeInBits != SrcVT.getSizeInBits()) + break; + DecodeZeroExtendMask(SrcVT.getScalarType(), VT, Mask); + Ops.push_back(Src); + return true; + } + } + + return false; +} + /// Calls setTargetShuffleZeroElements to resolve a target shuffle mask's inputs /// and set the SM_SentinelUndef and SM_SentinelZero values. Then check the /// remaining input indices in case we now have a unary shuffle and adjust the @@ -5176,14 +5819,14 @@ static bool resolveTargetShuffleInputs(SDValue Op, SDValue &Op0, SDValue &Op1, SmallVectorImpl<int> &Mask) { SmallVector<SDValue, 2> Ops; if (!setTargetShuffleZeroElements(Op, Mask, Ops)) - return false; + if (!getFauxShuffleMask(Op, Mask, Ops)) + return false; int NumElts = Mask.size(); - bool Op0InUse = std::any_of(Mask.begin(), Mask.end(), [NumElts](int Idx) { + bool Op0InUse = any_of(Mask, [NumElts](int Idx) { return 0 <= Idx && Idx < NumElts; }); - bool Op1InUse = std::any_of(Mask.begin(), Mask.end(), - [NumElts](int Idx) { return NumElts <= Idx; }); + bool Op1InUse = any_of(Mask, [NumElts](int Idx) { return NumElts <= Idx; }); Op0 = Op0InUse ? Ops[0] : SDValue(); Op1 = Op1InUse ? Ops[1] : SDValue(); @@ -5523,15 +6166,15 @@ static SDValue LowerAsSplatVectorLoad(SDValue SrcOp, MVT VT, const SDLoc &dl, unsigned RequiredAlign = VT.getSizeInBits()/8; SDValue Chain = LD->getChain(); // Make sure the stack object alignment is at least 16 or 32. - MachineFrameInfo *MFI = DAG.getMachineFunction().getFrameInfo(); + MachineFrameInfo &MFI = DAG.getMachineFunction().getFrameInfo(); if (DAG.InferPtrAlignment(Ptr) < RequiredAlign) { - if (MFI->isFixedObjectIndex(FI)) { + if (MFI.isFixedObjectIndex(FI)) { // Can't change the alignment. FIXME: It's possible to compute // the exact stack offset and reference FI + adjust offset instead. // If someone *really* cares about this. That's the way to implement it. return SDValue(); } else { - MFI->setObjectAlignment(FI, RequiredAlign); + MFI.setObjectAlignment(FI, RequiredAlign); } } @@ -5697,11 +6340,13 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts, int LoadSize = (1 + LastLoadedElt - FirstLoadedElt) * LDBaseVT.getStoreSizeInBits(); - // VZEXT_LOAD - consecutive load/undefs followed by zeros/undefs. - if (IsConsecutiveLoad && FirstLoadedElt == 0 && LoadSize == 64 && + // VZEXT_LOAD - consecutive 32/64-bit load/undefs followed by zeros/undefs. + if (IsConsecutiveLoad && FirstLoadedElt == 0 && + (LoadSize == 32 || LoadSize == 64) && ((VT.is128BitVector() || VT.is256BitVector() || VT.is512BitVector()))) { - MVT VecSVT = VT.isFloatingPoint() ? MVT::f64 : MVT::i64; - MVT VecVT = MVT::getVectorVT(VecSVT, VT.getSizeInBits() / 64); + MVT VecSVT = VT.isFloatingPoint() ? MVT::getFloatingPointVT(LoadSize) + : MVT::getIntegerVT(LoadSize); + MVT VecVT = MVT::getVectorVT(VecSVT, VT.getSizeInBits() / LoadSize); if (TLI.isTypeLegal(VecVT)) { SDVTList Tys = DAG.getVTList(VecVT, MVT::Other); SDValue Ops[] = { LDBase->getChain(), LDBase->getBasePtr() }; @@ -5728,31 +6373,53 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts, } } - // VZEXT_MOVL - consecutive 32-bit load/undefs followed by zeros/undefs. - if (IsConsecutiveLoad && FirstLoadedElt == 0 && LoadSize == 32 && - ((VT.is128BitVector() || VT.is256BitVector() || VT.is512BitVector()))) { - MVT VecSVT = VT.isFloatingPoint() ? MVT::f32 : MVT::i32; - MVT VecVT = MVT::getVectorVT(VecSVT, VT.getSizeInBits() / 32); - if (TLI.isTypeLegal(VecVT)) { - SDValue V = LastLoadedElt != 0 ? CreateLoad(VecSVT, LDBase) - : DAG.getBitcast(VecSVT, EltBase); - V = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VecVT, V); - V = DAG.getNode(X86ISD::VZEXT_MOVL, DL, VecVT, V); - return DAG.getBitcast(VT, V); - } + return SDValue(); +} + +static Constant *getConstantVector(MVT VT, APInt SplatValue, + unsigned SplatBitSize, LLVMContext &C) { + unsigned ScalarSize = VT.getScalarSizeInBits(); + unsigned NumElm = SplatBitSize / ScalarSize; + + SmallVector<Constant *, 32> ConstantVec; + for (unsigned i = 0; i < NumElm; i++) { + APInt Val = SplatValue.lshr(ScalarSize * i).trunc(ScalarSize); + Constant *Const; + if (VT.isFloatingPoint()) { + assert((ScalarSize == 32 || ScalarSize == 64) && + "Unsupported floating point scalar size"); + if (ScalarSize == 32) + Const = ConstantFP::get(Type::getFloatTy(C), Val.bitsToFloat()); + else + Const = ConstantFP::get(Type::getDoubleTy(C), Val.bitsToDouble()); + } else + Const = Constant::getIntegerValue(Type::getIntNTy(C, ScalarSize), Val); + ConstantVec.push_back(Const); } + return ConstantVector::get(ArrayRef<Constant *>(ConstantVec)); +} - return SDValue(); +static bool isUseOfShuffle(SDNode *N) { + for (auto *U : N->uses()) { + if (isTargetShuffle(U->getOpcode())) + return true; + if (U->getOpcode() == ISD::BITCAST) // Ignore bitcasts + return isUseOfShuffle(U); + } + return false; } /// Attempt to use the vbroadcast instruction to generate a splat value for the /// following cases: -/// 1. A splat BUILD_VECTOR which uses a single scalar load, or a constant. +/// 1. A splat BUILD_VECTOR which uses: +/// a. A single scalar load, or a constant. +/// b. Repeated pattern of constants (e.g. <0,1,0,1> or <0,1,2,3,0,1,2,3>). /// 2. A splat shuffle which uses a scalar_to_vector node which comes from /// a scalar load, or a constant. +/// /// The VBROADCAST node is returned when a pattern is found, /// or SDValue() otherwise. -static SDValue LowerVectorBroadcast(SDValue Op, const X86Subtarget &Subtarget, +static SDValue LowerVectorBroadcast(BuildVectorSDNode *BVOp, const X86Subtarget &Subtarget, SelectionDAG &DAG) { // VBROADCAST requires AVX. // TODO: Splats could be generated for non-AVX CPUs using SSE @@ -5760,81 +6427,103 @@ static SDValue LowerVectorBroadcast(SDValue Op, const X86Subtarget &Subtarget, if (!Subtarget.hasAVX()) return SDValue(); - MVT VT = Op.getSimpleValueType(); - SDLoc dl(Op); + MVT VT = BVOp->getSimpleValueType(0); + SDLoc dl(BVOp); assert((VT.is128BitVector() || VT.is256BitVector() || VT.is512BitVector()) && "Unsupported vector type for broadcast."); - SDValue Ld; - bool ConstSplatVal; - - switch (Op.getOpcode()) { - default: - // Unknown pattern found. - return SDValue(); - - case ISD::BUILD_VECTOR: { - auto *BVOp = cast<BuildVectorSDNode>(Op.getNode()); - BitVector UndefElements; - SDValue Splat = BVOp->getSplatValue(&UndefElements); - - // We need a splat of a single value to use broadcast, and it doesn't - // make any sense if the value is only in one element of the vector. - if (!Splat || (VT.getVectorNumElements() - UndefElements.count()) <= 1) + BitVector UndefElements; + SDValue Ld = BVOp->getSplatValue(&UndefElements); + + // We need a splat of a single value to use broadcast, and it doesn't + // make any sense if the value is only in one element of the vector. + if (!Ld || (VT.getVectorNumElements() - UndefElements.count()) <= 1) { + APInt SplatValue, Undef; + unsigned SplatBitSize; + bool HasUndef; + // Check if this is a repeated constant pattern suitable for broadcasting. + if (BVOp->isConstantSplat(SplatValue, Undef, SplatBitSize, HasUndef) && + SplatBitSize > VT.getScalarSizeInBits() && + SplatBitSize < VT.getSizeInBits()) { + // Avoid replacing with broadcast when it's a use of a shuffle + // instruction to preserve the present custom lowering of shuffles. + if (isUseOfShuffle(BVOp) || BVOp->hasOneUse()) return SDValue(); - - Ld = Splat; - ConstSplatVal = (Ld.getOpcode() == ISD::Constant || - Ld.getOpcode() == ISD::ConstantFP); - - // Make sure that all of the users of a non-constant load are from the - // BUILD_VECTOR node. - if (!ConstSplatVal && !BVOp->isOnlyUserOf(Ld.getNode())) - return SDValue(); - break; - } - - case ISD::VECTOR_SHUFFLE: { - ShuffleVectorSDNode *SVOp = cast<ShuffleVectorSDNode>(Op); - - // Shuffles must have a splat mask where the first element is - // broadcasted. - if ((!SVOp->isSplat()) || SVOp->getMaskElt(0) != 0) - return SDValue(); - - SDValue Sc = Op.getOperand(0); - if (Sc.getOpcode() != ISD::SCALAR_TO_VECTOR && - Sc.getOpcode() != ISD::BUILD_VECTOR) { - - if (!Subtarget.hasInt256()) - return SDValue(); - - // Use the register form of the broadcast instruction available on AVX2. - if (VT.getSizeInBits() >= 256) - Sc = extract128BitVector(Sc, 0, DAG, dl); - return DAG.getNode(X86ISD::VBROADCAST, dl, VT, Sc); + // replace BUILD_VECTOR with broadcast of the repeated constants. + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + LLVMContext *Ctx = DAG.getContext(); + MVT PVT = TLI.getPointerTy(DAG.getDataLayout()); + if (Subtarget.hasAVX()) { + if (SplatBitSize <= 64 && Subtarget.hasAVX2() && + !(SplatBitSize == 64 && Subtarget.is32Bit())) { + // Splatted value can fit in one INTEGER constant in constant pool. + // Load the constant and broadcast it. + MVT CVT = MVT::getIntegerVT(SplatBitSize); + Type *ScalarTy = Type::getIntNTy(*Ctx, SplatBitSize); + Constant *C = Constant::getIntegerValue(ScalarTy, SplatValue); + SDValue CP = DAG.getConstantPool(C, PVT); + unsigned Repeat = VT.getSizeInBits() / SplatBitSize; + + unsigned Alignment = cast<ConstantPoolSDNode>(CP)->getAlignment(); + Ld = DAG.getLoad( + CVT, dl, DAG.getEntryNode(), CP, + MachinePointerInfo::getConstantPool(DAG.getMachineFunction()), + Alignment); + SDValue Brdcst = DAG.getNode(X86ISD::VBROADCAST, dl, + MVT::getVectorVT(CVT, Repeat), Ld); + return DAG.getBitcast(VT, Brdcst); + } else if (SplatBitSize == 32 || SplatBitSize == 64) { + // Splatted value can fit in one FLOAT constant in constant pool. + // Load the constant and broadcast it. + // AVX have support for 32 and 64 bit broadcast for floats only. + // No 64bit integer in 32bit subtarget. + MVT CVT = MVT::getFloatingPointVT(SplatBitSize); + Constant *C = SplatBitSize == 32 + ? ConstantFP::get(Type::getFloatTy(*Ctx), + SplatValue.bitsToFloat()) + : ConstantFP::get(Type::getDoubleTy(*Ctx), + SplatValue.bitsToDouble()); + SDValue CP = DAG.getConstantPool(C, PVT); + unsigned Repeat = VT.getSizeInBits() / SplatBitSize; + + unsigned Alignment = cast<ConstantPoolSDNode>(CP)->getAlignment(); + Ld = DAG.getLoad( + CVT, dl, DAG.getEntryNode(), CP, + MachinePointerInfo::getConstantPool(DAG.getMachineFunction()), + Alignment); + SDValue Brdcst = DAG.getNode(X86ISD::VBROADCAST, dl, + MVT::getVectorVT(CVT, Repeat), Ld); + return DAG.getBitcast(VT, Brdcst); + } else if (SplatBitSize > 64) { + // Load the vector of constants and broadcast it. + MVT CVT = VT.getScalarType(); + Constant *VecC = getConstantVector(VT, SplatValue, SplatBitSize, + *Ctx); + SDValue VCP = DAG.getConstantPool(VecC, PVT); + unsigned NumElm = SplatBitSize / VT.getScalarSizeInBits(); + unsigned Alignment = cast<ConstantPoolSDNode>(VCP)->getAlignment(); + Ld = DAG.getLoad( + MVT::getVectorVT(CVT, NumElm), dl, DAG.getEntryNode(), VCP, + MachinePointerInfo::getConstantPool(DAG.getMachineFunction()), + Alignment); + SDValue Brdcst = DAG.getNode(X86ISD::SUBV_BROADCAST, dl, VT, Ld); + return DAG.getBitcast(VT, Brdcst); + } } - - Ld = Sc.getOperand(0); - ConstSplatVal = (Ld.getOpcode() == ISD::Constant || - Ld.getOpcode() == ISD::ConstantFP); - - // The scalar_to_vector node and the suspected - // load node must have exactly one user. - // Constants may have multiple users. - - // AVX-512 has register version of the broadcast - bool hasRegVer = Subtarget.hasAVX512() && VT.is512BitVector() && - Ld.getValueType().getSizeInBits() >= 32; - if (!ConstSplatVal && ((!Sc.hasOneUse() || !Ld.hasOneUse()) && - !hasRegVer)) - return SDValue(); - break; } + return SDValue(); } - unsigned ScalarSize = Ld.getValueType().getSizeInBits(); + bool ConstSplatVal = + (Ld.getOpcode() == ISD::Constant || Ld.getOpcode() == ISD::ConstantFP); + + // Make sure that all of the users of a non-constant load are from the + // BUILD_VECTOR node. + if (!ConstSplatVal && !BVOp->isOnlyUserOf(Ld.getNode())) + return SDValue(); + + unsigned ScalarSize = Ld.getValueSizeInBits(); bool IsGE256 = (VT.getSizeInBits() >= 256); // When optimizing for size, generate up to 5 extra bytes for a broadcast @@ -6025,8 +6714,7 @@ static SDValue ConvertI1VectorToInteger(SDValue Op, SelectionDAG &DAG) { Immediate |= cast<ConstantSDNode>(In)->getZExtValue() << idx; } SDLoc dl(Op); - MVT VT = - MVT::getIntegerVT(std::max((int)Op.getValueType().getSizeInBits(), 8)); + MVT VT = MVT::getIntegerVT(std::max((int)Op.getValueSizeInBits(), 8)); return DAG.getConstant(Immediate, dl, VT); } // Lower BUILD_VECTOR operation for v8i1 and v16i1 types. @@ -6273,23 +6961,24 @@ static SDValue ExpandHorizontalBinOp(const SDValue &V0, const SDValue &V1, return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, LO, HI); } -/// Try to fold a build_vector that performs an 'addsub' to an X86ISD::ADDSUB -/// node. -static SDValue LowerToAddSub(const BuildVectorSDNode *BV, - const X86Subtarget &Subtarget, SelectionDAG &DAG) { +/// Returns true iff \p BV builds a vector with the result equivalent to +/// the result of ADDSUB operation. +/// If true is returned then the operands of ADDSUB = Opnd0 +- Opnd1 operation +/// are written to the parameters \p Opnd0 and \p Opnd1. +static bool isAddSub(const BuildVectorSDNode *BV, + const X86Subtarget &Subtarget, SelectionDAG &DAG, + SDValue &Opnd0, SDValue &Opnd1) { + MVT VT = BV->getSimpleValueType(0); if ((!Subtarget.hasSSE3() || (VT != MVT::v4f32 && VT != MVT::v2f64)) && - (!Subtarget.hasAVX() || (VT != MVT::v8f32 && VT != MVT::v4f64))) - return SDValue(); + (!Subtarget.hasAVX() || (VT != MVT::v8f32 && VT != MVT::v4f64)) && + (!Subtarget.hasAVX512() || (VT != MVT::v16f32 && VT != MVT::v8f64))) + return false; - SDLoc DL(BV); unsigned NumElts = VT.getVectorNumElements(); SDValue InVec0 = DAG.getUNDEF(VT); SDValue InVec1 = DAG.getUNDEF(VT); - assert((VT == MVT::v8f32 || VT == MVT::v4f64 || VT == MVT::v4f32 || - VT == MVT::v2f64) && "build_vector with an invalid type found!"); - // Odd-numbered elements in the input build vector are obtained from // adding two integer/float elements. // Even-numbered elements in the input build vector are obtained from @@ -6311,7 +7000,7 @@ static SDValue LowerToAddSub(const BuildVectorSDNode *BV, // Early exit if we found an unexpected opcode. if (Opcode != ExpectedOpcode) - return SDValue(); + return false; SDValue Op0 = Op.getOperand(0); SDValue Op1 = Op.getOperand(1); @@ -6324,11 +7013,11 @@ static SDValue LowerToAddSub(const BuildVectorSDNode *BV, !isa<ConstantSDNode>(Op0.getOperand(1)) || !isa<ConstantSDNode>(Op1.getOperand(1)) || Op0.getOperand(1) != Op1.getOperand(1)) - return SDValue(); + return false; unsigned I0 = cast<ConstantSDNode>(Op0.getOperand(1))->getZExtValue(); if (I0 != i) - return SDValue(); + return false; // We found a valid add/sub node. Update the information accordingly. if (i & 1) @@ -6340,39 +7029,118 @@ static SDValue LowerToAddSub(const BuildVectorSDNode *BV, if (InVec0.isUndef()) { InVec0 = Op0.getOperand(0); if (InVec0.getSimpleValueType() != VT) - return SDValue(); + return false; } if (InVec1.isUndef()) { InVec1 = Op1.getOperand(0); if (InVec1.getSimpleValueType() != VT) - return SDValue(); + return false; } // Make sure that operands in input to each add/sub node always // come from a same pair of vectors. if (InVec0 != Op0.getOperand(0)) { if (ExpectedOpcode == ISD::FSUB) - return SDValue(); + return false; // FADD is commutable. Try to commute the operands // and then test again. std::swap(Op0, Op1); if (InVec0 != Op0.getOperand(0)) - return SDValue(); + return false; } if (InVec1 != Op1.getOperand(0)) - return SDValue(); + return false; // Update the pair of expected opcodes. std::swap(ExpectedOpcode, NextExpectedOpcode); } // Don't try to fold this build_vector into an ADDSUB if the inputs are undef. - if (AddFound && SubFound && !InVec0.isUndef() && !InVec1.isUndef()) - return DAG.getNode(X86ISD::ADDSUB, DL, VT, InVec0, InVec1); + if (!AddFound || !SubFound || InVec0.isUndef() || InVec1.isUndef()) + return false; - return SDValue(); + Opnd0 = InVec0; + Opnd1 = InVec1; + return true; +} + +/// Returns true if is possible to fold MUL and an idiom that has already been +/// recognized as ADDSUB(\p Opnd0, \p Opnd1) into FMADDSUB(x, y, \p Opnd1). +/// If (and only if) true is returned, the operands of FMADDSUB are written to +/// parameters \p Opnd0, \p Opnd1, \p Opnd2. +/// +/// Prior to calling this function it should be known that there is some +/// SDNode that potentially can be replaced with an X86ISD::ADDSUB operation +/// using \p Opnd0 and \p Opnd1 as operands. Also, this method is called +/// before replacement of such SDNode with ADDSUB operation. Thus the number +/// of \p Opnd0 uses is expected to be equal to 2. +/// For example, this function may be called for the following IR: +/// %AB = fmul fast <2 x double> %A, %B +/// %Sub = fsub fast <2 x double> %AB, %C +/// %Add = fadd fast <2 x double> %AB, %C +/// %Addsub = shufflevector <2 x double> %Sub, <2 x double> %Add, +/// <2 x i32> <i32 0, i32 3> +/// There is a def for %Addsub here, which potentially can be replaced by +/// X86ISD::ADDSUB operation: +/// %Addsub = X86ISD::ADDSUB %AB, %C +/// and such ADDSUB can further be replaced with FMADDSUB: +/// %Addsub = FMADDSUB %A, %B, %C. +/// +/// The main reason why this method is called before the replacement of the +/// recognized ADDSUB idiom with ADDSUB operation is that such replacement +/// is illegal sometimes. E.g. 512-bit ADDSUB is not available, while 512-bit +/// FMADDSUB is. +static bool isFMAddSub(const X86Subtarget &Subtarget, SelectionDAG &DAG, + SDValue &Opnd0, SDValue &Opnd1, SDValue &Opnd2) { + if (Opnd0.getOpcode() != ISD::FMUL || Opnd0->use_size() != 2 || + !Subtarget.hasAnyFMA()) + return false; + + // FIXME: These checks must match the similar ones in + // DAGCombiner::visitFADDForFMACombine. It would be good to have one + // function that would answer if it is Ok to fuse MUL + ADD to FMADD + // or MUL + ADDSUB to FMADDSUB. + const TargetOptions &Options = DAG.getTarget().Options; + bool AllowFusion = + (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath); + if (!AllowFusion) + return false; + + Opnd2 = Opnd1; + Opnd1 = Opnd0.getOperand(1); + Opnd0 = Opnd0.getOperand(0); + + return true; +} + +/// Try to fold a build_vector that performs an 'addsub' or 'fmaddsub' operation +/// accordingly to X86ISD::ADDSUB or X86ISD::FMADDSUB node. +static SDValue lowerToAddSubOrFMAddSub(const BuildVectorSDNode *BV, + const X86Subtarget &Subtarget, + SelectionDAG &DAG) { + SDValue Opnd0, Opnd1; + if (!isAddSub(BV, Subtarget, DAG, Opnd0, Opnd1)) + return SDValue(); + + MVT VT = BV->getSimpleValueType(0); + SDLoc DL(BV); + + // Try to generate X86ISD::FMADDSUB node here. + SDValue Opnd2; + if (isFMAddSub(Subtarget, DAG, Opnd0, Opnd1, Opnd2)) + return DAG.getNode(X86ISD::FMADDSUB, DL, VT, Opnd0, Opnd1, Opnd2); + + // Do not generate X86ISD::ADDSUB node for 512-bit types even though + // the ADDSUB idiom has been successfully recognized. There are no known + // X86 targets with 512-bit ADDSUB instructions! + // 512-bit ADDSUB idiom recognition was needed only as part of FMADDSUB idiom + // recognition. + if (VT.is512BitVector()) + return SDValue(); + + return DAG.getNode(X86ISD::ADDSUB, DL, VT, Opnd0, Opnd1); } /// Lower BUILD_VECTOR to a horizontal add/sub operation if possible. @@ -6510,17 +7278,18 @@ static SDValue LowerToHorizontalOp(const BuildVectorSDNode *BV, /// NOTE: Its not in our interest to start make a general purpose vectorizer /// from this, but enough scalar bit operations are created from the later /// legalization + scalarization stages to need basic support. -static SDValue lowerBuildVectorToBitOp(SDValue Op, SelectionDAG &DAG) { +static SDValue lowerBuildVectorToBitOp(BuildVectorSDNode *Op, + SelectionDAG &DAG) { SDLoc DL(Op); - MVT VT = Op.getSimpleValueType(); + MVT VT = Op->getSimpleValueType(0); unsigned NumElems = VT.getVectorNumElements(); const TargetLowering &TLI = DAG.getTargetLoweringInfo(); // Check that all elements have the same opcode. // TODO: Should we allow UNDEFS and if so how many? - unsigned Opcode = Op.getOperand(0).getOpcode(); + unsigned Opcode = Op->getOperand(0).getOpcode(); for (unsigned i = 1; i < NumElems; ++i) - if (Opcode != Op.getOperand(i).getOpcode()) + if (Opcode != Op->getOperand(i).getOpcode()) return SDValue(); // TODO: We may be able to add support for other Ops (ADD/SUB + shifts). @@ -6600,13 +7369,13 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const { return VectorConstant; BuildVectorSDNode *BV = cast<BuildVectorSDNode>(Op.getNode()); - if (SDValue AddSub = LowerToAddSub(BV, Subtarget, DAG)) + if (SDValue AddSub = lowerToAddSubOrFMAddSub(BV, Subtarget, DAG)) return AddSub; if (SDValue HorizontalOp = LowerToHorizontalOp(BV, Subtarget, DAG)) return HorizontalOp; - if (SDValue Broadcast = LowerVectorBroadcast(Op, Subtarget, DAG)) + if (SDValue Broadcast = LowerVectorBroadcast(BV, Subtarget, DAG)) return Broadcast; - if (SDValue BitOp = lowerBuildVectorToBitOp(Op, DAG)) + if (SDValue BitOp = lowerBuildVectorToBitOp(BV, DAG)) return BitOp; unsigned EVTBits = ExtVT.getSizeInBits(); @@ -6673,12 +7442,8 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const { if (ExtVT == MVT::i32 || ExtVT == MVT::f32 || ExtVT == MVT::f64 || (ExtVT == MVT::i64 && Subtarget.is64Bit())) { - if (VT.is512BitVector()) { - SDValue ZeroVec = getZeroVector(VT, Subtarget, DAG, dl); - return DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, VT, ZeroVec, - Item, DAG.getIntPtrConstant(0, dl)); - } - assert((VT.is128BitVector() || VT.is256BitVector()) && + assert((VT.is128BitVector() || VT.is256BitVector() || + VT.is512BitVector()) && "Expected an SSE value type!"); Item = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VT, Item); // Turn it into a MOVL (i.e. movss, movsd, or movd) to a zero vector. @@ -7088,6 +7853,7 @@ static bool isRepeatedShuffleMask(unsigned LaneSizeInBits, MVT VT, RepeatedMask.assign(LaneSize, -1); int Size = Mask.size(); for (int i = 0; i < Size; ++i) { + assert(Mask[i] == SM_SentinelUndef || Mask[i] >= 0); if (Mask[i] < 0) continue; if ((Mask[i] % Size) / LaneSize != i / LaneSize) @@ -7122,26 +7888,40 @@ is256BitLaneRepeatedShuffleMask(MVT VT, ArrayRef<int> Mask, return isRepeatedShuffleMask(256, VT, Mask, RepeatedMask); } -static void scaleShuffleMask(int Scale, ArrayRef<int> Mask, - SmallVectorImpl<int> &ScaledMask) { - assert(0 < Scale && "Unexpected scaling factor"); - int NumElts = Mask.size(); - ScaledMask.assign(NumElts * Scale, -1); - - for (int i = 0; i != NumElts; ++i) { - int M = Mask[i]; - - // Repeat sentinel values in every mask element. - if (M < 0) { - for (int s = 0; s != Scale; ++s) - ScaledMask[(Scale * i) + s] = M; +/// Test whether a target shuffle mask is equivalent within each sub-lane. +/// Unlike isRepeatedShuffleMask we must respect SM_SentinelZero. +static bool isRepeatedTargetShuffleMask(unsigned LaneSizeInBits, MVT VT, + ArrayRef<int> Mask, + SmallVectorImpl<int> &RepeatedMask) { + int LaneSize = LaneSizeInBits / VT.getScalarSizeInBits(); + RepeatedMask.assign(LaneSize, SM_SentinelUndef); + int Size = Mask.size(); + for (int i = 0; i < Size; ++i) { + assert(isUndefOrZero(Mask[i]) || (Mask[i] >= 0)); + if (Mask[i] == SM_SentinelUndef) + continue; + if (Mask[i] == SM_SentinelZero) { + if (!isUndefOrZero(RepeatedMask[i % LaneSize])) + return false; + RepeatedMask[i % LaneSize] = SM_SentinelZero; continue; } + if ((Mask[i] % Size) / LaneSize != i / LaneSize) + // This entry crosses lanes, so there is no way to model this shuffle. + return false; - // Scale mask element and increment across each mask element. - for (int s = 0; s != Scale; ++s) - ScaledMask[(Scale * i) + s] = (Scale * M) + s; + // Ok, handle the in-lane shuffles by detecting if and when they repeat. + // Adjust second vector indices to start at LaneSize instead of Size. + int LocalM = + Mask[i] < Size ? Mask[i] % LaneSize : Mask[i] % LaneSize + LaneSize; + if (RepeatedMask[i % LaneSize] == SM_SentinelUndef) + // This is the first non-undef entry in this slot of a 128-bit lane. + RepeatedMask[i % LaneSize] = LocalM; + else if (RepeatedMask[i % LaneSize] != LocalM) + // Found a mismatch with the repeated mask. + return false; } + return true; } /// \brief Checks whether a shuffle mask is equivalent to an explicit list of @@ -7251,7 +8031,7 @@ static SmallBitVector computeZeroableShuffleElements(ArrayRef<int> Mask, bool V1IsZero = ISD::isBuildVectorAllZeros(V1.getNode()); bool V2IsZero = ISD::isBuildVectorAllZeros(V2.getNode()); - int VectorSizeInBits = V1.getValueType().getSizeInBits(); + int VectorSizeInBits = V1.getValueSizeInBits(); int ScalarSizeInBits = VectorSizeInBits / Mask.size(); assert(!(VectorSizeInBits % ScalarSizeInBits) && "Illegal shuffle mask size"); @@ -7309,11 +8089,42 @@ static SmallBitVector computeZeroableShuffleElements(ArrayRef<int> Mask, return Zeroable; } -/// Try to lower a shuffle with a single PSHUFB of V1. -/// This is only possible if V2 is unused (at all, or only for zero elements). +// The Shuffle result is as follow: +// 0*a[0]0*a[1]...0*a[n] , n >=0 where a[] elements in a ascending order. +// Each Zeroable's element correspond to a particular Mask's element. +// As described in computeZeroableShuffleElements function. +// +// The function looks for a sub-mask that the nonzero elements are in +// increasing order. If such sub-mask exist. The function returns true. +static bool isNonZeroElementsInOrder(const SmallBitVector Zeroable, + ArrayRef<int> Mask,const EVT &VectorType, + bool &IsZeroSideLeft) { + int NextElement = -1; + // Check if the Mask's nonzero elements are in increasing order. + for (int i = 0, e = Zeroable.size(); i < e; i++) { + // Checks if the mask's zeros elements are built from only zeros. + if (Mask[i] == -1) + return false; + if (Zeroable[i]) + continue; + // Find the lowest non zero element + if (NextElement == -1) { + NextElement = Mask[i] != 0 ? VectorType.getVectorNumElements() : 0; + IsZeroSideLeft = NextElement != 0; + } + // Exit if the mask's non zero elements are not in increasing order. + if (NextElement != Mask[i]) + return false; + NextElement++; + } + return true; +} + +/// Try to lower a shuffle with a single PSHUFB of V1 or V2. static SDValue lowerVectorShuffleWithPSHUFB(const SDLoc &DL, MVT VT, ArrayRef<int> Mask, SDValue V1, SDValue V2, + const SmallBitVector &Zeroable, const X86Subtarget &Subtarget, SelectionDAG &DAG) { int Size = Mask.size(); @@ -7325,12 +8136,11 @@ static SDValue lowerVectorShuffleWithPSHUFB(const SDLoc &DL, MVT VT, (Subtarget.hasAVX2() && VT.is256BitVector()) || (Subtarget.hasBWI() && VT.is512BitVector())); - SmallBitVector Zeroable = computeZeroableShuffleElements(Mask, V1, V2); - SmallVector<SDValue, 64> PSHUFBMask(NumBytes); // Sign bit set in i8 mask means zero element. SDValue ZeroMask = DAG.getConstant(0x80, DL, MVT::i8); + SDValue V; for (int i = 0; i < NumBytes; ++i) { int M = Mask[i / NumEltBytes]; if (M < 0) { @@ -7341,9 +8151,13 @@ static SDValue lowerVectorShuffleWithPSHUFB(const SDLoc &DL, MVT VT, PSHUFBMask[i] = ZeroMask; continue; } - // Only allow V1. - if (M >= Size) + + // We can only use a single input of V1 or V2. + SDValue SrcV = (M >= Size ? V2 : V1); + if (V && V != SrcV) return SDValue(); + V = SrcV; + M %= Size; // PSHUFB can't cross lanes, ensure this doesn't happen. if ((M / LaneSize) != ((i / NumEltBytes) / LaneSize)) @@ -7353,33 +8167,66 @@ static SDValue lowerVectorShuffleWithPSHUFB(const SDLoc &DL, MVT VT, M = M * NumEltBytes + (i % NumEltBytes); PSHUFBMask[i] = DAG.getConstant(M, DL, MVT::i8); } + assert(V && "Failed to find a source input"); MVT I8VT = MVT::getVectorVT(MVT::i8, NumBytes); return DAG.getBitcast( - VT, DAG.getNode(X86ISD::PSHUFB, DL, I8VT, DAG.getBitcast(I8VT, V1), + VT, DAG.getNode(X86ISD::PSHUFB, DL, I8VT, DAG.getBitcast(I8VT, V), DAG.getBuildVector(I8VT, DL, PSHUFBMask))); } +static SDValue getMaskNode(SDValue Mask, MVT MaskVT, + const X86Subtarget &Subtarget, SelectionDAG &DAG, + const SDLoc &dl); + +// Function convertBitVectorToUnsigned - The function gets SmallBitVector +// as argument and convert him to unsigned. +// The output of the function is not(zeroable) +static unsigned convertBitVectorToUnsiged(const SmallBitVector &Zeroable) { + unsigned convertBit = 0; + for (int i = 0, e = Zeroable.size(); i < e; i++) + convertBit |= !(Zeroable[i]) << i; + return convertBit; +} + +// X86 has dedicated shuffle that can be lowered to VEXPAND +static SDValue lowerVectorShuffleToEXPAND(const SDLoc &DL, MVT VT, + const SmallBitVector &Zeroable, + ArrayRef<int> Mask, SDValue &V1, + SDValue &V2, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + bool IsLeftZeroSide = true; + if (!isNonZeroElementsInOrder(Zeroable, Mask, V1.getValueType(), + IsLeftZeroSide)) + return SDValue(); + unsigned VEXPANDMask = convertBitVectorToUnsiged(Zeroable); + MVT IntegerType = + MVT::getIntegerVT(std::max((int)VT.getVectorNumElements(), 8)); + SDValue MaskNode = DAG.getConstant(VEXPANDMask, DL, IntegerType); + unsigned NumElts = VT.getVectorNumElements(); + assert((NumElts == 4 || NumElts == 8 || NumElts == 16) && + "Unexpected number of vector elements"); + SDValue VMask = getMaskNode(MaskNode, MVT::getVectorVT(MVT::i1, NumElts), + Subtarget, DAG, DL); + SDValue ZeroVector = getZeroVector(VT, Subtarget, DAG, DL); + SDValue ExpandedVector = IsLeftZeroSide ? V2 : V1; + return DAG.getNode(ISD::VSELECT, DL, VT, VMask, + DAG.getNode(X86ISD::EXPAND, DL, VT, ExpandedVector), + ZeroVector); +} + // X86 has dedicated unpack instructions that can handle specific blend // operations: UNPCKH and UNPCKL. static SDValue lowerVectorShuffleWithUNPCK(const SDLoc &DL, MVT VT, ArrayRef<int> Mask, SDValue V1, SDValue V2, SelectionDAG &DAG) { - int NumElts = VT.getVectorNumElements(); - int NumEltsInLane = 128 / VT.getScalarSizeInBits(); - SmallVector<int, 8> Unpckl(NumElts); - SmallVector<int, 8> Unpckh(NumElts); - - for (int i = 0; i < NumElts; ++i) { - unsigned LaneStart = (i / NumEltsInLane) * NumEltsInLane; - int LoPos = (i % NumEltsInLane) / 2 + LaneStart + NumElts * (i % 2); - int HiPos = LoPos + NumEltsInLane / 2; - Unpckl[i] = LoPos; - Unpckh[i] = HiPos; - } - + SmallVector<int, 8> Unpckl; + createUnpackShuffleMask(VT, Unpckl, /* Lo = */ true, /* Unary = */ false); if (isShuffleEquivalent(V1, V2, Mask, Unpckl)) return DAG.getNode(X86ISD::UNPCKL, DL, VT, V1, V2); + + SmallVector<int, 8> Unpckh; + createUnpackShuffleMask(VT, Unpckh, /* Lo = */ false, /* Unary = */ false); if (isShuffleEquivalent(V1, V2, Mask, Unpckh)) return DAG.getNode(X86ISD::UNPCKH, DL, VT, V1, V2); @@ -7401,19 +8248,14 @@ static SDValue lowerVectorShuffleWithUNPCK(const SDLoc &DL, MVT VT, /// one of the inputs being zeroable. static SDValue lowerVectorShuffleAsBitMask(const SDLoc &DL, MVT VT, SDValue V1, SDValue V2, ArrayRef<int> Mask, + const SmallBitVector &Zeroable, SelectionDAG &DAG) { + assert(!VT.isFloatingPoint() && "Floating point types are not supported"); MVT EltVT = VT.getVectorElementType(); - int NumEltBits = EltVT.getSizeInBits(); - MVT IntEltVT = MVT::getIntegerVT(NumEltBits); - SDValue Zero = DAG.getConstant(0, DL, IntEltVT); - SDValue AllOnes = DAG.getConstant(APInt::getAllOnesValue(NumEltBits), DL, - IntEltVT); - if (EltVT.isFloatingPoint()) { - Zero = DAG.getBitcast(EltVT, Zero); - AllOnes = DAG.getBitcast(EltVT, AllOnes); - } + SDValue Zero = DAG.getConstant(0, DL, EltVT); + SDValue AllOnes = + DAG.getConstant(APInt::getAllOnesValue(EltVT.getSizeInBits()), DL, EltVT); SmallVector<SDValue, 16> VMaskOps(Mask.size(), Zero); - SmallBitVector Zeroable = computeZeroableShuffleElements(Mask, V1, V2); SDValue V; for (int i = 0, Size = Mask.size(); i < Size; ++i) { if (Zeroable[i]) @@ -7431,10 +8273,7 @@ static SDValue lowerVectorShuffleAsBitMask(const SDLoc &DL, MVT VT, SDValue V1, return SDValue(); // No non-zeroable elements! SDValue VMask = DAG.getBuildVector(VT, DL, VMaskOps); - V = DAG.getNode(VT.isFloatingPoint() - ? (unsigned) X86ISD::FAND : (unsigned) ISD::AND, - DL, VT, V, VMask); - return V; + return DAG.getNode(ISD::AND, DL, VT, V, VMask); } /// \brief Try to emit a blend instruction for a shuffle using bit math. @@ -7476,12 +8315,12 @@ static SDValue lowerVectorShuffleAsBitBlend(const SDLoc &DL, MVT VT, SDValue V1, /// that the shuffle mask is a blend, or convertible into a blend with zero. static SDValue lowerVectorShuffleAsBlend(const SDLoc &DL, MVT VT, SDValue V1, SDValue V2, ArrayRef<int> Original, + const SmallBitVector &Zeroable, const X86Subtarget &Subtarget, SelectionDAG &DAG) { bool V1IsZero = ISD::isBuildVectorAllZeros(V1.getNode()); bool V2IsZero = ISD::isBuildVectorAllZeros(V2.getNode()); SmallVector<int, 8> Mask(Original.begin(), Original.end()); - SmallBitVector Zeroable = computeZeroableShuffleElements(Mask, V1, V2); bool ForceV1Zero = false, ForceV2Zero = false; // Attempt to generate the binary blend mask. If an input is zero then @@ -7540,7 +8379,7 @@ static SDValue lowerVectorShuffleAsBlend(const SDLoc &DL, MVT VT, SDValue V1, case MVT::v4i64: case MVT::v8i32: assert(Subtarget.hasAVX2() && "256-bit integer blends require AVX2!"); - // FALLTHROUGH + LLVM_FALLTHROUGH; case MVT::v2i64: case MVT::v4i32: // If we have AVX2 it is faster to use VPBLENDD when the shuffle fits into @@ -7556,7 +8395,7 @@ static SDValue lowerVectorShuffleAsBlend(const SDLoc &DL, MVT VT, SDValue V1, VT, DAG.getNode(X86ISD::BLENDI, DL, BlendVT, V1, V2, DAG.getConstant(BlendMask, DL, MVT::i8))); } - // FALLTHROUGH + LLVM_FALLTHROUGH; case MVT::v8i16: { // For integer shuffles we need to expand the mask and cast the inputs to // v8i16s prior to blending. @@ -7582,15 +8421,16 @@ static SDValue lowerVectorShuffleAsBlend(const SDLoc &DL, MVT VT, SDValue V1, return DAG.getNode(X86ISD::BLENDI, DL, MVT::v16i16, V1, V2, DAG.getConstant(BlendMask, DL, MVT::i8)); } + LLVM_FALLTHROUGH; } - // FALLTHROUGH case MVT::v16i8: case MVT::v32i8: { assert((VT.is128BitVector() || Subtarget.hasAVX2()) && "256-bit byte-blends require AVX2 support!"); // Attempt to lower to a bitmask if we can. VPAND is faster than VPBLENDVB. - if (SDValue Masked = lowerVectorShuffleAsBitMask(DL, VT, V1, V2, Mask, DAG)) + if (SDValue Masked = + lowerVectorShuffleAsBitMask(DL, VT, V1, V2, Mask, Zeroable, DAG)) return Masked; // Scale the blend by the number of bytes per element. @@ -7704,32 +8544,12 @@ static SDValue lowerVectorShuffleAsDecomposedShuffleBlend(const SDLoc &DL, return DAG.getVectorShuffle(VT, DL, V1, V2, BlendMask); } -/// \brief Try to lower a vector shuffle as a byte rotation. -/// -/// SSSE3 has a generic PALIGNR instruction in x86 that will do an arbitrary -/// byte-rotation of the concatenation of two vectors; pre-SSSE3 can use -/// a PSRLDQ/PSLLDQ/POR pattern to get a similar effect. This routine will -/// try to generically lower a vector shuffle through such an pattern. It -/// does not check for the profitability of lowering either as PALIGNR or -/// PSRLDQ/PSLLDQ/POR, only whether the mask is valid to lower in that form. -/// This matches shuffle vectors that look like: -/// -/// v8i16 [11, 12, 13, 14, 15, 0, 1, 2] +/// \brief Try to lower a vector shuffle as a rotation. /// -/// Essentially it concatenates V1 and V2, shifts right by some number of -/// elements, and takes the low elements as the result. Note that while this is -/// specified as a *right shift* because x86 is little-endian, it is a *left -/// rotate* of the vector lanes. -static SDValue lowerVectorShuffleAsByteRotate(const SDLoc &DL, MVT VT, - SDValue V1, SDValue V2, - ArrayRef<int> Mask, - const X86Subtarget &Subtarget, - SelectionDAG &DAG) { - assert(!isNoopShuffleMask(Mask) && "We shouldn't lower no-op shuffles!"); - +/// This is used for support PALIGNR for SSSE3 or VALIGND/Q for AVX512. +static int matchVectorShuffleAsRotate(SDValue &V1, SDValue &V2, + ArrayRef<int> Mask) { int NumElts = Mask.size(); - int NumLanes = VT.getSizeInBits() / 128; - int NumLaneElts = NumElts / NumLanes; // We need to detect various ways of spelling a rotation: // [11, 12, 13, 14, 15, 0, 1, 2] @@ -7740,51 +8560,46 @@ static SDValue lowerVectorShuffleAsByteRotate(const SDLoc &DL, MVT VT, // [-1, 4, 5, 6, -1, -1, -1, -1] int Rotation = 0; SDValue Lo, Hi; - for (int l = 0; l < NumElts; l += NumLaneElts) { - for (int i = 0; i < NumLaneElts; ++i) { - if (Mask[l + i] < 0) - continue; - - // Get the mod-Size index and lane correct it. - int LaneIdx = (Mask[l + i] % NumElts) - l; - // Make sure it was in this lane. - if (LaneIdx < 0 || LaneIdx >= NumLaneElts) - return SDValue(); + for (int i = 0; i < NumElts; ++i) { + int M = Mask[i]; + assert((M == SM_SentinelUndef || (0 <= M && M < (2*NumElts))) && + "Unexpected mask index."); + if (M < 0) + continue; - // Determine where a rotated vector would have started. - int StartIdx = i - LaneIdx; - if (StartIdx == 0) - // The identity rotation isn't interesting, stop. - return SDValue(); + // Determine where a rotated vector would have started. + int StartIdx = i - (M % NumElts); + if (StartIdx == 0) + // The identity rotation isn't interesting, stop. + return -1; - // If we found the tail of a vector the rotation must be the missing - // front. If we found the head of a vector, it must be how much of the - // head. - int CandidateRotation = StartIdx < 0 ? -StartIdx : NumLaneElts - StartIdx; + // If we found the tail of a vector the rotation must be the missing + // front. If we found the head of a vector, it must be how much of the + // head. + int CandidateRotation = StartIdx < 0 ? -StartIdx : NumElts - StartIdx; - if (Rotation == 0) - Rotation = CandidateRotation; - else if (Rotation != CandidateRotation) - // The rotations don't match, so we can't match this mask. - return SDValue(); + if (Rotation == 0) + Rotation = CandidateRotation; + else if (Rotation != CandidateRotation) + // The rotations don't match, so we can't match this mask. + return -1; - // Compute which value this mask is pointing at. - SDValue MaskV = Mask[l + i] < NumElts ? V1 : V2; - - // Compute which of the two target values this index should be assigned - // to. This reflects whether the high elements are remaining or the low - // elements are remaining. - SDValue &TargetV = StartIdx < 0 ? Hi : Lo; - - // Either set up this value if we've not encountered it before, or check - // that it remains consistent. - if (!TargetV) - TargetV = MaskV; - else if (TargetV != MaskV) - // This may be a rotation, but it pulls from the inputs in some - // unsupported interleaving. - return SDValue(); - } + // Compute which value this mask is pointing at. + SDValue MaskV = M < NumElts ? V1 : V2; + + // Compute which of the two target values this index should be assigned + // to. This reflects whether the high elements are remaining or the low + // elements are remaining. + SDValue &TargetV = StartIdx < 0 ? Hi : Lo; + + // Either set up this value if we've not encountered it before, or check + // that it remains consistent. + if (!TargetV) + TargetV = MaskV; + else if (TargetV != MaskV) + // This may be a rotation, but it pulls from the inputs in some + // unsupported interleaving. + return -1; } // Check that we successfully analyzed the mask, and normalize the results. @@ -7795,23 +8610,75 @@ static SDValue lowerVectorShuffleAsByteRotate(const SDLoc &DL, MVT VT, else if (!Hi) Hi = Lo; + V1 = Lo; + V2 = Hi; + + return Rotation; +} + +/// \brief Try to lower a vector shuffle as a byte rotation. +/// +/// SSSE3 has a generic PALIGNR instruction in x86 that will do an arbitrary +/// byte-rotation of the concatenation of two vectors; pre-SSSE3 can use +/// a PSRLDQ/PSLLDQ/POR pattern to get a similar effect. This routine will +/// try to generically lower a vector shuffle through such an pattern. It +/// does not check for the profitability of lowering either as PALIGNR or +/// PSRLDQ/PSLLDQ/POR, only whether the mask is valid to lower in that form. +/// This matches shuffle vectors that look like: +/// +/// v8i16 [11, 12, 13, 14, 15, 0, 1, 2] +/// +/// Essentially it concatenates V1 and V2, shifts right by some number of +/// elements, and takes the low elements as the result. Note that while this is +/// specified as a *right shift* because x86 is little-endian, it is a *left +/// rotate* of the vector lanes. +static int matchVectorShuffleAsByteRotate(MVT VT, SDValue &V1, SDValue &V2, + ArrayRef<int> Mask) { + // Don't accept any shuffles with zero elements. + if (any_of(Mask, [](int M) { return M == SM_SentinelZero; })) + return -1; + + // PALIGNR works on 128-bit lanes. + SmallVector<int, 16> RepeatedMask; + if (!is128BitLaneRepeatedShuffleMask(VT, Mask, RepeatedMask)) + return -1; + + int Rotation = matchVectorShuffleAsRotate(V1, V2, RepeatedMask); + if (Rotation <= 0) + return -1; + + // PALIGNR rotates bytes, so we need to scale the + // rotation based on how many bytes are in the vector lane. + int NumElts = RepeatedMask.size(); + int Scale = 16 / NumElts; + return Rotation * Scale; +} + +static SDValue lowerVectorShuffleAsByteRotate(const SDLoc &DL, MVT VT, + SDValue V1, SDValue V2, + ArrayRef<int> Mask, + const X86Subtarget &Subtarget, + SelectionDAG &DAG) { + assert(!isNoopShuffleMask(Mask) && "We shouldn't lower no-op shuffles!"); + + SDValue Lo = V1, Hi = V2; + int ByteRotation = matchVectorShuffleAsByteRotate(VT, Lo, Hi, Mask); + if (ByteRotation <= 0) + return SDValue(); + // Cast the inputs to i8 vector of correct length to match PALIGNR or // PSLLDQ/PSRLDQ. - MVT ByteVT = MVT::getVectorVT(MVT::i8, 16 * NumLanes); + MVT ByteVT = MVT::getVectorVT(MVT::i8, VT.getSizeInBits() / 8); Lo = DAG.getBitcast(ByteVT, Lo); Hi = DAG.getBitcast(ByteVT, Hi); - // The actual rotate instruction rotates bytes, so we need to scale the - // rotation based on how many bytes are in the vector lane. - int Scale = 16 / NumLaneElts; - // SSSE3 targets can use the palignr instruction. if (Subtarget.hasSSSE3()) { assert((!VT.is512BitVector() || Subtarget.hasBWI()) && "512-bit PALIGNR requires BWI instructions"); return DAG.getBitcast( VT, DAG.getNode(X86ISD::PALIGNR, DL, ByteVT, Lo, Hi, - DAG.getConstant(Rotation * Scale, DL, MVT::i8))); + DAG.getConstant(ByteRotation, DL, MVT::i8))); } assert(VT.is128BitVector() && @@ -7822,8 +8689,8 @@ static SDValue lowerVectorShuffleAsByteRotate(const SDLoc &DL, MVT VT, "SSE2 rotate lowering only needed for v16i8!"); // Default SSE2 implementation - int LoByteShift = 16 - Rotation * Scale; - int HiByteShift = Rotation * Scale; + int LoByteShift = 16 - ByteRotation; + int HiByteShift = ByteRotation; SDValue LoShift = DAG.getNode(X86ISD::VSHLDQ, DL, MVT::v16i8, Lo, DAG.getConstant(LoByteShift, DL, MVT::i8)); @@ -7833,6 +8700,37 @@ static SDValue lowerVectorShuffleAsByteRotate(const SDLoc &DL, MVT VT, DAG.getNode(ISD::OR, DL, MVT::v16i8, LoShift, HiShift)); } +/// \brief Try to lower a vector shuffle as a dword/qword rotation. +/// +/// AVX512 has a VALIGND/VALIGNQ instructions that will do an arbitrary +/// rotation of the concatenation of two vectors; This routine will +/// try to generically lower a vector shuffle through such an pattern. +/// +/// Essentially it concatenates V1 and V2, shifts right by some number of +/// elements, and takes the low elements as the result. Note that while this is +/// specified as a *right shift* because x86 is little-endian, it is a *left +/// rotate* of the vector lanes. +static SDValue lowerVectorShuffleAsRotate(const SDLoc &DL, MVT VT, + SDValue V1, SDValue V2, + ArrayRef<int> Mask, + const X86Subtarget &Subtarget, + SelectionDAG &DAG) { + assert((VT.getScalarType() == MVT::i32 || VT.getScalarType() == MVT::i64) && + "Only 32-bit and 64-bit elements are supported!"); + + // 128/256-bit vectors are only supported with VLX. + assert((Subtarget.hasVLX() || (!VT.is128BitVector() && !VT.is256BitVector())) + && "VLX required for 128/256-bit vectors"); + + SDValue Lo = V1, Hi = V2; + int Rotation = matchVectorShuffleAsRotate(Lo, Hi, Mask); + if (Rotation <= 0) + return SDValue(); + + return DAG.getNode(X86ISD::VALIGN, DL, VT, Lo, Hi, + DAG.getConstant(Rotation, DL, MVT::i8)); +} + /// \brief Try to lower a vector shuffle as a bit shift (shifts in zeros). /// /// Attempts to match a shuffle mask against the PSLL(W/D/Q/DQ) and @@ -7856,14 +8754,13 @@ static SDValue lowerVectorShuffleAsByteRotate(const SDLoc &DL, MVT VT, /// [ 5, 6, 7, zz, zz, zz, zz, zz] /// [ -1, 5, 6, 7, zz, zz, zz, zz] /// [ 1, 2, -1, -1, -1, -1, zz, zz] -static SDValue lowerVectorShuffleAsShift(const SDLoc &DL, MVT VT, SDValue V1, - SDValue V2, ArrayRef<int> Mask, - const X86Subtarget &Subtarget, - SelectionDAG &DAG) { - SmallBitVector Zeroable = computeZeroableShuffleElements(Mask, V1, V2); - +static int matchVectorShuffleAsShift(MVT &ShiftVT, unsigned &Opcode, + unsigned ScalarSizeInBits, + ArrayRef<int> Mask, int MaskOffset, + const SmallBitVector &Zeroable, + const X86Subtarget &Subtarget) { int Size = Mask.size(); - assert(Size == (int)VT.getVectorNumElements() && "Unexpected mask size"); + unsigned SizeInBits = Size * ScalarSizeInBits; auto CheckZeros = [&](int Shift, int Scale, bool Left) { for (int i = 0; i < Size; i += Scale) @@ -7874,37 +8771,30 @@ static SDValue lowerVectorShuffleAsShift(const SDLoc &DL, MVT VT, SDValue V1, return true; }; - auto MatchShift = [&](int Shift, int Scale, bool Left, SDValue V) { + auto MatchShift = [&](int Shift, int Scale, bool Left) { for (int i = 0; i != Size; i += Scale) { unsigned Pos = Left ? i + Shift : i; unsigned Low = Left ? i : i + Shift; unsigned Len = Scale - Shift; - if (!isSequentialOrUndefInRange(Mask, Pos, Len, - Low + (V == V1 ? 0 : Size))) - return SDValue(); + if (!isSequentialOrUndefInRange(Mask, Pos, Len, Low + MaskOffset)) + return -1; } - int ShiftEltBits = VT.getScalarSizeInBits() * Scale; + int ShiftEltBits = ScalarSizeInBits * Scale; bool ByteShift = ShiftEltBits > 64; - unsigned OpCode = Left ? (ByteShift ? X86ISD::VSHLDQ : X86ISD::VSHLI) - : (ByteShift ? X86ISD::VSRLDQ : X86ISD::VSRLI); - int ShiftAmt = Shift * VT.getScalarSizeInBits() / (ByteShift ? 8 : 1); + Opcode = Left ? (ByteShift ? X86ISD::VSHLDQ : X86ISD::VSHLI) + : (ByteShift ? X86ISD::VSRLDQ : X86ISD::VSRLI); + int ShiftAmt = Shift * ScalarSizeInBits / (ByteShift ? 8 : 1); // Normalize the scale for byte shifts to still produce an i64 element // type. Scale = ByteShift ? Scale / 2 : Scale; // We need to round trip through the appropriate type for the shift. - MVT ShiftSVT = MVT::getIntegerVT(VT.getScalarSizeInBits() * Scale); - MVT ShiftVT = ByteShift ? MVT::getVectorVT(MVT::i8, VT.getSizeInBits() / 8) - : MVT::getVectorVT(ShiftSVT, Size / Scale); - assert(DAG.getTargetLoweringInfo().isTypeLegal(ShiftVT) && - "Illegal integer vector type"); - V = DAG.getBitcast(ShiftVT, V); - - V = DAG.getNode(OpCode, DL, ShiftVT, V, - DAG.getConstant(ShiftAmt, DL, MVT::i8)); - return DAG.getBitcast(VT, V); + MVT ShiftSVT = MVT::getIntegerVT(ScalarSizeInBits * Scale); + ShiftVT = ByteShift ? MVT::getVectorVT(MVT::i8, SizeInBits / 8) + : MVT::getVectorVT(ShiftSVT, Size / Scale); + return (int)ShiftAmt; }; // SSE/AVX supports logical shifts up to 64-bit integers - so we can just @@ -7913,29 +8803,64 @@ static SDValue lowerVectorShuffleAsShift(const SDLoc &DL, MVT VT, SDValue V1, // their width within the elements of the larger integer vector. Test each // multiple to see if we can find a match with the moved element indices // and that the shifted in elements are all zeroable. - unsigned MaxWidth = (VT.is512BitVector() && !Subtarget.hasBWI() ? 64 : 128); - for (int Scale = 2; Scale * VT.getScalarSizeInBits() <= MaxWidth; Scale *= 2) + unsigned MaxWidth = ((SizeInBits == 512) && !Subtarget.hasBWI() ? 64 : 128); + for (int Scale = 2; Scale * ScalarSizeInBits <= MaxWidth; Scale *= 2) for (int Shift = 1; Shift != Scale; ++Shift) for (bool Left : {true, false}) - if (CheckZeros(Shift, Scale, Left)) - for (SDValue V : {V1, V2}) - if (SDValue Match = MatchShift(Shift, Scale, Left, V)) - return Match; + if (CheckZeros(Shift, Scale, Left)) { + int ShiftAmt = MatchShift(Shift, Scale, Left); + if (0 < ShiftAmt) + return ShiftAmt; + } // no match - return SDValue(); + return -1; +} + +static SDValue lowerVectorShuffleAsShift(const SDLoc &DL, MVT VT, SDValue V1, + SDValue V2, ArrayRef<int> Mask, + const SmallBitVector &Zeroable, + const X86Subtarget &Subtarget, + SelectionDAG &DAG) { + int Size = Mask.size(); + assert(Size == (int)VT.getVectorNumElements() && "Unexpected mask size"); + + MVT ShiftVT; + SDValue V = V1; + unsigned Opcode; + + // Try to match shuffle against V1 shift. + int ShiftAmt = matchVectorShuffleAsShift( + ShiftVT, Opcode, VT.getScalarSizeInBits(), Mask, 0, Zeroable, Subtarget); + + // If V1 failed, try to match shuffle against V2 shift. + if (ShiftAmt < 0) { + ShiftAmt = + matchVectorShuffleAsShift(ShiftVT, Opcode, VT.getScalarSizeInBits(), + Mask, Size, Zeroable, Subtarget); + V = V2; + } + + if (ShiftAmt < 0) + return SDValue(); + + assert(DAG.getTargetLoweringInfo().isTypeLegal(ShiftVT) && + "Illegal integer vector type"); + V = DAG.getBitcast(ShiftVT, V); + V = DAG.getNode(Opcode, DL, ShiftVT, V, + DAG.getConstant(ShiftAmt, DL, MVT::i8)); + return DAG.getBitcast(VT, V); } /// \brief Try to lower a vector shuffle using SSE4a EXTRQ/INSERTQ. static SDValue lowerVectorShuffleWithSSE4A(const SDLoc &DL, MVT VT, SDValue V1, SDValue V2, ArrayRef<int> Mask, + const SmallBitVector &Zeroable, SelectionDAG &DAG) { - SmallBitVector Zeroable = computeZeroableShuffleElements(Mask, V1, V2); - assert(!Zeroable.all() && "Fully zeroable shuffle mask"); - int Size = Mask.size(); int HalfSize = Size / 2; assert(Size == (int)VT.getVectorNumElements() && "Unexpected mask size"); + assert(!Zeroable.all() && "Fully zeroable shuffle mask"); // Upper half must be undefined. if (!isUndefInRange(Mask, HalfSize, HalfSize)) @@ -8111,8 +9036,10 @@ static SDValue lowerVectorShuffleAsSpecificZeroOrAnyExtend( InputV = ShuffleOffset(InputV); // For 256-bit vectors, we only need the lower (128-bit) input half. - if (VT.is256BitVector()) - InputV = extract128BitVector(InputV, 0, DAG, DL); + // For 512-bit vectors, we only need the lower input half or quarter. + if (VT.getSizeInBits() > 128) + InputV = extractSubVector(InputV, 0, DAG, DL, + std::max(128, (int)VT.getSizeInBits() / Scale)); InputV = DAG.getNode(X86ISD::VZEXT, DL, ExtVT, InputV); return DAG.getBitcast(VT, InputV); @@ -8231,9 +9158,8 @@ static SDValue lowerVectorShuffleAsSpecificZeroOrAnyExtend( /// are both incredibly common and often quite performance sensitive. static SDValue lowerVectorShuffleAsZeroOrAnyExtend( const SDLoc &DL, MVT VT, SDValue V1, SDValue V2, ArrayRef<int> Mask, - const X86Subtarget &Subtarget, SelectionDAG &DAG) { - SmallBitVector Zeroable = computeZeroableShuffleElements(Mask, V1, V2); - + const SmallBitVector &Zeroable, const X86Subtarget &Subtarget, + SelectionDAG &DAG) { int Bits = VT.getSizeInBits(); int NumLanes = Bits / 128; int NumElements = VT.getVectorNumElements(); @@ -8388,14 +9314,14 @@ static bool isShuffleFoldableLoad(SDValue V) { /// across all subtarget feature sets. static SDValue lowerVectorShuffleAsElementInsertion( const SDLoc &DL, MVT VT, SDValue V1, SDValue V2, ArrayRef<int> Mask, - const X86Subtarget &Subtarget, SelectionDAG &DAG) { - SmallBitVector Zeroable = computeZeroableShuffleElements(Mask, V1, V2); + const SmallBitVector &Zeroable, const X86Subtarget &Subtarget, + SelectionDAG &DAG) { MVT ExtVT = VT; MVT EltVT = VT.getVectorElementType(); - int V2Index = std::find_if(Mask.begin(), Mask.end(), - [&Mask](int M) { return M >= (int)Mask.size(); }) - - Mask.begin(); + int V2Index = + find_if(Mask, [&Mask](int M) { return M >= (int)Mask.size(); }) - + Mask.begin(); bool IsV1Zeroable = true; for (int i = 0, Size = Mask.size(); i < Size; ++i) if (i != V2Index && !Zeroable[i]) { @@ -8709,6 +9635,13 @@ static SDValue lowerVectorShuffleAsBroadcast(const SDLoc &DL, MVT VT, V = DAG.getBitcast(SrcVT, V); } + // 32-bit targets need to load i64 as a f64 and then bitcast the result. + if (!Subtarget.is64Bit() && SrcVT == MVT::i64) { + V = DAG.getBitcast(MVT::f64, V); + unsigned NumBroadcastElts = BroadcastVT.getVectorNumElements(); + BroadcastVT = MVT::getVectorVT(MVT::f64, NumBroadcastElts); + } + return DAG.getBitcast(VT, DAG.getNode(Opcode, DL, BroadcastVT, V)); } @@ -8726,71 +9659,93 @@ static bool matchVectorShuffleAsInsertPS(SDValue &V1, SDValue &V2, assert(V1.getSimpleValueType().is128BitVector() && "Bad operand type!"); assert(V2.getSimpleValueType().is128BitVector() && "Bad operand type!"); assert(Mask.size() == 4 && "Unexpected mask size for v4 shuffle!"); - unsigned ZMask = 0; - int V1DstIndex = -1; - int V2DstIndex = -1; - bool V1UsedInPlace = false; - for (int i = 0; i < 4; ++i) { - // Synthesize a zero mask from the zeroable elements (includes undefs). - if (Zeroable[i]) { - ZMask |= 1 << i; - continue; - } + // Attempt to match INSERTPS with one element from VA or VB being + // inserted into VA (or undef). If successful, V1, V2 and InsertPSMask + // are updated. + auto matchAsInsertPS = [&](SDValue VA, SDValue VB, + ArrayRef<int> CandidateMask) { + unsigned ZMask = 0; + int VADstIndex = -1; + int VBDstIndex = -1; + bool VAUsedInPlace = false; + + for (int i = 0; i < 4; ++i) { + // Synthesize a zero mask from the zeroable elements (includes undefs). + if (Zeroable[i]) { + ZMask |= 1 << i; + continue; + } - // Flag if we use any V1 inputs in place. - if (i == Mask[i]) { - V1UsedInPlace = true; - continue; + // Flag if we use any VA inputs in place. + if (i == CandidateMask[i]) { + VAUsedInPlace = true; + continue; + } + + // We can only insert a single non-zeroable element. + if (VADstIndex >= 0 || VBDstIndex >= 0) + return false; + + if (CandidateMask[i] < 4) { + // VA input out of place for insertion. + VADstIndex = i; + } else { + // VB input for insertion. + VBDstIndex = i; + } } - // We can only insert a single non-zeroable element. - if (V1DstIndex >= 0 || V2DstIndex >= 0) + // Don't bother if we have no (non-zeroable) element for insertion. + if (VADstIndex < 0 && VBDstIndex < 0) return false; - if (Mask[i] < 4) { - // V1 input out of place for insertion. - V1DstIndex = i; + // Determine element insertion src/dst indices. The src index is from the + // start of the inserted vector, not the start of the concatenated vector. + unsigned VBSrcIndex = 0; + if (VADstIndex >= 0) { + // If we have a VA input out of place, we use VA as the V2 element + // insertion and don't use the original V2 at all. + VBSrcIndex = CandidateMask[VADstIndex]; + VBDstIndex = VADstIndex; + VB = VA; } else { - // V2 input for insertion. - V2DstIndex = i; + VBSrcIndex = CandidateMask[VBDstIndex] - 4; } - } - // Don't bother if we have no (non-zeroable) element for insertion. - if (V1DstIndex < 0 && V2DstIndex < 0) - return false; + // If no V1 inputs are used in place, then the result is created only from + // the zero mask and the V2 insertion - so remove V1 dependency. + if (!VAUsedInPlace) + VA = DAG.getUNDEF(MVT::v4f32); - // Determine element insertion src/dst indices. The src index is from the - // start of the inserted vector, not the start of the concatenated vector. - unsigned V2SrcIndex = 0; - if (V1DstIndex >= 0) { - // If we have a V1 input out of place, we use V1 as the V2 element insertion - // and don't use the original V2 at all. - V2SrcIndex = Mask[V1DstIndex]; - V2DstIndex = V1DstIndex; - V2 = V1; - } else { - V2SrcIndex = Mask[V2DstIndex] - 4; - } + // Update V1, V2 and InsertPSMask accordingly. + V1 = VA; + V2 = VB; - // If no V1 inputs are used in place, then the result is created only from - // the zero mask and the V2 insertion - so remove V1 dependency. - if (!V1UsedInPlace) - V1 = DAG.getUNDEF(MVT::v4f32); + // Insert the V2 element into the desired position. + InsertPSMask = VBSrcIndex << 6 | VBDstIndex << 4 | ZMask; + assert((InsertPSMask & ~0xFFu) == 0 && "Invalid mask!"); + return true; + }; - // Insert the V2 element into the desired position. - InsertPSMask = V2SrcIndex << 6 | V2DstIndex << 4 | ZMask; - assert((InsertPSMask & ~0xFFu) == 0 && "Invalid mask!"); - return true; + if (matchAsInsertPS(V1, V2, Mask)) + return true; + + // Commute and try again. + SmallVector<int, 4> CommutedMask(Mask.begin(), Mask.end()); + ShuffleVectorSDNode::commuteMask(CommutedMask); + if (matchAsInsertPS(V2, V1, CommutedMask)) + return true; + + return false; } static SDValue lowerVectorShuffleAsInsertPS(const SDLoc &DL, SDValue V1, SDValue V2, ArrayRef<int> Mask, + const SmallBitVector &Zeroable, SelectionDAG &DAG) { assert(V1.getSimpleValueType() == MVT::v4f32 && "Bad operand type!"); assert(V2.getSimpleValueType() == MVT::v4f32 && "Bad operand type!"); - SmallBitVector Zeroable = computeZeroableShuffleElements(Mask, V1, V2); // Attempt to match the insertps pattern. unsigned InsertPSMask; @@ -8922,6 +9877,7 @@ static SDValue lowerVectorShuffleAsPermuteAndUnpack(const SDLoc &DL, MVT VT, /// it is better to avoid lowering through this for integer vectors where /// possible. static SDValue lowerV2F64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, + const SmallBitVector &Zeroable, SDValue V1, SDValue V2, const X86Subtarget &Subtarget, SelectionDAG &DAG) { @@ -8946,8 +9902,11 @@ static SDValue lowerV2F64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, DAG.getConstant(SHUFPDMask, DL, MVT::i8)); } - return DAG.getNode(X86ISD::SHUFP, DL, MVT::v2f64, V1, V1, - DAG.getConstant(SHUFPDMask, DL, MVT::i8)); + return DAG.getNode( + X86ISD::SHUFP, DL, MVT::v2f64, + Mask[0] == SM_SentinelUndef ? DAG.getUNDEF(MVT::v2f64) : V1, + Mask[1] == SM_SentinelUndef ? DAG.getUNDEF(MVT::v2f64) : V1, + DAG.getConstant(SHUFPDMask, DL, MVT::i8)); } assert(Mask[0] >= 0 && Mask[0] < 2 && "Non-canonicalized blend!"); assert(Mask[1] >= 2 && "Non-canonicalized blend!"); @@ -8955,14 +9914,14 @@ static SDValue lowerV2F64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // If we have a single input, insert that into V1 if we can do so cheaply. if ((Mask[0] >= 2) + (Mask[1] >= 2) == 1) { if (SDValue Insertion = lowerVectorShuffleAsElementInsertion( - DL, MVT::v2f64, V1, V2, Mask, Subtarget, DAG)) + DL, MVT::v2f64, V1, V2, Mask, Zeroable, Subtarget, DAG)) return Insertion; // Try inverting the insertion since for v2 masks it is easy to do and we // can't reliably sort the mask one way or the other. int InverseMask[2] = {Mask[0] < 0 ? -1 : (Mask[0] ^ 2), Mask[1] < 0 ? -1 : (Mask[1] ^ 2)}; if (SDValue Insertion = lowerVectorShuffleAsElementInsertion( - DL, MVT::v2f64, V2, V1, InverseMask, Subtarget, DAG)) + DL, MVT::v2f64, V2, V1, InverseMask, Zeroable, Subtarget, DAG)) return Insertion; } @@ -8980,7 +9939,7 @@ static SDValue lowerV2F64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, if (Subtarget.hasSSE41()) if (SDValue Blend = lowerVectorShuffleAsBlend(DL, MVT::v2f64, V1, V2, Mask, - Subtarget, DAG)) + Zeroable, Subtarget, DAG)) return Blend; // Use dedicated unpack instructions for masks that match their pattern. @@ -9000,6 +9959,7 @@ static SDValue lowerV2F64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, /// it falls back to the floating point shuffle operation with appropriate bit /// casting. static SDValue lowerV2I64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, + const SmallBitVector &Zeroable, SDValue V1, SDValue V2, const X86Subtarget &Subtarget, SelectionDAG &DAG) { @@ -9052,19 +10012,19 @@ static SDValue lowerV2I64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // Try to use shift instructions. if (SDValue Shift = lowerVectorShuffleAsShift(DL, MVT::v2i64, V1, V2, Mask, - Subtarget, DAG)) + Zeroable, Subtarget, DAG)) return Shift; // When loading a scalar and then shuffling it into a vector we can often do // the insertion cheaply. if (SDValue Insertion = lowerVectorShuffleAsElementInsertion( - DL, MVT::v2i64, V1, V2, Mask, Subtarget, DAG)) + DL, MVT::v2i64, V1, V2, Mask, Zeroable, Subtarget, DAG)) return Insertion; // Try inverting the insertion since for v2 masks it is easy to do and we // can't reliably sort the mask one way or the other. int InverseMask[2] = {Mask[0] ^ 2, Mask[1] ^ 2}; if (SDValue Insertion = lowerVectorShuffleAsElementInsertion( - DL, MVT::v2i64, V2, V1, InverseMask, Subtarget, DAG)) + DL, MVT::v2i64, V2, V1, InverseMask, Zeroable, Subtarget, DAG)) return Insertion; // We have different paths for blend lowering, but they all must use the @@ -9072,7 +10032,7 @@ static SDValue lowerV2I64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, bool IsBlendSupported = Subtarget.hasSSE41(); if (IsBlendSupported) if (SDValue Blend = lowerVectorShuffleAsBlend(DL, MVT::v2i64, V1, V2, Mask, - Subtarget, DAG)) + Zeroable, Subtarget, DAG)) return Blend; // Use dedicated unpack instructions for masks that match their pattern. @@ -9139,9 +10099,7 @@ static SDValue lowerVectorShuffleWithSHUFPS(const SDLoc &DL, MVT VT, int NumV2Elements = count_if(Mask, [](int M) { return M >= 4; }); if (NumV2Elements == 1) { - int V2Index = - std::find_if(Mask.begin(), Mask.end(), [](int M) { return M >= 4; }) - - Mask.begin(); + int V2Index = find_if(Mask, [](int M) { return M >= 4; }) - Mask.begin(); // Compute the index adjacent to V2Index and in the same half by toggling // the low bit. @@ -9220,6 +10178,7 @@ static SDValue lowerVectorShuffleWithSHUFPS(const SDLoc &DL, MVT VT, /// domain crossing penalties, as these are sufficient to implement all v4f32 /// shuffles. static SDValue lowerV4F32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, + const SmallBitVector &Zeroable, SDValue V1, SDValue V2, const X86Subtarget &Subtarget, SelectionDAG &DAG) { @@ -9262,17 +10221,18 @@ static SDValue lowerV4F32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // when the V2 input is targeting element 0 of the mask -- that is the fast // case here. if (NumV2Elements == 1 && Mask[0] >= 4) - if (SDValue V = lowerVectorShuffleAsElementInsertion(DL, MVT::v4f32, V1, V2, - Mask, Subtarget, DAG)) + if (SDValue V = lowerVectorShuffleAsElementInsertion( + DL, MVT::v4f32, V1, V2, Mask, Zeroable, Subtarget, DAG)) return V; if (Subtarget.hasSSE41()) { if (SDValue Blend = lowerVectorShuffleAsBlend(DL, MVT::v4f32, V1, V2, Mask, - Subtarget, DAG)) + Zeroable, Subtarget, DAG)) return Blend; // Use INSERTPS if we can complete the shuffle efficiently. - if (SDValue V = lowerVectorShuffleAsInsertPS(DL, V1, V2, Mask, DAG)) + if (SDValue V = + lowerVectorShuffleAsInsertPS(DL, V1, V2, Mask, Zeroable, DAG)) return V; if (!isSingleSHUFPSMask(Mask)) @@ -9301,6 +10261,7 @@ static SDValue lowerV4F32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, /// We try to handle these with integer-domain shuffles where we can, but for /// blends we use the floating point domain blend instructions. static SDValue lowerV4I32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, + const SmallBitVector &Zeroable, SDValue V1, SDValue V2, const X86Subtarget &Subtarget, SelectionDAG &DAG) { @@ -9311,8 +10272,8 @@ static SDValue lowerV4I32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // Whenever we can lower this as a zext, that instruction is strictly faster // than any alternative. It also allows us to fold memory operands into the // shuffle in many cases. - if (SDValue ZExt = lowerVectorShuffleAsZeroOrAnyExtend(DL, MVT::v4i32, V1, V2, - Mask, Subtarget, DAG)) + if (SDValue ZExt = lowerVectorShuffleAsZeroOrAnyExtend( + DL, MVT::v4i32, V1, V2, Mask, Zeroable, Subtarget, DAG)) return ZExt; int NumV2Elements = count_if(Mask, [](int M) { return M >= 4; }); @@ -9341,13 +10302,13 @@ static SDValue lowerV4I32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // Try to use shift instructions. if (SDValue Shift = lowerVectorShuffleAsShift(DL, MVT::v4i32, V1, V2, Mask, - Subtarget, DAG)) + Zeroable, Subtarget, DAG)) return Shift; // There are special ways we can lower some single-element blends. if (NumV2Elements == 1) - if (SDValue V = lowerVectorShuffleAsElementInsertion(DL, MVT::v4i32, V1, V2, - Mask, Subtarget, DAG)) + if (SDValue V = lowerVectorShuffleAsElementInsertion( + DL, MVT::v4i32, V1, V2, Mask, Zeroable, Subtarget, DAG)) return V; // We have different paths for blend lowering, but they all must use the @@ -9355,11 +10316,11 @@ static SDValue lowerV4I32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, bool IsBlendSupported = Subtarget.hasSSE41(); if (IsBlendSupported) if (SDValue Blend = lowerVectorShuffleAsBlend(DL, MVT::v4i32, V1, V2, Mask, - Subtarget, DAG)) + Zeroable, Subtarget, DAG)) return Blend; - if (SDValue Masked = - lowerVectorShuffleAsBitMask(DL, MVT::v4i32, V1, V2, Mask, DAG)) + if (SDValue Masked = lowerVectorShuffleAsBitMask(DL, MVT::v4i32, V1, V2, Mask, + Zeroable, DAG)) return Masked; // Use dedicated unpack instructions for masks that match their pattern. @@ -9374,26 +10335,31 @@ static SDValue lowerV4I32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, DL, MVT::v4i32, V1, V2, Mask, Subtarget, DAG)) return Rotate; - // If we have direct support for blends, we should lower by decomposing into - // a permute. That will be faster than the domain cross. - if (IsBlendSupported) - return lowerVectorShuffleAsDecomposedShuffleBlend(DL, MVT::v4i32, V1, V2, - Mask, DAG); - - // Try to lower by permuting the inputs into an unpack instruction. - if (SDValue Unpack = lowerVectorShuffleAsPermuteAndUnpack(DL, MVT::v4i32, V1, - V2, Mask, DAG)) - return Unpack; + // Assume that a single SHUFPS is faster than an alternative sequence of + // multiple instructions (even if the CPU has a domain penalty). + // If some CPU is harmed by the domain switch, we can fix it in a later pass. + if (!isSingleSHUFPSMask(Mask)) { + // If we have direct support for blends, we should lower by decomposing into + // a permute. That will be faster than the domain cross. + if (IsBlendSupported) + return lowerVectorShuffleAsDecomposedShuffleBlend(DL, MVT::v4i32, V1, V2, + Mask, DAG); + + // Try to lower by permuting the inputs into an unpack instruction. + if (SDValue Unpack = lowerVectorShuffleAsPermuteAndUnpack( + DL, MVT::v4i32, V1, V2, Mask, DAG)) + return Unpack; + } // We implement this with SHUFPS because it can blend from two vectors. // Because we're going to eventually use SHUFPS, we use SHUFPS even to build // up the inputs, bypassing domain shift penalties that we would encur if we // directly used PSHUFD on Nehalem and older. For newer chips, this isn't // relevant. - return DAG.getBitcast( - MVT::v4i32, - DAG.getVectorShuffle(MVT::v4f32, DL, DAG.getBitcast(MVT::v4f32, V1), - DAG.getBitcast(MVT::v4f32, V2), Mask)); + SDValue CastV1 = DAG.getBitcast(MVT::v4f32, V1); + SDValue CastV2 = DAG.getBitcast(MVT::v4f32, V2); + SDValue ShufPS = DAG.getVectorShuffle(MVT::v4f32, DL, CastV1, CastV2, Mask); + return DAG.getBitcast(MVT::v4i32, ShufPS); } /// \brief Lowering of single-input v8i16 shuffles is the cornerstone of SSE2 @@ -9551,18 +10517,15 @@ static SDValue lowerV8I16GeneralSingleInputVectorShuffle( auto FixFlippedInputs = [&V, &DL, &Mask, &DAG](int PinnedIdx, int DWord, ArrayRef<int> Inputs) { int FixIdx = PinnedIdx ^ 1; // The adjacent slot to the pinned slot. - bool IsFixIdxInput = std::find(Inputs.begin(), Inputs.end(), - PinnedIdx ^ 1) != Inputs.end(); + bool IsFixIdxInput = is_contained(Inputs, PinnedIdx ^ 1); // Determine whether the free index is in the flipped dword or the // unflipped dword based on where the pinned index is. We use this bit // in an xor to conditionally select the adjacent dword. int FixFreeIdx = 2 * (DWord ^ (PinnedIdx / 2 == DWord)); - bool IsFixFreeIdxInput = std::find(Inputs.begin(), Inputs.end(), - FixFreeIdx) != Inputs.end(); + bool IsFixFreeIdxInput = is_contained(Inputs, FixFreeIdx); if (IsFixIdxInput == IsFixFreeIdxInput) FixFreeIdx += 1; - IsFixFreeIdxInput = std::find(Inputs.begin(), Inputs.end(), - FixFreeIdx) != Inputs.end(); + IsFixFreeIdxInput = is_contained(Inputs, FixFreeIdx); assert(IsFixIdxInput != IsFixFreeIdxInput && "We need to be changing the number of flipped inputs!"); int PSHUFHalfMask[] = {0, 1, 2, 3}; @@ -9734,9 +10697,8 @@ static SDValue lowerV8I16GeneralSingleInputVectorShuffle( // by inputs being moved and *staying* in that half. if (IncomingInputs.size() == 1) { if (isWordClobbered(SourceHalfMask, IncomingInputs[0] - SourceOffset)) { - int InputFixed = std::find(std::begin(SourceHalfMask), - std::end(SourceHalfMask), -1) - - std::begin(SourceHalfMask) + SourceOffset; + int InputFixed = find(SourceHalfMask, -1) - std::begin(SourceHalfMask) + + SourceOffset; SourceHalfMask[InputFixed - SourceOffset] = IncomingInputs[0] - SourceOffset; std::replace(HalfMask.begin(), HalfMask.end(), IncomingInputs[0], @@ -9868,8 +10830,8 @@ static SDValue lowerV8I16GeneralSingleInputVectorShuffle( /// blend if only one input is used. static SDValue lowerVectorShuffleAsBlendOfPSHUFBs( const SDLoc &DL, MVT VT, SDValue V1, SDValue V2, ArrayRef<int> Mask, - SelectionDAG &DAG, bool &V1InUse, bool &V2InUse) { - SmallBitVector Zeroable = computeZeroableShuffleElements(Mask, V1, V2); + const SmallBitVector &Zeroable, SelectionDAG &DAG, bool &V1InUse, + bool &V2InUse) { SDValue V1Mask[16]; SDValue V2Mask[16]; V1InUse = false; @@ -9929,6 +10891,7 @@ static SDValue lowerVectorShuffleAsBlendOfPSHUFBs( /// halves of the inputs separately (making them have relatively few inputs) /// and then concatenate them. static SDValue lowerV8I16VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, + const SmallBitVector &Zeroable, SDValue V1, SDValue V2, const X86Subtarget &Subtarget, SelectionDAG &DAG) { @@ -9939,7 +10902,7 @@ static SDValue lowerV8I16VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // Whenever we can lower this as a zext, that instruction is strictly faster // than any alternative. if (SDValue ZExt = lowerVectorShuffleAsZeroOrAnyExtend( - DL, MVT::v8i16, V1, V2, Mask, Subtarget, DAG)) + DL, MVT::v8i16, V1, V2, Mask, Zeroable, Subtarget, DAG)) return ZExt; int NumV2Inputs = count_if(Mask, [](int M) { return M >= 8; }); @@ -9952,7 +10915,7 @@ static SDValue lowerV8I16VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // Try to use shift instructions. if (SDValue Shift = lowerVectorShuffleAsShift(DL, MVT::v8i16, V1, V1, Mask, - Subtarget, DAG)) + Zeroable, Subtarget, DAG)) return Shift; // Use dedicated unpack instructions for masks that match their pattern. @@ -9978,18 +10941,19 @@ static SDValue lowerV8I16VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // Try to use shift instructions. if (SDValue Shift = lowerVectorShuffleAsShift(DL, MVT::v8i16, V1, V2, Mask, - Subtarget, DAG)) + Zeroable, Subtarget, DAG)) return Shift; // See if we can use SSE4A Extraction / Insertion. if (Subtarget.hasSSE4A()) - if (SDValue V = lowerVectorShuffleWithSSE4A(DL, MVT::v8i16, V1, V2, Mask, DAG)) + if (SDValue V = lowerVectorShuffleWithSSE4A(DL, MVT::v8i16, V1, V2, Mask, + Zeroable, DAG)) return V; // There are special ways we can lower some single-element blends. if (NumV2Inputs == 1) - if (SDValue V = lowerVectorShuffleAsElementInsertion(DL, MVT::v8i16, V1, V2, - Mask, Subtarget, DAG)) + if (SDValue V = lowerVectorShuffleAsElementInsertion( + DL, MVT::v8i16, V1, V2, Mask, Zeroable, Subtarget, DAG)) return V; // We have different paths for blend lowering, but they all must use the @@ -9997,11 +10961,11 @@ static SDValue lowerV8I16VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, bool IsBlendSupported = Subtarget.hasSSE41(); if (IsBlendSupported) if (SDValue Blend = lowerVectorShuffleAsBlend(DL, MVT::v8i16, V1, V2, Mask, - Subtarget, DAG)) + Zeroable, Subtarget, DAG)) return Blend; - if (SDValue Masked = - lowerVectorShuffleAsBitMask(DL, MVT::v8i16, V1, V2, Mask, DAG)) + if (SDValue Masked = lowerVectorShuffleAsBitMask(DL, MVT::v8i16, V1, V2, Mask, + Zeroable, DAG)) return Masked; // Use dedicated unpack instructions for masks that match their pattern. @@ -10027,14 +10991,14 @@ static SDValue lowerV8I16VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // can both shuffle and set up the inefficient blend. if (!IsBlendSupported && Subtarget.hasSSSE3()) { bool V1InUse, V2InUse; - return lowerVectorShuffleAsBlendOfPSHUFBs(DL, MVT::v8i16, V1, V2, Mask, DAG, - V1InUse, V2InUse); + return lowerVectorShuffleAsBlendOfPSHUFBs(DL, MVT::v8i16, V1, V2, Mask, + Zeroable, DAG, V1InUse, V2InUse); } // We can always bit-blend if we have to so the fallback strategy is to // decompose into single-input permutes and blends. return lowerVectorShuffleAsDecomposedShuffleBlend(DL, MVT::v8i16, V1, V2, - Mask, DAG); + Mask, DAG); } /// \brief Check whether a compaction lowering can be done by dropping even @@ -10111,6 +11075,7 @@ static int canLowerByDroppingEvenElements(ArrayRef<int> Mask, /// the existing lowering for v8i16 blends on each half, finally PACK-ing them /// back together. static SDValue lowerV16I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, + const SmallBitVector &Zeroable, SDValue V1, SDValue V2, const X86Subtarget &Subtarget, SelectionDAG &DAG) { @@ -10120,7 +11085,7 @@ static SDValue lowerV16I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // Try to use shift instructions. if (SDValue Shift = lowerVectorShuffleAsShift(DL, MVT::v16i8, V1, V2, Mask, - Subtarget, DAG)) + Zeroable, Subtarget, DAG)) return Shift; // Try to use byte rotation instructions. @@ -10130,12 +11095,13 @@ static SDValue lowerV16I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // Try to use a zext lowering. if (SDValue ZExt = lowerVectorShuffleAsZeroOrAnyExtend( - DL, MVT::v16i8, V1, V2, Mask, Subtarget, DAG)) + DL, MVT::v16i8, V1, V2, Mask, Zeroable, Subtarget, DAG)) return ZExt; // See if we can use SSE4A Extraction / Insertion. if (Subtarget.hasSSE4A()) - if (SDValue V = lowerVectorShuffleWithSSE4A(DL, MVT::v16i8, V1, V2, Mask, DAG)) + if (SDValue V = lowerVectorShuffleWithSSE4A(DL, MVT::v16i8, V1, V2, Mask, + Zeroable, DAG)) return V; int NumV2Elements = count_if(Mask, [](int M) { return M >= 16; }); @@ -10238,8 +11204,8 @@ static SDValue lowerV16I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, return V; } - if (SDValue Masked = - lowerVectorShuffleAsBitMask(DL, MVT::v16i8, V1, V2, Mask, DAG)) + if (SDValue Masked = lowerVectorShuffleAsBitMask(DL, MVT::v16i8, V1, V2, Mask, + Zeroable, DAG)) return Masked; // Use dedicated unpack instructions for masks that match their pattern. @@ -10265,15 +11231,15 @@ static SDValue lowerV16I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, bool V2InUse = false; SDValue PSHUFB = lowerVectorShuffleAsBlendOfPSHUFBs( - DL, MVT::v16i8, V1, V2, Mask, DAG, V1InUse, V2InUse); + DL, MVT::v16i8, V1, V2, Mask, Zeroable, DAG, V1InUse, V2InUse); // If both V1 and V2 are in use and we can use a direct blend or an unpack, // do so. This avoids using them to handle blends-with-zero which is // important as a single pshufb is significantly faster for that. if (V1InUse && V2InUse) { if (Subtarget.hasSSE41()) - if (SDValue Blend = lowerVectorShuffleAsBlend(DL, MVT::v16i8, V1, V2, - Mask, Subtarget, DAG)) + if (SDValue Blend = lowerVectorShuffleAsBlend( + DL, MVT::v16i8, V1, V2, Mask, Zeroable, Subtarget, DAG)) return Blend; // We can use an unpack to do the blending rather than an or in some @@ -10294,8 +11260,8 @@ static SDValue lowerV16I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // There are special ways we can lower some single-element blends. if (NumV2Elements == 1) - if (SDValue V = lowerVectorShuffleAsElementInsertion(DL, MVT::v16i8, V1, V2, - Mask, Subtarget, DAG)) + if (SDValue V = lowerVectorShuffleAsElementInsertion( + DL, MVT::v16i8, V1, V2, Mask, Zeroable, Subtarget, DAG)) return V; if (SDValue BitBlend = @@ -10349,22 +11315,18 @@ static SDValue lowerV16I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // with a pack. SDValue V = V1; - int LoBlendMask[8] = {-1, -1, -1, -1, -1, -1, -1, -1}; - int HiBlendMask[8] = {-1, -1, -1, -1, -1, -1, -1, -1}; + std::array<int, 8> LoBlendMask = {{-1, -1, -1, -1, -1, -1, -1, -1}}; + std::array<int, 8> HiBlendMask = {{-1, -1, -1, -1, -1, -1, -1, -1}}; for (int i = 0; i < 16; ++i) if (Mask[i] >= 0) (i < 8 ? LoBlendMask[i] : HiBlendMask[i % 8]) = Mask[i]; - SDValue Zero = getZeroVector(MVT::v8i16, Subtarget, DAG, DL); - SDValue VLoHalf, VHiHalf; // Check if any of the odd lanes in the v16i8 are used. If not, we can mask // them out and avoid using UNPCK{L,H} to extract the elements of V as // i16s. - if (std::none_of(std::begin(LoBlendMask), std::end(LoBlendMask), - [](int M) { return M >= 0 && M % 2 == 1; }) && - std::none_of(std::begin(HiBlendMask), std::end(HiBlendMask), - [](int M) { return M >= 0 && M % 2 == 1; })) { + if (none_of(LoBlendMask, [](int M) { return M >= 0 && M % 2 == 1; }) && + none_of(HiBlendMask, [](int M) { return M >= 0 && M % 2 == 1; })) { // Use a mask to drop the high bytes. VLoHalf = DAG.getBitcast(MVT::v8i16, V); VLoHalf = DAG.getNode(ISD::AND, DL, MVT::v8i16, VLoHalf, @@ -10383,6 +11345,8 @@ static SDValue lowerV16I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, } else { // Otherwise just unpack the low half of V into VLoHalf and the high half into // VHiHalf so that we can blend them as i16s. + SDValue Zero = getZeroVector(MVT::v16i8, Subtarget, DAG, DL); + VLoHalf = DAG.getBitcast( MVT::v8i16, DAG.getNode(X86ISD::UNPCKL, DL, MVT::v16i8, V, Zero)); VHiHalf = DAG.getBitcast( @@ -10401,83 +11365,28 @@ static SDValue lowerV16I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, /// dispatches to the lowering routines accordingly. static SDValue lower128BitVectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, MVT VT, SDValue V1, SDValue V2, + const SmallBitVector &Zeroable, const X86Subtarget &Subtarget, SelectionDAG &DAG) { switch (VT.SimpleTy) { case MVT::v2i64: - return lowerV2I64VectorShuffle(DL, Mask, V1, V2, Subtarget, DAG); + return lowerV2I64VectorShuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG); case MVT::v2f64: - return lowerV2F64VectorShuffle(DL, Mask, V1, V2, Subtarget, DAG); + return lowerV2F64VectorShuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG); case MVT::v4i32: - return lowerV4I32VectorShuffle(DL, Mask, V1, V2, Subtarget, DAG); + return lowerV4I32VectorShuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG); case MVT::v4f32: - return lowerV4F32VectorShuffle(DL, Mask, V1, V2, Subtarget, DAG); + return lowerV4F32VectorShuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG); case MVT::v8i16: - return lowerV8I16VectorShuffle(DL, Mask, V1, V2, Subtarget, DAG); + return lowerV8I16VectorShuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG); case MVT::v16i8: - return lowerV16I8VectorShuffle(DL, Mask, V1, V2, Subtarget, DAG); + return lowerV16I8VectorShuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG); default: llvm_unreachable("Unimplemented!"); } } -/// \brief Helper function to test whether a shuffle mask could be -/// simplified by widening the elements being shuffled. -/// -/// Appends the mask for wider elements in WidenedMask if valid. Otherwise -/// leaves it in an unspecified state. -/// -/// NOTE: This must handle normal vector shuffle masks and *target* vector -/// shuffle masks. The latter have the special property of a '-2' representing -/// a zero-ed lane of a vector. -static bool canWidenShuffleElements(ArrayRef<int> Mask, - SmallVectorImpl<int> &WidenedMask) { - WidenedMask.assign(Mask.size() / 2, 0); - for (int i = 0, Size = Mask.size(); i < Size; i += 2) { - // If both elements are undef, its trivial. - if (Mask[i] == SM_SentinelUndef && Mask[i + 1] == SM_SentinelUndef) { - WidenedMask[i/2] = SM_SentinelUndef; - continue; - } - - // Check for an undef mask and a mask value properly aligned to fit with - // a pair of values. If we find such a case, use the non-undef mask's value. - if (Mask[i] == SM_SentinelUndef && Mask[i + 1] >= 0 && Mask[i + 1] % 2 == 1) { - WidenedMask[i/2] = Mask[i + 1] / 2; - continue; - } - if (Mask[i + 1] == SM_SentinelUndef && Mask[i] >= 0 && Mask[i] % 2 == 0) { - WidenedMask[i/2] = Mask[i] / 2; - continue; - } - - // When zeroing, we need to spread the zeroing across both lanes to widen. - if (Mask[i] == SM_SentinelZero || Mask[i + 1] == SM_SentinelZero) { - if ((Mask[i] == SM_SentinelZero || Mask[i] == SM_SentinelUndef) && - (Mask[i + 1] == SM_SentinelZero || Mask[i + 1] == SM_SentinelUndef)) { - WidenedMask[i/2] = SM_SentinelZero; - continue; - } - return false; - } - - // Finally check if the two mask values are adjacent and aligned with - // a pair. - if (Mask[i] != SM_SentinelUndef && Mask[i] % 2 == 0 && Mask[i] + 1 == Mask[i + 1]) { - WidenedMask[i/2] = Mask[i] / 2; - continue; - } - - // Otherwise we can't safely widen the elements used in this shuffle. - return false; - } - assert(WidenedMask.size() == Mask.size() / 2 && - "Incorrect size of mask after widening the elements!"); - - return true; -} - /// \brief Generic routine to split vector shuffle into half-sized shuffles. /// /// This routine just extracts two subvectors, shuffles them independently, and @@ -10712,15 +11621,20 @@ static SDValue lowerVectorShuffleAsLanePermuteAndBlend(const SDLoc &DL, MVT VT, /// \brief Handle lowering 2-lane 128-bit shuffles. static SDValue lowerV2X128VectorShuffle(const SDLoc &DL, MVT VT, SDValue V1, SDValue V2, ArrayRef<int> Mask, + const SmallBitVector &Zeroable, const X86Subtarget &Subtarget, SelectionDAG &DAG) { + SmallVector<int, 4> WidenedMask; + if (!canWidenShuffleElements(Mask, WidenedMask)) + return SDValue(); + // TODO: If minimizing size and one of the inputs is a zero vector and the // the zero vector has only one use, we could use a VPERM2X128 to save the // instruction bytes needed to explicitly generate the zero vector. // Blends are faster and handle all the non-lane-crossing cases. if (SDValue Blend = lowerVectorShuffleAsBlend(DL, VT, V1, V2, Mask, - Subtarget, DAG)) + Zeroable, Subtarget, DAG)) return Blend; bool IsV1Zero = ISD::isBuildVectorAllZeros(V1.getNode()); @@ -10761,15 +11675,10 @@ static SDValue lowerV2X128VectorShuffle(const SDLoc &DL, MVT VT, SDValue V1, // [6] - ignore // [7] - zero high half of destination - int MaskLO = Mask[0]; - if (MaskLO == SM_SentinelUndef) - MaskLO = Mask[1] == SM_SentinelUndef ? 0 : Mask[1]; - - int MaskHI = Mask[2]; - if (MaskHI == SM_SentinelUndef) - MaskHI = Mask[3] == SM_SentinelUndef ? 0 : Mask[3]; + int MaskLO = WidenedMask[0] < 0 ? 0 : WidenedMask[0]; + int MaskHI = WidenedMask[1] < 0 ? 0 : WidenedMask[1]; - unsigned PermMask = MaskLO / 2 | (MaskHI / 2) << 4; + unsigned PermMask = MaskLO | (MaskHI << 4); // If either input is a zero vector, replace it with an undef input. // Shuffle mask values < 4 are selecting elements of V1. @@ -10778,16 +11687,16 @@ static SDValue lowerV2X128VectorShuffle(const SDLoc &DL, MVT VT, SDValue V1, // selecting the zero vector and setting the zero mask bit. if (IsV1Zero) { V1 = DAG.getUNDEF(VT); - if (MaskLO < 4) + if (MaskLO < 2) PermMask = (PermMask & 0xf0) | 0x08; - if (MaskHI < 4) + if (MaskHI < 2) PermMask = (PermMask & 0x0f) | 0x80; } if (IsV2Zero) { V2 = DAG.getUNDEF(VT); - if (MaskLO >= 4) + if (MaskLO >= 2) PermMask = (PermMask & 0xf0) | 0x08; - if (MaskHI >= 4) + if (MaskHI >= 2) PermMask = (PermMask & 0x0f) | 0x80; } @@ -11178,35 +12087,65 @@ static SDValue lowerShuffleAsRepeatedMaskAndLanePermute( SubLaneMask); } -static SDValue lowerVectorShuffleWithSHUFPD(const SDLoc &DL, MVT VT, - ArrayRef<int> Mask, SDValue V1, - SDValue V2, SelectionDAG &DAG) { +static bool matchVectorShuffleWithSHUFPD(MVT VT, SDValue &V1, SDValue &V2, + unsigned &ShuffleImm, + ArrayRef<int> Mask) { + int NumElts = VT.getVectorNumElements(); + assert(VT.getScalarType() == MVT::f64 && + (NumElts == 2 || NumElts == 4 || NumElts == 8) && + "Unexpected data type for VSHUFPD"); // Mask for V8F64: 0/1, 8/9, 2/3, 10/11, 4/5, .. // Mask for V4F64; 0/1, 4/5, 2/3, 6/7.. - assert(VT.getScalarSizeInBits() == 64 && "Unexpected data type for VSHUFPD"); - int NumElts = VT.getVectorNumElements(); + ShuffleImm = 0; bool ShufpdMask = true; bool CommutableMask = true; - unsigned Immediate = 0; for (int i = 0; i < NumElts; ++i) { - if (Mask[i] < 0) + if (Mask[i] == SM_SentinelUndef) continue; + if (Mask[i] < 0) + return false; int Val = (i & 6) + NumElts * (i & 1); - int CommutVal = (i & 0xe) + NumElts * ((i & 1)^1); - if (Mask[i] < Val || Mask[i] > Val + 1) + int CommutVal = (i & 0xe) + NumElts * ((i & 1) ^ 1); + if (Mask[i] < Val || Mask[i] > Val + 1) ShufpdMask = false; - if (Mask[i] < CommutVal || Mask[i] > CommutVal + 1) + if (Mask[i] < CommutVal || Mask[i] > CommutVal + 1) CommutableMask = false; - Immediate |= (Mask[i] % 2) << i; + ShuffleImm |= (Mask[i] % 2) << i; } + if (ShufpdMask) - return DAG.getNode(X86ISD::SHUFP, DL, VT, V1, V2, - DAG.getConstant(Immediate, DL, MVT::i8)); - if (CommutableMask) - return DAG.getNode(X86ISD::SHUFP, DL, VT, V2, V1, - DAG.getConstant(Immediate, DL, MVT::i8)); - return SDValue(); + return true; + if (CommutableMask) { + std::swap(V1, V2); + return true; + } + + return false; +} + +static SDValue lowerVectorShuffleWithSHUFPD(const SDLoc &DL, MVT VT, + ArrayRef<int> Mask, SDValue V1, + SDValue V2, SelectionDAG &DAG) { + unsigned Immediate = 0; + if (!matchVectorShuffleWithSHUFPD(VT, V1, V2, Immediate, Mask)) + return SDValue(); + + return DAG.getNode(X86ISD::SHUFP, DL, VT, V1, V2, + DAG.getConstant(Immediate, DL, MVT::i8)); +} + +static SDValue lowerVectorShuffleWithPERMV(const SDLoc &DL, MVT VT, + ArrayRef<int> Mask, SDValue V1, + SDValue V2, SelectionDAG &DAG) { + MVT MaskEltVT = MVT::getIntegerVT(VT.getScalarSizeInBits()); + MVT MaskVecVT = MVT::getVectorVT(MaskEltVT, VT.getVectorNumElements()); + + SDValue MaskNode = getConstVector(Mask, MaskVecVT, DAG, DL, true); + if (V2.isUndef()) + return DAG.getNode(X86ISD::VPERMV, DL, VT, MaskNode, V1); + + return DAG.getNode(X86ISD::VPERMV3, DL, VT, V1, MaskNode, V2); } /// \brief Handle lowering of 4-lane 64-bit floating point shuffles. @@ -11214,6 +12153,7 @@ static SDValue lowerVectorShuffleWithSHUFPD(const SDLoc &DL, MVT VT, /// Also ends up handling lowering of 4-lane 64-bit integer shuffles when AVX2 /// isn't available. static SDValue lowerV4F64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, + const SmallBitVector &Zeroable, SDValue V1, SDValue V2, const X86Subtarget &Subtarget, SelectionDAG &DAG) { @@ -11221,11 +12161,9 @@ static SDValue lowerV4F64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, assert(V2.getSimpleValueType() == MVT::v4f64 && "Bad operand type!"); assert(Mask.size() == 4 && "Unexpected mask size for v4 shuffle!"); - SmallVector<int, 4> WidenedMask; - if (canWidenShuffleElements(Mask, WidenedMask)) - if (SDValue V = lowerV2X128VectorShuffle(DL, MVT::v4f64, V1, V2, Mask, - Subtarget, DAG)) - return V; + if (SDValue V = lowerV2X128VectorShuffle(DL, MVT::v4f64, V1, V2, Mask, + Zeroable, Subtarget, DAG)) + return V; if (V2.isUndef()) { // Check for being able to broadcast a single element. @@ -11268,7 +12206,7 @@ static SDValue lowerV4F64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, return V; if (SDValue Blend = lowerVectorShuffleAsBlend(DL, MVT::v4f64, V1, V2, Mask, - Subtarget, DAG)) + Zeroable, Subtarget, DAG)) return Blend; // Check if the blend happens to exactly fit that of SHUFPD. @@ -11280,7 +12218,7 @@ static SDValue lowerV4F64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // the results into the target lanes. if (SDValue V = lowerShuffleAsRepeatedMaskAndLanePermute( DL, MVT::v4f64, V1, V2, Mask, Subtarget, DAG)) - return V; + return V; // Try to simplify this by merging 128-bit lanes to enable a lane-based // shuffle. However, if we have AVX2 and either inputs are already in place, @@ -11291,6 +12229,11 @@ static SDValue lowerV4F64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, if (SDValue Result = lowerVectorShuffleByMerging128BitLanes( DL, MVT::v4f64, V1, V2, Mask, Subtarget, DAG)) return Result; + // If we have VLX support, we can use VEXPAND. + if (Subtarget.hasVLX()) + if (SDValue V = lowerVectorShuffleToEXPAND(DL, MVT::v4f64, Zeroable, Mask, + V1, V2, DAG, Subtarget)) + return V; // If we have AVX2 then we always want to lower with a blend because an v4 we // can fully permute the elements. @@ -11307,6 +12250,7 @@ static SDValue lowerV4F64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, /// This routine is only called when we have AVX2 and thus a reasonable /// instruction set for v4i64 shuffling.. static SDValue lowerV4I64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, + const SmallBitVector &Zeroable, SDValue V1, SDValue V2, const X86Subtarget &Subtarget, SelectionDAG &DAG) { @@ -11315,14 +12259,12 @@ static SDValue lowerV4I64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, assert(Mask.size() == 4 && "Unexpected mask size for v4 shuffle!"); assert(Subtarget.hasAVX2() && "We can only lower v4i64 with AVX2!"); - SmallVector<int, 4> WidenedMask; - if (canWidenShuffleElements(Mask, WidenedMask)) - if (SDValue V = lowerV2X128VectorShuffle(DL, MVT::v4i64, V1, V2, Mask, - Subtarget, DAG)) - return V; + if (SDValue V = lowerV2X128VectorShuffle(DL, MVT::v4i64, V1, V2, Mask, + Zeroable, Subtarget, DAG)) + return V; if (SDValue Blend = lowerVectorShuffleAsBlend(DL, MVT::v4i64, V1, V2, Mask, - Subtarget, DAG)) + Zeroable, Subtarget, DAG)) return Blend; // Check for being able to broadcast a single element. @@ -11352,9 +12294,25 @@ static SDValue lowerV4I64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // Try to use shift instructions. if (SDValue Shift = lowerVectorShuffleAsShift(DL, MVT::v4i64, V1, V2, Mask, - Subtarget, DAG)) + Zeroable, Subtarget, DAG)) return Shift; + // If we have VLX support, we can use VALIGN or VEXPAND. + if (Subtarget.hasVLX()) { + if (SDValue Rotate = lowerVectorShuffleAsRotate(DL, MVT::v4i64, V1, V2, + Mask, Subtarget, DAG)) + return Rotate; + + if (SDValue V = lowerVectorShuffleToEXPAND(DL, MVT::v4i64, Zeroable, Mask, + V1, V2, DAG, Subtarget)) + return V; + } + + // Try to use PALIGNR. + if (SDValue Rotate = lowerVectorShuffleAsByteRotate(DL, MVT::v4i64, V1, V2, + Mask, Subtarget, DAG)) + return Rotate; + // Use dedicated unpack instructions for masks that match their pattern. if (SDValue V = lowerVectorShuffleWithUNPCK(DL, MVT::v4i64, Mask, V1, V2, DAG)) @@ -11364,8 +12322,8 @@ static SDValue lowerV4I64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // shuffle. However, if we have AVX2 and either inputs are already in place, // we will be able to shuffle even across lanes the other input in a single // instruction so skip this pattern. - if (!(Subtarget.hasAVX2() && (isShuffleMaskInputInPlace(0, Mask) || - isShuffleMaskInputInPlace(1, Mask)))) + if (!isShuffleMaskInputInPlace(0, Mask) && + !isShuffleMaskInputInPlace(1, Mask)) if (SDValue Result = lowerVectorShuffleByMerging128BitLanes( DL, MVT::v4i64, V1, V2, Mask, Subtarget, DAG)) return Result; @@ -11380,6 +12338,7 @@ static SDValue lowerV4I64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, /// Also ends up handling lowering of 8-lane 32-bit integer shuffles when AVX2 /// isn't available. static SDValue lowerV8F32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, + const SmallBitVector &Zeroable, SDValue V1, SDValue V2, const X86Subtarget &Subtarget, SelectionDAG &DAG) { @@ -11388,7 +12347,7 @@ static SDValue lowerV8F32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, assert(Mask.size() == 8 && "Unexpected mask size for v8 shuffle!"); if (SDValue Blend = lowerVectorShuffleAsBlend(DL, MVT::v8f32, V1, V2, Mask, - Subtarget, DAG)) + Zeroable, Subtarget, DAG)) return Blend; // Check for being able to broadcast a single element. @@ -11432,17 +12391,12 @@ static SDValue lowerV8F32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // If we have a single input shuffle with different shuffle patterns in the // two 128-bit lanes use the variable mask to VPERMILPS. if (V2.isUndef()) { - SDValue VPermMask[8]; - for (int i = 0; i < 8; ++i) - VPermMask[i] = Mask[i] < 0 ? DAG.getUNDEF(MVT::i32) - : DAG.getConstant(Mask[i], DL, MVT::i32); + SDValue VPermMask = getConstVector(Mask, MVT::v8i32, DAG, DL, true); if (!is128BitLaneCrossingShuffleMask(MVT::v8f32, Mask)) - return DAG.getNode(X86ISD::VPERMILPV, DL, MVT::v8f32, V1, - DAG.getBuildVector(MVT::v8i32, DL, VPermMask)); + return DAG.getNode(X86ISD::VPERMILPV, DL, MVT::v8f32, V1, VPermMask); if (Subtarget.hasAVX2()) - return DAG.getNode(X86ISD::VPERMV, DL, MVT::v8f32, - DAG.getBuildVector(MVT::v8i32, DL, VPermMask), V1); + return DAG.getNode(X86ISD::VPERMV, DL, MVT::v8f32, VPermMask, V1); // Otherwise, fall back. return lowerVectorShuffleAsLanePermuteAndBlend(DL, MVT::v8f32, V1, V2, Mask, @@ -11454,6 +12408,11 @@ static SDValue lowerV8F32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, if (SDValue Result = lowerVectorShuffleByMerging128BitLanes( DL, MVT::v8f32, V1, V2, Mask, Subtarget, DAG)) return Result; + // If we have VLX support, we can use VEXPAND. + if (Subtarget.hasVLX()) + if (SDValue V = lowerVectorShuffleToEXPAND(DL, MVT::v8f32, Zeroable, Mask, + V1, V2, DAG, Subtarget)) + return V; // If we have AVX2 then we always want to lower with a blend because at v8 we // can fully permute the elements. @@ -11470,6 +12429,7 @@ static SDValue lowerV8F32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, /// This routine is only called when we have AVX2 and thus a reasonable /// instruction set for v8i32 shuffling.. static SDValue lowerV8I32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, + const SmallBitVector &Zeroable, SDValue V1, SDValue V2, const X86Subtarget &Subtarget, SelectionDAG &DAG) { @@ -11481,12 +12441,12 @@ static SDValue lowerV8I32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // Whenever we can lower this as a zext, that instruction is strictly faster // than any alternative. It also allows us to fold memory operands into the // shuffle in many cases. - if (SDValue ZExt = lowerVectorShuffleAsZeroOrAnyExtend(DL, MVT::v8i32, V1, V2, - Mask, Subtarget, DAG)) + if (SDValue ZExt = lowerVectorShuffleAsZeroOrAnyExtend( + DL, MVT::v8i32, V1, V2, Mask, Zeroable, Subtarget, DAG)) return ZExt; if (SDValue Blend = lowerVectorShuffleAsBlend(DL, MVT::v8i32, V1, V2, Mask, - Subtarget, DAG)) + Zeroable, Subtarget, DAG)) return Blend; // Check for being able to broadcast a single element. @@ -11498,7 +12458,9 @@ static SDValue lowerV8I32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // efficient instructions that mirror the shuffles across the two 128-bit // lanes. SmallVector<int, 4> RepeatedMask; - if (is128BitLaneRepeatedShuffleMask(MVT::v8i32, Mask, RepeatedMask)) { + bool Is128BitLaneRepeatedShuffle = + is128BitLaneRepeatedShuffleMask(MVT::v8i32, Mask, RepeatedMask); + if (Is128BitLaneRepeatedShuffle) { assert(RepeatedMask.size() == 4 && "Unexpected repeated mask size!"); if (V2.isUndef()) return DAG.getNode(X86ISD::PSHUFD, DL, MVT::v8i32, V1, @@ -11512,16 +12474,27 @@ static SDValue lowerV8I32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // Try to use shift instructions. if (SDValue Shift = lowerVectorShuffleAsShift(DL, MVT::v8i32, V1, V2, Mask, - Subtarget, DAG)) + Zeroable, Subtarget, DAG)) return Shift; + // If we have VLX support, we can use VALIGN or EXPAND. + if (Subtarget.hasVLX()) { + if (SDValue Rotate = lowerVectorShuffleAsRotate(DL, MVT::v8i32, V1, V2, + Mask, Subtarget, DAG)) + return Rotate; + + if (SDValue V = lowerVectorShuffleToEXPAND(DL, MVT::v8i32, Zeroable, Mask, + V1, V2, DAG, Subtarget)) + return V; + } + // Try to use byte rotation instructions. if (SDValue Rotate = lowerVectorShuffleAsByteRotate( DL, MVT::v8i32, V1, V2, Mask, Subtarget, DAG)) return Rotate; // Try to create an in-lane repeating shuffle mask and then shuffle the - // the results into the target lanes. + // results into the target lanes. if (SDValue V = lowerShuffleAsRepeatedMaskAndLanePermute( DL, MVT::v8i32, V1, V2, Mask, Subtarget, DAG)) return V; @@ -11529,12 +12502,19 @@ static SDValue lowerV8I32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // If the shuffle patterns aren't repeated but it is a single input, directly // generate a cross-lane VPERMD instruction. if (V2.isUndef()) { - SDValue VPermMask[8]; - for (int i = 0; i < 8; ++i) - VPermMask[i] = Mask[i] < 0 ? DAG.getUNDEF(MVT::i32) - : DAG.getConstant(Mask[i], DL, MVT::i32); - return DAG.getNode(X86ISD::VPERMV, DL, MVT::v8i32, - DAG.getBuildVector(MVT::v8i32, DL, VPermMask), V1); + SDValue VPermMask = getConstVector(Mask, MVT::v8i32, DAG, DL, true); + return DAG.getNode(X86ISD::VPERMV, DL, MVT::v8i32, VPermMask, V1); + } + + // Assume that a single SHUFPS is faster than an alternative sequence of + // multiple instructions (even if the CPU has a domain penalty). + // If some CPU is harmed by the domain switch, we can fix it in a later pass. + if (Is128BitLaneRepeatedShuffle && isSingleSHUFPSMask(RepeatedMask)) { + SDValue CastV1 = DAG.getBitcast(MVT::v8f32, V1); + SDValue CastV2 = DAG.getBitcast(MVT::v8f32, V2); + SDValue ShufPS = lowerVectorShuffleWithSHUFPS(DL, MVT::v8f32, RepeatedMask, + CastV1, CastV2, DAG); + return DAG.getBitcast(MVT::v8i32, ShufPS); } // Try to simplify this by merging 128-bit lanes to enable a lane-based @@ -11553,6 +12533,7 @@ static SDValue lowerV8I32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, /// This routine is only called when we have AVX2 and thus a reasonable /// instruction set for v16i16 shuffling.. static SDValue lowerV16I16VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, + const SmallBitVector &Zeroable, SDValue V1, SDValue V2, const X86Subtarget &Subtarget, SelectionDAG &DAG) { @@ -11564,8 +12545,8 @@ static SDValue lowerV16I16VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // Whenever we can lower this as a zext, that instruction is strictly faster // than any alternative. It also allows us to fold memory operands into the // shuffle in many cases. - if (SDValue ZExt = lowerVectorShuffleAsZeroOrAnyExtend(DL, MVT::v16i16, V1, V2, - Mask, Subtarget, DAG)) + if (SDValue ZExt = lowerVectorShuffleAsZeroOrAnyExtend( + DL, MVT::v16i16, V1, V2, Mask, Zeroable, Subtarget, DAG)) return ZExt; // Check for being able to broadcast a single element. @@ -11574,7 +12555,7 @@ static SDValue lowerV16I16VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, return Broadcast; if (SDValue Blend = lowerVectorShuffleAsBlend(DL, MVT::v16i16, V1, V2, Mask, - Subtarget, DAG)) + Zeroable, Subtarget, DAG)) return Blend; // Use dedicated unpack instructions for masks that match their pattern. @@ -11584,7 +12565,7 @@ static SDValue lowerV16I16VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // Try to use shift instructions. if (SDValue Shift = lowerVectorShuffleAsShift(DL, MVT::v16i16, V1, V2, Mask, - Subtarget, DAG)) + Zeroable, Subtarget, DAG)) return Shift; // Try to use byte rotation instructions. @@ -11615,10 +12596,14 @@ static SDValue lowerV16I16VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, } } - if (SDValue PSHUFB = lowerVectorShuffleWithPSHUFB(DL, MVT::v16i16, Mask, V1, - V2, Subtarget, DAG)) + if (SDValue PSHUFB = lowerVectorShuffleWithPSHUFB( + DL, MVT::v16i16, Mask, V1, V2, Zeroable, Subtarget, DAG)) return PSHUFB; + // AVX512BWVL can lower to VPERMW. + if (Subtarget.hasBWI() && Subtarget.hasVLX()) + return lowerVectorShuffleWithPERMV(DL, MVT::v16i16, Mask, V1, V2, DAG); + // Try to simplify this by merging 128-bit lanes to enable a lane-based // shuffle. if (SDValue Result = lowerVectorShuffleByMerging128BitLanes( @@ -11634,6 +12619,7 @@ static SDValue lowerV16I16VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, /// This routine is only called when we have AVX2 and thus a reasonable /// instruction set for v32i8 shuffling.. static SDValue lowerV32I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, + const SmallBitVector &Zeroable, SDValue V1, SDValue V2, const X86Subtarget &Subtarget, SelectionDAG &DAG) { @@ -11645,8 +12631,8 @@ static SDValue lowerV32I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // Whenever we can lower this as a zext, that instruction is strictly faster // than any alternative. It also allows us to fold memory operands into the // shuffle in many cases. - if (SDValue ZExt = lowerVectorShuffleAsZeroOrAnyExtend(DL, MVT::v32i8, V1, V2, - Mask, Subtarget, DAG)) + if (SDValue ZExt = lowerVectorShuffleAsZeroOrAnyExtend( + DL, MVT::v32i8, V1, V2, Mask, Zeroable, Subtarget, DAG)) return ZExt; // Check for being able to broadcast a single element. @@ -11655,7 +12641,7 @@ static SDValue lowerV32I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, return Broadcast; if (SDValue Blend = lowerVectorShuffleAsBlend(DL, MVT::v32i8, V1, V2, Mask, - Subtarget, DAG)) + Zeroable, Subtarget, DAG)) return Blend; // Use dedicated unpack instructions for masks that match their pattern. @@ -11665,7 +12651,7 @@ static SDValue lowerV32I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // Try to use shift instructions. if (SDValue Shift = lowerVectorShuffleAsShift(DL, MVT::v32i8, V1, V2, Mask, - Subtarget, DAG)) + Zeroable, Subtarget, DAG)) return Shift; // Try to use byte rotation instructions. @@ -11685,8 +12671,8 @@ static SDValue lowerV32I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, return lowerVectorShuffleAsLanePermuteAndBlend(DL, MVT::v32i8, V1, V2, Mask, DAG); - if (SDValue PSHUFB = lowerVectorShuffleWithPSHUFB(DL, MVT::v32i8, Mask, V1, - V2, Subtarget, DAG)) + if (SDValue PSHUFB = lowerVectorShuffleWithPSHUFB( + DL, MVT::v32i8, Mask, V1, V2, Zeroable, Subtarget, DAG)) return PSHUFB; // Try to simplify this by merging 128-bit lanes to enable a lane-based @@ -11706,6 +12692,7 @@ static SDValue lowerV32I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, /// together based on the available instructions. static SDValue lower256BitVectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, MVT VT, SDValue V1, SDValue V2, + const SmallBitVector &Zeroable, const X86Subtarget &Subtarget, SelectionDAG &DAG) { // If we have a single input to the zero element, insert that into V1 if we @@ -11715,7 +12702,7 @@ static SDValue lower256BitVectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, if (NumV2Elements == 1 && Mask[0] >= NumElts) if (SDValue Insertion = lowerVectorShuffleAsElementInsertion( - DL, VT, V1, V2, Mask, Subtarget, DAG)) + DL, VT, V1, V2, Mask, Zeroable, Subtarget, DAG)) return Insertion; // Handle special cases where the lower or upper half is UNDEF. @@ -11734,7 +12721,8 @@ static SDValue lower256BitVectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, if (ElementBits < 32) { // No floating point type available, if we can't use the bit operations // for masking/blending then decompose into 128-bit vectors. - if (SDValue V = lowerVectorShuffleAsBitMask(DL, VT, V1, V2, Mask, DAG)) + if (SDValue V = + lowerVectorShuffleAsBitMask(DL, VT, V1, V2, Mask, Zeroable, DAG)) return V; if (SDValue V = lowerVectorShuffleAsBitBlend(DL, VT, V1, V2, Mask, DAG)) return V; @@ -11750,17 +12738,17 @@ static SDValue lower256BitVectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, switch (VT.SimpleTy) { case MVT::v4f64: - return lowerV4F64VectorShuffle(DL, Mask, V1, V2, Subtarget, DAG); + return lowerV4F64VectorShuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG); case MVT::v4i64: - return lowerV4I64VectorShuffle(DL, Mask, V1, V2, Subtarget, DAG); + return lowerV4I64VectorShuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG); case MVT::v8f32: - return lowerV8F32VectorShuffle(DL, Mask, V1, V2, Subtarget, DAG); + return lowerV8F32VectorShuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG); case MVT::v8i32: - return lowerV8I32VectorShuffle(DL, Mask, V1, V2, Subtarget, DAG); + return lowerV8I32VectorShuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG); case MVT::v16i16: - return lowerV16I16VectorShuffle(DL, Mask, V1, V2, Subtarget, DAG); + return lowerV16I16VectorShuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG); case MVT::v32i8: - return lowerV32I8VectorShuffle(DL, Mask, V1, V2, Subtarget, DAG); + return lowerV32I8VectorShuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG); default: llvm_unreachable("Not a valid 256-bit x86 vector type!"); @@ -11782,57 +12770,81 @@ static SDValue lowerV4X128VectorShuffle(const SDLoc &DL, MVT VT, if (!canWidenShuffleElements(Mask, WidenedMask)) return SDValue(); + // Check for patterns which can be matched with a single insert of a 256-bit + // subvector. + bool OnlyUsesV1 = isShuffleEquivalent(V1, V2, Mask, + {0, 1, 2, 3, 0, 1, 2, 3}); + if (OnlyUsesV1 || isShuffleEquivalent(V1, V2, Mask, + {0, 1, 2, 3, 8, 9, 10, 11})) { + MVT SubVT = MVT::getVectorVT(VT.getVectorElementType(), 4); + SDValue LoV = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVT, V1, + DAG.getIntPtrConstant(0, DL)); + SDValue HiV = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVT, + OnlyUsesV1 ? V1 : V2, + DAG.getIntPtrConstant(0, DL)); + return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, LoV, HiV); + } + + assert(WidenedMask.size() == 4); + + // See if this is an insertion of the lower 128-bits of V2 into V1. + bool IsInsert = true; + int V2Index = -1; + for (int i = 0; i < 4; ++i) { + assert(WidenedMask[i] >= -1); + if (WidenedMask[i] < 0) + continue; + + // Make sure all V1 subvectors are in place. + if (WidenedMask[i] < 4) { + if (WidenedMask[i] != i) { + IsInsert = false; + break; + } + } else { + // Make sure we only have a single V2 index and its the lowest 128-bits. + if (V2Index >= 0 || WidenedMask[i] != 4) { + IsInsert = false; + break; + } + V2Index = i; + } + } + if (IsInsert && V2Index >= 0) { + MVT SubVT = MVT::getVectorVT(VT.getVectorElementType(), 2); + SDValue Subvec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVT, V2, + DAG.getIntPtrConstant(0, DL)); + return insert128BitVector(V1, Subvec, V2Index * 2, DAG, DL); + } + + // Try to lower to to vshuf64x2/vshuf32x4. SDValue Ops[2] = {DAG.getUNDEF(VT), DAG.getUNDEF(VT)}; + unsigned PermMask = 0; // Insure elements came from the same Op. - int MaxOp1Index = VT.getVectorNumElements()/2 - 1; - for (int i = 0, Size = WidenedMask.size(); i < Size; ++i) { - if (WidenedMask[i] == SM_SentinelZero) - return SDValue(); - if (WidenedMask[i] == SM_SentinelUndef) + for (int i = 0; i < 4; ++i) { + assert(WidenedMask[i] >= -1); + if (WidenedMask[i] < 0) continue; - SDValue Op = WidenedMask[i] > MaxOp1Index ? V2 : V1; - unsigned OpIndex = (i < Size/2) ? 0 : 1; + SDValue Op = WidenedMask[i] >= 4 ? V2 : V1; + unsigned OpIndex = i / 2; if (Ops[OpIndex].isUndef()) Ops[OpIndex] = Op; else if (Ops[OpIndex] != Op) return SDValue(); - } - - // Form a 128-bit permutation. - // Convert the 64-bit shuffle mask selection values into 128-bit selection - // bits defined by a vshuf64x2 instruction's immediate control byte. - unsigned PermMask = 0, Imm = 0; - unsigned ControlBitsNum = WidenedMask.size() / 2; - for (int i = 0, Size = WidenedMask.size(); i < Size; ++i) { - // Use first element in place of undef mask. - Imm = (WidenedMask[i] == SM_SentinelUndef) ? 0 : WidenedMask[i]; - PermMask |= (Imm % WidenedMask.size()) << (i * ControlBitsNum); + // Convert the 128-bit shuffle mask selection values into 128-bit selection + // bits defined by a vshuf64x2 instruction's immediate control byte. + PermMask |= (WidenedMask[i] % 4) << (i * 2); } return DAG.getNode(X86ISD::SHUF128, DL, VT, Ops[0], Ops[1], DAG.getConstant(PermMask, DL, MVT::i8)); } -static SDValue lowerVectorShuffleWithPERMV(const SDLoc &DL, MVT VT, - ArrayRef<int> Mask, SDValue V1, - SDValue V2, SelectionDAG &DAG) { - - assert(VT.getScalarSizeInBits() >= 16 && "Unexpected data type for PERMV"); - - MVT MaskEltVT = MVT::getIntegerVT(VT.getScalarSizeInBits()); - MVT MaskVecVT = MVT::getVectorVT(MaskEltVT, VT.getVectorNumElements()); - - SDValue MaskNode = getConstVector(Mask, MaskVecVT, DAG, DL, true); - if (V2.isUndef()) - return DAG.getNode(X86ISD::VPERMV, DL, VT, MaskNode, V1); - - return DAG.getNode(X86ISD::VPERMV3, DL, VT, V1, MaskNode, V2); -} - /// \brief Handle lowering of 8-lane 64-bit floating point shuffles. static SDValue lowerV8F64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, + const SmallBitVector &Zeroable, SDValue V1, SDValue V2, const X86Subtarget &Subtarget, SelectionDAG &DAG) { @@ -11875,11 +12887,16 @@ static SDValue lowerV8F64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, lowerVectorShuffleWithSHUFPD(DL, MVT::v8f64, Mask, V1, V2, DAG)) return Op; + if (SDValue V = lowerVectorShuffleToEXPAND(DL, MVT::v8f64, Zeroable, Mask, V1, + V2, DAG, Subtarget)) + return V; + return lowerVectorShuffleWithPERMV(DL, MVT::v8f64, Mask, V1, V2, DAG); } /// \brief Handle lowering of 16-lane 32-bit floating point shuffles. static SDValue lowerV16F32VectorShuffle(SDLoc DL, ArrayRef<int> Mask, + const SmallBitVector &Zeroable, SDValue V1, SDValue V2, const X86Subtarget &Subtarget, SelectionDAG &DAG) { @@ -11911,12 +12928,17 @@ static SDValue lowerV16F32VectorShuffle(SDLoc DL, ArrayRef<int> Mask, // Otherwise, fall back to a SHUFPS sequence. return lowerVectorShuffleWithSHUFPS(DL, MVT::v16f32, RepeatedMask, V1, V2, DAG); } + // If we have AVX512F support, we can use VEXPAND. + if (SDValue V = lowerVectorShuffleToEXPAND(DL, MVT::v16f32, Zeroable, Mask, + V1, V2, DAG, Subtarget)) + return V; return lowerVectorShuffleWithPERMV(DL, MVT::v16f32, Mask, V1, V2, DAG); } /// \brief Handle lowering of 8-lane 64-bit integer shuffles. static SDValue lowerV8I64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, + const SmallBitVector &Zeroable, SDValue V1, SDValue V2, const X86Subtarget &Subtarget, SelectionDAG &DAG) { @@ -11951,18 +12973,33 @@ static SDValue lowerV8I64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // Try to use shift instructions. if (SDValue Shift = lowerVectorShuffleAsShift(DL, MVT::v8i64, V1, V2, Mask, - Subtarget, DAG)) + Zeroable, Subtarget, DAG)) return Shift; + // Try to use VALIGN. + if (SDValue Rotate = lowerVectorShuffleAsRotate(DL, MVT::v8i64, V1, V2, + Mask, Subtarget, DAG)) + return Rotate; + + // Try to use PALIGNR. + if (SDValue Rotate = lowerVectorShuffleAsByteRotate(DL, MVT::v8i64, V1, V2, + Mask, Subtarget, DAG)) + return Rotate; + if (SDValue Unpck = lowerVectorShuffleWithUNPCK(DL, MVT::v8i64, Mask, V1, V2, DAG)) return Unpck; + // If we have AVX512F support, we can use VEXPAND. + if (SDValue V = lowerVectorShuffleToEXPAND(DL, MVT::v8i64, Zeroable, Mask, V1, + V2, DAG, Subtarget)) + return V; return lowerVectorShuffleWithPERMV(DL, MVT::v8i64, Mask, V1, V2, DAG); } /// \brief Handle lowering of 16-lane 32-bit integer shuffles. static SDValue lowerV16I32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, + const SmallBitVector &Zeroable, SDValue V1, SDValue V2, const X86Subtarget &Subtarget, SelectionDAG &DAG) { @@ -11970,11 +13007,20 @@ static SDValue lowerV16I32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, assert(V2.getSimpleValueType() == MVT::v16i32 && "Bad operand type!"); assert(Mask.size() == 16 && "Unexpected mask size for v16 shuffle!"); + // Whenever we can lower this as a zext, that instruction is strictly faster + // than any alternative. It also allows us to fold memory operands into the + // shuffle in many cases. + if (SDValue ZExt = lowerVectorShuffleAsZeroOrAnyExtend( + DL, MVT::v16i32, V1, V2, Mask, Zeroable, Subtarget, DAG)) + return ZExt; + // If the shuffle mask is repeated in each 128-bit lane we can use more // efficient instructions that mirror the shuffles across the four 128-bit // lanes. SmallVector<int, 4> RepeatedMask; - if (is128BitLaneRepeatedShuffleMask(MVT::v16i32, Mask, RepeatedMask)) { + bool Is128BitLaneRepeatedShuffle = + is128BitLaneRepeatedShuffleMask(MVT::v16i32, Mask, RepeatedMask); + if (Is128BitLaneRepeatedShuffle) { assert(RepeatedMask.size() == 4 && "Unexpected repeated mask size!"); if (V2.isUndef()) return DAG.getNode(X86ISD::PSHUFD, DL, MVT::v16i32, V1, @@ -11988,20 +13034,40 @@ static SDValue lowerV16I32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // Try to use shift instructions. if (SDValue Shift = lowerVectorShuffleAsShift(DL, MVT::v16i32, V1, V2, Mask, - Subtarget, DAG)) + Zeroable, Subtarget, DAG)) return Shift; + // Try to use VALIGN. + if (SDValue Rotate = lowerVectorShuffleAsRotate(DL, MVT::v16i32, V1, V2, + Mask, Subtarget, DAG)) + return Rotate; + // Try to use byte rotation instructions. if (Subtarget.hasBWI()) if (SDValue Rotate = lowerVectorShuffleAsByteRotate( DL, MVT::v16i32, V1, V2, Mask, Subtarget, DAG)) return Rotate; + // Assume that a single SHUFPS is faster than using a permv shuffle. + // If some CPU is harmed by the domain switch, we can fix it in a later pass. + if (Is128BitLaneRepeatedShuffle && isSingleSHUFPSMask(RepeatedMask)) { + SDValue CastV1 = DAG.getBitcast(MVT::v16f32, V1); + SDValue CastV2 = DAG.getBitcast(MVT::v16f32, V2); + SDValue ShufPS = lowerVectorShuffleWithSHUFPS(DL, MVT::v16f32, RepeatedMask, + CastV1, CastV2, DAG); + return DAG.getBitcast(MVT::v16i32, ShufPS); + } + // If we have AVX512F support, we can use VEXPAND. + if (SDValue V = lowerVectorShuffleToEXPAND(DL, MVT::v16i32, Zeroable, Mask, + V1, V2, DAG, Subtarget)) + return V; + return lowerVectorShuffleWithPERMV(DL, MVT::v16i32, Mask, V1, V2, DAG); } /// \brief Handle lowering of 32-lane 16-bit integer shuffles. static SDValue lowerV32I16VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, + const SmallBitVector &Zeroable, SDValue V1, SDValue V2, const X86Subtarget &Subtarget, SelectionDAG &DAG) { @@ -12010,6 +13076,13 @@ static SDValue lowerV32I16VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, assert(Mask.size() == 32 && "Unexpected mask size for v32 shuffle!"); assert(Subtarget.hasBWI() && "We can only lower v32i16 with AVX-512-BWI!"); + // Whenever we can lower this as a zext, that instruction is strictly faster + // than any alternative. It also allows us to fold memory operands into the + // shuffle in many cases. + if (SDValue ZExt = lowerVectorShuffleAsZeroOrAnyExtend( + DL, MVT::v32i16, V1, V2, Mask, Zeroable, Subtarget, DAG)) + return ZExt; + // Use dedicated unpack instructions for masks that match their pattern. if (SDValue V = lowerVectorShuffleWithUNPCK(DL, MVT::v32i16, Mask, V1, V2, DAG)) @@ -12017,7 +13090,7 @@ static SDValue lowerV32I16VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // Try to use shift instructions. if (SDValue Shift = lowerVectorShuffleAsShift(DL, MVT::v32i16, V1, V2, Mask, - Subtarget, DAG)) + Zeroable, Subtarget, DAG)) return Shift; // Try to use byte rotation instructions. @@ -12041,6 +13114,7 @@ static SDValue lowerV32I16VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, /// \brief Handle lowering of 64-lane 8-bit integer shuffles. static SDValue lowerV64I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, + const SmallBitVector &Zeroable, SDValue V1, SDValue V2, const X86Subtarget &Subtarget, SelectionDAG &DAG) { @@ -12049,6 +13123,13 @@ static SDValue lowerV64I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, assert(Mask.size() == 64 && "Unexpected mask size for v64 shuffle!"); assert(Subtarget.hasBWI() && "We can only lower v64i8 with AVX-512-BWI!"); + // Whenever we can lower this as a zext, that instruction is strictly faster + // than any alternative. It also allows us to fold memory operands into the + // shuffle in many cases. + if (SDValue ZExt = lowerVectorShuffleAsZeroOrAnyExtend( + DL, MVT::v64i8, V1, V2, Mask, Zeroable, Subtarget, DAG)) + return ZExt; + // Use dedicated unpack instructions for masks that match their pattern. if (SDValue V = lowerVectorShuffleWithUNPCK(DL, MVT::v64i8, Mask, V1, V2, DAG)) @@ -12056,7 +13137,7 @@ static SDValue lowerV64I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // Try to use shift instructions. if (SDValue Shift = lowerVectorShuffleAsShift(DL, MVT::v64i8, V1, V2, Mask, - Subtarget, DAG)) + Zeroable, Subtarget, DAG)) return Shift; // Try to use byte rotation instructions. @@ -12064,10 +13145,20 @@ static SDValue lowerV64I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, DL, MVT::v64i8, V1, V2, Mask, Subtarget, DAG)) return Rotate; - if (SDValue PSHUFB = lowerVectorShuffleWithPSHUFB(DL, MVT::v64i8, Mask, V1, - V2, Subtarget, DAG)) + if (SDValue PSHUFB = lowerVectorShuffleWithPSHUFB( + DL, MVT::v64i8, Mask, V1, V2, Zeroable, Subtarget, DAG)) return PSHUFB; + // VBMI can use VPERMV/VPERMV3 byte shuffles. + if (Subtarget.hasVBMI()) + return lowerVectorShuffleWithPERMV(DL, MVT::v64i8, Mask, V1, V2, DAG); + + // Try to create an in-lane repeating shuffle mask and then shuffle the + // the results into the target lanes. + if (SDValue V = lowerShuffleAsRepeatedMaskAndLanePermute( + DL, MVT::v64i8, V1, V2, Mask, Subtarget, DAG)) + return V; + // FIXME: Implement direct support for this type! return splitAndLowerVectorShuffle(DL, MVT::v64i8, V1, V2, Mask, DAG); } @@ -12079,11 +13170,22 @@ static SDValue lowerV64I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, /// together based on the available instructions. static SDValue lower512BitVectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, MVT VT, SDValue V1, SDValue V2, + const SmallBitVector &Zeroable, const X86Subtarget &Subtarget, SelectionDAG &DAG) { assert(Subtarget.hasAVX512() && "Cannot lower 512-bit vectors w/ basic ISA!"); + // If we have a single input to the zero element, insert that into V1 if we + // can do so cheaply. + int NumElts = Mask.size(); + int NumV2Elements = count_if(Mask, [NumElts](int M) { return M >= NumElts; }); + + if (NumV2Elements == 1 && Mask[0] >= NumElts) + if (SDValue Insertion = lowerVectorShuffleAsElementInsertion( + DL, VT, V1, V2, Mask, Zeroable, Subtarget, DAG)) + return Insertion; + // Check for being able to broadcast a single element. if (SDValue Broadcast = lowerVectorShuffleAsBroadcast(DL, VT, V1, V2, Mask, Subtarget, DAG)) @@ -12095,17 +13197,17 @@ static SDValue lower512BitVectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // the requisite ISA extensions for that element type are available. switch (VT.SimpleTy) { case MVT::v8f64: - return lowerV8F64VectorShuffle(DL, Mask, V1, V2, Subtarget, DAG); + return lowerV8F64VectorShuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG); case MVT::v16f32: - return lowerV16F32VectorShuffle(DL, Mask, V1, V2, Subtarget, DAG); + return lowerV16F32VectorShuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG); case MVT::v8i64: - return lowerV8I64VectorShuffle(DL, Mask, V1, V2, Subtarget, DAG); + return lowerV8I64VectorShuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG); case MVT::v16i32: - return lowerV16I32VectorShuffle(DL, Mask, V1, V2, Subtarget, DAG); + return lowerV16I32VectorShuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG); case MVT::v32i16: - return lowerV32I16VectorShuffle(DL, Mask, V1, V2, Subtarget, DAG); + return lowerV32I16VectorShuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG); case MVT::v64i8: - return lowerV64I8VectorShuffle(DL, Mask, V1, V2, Subtarget, DAG); + return lowerV64I8VectorShuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG); default: llvm_unreachable("Not a valid 512-bit x86 vector type!"); @@ -12161,9 +13263,81 @@ static SDValue lower1BitVectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, V2 = getOnesVector(ExtVT, Subtarget, DAG, DL); else V2 = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtVT, V2); - return DAG.getNode(ISD::TRUNCATE, DL, VT, - DAG.getVectorShuffle(ExtVT, DL, V1, V2, Mask)); + + SDValue Shuffle = DAG.getVectorShuffle(ExtVT, DL, V1, V2, Mask); + // i1 was sign extended we can use X86ISD::CVT2MASK. + int NumElems = VT.getVectorNumElements(); + if ((Subtarget.hasBWI() && (NumElems >= 32)) || + (Subtarget.hasDQI() && (NumElems < 32))) + return DAG.getNode(X86ISD::CVT2MASK, DL, VT, Shuffle); + + return DAG.getNode(ISD::TRUNCATE, DL, VT, Shuffle); } + +/// Helper function that returns true if the shuffle mask should be +/// commuted to improve canonicalization. +static bool canonicalizeShuffleMaskWithCommute(ArrayRef<int> Mask) { + int NumElements = Mask.size(); + + int NumV1Elements = 0, NumV2Elements = 0; + for (int M : Mask) + if (M < 0) + continue; + else if (M < NumElements) + ++NumV1Elements; + else + ++NumV2Elements; + + // Commute the shuffle as needed such that more elements come from V1 than + // V2. This allows us to match the shuffle pattern strictly on how many + // elements come from V1 without handling the symmetric cases. + if (NumV2Elements > NumV1Elements) + return true; + + assert(NumV1Elements > 0 && "No V1 indices"); + + if (NumV2Elements == 0) + return false; + + // When the number of V1 and V2 elements are the same, try to minimize the + // number of uses of V2 in the low half of the vector. When that is tied, + // ensure that the sum of indices for V1 is equal to or lower than the sum + // indices for V2. When those are equal, try to ensure that the number of odd + // indices for V1 is lower than the number of odd indices for V2. + if (NumV1Elements == NumV2Elements) { + int LowV1Elements = 0, LowV2Elements = 0; + for (int M : Mask.slice(0, NumElements / 2)) + if (M >= NumElements) + ++LowV2Elements; + else if (M >= 0) + ++LowV1Elements; + if (LowV2Elements > LowV1Elements) + return true; + if (LowV2Elements == LowV1Elements) { + int SumV1Indices = 0, SumV2Indices = 0; + for (int i = 0, Size = Mask.size(); i < Size; ++i) + if (Mask[i] >= NumElements) + SumV2Indices += i; + else if (Mask[i] >= 0) + SumV1Indices += i; + if (SumV2Indices < SumV1Indices) + return true; + if (SumV2Indices == SumV1Indices) { + int NumV1OddIndices = 0, NumV2OddIndices = 0; + for (int i = 0, Size = Mask.size(); i < Size; ++i) + if (Mask[i] >= NumElements) + NumV2OddIndices += i % 2; + else if (Mask[i] >= 0) + NumV1OddIndices += i % 2; + if (NumV2OddIndices < NumV1OddIndices) + return true; + } + } + } + + return false; +} + /// \brief Top-level lowering for x86 vector shuffles. /// /// This handles decomposition, canonicalization, and lowering of all x86 @@ -12209,6 +13383,12 @@ static SDValue lowerVectorShuffle(SDValue Op, const X86Subtarget &Subtarget, return DAG.getVectorShuffle(VT, DL, V1, V2, NewMask); } + // Check for illegal shuffle mask element index values. + int MaskUpperLimit = Mask.size() * (V2IsUndef ? 1 : 2); (void)MaskUpperLimit; + assert(llvm::all_of(Mask, + [&](int M) { return -1 <= M && M < MaskUpperLimit; }) && + "Out of bounds shuffle index"); + // We actually see shuffles that are entirely re-arrangements of a set of // zero inputs. This mostly happens while decomposing complex shuffles into // simple ones. Directly lower these as a buildvector of zeros. @@ -12237,69 +13417,22 @@ static SDValue lowerVectorShuffle(SDValue Op, const X86Subtarget &Subtarget, } } - int NumV1Elements = 0, NumUndefElements = 0, NumV2Elements = 0; - for (int M : Mask) - if (M < 0) - ++NumUndefElements; - else if (M < NumElements) - ++NumV1Elements; - else - ++NumV2Elements; - - // Commute the shuffle as needed such that more elements come from V1 than - // V2. This allows us to match the shuffle pattern strictly on how many - // elements come from V1 without handling the symmetric cases. - if (NumV2Elements > NumV1Elements) + // Commute the shuffle if it will improve canonicalization. + if (canonicalizeShuffleMaskWithCommute(Mask)) return DAG.getCommutedVectorShuffle(*SVOp); - assert(NumV1Elements > 0 && "No V1 indices"); - assert((NumV2Elements > 0 || V2IsUndef) && "V2 not undef, but not used"); - - // When the number of V1 and V2 elements are the same, try to minimize the - // number of uses of V2 in the low half of the vector. When that is tied, - // ensure that the sum of indices for V1 is equal to or lower than the sum - // indices for V2. When those are equal, try to ensure that the number of odd - // indices for V1 is lower than the number of odd indices for V2. - if (NumV1Elements == NumV2Elements) { - int LowV1Elements = 0, LowV2Elements = 0; - for (int M : Mask.slice(0, NumElements / 2)) - if (M >= NumElements) - ++LowV2Elements; - else if (M >= 0) - ++LowV1Elements; - if (LowV2Elements > LowV1Elements) - return DAG.getCommutedVectorShuffle(*SVOp); - if (LowV2Elements == LowV1Elements) { - int SumV1Indices = 0, SumV2Indices = 0; - for (int i = 0, Size = Mask.size(); i < Size; ++i) - if (Mask[i] >= NumElements) - SumV2Indices += i; - else if (Mask[i] >= 0) - SumV1Indices += i; - if (SumV2Indices < SumV1Indices) - return DAG.getCommutedVectorShuffle(*SVOp); - if (SumV2Indices == SumV1Indices) { - int NumV1OddIndices = 0, NumV2OddIndices = 0; - for (int i = 0, Size = Mask.size(); i < Size; ++i) - if (Mask[i] >= NumElements) - NumV2OddIndices += i % 2; - else if (Mask[i] >= 0) - NumV1OddIndices += i % 2; - if (NumV2OddIndices < NumV1OddIndices) - return DAG.getCommutedVectorShuffle(*SVOp); - } - } - } - // For each vector width, delegate to a specialized lowering routine. if (VT.is128BitVector()) - return lower128BitVectorShuffle(DL, Mask, VT, V1, V2, Subtarget, DAG); + return lower128BitVectorShuffle(DL, Mask, VT, V1, V2, Zeroable, Subtarget, + DAG); if (VT.is256BitVector()) - return lower256BitVectorShuffle(DL, Mask, VT, V1, V2, Subtarget, DAG); + return lower256BitVectorShuffle(DL, Mask, VT, V1, V2, Zeroable, Subtarget, + DAG); if (VT.is512BitVector()) - return lower512BitVectorShuffle(DL, Mask, VT, V1, V2, Subtarget, DAG); + return lower512BitVectorShuffle(DL, Mask, VT, V1, V2, Zeroable, Subtarget, + DAG); if (Is1BitVector) return lower1BitVectorShuffle(DL, Mask, VT, V1, V2, Subtarget, DAG); @@ -12392,21 +13525,6 @@ static SDValue LowerEXTRACT_VECTOR_ELT_SSE4(SDValue Op, SelectionDAG &DAG) { return DAG.getNode(ISD::TRUNCATE, dl, VT, Assert); } - if (VT.getSizeInBits() == 16) { - // If Idx is 0, it's cheaper to do a move instead of a pextrw. - if (isNullConstant(Op.getOperand(1))) - return DAG.getNode( - ISD::TRUNCATE, dl, MVT::i16, - DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i32, - DAG.getBitcast(MVT::v4i32, Op.getOperand(0)), - Op.getOperand(1))); - SDValue Extract = DAG.getNode(X86ISD::PEXTRW, dl, MVT::i32, - Op.getOperand(0), Op.getOperand(1)); - SDValue Assert = DAG.getNode(ISD::AssertZext, dl, MVT::i32, Extract, - DAG.getValueType(VT)); - return DAG.getNode(ISD::TRUNCATE, dl, VT, Assert); - } - if (VT == MVT::f32) { // EXTRACTPS outputs to a GPR32 register which will require a movd to copy // the result back to FR32 register. It's only worth matching if the @@ -12432,6 +13550,7 @@ static SDValue LowerEXTRACT_VECTOR_ELT_SSE4(SDValue Op, SelectionDAG &DAG) { if (isa<ConstantSDNode>(Op.getOperand(1))) return Op; } + return SDValue(); } @@ -12460,7 +13579,8 @@ X86TargetLowering::ExtractBitFromMaskVector(SDValue Op, SelectionDAG &DAG) const } unsigned IdxVal = cast<ConstantSDNode>(Idx)->getZExtValue(); - if (!Subtarget.hasDQI() && (VecVT.getVectorNumElements() <= 8)) { + if ((!Subtarget.hasDQI() && (VecVT.getVectorNumElements() == 8)) || + (VecVT.getVectorNumElements() < 8)) { // Use kshiftlw/rw instruction. VecVT = MVT::v16i1; Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, VecVT, @@ -12469,8 +13589,9 @@ X86TargetLowering::ExtractBitFromMaskVector(SDValue Op, SelectionDAG &DAG) const DAG.getIntPtrConstant(0, dl)); } unsigned MaxSift = VecVT.getVectorNumElements() - 1; - Vec = DAG.getNode(X86ISD::VSHLI, dl, VecVT, Vec, - DAG.getConstant(MaxSift - IdxVal, dl, MVT::i8)); + if (MaxSift - IdxVal) + Vec = DAG.getNode(X86ISD::VSHLI, dl, VecVT, Vec, + DAG.getConstant(MaxSift - IdxVal, dl, MVT::i8)); Vec = DAG.getNode(X86ISD::VSRLI, dl, VecVT, Vec, DAG.getConstant(MaxSift, dl, MVT::i8)); return DAG.getNode(X86ISD::VEXTRACT, dl, MVT::i1, Vec, @@ -12491,10 +13612,10 @@ X86TargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op, if (!isa<ConstantSDNode>(Idx)) { if (VecVT.is512BitVector() || (VecVT.is256BitVector() && Subtarget.hasInt256() && - VecVT.getVectorElementType().getSizeInBits() == 32)) { + VecVT.getScalarSizeInBits() == 32)) { MVT MaskEltVT = - MVT::getIntegerVT(VecVT.getVectorElementType().getSizeInBits()); + MVT::getIntegerVT(VecVT.getScalarSizeInBits()); MVT MaskVT = MVT::getVectorVT(MaskEltVT, VecVT.getSizeInBits() / MaskEltVT.getSizeInBits()); @@ -12531,26 +13652,31 @@ X86TargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op, assert(VecVT.is128BitVector() && "Unexpected vector length"); - if (Subtarget.hasSSE41()) - if (SDValue Res = LowerEXTRACT_VECTOR_ELT_SSE4(Op, DAG)) - return Res; - MVT VT = Op.getSimpleValueType(); - // TODO: handle v16i8. + if (VT.getSizeInBits() == 16) { - if (IdxVal == 0) + // If IdxVal is 0, it's cheaper to do a move instead of a pextrw, unless + // we're going to zero extend the register or fold the store (SSE41 only). + if (IdxVal == 0 && !MayFoldIntoZeroExtend(Op) && + !(Subtarget.hasSSE41() && MayFoldIntoStore(Op))) return DAG.getNode(ISD::TRUNCATE, dl, MVT::i16, DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i32, DAG.getBitcast(MVT::v4i32, Vec), Idx)); // Transform it so it match pextrw which produces a 32-bit result. - MVT EltVT = MVT::i32; - SDValue Extract = DAG.getNode(X86ISD::PEXTRW, dl, EltVT, Vec, Idx); - SDValue Assert = DAG.getNode(ISD::AssertZext, dl, EltVT, Extract, + SDValue Extract = DAG.getNode(X86ISD::PEXTRW, dl, MVT::i32, + Op.getOperand(0), Op.getOperand(1)); + SDValue Assert = DAG.getNode(ISD::AssertZext, dl, MVT::i32, Extract, DAG.getValueType(VT)); return DAG.getNode(ISD::TRUNCATE, dl, VT, Assert); } + if (Subtarget.hasSSE41()) + if (SDValue Res = LowerEXTRACT_VECTOR_ELT_SSE4(Op, DAG)) + return Res; + + // TODO: handle v16i8. + if (VT.getSizeInBits() == 32) { if (IdxVal == 0) return Op; @@ -12604,12 +13730,46 @@ X86TargetLowering::InsertBitToMaskVector(SDValue Op, SelectionDAG &DAG) const { unsigned IdxVal = cast<ConstantSDNode>(Idx)->getZExtValue(); SDValue EltInVec = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VecVT, Elt); - if (IdxVal) + unsigned NumElems = VecVT.getVectorNumElements(); + + if(Vec.isUndef()) { + if (IdxVal) + EltInVec = DAG.getNode(X86ISD::VSHLI, dl, VecVT, EltInVec, + DAG.getConstant(IdxVal, dl, MVT::i8)); + return EltInVec; + } + + // Insertion of one bit into first or last position + // can be done with two SHIFTs + OR. + if (IdxVal == 0 ) { + // EltInVec already at correct index and other bits are 0. + // Clean the first bit in source vector. + Vec = DAG.getNode(X86ISD::VSRLI, dl, VecVT, Vec, + DAG.getConstant(1 , dl, MVT::i8)); + Vec = DAG.getNode(X86ISD::VSHLI, dl, VecVT, Vec, + DAG.getConstant(1, dl, MVT::i8)); + + return DAG.getNode(ISD::OR, dl, VecVT, Vec, EltInVec); + } + if (IdxVal == NumElems -1) { + // Move the bit to the last position inside the vector. EltInVec = DAG.getNode(X86ISD::VSHLI, dl, VecVT, EltInVec, DAG.getConstant(IdxVal, dl, MVT::i8)); - if (Vec.isUndef()) - return EltInVec; - return DAG.getNode(ISD::OR, dl, VecVT, Vec, EltInVec); + // Clean the last bit in the source vector. + Vec = DAG.getNode(X86ISD::VSHLI, dl, VecVT, Vec, + DAG.getConstant(1, dl, MVT::i8)); + Vec = DAG.getNode(X86ISD::VSRLI, dl, VecVT, Vec, + DAG.getConstant(1 , dl, MVT::i8)); + + return DAG.getNode(ISD::OR, dl, VecVT, Vec, EltInVec); + } + + // Use shuffle to insert element. + SmallVector<int, 64> MaskVec(NumElems); + for (unsigned i = 0; i != NumElems; ++i) + MaskVec[i] = (i == IdxVal) ? NumElems : i; + + return DAG.getVectorShuffle(VecVT, dl, Vec, EltInVec, MaskVec); } SDValue X86TargetLowering::LowerINSERT_VECTOR_ELT(SDValue Op, @@ -12764,10 +13924,6 @@ static SDValue LowerSCALAR_TO_VECTOR(SDValue Op, SelectionDAG &DAG) { return insert128BitVector(DAG.getUNDEF(OpVT), Op, 0, DAG, dl); } - if (OpVT == MVT::v1i64 && - Op.getOperand(0).getValueType() == MVT::i64) - return DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, MVT::v1i64, Op.getOperand(0)); - SDValue AnyExt = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i32, Op.getOperand(0)); assert(OpVT.is128BitVector() && "Expected an SSE type!"); return DAG.getBitcast( @@ -12779,25 +13935,32 @@ static SDValue LowerSCALAR_TO_VECTOR(SDValue Op, SelectionDAG &DAG) { // upper bits of a vector. static SDValue LowerEXTRACT_SUBVECTOR(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG) { + assert(Subtarget.hasAVX() && "EXTRACT_SUBVECTOR requires AVX"); + SDLoc dl(Op); SDValue In = Op.getOperand(0); SDValue Idx = Op.getOperand(1); unsigned IdxVal = cast<ConstantSDNode>(Idx)->getZExtValue(); - MVT ResVT = Op.getSimpleValueType(); - MVT InVT = In.getSimpleValueType(); + MVT ResVT = Op.getSimpleValueType(); - if (Subtarget.hasFp256()) { - if (ResVT.is128BitVector() && - (InVT.is256BitVector() || InVT.is512BitVector()) && - isa<ConstantSDNode>(Idx)) { - return extract128BitVector(In, IdxVal, DAG, dl); - } - if (ResVT.is256BitVector() && InVT.is512BitVector() && - isa<ConstantSDNode>(Idx)) { - return extract256BitVector(In, IdxVal, DAG, dl); - } - } - return SDValue(); + assert((In.getSimpleValueType().is256BitVector() || + In.getSimpleValueType().is512BitVector()) && + "Can only extract from 256-bit or 512-bit vectors"); + + if (ResVT.is128BitVector()) + return extract128BitVector(In, IdxVal, DAG, dl); + if (ResVT.is256BitVector()) + return extract256BitVector(In, IdxVal, DAG, dl); + + llvm_unreachable("Unimplemented!"); +} + +static bool areOnlyUsersOf(SDNode *N, ArrayRef<SDValue> ValidUsers) { + for (SDNode::use_iterator I = N->use_begin(), E = N->use_end(); I != E; ++I) + if (llvm::all_of(ValidUsers, + [&I](SDValue V) { return V.getNode() != *I; })) + return false; + return true; } // Lower a node with an INSERT_SUBVECTOR opcode. This may result in a @@ -12805,58 +13968,97 @@ static SDValue LowerEXTRACT_SUBVECTOR(SDValue Op, const X86Subtarget &Subtarget, // the upper bits of a vector. static SDValue LowerINSERT_SUBVECTOR(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG) { - if (!Subtarget.hasAVX()) - return SDValue(); + assert(Subtarget.hasAVX() && "INSERT_SUBVECTOR requires AVX"); SDLoc dl(Op); SDValue Vec = Op.getOperand(0); SDValue SubVec = Op.getOperand(1); SDValue Idx = Op.getOperand(2); - if (!isa<ConstantSDNode>(Idx)) - return SDValue(); - unsigned IdxVal = cast<ConstantSDNode>(Idx)->getZExtValue(); MVT OpVT = Op.getSimpleValueType(); MVT SubVecVT = SubVec.getSimpleValueType(); - // Fold two 16-byte subvector loads into one 32-byte load: - // (insert_subvector (insert_subvector undef, (load addr), 0), - // (load addr + 16), Elts/2) + if (OpVT.getVectorElementType() == MVT::i1) + return insert1BitVector(Op, DAG, Subtarget); + + assert((OpVT.is256BitVector() || OpVT.is512BitVector()) && + "Can only insert into 256-bit or 512-bit vectors"); + + // Fold two 16-byte or 32-byte subvector loads into one 32-byte or 64-byte + // load: + // (insert_subvector (insert_subvector undef, (load16 addr), 0), + // (load16 addr + 16), Elts/2) // --> load32 addr + // or: + // (insert_subvector (insert_subvector undef, (load32 addr), 0), + // (load32 addr + 32), Elts/2) + // --> load64 addr + // or a 16-byte or 32-byte broadcast: + // (insert_subvector (insert_subvector undef, (load16 addr), 0), + // (load16 addr), Elts/2) + // --> X86SubVBroadcast(load16 addr) + // or: + // (insert_subvector (insert_subvector undef, (load32 addr), 0), + // (load32 addr), Elts/2) + // --> X86SubVBroadcast(load32 addr) if ((IdxVal == OpVT.getVectorNumElements() / 2) && Vec.getOpcode() == ISD::INSERT_SUBVECTOR && - OpVT.is256BitVector() && SubVecVT.is128BitVector()) { + OpVT.getSizeInBits() == SubVecVT.getSizeInBits() * 2) { auto *Idx2 = dyn_cast<ConstantSDNode>(Vec.getOperand(2)); if (Idx2 && Idx2->getZExtValue() == 0) { + SDValue SubVec2 = Vec.getOperand(1); // If needed, look through bitcasts to get to the load. - SDValue SubVec2 = peekThroughBitcasts(Vec.getOperand(1)); - if (auto *FirstLd = dyn_cast<LoadSDNode>(SubVec2)) { + if (auto *FirstLd = dyn_cast<LoadSDNode>(peekThroughBitcasts(SubVec2))) { bool Fast; unsigned Alignment = FirstLd->getAlignment(); unsigned AS = FirstLd->getAddressSpace(); const X86TargetLowering *TLI = Subtarget.getTargetLowering(); if (TLI->allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), OpVT, AS, Alignment, &Fast) && Fast) { - SDValue Ops[] = { SubVec2, SubVec }; + SDValue Ops[] = {SubVec2, SubVec}; if (SDValue Ld = EltsFromConsecutiveLoads(OpVT, Ops, dl, DAG, false)) return Ld; } } + // If lower/upper loads are the same and the only users of the load, then + // lower to a VBROADCASTF128/VBROADCASTI128/etc. + if (auto *Ld = dyn_cast<LoadSDNode>(peekThroughOneUseBitcasts(SubVec2))) { + if (SubVec2 == SubVec && ISD::isNormalLoad(Ld) && + areOnlyUsersOf(SubVec2.getNode(), {Op, Vec})) { + return DAG.getNode(X86ISD::SUBV_BROADCAST, dl, OpVT, SubVec); + } + } + // If this is subv_broadcast insert into both halves, use a larger + // subv_broadcast. + if (SubVec.getOpcode() == X86ISD::SUBV_BROADCAST && SubVec == SubVec2) { + return DAG.getNode(X86ISD::SUBV_BROADCAST, dl, OpVT, + SubVec.getOperand(0)); + } } } - if ((OpVT.is256BitVector() || OpVT.is512BitVector()) && - SubVecVT.is128BitVector()) + if (SubVecVT.is128BitVector()) return insert128BitVector(Vec, SubVec, IdxVal, DAG, dl); - if (OpVT.is512BitVector() && SubVecVT.is256BitVector()) + if (SubVecVT.is256BitVector()) return insert256BitVector(Vec, SubVec, IdxVal, DAG, dl); - if (OpVT.getVectorElementType() == MVT::i1) - return insert1BitVector(Op, DAG, Subtarget); + llvm_unreachable("Unimplemented!"); +} - return SDValue(); +// Returns the appropriate wrapper opcode for a global reference. +unsigned X86TargetLowering::getGlobalWrapperKind(const GlobalValue *GV) const { + // References to absolute symbols are never PC-relative. + if (GV && GV->isAbsoluteSymbolRef()) + return X86ISD::Wrapper; + + CodeModel::Model M = getTargetMachine().getCodeModel(); + if (Subtarget.isPICStyleRIPRel() && + (M == CodeModel::Small || M == CodeModel::Kernel)) + return X86ISD::WrapperRIP; + + return X86ISD::Wrapper; } // ConstantPool, JumpTable, GlobalAddress, and ExternalSymbol are lowered as @@ -12872,18 +14074,12 @@ X86TargetLowering::LowerConstantPool(SDValue Op, SelectionDAG &DAG) const { // In PIC mode (unless we're in RIPRel PIC mode) we add an offset to the // global base reg. unsigned char OpFlag = Subtarget.classifyLocalReference(nullptr); - unsigned WrapperKind = X86ISD::Wrapper; - CodeModel::Model M = DAG.getTarget().getCodeModel(); - - if (Subtarget.isPICStyleRIPRel() && - (M == CodeModel::Small || M == CodeModel::Kernel)) - WrapperKind = X86ISD::WrapperRIP; auto PtrVT = getPointerTy(DAG.getDataLayout()); SDValue Result = DAG.getTargetConstantPool( CP->getConstVal(), PtrVT, CP->getAlignment(), CP->getOffset(), OpFlag); SDLoc DL(CP); - Result = DAG.getNode(WrapperKind, DL, PtrVT, Result); + Result = DAG.getNode(getGlobalWrapperKind(), DL, PtrVT, Result); // With PIC, the address is actually $g + Offset. if (OpFlag) { Result = @@ -12900,17 +14096,11 @@ SDValue X86TargetLowering::LowerJumpTable(SDValue Op, SelectionDAG &DAG) const { // In PIC mode (unless we're in RIPRel PIC mode) we add an offset to the // global base reg. unsigned char OpFlag = Subtarget.classifyLocalReference(nullptr); - unsigned WrapperKind = X86ISD::Wrapper; - CodeModel::Model M = DAG.getTarget().getCodeModel(); - - if (Subtarget.isPICStyleRIPRel() && - (M == CodeModel::Small || M == CodeModel::Kernel)) - WrapperKind = X86ISD::WrapperRIP; auto PtrVT = getPointerTy(DAG.getDataLayout()); SDValue Result = DAG.getTargetJumpTable(JT->getIndex(), PtrVT, OpFlag); SDLoc DL(JT); - Result = DAG.getNode(WrapperKind, DL, PtrVT, Result); + Result = DAG.getNode(getGlobalWrapperKind(), DL, PtrVT, Result); // With PIC, the address is actually $g + Offset. if (OpFlag) @@ -12929,18 +14119,12 @@ X86TargetLowering::LowerExternalSymbol(SDValue Op, SelectionDAG &DAG) const { // global base reg. const Module *Mod = DAG.getMachineFunction().getFunction()->getParent(); unsigned char OpFlag = Subtarget.classifyGlobalReference(nullptr, *Mod); - unsigned WrapperKind = X86ISD::Wrapper; - CodeModel::Model M = DAG.getTarget().getCodeModel(); - - if (Subtarget.isPICStyleRIPRel() && - (M == CodeModel::Small || M == CodeModel::Kernel)) - WrapperKind = X86ISD::WrapperRIP; auto PtrVT = getPointerTy(DAG.getDataLayout()); SDValue Result = DAG.getTargetExternalSymbol(Sym, PtrVT, OpFlag); SDLoc DL(Op); - Result = DAG.getNode(WrapperKind, DL, PtrVT, Result); + Result = DAG.getNode(getGlobalWrapperKind(), DL, PtrVT, Result); // With PIC, the address is actually $g + Offset. if (isPositionIndependent() && !Subtarget.is64Bit()) { @@ -12963,18 +14147,12 @@ X86TargetLowering::LowerBlockAddress(SDValue Op, SelectionDAG &DAG) const { // Create the TargetBlockAddressAddress node. unsigned char OpFlags = Subtarget.classifyBlockAddressReference(); - CodeModel::Model M = DAG.getTarget().getCodeModel(); const BlockAddress *BA = cast<BlockAddressSDNode>(Op)->getBlockAddress(); int64_t Offset = cast<BlockAddressSDNode>(Op)->getOffset(); SDLoc dl(Op); auto PtrVT = getPointerTy(DAG.getDataLayout()); SDValue Result = DAG.getTargetBlockAddress(BA, PtrVT, Offset, OpFlags); - - if (Subtarget.isPICStyleRIPRel() && - (M == CodeModel::Small || M == CodeModel::Kernel)) - Result = DAG.getNode(X86ISD::WrapperRIP, dl, PtrVT, Result); - else - Result = DAG.getNode(X86ISD::Wrapper, dl, PtrVT, Result); + Result = DAG.getNode(getGlobalWrapperKind(), dl, PtrVT, Result); // With PIC, the address is actually $g + Offset. if (isGlobalRelativeToPICBase(OpFlags)) { @@ -13003,11 +14181,7 @@ SDValue X86TargetLowering::LowerGlobalAddress(const GlobalValue *GV, Result = DAG.getTargetGlobalAddress(GV, dl, PtrVT, 0, OpFlags); } - if (Subtarget.isPICStyleRIPRel() && - (M == CodeModel::Small || M == CodeModel::Kernel)) - Result = DAG.getNode(X86ISD::WrapperRIP, dl, PtrVT, Result); - else - Result = DAG.getNode(X86ISD::Wrapper, dl, PtrVT, Result); + Result = DAG.getNode(getGlobalWrapperKind(GV), dl, PtrVT, Result); // With PIC, the address is actually $g + Offset. if (isGlobalRelativeToPICBase(OpFlags)) { @@ -13041,7 +14215,7 @@ static SDValue GetTLSADDR(SelectionDAG &DAG, SDValue Chain, GlobalAddressSDNode *GA, SDValue *InFlag, const EVT PtrVT, unsigned ReturnReg, unsigned char OperandFlags, bool LocalDynamic = false) { - MachineFrameInfo *MFI = DAG.getMachineFunction().getFrameInfo(); + MachineFrameInfo &MFI = DAG.getMachineFunction().getFrameInfo(); SDVTList NodeTys = DAG.getVTList(MVT::Other, MVT::Glue); SDLoc dl(GA); SDValue TGA = DAG.getTargetGlobalAddress(GA->getGlobal(), dl, @@ -13061,8 +14235,8 @@ GetTLSADDR(SelectionDAG &DAG, SDValue Chain, GlobalAddressSDNode *GA, } // TLSADDR will be codegen'ed as call. Inform MFI that function has calls. - MFI->setAdjustsStack(true); - MFI->setHasCalls(true); + MFI.setAdjustsStack(true); + MFI.setHasCalls(true); SDValue Flag = Chain.getValue(1); return DAG.getCopyFromReg(Chain, dl, ReturnReg, PtrVT, Flag); @@ -13097,7 +14271,7 @@ static SDValue LowerToTLSLocalDynamicModel(GlobalAddressSDNode *GA, SDLoc dl(GA); // Get the start address of the TLS block for this module. - X86MachineFunctionInfo* MFI = DAG.getMachineFunction() + X86MachineFunctionInfo *MFI = DAG.getMachineFunction() .getInfo<X86MachineFunctionInfo>(); MFI->incNumLocalDynamicTLSAccesses(); @@ -13251,8 +14425,8 @@ X86TargetLowering::LowerGlobalTLSAddress(SDValue Op, SelectionDAG &DAG) const { Chain.getValue(1), DL); // TLSCALL will be codegen'ed as call. Inform MFI that function has calls. - MachineFrameInfo *MFI = DAG.getMachineFunction().getFrameInfo(); - MFI->setAdjustsStack(true); + MachineFrameInfo &MFI = DAG.getMachineFunction().getFrameInfo(); + MFI.setAdjustsStack(true); // And our return value (tls address) is in the standard call return value // location. @@ -13395,9 +14569,9 @@ SDValue X86TargetLowering::LowerSINT_TO_FP(SDValue Op, const TargetLowering &TLI = DAG.getTargetLoweringInfo(); if (SrcVT.isVector()) { if (SrcVT == MVT::v2i32 && VT == MVT::v2f64) { - return DAG.getNode(X86ISD::CVTDQ2PD, dl, VT, + return DAG.getNode(X86ISD::CVTSI2P, dl, VT, DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i32, Src, - DAG.getUNDEF(SrcVT))); + DAG.getUNDEF(SrcVT))); } if (SrcVT.getVectorElementType() == MVT::i1) { if (SrcVT == MVT::v2i1 && TLI.isTypeLegal(SrcVT)) @@ -13433,7 +14607,7 @@ SDValue X86TargetLowering::LowerSINT_TO_FP(SDValue Op, unsigned Size = SrcVT.getSizeInBits()/8; MachineFunction &MF = DAG.getMachineFunction(); auto PtrVT = getPointerTy(MF.getDataLayout()); - int SSFI = MF.getFrameInfo()->CreateStackObject(Size, Size, false); + int SSFI = MF.getFrameInfo().CreateStackObject(Size, Size, false); SDValue StackSlot = DAG.getFrameIndex(SSFI, PtrVT); SDValue Chain = DAG.getStore( DAG.getEntryNode(), dl, ValueToStore, StackSlot, @@ -13479,8 +14653,8 @@ SDValue X86TargetLowering::BuildFILD(SDValue Op, EVT SrcVT, SDValue Chain, // shouldn't be necessary except that RFP cannot be live across // multiple blocks. When stackifier is fixed, they can be uncoupled. MachineFunction &MF = DAG.getMachineFunction(); - unsigned SSFISize = Op.getValueType().getSizeInBits()/8; - int SSFI = MF.getFrameInfo()->CreateStackObject(SSFISize, SSFISize, false); + unsigned SSFISize = Op.getValueSizeInBits()/8; + int SSFI = MF.getFrameInfo().CreateStackObject(SSFISize, SSFISize, false); auto PtrVT = getPointerTy(MF.getDataLayout()); SDValue StackSlot = DAG.getFrameIndex(SSFI, PtrVT); Tys = DAG.getVTList(MVT::Other); @@ -13528,10 +14702,10 @@ SDValue X86TargetLowering::LowerUINT_TO_FP_i64(SDValue Op, SmallVector<Constant*,2> CV1; CV1.push_back( - ConstantFP::get(*Context, APFloat(APFloat::IEEEdouble, + ConstantFP::get(*Context, APFloat(APFloat::IEEEdouble(), APInt(64, 0x4330000000000000ULL)))); CV1.push_back( - ConstantFP::get(*Context, APFloat(APFloat::IEEEdouble, + ConstantFP::get(*Context, APFloat(APFloat::IEEEdouble(), APInt(64, 0x4530000000000000ULL)))); Constant *C1 = ConstantVector::get(CV1); SDValue CPIdx1 = DAG.getConstantPool(C1, PtrVT, 16); @@ -13560,8 +14734,7 @@ SDValue X86TargetLowering::LowerUINT_TO_FP_i64(SDValue Op, Result = DAG.getNode(X86ISD::FHADD, dl, MVT::v2f64, Sub, Sub); } else { SDValue S2F = DAG.getBitcast(MVT::v4i32, Sub); - SDValue Shuffle = getTargetShuffleNode(X86ISD::PSHUFD, dl, MVT::v4i32, - S2F, 0x4E, DAG); + SDValue Shuffle = DAG.getVectorShuffle(MVT::v4i32, dl, S2F, S2F, {2,3,0,1}); Result = DAG.getNode(ISD::FADD, dl, MVT::v2f64, DAG.getBitcast(MVT::v2f64, Shuffle), Sub); } @@ -13617,6 +14790,41 @@ SDValue X86TargetLowering::LowerUINT_TO_FP_i32(SDValue Op, return Sub; } +static SDValue lowerUINT_TO_FP_v2i32(SDValue Op, SelectionDAG &DAG, + const X86Subtarget &Subtarget, SDLoc &DL) { + if (Op.getSimpleValueType() != MVT::v2f64) + return SDValue(); + + SDValue N0 = Op.getOperand(0); + assert(N0.getSimpleValueType() == MVT::v2i32 && "Unexpected input type"); + + // Legalize to v4i32 type. + N0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v4i32, N0, + DAG.getUNDEF(MVT::v2i32)); + + if (Subtarget.hasAVX512()) + return DAG.getNode(X86ISD::CVTUI2P, DL, MVT::v2f64, N0); + + // Same implementation as VectorLegalizer::ExpandUINT_TO_FLOAT, + // but using v2i32 to v2f64 with X86ISD::CVTSI2P. + SDValue HalfWord = DAG.getConstant(16, DL, MVT::v4i32); + SDValue HalfWordMask = DAG.getConstant(0x0000FFFF, DL, MVT::v4i32); + + // Two to the power of half-word-size. + SDValue TWOHW = DAG.getConstantFP(1 << 16, DL, MVT::v2f64); + + // Clear upper part of LO, lower HI. + SDValue HI = DAG.getNode(ISD::SRL, DL, MVT::v4i32, N0, HalfWord); + SDValue LO = DAG.getNode(ISD::AND, DL, MVT::v4i32, N0, HalfWordMask); + + SDValue fHI = DAG.getNode(X86ISD::CVTSI2P, DL, MVT::v2f64, HI); + fHI = DAG.getNode(ISD::FMUL, DL, MVT::v2f64, fHI, TWOHW); + SDValue fLO = DAG.getNode(X86ISD::CVTSI2P, DL, MVT::v2f64, LO); + + // Add the two halves. + return DAG.getNode(ISD::FADD, DL, MVT::v2f64, fHI, fLO); +} + static SDValue lowerUINT_TO_FP_vXi32(SDValue Op, SelectionDAG &DAG, const X86Subtarget &Subtarget) { // The algorithm is the following: @@ -13699,7 +14907,7 @@ static SDValue lowerUINT_TO_FP_vXi32(SDValue Op, SelectionDAG &DAG, // Create the vector constant for -(0x1.0p39f + 0x1.0p23f). SDValue VecCstFAdd = DAG.getConstantFP( - APFloat(APFloat::IEEEsingle, APInt(32, 0xD3000080)), DL, VecFloatVT); + APFloat(APFloat::IEEEsingle(), APInt(32, 0xD3000080)), DL, VecFloatVT); // float4 fhi = (float4) hi - (0x1.0p39f + 0x1.0p23f); SDValue HighBitcast = DAG.getBitcast(VecFloatVT, High); @@ -13714,29 +14922,31 @@ static SDValue lowerUINT_TO_FP_vXi32(SDValue Op, SelectionDAG &DAG, SDValue X86TargetLowering::lowerUINT_TO_FP_vec(SDValue Op, SelectionDAG &DAG) const { SDValue N0 = Op.getOperand(0); - MVT SVT = N0.getSimpleValueType(); + MVT SrcVT = N0.getSimpleValueType(); SDLoc dl(Op); - if (SVT.getVectorElementType() == MVT::i1) { - if (SVT == MVT::v2i1) + if (SrcVT.getVectorElementType() == MVT::i1) { + if (SrcVT == MVT::v2i1) return DAG.getNode(ISD::UINT_TO_FP, dl, Op.getValueType(), DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::v2i64, N0)); - MVT IntegerVT = MVT::getVectorVT(MVT::i32, SVT.getVectorNumElements()); + MVT IntegerVT = MVT::getVectorVT(MVT::i32, SrcVT.getVectorNumElements()); return DAG.getNode(ISD::UINT_TO_FP, dl, Op.getValueType(), DAG.getNode(ISD::ZERO_EXTEND, dl, IntegerVT, N0)); } - switch (SVT.SimpleTy) { + switch (SrcVT.SimpleTy) { default: llvm_unreachable("Custom UINT_TO_FP is not supported!"); case MVT::v4i8: case MVT::v4i16: case MVT::v8i8: case MVT::v8i16: { - MVT NVT = MVT::getVectorVT(MVT::i32, SVT.getVectorNumElements()); + MVT NVT = MVT::getVectorVT(MVT::i32, SrcVT.getVectorNumElements()); return DAG.getNode(ISD::SINT_TO_FP, dl, Op.getValueType(), DAG.getNode(ISD::ZERO_EXTEND, dl, NVT, N0)); } + case MVT::v2i32: + return lowerUINT_TO_FP_v2i32(Op, DAG, Subtarget, dl); case MVT::v4i32: case MVT::v8i32: return lowerUINT_TO_FP_vXi32(Op, DAG, Subtarget); @@ -13754,15 +14964,15 @@ SDValue X86TargetLowering::LowerUINT_TO_FP(SDValue Op, SDLoc dl(Op); auto PtrVT = getPointerTy(DAG.getDataLayout()); - if (Op.getSimpleValueType().isVector()) - return lowerUINT_TO_FP_vec(Op, DAG); - // Since UINT_TO_FP is legal (it's marked custom), dag combiner won't // optimize it to a SINT_TO_FP when the sign bit is known zero. Perform // the optimization here. if (DAG.SignBitIsZero(N0)) return DAG.getNode(ISD::SINT_TO_FP, dl, Op.getValueType(), N0); + if (Op.getSimpleValueType().isVector()) + return lowerUINT_TO_FP_vec(Op, DAG); + MVT SrcVT = N0.getSimpleValueType(); MVT DstVT = Op.getSimpleValueType(); @@ -13903,7 +15113,7 @@ X86TargetLowering::FP_TO_INTHelper(SDValue Op, SelectionDAG &DAG, // stack slot. MachineFunction &MF = DAG.getMachineFunction(); unsigned MemSize = DstTy.getSizeInBits()/8; - int SSFI = MF.getFrameInfo()->CreateStackObject(MemSize, MemSize, false); + int SSFI = MF.getFrameInfo().CreateStackObject(MemSize, MemSize, false); SDValue StackSlot = DAG.getFrameIndex(SSFI, PtrVT); unsigned Opc; @@ -13935,15 +15145,15 @@ X86TargetLowering::FP_TO_INTHelper(SDValue Op, SelectionDAG &DAG, // For X87 we'd like to use the smallest FP type for this constant, but // for DAG type consistency we have to match the FP operand type. - APFloat Thresh(APFloat::IEEEsingle, APInt(32, 0x5f000000)); + APFloat Thresh(APFloat::IEEEsingle(), APInt(32, 0x5f000000)); LLVM_ATTRIBUTE_UNUSED APFloat::opStatus Status = APFloat::opOK; bool LosesInfo = false; if (TheVT == MVT::f64) // The rounding mode is irrelevant as the conversion should be exact. - Status = Thresh.convert(APFloat::IEEEdouble, APFloat::rmNearestTiesToEven, + Status = Thresh.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, &LosesInfo); else if (TheVT == MVT::f80) - Status = Thresh.convert(APFloat::x87DoubleExtended, + Status = Thresh.convert(APFloat::x87DoubleExtended(), APFloat::rmNearestTiesToEven, &LosesInfo); assert(Status == APFloat::opOK && !LosesInfo && @@ -13981,7 +15191,7 @@ X86TargetLowering::FP_TO_INTHelper(SDValue Op, SelectionDAG &DAG, MachineMemOperand::MOLoad, MemSize, MemSize); Value = DAG.getMemIntrinsicNode(X86ISD::FLD, DL, Tys, Ops, DstTy, MMO); Chain = Value.getValue(1); - SSFI = MF.getFrameInfo()->CreateStackObject(MemSize, MemSize, false); + SSFI = MF.getFrameInfo().CreateStackObject(MemSize, MemSize, false); StackSlot = DAG.getFrameIndex(SSFI, PtrVT); } @@ -14084,14 +15294,14 @@ static SDValue LowerZERO_EXTEND_AVX512(SDValue Op, SDValue In = Op->getOperand(0); MVT InVT = In.getSimpleValueType(); SDLoc DL(Op); - unsigned int NumElts = VT.getVectorNumElements(); - if (NumElts != 8 && NumElts != 16 && !Subtarget.hasBWI()) - return SDValue(); + unsigned NumElts = VT.getVectorNumElements(); - if (VT.is512BitVector() && InVT.getVectorElementType() != MVT::i1) + if (VT.is512BitVector() && InVT.getVectorElementType() != MVT::i1 && + (NumElts == 8 || NumElts == 16 || Subtarget.hasBWI())) return DAG.getNode(X86ISD::VZEXT, DL, VT, In); - assert(InVT.getVectorElementType() == MVT::i1); + if (InVT.getVectorElementType() != MVT::i1) + return SDValue(); // Extend VT if the target is 256 or 128bit vector and VLX is not supported. MVT ExtVT = VT; @@ -14137,6 +15347,85 @@ static SDValue LowerZERO_EXTEND(SDValue Op, const X86Subtarget &Subtarget, return SDValue(); } +/// Helper to recursively truncate vector elements in half with PACKSS. +/// It makes use of the fact that vector comparison results will be all-zeros +/// or all-ones to use (vXi8 PACKSS(vYi16, vYi16)) instead of matching types. +/// AVX2 (Int256) sub-targets require extra shuffling as the PACKSS operates +/// within each 128-bit lane. +static SDValue truncateVectorCompareWithPACKSS(EVT DstVT, SDValue In, + const SDLoc &DL, + SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + // Requires SSE2 but AVX512 has fast truncate. + if (!Subtarget.hasSSE2() || Subtarget.hasAVX512()) + return SDValue(); + + EVT SrcVT = In.getValueType(); + + // No truncation required, we might get here due to recursive calls. + if (SrcVT == DstVT) + return In; + + // We only support vector truncation to 128bits or greater from a + // 256bits or greater source. + if ((DstVT.getSizeInBits() % 128) != 0) + return SDValue(); + if ((SrcVT.getSizeInBits() % 256) != 0) + return SDValue(); + + unsigned NumElems = SrcVT.getVectorNumElements(); + assert(DstVT.getVectorNumElements() == NumElems && "Illegal truncation"); + assert(SrcVT.getSizeInBits() > DstVT.getSizeInBits() && "Illegal truncation"); + + EVT PackedSVT = + EVT::getIntegerVT(*DAG.getContext(), SrcVT.getScalarSizeInBits() / 2); + + // Extract lower/upper subvectors. + unsigned NumSubElts = NumElems / 2; + unsigned SrcSizeInBits = SrcVT.getSizeInBits(); + SDValue Lo = extractSubVector(In, 0 * NumSubElts, DAG, DL, SrcSizeInBits / 2); + SDValue Hi = extractSubVector(In, 1 * NumSubElts, DAG, DL, SrcSizeInBits / 2); + + // 256bit -> 128bit truncate - PACKSS lower/upper 128-bit subvectors. + if (SrcVT.is256BitVector()) { + Lo = DAG.getBitcast(MVT::v8i16, Lo); + Hi = DAG.getBitcast(MVT::v8i16, Hi); + SDValue Res = DAG.getNode(X86ISD::PACKSS, DL, MVT::v16i8, Lo, Hi); + return DAG.getBitcast(DstVT, Res); + } + + // AVX2: 512bit -> 256bit truncate - PACKSS lower/upper 256-bit subvectors. + // AVX2: 512bit -> 128bit truncate - PACKSS(PACKSS, PACKSS). + if (SrcVT.is512BitVector() && Subtarget.hasInt256()) { + Lo = DAG.getBitcast(MVT::v16i16, Lo); + Hi = DAG.getBitcast(MVT::v16i16, Hi); + SDValue Res = DAG.getNode(X86ISD::PACKSS, DL, MVT::v32i8, Lo, Hi); + + // 256-bit PACKSS(ARG0, ARG1) leaves us with ((LO0,LO1),(HI0,HI1)), + // so we need to shuffle to get ((LO0,HI0),(LO1,HI1)). + Res = DAG.getBitcast(MVT::v4i64, Res); + Res = DAG.getVectorShuffle(MVT::v4i64, DL, Res, Res, {0, 2, 1, 3}); + + if (DstVT.is256BitVector()) + return DAG.getBitcast(DstVT, Res); + + // If 512bit -> 128bit truncate another stage. + EVT PackedVT = EVT::getVectorVT(*DAG.getContext(), PackedSVT, NumElems); + Res = DAG.getBitcast(PackedVT, Res); + return truncateVectorCompareWithPACKSS(DstVT, Res, DL, DAG, Subtarget); + } + + // Recursively pack lower/upper subvectors, concat result and pack again. + assert(SrcVT.getSizeInBits() >= 512 && "Expected 512-bit vector or greater"); + EVT PackedVT = EVT::getVectorVT(*DAG.getContext(), PackedSVT, NumElems / 2); + Lo = truncateVectorCompareWithPACKSS(PackedVT, Lo, DL, DAG, Subtarget); + Hi = truncateVectorCompareWithPACKSS(PackedVT, Hi, DL, DAG, Subtarget); + + PackedVT = EVT::getVectorVT(*DAG.getContext(), PackedSVT, NumElems); + SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, DL, PackedVT, Lo, Hi); + return truncateVectorCompareWithPACKSS(DstVT, Res, DL, DAG, Subtarget); +} + static SDValue LowerTruncateVecI1(SDValue Op, SelectionDAG &DAG, const X86Subtarget &Subtarget) { @@ -14203,6 +15492,22 @@ SDValue X86TargetLowering::LowerTRUNCATE(SDValue Op, SelectionDAG &DAG) const { DAG.getNode(X86ISD::VSEXT, DL, MVT::v16i32, In)); return DAG.getNode(X86ISD::VTRUNC, DL, VT, In); } + + // Truncate with PACKSS if we are truncating a vector comparison result. + // TODO: We should be able to support other operations as long as we + // we are saturating+packing zero/all bits only. + auto IsPackableComparison = [](SDValue V) { + unsigned Opcode = V.getOpcode(); + return (Opcode == X86ISD::PCMPGT || Opcode == X86ISD::PCMPEQ || + Opcode == X86ISD::CMPP); + }; + + if (IsPackableComparison(In) || (In.getOpcode() == ISD::CONCAT_VECTORS && + all_of(In->ops(), IsPackableComparison))) { + if (SDValue V = truncateVectorCompareWithPACKSS(VT, In, DL, DAG, Subtarget)) + return V; + } + if ((VT == MVT::v4i32) && (InVT == MVT::v4i64)) { // On AVX2, v4i64 -> v4i32 becomes VPERMD. if (Subtarget.hasInt256()) { @@ -14299,30 +15604,31 @@ SDValue X86TargetLowering::LowerTRUNCATE(SDValue Op, SelectionDAG &DAG) const { DAG.getIntPtrConstant(0, DL)); } -SDValue X86TargetLowering::LowerFP_TO_SINT(SDValue Op, - SelectionDAG &DAG) const { - assert(!Op.getSimpleValueType().isVector()); +SDValue X86TargetLowering::LowerFP_TO_INT(SDValue Op, + const X86Subtarget &Subtarget, + SelectionDAG &DAG) const { + bool IsSigned = Op.getOpcode() == ISD::FP_TO_SINT; - std::pair<SDValue,SDValue> Vals = FP_TO_INTHelper(Op, DAG, - /*IsSigned=*/ true, /*IsReplace=*/ false); - SDValue FIST = Vals.first, StackSlot = Vals.second; - // If FP_TO_INTHelper failed, the node is actually supposed to be Legal. - if (!FIST.getNode()) - return Op; + MVT VT = Op.getSimpleValueType(); - if (StackSlot.getNode()) - // Load the result. - return DAG.getLoad(Op.getValueType(), SDLoc(Op), FIST, StackSlot, - MachinePointerInfo()); + if (VT.isVector()) { + assert(Subtarget.hasDQI() && Subtarget.hasVLX() && "Requires AVX512DQVL!"); + SDValue Src = Op.getOperand(0); + SDLoc dl(Op); + if (VT == MVT::v2i64 && Src.getSimpleValueType() == MVT::v2f32) { + return DAG.getNode(IsSigned ? X86ISD::CVTTP2SI : X86ISD::CVTTP2UI, + dl, VT, + DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4f32, Src, + DAG.getUNDEF(MVT::v2f32))); + } - // The node is the result. - return FIST; -} + return SDValue(); + } + + assert(!VT.isVector()); -SDValue X86TargetLowering::LowerFP_TO_UINT(SDValue Op, - SelectionDAG &DAG) const { std::pair<SDValue,SDValue> Vals = FP_TO_INTHelper(Op, DAG, - /*IsSigned=*/ false, /*IsReplace=*/ false); + IsSigned, /*IsReplace=*/ false); SDValue FIST = Vals.first, StackSlot = Vals.second; // If FP_TO_INTHelper failed, the node is actually supposed to be Legal. if (!FIST.getNode()) @@ -14330,8 +15636,7 @@ SDValue X86TargetLowering::LowerFP_TO_UINT(SDValue Op, if (StackSlot.getNode()) // Load the result. - return DAG.getLoad(Op.getValueType(), SDLoc(Op), FIST, StackSlot, - MachinePointerInfo()); + return DAG.getLoad(VT, SDLoc(Op), FIST, StackSlot, MachinePointerInfo()); // The node is the result. return FIST; @@ -14376,17 +15681,14 @@ static SDValue LowerFABSorFNEG(SDValue Op, SelectionDAG &DAG) { MVT LogicVT; MVT EltVT; - unsigned NumElts; if (VT.isVector()) { LogicVT = VT; EltVT = VT.getVectorElementType(); - NumElts = VT.getVectorNumElements(); } else if (IsF128) { // SSE instructions are used for optimized f128 logical operations. LogicVT = MVT::f128; EltVT = VT; - NumElts = 1; } else { // There are no scalar bitwise logical SSE/AVX instructions, so we // generate a 16-byte vector constant and logic op even for the scalar case. @@ -14394,22 +15696,16 @@ static SDValue LowerFABSorFNEG(SDValue Op, SelectionDAG &DAG) { // the logic op, so it can save (~4 bytes) on code size. LogicVT = (VT == MVT::f64) ? MVT::v2f64 : MVT::v4f32; EltVT = VT; - NumElts = (VT == MVT::f64) ? 2 : 4; } unsigned EltBits = EltVT.getSizeInBits(); - LLVMContext *Context = DAG.getContext(); // For FABS, mask is 0x7f...; for FNEG, mask is 0x80... APInt MaskElt = IsFABS ? APInt::getSignedMaxValue(EltBits) : APInt::getSignBit(EltBits); - Constant *C = ConstantInt::get(*Context, MaskElt); - C = ConstantVector::getSplat(NumElts, C); - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - SDValue CPIdx = DAG.getConstantPool(C, TLI.getPointerTy(DAG.getDataLayout())); - unsigned Alignment = cast<ConstantPoolSDNode>(CPIdx)->getAlignment(); - SDValue Mask = DAG.getLoad( - LogicVT, dl, DAG.getEntryNode(), CPIdx, - MachinePointerInfo::getConstantPool(DAG.getMachineFunction()), Alignment); + const fltSemantics &Sem = + EltVT == MVT::f64 ? APFloat::IEEEdouble() : + (IsF128 ? APFloat::IEEEquad() : APFloat::IEEEsingle()); + SDValue Mask = DAG.getConstantFP(APFloat(Sem, MaskElt), dl, LogicVT); SDValue Op0 = Op.getOperand(0); bool IsFNABS = !IsFABS && (Op0.getOpcode() == ISD::FABS); @@ -14429,92 +15725,73 @@ static SDValue LowerFABSorFNEG(SDValue Op, SelectionDAG &DAG) { } static SDValue LowerFCOPYSIGN(SDValue Op, SelectionDAG &DAG) { - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - LLVMContext *Context = DAG.getContext(); - SDValue Op0 = Op.getOperand(0); - SDValue Op1 = Op.getOperand(1); + SDValue Mag = Op.getOperand(0); + SDValue Sign = Op.getOperand(1); SDLoc dl(Op); + + // If the sign operand is smaller, extend it first. MVT VT = Op.getSimpleValueType(); - MVT SrcVT = Op1.getSimpleValueType(); - bool IsF128 = (VT == MVT::f128); + if (Sign.getSimpleValueType().bitsLT(VT)) + Sign = DAG.getNode(ISD::FP_EXTEND, dl, VT, Sign); - // If second operand is smaller, extend it first. - if (SrcVT.bitsLT(VT)) { - Op1 = DAG.getNode(ISD::FP_EXTEND, dl, VT, Op1); - SrcVT = VT; - } // And if it is bigger, shrink it first. - if (SrcVT.bitsGT(VT)) { - Op1 = DAG.getNode(ISD::FP_ROUND, dl, VT, Op1, DAG.getIntPtrConstant(1, dl)); - SrcVT = VT; - } + if (Sign.getSimpleValueType().bitsGT(VT)) + Sign = DAG.getNode(ISD::FP_ROUND, dl, VT, Sign, DAG.getIntPtrConstant(1, dl)); // At this point the operands and the result should have the same // type, and that won't be f80 since that is not custom lowered. - assert((VT == MVT::f64 || VT == MVT::f32 || IsF128) && + bool IsF128 = (VT == MVT::f128); + assert((VT == MVT::f64 || VT == MVT::f32 || VT == MVT::f128 || + VT == MVT::v2f64 || VT == MVT::v4f64 || VT == MVT::v4f32 || + VT == MVT::v8f32 || VT == MVT::v8f64 || VT == MVT::v16f32) && "Unexpected type in LowerFCOPYSIGN"); + MVT EltVT = VT.getScalarType(); const fltSemantics &Sem = - VT == MVT::f64 ? APFloat::IEEEdouble : - (IsF128 ? APFloat::IEEEquad : APFloat::IEEEsingle); - const unsigned SizeInBits = VT.getSizeInBits(); + EltVT == MVT::f64 ? APFloat::IEEEdouble() + : (IsF128 ? APFloat::IEEEquad() : APFloat::IEEEsingle()); + + // Perform all scalar logic operations as 16-byte vectors because there are no + // scalar FP logic instructions in SSE. + // TODO: This isn't necessary. If we used scalar types, we might avoid some + // unnecessary splats, but we might miss load folding opportunities. Should + // this decision be based on OptimizeForSize? + bool IsFakeVector = !VT.isVector() && !IsF128; + MVT LogicVT = VT; + if (IsFakeVector) + LogicVT = (VT == MVT::f64) ? MVT::v2f64 : MVT::v4f32; - SmallVector<Constant *, 4> CV( - VT == MVT::f64 ? 2 : (IsF128 ? 1 : 4), - ConstantFP::get(*Context, APFloat(Sem, APInt(SizeInBits, 0)))); + // The mask constants are automatically splatted for vector types. + unsigned EltSizeInBits = VT.getScalarSizeInBits(); + SDValue SignMask = DAG.getConstantFP( + APFloat(Sem, APInt::getSignBit(EltSizeInBits)), dl, LogicVT); + SDValue MagMask = DAG.getConstantFP( + APFloat(Sem, ~APInt::getSignBit(EltSizeInBits)), dl, LogicVT); // First, clear all bits but the sign bit from the second operand (sign). - CV[0] = ConstantFP::get(*Context, - APFloat(Sem, APInt::getHighBitsSet(SizeInBits, 1))); - Constant *C = ConstantVector::get(CV); - auto PtrVT = TLI.getPointerTy(DAG.getDataLayout()); - SDValue CPIdx = DAG.getConstantPool(C, PtrVT, 16); - - // Perform all logic operations as 16-byte vectors because there are no - // scalar FP logic instructions in SSE. This allows load folding of the - // constants into the logic instructions. - MVT LogicVT = (VT == MVT::f64) ? MVT::v2f64 : (IsF128 ? MVT::f128 : MVT::v4f32); - SDValue Mask1 = - DAG.getLoad(LogicVT, dl, DAG.getEntryNode(), CPIdx, - MachinePointerInfo::getConstantPool(DAG.getMachineFunction()), - /* Alignment = */ 16); - if (!IsF128) - Op1 = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, LogicVT, Op1); - SDValue SignBit = DAG.getNode(X86ISD::FAND, dl, LogicVT, Op1, Mask1); + if (IsFakeVector) + Sign = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, LogicVT, Sign); + SDValue SignBit = DAG.getNode(X86ISD::FAND, dl, LogicVT, Sign, SignMask); // Next, clear the sign bit from the first operand (magnitude). - // If it's a constant, we can clear it here. - if (ConstantFPSDNode *Op0CN = dyn_cast<ConstantFPSDNode>(Op0)) { + // TODO: If we had general constant folding for FP logic ops, this check + // wouldn't be necessary. + SDValue MagBits; + if (ConstantFPSDNode *Op0CN = dyn_cast<ConstantFPSDNode>(Mag)) { APFloat APF = Op0CN->getValueAPF(); - // If the magnitude is a positive zero, the sign bit alone is enough. - if (APF.isPosZero()) - return IsF128 ? SignBit : - DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, SrcVT, SignBit, - DAG.getIntPtrConstant(0, dl)); APF.clearSign(); - CV[0] = ConstantFP::get(*Context, APF); + MagBits = DAG.getConstantFP(APF, dl, LogicVT); } else { - CV[0] = ConstantFP::get( - *Context, - APFloat(Sem, APInt::getLowBitsSet(SizeInBits, SizeInBits - 1))); - } - C = ConstantVector::get(CV); - CPIdx = DAG.getConstantPool(C, PtrVT, 16); - SDValue Val = - DAG.getLoad(LogicVT, dl, DAG.getEntryNode(), CPIdx, - MachinePointerInfo::getConstantPool(DAG.getMachineFunction()), - /* Alignment = */ 16); - // If the magnitude operand wasn't a constant, we need to AND out the sign. - if (!isa<ConstantFPSDNode>(Op0)) { - if (!IsF128) - Op0 = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, LogicVT, Op0); - Val = DAG.getNode(X86ISD::FAND, dl, LogicVT, Op0, Val); + // If the magnitude operand wasn't a constant, we need to AND out the sign. + if (IsFakeVector) + Mag = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, LogicVT, Mag); + MagBits = DAG.getNode(X86ISD::FAND, dl, LogicVT, Mag, MagMask); } + // OR the magnitude value with the sign bit. - Val = DAG.getNode(X86ISD::FOR, dl, LogicVT, Val, SignBit); - return IsF128 ? Val : - DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, SrcVT, Val, - DAG.getIntPtrConstant(0, dl)); + SDValue Or = DAG.getNode(X86ISD::FOR, dl, LogicVT, MagBits, SignBit); + return !IsFakeVector ? Or : DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT, Or, + DAG.getIntPtrConstant(0, dl)); } static SDValue LowerFGETSIGN(SDValue Op, SelectionDAG &DAG) { @@ -14741,6 +16018,12 @@ SDValue X86TargetLowering::EmitTest(SDValue Op, unsigned X86CC, const SDLoc &dl, } } + // Sometimes flags can be set either with an AND or with an SRL/SHL + // instruction. SRL/SHL variant should be preferred for masks longer than this + // number of bits. + const int ShiftToAndMaxMaskWidth = 32; + const bool ZeroCheck = (X86CC == X86::COND_E || X86CC == X86::COND_NE); + // NOTICE: In the code below we use ArithOp to hold the arithmetic operation // which may be the result of a CAST. We use the variable 'Op', which is the // non-casted variable when we check for possible users. @@ -14764,7 +16047,7 @@ SDValue X86TargetLowering::EmitTest(SDValue Op, unsigned X86CC, const SDLoc &dl, goto default_case; if (ConstantSDNode *C = - dyn_cast<ConstantSDNode>(ArithOp.getNode()->getOperand(1))) { + dyn_cast<ConstantSDNode>(ArithOp.getOperand(1))) { // An add of one will be selected as an INC. if (C->isOne() && !Subtarget.slowIncDec()) { Opcode = X86ISD::INC; @@ -14789,7 +16072,7 @@ SDValue X86TargetLowering::EmitTest(SDValue Op, unsigned X86CC, const SDLoc &dl, // If we have a constant logical shift that's only used in a comparison // against zero turn it into an equivalent AND. This allows turning it into // a TEST instruction later. - if ((X86CC == X86::COND_E || X86CC == X86::COND_NE) && Op->hasOneUse() && + if (ZeroCheck && Op->hasOneUse() && isa<ConstantSDNode>(Op->getOperand(1)) && !hasNonFlagsUse(Op)) { EVT VT = Op.getValueType(); unsigned BitWidth = VT.getSizeInBits(); @@ -14799,7 +16082,7 @@ SDValue X86TargetLowering::EmitTest(SDValue Op, unsigned X86CC, const SDLoc &dl, APInt Mask = ArithOp.getOpcode() == ISD::SRL ? APInt::getHighBitsSet(BitWidth, BitWidth - ShAmt) : APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt); - if (!Mask.isSignedIntN(32)) // Avoid large immediates. + if (!Mask.isSignedIntN(ShiftToAndMaxMaskWidth)) break; Op = DAG.getNode(ISD::AND, dl, VT, Op->getOperand(0), DAG.getConstant(Mask, dl, VT)); @@ -14808,20 +16091,61 @@ SDValue X86TargetLowering::EmitTest(SDValue Op, unsigned X86CC, const SDLoc &dl, case ISD::AND: // If the primary 'and' result isn't used, don't bother using X86ISD::AND, - // because a TEST instruction will be better. + // because a TEST instruction will be better. However, AND should be + // preferred if the instruction can be combined into ANDN. if (!hasNonFlagsUse(Op)) { SDValue Op0 = ArithOp->getOperand(0); SDValue Op1 = ArithOp->getOperand(1); EVT VT = ArithOp.getValueType(); bool isAndn = isBitwiseNot(Op0) || isBitwiseNot(Op1); bool isLegalAndnType = VT == MVT::i32 || VT == MVT::i64; + bool isProperAndn = isAndn && isLegalAndnType && Subtarget.hasBMI(); + + // If we cannot select an ANDN instruction, check if we can replace + // AND+IMM64 with a shift before giving up. This is possible for masks + // like 0xFF000000 or 0x00FFFFFF and if we care only about the zero flag. + if (!isProperAndn) { + if (!ZeroCheck) + break; + + assert(!isa<ConstantSDNode>(Op0) && "AND node isn't canonicalized"); + auto *CN = dyn_cast<ConstantSDNode>(Op1); + if (!CN) + break; + + const APInt &Mask = CN->getAPIntValue(); + if (Mask.isSignedIntN(ShiftToAndMaxMaskWidth)) + break; // Prefer TEST instruction. + + unsigned BitWidth = Mask.getBitWidth(); + unsigned LeadingOnes = Mask.countLeadingOnes(); + unsigned TrailingZeros = Mask.countTrailingZeros(); + + if (LeadingOnes + TrailingZeros == BitWidth) { + assert(TrailingZeros < VT.getSizeInBits() && + "Shift amount should be less than the type width"); + MVT ShTy = getScalarShiftAmountTy(DAG.getDataLayout(), VT); + SDValue ShAmt = DAG.getConstant(TrailingZeros, dl, ShTy); + Op = DAG.getNode(ISD::SRL, dl, VT, Op0, ShAmt); + break; + } + + unsigned LeadingZeros = Mask.countLeadingZeros(); + unsigned TrailingOnes = Mask.countTrailingOnes(); + + if (LeadingZeros + TrailingOnes == BitWidth) { + assert(LeadingZeros < VT.getSizeInBits() && + "Shift amount should be less than the type width"); + MVT ShTy = getScalarShiftAmountTy(DAG.getDataLayout(), VT); + SDValue ShAmt = DAG.getConstant(LeadingZeros, dl, ShTy); + Op = DAG.getNode(ISD::SHL, dl, VT, Op0, ShAmt); + break; + } - // But if we can combine this into an ANDN operation, then create an AND - // now and allow it to be pattern matched into an ANDN. - if (!Subtarget.hasBMI() || !isAndn || !isLegalAndnType) break; + } } - // FALL THROUGH + LLVM_FALLTHROUGH; case ISD::SUB: case ISD::OR: case ISD::XOR: @@ -14839,7 +16163,7 @@ SDValue X86TargetLowering::EmitTest(SDValue Op, unsigned X86CC, const SDLoc &dl, case ISD::XOR: Opcode = X86ISD::XOR; break; case ISD::AND: Opcode = X86ISD::AND; break; case ISD::OR: { - if (!NeedTruncation && (X86CC == X86::COND_E || X86CC == X86::COND_NE)) { + if (!NeedTruncation && ZeroCheck) { if (SDValue EFLAGS = LowerVectorAllZeroTest(Op, Subtarget, DAG)) return EFLAGS; } @@ -14968,14 +16292,27 @@ SDValue X86TargetLowering::ConvertCmpIfNecessary(SDValue Cmp, return DAG.getNode(X86ISD::SAHF, dl, MVT::i32, TruncSrl); } +/// Check if replacement of SQRT with RSQRT should be disabled. +bool X86TargetLowering::isFsqrtCheap(SDValue Op, SelectionDAG &DAG) const { + EVT VT = Op.getValueType(); + + // We never want to use both SQRT and RSQRT instructions for the same input. + if (DAG.getNodeIfExists(X86ISD::FRSQRT, DAG.getVTList(VT), Op)) + return false; + + if (VT.isVector()) + return Subtarget.hasFastVectorFSQRT(); + return Subtarget.hasFastScalarFSQRT(); +} + /// The minimum architected relative accuracy is 2^-12. We need one /// Newton-Raphson step to have a good float result (24 bits of precision). -SDValue X86TargetLowering::getRsqrtEstimate(SDValue Op, - DAGCombinerInfo &DCI, - unsigned &RefinementSteps, - bool &UseOneConstNR) const { +SDValue X86TargetLowering::getSqrtEstimate(SDValue Op, + SelectionDAG &DAG, int Enabled, + int &RefinementSteps, + bool &UseOneConstNR, + bool Reciprocal) const { EVT VT = Op.getValueType(); - const char *RecipOp; // SSE1 has rsqrtss and rsqrtps. AVX adds a 256-bit variant for rsqrtps. // TODO: Add support for AVX512 (v16f32). @@ -14984,30 +16321,24 @@ SDValue X86TargetLowering::getRsqrtEstimate(SDValue Op, // instructions: convert to single, rsqrtss, convert back to double, refine // (3 steps = at least 13 insts). If an 'rsqrtsd' variant was added to the ISA // along with FMA, this could be a throughput win. - if (VT == MVT::f32 && Subtarget.hasSSE1()) - RecipOp = "sqrtf"; - else if ((VT == MVT::v4f32 && Subtarget.hasSSE1()) || - (VT == MVT::v8f32 && Subtarget.hasAVX())) - RecipOp = "vec-sqrtf"; - else - return SDValue(); - - TargetRecip Recips = DCI.DAG.getTarget().Options.Reciprocals; - if (!Recips.isEnabled(RecipOp)) - return SDValue(); + if ((VT == MVT::f32 && Subtarget.hasSSE1()) || + (VT == MVT::v4f32 && Subtarget.hasSSE1()) || + (VT == MVT::v8f32 && Subtarget.hasAVX())) { + if (RefinementSteps == ReciprocalEstimate::Unspecified) + RefinementSteps = 1; - RefinementSteps = Recips.getRefinementSteps(RecipOp); - UseOneConstNR = false; - return DCI.DAG.getNode(X86ISD::FRSQRT, SDLoc(Op), VT, Op); + UseOneConstNR = false; + return DAG.getNode(X86ISD::FRSQRT, SDLoc(Op), VT, Op); + } + return SDValue(); } /// The minimum architected relative accuracy is 2^-12. We need one /// Newton-Raphson step to have a good float result (24 bits of precision). -SDValue X86TargetLowering::getRecipEstimate(SDValue Op, - DAGCombinerInfo &DCI, - unsigned &RefinementSteps) const { +SDValue X86TargetLowering::getRecipEstimate(SDValue Op, SelectionDAG &DAG, + int Enabled, + int &RefinementSteps) const { EVT VT = Op.getValueType(); - const char *RecipOp; // SSE1 has rcpss and rcpps. AVX adds a 256-bit variant for rcpps. // TODO: Add support for AVX512 (v16f32). @@ -15016,20 +16347,22 @@ SDValue X86TargetLowering::getRecipEstimate(SDValue Op, // 15 instructions: convert to single, rcpss, convert back to double, refine // (3 steps = 12 insts). If an 'rcpsd' variant was added to the ISA // along with FMA, this could be a throughput win. - if (VT == MVT::f32 && Subtarget.hasSSE1()) - RecipOp = "divf"; - else if ((VT == MVT::v4f32 && Subtarget.hasSSE1()) || - (VT == MVT::v8f32 && Subtarget.hasAVX())) - RecipOp = "vec-divf"; - else - return SDValue(); - TargetRecip Recips = DCI.DAG.getTarget().Options.Reciprocals; - if (!Recips.isEnabled(RecipOp)) - return SDValue(); + if ((VT == MVT::f32 && Subtarget.hasSSE1()) || + (VT == MVT::v4f32 && Subtarget.hasSSE1()) || + (VT == MVT::v8f32 && Subtarget.hasAVX())) { + // Enable estimate codegen with 1 refinement step for vector division. + // Scalar division estimates are disabled because they break too much + // real-world code. These defaults are intended to match GCC behavior. + if (VT == MVT::f32 && Enabled == ReciprocalEstimate::Unspecified) + return SDValue(); + + if (RefinementSteps == ReciprocalEstimate::Unspecified) + RefinementSteps = 1; - RefinementSteps = Recips.getRefinementSteps(RecipOp); - return DCI.DAG.getNode(X86ISD::FRCP, SDLoc(Op), VT, Op); + return DAG.getNode(X86ISD::FRCP, SDLoc(Op), VT, Op); + } + return SDValue(); } /// If we have at least two divisions that use the same divisor, convert to @@ -15042,9 +16375,46 @@ unsigned X86TargetLowering::combineRepeatedFPDivisors() const { return 2; } +/// Helper for creating a X86ISD::SETCC node. +static SDValue getSETCC(X86::CondCode Cond, SDValue EFLAGS, const SDLoc &dl, + SelectionDAG &DAG) { + return DAG.getNode(X86ISD::SETCC, dl, MVT::i8, + DAG.getConstant(Cond, dl, MVT::i8), EFLAGS); +} + +/// Create a BT (Bit Test) node - Test bit \p BitNo in \p Src and set condition +/// according to equal/not-equal condition code \p CC. +static SDValue getBitTestCondition(SDValue Src, SDValue BitNo, ISD::CondCode CC, + const SDLoc &dl, SelectionDAG &DAG) { + // If Src is i8, promote it to i32 with any_extend. There is no i8 BT + // instruction. Since the shift amount is in-range-or-undefined, we know + // that doing a bittest on the i32 value is ok. We extend to i32 because + // the encoding for the i16 version is larger than the i32 version. + // Also promote i16 to i32 for performance / code size reason. + if (Src.getValueType() == MVT::i8 || Src.getValueType() == MVT::i16) + Src = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i32, Src); + + // See if we can use the 32-bit instruction instead of the 64-bit one for a + // shorter encoding. Since the former takes the modulo 32 of BitNo and the + // latter takes the modulo 64, this is only valid if the 5th bit of BitNo is + // known to be zero. + if (Src.getValueType() == MVT::i64 && + DAG.MaskedValueIsZero(BitNo, APInt(BitNo.getValueSizeInBits(), 32))) + Src = DAG.getNode(ISD::TRUNCATE, dl, MVT::i32, Src); + + // If the operand types disagree, extend the shift amount to match. Since + // BT ignores high bits (like shifts) we can use anyextend. + if (Src.getValueType() != BitNo.getValueType()) + BitNo = DAG.getNode(ISD::ANY_EXTEND, dl, Src.getValueType(), BitNo); + + SDValue BT = DAG.getNode(X86ISD::BT, dl, MVT::i32, Src, BitNo); + X86::CondCode Cond = CC == ISD::SETEQ ? X86::COND_AE : X86::COND_B; + return getSETCC(Cond, BT, dl , DAG); +} + /// Result of 'and' is compared against zero. Change to a BT node if possible. -SDValue X86TargetLowering::LowerToBT(SDValue And, ISD::CondCode CC, - const SDLoc &dl, SelectionDAG &DAG) const { +static SDValue LowerAndToBT(SDValue And, ISD::CondCode CC, + const SDLoc &dl, SelectionDAG &DAG) { SDValue Op0 = And.getOperand(0); SDValue Op1 = And.getOperand(1); if (Op0.getOpcode() == ISD::TRUNCATE) @@ -15087,27 +16457,35 @@ SDValue X86TargetLowering::LowerToBT(SDValue And, ISD::CondCode CC, } } - if (LHS.getNode()) { - // If LHS is i8, promote it to i32 with any_extend. There is no i8 BT - // instruction. Since the shift amount is in-range-or-undefined, we know - // that doing a bittest on the i32 value is ok. We extend to i32 because - // the encoding for the i16 version is larger than the i32 version. - // Also promote i16 to i32 for performance / code size reason. - if (LHS.getValueType() == MVT::i8 || - LHS.getValueType() == MVT::i16) - LHS = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i32, LHS); + if (LHS.getNode()) + return getBitTestCondition(LHS, RHS, CC, dl, DAG); - // If the operand types disagree, extend the shift amount to match. Since - // BT ignores high bits (like shifts) we can use anyextend. - if (LHS.getValueType() != RHS.getValueType()) - RHS = DAG.getNode(ISD::ANY_EXTEND, dl, LHS.getValueType(), RHS); + return SDValue(); +} - SDValue BT = DAG.getNode(X86ISD::BT, dl, MVT::i32, LHS, RHS); - X86::CondCode Cond = CC == ISD::SETEQ ? X86::COND_AE : X86::COND_B; - return DAG.getNode(X86ISD::SETCC, dl, MVT::i8, - DAG.getConstant(Cond, dl, MVT::i8), BT); - } +// Convert (truncate (srl X, N) to i1) to (bt X, N) +static SDValue LowerTruncateToBT(SDValue Op, ISD::CondCode CC, + const SDLoc &dl, SelectionDAG &DAG) { + + assert(Op.getOpcode() == ISD::TRUNCATE && Op.getValueType() == MVT::i1 && + "Expected TRUNCATE to i1 node"); + if (Op.getOperand(0).getOpcode() != ISD::SRL) + return SDValue(); + + SDValue ShiftRight = Op.getOperand(0); + return getBitTestCondition(ShiftRight.getOperand(0), ShiftRight.getOperand(1), + CC, dl, DAG); +} + +/// Result of 'and' or 'trunc to i1' is compared against zero. +/// Change to a BT node if possible. +SDValue X86TargetLowering::LowerToBT(SDValue Op, ISD::CondCode CC, + const SDLoc &dl, SelectionDAG &DAG) const { + if (Op.getOpcode() == ISD::AND) + return LowerAndToBT(Op, CC, dl, DAG); + if (Op.getOpcode() == ISD::TRUNCATE && Op.getValueType() == MVT::i1) + return LowerTruncateToBT(Op, CC, dl, DAG); return SDValue(); } @@ -15132,19 +16510,19 @@ static int translateX86FSETCC(ISD::CondCode SetCCOpcode, SDValue &Op0, case ISD::SETOEQ: case ISD::SETEQ: SSECC = 0; break; case ISD::SETOGT: - case ISD::SETGT: Swap = true; // Fallthrough + case ISD::SETGT: Swap = true; LLVM_FALLTHROUGH; case ISD::SETLT: case ISD::SETOLT: SSECC = 1; break; case ISD::SETOGE: - case ISD::SETGE: Swap = true; // Fallthrough + case ISD::SETGE: Swap = true; LLVM_FALLTHROUGH; case ISD::SETLE: case ISD::SETOLE: SSECC = 2; break; case ISD::SETUO: SSECC = 3; break; case ISD::SETUNE: case ISD::SETNE: SSECC = 4; break; - case ISD::SETULE: Swap = true; // Fallthrough + case ISD::SETULE: Swap = true; LLVM_FALLTHROUGH; case ISD::SETUGE: SSECC = 5; break; - case ISD::SETULT: Swap = true; // Fallthrough + case ISD::SETULT: Swap = true; LLVM_FALLTHROUGH; case ISD::SETUGT: SSECC = 6; break; case ISD::SETO: SSECC = 7; break; case ISD::SETUEQ: @@ -15250,12 +16628,12 @@ static SDValue LowerIntVSETCC_AVX512(SDValue Op, SelectionDAG &DAG) { case ISD::SETNE: SSECC = 4; break; case ISD::SETEQ: Opc = X86ISD::PCMPEQM; break; case ISD::SETUGT: SSECC = 6; Unsigned = true; break; - case ISD::SETLT: Swap = true; //fall-through + case ISD::SETLT: Swap = true; LLVM_FALLTHROUGH; case ISD::SETGT: Opc = X86ISD::PCMPGTM; break; case ISD::SETULT: SSECC = 1; Unsigned = true; break; case ISD::SETUGE: SSECC = 5; Unsigned = true; break; //NLT case ISD::SETGE: Swap = true; SSECC = 2; break; // LE + swap - case ISD::SETULE: Unsigned = true; //fall-through + case ISD::SETULE: Unsigned = true; LLVM_FALLTHROUGH; case ISD::SETLE: SSECC = 2; break; } @@ -15414,7 +16792,7 @@ static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget, // In this case use SSE compare bool UseAVX512Inst = (OpVT.is512BitVector() || - OpVT.getVectorElementType().getSizeInBits() >= 32 || + OpVT.getScalarSizeInBits() >= 32 || (Subtarget.hasBWI() && Subtarget.hasVLX())); if (UseAVX512Inst) @@ -15638,15 +17016,12 @@ SDValue X86TargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const { // Lower (X & (1 << N)) == 0 to BT(X, N). // Lower ((X >>u N) & 1) != 0 to BT(X, N). // Lower ((X >>s N) & 1) != 0 to BT(X, N). - if (Op0.getOpcode() == ISD::AND && Op0.hasOneUse() && - isNullConstant(Op1) && + // Lower (trunc (X >> N) to i1) to BT(X, N). + if (Op0.hasOneUse() && isNullConstant(Op1) && (CC == ISD::SETEQ || CC == ISD::SETNE)) { if (SDValue NewSetCC = LowerToBT(Op0, CC, dl, DAG)) { - if (VT == MVT::i1) { - NewSetCC = DAG.getNode(ISD::AssertZext, dl, MVT::i8, NewSetCC, - DAG.getValueType(MVT::i1)); + if (VT == MVT::i1) return DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, NewSetCC); - } return NewSetCC; } } @@ -15665,14 +17040,9 @@ SDValue X86TargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const { return Op0; CCode = X86::GetOppositeBranchCondition(CCode); - SDValue SetCC = DAG.getNode(X86ISD::SETCC, dl, MVT::i8, - DAG.getConstant(CCode, dl, MVT::i8), - Op0.getOperand(1)); - if (VT == MVT::i1) { - SetCC = DAG.getNode(ISD::AssertZext, dl, MVT::i8, SetCC, - DAG.getValueType(MVT::i1)); + SDValue SetCC = getSETCC(CCode, Op0.getOperand(1), dl, DAG); + if (VT == MVT::i1) return DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, SetCC); - } return SetCC; } } @@ -15687,20 +17057,16 @@ SDValue X86TargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const { } } - bool isFP = Op1.getSimpleValueType().isFloatingPoint(); - unsigned X86CC = TranslateX86CC(CC, dl, isFP, Op0, Op1, DAG); + bool IsFP = Op1.getSimpleValueType().isFloatingPoint(); + X86::CondCode X86CC = TranslateX86CC(CC, dl, IsFP, Op0, Op1, DAG); if (X86CC == X86::COND_INVALID) return SDValue(); SDValue EFLAGS = EmitCmp(Op0, Op1, X86CC, dl, DAG); EFLAGS = ConvertCmpIfNecessary(EFLAGS, DAG); - SDValue SetCC = DAG.getNode(X86ISD::SETCC, dl, MVT::i8, - DAG.getConstant(X86CC, dl, MVT::i8), EFLAGS); - if (VT == MVT::i1) { - SetCC = DAG.getNode(ISD::AssertZext, dl, MVT::i8, SetCC, - DAG.getValueType(MVT::i1)); + SDValue SetCC = getSETCC(X86CC, EFLAGS, dl, DAG); + if (VT == MVT::i1) return DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, SetCC); - } return SetCC; } @@ -15717,34 +17083,23 @@ SDValue X86TargetLowering::LowerSETCCE(SDValue Op, SelectionDAG &DAG) const { assert(Carry.getOpcode() != ISD::CARRY_FALSE); SDVTList VTs = DAG.getVTList(LHS.getValueType(), MVT::i32); SDValue Cmp = DAG.getNode(X86ISD::SBB, DL, VTs, LHS, RHS, Carry); - SDValue SetCC = DAG.getNode(X86ISD::SETCC, DL, MVT::i8, - DAG.getConstant(CC, DL, MVT::i8), Cmp.getValue(1)); - if (Op.getSimpleValueType() == MVT::i1) { - SetCC = DAG.getNode(ISD::AssertZext, DL, MVT::i8, SetCC, - DAG.getValueType(MVT::i1)); + SDValue SetCC = getSETCC(CC, Cmp.getValue(1), DL, DAG); + if (Op.getSimpleValueType() == MVT::i1) return DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, SetCC); - } return SetCC; } /// Return true if opcode is a X86 logical comparison. static bool isX86LogicalCmp(SDValue Op) { - unsigned Opc = Op.getNode()->getOpcode(); + unsigned Opc = Op.getOpcode(); if (Opc == X86ISD::CMP || Opc == X86ISD::COMI || Opc == X86ISD::UCOMI || Opc == X86ISD::SAHF) return true; if (Op.getResNo() == 1 && - (Opc == X86ISD::ADD || - Opc == X86ISD::SUB || - Opc == X86ISD::ADC || - Opc == X86ISD::SBB || - Opc == X86ISD::SMUL || - Opc == X86ISD::UMUL || - Opc == X86ISD::INC || - Opc == X86ISD::DEC || - Opc == X86ISD::OR || - Opc == X86ISD::XOR || - Opc == X86ISD::AND)) + (Opc == X86ISD::ADD || Opc == X86ISD::SUB || Opc == X86ISD::ADC || + Opc == X86ISD::SBB || Opc == X86ISD::SMUL || Opc == X86ISD::UMUL || + Opc == X86ISD::INC || Opc == X86ISD::DEC || Opc == X86ISD::OR || + Opc == X86ISD::XOR || Opc == X86ISD::AND)) return true; if (Op.getResNo() == 2 && Opc == X86ISD::UMUL) @@ -15753,27 +17108,18 @@ static bool isX86LogicalCmp(SDValue Op) { return false; } -/// Returns the "condition" node, that may be wrapped with "truncate". -/// Like this: (i1 (trunc (i8 X86ISD::SETCC))). -static SDValue getCondAfterTruncWithZeroHighBitsInput(SDValue V, SelectionDAG &DAG) { +static bool isTruncWithZeroHighBitsInput(SDValue V, SelectionDAG &DAG) { if (V.getOpcode() != ISD::TRUNCATE) - return V; + return false; SDValue VOp0 = V.getOperand(0); - if (VOp0.getOpcode() == ISD::AssertZext && - V.getValueSizeInBits() == - cast<VTSDNode>(VOp0.getOperand(1))->getVT().getSizeInBits()) - return VOp0.getOperand(0); - unsigned InBits = VOp0.getValueSizeInBits(); unsigned Bits = V.getValueSizeInBits(); - if (DAG.MaskedValueIsZero(VOp0, APInt::getHighBitsSet(InBits,InBits-Bits))) - return V.getOperand(0); - return V; + return DAG.MaskedValueIsZero(VOp0, APInt::getHighBitsSet(InBits,InBits-Bits)); } SDValue X86TargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const { - bool addTest = true; + bool AddTest = true; SDValue Cond = Op.getOperand(0); SDValue Op1 = Op.getOperand(1); SDValue Op2 = Op.getOperand(2); @@ -15794,9 +17140,10 @@ SDValue X86TargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const { if (SSECC != 8) { if (Subtarget.hasAVX512()) { - SDValue Cmp = DAG.getNode(X86ISD::FSETCC, DL, MVT::i1, CondOp0, CondOp1, - DAG.getConstant(SSECC, DL, MVT::i8)); - return DAG.getNode(X86ISD::SELECT, DL, VT, Cmp, Op1, Op2); + SDValue Cmp = DAG.getNode(X86ISD::FSETCCM, DL, MVT::i1, CondOp0, + CondOp1, DAG.getConstant(SSECC, DL, MVT::i8)); + return DAG.getNode(VT.isVector() ? X86ISD::SELECT : X86ISD::SELECTS, + DL, VT, Cmp, Op1, Op2); } SDValue Cmp = DAG.getNode(X86ISD::FSETCC, DL, VT, CondOp0, CondOp1, @@ -15840,6 +17187,11 @@ SDValue X86TargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const { } } + // AVX512 fallback is to lower selects of scalar floats to masked moves. + if (Cond.getValueType() == MVT::i1 && (VT == MVT::f64 || VT == MVT::f32) && + Subtarget.hasAVX512()) + return DAG.getNode(X86ISD::SELECTS, DL, VT, Cond, Op1, Op2); + if (VT.isVector() && VT.getVectorElementType() == MVT::i1) { SDValue Op1Scalar; if (ISD::isBuildVectorOfConstantSDNodes(Op1.getNode())) @@ -15875,8 +17227,14 @@ SDValue X86TargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const { } if (Cond.getOpcode() == ISD::SETCC) { - if (SDValue NewCond = LowerSETCC(Cond, DAG)) + if (SDValue NewCond = LowerSETCC(Cond, DAG)) { Cond = NewCond; + // If the condition was updated, it's possible that the operands of the + // select were also updated (for example, EmitTest has a RAUW). Refresh + // the local references to the select operands in case they got stale. + Op1 = Op.getOperand(1); + Op2 = Op.getOperand(2); + } } // (select (x == 0), -1, y) -> (sign_bit (x - 1)) | y @@ -15953,7 +17311,7 @@ SDValue X86TargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const { if ((isX86LogicalCmp(Cmp) && !IllegalFPCMov) || Opc == X86ISD::BT) { // FIXME Cond = Cmp; - addTest = false; + AddTest = false; } } else if (CondOpcode == ISD::USUBO || CondOpcode == ISD::SSUBO || CondOpcode == ISD::UADDO || CondOpcode == ISD::SADDO || @@ -15987,12 +17345,13 @@ SDValue X86TargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const { Cond = X86Op.getValue(1); CC = DAG.getConstant(X86Cond, DL, MVT::i8); - addTest = false; + AddTest = false; } - if (addTest) { + if (AddTest) { // Look past the truncate if the high bits are known zero. - Cond = getCondAfterTruncWithZeroHighBitsInput(Cond, DAG); + if (isTruncWithZeroHighBitsInput(Cond, DAG)) + Cond = Cond.getOperand(0); // We know the result of AND is compared against zero. Try to match // it to BT. @@ -16000,12 +17359,12 @@ SDValue X86TargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const { if (SDValue NewSetCC = LowerToBT(Cond, ISD::SETNE, DL, DAG)) { CC = NewSetCC.getOperand(0); Cond = NewSetCC.getOperand(1); - addTest = false; + AddTest = false; } } } - if (addTest) { + if (AddTest) { CC = DAG.getConstant(X86::COND_NE, DL, MVT::i8); Cond = EmitTest(Cond, X86::COND_NE, DL, DAG); } @@ -16077,34 +17436,44 @@ static SDValue LowerSIGN_EXTEND_AVX512(SDValue Op, VTElt.getSizeInBits() >= 32)))) return DAG.getNode(X86ISD::VSEXT, dl, VT, In); - unsigned int NumElts = VT.getVectorNumElements(); - - if (NumElts != 8 && NumElts != 16 && !Subtarget.hasBWI()) - return SDValue(); + unsigned NumElts = VT.getVectorNumElements(); - if (VT.is512BitVector() && InVT.getVectorElementType() != MVT::i1) { + if (VT.is512BitVector() && InVTElt != MVT::i1 && + (NumElts == 8 || NumElts == 16 || Subtarget.hasBWI())) { if (In.getOpcode() == X86ISD::VSEXT || In.getOpcode() == X86ISD::VZEXT) return DAG.getNode(In.getOpcode(), dl, VT, In.getOperand(0)); return DAG.getNode(X86ISD::VSEXT, dl, VT, In); } - assert (InVT.getVectorElementType() == MVT::i1 && "Unexpected vector type"); - MVT ExtVT = NumElts == 8 ? MVT::v8i64 : MVT::v16i32; - SDValue NegOne = - DAG.getConstant(APInt::getAllOnesValue(ExtVT.getScalarSizeInBits()), dl, - ExtVT); - SDValue Zero = - DAG.getConstant(APInt::getNullValue(ExtVT.getScalarSizeInBits()), dl, ExtVT); + if (InVTElt != MVT::i1) + return SDValue(); + + MVT ExtVT = VT; + if (!VT.is512BitVector() && !Subtarget.hasVLX()) + ExtVT = MVT::getVectorVT(MVT::getIntegerVT(512/NumElts), NumElts); + + SDValue V; + if (Subtarget.hasDQI()) { + V = DAG.getNode(X86ISD::VSEXT, dl, ExtVT, In); + assert(!VT.is512BitVector() && "Unexpected vector type"); + } else { + SDValue NegOne = getOnesVector(ExtVT, Subtarget, DAG, dl); + SDValue Zero = getZeroVector(ExtVT, Subtarget, DAG, dl); + V = DAG.getNode(ISD::VSELECT, dl, ExtVT, In, NegOne, Zero); + if (ExtVT == VT) + return V; + } - SDValue V = DAG.getNode(ISD::VSELECT, dl, ExtVT, In, NegOne, Zero); - if (VT.is512BitVector()) - return V; return DAG.getNode(X86ISD::VTRUNC, dl, VT, V); } -static SDValue LowerSIGN_EXTEND_VECTOR_INREG(SDValue Op, - const X86Subtarget &Subtarget, - SelectionDAG &DAG) { +// Lowering for SIGN_EXTEND_VECTOR_INREG and ZERO_EXTEND_VECTOR_INREG. +// For sign extend this needs to handle all vector sizes and SSE4.1 and +// non-SSE4.1 targets. For zero extend this should only handle inputs of +// MVT::v64i8 when BWI is not supported, but AVX512 is. +static SDValue LowerEXTEND_VECTOR_INREG(SDValue Op, + const X86Subtarget &Subtarget, + SelectionDAG &DAG) { SDValue In = Op->getOperand(0); MVT VT = Op->getSimpleValueType(0); MVT InVT = In.getSimpleValueType(); @@ -16119,20 +17488,33 @@ static SDValue LowerSIGN_EXTEND_VECTOR_INREG(SDValue Op, if (InSVT != MVT::i32 && InSVT != MVT::i16 && InSVT != MVT::i8) return SDValue(); if (!(VT.is128BitVector() && Subtarget.hasSSE2()) && - !(VT.is256BitVector() && Subtarget.hasInt256())) + !(VT.is256BitVector() && Subtarget.hasInt256()) && + !(VT.is512BitVector() && Subtarget.hasAVX512())) return SDValue(); SDLoc dl(Op); // For 256-bit vectors, we only need the lower (128-bit) half of the input. - if (VT.is256BitVector()) - In = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, - MVT::getVectorVT(InSVT, InVT.getVectorNumElements() / 2), - In, DAG.getIntPtrConstant(0, dl)); + // For 512-bit vectors, we need 128-bits or 256-bits. + if (VT.getSizeInBits() > 128) { + // Input needs to be at least the same number of elements as output, and + // at least 128-bits. + int InSize = InSVT.getSizeInBits() * VT.getVectorNumElements(); + In = extractSubVector(In, 0, DAG, dl, std::max(InSize, 128)); + } + + assert((Op.getOpcode() != ISD::ZERO_EXTEND_VECTOR_INREG || + InVT == MVT::v64i8) && "Zero extend only for v64i8 input!"); // SSE41 targets can use the pmovsx* instructions directly. + unsigned ExtOpc = Op.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG ? + X86ISD::VSEXT : X86ISD::VZEXT; if (Subtarget.hasSSE41()) - return DAG.getNode(X86ISD::VSEXT, dl, VT, In); + return DAG.getNode(ExtOpc, dl, VT, In); + + // We should only get here for sign extend. + assert(Op.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG && + "Unexpected opcode!"); // pre-SSE41 targets unpack lower lanes and then sign-extend using SRAI. SDValue Curr = In; @@ -16150,7 +17532,7 @@ static SDValue LowerSIGN_EXTEND_VECTOR_INREG(SDValue Op, SDValue SignExt = Curr; if (CurrVT != InVT) { unsigned SignExtShift = - CurrVT.getVectorElementType().getSizeInBits() - InSVT.getSizeInBits(); + CurrVT.getScalarSizeInBits() - InSVT.getSizeInBits(); SignExt = DAG.getNode(X86ISD::VSRAI, dl, CurrVT, Curr, DAG.getConstant(SignExtShift, dl, MVT::i8)); } @@ -16211,7 +17593,7 @@ static SDValue LowerSIGN_EXTEND(SDValue Op, const X86Subtarget &Subtarget, SDValue OpHi = DAG.getVectorShuffle(InVT, dl, In, Undef, ShufMask2); MVT HalfVT = MVT::getVectorVT(VT.getVectorElementType(), - VT.getVectorNumElements()/2); + VT.getVectorNumElements() / 2); OpLo = DAG.getNode(X86ISD::VSEXT, dl, HalfVT, OpLo); OpHi = DAG.getNode(X86ISD::VSEXT, dl, HalfVT, OpHi); @@ -16643,7 +18025,7 @@ SDValue X86TargetLowering::LowerBRCOND(SDValue Op, SelectionDAG &DAG) const { case X86::COND_B: // These can only come from an arithmetic instruction with overflow, // e.g. SADDO, UADDO. - Cond = Cond.getNode()->getOperand(1); + Cond = Cond.getOperand(1); addTest = false; break; } @@ -16828,11 +18210,11 @@ SDValue X86TargetLowering::LowerBRCOND(SDValue Op, SelectionDAG &DAG) const { if (addTest) { // Look pass the truncate if the high bits are known zero. - Cond = getCondAfterTruncWithZeroHighBitsInput(Cond, DAG); + if (isTruncWithZeroHighBitsInput(Cond, DAG)) + Cond = Cond.getOperand(0); - // We know the result of AND is compared against zero. Try to match - // it to BT. - if (Cond.getOpcode() == ISD::AND && Cond.hasOneUse()) { + // We know the result is compared against zero. Try to match it to BT. + if (Cond.hasOneUse()) { if (SDValue NewSetCC = LowerToBT(Cond, ISD::SETNE, dl, DAG)) { CC = NewSetCC.getOperand(0); Cond = NewSetCC.getOperand(1); @@ -17000,7 +18382,7 @@ SDValue X86TargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const { SDValue X86TargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const { assert(Subtarget.is64Bit() && "LowerVAARG only handles 64-bit va_arg!"); - assert(Op.getNode()->getNumOperands() == 4); + assert(Op.getNumOperands() == 4); MachineFunction &MF = DAG.getMachineFunction(); if (Subtarget.isCallingConvWin64(MF.getFunction()->getCallingConv())) @@ -17161,6 +18543,7 @@ static SDValue getTargetVShiftByConstNode(unsigned Opc, const SDLoc &dl, MVT VT, /// constant. Takes immediate version of shift as input. static SDValue getTargetVShiftNode(unsigned Opc, const SDLoc &dl, MVT VT, SDValue SrcOp, SDValue ShAmt, + const X86Subtarget &Subtarget, SelectionDAG &DAG) { MVT SVT = ShAmt.getSimpleValueType(); assert((SVT == MVT::i32 || SVT == MVT::i64) && "Unexpected value type!"); @@ -17178,27 +18561,32 @@ static SDValue getTargetVShiftNode(unsigned Opc, const SDLoc &dl, MVT VT, case X86ISD::VSRAI: Opc = X86ISD::VSRA; break; } - const X86Subtarget &Subtarget = - static_cast<const X86Subtarget &>(DAG.getSubtarget()); - if (Subtarget.hasSSE41() && ShAmt.getOpcode() == ISD::ZERO_EXTEND && - ShAmt.getOperand(0).getSimpleValueType() == MVT::i16) { - // Let the shuffle legalizer expand this shift amount node. - SDValue Op0 = ShAmt.getOperand(0); - Op0 = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(Op0), MVT::v8i16, Op0); - ShAmt = getShuffleVectorZeroOrUndef(Op0, 0, true, Subtarget, DAG); + // Need to build a vector containing shift amount. + // SSE/AVX packed shifts only use the lower 64-bit of the shift count. + // +=================+============+=======================================+ + // | ShAmt is | HasSSE4.1? | Construct ShAmt vector as | + // +=================+============+=======================================+ + // | i64 | Yes, No | Use ShAmt as lowest elt | + // | i32 | Yes | zero-extend in-reg | + // | (i32 zext(i16)) | Yes | zero-extend in-reg | + // | i16/i32 | No | v4i32 build_vector(ShAmt, 0, ud, ud)) | + // +=================+============+=======================================+ + + if (SVT == MVT::i64) + ShAmt = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(ShAmt), MVT::v2i64, ShAmt); + else if (Subtarget.hasSSE41() && ShAmt.getOpcode() == ISD::ZERO_EXTEND && + ShAmt.getOperand(0).getSimpleValueType() == MVT::i16) { + ShAmt = ShAmt.getOperand(0); + ShAmt = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(ShAmt), MVT::v8i16, ShAmt); + ShAmt = DAG.getNode(X86ISD::VZEXT, SDLoc(ShAmt), MVT::v2i64, ShAmt); + } else if (Subtarget.hasSSE41() && + ShAmt.getOpcode() == ISD::EXTRACT_VECTOR_ELT) { + ShAmt = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(ShAmt), MVT::v4i32, ShAmt); + ShAmt = DAG.getNode(X86ISD::VZEXT, SDLoc(ShAmt), MVT::v2i64, ShAmt); } else { - // Need to build a vector containing shift amount. - // SSE/AVX packed shifts only use the lower 64-bit of the shift count. - SmallVector<SDValue, 4> ShOps; - ShOps.push_back(ShAmt); - if (SVT == MVT::i32) { - ShOps.push_back(DAG.getConstant(0, dl, SVT)); - ShOps.push_back(DAG.getUNDEF(SVT)); - } - ShOps.push_back(DAG.getUNDEF(SVT)); - - MVT BVT = SVT == MVT::i32 ? MVT::v4i32 : MVT::v2i64; - ShAmt = DAG.getBuildVector(BVT, dl, ShOps); + SmallVector<SDValue, 4> ShOps = {ShAmt, DAG.getConstant(0, dl, SVT), + DAG.getUNDEF(SVT), DAG.getUNDEF(SVT)}; + ShAmt = DAG.getBuildVector(MVT::v4i32, dl, ShOps); } // The return type has to be a 128-bit type with the same element @@ -17290,7 +18678,7 @@ static SDValue getVectorMaskingNode(SDValue Op, SDValue Mask, case X86ISD::VTRUNC: case X86ISD::VTRUNCS: case X86ISD::VTRUNCUS: - case ISD::FP_TO_FP16: + case X86ISD::CVTPS2PH: // We can't use ISD::VSELECT here because it is not always "Legal" // for the destination type. For example vpmovqb require only AVX512 // and vselect that can operate on byte element type require BWI @@ -17321,7 +18709,8 @@ static SDValue getScalarMaskingNode(SDValue Op, SDValue Mask, // The mask should be of type MVT::i1 SDValue IMask = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Mask); - if (Op.getOpcode() == X86ISD::FSETCC) + if (Op.getOpcode() == X86ISD::FSETCCM || + Op.getOpcode() == X86ISD::FSETCCM_RND) return DAG.getNode(ISD::AND, dl, VT, Op, IMask); if (Op.getOpcode() == X86ISD::VFPCLASS || Op.getOpcode() == X86ISD::VFPCLASSS) @@ -17329,7 +18718,7 @@ static SDValue getScalarMaskingNode(SDValue Op, SDValue Mask, if (PreservedSrc.isUndef()) PreservedSrc = getZeroVector(VT, Subtarget, DAG, dl); - return DAG.getNode(X86ISD::SELECT, dl, VT, IMask, Op, PreservedSrc); + return DAG.getNode(X86ISD::SELECTS, dl, VT, IMask, Op, PreservedSrc); } static int getSEHRegistrationNodeSize(const Function *Fn) { @@ -17395,6 +18784,15 @@ static SDValue recoverFramePointer(SelectionDAG &DAG, const Function *Fn, static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG) { + // Helper to detect if the operand is CUR_DIRECTION rounding mode. + auto isRoundModeCurDirection = [](SDValue Rnd) { + if (!isa<ConstantSDNode>(Rnd)) + return false; + + unsigned Round = cast<ConstantSDNode>(Rnd)->getZExtValue(); + return Round == X86::STATIC_ROUNDING::CUR_DIRECTION; + }; + SDLoc dl(Op); unsigned IntNo = cast<ConstantSDNode>(Op.getOperand(0))->getZExtValue(); MVT VT = Op.getSimpleValueType(); @@ -17406,9 +18804,6 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget &Subtarget case INTR_TYPE_2OP: return DAG.getNode(IntrData->Opc0, dl, Op.getValueType(), Op.getOperand(1), Op.getOperand(2)); - case INTR_TYPE_2OP_IMM8: - return DAG.getNode(IntrData->Opc0, dl, Op.getValueType(), Op.getOperand(1), - DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, Op.getOperand(2))); case INTR_TYPE_3OP: return DAG.getNode(IntrData->Opc0, dl, Op.getValueType(), Op.getOperand(1), Op.getOperand(2), Op.getOperand(3)); @@ -17420,7 +18815,7 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget &Subtarget SDValue PassThru = Op.getOperand(2); SDValue Mask = Op.getOperand(3); SDValue RoundingMode; - // We allways add rounding mode to the Node. + // We always add rounding mode to the Node. // If the rounding mode is not specified, we add the // "current direction" mode. if (Op.getNumOperands() == 4) @@ -17428,13 +18823,7 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget &Subtarget DAG.getConstant(X86::STATIC_ROUNDING::CUR_DIRECTION, dl, MVT::i32); else RoundingMode = Op.getOperand(4); - unsigned IntrWithRoundingModeOpcode = IntrData->Opc1; - if (IntrWithRoundingModeOpcode != 0) - if (cast<ConstantSDNode>(RoundingMode)->getZExtValue() != - X86::STATIC_ROUNDING::CUR_DIRECTION) - return getVectorMaskingNode(DAG.getNode(IntrWithRoundingModeOpcode, - dl, Op.getValueType(), Src, RoundingMode), - Mask, PassThru, Subtarget, DAG); + assert(IntrData->Opc1 == 0 && "Unexpected second opcode!"); return getVectorMaskingNode(DAG.getNode(IntrData->Opc0, dl, VT, Src, RoundingMode), Mask, PassThru, Subtarget, DAG); @@ -17449,8 +18838,7 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget &Subtarget unsigned IntrWithRoundingModeOpcode = IntrData->Opc1; if (IntrWithRoundingModeOpcode != 0) { SDValue Rnd = Op.getOperand(4); - unsigned Round = cast<ConstantSDNode>(Rnd)->getZExtValue(); - if (Round != X86::STATIC_ROUNDING::CUR_DIRECTION) { + if (!isRoundModeCurDirection(Rnd)) { return getVectorMaskingNode(DAG.getNode(IntrWithRoundingModeOpcode, dl, Op.getValueType(), Src, Rnd), @@ -17478,8 +18866,7 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget &Subtarget // (2) With rounding mode and sae - 7 operands. if (Op.getNumOperands() == 6) { SDValue Sae = Op.getOperand(5); - unsigned Opc = IntrData->Opc1 ? IntrData->Opc1 : IntrData->Opc0; - return getScalarMaskingNode(DAG.getNode(Opc, dl, VT, Src1, Src2, + return getScalarMaskingNode(DAG.getNode(IntrData->Opc0, dl, VT, Src1, Src2, Sae), Mask, Src0, Subtarget, DAG); } @@ -17506,8 +18893,7 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget &Subtarget unsigned IntrWithRoundingModeOpcode = IntrData->Opc1; if (IntrWithRoundingModeOpcode != 0) { SDValue Rnd = Op.getOperand(5); - unsigned Round = cast<ConstantSDNode>(Rnd)->getZExtValue(); - if (Round != X86::STATIC_ROUNDING::CUR_DIRECTION) { + if (!isRoundModeCurDirection(Rnd)) { return getVectorMaskingNode(DAG.getNode(IntrWithRoundingModeOpcode, dl, Op.getValueType(), Src1, Src2, Rnd), @@ -17564,12 +18950,11 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget &Subtarget else Rnd = DAG.getConstant(X86::STATIC_ROUNDING::CUR_DIRECTION, dl, MVT::i32); return getVectorMaskingNode(DAG.getNode(IntrData->Opc0, dl, VT, - Src1, Src2, Imm, Rnd), - Mask, PassThru, Subtarget, DAG); + Src1, Src2, Imm, Rnd), + Mask, PassThru, Subtarget, DAG); } case INTR_TYPE_3OP_IMM8_MASK: - case INTR_TYPE_3OP_MASK: - case INSERT_SUBVEC: { + case INTR_TYPE_3OP_MASK: { SDValue Src1 = Op.getOperand(1); SDValue Src2 = Op.getOperand(2); SDValue Src3 = Op.getOperand(3); @@ -17578,13 +18963,6 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget &Subtarget if (IntrData->Type == INTR_TYPE_3OP_IMM8_MASK) Src3 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, Src3); - else if (IntrData->Type == INSERT_SUBVEC) { - // imm should be adapted to ISD::INSERT_SUBVECTOR behavior - assert(isa<ConstantSDNode>(Src3) && "Expected a ConstantSDNode here!"); - unsigned Imm = cast<ConstantSDNode>(Src3)->getZExtValue(); - Imm *= Src2.getSimpleValueType().getVectorNumElements(); - Src3 = DAG.getTargetConstant(Imm, dl, MVT::i32); - } // We specify 2 possible opcodes for intrinsics with rounding modes. // First, we check if the intrinsic may have non-default rounding mode, @@ -17592,8 +18970,7 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget &Subtarget unsigned IntrWithRoundingModeOpcode = IntrData->Opc1; if (IntrWithRoundingModeOpcode != 0) { SDValue Rnd = Op.getOperand(6); - unsigned Round = cast<ConstantSDNode>(Rnd)->getZExtValue(); - if (Round != X86::STATIC_ROUNDING::CUR_DIRECTION) { + if (!isRoundModeCurDirection(Rnd)) { return getVectorMaskingNode(DAG.getNode(IntrWithRoundingModeOpcode, dl, Op.getValueType(), Src1, Src2, Src3, Rnd), @@ -17616,19 +18993,21 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget &Subtarget } case VPERM_3OP_MASKZ: case VPERM_3OP_MASK:{ + MVT VT = Op.getSimpleValueType(); // Src2 is the PassThru SDValue Src1 = Op.getOperand(1); - SDValue Src2 = Op.getOperand(2); + // PassThru needs to be the same type as the destination in order + // to pattern match correctly. + SDValue Src2 = DAG.getBitcast(VT, Op.getOperand(2)); SDValue Src3 = Op.getOperand(3); SDValue Mask = Op.getOperand(4); - MVT VT = Op.getSimpleValueType(); SDValue PassThru = SDValue(); // set PassThru element if (IntrData->Type == VPERM_3OP_MASKZ) PassThru = getZeroVector(VT, Subtarget, DAG, dl); else - PassThru = DAG.getBitcast(VT, Src2); + PassThru = Src2; // Swap Src1 and Src2 in the node creation return getVectorMaskingNode(DAG.getNode(IntrData->Opc0, @@ -17660,8 +19039,7 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget &Subtarget unsigned IntrWithRoundingModeOpcode = IntrData->Opc1; if (IntrWithRoundingModeOpcode != 0) { SDValue Rnd = Op.getOperand(5); - if (cast<ConstantSDNode>(Rnd)->getZExtValue() != - X86::STATIC_ROUNDING::CUR_DIRECTION) + if (!isRoundModeCurDirection(Rnd)) return getVectorMaskingNode(DAG.getNode(IntrWithRoundingModeOpcode, dl, Op.getValueType(), Src1, Src2, Src3, Rnd), @@ -17713,6 +19091,35 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget &Subtarget Src1, Src2, Src3, Src4), Mask, PassThru, Subtarget, DAG); } + case CVTPD2PS: + // ISD::FP_ROUND has a second argument that indicates if the truncation + // does not change the value. Set it to 0 since it can change. + return DAG.getNode(IntrData->Opc0, dl, VT, Op.getOperand(1), + DAG.getIntPtrConstant(0, dl)); + case CVTPD2PS_MASK: { + SDValue Src = Op.getOperand(1); + SDValue PassThru = Op.getOperand(2); + SDValue Mask = Op.getOperand(3); + // We add rounding mode to the Node when + // - RM Opcode is specified and + // - RM is not "current direction". + unsigned IntrWithRoundingModeOpcode = IntrData->Opc1; + if (IntrWithRoundingModeOpcode != 0) { + SDValue Rnd = Op.getOperand(4); + if (!isRoundModeCurDirection(Rnd)) { + return getVectorMaskingNode(DAG.getNode(IntrWithRoundingModeOpcode, + dl, Op.getValueType(), + Src, Rnd), + Mask, PassThru, Subtarget, DAG); + } + } + assert(IntrData->Opc0 == ISD::FP_ROUND && "Unexpected opcode!"); + // ISD::FP_ROUND has a second argument that indicates if the truncation + // does not change the value. Set it to 0 since it can change. + return getVectorMaskingNode(DAG.getNode(IntrData->Opc0, dl, VT, Src, + DAG.getIntPtrConstant(0, dl)), + Mask, PassThru, Subtarget, DAG); + } case FPCLASS: { // FPclass intrinsics with mask SDValue Src1 = Op.getOperand(1); @@ -17738,7 +19145,7 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget &Subtarget SDValue FPclass = DAG.getNode(IntrData->Opc0, dl, MVT::i1, Src1, Imm); SDValue FPclassMask = getScalarMaskingNode(FPclass, Mask, DAG.getTargetConstant(0, dl, MVT::i1), Subtarget, DAG); - return DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::i8, FPclassMask); + return DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, FPclassMask); } case CMP_MASK: case CMP_MASK_CC: { @@ -17765,8 +19172,7 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget &Subtarget // (IntrData->Opc1 != 0), then we check the rounding mode operand. if (IntrData->Opc1 != 0) { SDValue Rnd = Op.getOperand(5); - if (cast<ConstantSDNode>(Rnd)->getZExtValue() != - X86::STATIC_ROUNDING::CUR_DIRECTION) + if (!isRoundModeCurDirection(Rnd)) Cmp = DAG.getNode(IntrData->Opc1, dl, MaskVT, Op.getOperand(1), Op.getOperand(2), CC, Rnd); } @@ -17798,8 +19204,7 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget &Subtarget SDValue Cmp; if (IntrData->Opc1 != 0) { SDValue Rnd = Op.getOperand(5); - if (cast<ConstantSDNode>(Rnd)->getZExtValue() != - X86::STATIC_ROUNDING::CUR_DIRECTION) + if (!isRoundModeCurDirection(Rnd)) Cmp = DAG.getNode(IntrData->Opc1, dl, MVT::i1, Src1, Src2, CC, Rnd); } //default rounding mode @@ -17822,39 +19227,29 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget &Subtarget SDValue SetCC; switch (CC) { case ISD::SETEQ: { // (ZF = 0 and PF = 0) - SetCC = DAG.getNode(X86ISD::SETCC, dl, MVT::i8, - DAG.getConstant(X86::COND_E, dl, MVT::i8), Comi); - SDValue SetNP = DAG.getNode(X86ISD::SETCC, dl, MVT::i8, - DAG.getConstant(X86::COND_NP, dl, MVT::i8), - Comi); + SetCC = getSETCC(X86::COND_E, Comi, dl, DAG); + SDValue SetNP = getSETCC(X86::COND_NP, Comi, dl, DAG); SetCC = DAG.getNode(ISD::AND, dl, MVT::i8, SetCC, SetNP); break; } case ISD::SETNE: { // (ZF = 1 or PF = 1) - SetCC = DAG.getNode(X86ISD::SETCC, dl, MVT::i8, - DAG.getConstant(X86::COND_NE, dl, MVT::i8), Comi); - SDValue SetP = DAG.getNode(X86ISD::SETCC, dl, MVT::i8, - DAG.getConstant(X86::COND_P, dl, MVT::i8), - Comi); + SetCC = getSETCC(X86::COND_NE, Comi, dl, DAG); + SDValue SetP = getSETCC(X86::COND_P, Comi, dl, DAG); SetCC = DAG.getNode(ISD::OR, dl, MVT::i8, SetCC, SetP); break; } case ISD::SETGT: // (CF = 0 and ZF = 0) - SetCC = DAG.getNode(X86ISD::SETCC, dl, MVT::i8, - DAG.getConstant(X86::COND_A, dl, MVT::i8), Comi); + SetCC = getSETCC(X86::COND_A, Comi, dl, DAG); break; case ISD::SETLT: { // The condition is opposite to GT. Swap the operands. - SetCC = DAG.getNode(X86ISD::SETCC, dl, MVT::i8, - DAG.getConstant(X86::COND_A, dl, MVT::i8), InvComi); + SetCC = getSETCC(X86::COND_A, InvComi, dl, DAG); break; } case ISD::SETGE: // CF = 0 - SetCC = DAG.getNode(X86ISD::SETCC, dl, MVT::i8, - DAG.getConstant(X86::COND_AE, dl, MVT::i8), Comi); + SetCC = getSETCC(X86::COND_AE, Comi, dl, DAG); break; case ISD::SETLE: // The condition is opposite to GE. Swap the operands. - SetCC = DAG.getNode(X86ISD::SETCC, dl, MVT::i8, - DAG.getConstant(X86::COND_AE, dl, MVT::i8), InvComi); + SetCC = getSETCC(X86::COND_AE, InvComi, dl, DAG); break; default: llvm_unreachable("Unexpected illegal condition!"); @@ -17868,19 +19263,19 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget &Subtarget SDValue Sae = Op.getOperand(4); SDValue FCmp; - if (cast<ConstantSDNode>(Sae)->getZExtValue() == - X86::STATIC_ROUNDING::CUR_DIRECTION) - FCmp = DAG.getNode(X86ISD::FSETCC, dl, MVT::i1, LHS, RHS, + if (isRoundModeCurDirection(Sae)) + FCmp = DAG.getNode(X86ISD::FSETCCM, dl, MVT::i1, LHS, RHS, DAG.getConstant(CondVal, dl, MVT::i8)); else - FCmp = DAG.getNode(X86ISD::FSETCC, dl, MVT::i1, LHS, RHS, + FCmp = DAG.getNode(X86ISD::FSETCCM_RND, dl, MVT::i1, LHS, RHS, DAG.getConstant(CondVal, dl, MVT::i8), Sae); // AnyExt just uses KMOVW %kreg, %r32; ZeroExt emits "and $1, %reg" return DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i32, FCmp); } case VSHIFT: return getTargetVShiftNode(IntrData->Opc0, dl, Op.getSimpleValueType(), - Op.getOperand(1), Op.getOperand(2), DAG); + Op.getOperand(1), Op.getOperand(2), Subtarget, + DAG); case COMPRESS_EXPAND_IN_REG: { SDValue Mask = Op.getOperand(3); SDValue DataToCompress = Op.getOperand(1); @@ -18027,14 +19422,15 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget &Subtarget case Intrinsic::x86_avx_vtestc_pd_256: case Intrinsic::x86_avx_vtestnzc_pd_256: { bool IsTestPacked = false; - unsigned X86CC; + X86::CondCode X86CC; switch (IntNo) { default: llvm_unreachable("Bad fallthrough in Intrinsic lowering."); case Intrinsic::x86_avx_vtestz_ps: case Intrinsic::x86_avx_vtestz_pd: case Intrinsic::x86_avx_vtestz_ps_256: case Intrinsic::x86_avx_vtestz_pd_256: - IsTestPacked = true; // Fallthrough + IsTestPacked = true; + LLVM_FALLTHROUGH; case Intrinsic::x86_sse41_ptestz: case Intrinsic::x86_avx_ptestz_256: // ZF = 1 @@ -18044,7 +19440,8 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget &Subtarget case Intrinsic::x86_avx_vtestc_pd: case Intrinsic::x86_avx_vtestc_ps_256: case Intrinsic::x86_avx_vtestc_pd_256: - IsTestPacked = true; // Fallthrough + IsTestPacked = true; + LLVM_FALLTHROUGH; case Intrinsic::x86_sse41_ptestc: case Intrinsic::x86_avx_ptestc_256: // CF = 1 @@ -18054,7 +19451,8 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget &Subtarget case Intrinsic::x86_avx_vtestnzc_pd: case Intrinsic::x86_avx_vtestnzc_ps_256: case Intrinsic::x86_avx_vtestnzc_pd_256: - IsTestPacked = true; // Fallthrough + IsTestPacked = true; + LLVM_FALLTHROUGH; case Intrinsic::x86_sse41_ptestnzc: case Intrinsic::x86_avx_ptestnzc_256: // ZF and CF = 0 @@ -18066,18 +19464,17 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget &Subtarget SDValue RHS = Op.getOperand(2); unsigned TestOpc = IsTestPacked ? X86ISD::TESTP : X86ISD::PTEST; SDValue Test = DAG.getNode(TestOpc, dl, MVT::i32, LHS, RHS); - SDValue CC = DAG.getConstant(X86CC, dl, MVT::i8); - SDValue SetCC = DAG.getNode(X86ISD::SETCC, dl, MVT::i8, CC, Test); + SDValue SetCC = getSETCC(X86CC, Test, dl, DAG); return DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, SetCC); } case Intrinsic::x86_avx512_kortestz_w: case Intrinsic::x86_avx512_kortestc_w: { - unsigned X86CC = (IntNo == Intrinsic::x86_avx512_kortestz_w)? X86::COND_E: X86::COND_B; + X86::CondCode X86CC = + (IntNo == Intrinsic::x86_avx512_kortestz_w) ? X86::COND_E : X86::COND_B; SDValue LHS = DAG.getBitcast(MVT::v16i1, Op.getOperand(1)); SDValue RHS = DAG.getBitcast(MVT::v16i1, Op.getOperand(2)); - SDValue CC = DAG.getConstant(X86CC, dl, MVT::i8); SDValue Test = DAG.getNode(X86ISD::KORTEST, dl, MVT::i32, LHS, RHS); - SDValue SetCC = DAG.getNode(X86ISD::SETCC, dl, MVT::i8, CC, Test); + SDValue SetCC = getSETCC(X86CC, Test, dl, DAG); return DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, SetCC); } @@ -18092,7 +19489,7 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget &Subtarget case Intrinsic::x86_sse42_pcmpistriz128: case Intrinsic::x86_sse42_pcmpestriz128: { unsigned Opcode; - unsigned X86CC; + X86::CondCode X86CC; switch (IntNo) { default: llvm_unreachable("Impossible intrinsic"); // Can't reach here. case Intrinsic::x86_sse42_pcmpistria128: @@ -18139,9 +19536,7 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget &Subtarget SmallVector<SDValue, 5> NewOps(Op->op_begin()+1, Op->op_end()); SDVTList VTs = DAG.getVTList(Op.getValueType(), MVT::i32); SDValue PCMP = DAG.getNode(Opcode, dl, VTs, NewOps); - SDValue SetCC = DAG.getNode(X86ISD::SETCC, dl, MVT::i8, - DAG.getConstant(X86CC, dl, MVT::i8), - SDValue(PCMP.getNode(), 1)); + SDValue SetCC = getSETCC(X86CC, SDValue(PCMP.getNode(), 1), dl, DAG); return DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, SetCC); } @@ -18267,6 +19662,51 @@ static SDValue getPrefetchNode(unsigned Opc, SDValue Op, SelectionDAG &DAG, return SDValue(Res, 0); } +/// Handles the lowering of builtin intrinsic that return the value +/// of the extended control register. +static void getExtendedControlRegister(SDNode *N, const SDLoc &DL, + SelectionDAG &DAG, + const X86Subtarget &Subtarget, + SmallVectorImpl<SDValue> &Results) { + assert(N->getNumOperands() == 3 && "Unexpected number of operands!"); + SDVTList Tys = DAG.getVTList(MVT::Other, MVT::Glue); + SDValue LO, HI; + + // The ECX register is used to select the index of the XCR register to + // return. + SDValue Chain = + DAG.getCopyToReg(N->getOperand(0), DL, X86::ECX, N->getOperand(2)); + SDNode *N1 = DAG.getMachineNode(X86::XGETBV, DL, Tys, Chain); + Chain = SDValue(N1, 0); + + // Reads the content of XCR and returns it in registers EDX:EAX. + if (Subtarget.is64Bit()) { + LO = DAG.getCopyFromReg(Chain, DL, X86::RAX, MVT::i64, SDValue(N1, 1)); + HI = DAG.getCopyFromReg(LO.getValue(1), DL, X86::RDX, MVT::i64, + LO.getValue(2)); + } else { + LO = DAG.getCopyFromReg(Chain, DL, X86::EAX, MVT::i32, SDValue(N1, 1)); + HI = DAG.getCopyFromReg(LO.getValue(1), DL, X86::EDX, MVT::i32, + LO.getValue(2)); + } + Chain = HI.getValue(1); + + if (Subtarget.is64Bit()) { + // Merge the two 32-bit values into a 64-bit one.. + SDValue Tmp = DAG.getNode(ISD::SHL, DL, MVT::i64, HI, + DAG.getConstant(32, DL, MVT::i8)); + Results.push_back(DAG.getNode(ISD::OR, DL, MVT::i64, LO, Tmp)); + Results.push_back(Chain); + return; + } + + // Use a buildpair to merge the two 32-bit values into a 64-bit one. + SDValue Ops[] = { LO, HI }; + SDValue Pair = DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i64, Ops); + Results.push_back(Pair); + Results.push_back(Chain); +} + /// Handles the lowering of builtin intrinsics that read performance monitor /// counters (x86_rdpmc). static void getReadPerformanceCounter(SDNode *N, const SDLoc &DL, @@ -18413,6 +19853,33 @@ static SDValue MarkEHGuard(SDValue Op, SelectionDAG &DAG) { return Chain; } +/// Emit Truncating Store with signed or unsigned saturation. +static SDValue +EmitTruncSStore(bool SignedSat, SDValue Chain, const SDLoc &Dl, SDValue Val, + SDValue Ptr, EVT MemVT, MachineMemOperand *MMO, + SelectionDAG &DAG) { + + SDVTList VTs = DAG.getVTList(MVT::Other); + SDValue Undef = DAG.getUNDEF(Ptr.getValueType()); + SDValue Ops[] = { Chain, Val, Ptr, Undef }; + return SignedSat ? + DAG.getTargetMemSDNode<TruncSStoreSDNode>(VTs, Ops, Dl, MemVT, MMO) : + DAG.getTargetMemSDNode<TruncUSStoreSDNode>(VTs, Ops, Dl, MemVT, MMO); +} + +/// Emit Masked Truncating Store with signed or unsigned saturation. +static SDValue +EmitMaskedTruncSStore(bool SignedSat, SDValue Chain, const SDLoc &Dl, + SDValue Val, SDValue Ptr, SDValue Mask, EVT MemVT, + MachineMemOperand *MMO, SelectionDAG &DAG) { + + SDVTList VTs = DAG.getVTList(MVT::Other); + SDValue Ops[] = { Chain, Ptr, Mask, Val }; + return SignedSat ? + DAG.getTargetMemSDNode<MaskedTruncSStoreSDNode>(VTs, Ops, Dl, MemVT, MMO) : + DAG.getTargetMemSDNode<MaskedTruncUSStoreSDNode>(VTs, Ops, Dl, MemVT, MMO); +} + static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG) { unsigned IntNo = cast<ConstantSDNode>(Op.getOperand(1))->getZExtValue(); @@ -18429,8 +19896,8 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget, IntNo == llvm::Intrinsic::x86_flags_write_u64) { // We need a frame pointer because this will get lowered to a PUSH/POP // sequence. - MachineFrameInfo *MFI = DAG.getMachineFunction().getFrameInfo(); - MFI->setHasCopyImplyingStackAdjustment(true); + MachineFrameInfo &MFI = DAG.getMachineFunction().getFrameInfo(); + MFI.setHasCopyImplyingStackAdjustment(true); // Don't do anything here, we will expand these intrinsics out later // during ExpandISelPseudos in EmitInstrWithCustomInserter. return SDValue(); @@ -18509,13 +19976,18 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget, getReadPerformanceCounter(Op.getNode(), dl, DAG, Subtarget, Results); return DAG.getMergeValues(Results, dl); } + // Get Extended Control Register. + case XGETBV: { + SmallVector<SDValue, 2> Results; + getExtendedControlRegister(Op.getNode(), dl, DAG, Subtarget, Results); + return DAG.getMergeValues(Results, dl); + } // XTEST intrinsics. case XTEST: { SDVTList VTs = DAG.getVTList(Op->getValueType(0), MVT::Other); SDValue InTrans = DAG.getNode(IntrData->Opc0, dl, VTs, Op.getOperand(0)); - SDValue SetCC = DAG.getNode(X86ISD::SETCC, dl, MVT::i8, - DAG.getConstant(X86::COND_NE, dl, MVT::i8), - InTrans); + + SDValue SetCC = getSETCC(X86::COND_NE, InTrans, dl, DAG); SDValue Ret = DAG.getNode(ISD::ZERO_EXTEND, dl, Op->getValueType(0), SetCC); return DAG.getNode(ISD::MERGE_VALUES, dl, Op->getVTList(), Ret, SDValue(InTrans.getNode(), 1)); @@ -18530,9 +20002,7 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget, Op.getOperand(4), GenCF.getValue(1)); SDValue Store = DAG.getStore(Op.getOperand(0), dl, Res.getValue(0), Op.getOperand(5), MachinePointerInfo()); - SDValue SetCC = DAG.getNode(X86ISD::SETCC, dl, MVT::i8, - DAG.getConstant(X86::COND_B, dl, MVT::i8), - Res.getValue(1)); + SDValue SetCC = getSETCC(X86::COND_B, Res.getValue(1), dl, DAG); SDValue Results[] = { SetCC, Store }; return DAG.getMergeValues(Results, dl); } @@ -18550,11 +20020,12 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget, return DAG.getStore(Chain, dl, DataToCompress, Addr, MemIntr->getMemOperand()); - SDValue Compressed = - getVectorMaskingNode(DAG.getNode(IntrData->Opc0, dl, VT, DataToCompress), - Mask, DAG.getUNDEF(VT), Subtarget, DAG); - return DAG.getStore(Chain, dl, Compressed, Addr, - MemIntr->getMemOperand()); + MVT MaskVT = MVT::getVectorVT(MVT::i1, VT.getVectorNumElements()); + SDValue VMask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl); + + return DAG.getMaskedStore(Chain, dl, DataToCompress, Addr, VMask, VT, + MemIntr->getMemOperand(), + false /* truncating */, true /* compressing */); } case TRUNCATE_TO_MEM_VI8: case TRUNCATE_TO_MEM_VI16: @@ -18567,18 +20038,39 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget, MemIntrinsicSDNode *MemIntr = dyn_cast<MemIntrinsicSDNode>(Op); assert(MemIntr && "Expected MemIntrinsicSDNode!"); - EVT VT = MemIntr->getMemoryVT(); + EVT MemVT = MemIntr->getMemoryVT(); - if (isAllOnesConstant(Mask)) // return just a truncate store - return DAG.getTruncStore(Chain, dl, DataToTruncate, Addr, VT, - MemIntr->getMemOperand()); + uint16_t TruncationOp = IntrData->Opc0; + switch (TruncationOp) { + case X86ISD::VTRUNC: { + if (isAllOnesConstant(Mask)) // return just a truncate store + return DAG.getTruncStore(Chain, dl, DataToTruncate, Addr, MemVT, + MemIntr->getMemOperand()); - MVT MaskVT = MVT::getVectorVT(MVT::i1, VT.getVectorNumElements()); - SDValue VMask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl); + MVT MaskVT = MVT::getVectorVT(MVT::i1, MemVT.getVectorNumElements()); + SDValue VMask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl); - return DAG.getMaskedStore(Chain, dl, DataToTruncate, Addr, VMask, VT, - MemIntr->getMemOperand(), true); + return DAG.getMaskedStore(Chain, dl, DataToTruncate, Addr, VMask, MemVT, + MemIntr->getMemOperand(), true /* truncating */); + } + case X86ISD::VTRUNCUS: + case X86ISD::VTRUNCS: { + bool IsSigned = (TruncationOp == X86ISD::VTRUNCS); + if (isAllOnesConstant(Mask)) + return EmitTruncSStore(IsSigned, Chain, dl, DataToTruncate, Addr, MemVT, + MemIntr->getMemOperand(), DAG); + + MVT MaskVT = MVT::getVectorVT(MVT::i1, MemVT.getVectorNumElements()); + SDValue VMask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl); + + return EmitMaskedTruncSStore(IsSigned, Chain, dl, DataToTruncate, Addr, + VMask, MemVT, MemIntr->getMemOperand(), DAG); + } + default: + llvm_unreachable("Unsupported truncstore intrinsic"); + } } + case EXPAND_FROM_MEM: { SDValue Mask = Op.getOperand(4); SDValue PassThru = Op.getOperand(3); @@ -18589,24 +20081,24 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget, MemIntrinsicSDNode *MemIntr = dyn_cast<MemIntrinsicSDNode>(Op); assert(MemIntr && "Expected MemIntrinsicSDNode!"); - SDValue DataToExpand = DAG.getLoad(VT, dl, Chain, Addr, - MemIntr->getMemOperand()); + if (isAllOnesConstant(Mask)) // Return a regular (unmasked) vector load. + return DAG.getLoad(VT, dl, Chain, Addr, MemIntr->getMemOperand()); + if (X86::isZeroNode(Mask)) + return DAG.getUNDEF(VT); - if (isAllOnesConstant(Mask)) // return just a load - return DataToExpand; - - SDValue Results[] = { - getVectorMaskingNode(DAG.getNode(IntrData->Opc0, dl, VT, DataToExpand), - Mask, PassThru, Subtarget, DAG), Chain}; - return DAG.getMergeValues(Results, dl); + MVT MaskVT = MVT::getVectorVT(MVT::i1, VT.getVectorNumElements()); + SDValue VMask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl); + return DAG.getMaskedLoad(VT, dl, Chain, Addr, VMask, PassThru, VT, + MemIntr->getMemOperand(), ISD::NON_EXTLOAD, + true /* expanding */); } } } SDValue X86TargetLowering::LowerRETURNADDR(SDValue Op, SelectionDAG &DAG) const { - MachineFrameInfo *MFI = DAG.getMachineFunction().getFrameInfo(); - MFI->setReturnAddressIsTaken(true); + MachineFrameInfo &MFI = DAG.getMachineFunction().getFrameInfo(); + MFI.setReturnAddressIsTaken(true); if (verifyReturnAddressArgumentIsConstant(Op, DAG)) return SDValue(); @@ -18630,14 +20122,20 @@ SDValue X86TargetLowering::LowerRETURNADDR(SDValue Op, MachinePointerInfo()); } +SDValue X86TargetLowering::LowerADDROFRETURNADDR(SDValue Op, + SelectionDAG &DAG) const { + DAG.getMachineFunction().getFrameInfo().setReturnAddressIsTaken(true); + return getReturnAddressFrameIndex(DAG); +} + SDValue X86TargetLowering::LowerFRAMEADDR(SDValue Op, SelectionDAG &DAG) const { MachineFunction &MF = DAG.getMachineFunction(); - MachineFrameInfo *MFI = MF.getFrameInfo(); + MachineFrameInfo &MFI = MF.getFrameInfo(); X86MachineFunctionInfo *FuncInfo = MF.getInfo<X86MachineFunctionInfo>(); const X86RegisterInfo *RegInfo = Subtarget.getRegisterInfo(); EVT VT = Op.getValueType(); - MFI->setFrameAddressIsTaken(true); + MFI.setFrameAddressIsTaken(true); if (MF.getTarget().getMCAsmInfo()->usesWindowsCFI()) { // Depth > 0 makes no sense on targets which use Windows unwind codes. It @@ -18647,7 +20145,7 @@ SDValue X86TargetLowering::LowerFRAMEADDR(SDValue Op, SelectionDAG &DAG) const { if (!FrameAddrIndex) { // Set up a frame object for the return address. unsigned SlotSize = RegInfo->getSlotSize(); - FrameAddrIndex = MF.getFrameInfo()->CreateFixedObject( + FrameAddrIndex = MF.getFrameInfo().CreateFixedObject( SlotSize, /*Offset=*/0, /*IsImmutable=*/false); FuncInfo->setFAIndex(FrameAddrIndex); } @@ -18965,7 +20463,7 @@ SDValue X86TargetLowering::LowerFLT_ROUNDS_(SDValue Op, SDLoc DL(Op); // Save FP Control Word to stack slot - int SSFI = MF.getFrameInfo()->CreateStackObject(2, StackAlignment, false); + int SSFI = MF.getFrameInfo().CreateStackObject(2, StackAlignment, false); SDValue StackSlot = DAG.getFrameIndex(SSFI, getPointerTy(DAG.getDataLayout())); @@ -19083,7 +20581,7 @@ static SDValue LowerVectorCTLZInRegLUT(SDValue Op, const SDLoc &DL, SmallVector<SDValue, 64> LUTVec; for (int i = 0; i < NumBytes; ++i) LUTVec.push_back(DAG.getConstant(LUT[i % 16], DL, MVT::i8)); - SDValue InRegLUT = DAG.getNode(ISD::BUILD_VECTOR, DL, CurrVT, LUTVec); + SDValue InRegLUT = DAG.getBuildVector(CurrVT, DL, LUTVec); // Begin by bitcasting the input to byte vector, then split those bytes // into lo/hi nibbles and use the PSHUFB LUT to perform CLTZ on each of them. @@ -19444,43 +20942,63 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget &Subtarget, assert((VT == MVT::v2i64 || VT == MVT::v4i64 || VT == MVT::v8i64) && "Only know how to lower V2I64/V4I64/V8I64 multiply"); + // 32-bit vector types used for MULDQ/MULUDQ. + MVT MulVT = MVT::getVectorVT(MVT::i32, VT.getSizeInBits() / 32); + + // MULDQ returns the 64-bit result of the signed multiplication of the lower + // 32-bits. We can lower with this if the sign bits stretch that far. + if (Subtarget.hasSSE41() && DAG.ComputeNumSignBits(A) > 32 && + DAG.ComputeNumSignBits(B) > 32) { + return DAG.getNode(X86ISD::PMULDQ, dl, VT, DAG.getBitcast(MulVT, A), + DAG.getBitcast(MulVT, B)); + } + // Ahi = psrlqi(a, 32); // Bhi = psrlqi(b, 32); // // AloBlo = pmuludq(a, b); // AloBhi = pmuludq(a, Bhi); // AhiBlo = pmuludq(Ahi, b); + // + // Hi = psllqi(AloBhi + AhiBlo, 32); + // return AloBlo + Hi; + APInt LowerBitsMask = APInt::getLowBitsSet(64, 32); + bool ALoIsZero = DAG.MaskedValueIsZero(A, LowerBitsMask); + bool BLoIsZero = DAG.MaskedValueIsZero(B, LowerBitsMask); + + APInt UpperBitsMask = APInt::getHighBitsSet(64, 32); + bool AHiIsZero = DAG.MaskedValueIsZero(A, UpperBitsMask); + bool BHiIsZero = DAG.MaskedValueIsZero(B, UpperBitsMask); - // AloBhi = psllqi(AloBhi, 32); - // AhiBlo = psllqi(AhiBlo, 32); - // return AloBlo + AloBhi + AhiBlo; + // Bit cast to 32-bit vectors for MULUDQ. + SDValue Alo = DAG.getBitcast(MulVT, A); + SDValue Blo = DAG.getBitcast(MulVT, B); - SDValue Ahi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, A, 32, DAG); - SDValue Bhi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, B, 32, DAG); + SDValue Zero = getZeroVector(VT, Subtarget, DAG, dl); - SDValue AhiBlo = Ahi; - SDValue AloBhi = Bhi; - // Bit cast to 32-bit vectors for MULUDQ - MVT MulVT = (VT == MVT::v2i64) ? MVT::v4i32 : - (VT == MVT::v4i64) ? MVT::v8i32 : MVT::v16i32; - A = DAG.getBitcast(MulVT, A); - B = DAG.getBitcast(MulVT, B); - Ahi = DAG.getBitcast(MulVT, Ahi); - Bhi = DAG.getBitcast(MulVT, Bhi); + // Only multiply lo/hi halves that aren't known to be zero. + SDValue AloBlo = Zero; + if (!ALoIsZero && !BLoIsZero) + AloBlo = DAG.getNode(X86ISD::PMULUDQ, dl, VT, Alo, Blo); - SDValue AloBlo = DAG.getNode(X86ISD::PMULUDQ, dl, VT, A, B); - // After shifting right const values the result may be all-zero. - if (!ISD::isBuildVectorAllZeros(Ahi.getNode())) { - AhiBlo = DAG.getNode(X86ISD::PMULUDQ, dl, VT, Ahi, B); - AhiBlo = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, AhiBlo, 32, DAG); + SDValue AloBhi = Zero; + if (!ALoIsZero && !BHiIsZero) { + SDValue Bhi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, B, 32, DAG); + Bhi = DAG.getBitcast(MulVT, Bhi); + AloBhi = DAG.getNode(X86ISD::PMULUDQ, dl, VT, Alo, Bhi); } - if (!ISD::isBuildVectorAllZeros(Bhi.getNode())) { - AloBhi = DAG.getNode(X86ISD::PMULUDQ, dl, VT, A, Bhi); - AloBhi = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, AloBhi, 32, DAG); + + SDValue AhiBlo = Zero; + if (!AHiIsZero && !BLoIsZero) { + SDValue Ahi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, A, 32, DAG); + Ahi = DAG.getBitcast(MulVT, Ahi); + AhiBlo = DAG.getNode(X86ISD::PMULUDQ, dl, VT, Ahi, Blo); } - SDValue Res = DAG.getNode(ISD::ADD, dl, VT, AloBlo, AloBhi); - return DAG.getNode(ISD::ADD, dl, VT, Res, AhiBlo); + SDValue Hi = DAG.getNode(ISD::ADD, dl, VT, AloBhi, AhiBlo); + Hi = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, Hi, 32, DAG); + + return DAG.getNode(ISD::ADD, dl, VT, AloBlo, Hi); } static SDValue LowerMULH(SDValue Op, const X86Subtarget &Subtarget, @@ -19905,7 +21423,8 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG, // Special case in 32-bit mode, where i64 is expanded into high and low parts. if (!Subtarget.is64Bit() && !Subtarget.hasXOP() && - (VT == MVT::v2i64 || (Subtarget.hasInt256() && VT == MVT::v4i64))) { + (VT == MVT::v2i64 || (Subtarget.hasInt256() && VT == MVT::v4i64) || + (Subtarget.hasAVX512() && VT == MVT::v8i64))) { // Peek through any splat that was introduced for i64 shift vectorization. int SplatIndex = -1; @@ -20018,7 +21537,7 @@ static SDValue LowerScalarVariableShift(SDValue Op, SelectionDAG &DAG, else if (EltVT.bitsLT(MVT::i32)) BaseShAmt = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, BaseShAmt); - return getTargetVShiftNode(X86OpcI, dl, VT, R, BaseShAmt, DAG); + return getTargetVShiftNode(X86OpcI, dl, VT, R, BaseShAmt, Subtarget, DAG); } } @@ -20147,7 +21666,7 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget, } // If possible, lower this shift as a sequence of two shifts by - // constant plus a MOVSS/MOVSD instead of scalarizing it. + // constant plus a MOVSS/MOVSD/PBLEND instead of scalarizing it. // Example: // (v4i32 (srl A, (build_vector < X, Y, Y, Y>))) // @@ -20167,7 +21686,7 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget, SDValue Amt2 = (VT == MVT::v4i32) ? Amt->getOperand(1) : Amt->getOperand(2); // See if it is possible to replace this node with a sequence of - // two shifts followed by a MOVSS/MOVSD + // two shifts followed by a MOVSS/MOVSD/PBLEND. if (VT == MVT::v4i32) { // Check if it is legal to use a MOVSS. CanBeSimplified = Amt2 == Amt->getOperand(2) && @@ -20199,21 +21718,21 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget, if (CanBeSimplified && isa<ConstantSDNode>(Amt1) && isa<ConstantSDNode>(Amt2)) { - // Replace this node with two shifts followed by a MOVSS/MOVSD. + // Replace this node with two shifts followed by a MOVSS/MOVSD/PBLEND. MVT CastVT = MVT::v4i32; SDValue Splat1 = - DAG.getConstant(cast<ConstantSDNode>(Amt1)->getAPIntValue(), dl, VT); + DAG.getConstant(cast<ConstantSDNode>(Amt1)->getAPIntValue(), dl, VT); SDValue Shift1 = DAG.getNode(Op->getOpcode(), dl, VT, R, Splat1); SDValue Splat2 = - DAG.getConstant(cast<ConstantSDNode>(Amt2)->getAPIntValue(), dl, VT); + DAG.getConstant(cast<ConstantSDNode>(Amt2)->getAPIntValue(), dl, VT); SDValue Shift2 = DAG.getNode(Op->getOpcode(), dl, VT, R, Splat2); - if (TargetOpcode == X86ISD::MOVSD) - CastVT = MVT::v2i64; SDValue BitCast1 = DAG.getBitcast(CastVT, Shift1); SDValue BitCast2 = DAG.getBitcast(CastVT, Shift2); - SDValue Result = getTargetShuffleNode(TargetOpcode, dl, CastVT, BitCast2, - BitCast1, DAG); - return DAG.getBitcast(VT, Result); + if (TargetOpcode == X86ISD::MOVSD) + return DAG.getBitcast(VT, DAG.getVectorShuffle(CastVT, dl, BitCast1, + BitCast2, {0, 1, 6, 7})); + return DAG.getBitcast(VT, DAG.getVectorShuffle(CastVT, dl, BitCast1, + BitCast2, {0, 5, 6, 7})); } } @@ -20264,15 +21783,44 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget, return DAG.getVectorShuffle(VT, dl, R02, R13, {0, 5, 2, 7}); } + // It's worth extending once and using the vXi16/vXi32 shifts for smaller + // types, but without AVX512 the extra overheads to get from vXi8 to vXi32 + // make the existing SSE solution better. + if ((Subtarget.hasInt256() && VT == MVT::v8i16) || + (Subtarget.hasAVX512() && VT == MVT::v16i16) || + (Subtarget.hasAVX512() && VT == MVT::v16i8) || + (Subtarget.hasBWI() && VT == MVT::v32i8)) { + MVT EvtSVT = (VT == MVT::v32i8 ? MVT::i16 : MVT::i32); + MVT ExtVT = MVT::getVectorVT(EvtSVT, VT.getVectorNumElements()); + unsigned ExtOpc = + Op.getOpcode() == ISD::SRA ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND; + R = DAG.getNode(ExtOpc, dl, ExtVT, R); + Amt = DAG.getNode(ISD::ANY_EXTEND, dl, ExtVT, Amt); + return DAG.getNode(ISD::TRUNCATE, dl, VT, + DAG.getNode(Op.getOpcode(), dl, ExtVT, R, Amt)); + } + if (VT == MVT::v16i8 || - (VT == MVT::v32i8 && Subtarget.hasInt256() && !Subtarget.hasXOP())) { + (VT == MVT::v32i8 && Subtarget.hasInt256() && !Subtarget.hasXOP()) || + (VT == MVT::v64i8 && Subtarget.hasBWI())) { MVT ExtVT = MVT::getVectorVT(MVT::i16, VT.getVectorNumElements() / 2); unsigned ShiftOpcode = Op->getOpcode(); auto SignBitSelect = [&](MVT SelVT, SDValue Sel, SDValue V0, SDValue V1) { - // On SSE41 targets we make use of the fact that VSELECT lowers - // to PBLENDVB which selects bytes based just on the sign bit. - if (Subtarget.hasSSE41()) { + if (VT.is512BitVector()) { + // On AVX512BW targets we make use of the fact that VSELECT lowers + // to a masked blend which selects bytes based just on the sign bit + // extracted to a mask. + MVT MaskVT = MVT::getVectorVT(MVT::i1, VT.getVectorNumElements()); + V0 = DAG.getBitcast(VT, V0); + V1 = DAG.getBitcast(VT, V1); + Sel = DAG.getBitcast(VT, Sel); + Sel = DAG.getNode(X86ISD::CVT2MASK, dl, MaskVT, Sel); + return DAG.getBitcast(SelVT, + DAG.getNode(ISD::VSELECT, dl, VT, Sel, V0, V1)); + } else if (Subtarget.hasSSE41()) { + // On SSE41 targets we make use of the fact that VSELECT lowers + // to PBLENDVB which selects bytes based just on the sign bit. V0 = DAG.getBitcast(VT, V0); V1 = DAG.getBitcast(VT, V1); Sel = DAG.getBitcast(VT, Sel); @@ -20372,19 +21920,6 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget, } } - // It's worth extending once and using the v8i32 shifts for 16-bit types, but - // the extra overheads to get from v16i8 to v8i32 make the existing SSE - // solution better. - if (Subtarget.hasInt256() && VT == MVT::v8i16) { - MVT ExtVT = MVT::v8i32; - unsigned ExtOpc = - Op.getOpcode() == ISD::SRA ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND; - R = DAG.getNode(ExtOpc, dl, ExtVT, R); - Amt = DAG.getNode(ISD::ANY_EXTEND, dl, ExtVT, Amt); - return DAG.getNode(ISD::TRUNCATE, dl, VT, - DAG.getNode(Op.getOpcode(), dl, ExtVT, R, Amt)); - } - if (Subtarget.hasInt256() && !Subtarget.hasXOP() && VT == MVT::v16i16) { MVT ExtVT = MVT::v8i32; SDValue Z = getZeroVector(VT, Subtarget, DAG, dl); @@ -20519,7 +22054,7 @@ static SDValue LowerXALUO(SDValue Op, SelectionDAG &DAG) { SDValue LHS = N->getOperand(0); SDValue RHS = N->getOperand(1); unsigned BaseOp = 0; - unsigned Cond = 0; + X86::CondCode Cond; SDLoc DL(Op); switch (Op.getOpcode()) { default: llvm_unreachable("Unknown ovf instruction!"); @@ -20567,16 +22102,11 @@ static SDValue LowerXALUO(SDValue Op, SelectionDAG &DAG) { MVT::i32); SDValue Sum = DAG.getNode(X86ISD::UMUL, DL, VTs, LHS, RHS); - SDValue SetCC = - DAG.getNode(X86ISD::SETCC, DL, MVT::i8, - DAG.getConstant(X86::COND_O, DL, MVT::i32), - SDValue(Sum.getNode(), 2)); + SDValue SetCC = getSETCC(X86::COND_O, SDValue(Sum.getNode(), 2), DL, DAG); - if (N->getValueType(1) == MVT::i1) { - SetCC = DAG.getNode(ISD::AssertZext, DL, MVT::i8, SetCC, - DAG.getValueType(MVT::i1)); + if (N->getValueType(1) == MVT::i1) SetCC = DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, SetCC); - } + return DAG.getNode(ISD::MERGE_VALUES, DL, N->getVTList(), Sum, SetCC); } } @@ -20585,16 +22115,11 @@ static SDValue LowerXALUO(SDValue Op, SelectionDAG &DAG) { SDVTList VTs = DAG.getVTList(N->getValueType(0), MVT::i32); SDValue Sum = DAG.getNode(BaseOp, DL, VTs, LHS, RHS); - SDValue SetCC = - DAG.getNode(X86ISD::SETCC, DL, MVT::i8, - DAG.getConstant(Cond, DL, MVT::i32), - SDValue(Sum.getNode(), 1)); + SDValue SetCC = getSETCC(Cond, SDValue(Sum.getNode(), 1), DL, DAG); - if (N->getValueType(1) == MVT::i1) { - SetCC = DAG.getNode(ISD::AssertZext, DL, MVT::i8, SetCC, - DAG.getValueType(MVT::i1)); + if (N->getValueType(1) == MVT::i1) SetCC = DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, SetCC); - } + return DAG.getNode(ISD::MERGE_VALUES, DL, N->getVTList(), Sum, SetCC); } @@ -20790,9 +22315,7 @@ static SDValue LowerCMP_SWAP(SDValue Op, const X86Subtarget &Subtarget, DAG.getCopyFromReg(Result.getValue(0), DL, Reg, T, Result.getValue(1)); SDValue EFLAGS = DAG.getCopyFromReg(cpOut.getValue(1), DL, X86::EFLAGS, MVT::i32, cpOut.getValue(2)); - SDValue Success = DAG.getNode(X86ISD::SETCC, DL, Op->getValueType(1), - DAG.getConstant(X86::COND_E, DL, MVT::i8), - EFLAGS); + SDValue Success = getSETCC(X86::COND_E, EFLAGS, DL, DAG); DAG.ReplaceAllUsesOfValueWith(Op.getValue(0), cpOut); DAG.ReplaceAllUsesOfValueWith(Op.getValue(1), Success); @@ -20898,8 +22421,9 @@ static SDValue LowerHorizontalByteSum(SDValue V, MVT VT, // two v2i64 vectors which concatenated are the 4 population counts. We can // then use PACKUSWB to shrink and concatenate them into a v4i32 again. SDValue Zeros = getZeroVector(VT, Subtarget, DAG, DL); - SDValue Low = DAG.getNode(X86ISD::UNPCKL, DL, VT, V, Zeros); - SDValue High = DAG.getNode(X86ISD::UNPCKH, DL, VT, V, Zeros); + SDValue V32 = DAG.getBitcast(VT, V); + SDValue Low = DAG.getNode(X86ISD::UNPCKL, DL, VT, V32, Zeros); + SDValue High = DAG.getNode(X86ISD::UNPCKH, DL, VT, V32, Zeros); // Do the horizontal sums into two v2i64s. Zeros = getZeroVector(ByteVecVT, Subtarget, DAG, DL); @@ -21054,6 +22578,8 @@ static SDValue LowerVectorCTPOPBitmath(SDValue Op, const SDLoc &DL, DAG); } +// Please ensure that any codegen change from LowerVectorCTPOP is reflected in +// updated cost models in X86TTIImpl::getIntrinsicInstrCost. static SDValue LowerVectorCTPOP(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG) { MVT VT = Op.getSimpleValueType(); @@ -21260,8 +22786,7 @@ static SDValue lowerAtomicArith(SDValue N, SelectionDAG &DAG, AtomicSDNode *AN = cast<AtomicSDNode>(N.getNode()); RHS = DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), RHS); return DAG.getAtomic(ISD::ATOMIC_LOAD_ADD, DL, VT, Chain, LHS, - RHS, AN->getMemOperand(), AN->getOrdering(), - AN->getSynchScope()); + RHS, AN->getMemOperand()); } assert(Opc == ISD::ATOMIC_LOAD_ADD && "Used AtomicRMW ops other than Add should have been expanded!"); @@ -21292,9 +22817,7 @@ static SDValue LowerATOMIC_STORE(SDValue Op, SelectionDAG &DAG) { cast<AtomicSDNode>(Node)->getMemoryVT(), Node->getOperand(0), Node->getOperand(1), Node->getOperand(2), - cast<AtomicSDNode>(Node)->getMemOperand(), - cast<AtomicSDNode>(Node)->getOrdering(), - cast<AtomicSDNode>(Node)->getSynchScope()); + cast<AtomicSDNode>(Node)->getMemOperand()); return Swap.getValue(1); } // Other atomic stores have a simple pattern. @@ -21534,26 +23057,48 @@ static SDValue LowerMLOAD(SDValue Op, const X86Subtarget &Subtarget, SDValue Mask = N->getMask(); SDLoc dl(Op); + assert((!N->isExpandingLoad() || Subtarget.hasAVX512()) && + "Expanding masked load is supported on AVX-512 target only!"); + + assert((!N->isExpandingLoad() || ScalarVT.getSizeInBits() >= 32) && + "Expanding masked load is supported for 32 and 64-bit types only!"); + + // 4x32, 4x64 and 2x64 vectors of non-expanding loads are legal regardless of + // VLX. These types for exp-loads are handled here. + if (!N->isExpandingLoad() && VT.getVectorNumElements() <= 4) + return Op; + assert(Subtarget.hasAVX512() && !Subtarget.hasVLX() && !VT.is512BitVector() && "Cannot lower masked load op."); - assert(((ScalarVT == MVT::i32 || ScalarVT == MVT::f32) || + assert((ScalarVT.getSizeInBits() >= 32 || (Subtarget.hasBWI() && (ScalarVT == MVT::i8 || ScalarVT == MVT::i16))) && "Unsupported masked load op."); // This operation is legal for targets with VLX, but without // VLX the vector should be widened to 512 bit - unsigned NumEltsInWideVec = 512/VT.getScalarSizeInBits(); + unsigned NumEltsInWideVec = 512 / VT.getScalarSizeInBits(); MVT WideDataVT = MVT::getVectorVT(ScalarVT, NumEltsInWideVec); - MVT WideMaskVT = MVT::getVectorVT(MVT::i1, NumEltsInWideVec); SDValue Src0 = N->getSrc0(); Src0 = ExtendToType(Src0, WideDataVT, DAG); + + // Mask element has to be i1. + MVT MaskEltTy = Mask.getSimpleValueType().getScalarType(); + assert((MaskEltTy == MVT::i1 || VT.getVectorNumElements() <= 4) && + "We handle 4x32, 4x64 and 2x64 vectors only in this casse"); + + MVT WideMaskVT = MVT::getVectorVT(MaskEltTy, NumEltsInWideVec); + Mask = ExtendToType(Mask, WideMaskVT, DAG, true); + if (MaskEltTy != MVT::i1) + Mask = DAG.getNode(ISD::TRUNCATE, dl, + MVT::getVectorVT(MVT::i1, NumEltsInWideVec), Mask); SDValue NewLoad = DAG.getMaskedLoad(WideDataVT, dl, N->getChain(), N->getBasePtr(), Mask, Src0, N->getMemoryVT(), N->getMemOperand(), - N->getExtensionType()); + N->getExtensionType(), + N->isExpandingLoad()); SDValue Exract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, NewLoad.getValue(0), @@ -21571,10 +23116,20 @@ static SDValue LowerMSTORE(SDValue Op, const X86Subtarget &Subtarget, SDValue Mask = N->getMask(); SDLoc dl(Op); + assert((!N->isCompressingStore() || Subtarget.hasAVX512()) && + "Expanding masked load is supported on AVX-512 target only!"); + + assert((!N->isCompressingStore() || ScalarVT.getSizeInBits() >= 32) && + "Expanding masked load is supported for 32 and 64-bit types only!"); + + // 4x32 and 2x64 vectors of non-compressing stores are legal regardless to VLX. + if (!N->isCompressingStore() && VT.getVectorNumElements() <= 4) + return Op; + assert(Subtarget.hasAVX512() && !Subtarget.hasVLX() && !VT.is512BitVector() && "Cannot lower masked store op."); - assert(((ScalarVT == MVT::i32 || ScalarVT == MVT::f32) || + assert((ScalarVT.getSizeInBits() >= 32 || (Subtarget.hasBWI() && (ScalarVT == MVT::i8 || ScalarVT == MVT::i16))) && "Unsupported masked store op."); @@ -21583,12 +23138,22 @@ static SDValue LowerMSTORE(SDValue Op, const X86Subtarget &Subtarget, // VLX the vector should be widened to 512 bit unsigned NumEltsInWideVec = 512/VT.getScalarSizeInBits(); MVT WideDataVT = MVT::getVectorVT(ScalarVT, NumEltsInWideVec); - MVT WideMaskVT = MVT::getVectorVT(MVT::i1, NumEltsInWideVec); + + // Mask element has to be i1. + MVT MaskEltTy = Mask.getSimpleValueType().getScalarType(); + assert((MaskEltTy == MVT::i1 || VT.getVectorNumElements() <= 4) && + "We handle 4x32, 4x64 and 2x64 vectors only in this casse"); + + MVT WideMaskVT = MVT::getVectorVT(MaskEltTy, NumEltsInWideVec); + DataToStore = ExtendToType(DataToStore, WideDataVT, DAG); Mask = ExtendToType(Mask, WideMaskVT, DAG, true); + if (MaskEltTy != MVT::i1) + Mask = DAG.getNode(ISD::TRUNCATE, dl, + MVT::getVectorVT(MVT::i1, NumEltsInWideVec), Mask); return DAG.getMaskedStore(N->getChain(), dl, DataToStore, N->getBasePtr(), Mask, N->getMemoryVT(), N->getMemOperand(), - N->isTruncatingStore()); + N->isTruncatingStore(), N->isCompressingStore()); } static SDValue LowerMGATHER(SDValue Op, const X86Subtarget &Subtarget, @@ -21734,10 +23299,11 @@ SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { case ISD::ZERO_EXTEND: return LowerZERO_EXTEND(Op, Subtarget, DAG); case ISD::SIGN_EXTEND: return LowerSIGN_EXTEND(Op, Subtarget, DAG); case ISD::ANY_EXTEND: return LowerANY_EXTEND(Op, Subtarget, DAG); + case ISD::ZERO_EXTEND_VECTOR_INREG: case ISD::SIGN_EXTEND_VECTOR_INREG: - return LowerSIGN_EXTEND_VECTOR_INREG(Op, Subtarget, DAG); - case ISD::FP_TO_SINT: return LowerFP_TO_SINT(Op, DAG); - case ISD::FP_TO_UINT: return LowerFP_TO_UINT(Op, DAG); + return LowerEXTEND_VECTOR_INREG(Op, Subtarget, DAG); + case ISD::FP_TO_SINT: + case ISD::FP_TO_UINT: return LowerFP_TO_INT(Op, Subtarget, DAG); case ISD::FP_EXTEND: return LowerFP_EXTEND(Op, DAG); case ISD::LOAD: return LowerExtendedLoad(Op, Subtarget, DAG); case ISD::FABS: @@ -21756,6 +23322,7 @@ SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { case ISD::INTRINSIC_VOID: case ISD::INTRINSIC_W_CHAIN: return LowerINTRINSIC_W_CHAIN(Op, Subtarget, DAG); case ISD::RETURNADDR: return LowerRETURNADDR(Op, DAG); + case ISD::ADDROFRETURNADDR: return LowerADDROFRETURNADDR(Op, DAG); case ISD::FRAMEADDR: return LowerFRAMEADDR(Op, DAG); case ISD::FRAME_TO_ARGS_OFFSET: return LowerFRAME_TO_ARGS_OFFSET(Op, DAG); @@ -21830,7 +23397,7 @@ void X86TargetLowering::LowerOperationWrapper(SDNode *N, // In some cases (LowerSINT_TO_FP for example) Res has more result values // than original node, chain should be dropped(last value). for (unsigned I = 0, E = N->getNumValues(); I != E; ++I) - Results.push_back(Res.getValue(I)); + Results.push_back(Res.getValue(I)); } /// Replace a node with an illegal result type with a new node built out of @@ -21851,9 +23418,9 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N, auto InVTSize = InVT.getSizeInBits(); const unsigned RegSize = (InVTSize > 128) ? ((InVTSize > 256) ? 512 : 256) : 128; - assert((!Subtarget.hasAVX512() || RegSize < 512) && - "512-bit vector requires AVX512"); - assert((!Subtarget.hasAVX2() || RegSize < 256) && + assert((Subtarget.hasBWI() || RegSize < 512) && + "512-bit vector requires AVX512BW"); + assert((Subtarget.hasAVX2() || RegSize < 256) && "256-bit vector requires AVX2"); auto ElemVT = InVT.getVectorElementType(); @@ -21888,13 +23455,6 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N, Results.push_back(DAG.getNode(N->getOpcode(), dl, MVT::v4f32, LHS, RHS)); return; } - case ISD::SIGN_EXTEND_INREG: - case ISD::ADDC: - case ISD::ADDE: - case ISD::SUBC: - case ISD::SUBE: - // We don't want to expand or promote these. - return; case ISD::SDIV: case ISD::UDIV: case ISD::SREM: @@ -21909,6 +23469,36 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N, case ISD::FP_TO_UINT: { bool IsSigned = N->getOpcode() == ISD::FP_TO_SINT; + if (N->getValueType(0) == MVT::v2i32) { + assert((IsSigned || Subtarget.hasAVX512()) && + "Can only handle signed conversion without AVX512"); + assert(Subtarget.hasSSE2() && "Requires at least SSE2!"); + SDValue Src = N->getOperand(0); + if (Src.getValueType() == MVT::v2f64) { + SDValue Idx = DAG.getIntPtrConstant(0, dl); + SDValue Res = DAG.getNode(IsSigned ? X86ISD::CVTTP2SI + : X86ISD::CVTTP2UI, + dl, MVT::v4i32, Src); + Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v2i32, Res, Idx); + Results.push_back(Res); + return; + } + if (Src.getValueType() == MVT::v2f32) { + SDValue Idx = DAG.getIntPtrConstant(0, dl); + SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4f32, Src, + DAG.getUNDEF(MVT::v2f32)); + Res = DAG.getNode(IsSigned ? ISD::FP_TO_SINT + : ISD::FP_TO_UINT, dl, MVT::v4i32, Res); + Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v2i32, Res, Idx); + Results.push_back(Res); + return; + } + + // The FP_TO_INTHelper below only handles f32/f64/f80 scalar inputs, + // so early out here. + return; + } + std::pair<SDValue,SDValue> Vals = FP_TO_INTHelper(SDValue(N, 0), DAG, IsSigned, /*IsReplace=*/ true); SDValue FIST = Vals.first, StackSlot = Vals.second; @@ -21923,13 +23513,28 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N, } return; } + case ISD::SINT_TO_FP: { + assert(Subtarget.hasDQI() && Subtarget.hasVLX() && "Requires AVX512DQVL!"); + SDValue Src = N->getOperand(0); + if (N->getValueType(0) != MVT::v2f32 || Src.getValueType() != MVT::v2i64) + return; + Results.push_back(DAG.getNode(X86ISD::CVTSI2P, dl, MVT::v4f32, Src)); + return; + } case ISD::UINT_TO_FP: { assert(Subtarget.hasSSE2() && "Requires at least SSE2!"); - if (N->getOperand(0).getValueType() != MVT::v2i32 || - N->getValueType(0) != MVT::v2f32) + EVT VT = N->getValueType(0); + if (VT != MVT::v2f32) return; - SDValue ZExtIn = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::v2i64, - N->getOperand(0)); + SDValue Src = N->getOperand(0); + EVT SrcVT = Src.getValueType(); + if (Subtarget.hasDQI() && Subtarget.hasVLX() && SrcVT == MVT::v2i64) { + Results.push_back(DAG.getNode(X86ISD::CVTUI2P, dl, MVT::v4f32, Src)); + return; + } + if (SrcVT != MVT::v2i32) + return; + SDValue ZExtIn = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::v2i64, Src); SDValue VBias = DAG.getConstantFP(BitsToDouble(0x4330000000000000ULL), dl, MVT::v2f64); SDValue Or = DAG.getNode(ISD::OR, dl, MVT::v2i64, ZExtIn, @@ -21967,6 +23572,9 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N, Results); case Intrinsic::x86_rdpmc: return getReadPerformanceCounter(N, dl, DAG, Subtarget, Results); + + case Intrinsic::x86_xgetbv: + return getExtendedControlRegister(N, dl, DAG, Subtarget, Results); } } case ISD::INTRINSIC_WO_CHAIN: { @@ -22052,9 +23660,7 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N, SDValue EFLAGS = DAG.getCopyFromReg(cpOutH.getValue(1), dl, X86::EFLAGS, MVT::i32, cpOutH.getValue(2)); - SDValue Success = - DAG.getNode(X86ISD::SETCC, dl, MVT::i8, - DAG.getConstant(X86::COND_E, dl, MVT::i8), EFLAGS); + SDValue Success = getSETCC(X86::COND_E, EFLAGS, dl, DAG); Success = DAG.getZExtOrTrunc(Success, dl, N->getValueType(1)); Results.push_back(DAG.getNode(ISD::BUILD_PAIR, dl, T, OpsF)); @@ -22143,6 +23749,8 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { case X86ISD::SETCC: return "X86ISD::SETCC"; case X86ISD::SETCC_CARRY: return "X86ISD::SETCC_CARRY"; case X86ISD::FSETCC: return "X86ISD::FSETCC"; + case X86ISD::FSETCCM: return "X86ISD::FSETCCM"; + case X86ISD::FSETCCM_RND: return "X86ISD::FSETCCM_RND"; case X86ISD::CMOV: return "X86ISD::CMOV"; case X86ISD::BRCOND: return "X86ISD::BRCOND"; case X86ISD::RET_FLAG: return "X86ISD::RET_FLAG"; @@ -22215,11 +23823,17 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { case X86ISD::VTRUNC: return "X86ISD::VTRUNC"; case X86ISD::VTRUNCS: return "X86ISD::VTRUNCS"; case X86ISD::VTRUNCUS: return "X86ISD::VTRUNCUS"; + case X86ISD::VTRUNCSTORES: return "X86ISD::VTRUNCSTORES"; + case X86ISD::VTRUNCSTOREUS: return "X86ISD::VTRUNCSTOREUS"; + case X86ISD::VMTRUNCSTORES: return "X86ISD::VMTRUNCSTORES"; + case X86ISD::VMTRUNCSTOREUS: return "X86ISD::VMTRUNCSTOREUS"; case X86ISD::VINSERT: return "X86ISD::VINSERT"; case X86ISD::VFPEXT: return "X86ISD::VFPEXT"; + case X86ISD::VFPEXT_RND: return "X86ISD::VFPEXT_RND"; + case X86ISD::VFPEXTS_RND: return "X86ISD::VFPEXTS_RND"; case X86ISD::VFPROUND: return "X86ISD::VFPROUND"; - case X86ISD::CVTDQ2PD: return "X86ISD::CVTDQ2PD"; - case X86ISD::CVTUDQ2PD: return "X86ISD::CVTUDQ2PD"; + case X86ISD::VFPROUND_RND: return "X86ISD::VFPROUND_RND"; + case X86ISD::VFPROUNDS_RND: return "X86ISD::VFPROUNDS_RND"; case X86ISD::CVT2MASK: return "X86ISD::CVT2MASK"; case X86ISD::VSHLDQ: return "X86ISD::VSHLDQ"; case X86ISD::VSRLDQ: return "X86ISD::VSRLDQ"; @@ -22332,27 +23946,43 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { case X86ISD::FNMSUB_RND: return "X86ISD::FNMSUB_RND"; case X86ISD::FMADDSUB_RND: return "X86ISD::FMADDSUB_RND"; case X86ISD::FMSUBADD_RND: return "X86ISD::FMSUBADD_RND"; + case X86ISD::FMADDS1_RND: return "X86ISD::FMADDS1_RND"; + case X86ISD::FNMADDS1_RND: return "X86ISD::FNMADDS1_RND"; + case X86ISD::FMSUBS1_RND: return "X86ISD::FMSUBS1_RND"; + case X86ISD::FNMSUBS1_RND: return "X86ISD::FNMSUBS1_RND"; + case X86ISD::FMADDS3_RND: return "X86ISD::FMADDS3_RND"; + case X86ISD::FNMADDS3_RND: return "X86ISD::FNMADDS3_RND"; + case X86ISD::FMSUBS3_RND: return "X86ISD::FMSUBS3_RND"; + case X86ISD::FNMSUBS3_RND: return "X86ISD::FNMSUBS3_RND"; case X86ISD::VPMADD52H: return "X86ISD::VPMADD52H"; case X86ISD::VPMADD52L: return "X86ISD::VPMADD52L"; case X86ISD::VRNDSCALE: return "X86ISD::VRNDSCALE"; + case X86ISD::VRNDSCALES: return "X86ISD::VRNDSCALES"; case X86ISD::VREDUCE: return "X86ISD::VREDUCE"; + case X86ISD::VREDUCES: return "X86ISD::VREDUCES"; case X86ISD::VGETMANT: return "X86ISD::VGETMANT"; + case X86ISD::VGETMANTS: return "X86ISD::VGETMANTS"; case X86ISD::PCMPESTRI: return "X86ISD::PCMPESTRI"; case X86ISD::PCMPISTRI: return "X86ISD::PCMPISTRI"; case X86ISD::XTEST: return "X86ISD::XTEST"; case X86ISD::COMPRESS: return "X86ISD::COMPRESS"; case X86ISD::EXPAND: return "X86ISD::EXPAND"; case X86ISD::SELECT: return "X86ISD::SELECT"; + case X86ISD::SELECTS: return "X86ISD::SELECTS"; case X86ISD::ADDSUB: return "X86ISD::ADDSUB"; case X86ISD::RCP28: return "X86ISD::RCP28"; + case X86ISD::RCP28S: return "X86ISD::RCP28S"; case X86ISD::EXP2: return "X86ISD::EXP2"; case X86ISD::RSQRT28: return "X86ISD::RSQRT28"; + case X86ISD::RSQRT28S: return "X86ISD::RSQRT28S"; case X86ISD::FADD_RND: return "X86ISD::FADD_RND"; case X86ISD::FSUB_RND: return "X86ISD::FSUB_RND"; case X86ISD::FMUL_RND: return "X86ISD::FMUL_RND"; case X86ISD::FDIV_RND: return "X86ISD::FDIV_RND"; case X86ISD::FSQRT_RND: return "X86ISD::FSQRT_RND"; + case X86ISD::FSQRTS_RND: return "X86ISD::FSQRTS_RND"; case X86ISD::FGETEXP_RND: return "X86ISD::FGETEXP_RND"; + case X86ISD::FGETEXPS_RND: return "X86ISD::FGETEXPS_RND"; case X86ISD::SCALEF: return "X86ISD::SCALEF"; case X86ISD::SCALEFS: return "X86ISD::SCALEFS"; case X86ISD::ADDS: return "X86ISD::ADDS"; @@ -22361,13 +23991,27 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { case X86ISD::MULHRS: return "X86ISD::MULHRS"; case X86ISD::SINT_TO_FP_RND: return "X86ISD::SINT_TO_FP_RND"; case X86ISD::UINT_TO_FP_RND: return "X86ISD::UINT_TO_FP_RND"; - case X86ISD::FP_TO_SINT_RND: return "X86ISD::FP_TO_SINT_RND"; - case X86ISD::FP_TO_UINT_RND: return "X86ISD::FP_TO_UINT_RND"; + case X86ISD::CVTTP2SI: return "X86ISD::CVTTP2SI"; + case X86ISD::CVTTP2UI: return "X86ISD::CVTTP2UI"; + case X86ISD::CVTTP2SI_RND: return "X86ISD::CVTTP2SI_RND"; + case X86ISD::CVTTP2UI_RND: return "X86ISD::CVTTP2UI_RND"; + case X86ISD::CVTTS2SI_RND: return "X86ISD::CVTTS2SI_RND"; + case X86ISD::CVTTS2UI_RND: return "X86ISD::CVTTS2UI_RND"; + case X86ISD::CVTSI2P: return "X86ISD::CVTSI2P"; + case X86ISD::CVTUI2P: return "X86ISD::CVTUI2P"; case X86ISD::VFPCLASS: return "X86ISD::VFPCLASS"; case X86ISD::VFPCLASSS: return "X86ISD::VFPCLASSS"; case X86ISD::MULTISHIFT: return "X86ISD::MULTISHIFT"; - case X86ISD::SCALAR_FP_TO_SINT_RND: return "X86ISD::SCALAR_FP_TO_SINT_RND"; - case X86ISD::SCALAR_FP_TO_UINT_RND: return "X86ISD::SCALAR_FP_TO_UINT_RND"; + case X86ISD::SCALAR_SINT_TO_FP_RND: return "X86ISD::SCALAR_SINT_TO_FP_RND"; + case X86ISD::SCALAR_UINT_TO_FP_RND: return "X86ISD::SCALAR_UINT_TO_FP_RND"; + case X86ISD::CVTPS2PH: return "X86ISD::CVTPS2PH"; + case X86ISD::CVTPH2PS: return "X86ISD::CVTPH2PS"; + case X86ISD::CVTP2SI: return "X86ISD::CVTP2SI"; + case X86ISD::CVTP2UI: return "X86ISD::CVTP2UI"; + case X86ISD::CVTP2SI_RND: return "X86ISD::CVTP2SI_RND"; + case X86ISD::CVTP2UI_RND: return "X86ISD::CVTP2UI_RND"; + case X86ISD::CVTS2SI_RND: return "X86ISD::CVTS2SI_RND"; + case X86ISD::CVTS2UI_RND: return "X86ISD::CVTS2UI_RND"; } return nullptr; } @@ -24031,11 +25675,10 @@ X86TargetLowering::EmitSjLjDispatchBlock(MachineInstr &MI, MachineBasicBlock *BB) const { DebugLoc DL = MI.getDebugLoc(); MachineFunction *MF = BB->getParent(); - MachineModuleInfo *MMI = &MF->getMMI(); - MachineFrameInfo *MFI = MF->getFrameInfo(); + MachineFrameInfo &MFI = MF->getFrameInfo(); MachineRegisterInfo *MRI = &MF->getRegInfo(); const TargetInstrInfo *TII = Subtarget.getInstrInfo(); - int FI = MFI->getFunctionContextIndex(); + int FI = MFI.getFunctionContextIndex(); // Get a mapping of the call site numbers to all of the landing pads they're // associated with. @@ -24055,10 +25698,10 @@ X86TargetLowering::EmitSjLjDispatchBlock(MachineInstr &MI, break; } - if (!MMI->hasCallSiteLandingPad(Sym)) + if (!MF->hasCallSiteLandingPad(Sym)) continue; - for (unsigned CSI : MMI->getCallSiteLandingPad(Sym)) { + for (unsigned CSI : MF->getCallSiteLandingPad(Sym)) { CallSiteNumToLPad[CSI].push_back(&MBB); MaxCSNum = std::max(MaxCSNum, CSI); } @@ -24208,173 +25851,18 @@ X86TargetLowering::EmitSjLjDispatchBlock(MachineInstr &MI, return BB; } -// Replace 213-type (isel default) FMA3 instructions with 231-type for -// accumulator loops. Writing back to the accumulator allows the coalescer -// to remove extra copies in the loop. -// FIXME: Do this on AVX512. We don't support 231 variants yet (PR23937). -MachineBasicBlock * -X86TargetLowering::emitFMA3Instr(MachineInstr &MI, - MachineBasicBlock *MBB) const { - MachineOperand &AddendOp = MI.getOperand(3); - - // Bail out early if the addend isn't a register - we can't switch these. - if (!AddendOp.isReg()) - return MBB; - - MachineFunction &MF = *MBB->getParent(); - MachineRegisterInfo &MRI = MF.getRegInfo(); - - // Check whether the addend is defined by a PHI: - assert(MRI.hasOneDef(AddendOp.getReg()) && "Multiple defs in SSA?"); - MachineInstr &AddendDef = *MRI.def_instr_begin(AddendOp.getReg()); - if (!AddendDef.isPHI()) - return MBB; - - // Look for the following pattern: - // loop: - // %addend = phi [%entry, 0], [%loop, %result] - // ... - // %result<tied1> = FMA213 %m2<tied0>, %m1, %addend - - // Replace with: - // loop: - // %addend = phi [%entry, 0], [%loop, %result] - // ... - // %result<tied1> = FMA231 %addend<tied0>, %m1, %m2 - - for (unsigned i = 1, e = AddendDef.getNumOperands(); i < e; i += 2) { - assert(AddendDef.getOperand(i).isReg()); - MachineOperand PHISrcOp = AddendDef.getOperand(i); - MachineInstr &PHISrcInst = *MRI.def_instr_begin(PHISrcOp.getReg()); - if (&PHISrcInst == &MI) { - // Found a matching instruction. - unsigned NewFMAOpc = 0; - switch (MI.getOpcode()) { - case X86::VFMADDPDr213r: - NewFMAOpc = X86::VFMADDPDr231r; - break; - case X86::VFMADDPSr213r: - NewFMAOpc = X86::VFMADDPSr231r; - break; - case X86::VFMADDSDr213r: - NewFMAOpc = X86::VFMADDSDr231r; - break; - case X86::VFMADDSSr213r: - NewFMAOpc = X86::VFMADDSSr231r; - break; - case X86::VFMSUBPDr213r: - NewFMAOpc = X86::VFMSUBPDr231r; - break; - case X86::VFMSUBPSr213r: - NewFMAOpc = X86::VFMSUBPSr231r; - break; - case X86::VFMSUBSDr213r: - NewFMAOpc = X86::VFMSUBSDr231r; - break; - case X86::VFMSUBSSr213r: - NewFMAOpc = X86::VFMSUBSSr231r; - break; - case X86::VFNMADDPDr213r: - NewFMAOpc = X86::VFNMADDPDr231r; - break; - case X86::VFNMADDPSr213r: - NewFMAOpc = X86::VFNMADDPSr231r; - break; - case X86::VFNMADDSDr213r: - NewFMAOpc = X86::VFNMADDSDr231r; - break; - case X86::VFNMADDSSr213r: - NewFMAOpc = X86::VFNMADDSSr231r; - break; - case X86::VFNMSUBPDr213r: - NewFMAOpc = X86::VFNMSUBPDr231r; - break; - case X86::VFNMSUBPSr213r: - NewFMAOpc = X86::VFNMSUBPSr231r; - break; - case X86::VFNMSUBSDr213r: - NewFMAOpc = X86::VFNMSUBSDr231r; - break; - case X86::VFNMSUBSSr213r: - NewFMAOpc = X86::VFNMSUBSSr231r; - break; - case X86::VFMADDSUBPDr213r: - NewFMAOpc = X86::VFMADDSUBPDr231r; - break; - case X86::VFMADDSUBPSr213r: - NewFMAOpc = X86::VFMADDSUBPSr231r; - break; - case X86::VFMSUBADDPDr213r: - NewFMAOpc = X86::VFMSUBADDPDr231r; - break; - case X86::VFMSUBADDPSr213r: - NewFMAOpc = X86::VFMSUBADDPSr231r; - break; - - case X86::VFMADDPDr213rY: - NewFMAOpc = X86::VFMADDPDr231rY; - break; - case X86::VFMADDPSr213rY: - NewFMAOpc = X86::VFMADDPSr231rY; - break; - case X86::VFMSUBPDr213rY: - NewFMAOpc = X86::VFMSUBPDr231rY; - break; - case X86::VFMSUBPSr213rY: - NewFMAOpc = X86::VFMSUBPSr231rY; - break; - case X86::VFNMADDPDr213rY: - NewFMAOpc = X86::VFNMADDPDr231rY; - break; - case X86::VFNMADDPSr213rY: - NewFMAOpc = X86::VFNMADDPSr231rY; - break; - case X86::VFNMSUBPDr213rY: - NewFMAOpc = X86::VFNMSUBPDr231rY; - break; - case X86::VFNMSUBPSr213rY: - NewFMAOpc = X86::VFNMSUBPSr231rY; - break; - case X86::VFMADDSUBPDr213rY: - NewFMAOpc = X86::VFMADDSUBPDr231rY; - break; - case X86::VFMADDSUBPSr213rY: - NewFMAOpc = X86::VFMADDSUBPSr231rY; - break; - case X86::VFMSUBADDPDr213rY: - NewFMAOpc = X86::VFMSUBADDPDr231rY; - break; - case X86::VFMSUBADDPSr213rY: - NewFMAOpc = X86::VFMSUBADDPSr231rY; - break; - default: - llvm_unreachable("Unrecognized FMA variant."); - } - - const TargetInstrInfo &TII = *Subtarget.getInstrInfo(); - MachineInstrBuilder MIB = - BuildMI(MF, MI.getDebugLoc(), TII.get(NewFMAOpc)) - .addOperand(MI.getOperand(0)) - .addOperand(MI.getOperand(3)) - .addOperand(MI.getOperand(2)) - .addOperand(MI.getOperand(1)); - MBB->insert(MachineBasicBlock::iterator(MI), MIB); - MI.eraseFromParent(); - } - } - - return MBB; -} - MachineBasicBlock * X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, MachineBasicBlock *BB) const { + MachineFunction *MF = BB->getParent(); + const TargetInstrInfo *TII = Subtarget.getInstrInfo(); + DebugLoc DL = MI.getDebugLoc(); + switch (MI.getOpcode()) { default: llvm_unreachable("Unexpected instr type to insert"); case X86::TAILJMPd64: case X86::TAILJMPr64: case X86::TAILJMPm64: - case X86::TAILJMPd64_REX: case X86::TAILJMPr64_REX: case X86::TAILJMPm64_REX: llvm_unreachable("TAILJMP64 would not be touched here."); @@ -24423,8 +25911,6 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, case X86::RDFLAGS32: case X86::RDFLAGS64: { - DebugLoc DL = MI.getDebugLoc(); - const TargetInstrInfo *TII = Subtarget.getInstrInfo(); unsigned PushF = MI.getOpcode() == X86::RDFLAGS32 ? X86::PUSHF32 : X86::PUSHF64; unsigned Pop = MI.getOpcode() == X86::RDFLAGS32 ? X86::POP32r : X86::POP64r; @@ -24442,8 +25928,6 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, case X86::WRFLAGS32: case X86::WRFLAGS64: { - DebugLoc DL = MI.getDebugLoc(); - const TargetInstrInfo *TII = Subtarget.getInstrInfo(); unsigned Push = MI.getOpcode() == X86::WRFLAGS32 ? X86::PUSH32r : X86::PUSH64r; unsigned PopF = @@ -24468,19 +25952,15 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, case X86::FP80_TO_INT16_IN_MEM: case X86::FP80_TO_INT32_IN_MEM: case X86::FP80_TO_INT64_IN_MEM: { - MachineFunction *F = BB->getParent(); - const TargetInstrInfo *TII = Subtarget.getInstrInfo(); - DebugLoc DL = MI.getDebugLoc(); - // Change the floating point control register to use "round towards zero" // mode when truncating to an integer value. - int CWFrameIdx = F->getFrameInfo()->CreateStackObject(2, 2, false); + int CWFrameIdx = MF->getFrameInfo().CreateStackObject(2, 2, false); addFrameReference(BuildMI(*BB, MI, DL, TII->get(X86::FNSTCW16m)), CWFrameIdx); // Load the old value of the high byte of the control word... unsigned OldCW = - F->getRegInfo().createVirtualRegister(&X86::GR16RegClass); + MF->getRegInfo().createVirtualRegister(&X86::GR16RegClass); addFrameReference(BuildMI(*BB, MI, DL, TII->get(X86::MOV16rm), OldCW), CWFrameIdx); @@ -24588,39 +26068,57 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, case TargetOpcode::PATCHPOINT: return emitPatchPoint(MI, BB); - case X86::VFMADDPDr213r: - case X86::VFMADDPSr213r: - case X86::VFMADDSDr213r: - case X86::VFMADDSSr213r: - case X86::VFMSUBPDr213r: - case X86::VFMSUBPSr213r: - case X86::VFMSUBSDr213r: - case X86::VFMSUBSSr213r: - case X86::VFNMADDPDr213r: - case X86::VFNMADDPSr213r: - case X86::VFNMADDSDr213r: - case X86::VFNMADDSSr213r: - case X86::VFNMSUBPDr213r: - case X86::VFNMSUBPSr213r: - case X86::VFNMSUBSDr213r: - case X86::VFNMSUBSSr213r: - case X86::VFMADDSUBPDr213r: - case X86::VFMADDSUBPSr213r: - case X86::VFMSUBADDPDr213r: - case X86::VFMSUBADDPSr213r: - case X86::VFMADDPDr213rY: - case X86::VFMADDPSr213rY: - case X86::VFMSUBPDr213rY: - case X86::VFMSUBPSr213rY: - case X86::VFNMADDPDr213rY: - case X86::VFNMADDPSr213rY: - case X86::VFNMSUBPDr213rY: - case X86::VFNMSUBPSr213rY: - case X86::VFMADDSUBPDr213rY: - case X86::VFMADDSUBPSr213rY: - case X86::VFMSUBADDPDr213rY: - case X86::VFMSUBADDPSr213rY: - return emitFMA3Instr(MI, BB); + case X86::LCMPXCHG8B: { + const X86RegisterInfo *TRI = Subtarget.getRegisterInfo(); + // In addition to 4 E[ABCD] registers implied by encoding, CMPXCHG8B + // requires a memory operand. If it happens that current architecture is + // i686 and for current function we need a base pointer + // - which is ESI for i686 - register allocator would not be able to + // allocate registers for an address in form of X(%reg, %reg, Y) + // - there never would be enough unreserved registers during regalloc + // (without the need for base ptr the only option would be X(%edi, %esi, Y). + // We are giving a hand to register allocator by precomputing the address in + // a new vreg using LEA. + + // If it is not i686 or there is no base pointer - nothing to do here. + if (!Subtarget.is32Bit() || !TRI->hasBasePointer(*MF)) + return BB; + + // Even though this code does not necessarily needs the base pointer to + // be ESI, we check for that. The reason: if this assert fails, there are + // some changes happened in the compiler base pointer handling, which most + // probably have to be addressed somehow here. + assert(TRI->getBaseRegister() == X86::ESI && + "LCMPXCHG8B custom insertion for i686 is written with X86::ESI as a " + "base pointer in mind"); + + MachineRegisterInfo &MRI = MF->getRegInfo(); + MVT SPTy = getPointerTy(MF->getDataLayout()); + const TargetRegisterClass *AddrRegClass = getRegClassFor(SPTy); + unsigned computedAddrVReg = MRI.createVirtualRegister(AddrRegClass); + + X86AddressMode AM = getAddressFromInstr(&MI, 0); + // Regalloc does not need any help when the memory operand of CMPXCHG8B + // does not use index register. + if (AM.IndexReg == X86::NoRegister) + return BB; + + // After X86TargetLowering::ReplaceNodeResults CMPXCHG8B is glued to its + // four operand definitions that are E[ABCD] registers. We skip them and + // then insert the LEA. + MachineBasicBlock::iterator MBBI(MI); + while (MBBI->definesRegister(X86::EAX) || MBBI->definesRegister(X86::EBX) || + MBBI->definesRegister(X86::ECX) || MBBI->definesRegister(X86::EDX)) + --MBBI; + addFullAddress( + BuildMI(*BB, *MBBI, DL, TII->get(X86::LEA32r), computedAddrVReg), AM); + + setDirectAddressInInstr(&MI, 0, computedAddrVReg); + + return BB; + } + case X86::LCMPXCHG16B: + return BB; case X86::LCMPXCHG8B_SAVE_EBX: case X86::LCMPXCHG16B_SAVE_RBX: { unsigned BasePtr = @@ -24667,7 +26165,7 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op, // These nodes' second result is a boolean. if (Op.getResNo() == 0) break; - // Fallthrough + LLVM_FALLTHROUGH; case X86ISD::SETCC: KnownZero |= APInt::getHighBitsSet(BitWidth, BitWidth - 1); break; @@ -24676,16 +26174,36 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op, KnownZero = APInt::getHighBitsSet(BitWidth, BitWidth - NumLoBits); break; } + case X86ISD::VZEXT: { + SDValue N0 = Op.getOperand(0); + unsigned NumElts = Op.getValueType().getVectorNumElements(); + unsigned InNumElts = N0.getValueType().getVectorNumElements(); + unsigned InBitWidth = N0.getValueType().getScalarSizeInBits(); + + KnownZero = KnownOne = APInt(InBitWidth, 0); + APInt DemandedElts = APInt::getLowBitsSet(InNumElts, NumElts); + DAG.computeKnownBits(N0, KnownZero, KnownOne, DemandedElts, Depth + 1); + KnownOne = KnownOne.zext(BitWidth); + KnownZero = KnownZero.zext(BitWidth); + KnownZero |= APInt::getHighBitsSet(BitWidth, BitWidth - InBitWidth); + break; + } } } unsigned X86TargetLowering::ComputeNumSignBitsForTargetNode( - SDValue Op, - const SelectionDAG &, - unsigned Depth) const { + SDValue Op, const SelectionDAG &DAG, unsigned Depth) const { // SETCC_CARRY sets the dest to ~0 for true or 0 for false. if (Op.getOpcode() == X86ISD::SETCC_CARRY) - return Op.getValueType().getScalarSizeInBits(); + return Op.getScalarValueSizeInBits(); + + if (Op.getOpcode() == X86ISD::VSEXT) { + EVT VT = Op.getValueType(); + EVT SrcVT = Op.getOperand(0).getValueType(); + unsigned Tmp = DAG.ComputeNumSignBits(Op.getOperand(0), Depth + 1); + Tmp += VT.getScalarSizeInBits() - SrcVT.getScalarSizeInBits(); + return Tmp; + } // Fallback case. return 1; @@ -24706,171 +26224,113 @@ bool X86TargetLowering::isGAPlusOffset(SDNode *N, return TargetLowering::isGAPlusOffset(N, GA, Offset); } -/// Performs shuffle combines for 256-bit vectors. -/// FIXME: This could be expanded to support 512 bit vectors as well. -static SDValue combineShuffle256(SDNode *N, SelectionDAG &DAG, - TargetLowering::DAGCombinerInfo &DCI, - const X86Subtarget &Subtarget) { - SDLoc dl(N); - ShuffleVectorSDNode *SVOp = cast<ShuffleVectorSDNode>(N); - SDValue V1 = SVOp->getOperand(0); - SDValue V2 = SVOp->getOperand(1); - MVT VT = SVOp->getSimpleValueType(0); - unsigned NumElems = VT.getVectorNumElements(); - - if (V1.getOpcode() == ISD::CONCAT_VECTORS && - V2.getOpcode() == ISD::CONCAT_VECTORS) { - // - // 0,0,0,... - // | - // V UNDEF BUILD_VECTOR UNDEF - // \ / \ / - // CONCAT_VECTOR CONCAT_VECTOR - // \ / - // \ / - // RESULT: V + zero extended - // - if (V2.getOperand(0).getOpcode() != ISD::BUILD_VECTOR || - !V2.getOperand(1).isUndef() || !V1.getOperand(1).isUndef()) - return SDValue(); - - if (!ISD::isBuildVectorAllZeros(V2.getOperand(0).getNode())) - return SDValue(); - - // To match the shuffle mask, the first half of the mask should - // be exactly the first vector, and all the rest a splat with the - // first element of the second one. - for (unsigned i = 0; i != NumElems/2; ++i) - if (!isUndefOrEqual(SVOp->getMaskElt(i), i) || - !isUndefOrEqual(SVOp->getMaskElt(i+NumElems/2), NumElems)) - return SDValue(); - - // If V1 is coming from a vector load then just fold to a VZEXT_LOAD. - if (LoadSDNode *Ld = dyn_cast<LoadSDNode>(V1.getOperand(0))) { - if (Ld->hasNUsesOfValue(1, 0)) { - SDVTList Tys = DAG.getVTList(MVT::v4i64, MVT::Other); - SDValue Ops[] = { Ld->getChain(), Ld->getBasePtr() }; - SDValue ResNode = - DAG.getMemIntrinsicNode(X86ISD::VZEXT_LOAD, dl, Tys, Ops, - Ld->getMemoryVT(), - Ld->getPointerInfo(), - Ld->getAlignment(), - false/*isVolatile*/, true/*ReadMem*/, - false/*WriteMem*/); - - // Make sure the newly-created LOAD is in the same position as Ld in - // terms of dependency. We create a TokenFactor for Ld and ResNode, - // and update uses of Ld's output chain to use the TokenFactor. - if (Ld->hasAnyUseOfValue(1)) { - SDValue NewChain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, - SDValue(Ld, 1), SDValue(ResNode.getNode(), 1)); - DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), NewChain); - DAG.UpdateNodeOperands(NewChain.getNode(), SDValue(Ld, 1), - SDValue(ResNode.getNode(), 1)); - } - - return DAG.getBitcast(VT, ResNode); - } - } - - // Emit a zeroed vector and insert the desired subvector on its - // first half. - SDValue Zeros = getZeroVector(VT, Subtarget, DAG, dl); - SDValue InsV = insert128BitVector(Zeros, V1.getOperand(0), 0, DAG, dl); - return DCI.CombineTo(N, InsV); - } - - return SDValue(); -} - // Attempt to match a combined shuffle mask against supported unary shuffle // instructions. // TODO: Investigate sharing more of this with shuffle lowering. -static bool matchUnaryVectorShuffle(MVT SrcVT, ArrayRef<int> Mask, +static bool matchUnaryVectorShuffle(MVT MaskVT, ArrayRef<int> Mask, + bool FloatDomain, const X86Subtarget &Subtarget, - unsigned &Shuffle, MVT &ShuffleVT) { - bool FloatDomain = SrcVT.isFloatingPoint() || - (!Subtarget.hasAVX2() && SrcVT.is256BitVector()); + unsigned &Shuffle, MVT &SrcVT, MVT &DstVT) { + unsigned NumMaskElts = Mask.size(); + unsigned MaskEltSize = MaskVT.getScalarSizeInBits(); - // Match a 128-bit integer vector against a VZEXT_MOVL (MOVQ) instruction. - if (!FloatDomain && SrcVT.is128BitVector() && - isTargetShuffleEquivalent(Mask, {0, SM_SentinelZero})) { + // Match against a VZEXT_MOVL instruction, SSE1 only supports 32-bits (MOVSS). + if (((MaskEltSize == 32) || (MaskEltSize == 64 && Subtarget.hasSSE2())) && + isUndefOrEqual(Mask[0], 0) && + isUndefOrZeroInRange(Mask, 1, NumMaskElts - 1)) { Shuffle = X86ISD::VZEXT_MOVL; - ShuffleVT = MVT::v2i64; + SrcVT = DstVT = !Subtarget.hasSSE2() ? MVT::v4f32 : MaskVT; return true; } + // Match against a VZEXT instruction. + // TODO: Add 256/512-bit vector support. + if (!FloatDomain && MaskVT.is128BitVector() && Subtarget.hasSSE41()) { + unsigned MaxScale = 64 / MaskEltSize; + for (unsigned Scale = 2; Scale <= MaxScale; Scale *= 2) { + bool Match = true; + unsigned NumDstElts = NumMaskElts / Scale; + for (unsigned i = 0; i != NumDstElts && Match; ++i) { + Match &= isUndefOrEqual(Mask[i * Scale], (int)i); + Match &= isUndefOrZeroInRange(Mask, (i * Scale) + 1, Scale - 1); + } + if (Match) { + SrcVT = MaskVT; + DstVT = MVT::getIntegerVT(Scale * MaskEltSize); + DstVT = MVT::getVectorVT(DstVT, NumDstElts); + Shuffle = X86ISD::VZEXT; + return true; + } + } + } + // Check if we have SSE3 which will let us use MOVDDUP etc. The // instructions are no slower than UNPCKLPD but has the option to // fold the input operand into even an unaligned memory load. - if (SrcVT.is128BitVector() && Subtarget.hasSSE3() && FloatDomain) { + if (MaskVT.is128BitVector() && Subtarget.hasSSE3() && FloatDomain) { if (isTargetShuffleEquivalent(Mask, {0, 0})) { Shuffle = X86ISD::MOVDDUP; - ShuffleVT = MVT::v2f64; + SrcVT = DstVT = MVT::v2f64; return true; } if (isTargetShuffleEquivalent(Mask, {0, 0, 2, 2})) { Shuffle = X86ISD::MOVSLDUP; - ShuffleVT = MVT::v4f32; + SrcVT = DstVT = MVT::v4f32; return true; } if (isTargetShuffleEquivalent(Mask, {1, 1, 3, 3})) { Shuffle = X86ISD::MOVSHDUP; - ShuffleVT = MVT::v4f32; + SrcVT = DstVT = MVT::v4f32; return true; } } - if (SrcVT.is256BitVector() && FloatDomain) { + if (MaskVT.is256BitVector() && FloatDomain) { assert(Subtarget.hasAVX() && "AVX required for 256-bit vector shuffles"); if (isTargetShuffleEquivalent(Mask, {0, 0, 2, 2})) { Shuffle = X86ISD::MOVDDUP; - ShuffleVT = MVT::v4f64; + SrcVT = DstVT = MVT::v4f64; return true; } if (isTargetShuffleEquivalent(Mask, {0, 0, 2, 2, 4, 4, 6, 6})) { Shuffle = X86ISD::MOVSLDUP; - ShuffleVT = MVT::v8f32; + SrcVT = DstVT = MVT::v8f32; return true; } if (isTargetShuffleEquivalent(Mask, {1, 1, 3, 3, 5, 5, 7, 7})) { Shuffle = X86ISD::MOVSHDUP; - ShuffleVT = MVT::v8f32; + SrcVT = DstVT = MVT::v8f32; return true; } } - if (SrcVT.is512BitVector() && FloatDomain) { + if (MaskVT.is512BitVector() && FloatDomain) { assert(Subtarget.hasAVX512() && "AVX512 required for 512-bit vector shuffles"); if (isTargetShuffleEquivalent(Mask, {0, 0, 2, 2, 4, 4, 6, 6})) { Shuffle = X86ISD::MOVDDUP; - ShuffleVT = MVT::v8f64; + SrcVT = DstVT = MVT::v8f64; return true; } if (isTargetShuffleEquivalent( Mask, {0, 0, 2, 2, 4, 4, 6, 6, 8, 8, 10, 10, 12, 12, 14, 14})) { Shuffle = X86ISD::MOVSLDUP; - ShuffleVT = MVT::v16f32; + SrcVT = DstVT = MVT::v16f32; return true; } if (isTargetShuffleEquivalent( Mask, {1, 1, 3, 3, 5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15})) { Shuffle = X86ISD::MOVSHDUP; - ShuffleVT = MVT::v16f32; + SrcVT = DstVT = MVT::v16f32; return true; } } // Attempt to match against broadcast-from-vector. if (Subtarget.hasAVX2()) { - unsigned NumElts = Mask.size(); - SmallVector<int, 64> BroadcastMask(NumElts, 0); + SmallVector<int, 64> BroadcastMask(NumMaskElts, 0); if (isTargetShuffleEquivalent(Mask, BroadcastMask)) { - unsigned EltSize = SrcVT.getSizeInBits() / NumElts; - ShuffleVT = FloatDomain ? MVT::getFloatingPointVT(EltSize) - : MVT::getIntegerVT(EltSize); - ShuffleVT = MVT::getVectorVT(ShuffleVT, NumElts); + SrcVT = DstVT = MaskVT; Shuffle = X86ISD::VBROADCAST; return true; } @@ -24882,19 +26342,44 @@ static bool matchUnaryVectorShuffle(MVT SrcVT, ArrayRef<int> Mask, // Attempt to match a combined shuffle mask against supported unary immediate // permute instructions. // TODO: Investigate sharing more of this with shuffle lowering. -static bool matchPermuteVectorShuffle(MVT SrcVT, ArrayRef<int> Mask, - const X86Subtarget &Subtarget, - unsigned &Shuffle, MVT &ShuffleVT, - unsigned &PermuteImm) { - // Ensure we don't contain any zero elements. - for (int M : Mask) { - if (M == SM_SentinelZero) - return false; - assert(SM_SentinelUndef <= M && M < (int)Mask.size() && - "Expected unary shuffle"); +static bool matchUnaryPermuteVectorShuffle(MVT MaskVT, ArrayRef<int> Mask, + bool FloatDomain, + const X86Subtarget &Subtarget, + unsigned &Shuffle, MVT &ShuffleVT, + unsigned &PermuteImm) { + unsigned NumMaskElts = Mask.size(); + + bool ContainsZeros = false; + SmallBitVector Zeroable(NumMaskElts, false); + for (unsigned i = 0; i != NumMaskElts; ++i) { + int M = Mask[i]; + Zeroable[i] = isUndefOrZero(M); + ContainsZeros |= (M == SM_SentinelZero); + } + + // Attempt to match against byte/bit shifts. + // FIXME: Add 512-bit support. + if (!FloatDomain && ((MaskVT.is128BitVector() && Subtarget.hasSSE2()) || + (MaskVT.is256BitVector() && Subtarget.hasAVX2()))) { + int ShiftAmt = matchVectorShuffleAsShift(ShuffleVT, Shuffle, + MaskVT.getScalarSizeInBits(), Mask, + 0, Zeroable, Subtarget); + if (0 < ShiftAmt) { + PermuteImm = (unsigned)ShiftAmt; + return true; + } } - unsigned MaskScalarSizeInBits = SrcVT.getSizeInBits() / Mask.size(); + // Ensure we don't contain any zero elements. + if (ContainsZeros) + return false; + + assert(llvm::all_of(Mask, [&](int M) { + return SM_SentinelUndef <= M && M < (int)NumMaskElts; + }) && "Expected unary shuffle"); + + unsigned InputSizeInBits = MaskVT.getSizeInBits(); + unsigned MaskScalarSizeInBits = InputSizeInBits / Mask.size(); MVT MaskEltVT = MVT::getIntegerVT(MaskScalarSizeInBits); // Handle PSHUFLW/PSHUFHW repeated patterns. @@ -24908,7 +26393,7 @@ static bool matchPermuteVectorShuffle(MVT SrcVT, ArrayRef<int> Mask, if (isUndefOrInRange(LoMask, 0, 4) && isSequentialOrUndefInRange(HiMask, 0, 4, 4)) { Shuffle = X86ISD::PSHUFLW; - ShuffleVT = MVT::getVectorVT(MVT::i16, SrcVT.getSizeInBits() / 16); + ShuffleVT = MVT::getVectorVT(MVT::i16, InputSizeInBits / 16); PermuteImm = getV4X86ShuffleImm(LoMask); return true; } @@ -24922,7 +26407,7 @@ static bool matchPermuteVectorShuffle(MVT SrcVT, ArrayRef<int> Mask, OffsetHiMask[i] = (HiMask[i] < 0 ? HiMask[i] : HiMask[i] - 4); Shuffle = X86ISD::PSHUFHW; - ShuffleVT = MVT::getVectorVT(MVT::i16, SrcVT.getSizeInBits() / 16); + ShuffleVT = MVT::getVectorVT(MVT::i16, InputSizeInBits / 16); PermuteImm = getV4X86ShuffleImm(OffsetHiMask); return true; } @@ -24938,24 +26423,23 @@ static bool matchPermuteVectorShuffle(MVT SrcVT, ArrayRef<int> Mask, // AVX introduced the VPERMILPD/VPERMILPS float permutes, before then we // had to use 2-input SHUFPD/SHUFPS shuffles (not handled here). - bool FloatDomain = SrcVT.isFloatingPoint(); if (FloatDomain && !Subtarget.hasAVX()) return false; // Pre-AVX2 we must use float shuffles on 256-bit vectors. - if (SrcVT.is256BitVector() && !Subtarget.hasAVX2()) + if (MaskVT.is256BitVector() && !Subtarget.hasAVX2()) FloatDomain = true; // Check for lane crossing permutes. if (is128BitLaneCrossingShuffleMask(MaskEltVT, Mask)) { // PERMPD/PERMQ permutes within a 256-bit vector (AVX2+). - if (Subtarget.hasAVX2() && SrcVT.is256BitVector() && Mask.size() == 4) { + if (Subtarget.hasAVX2() && MaskVT.is256BitVector() && Mask.size() == 4) { Shuffle = X86ISD::VPERMI; ShuffleVT = (FloatDomain ? MVT::v4f64 : MVT::v4i64); PermuteImm = getV4X86ShuffleImm(Mask); return true; } - if (Subtarget.hasAVX512() && SrcVT.is512BitVector() && Mask.size() == 8) { + if (Subtarget.hasAVX512() && MaskVT.is512BitVector() && Mask.size() == 8) { SmallVector<int, 4> RepeatedMask; if (is256BitLaneRepeatedShuffleMask(MVT::v8f64, Mask, RepeatedMask)) { Shuffle = X86ISD::VPERMI; @@ -24994,7 +26478,7 @@ static bool matchPermuteVectorShuffle(MVT SrcVT, ArrayRef<int> Mask, Shuffle = (FloatDomain ? X86ISD::VPERMILPI : X86ISD::PSHUFD); ShuffleVT = (FloatDomain ? MVT::f32 : MVT::i32); - ShuffleVT = MVT::getVectorVT(ShuffleVT, SrcVT.getSizeInBits() / 32); + ShuffleVT = MVT::getVectorVT(ShuffleVT, InputSizeInBits / 32); PermuteImm = getV4X86ShuffleImm(WordMask); return true; } @@ -25002,47 +26486,259 @@ static bool matchPermuteVectorShuffle(MVT SrcVT, ArrayRef<int> Mask, // Attempt to match a combined unary shuffle mask against supported binary // shuffle instructions. // TODO: Investigate sharing more of this with shuffle lowering. -static bool matchBinaryVectorShuffle(MVT SrcVT, ArrayRef<int> Mask, - unsigned &Shuffle, MVT &ShuffleVT) { - bool FloatDomain = SrcVT.isFloatingPoint(); +static bool matchBinaryVectorShuffle(MVT MaskVT, ArrayRef<int> Mask, + bool FloatDomain, SDValue &V1, SDValue &V2, + const X86Subtarget &Subtarget, + unsigned &Shuffle, MVT &ShuffleVT, + bool IsUnary) { + unsigned EltSizeInBits = MaskVT.getScalarSizeInBits(); - if (SrcVT.is128BitVector()) { + if (MaskVT.is128BitVector()) { if (isTargetShuffleEquivalent(Mask, {0, 0}) && FloatDomain) { + V2 = V1; Shuffle = X86ISD::MOVLHPS; ShuffleVT = MVT::v4f32; return true; } if (isTargetShuffleEquivalent(Mask, {1, 1}) && FloatDomain) { + V2 = V1; Shuffle = X86ISD::MOVHLPS; ShuffleVT = MVT::v4f32; return true; } - if (isTargetShuffleEquivalent(Mask, {0, 0, 1, 1}) && FloatDomain) { - Shuffle = X86ISD::UNPCKL; - ShuffleVT = MVT::v4f32; + if (isTargetShuffleEquivalent(Mask, {0, 3}) && Subtarget.hasSSE2() && + (FloatDomain || !Subtarget.hasSSE41())) { + std::swap(V1, V2); + Shuffle = X86ISD::MOVSD; + ShuffleVT = MaskVT; return true; } - if (isTargetShuffleEquivalent(Mask, {2, 2, 3, 3}) && FloatDomain) { - Shuffle = X86ISD::UNPCKH; - ShuffleVT = MVT::v4f32; + if (isTargetShuffleEquivalent(Mask, {4, 1, 2, 3}) && + (FloatDomain || !Subtarget.hasSSE41())) { + Shuffle = X86ISD::MOVSS; + ShuffleVT = MaskVT; + return true; + } + } + + // Attempt to match against either a unary or binary UNPCKL/UNPCKH shuffle. + if ((MaskVT == MVT::v4f32 && Subtarget.hasSSE1()) || + (MaskVT.is128BitVector() && Subtarget.hasSSE2()) || + (MaskVT.is256BitVector() && 32 <= EltSizeInBits && Subtarget.hasAVX()) || + (MaskVT.is256BitVector() && Subtarget.hasAVX2()) || + (MaskVT.is512BitVector() && Subtarget.hasAVX512())) { + MVT LegalVT = MaskVT; + if (LegalVT.is256BitVector() && !Subtarget.hasAVX2()) + LegalVT = (32 == EltSizeInBits ? MVT::v8f32 : MVT::v4f64); + + SmallVector<int, 64> Unpckl, Unpckh; + if (IsUnary) { + createUnpackShuffleMask(MaskVT, Unpckl, true, true); + if (isTargetShuffleEquivalent(Mask, Unpckl)) { + V2 = V1; + Shuffle = X86ISD::UNPCKL; + ShuffleVT = LegalVT; + return true; + } + + createUnpackShuffleMask(MaskVT, Unpckh, false, true); + if (isTargetShuffleEquivalent(Mask, Unpckh)) { + V2 = V1; + Shuffle = X86ISD::UNPCKH; + ShuffleVT = LegalVT; + return true; + } + } else { + createUnpackShuffleMask(MaskVT, Unpckl, true, false); + if (isTargetShuffleEquivalent(Mask, Unpckl)) { + Shuffle = X86ISD::UNPCKL; + ShuffleVT = LegalVT; + return true; + } + + createUnpackShuffleMask(MaskVT, Unpckh, false, false); + if (isTargetShuffleEquivalent(Mask, Unpckh)) { + Shuffle = X86ISD::UNPCKH; + ShuffleVT = LegalVT; + return true; + } + + ShuffleVectorSDNode::commuteMask(Unpckl); + if (isTargetShuffleEquivalent(Mask, Unpckl)) { + std::swap(V1, V2); + Shuffle = X86ISD::UNPCKL; + ShuffleVT = LegalVT; + return true; + } + + ShuffleVectorSDNode::commuteMask(Unpckh); + if (isTargetShuffleEquivalent(Mask, Unpckh)) { + std::swap(V1, V2); + Shuffle = X86ISD::UNPCKH; + ShuffleVT = LegalVT; + return true; + } + } + } + + return false; +} + +static bool matchBinaryPermuteVectorShuffle(MVT MaskVT, ArrayRef<int> Mask, + bool FloatDomain, + SDValue &V1, SDValue &V2, + SDLoc &DL, SelectionDAG &DAG, + const X86Subtarget &Subtarget, + unsigned &Shuffle, MVT &ShuffleVT, + unsigned &PermuteImm) { + unsigned NumMaskElts = Mask.size(); + + // Attempt to match against PALIGNR byte rotate. + if (!FloatDomain && ((MaskVT.is128BitVector() && Subtarget.hasSSSE3()) || + (MaskVT.is256BitVector() && Subtarget.hasAVX2()))) { + int ByteRotation = matchVectorShuffleAsByteRotate(MaskVT, V1, V2, Mask); + if (0 < ByteRotation) { + Shuffle = X86ISD::PALIGNR; + ShuffleVT = MVT::getVectorVT(MVT::i8, MaskVT.getSizeInBits() / 8); + PermuteImm = ByteRotation; + return true; + } + } + + // Attempt to combine to X86ISD::BLENDI. + if (NumMaskElts <= 8 && ((Subtarget.hasSSE41() && MaskVT.is128BitVector()) || + (Subtarget.hasAVX() && MaskVT.is256BitVector()))) { + // Determine a type compatible with X86ISD::BLENDI. + // TODO - add 16i16 support (requires lane duplication). + MVT BlendVT = MaskVT; + if (Subtarget.hasAVX2()) { + if (BlendVT == MVT::v4i64) + BlendVT = MVT::v8i32; + else if (BlendVT == MVT::v2i64) + BlendVT = MVT::v4i32; + } else { + if (BlendVT == MVT::v2i64 || BlendVT == MVT::v4i32) + BlendVT = MVT::v8i16; + else if (BlendVT == MVT::v4i64) + BlendVT = MVT::v4f64; + else if (BlendVT == MVT::v8i32) + BlendVT = MVT::v8f32; + } + + unsigned BlendSize = BlendVT.getVectorNumElements(); + unsigned MaskRatio = BlendSize / NumMaskElts; + + // Can we blend with zero? + if (isSequentialOrUndefOrZeroInRange(Mask, /*Pos*/ 0, /*Size*/ NumMaskElts, + /*Low*/ 0) && + NumMaskElts <= BlendVT.getVectorNumElements()) { + PermuteImm = 0; + for (unsigned i = 0; i != BlendSize; ++i) + if (Mask[i / MaskRatio] < 0) + PermuteImm |= 1u << i; + + V2 = getZeroVector(BlendVT, Subtarget, DAG, DL); + Shuffle = X86ISD::BLENDI; + ShuffleVT = BlendVT; return true; } - if (isTargetShuffleEquivalent(Mask, {0, 0, 1, 1, 2, 2, 3, 3}) || - isTargetShuffleEquivalent( - Mask, {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7})) { - Shuffle = X86ISD::UNPCKL; - ShuffleVT = Mask.size() == 8 ? MVT::v8i16 : MVT::v16i8; + + // Attempt to match as a binary blend. + if (NumMaskElts <= BlendVT.getVectorNumElements()) { + bool MatchBlend = true; + for (int i = 0; i != (int)NumMaskElts; ++i) { + int M = Mask[i]; + if (M == SM_SentinelUndef) + continue; + else if (M == SM_SentinelZero) + MatchBlend = false; + else if ((M != i) && (M != (i + (int)NumMaskElts))) + MatchBlend = false; + } + + if (MatchBlend) { + PermuteImm = 0; + for (unsigned i = 0; i != BlendSize; ++i) + if ((int)NumMaskElts <= Mask[i / MaskRatio]) + PermuteImm |= 1u << i; + + Shuffle = X86ISD::BLENDI; + ShuffleVT = BlendVT; + return true; + } + } + } + + // Attempt to combine to INSERTPS. + if (Subtarget.hasSSE41() && MaskVT == MVT::v4f32) { + SmallBitVector Zeroable(4, false); + for (unsigned i = 0; i != NumMaskElts; ++i) + if (Mask[i] < 0) + Zeroable[i] = true; + + if (Zeroable.any() && + matchVectorShuffleAsInsertPS(V1, V2, PermuteImm, Zeroable, Mask, DAG)) { + Shuffle = X86ISD::INSERTPS; + ShuffleVT = MVT::v4f32; return true; } - if (isTargetShuffleEquivalent(Mask, {4, 4, 5, 5, 6, 6, 7, 7}) || - isTargetShuffleEquivalent(Mask, {8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, - 13, 14, 14, 15, 15})) { - Shuffle = X86ISD::UNPCKH; - ShuffleVT = Mask.size() == 8 ? MVT::v8i16 : MVT::v16i8; + } + + // Attempt to combine to SHUFPD. + if ((MaskVT == MVT::v2f64 && Subtarget.hasSSE2()) || + (MaskVT == MVT::v4f64 && Subtarget.hasAVX()) || + (MaskVT == MVT::v8f64 && Subtarget.hasAVX512())) { + if (matchVectorShuffleWithSHUFPD(MaskVT, V1, V2, PermuteImm, Mask)) { + Shuffle = X86ISD::SHUFP; + ShuffleVT = MaskVT; return true; } } + // Attempt to combine to SHUFPS. + if ((MaskVT == MVT::v4f32 && Subtarget.hasSSE1()) || + (MaskVT == MVT::v8f32 && Subtarget.hasAVX()) || + (MaskVT == MVT::v16f32 && Subtarget.hasAVX512())) { + SmallVector<int, 4> RepeatedMask; + if (isRepeatedTargetShuffleMask(128, MaskVT, Mask, RepeatedMask)) { + auto MatchHalf = [&](unsigned Offset, int &S0, int &S1) { + int M0 = RepeatedMask[Offset]; + int M1 = RepeatedMask[Offset + 1]; + + if (isUndefInRange(RepeatedMask, Offset, 2)) { + return DAG.getUNDEF(MaskVT); + } else if (isUndefOrZeroInRange(RepeatedMask, Offset, 2)) { + S0 = (SM_SentinelUndef == M0 ? -1 : 0); + S1 = (SM_SentinelUndef == M1 ? -1 : 1); + return getZeroVector(MaskVT, Subtarget, DAG, DL); + } else if (isUndefOrInRange(M0, 0, 4) && isUndefOrInRange(M1, 0, 4)) { + S0 = (SM_SentinelUndef == M0 ? -1 : M0 & 3); + S1 = (SM_SentinelUndef == M1 ? -1 : M1 & 3); + return V1; + } else if (isUndefOrInRange(M0, 4, 8) && isUndefOrInRange(M1, 4, 8)) { + S0 = (SM_SentinelUndef == M0 ? -1 : M0 & 3); + S1 = (SM_SentinelUndef == M1 ? -1 : M1 & 3); + return V2; + } + + return SDValue(); + }; + + int ShufMask[4] = {-1, -1, -1, -1}; + SDValue Lo = MatchHalf(0, ShufMask[0], ShufMask[1]); + SDValue Hi = MatchHalf(2, ShufMask[2], ShufMask[3]); + + if (Lo && Hi) { + V1 = Lo; + V2 = Hi; + Shuffle = X86ISD::SHUFP; + ShuffleVT = MaskVT; + PermuteImm = getV4X86ShuffleImm(ShufMask); + return true; + } + } + } + return false; } @@ -25055,33 +26751,44 @@ static bool matchBinaryVectorShuffle(MVT SrcVT, ArrayRef<int> Mask, /// into either a single instruction if there is a special purpose instruction /// for this operation, or into a PSHUFB instruction which is a fully general /// instruction but should only be used to replace chains over a certain depth. -static bool combineX86ShuffleChain(SDValue Input, SDValue Root, +static bool combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, ArrayRef<int> BaseMask, int Depth, bool HasVariableMask, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { assert(!BaseMask.empty() && "Cannot combine an empty shuffle mask!"); + assert((Inputs.size() == 1 || Inputs.size() == 2) && + "Unexpected number of shuffle inputs!"); - // Find the operand that enters the chain. Note that multiple uses are OK - // here, we're not going to remove the operand we find. - Input = peekThroughBitcasts(Input); + // Find the inputs that enter the chain. Note that multiple uses are OK + // here, we're not going to remove the operands we find. + bool UnaryShuffle = (Inputs.size() == 1); + SDValue V1 = peekThroughBitcasts(Inputs[0]); + SDValue V2 = (UnaryShuffle ? V1 : peekThroughBitcasts(Inputs[1])); - MVT VT = Input.getSimpleValueType(); + MVT VT1 = V1.getSimpleValueType(); + MVT VT2 = V2.getSimpleValueType(); MVT RootVT = Root.getSimpleValueType(); - SDLoc DL(Root); + assert(VT1.getSizeInBits() == RootVT.getSizeInBits() && + VT2.getSizeInBits() == RootVT.getSizeInBits() && + "Vector size mismatch"); + SDLoc DL(Root); SDValue Res; unsigned NumBaseMaskElts = BaseMask.size(); if (NumBaseMaskElts == 1) { assert(BaseMask[0] == 0 && "Invalid shuffle index found!"); - DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Input), + DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, V1), /*AddTo*/ true); return true; } unsigned RootSizeInBits = RootVT.getSizeInBits(); + unsigned NumRootElts = RootVT.getVectorNumElements(); unsigned BaseMaskEltSizeInBits = RootSizeInBits / NumBaseMaskElts; + bool FloatDomain = VT1.isFloatingPoint() || VT2.isFloatingPoint() || + (RootVT.is256BitVector() && !Subtarget.hasAVX2()); // Don't combine if we are a AVX512/EVEX target and the mask element size // is different from the root element size - this would prevent writemasks @@ -25089,26 +26796,25 @@ static bool combineX86ShuffleChain(SDValue Input, SDValue Root, // TODO - this currently prevents all lane shuffles from occurring. // TODO - check for writemasks usage instead of always preventing combining. // TODO - attempt to narrow Mask back to writemask size. - if (RootVT.getScalarSizeInBits() != BaseMaskEltSizeInBits && - (RootSizeInBits == 512 || - (Subtarget.hasVLX() && RootSizeInBits >= 128))) { + bool IsEVEXShuffle = + RootSizeInBits == 512 || (Subtarget.hasVLX() && RootSizeInBits >= 128); + if (IsEVEXShuffle && (RootVT.getScalarSizeInBits() != BaseMaskEltSizeInBits)) return false; - } // TODO - handle 128/256-bit lane shuffles of 512-bit vectors. // Handle 128-bit lane shuffles of 256-bit vectors. - if (VT.is256BitVector() && NumBaseMaskElts == 2 && + // TODO - this should support binary shuffles. + if (UnaryShuffle && RootVT.is256BitVector() && NumBaseMaskElts == 2 && !isSequentialOrUndefOrZeroInRange(BaseMask, 0, 2, 0)) { if (Depth == 1 && Root.getOpcode() == X86ISD::VPERM2X128) return false; // Nothing to do! - MVT ShuffleVT = (VT.isFloatingPoint() || !Subtarget.hasAVX2() ? MVT::v4f64 - : MVT::v4i64); + MVT ShuffleVT = (FloatDomain ? MVT::v4f64 : MVT::v4i64); unsigned PermMask = 0; PermMask |= ((BaseMask[0] < 0 ? 0x8 : (BaseMask[0] & 1)) << 0); PermMask |= ((BaseMask[1] < 0 ? 0x8 : (BaseMask[1] & 1)) << 4); - Res = DAG.getBitcast(ShuffleVT, Input); + Res = DAG.getBitcast(ShuffleVT, V1); DCI.AddToWorklist(Res.getNode()); Res = DAG.getNode(X86ISD::VPERM2X128, DL, ShuffleVT, Res, DAG.getUNDEF(ShuffleVT), @@ -25134,144 +26840,234 @@ static bool combineX86ShuffleChain(SDValue Input, SDValue Root, unsigned MaskEltSizeInBits = RootSizeInBits / NumMaskElts; // Determine the effective mask value type. - bool FloatDomain = - (VT.isFloatingPoint() || (VT.is256BitVector() && !Subtarget.hasAVX2())) && - (32 <= MaskEltSizeInBits); + FloatDomain &= (32 <= MaskEltSizeInBits); MVT MaskVT = FloatDomain ? MVT::getFloatingPointVT(MaskEltSizeInBits) : MVT::getIntegerVT(MaskEltSizeInBits); MaskVT = MVT::getVectorVT(MaskVT, NumMaskElts); + // Only allow legal mask types. + if (!DAG.getTargetLoweringInfo().isTypeLegal(MaskVT)) + return false; + // Attempt to match the mask against known shuffle patterns. - MVT ShuffleVT; + MVT ShuffleSrcVT, ShuffleVT; unsigned Shuffle, PermuteImm; - if (matchUnaryVectorShuffle(VT, Mask, Subtarget, Shuffle, ShuffleVT)) { - if (Depth == 1 && Root.getOpcode() == Shuffle) - return false; // Nothing to do! - Res = DAG.getBitcast(ShuffleVT, Input); - DCI.AddToWorklist(Res.getNode()); - Res = DAG.getNode(Shuffle, DL, ShuffleVT, Res); - DCI.AddToWorklist(Res.getNode()); - DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Res), - /*AddTo*/ true); - return true; + if (UnaryShuffle) { + // If we are shuffling a X86ISD::VZEXT_LOAD then we can use the load + // directly if we don't shuffle the lower element and we shuffle the upper + // (zero) elements within themselves. + if (V1.getOpcode() == X86ISD::VZEXT_LOAD && + (V1.getScalarValueSizeInBits() % MaskEltSizeInBits) == 0) { + unsigned Scale = V1.getScalarValueSizeInBits() / MaskEltSizeInBits; + ArrayRef<int> HiMask(Mask.data() + Scale, NumMaskElts - Scale); + if (isSequentialOrUndefInRange(Mask, 0, Scale, 0) && + isUndefOrZeroOrInRange(HiMask, Scale, NumMaskElts)) { + DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, V1), + /*AddTo*/ true); + return true; + } + } + + if (matchUnaryVectorShuffle(MaskVT, Mask, FloatDomain, Subtarget, Shuffle, + ShuffleSrcVT, ShuffleVT)) { + if (Depth == 1 && Root.getOpcode() == Shuffle) + return false; // Nothing to do! + if (IsEVEXShuffle && (NumRootElts != ShuffleVT.getVectorNumElements())) + return false; // AVX512 Writemask clash. + Res = DAG.getBitcast(ShuffleSrcVT, V1); + DCI.AddToWorklist(Res.getNode()); + Res = DAG.getNode(Shuffle, DL, ShuffleVT, Res); + DCI.AddToWorklist(Res.getNode()); + DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Res), + /*AddTo*/ true); + return true; + } + + if (matchUnaryPermuteVectorShuffle(MaskVT, Mask, FloatDomain, Subtarget, + Shuffle, ShuffleVT, PermuteImm)) { + if (Depth == 1 && Root.getOpcode() == Shuffle) + return false; // Nothing to do! + if (IsEVEXShuffle && (NumRootElts != ShuffleVT.getVectorNumElements())) + return false; // AVX512 Writemask clash. + Res = DAG.getBitcast(ShuffleVT, V1); + DCI.AddToWorklist(Res.getNode()); + Res = DAG.getNode(Shuffle, DL, ShuffleVT, Res, + DAG.getConstant(PermuteImm, DL, MVT::i8)); + DCI.AddToWorklist(Res.getNode()); + DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Res), + /*AddTo*/ true); + return true; + } } - if (matchPermuteVectorShuffle(VT, Mask, Subtarget, Shuffle, ShuffleVT, - PermuteImm)) { + if (matchBinaryVectorShuffle(MaskVT, Mask, FloatDomain, V1, V2, Subtarget, + Shuffle, ShuffleVT, UnaryShuffle)) { if (Depth == 1 && Root.getOpcode() == Shuffle) return false; // Nothing to do! - Res = DAG.getBitcast(ShuffleVT, Input); - DCI.AddToWorklist(Res.getNode()); - Res = DAG.getNode(Shuffle, DL, ShuffleVT, Res, - DAG.getConstant(PermuteImm, DL, MVT::i8)); + if (IsEVEXShuffle && (NumRootElts != ShuffleVT.getVectorNumElements())) + return false; // AVX512 Writemask clash. + V1 = DAG.getBitcast(ShuffleVT, V1); + DCI.AddToWorklist(V1.getNode()); + V2 = DAG.getBitcast(ShuffleVT, V2); + DCI.AddToWorklist(V2.getNode()); + Res = DAG.getNode(Shuffle, DL, ShuffleVT, V1, V2); DCI.AddToWorklist(Res.getNode()); DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Res), /*AddTo*/ true); return true; } - if (matchBinaryVectorShuffle(VT, Mask, Shuffle, ShuffleVT)) { + if (matchBinaryPermuteVectorShuffle(MaskVT, Mask, FloatDomain, V1, V2, DL, + DAG, Subtarget, Shuffle, ShuffleVT, + PermuteImm)) { if (Depth == 1 && Root.getOpcode() == Shuffle) return false; // Nothing to do! - Res = DAG.getBitcast(ShuffleVT, Input); - DCI.AddToWorklist(Res.getNode()); - Res = DAG.getNode(Shuffle, DL, ShuffleVT, Res, Res); + if (IsEVEXShuffle && (NumRootElts != ShuffleVT.getVectorNumElements())) + return false; // AVX512 Writemask clash. + V1 = DAG.getBitcast(ShuffleVT, V1); + DCI.AddToWorklist(V1.getNode()); + V2 = DAG.getBitcast(ShuffleVT, V2); + DCI.AddToWorklist(V2.getNode()); + Res = DAG.getNode(Shuffle, DL, ShuffleVT, V1, V2, + DAG.getConstant(PermuteImm, DL, MVT::i8)); DCI.AddToWorklist(Res.getNode()); DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Res), /*AddTo*/ true); return true; } - // Attempt to blend with zero. - if (NumMaskElts <= 8 && - ((Subtarget.hasSSE41() && VT.is128BitVector()) || - (Subtarget.hasAVX() && VT.is256BitVector()))) { - // Convert VT to a type compatible with X86ISD::BLENDI. - // TODO - add 16i16 support (requires lane duplication). - MVT ShuffleVT = MaskVT; - if (Subtarget.hasAVX2()) { - if (ShuffleVT == MVT::v4i64) - ShuffleVT = MVT::v8i32; - else if (ShuffleVT == MVT::v2i64) - ShuffleVT = MVT::v4i32; - } else { - if (ShuffleVT == MVT::v2i64 || ShuffleVT == MVT::v4i32) - ShuffleVT = MVT::v8i16; - else if (ShuffleVT == MVT::v4i64) - ShuffleVT = MVT::v4f64; - else if (ShuffleVT == MVT::v8i32) - ShuffleVT = MVT::v8f32; - } - - if (isSequentialOrUndefOrZeroInRange(Mask, /*Pos*/ 0, /*Size*/ NumMaskElts, - /*Low*/ 0) && - NumMaskElts <= ShuffleVT.getVectorNumElements()) { - unsigned BlendMask = 0; - unsigned ShuffleSize = ShuffleVT.getVectorNumElements(); - unsigned MaskRatio = ShuffleSize / NumMaskElts; - - if (Depth == 1 && Root.getOpcode() == X86ISD::BLENDI) - return false; - - for (unsigned i = 0; i != ShuffleSize; ++i) - if (Mask[i / MaskRatio] < 0) - BlendMask |= 1u << i; + // Don't try to re-form single instruction chains under any circumstances now + // that we've done encoding canonicalization for them. + if (Depth < 2) + return false; - SDValue Zero = getZeroVector(ShuffleVT, Subtarget, DAG, DL); - Res = DAG.getBitcast(ShuffleVT, Input); + bool MaskContainsZeros = + any_of(Mask, [](int M) { return M == SM_SentinelZero; }); + + if (is128BitLaneCrossingShuffleMask(MaskVT, Mask)) { + // If we have a single input lane-crossing shuffle then lower to VPERMV. + if (UnaryShuffle && (Depth >= 3 || HasVariableMask) && !MaskContainsZeros && + ((Subtarget.hasAVX2() && + (MaskVT == MVT::v8f32 || MaskVT == MVT::v8i32)) || + (Subtarget.hasAVX512() && + (MaskVT == MVT::v8f64 || MaskVT == MVT::v8i64 || + MaskVT == MVT::v16f32 || MaskVT == MVT::v16i32)) || + (Subtarget.hasBWI() && MaskVT == MVT::v32i16) || + (Subtarget.hasBWI() && Subtarget.hasVLX() && MaskVT == MVT::v16i16) || + (Subtarget.hasVBMI() && MaskVT == MVT::v64i8) || + (Subtarget.hasVBMI() && Subtarget.hasVLX() && MaskVT == MVT::v32i8))) { + MVT VPermMaskSVT = MVT::getIntegerVT(MaskEltSizeInBits); + MVT VPermMaskVT = MVT::getVectorVT(VPermMaskSVT, NumMaskElts); + SDValue VPermMask = getConstVector(Mask, VPermMaskVT, DAG, DL, true); + DCI.AddToWorklist(VPermMask.getNode()); + Res = DAG.getBitcast(MaskVT, V1); DCI.AddToWorklist(Res.getNode()); - Res = DAG.getNode(X86ISD::BLENDI, DL, ShuffleVT, Res, Zero, - DAG.getConstant(BlendMask, DL, MVT::i8)); + Res = DAG.getNode(X86ISD::VPERMV, DL, MaskVT, VPermMask, Res); DCI.AddToWorklist(Res.getNode()); DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Res), /*AddTo*/ true); return true; } - } - // Attempt to combine to INSERTPS. - if (Subtarget.hasSSE41() && NumMaskElts == 4 && - (VT == MVT::v2f64 || VT == MVT::v4f32)) { - SmallBitVector Zeroable(4, false); - for (unsigned i = 0; i != NumMaskElts; ++i) - if (Mask[i] < 0) - Zeroable[i] = true; + // Lower a unary+zero lane-crossing shuffle as VPERMV3 with a zero + // vector as the second source. + if (UnaryShuffle && (Depth >= 3 || HasVariableMask) && + ((Subtarget.hasAVX512() && + (MaskVT == MVT::v8f64 || MaskVT == MVT::v8i64 || + MaskVT == MVT::v16f32 || MaskVT == MVT::v16i32)) || + (Subtarget.hasVLX() && + (MaskVT == MVT::v4f64 || MaskVT == MVT::v4i64 || + MaskVT == MVT::v8f32 || MaskVT == MVT::v8i32)) || + (Subtarget.hasBWI() && MaskVT == MVT::v32i16) || + (Subtarget.hasBWI() && Subtarget.hasVLX() && MaskVT == MVT::v16i16) || + (Subtarget.hasVBMI() && MaskVT == MVT::v64i8) || + (Subtarget.hasVBMI() && Subtarget.hasVLX() && MaskVT == MVT::v32i8))) { + // Adjust shuffle mask - replace SM_SentinelZero with second source index. + for (unsigned i = 0; i != NumMaskElts; ++i) + if (Mask[i] == SM_SentinelZero) + Mask[i] = NumMaskElts + i; + + MVT VPermMaskSVT = MVT::getIntegerVT(MaskEltSizeInBits); + MVT VPermMaskVT = MVT::getVectorVT(VPermMaskSVT, NumMaskElts); + SDValue VPermMask = getConstVector(Mask, VPermMaskVT, DAG, DL, true); + DCI.AddToWorklist(VPermMask.getNode()); + Res = DAG.getBitcast(MaskVT, V1); + DCI.AddToWorklist(Res.getNode()); + SDValue Zero = getZeroVector(MaskVT, Subtarget, DAG, DL); + DCI.AddToWorklist(Zero.getNode()); + Res = DAG.getNode(X86ISD::VPERMV3, DL, MaskVT, Res, VPermMask, Zero); + DCI.AddToWorklist(Res.getNode()); + DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Res), + /*AddTo*/ true); + return true; + } - unsigned InsertPSMask; - SDValue V1 = Input, V2 = Input; - if (Zeroable.any() && matchVectorShuffleAsInsertPS(V1, V2, InsertPSMask, - Zeroable, Mask, DAG)) { - if (Depth == 1 && Root.getOpcode() == X86ISD::INSERTPS) - return false; // Nothing to do! - V1 = DAG.getBitcast(MVT::v4f32, V1); + // If we have a dual input lane-crossing shuffle then lower to VPERMV3. + if ((Depth >= 3 || HasVariableMask) && !MaskContainsZeros && + ((Subtarget.hasAVX512() && + (MaskVT == MVT::v8f64 || MaskVT == MVT::v8i64 || + MaskVT == MVT::v16f32 || MaskVT == MVT::v16i32)) || + (Subtarget.hasVLX() && + (MaskVT == MVT::v4f64 || MaskVT == MVT::v4i64 || + MaskVT == MVT::v8f32 || MaskVT == MVT::v8i32)) || + (Subtarget.hasBWI() && MaskVT == MVT::v32i16) || + (Subtarget.hasBWI() && Subtarget.hasVLX() && MaskVT == MVT::v16i16) || + (Subtarget.hasVBMI() && MaskVT == MVT::v64i8) || + (Subtarget.hasVBMI() && Subtarget.hasVLX() && MaskVT == MVT::v32i8))) { + MVT VPermMaskSVT = MVT::getIntegerVT(MaskEltSizeInBits); + MVT VPermMaskVT = MVT::getVectorVT(VPermMaskSVT, NumMaskElts); + SDValue VPermMask = getConstVector(Mask, VPermMaskVT, DAG, DL, true); + DCI.AddToWorklist(VPermMask.getNode()); + V1 = DAG.getBitcast(MaskVT, V1); DCI.AddToWorklist(V1.getNode()); - V2 = DAG.getBitcast(MVT::v4f32, V2); + V2 = DAG.getBitcast(MaskVT, V2); DCI.AddToWorklist(V2.getNode()); - Res = DAG.getNode(X86ISD::INSERTPS, DL, MVT::v4f32, V1, V2, - DAG.getConstant(InsertPSMask, DL, MVT::i8)); + Res = DAG.getNode(X86ISD::VPERMV3, DL, MaskVT, V1, VPermMask, V2); DCI.AddToWorklist(Res.getNode()); DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Res), /*AddTo*/ true); return true; } - } - - // Don't try to re-form single instruction chains under any circumstances now - // that we've done encoding canonicalization for them. - if (Depth < 2) - return false; - - if (is128BitLaneCrossingShuffleMask(MaskVT, Mask)) return false; + } - bool MaskContainsZeros = - llvm::any_of(Mask, [](int M) { return M == SM_SentinelZero; }); + // See if we can combine a single input shuffle with zeros to a bit-mask, + // which is much simpler than any shuffle. + if (UnaryShuffle && MaskContainsZeros && (Depth >= 3 || HasVariableMask) && + isSequentialOrUndefOrZeroInRange(Mask, 0, NumMaskElts, 0) && + DAG.getTargetLoweringInfo().isTypeLegal(MaskVT)) { + APInt Zero = APInt::getNullValue(MaskEltSizeInBits); + APInt AllOnes = APInt::getAllOnesValue(MaskEltSizeInBits); + SmallBitVector UndefElts(NumMaskElts, false); + SmallVector<APInt, 64> EltBits(NumMaskElts, Zero); + for (unsigned i = 0; i != NumMaskElts; ++i) { + int M = Mask[i]; + if (M == SM_SentinelUndef) { + UndefElts[i] = true; + continue; + } + if (M == SM_SentinelZero) + continue; + EltBits[i] = AllOnes; + } + SDValue BitMask = getConstVector(EltBits, UndefElts, MaskVT, DAG, DL); + DCI.AddToWorklist(BitMask.getNode()); + Res = DAG.getBitcast(MaskVT, V1); + DCI.AddToWorklist(Res.getNode()); + unsigned AndOpcode = + FloatDomain ? unsigned(X86ISD::FAND) : unsigned(ISD::AND); + Res = DAG.getNode(AndOpcode, DL, MaskVT, Res, BitMask); + DCI.AddToWorklist(Res.getNode()); + DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Res), + /*AddTo*/ true); + return true; + } // If we have a single input shuffle with different shuffle patterns in the // the 128-bit lanes use the variable mask to VPERMILPS. // TODO Combine other mask types at higher depths. - if (HasVariableMask && !MaskContainsZeros && + if (UnaryShuffle && HasVariableMask && !MaskContainsZeros && ((MaskVT == MVT::v8f32 && Subtarget.hasAVX()) || (MaskVT == MVT::v16f32 && Subtarget.hasAVX512()))) { SmallVector<SDValue, 16> VPermIdx; @@ -25283,7 +27079,7 @@ static bool combineX86ShuffleChain(SDValue Input, SDValue Root, MVT VPermMaskVT = MVT::getVectorVT(MVT::i32, NumMaskElts); SDValue VPermMask = DAG.getBuildVector(VPermMaskVT, DL, VPermIdx); DCI.AddToWorklist(VPermMask.getNode()); - Res = DAG.getBitcast(MaskVT, Input); + Res = DAG.getBitcast(MaskVT, V1); DCI.AddToWorklist(Res.getNode()); Res = DAG.getNode(X86ISD::VPERMILPV, DL, MaskVT, Res, VPermMask); DCI.AddToWorklist(Res.getNode()); @@ -25292,17 +27088,60 @@ static bool combineX86ShuffleChain(SDValue Input, SDValue Root, return true; } + // With XOP, binary shuffles of 128/256-bit floating point vectors can combine + // to VPERMIL2PD/VPERMIL2PS. + if ((Depth >= 3 || HasVariableMask) && Subtarget.hasXOP() && + (MaskVT == MVT::v2f64 || MaskVT == MVT::v4f64 || MaskVT == MVT::v4f32 || + MaskVT == MVT::v8f32)) { + // VPERMIL2 Operation. + // Bits[3] - Match Bit. + // Bits[2:1] - (Per Lane) PD Shuffle Mask. + // Bits[2:0] - (Per Lane) PS Shuffle Mask. + unsigned NumLanes = MaskVT.getSizeInBits() / 128; + unsigned NumEltsPerLane = NumMaskElts / NumLanes; + SmallVector<int, 8> VPerm2Idx; + MVT MaskIdxSVT = MVT::getIntegerVT(MaskVT.getScalarSizeInBits()); + MVT MaskIdxVT = MVT::getVectorVT(MaskIdxSVT, NumMaskElts); + unsigned M2ZImm = 0; + for (int M : Mask) { + if (M == SM_SentinelUndef) { + VPerm2Idx.push_back(-1); + continue; + } + if (M == SM_SentinelZero) { + M2ZImm = 2; + VPerm2Idx.push_back(8); + continue; + } + int Index = (M % NumEltsPerLane) + ((M / NumMaskElts) * NumEltsPerLane); + Index = (MaskVT.getScalarSizeInBits() == 64 ? Index << 1 : Index); + VPerm2Idx.push_back(Index); + } + V1 = DAG.getBitcast(MaskVT, V1); + DCI.AddToWorklist(V1.getNode()); + V2 = DAG.getBitcast(MaskVT, V2); + DCI.AddToWorklist(V2.getNode()); + SDValue VPerm2MaskOp = getConstVector(VPerm2Idx, MaskIdxVT, DAG, DL, true); + DCI.AddToWorklist(VPerm2MaskOp.getNode()); + Res = DAG.getNode(X86ISD::VPERMIL2, DL, MaskVT, V1, V2, VPerm2MaskOp, + DAG.getConstant(M2ZImm, DL, MVT::i8)); + DCI.AddToWorklist(Res.getNode()); + DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Res), + /*AddTo*/ true); + return true; + } + // If we have 3 or more shuffle instructions or a chain involving a variable // mask, we can replace them with a single PSHUFB instruction profitably. // Intel's manuals suggest only using PSHUFB if doing so replacing 5 // instructions, but in practice PSHUFB tends to be *very* fast so we're // more aggressive. - if ((Depth >= 3 || HasVariableMask) && - ((VT.is128BitVector() && Subtarget.hasSSSE3()) || - (VT.is256BitVector() && Subtarget.hasAVX2()) || - (VT.is512BitVector() && Subtarget.hasBWI()))) { + if (UnaryShuffle && (Depth >= 3 || HasVariableMask) && + ((RootVT.is128BitVector() && Subtarget.hasSSSE3()) || + (RootVT.is256BitVector() && Subtarget.hasAVX2()) || + (RootVT.is512BitVector() && Subtarget.hasBWI()))) { SmallVector<SDValue, 16> PSHUFBMask; - int NumBytes = VT.getSizeInBits() / 8; + int NumBytes = RootVT.getSizeInBits() / 8; int Ratio = NumBytes / NumMaskElts; for (int i = 0; i < NumBytes; ++i) { int M = Mask[i / Ratio]; @@ -25319,7 +27158,7 @@ static bool combineX86ShuffleChain(SDValue Input, SDValue Root, PSHUFBMask.push_back(DAG.getConstant(M, DL, MVT::i8)); } MVT ByteVT = MVT::getVectorVT(MVT::i8, NumBytes); - Res = DAG.getBitcast(ByteVT, Input); + Res = DAG.getBitcast(ByteVT, V1); DCI.AddToWorklist(Res.getNode()); SDValue PSHUFBMaskOp = DAG.getBuildVector(ByteVT, DL, PSHUFBMask); DCI.AddToWorklist(PSHUFBMaskOp.getNode()); @@ -25330,10 +27169,135 @@ static bool combineX86ShuffleChain(SDValue Input, SDValue Root, return true; } + // With XOP, if we have a 128-bit binary input shuffle we can always combine + // to VPPERM. We match the depth requirement of PSHUFB - VPPERM is never + // slower than PSHUFB on targets that support both. + if ((Depth >= 3 || HasVariableMask) && RootVT.is128BitVector() && + Subtarget.hasXOP()) { + // VPPERM Mask Operation + // Bits[4:0] - Byte Index (0 - 31) + // Bits[7:5] - Permute Operation (0 - Source byte, 4 - ZERO) + SmallVector<SDValue, 16> VPPERMMask; + int NumBytes = 16; + int Ratio = NumBytes / NumMaskElts; + for (int i = 0; i < NumBytes; ++i) { + int M = Mask[i / Ratio]; + if (M == SM_SentinelUndef) { + VPPERMMask.push_back(DAG.getUNDEF(MVT::i8)); + continue; + } + if (M == SM_SentinelZero) { + VPPERMMask.push_back(DAG.getConstant(128, DL, MVT::i8)); + continue; + } + M = Ratio * M + i % Ratio; + VPPERMMask.push_back(DAG.getConstant(M, DL, MVT::i8)); + } + MVT ByteVT = MVT::v16i8; + V1 = DAG.getBitcast(ByteVT, V1); + DCI.AddToWorklist(V1.getNode()); + V2 = DAG.getBitcast(ByteVT, V2); + DCI.AddToWorklist(V2.getNode()); + SDValue VPPERMMaskOp = DAG.getBuildVector(ByteVT, DL, VPPERMMask); + DCI.AddToWorklist(VPPERMMaskOp.getNode()); + Res = DAG.getNode(X86ISD::VPPERM, DL, ByteVT, V1, V2, VPPERMMaskOp); + DCI.AddToWorklist(Res.getNode()); + DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Res), + /*AddTo*/ true); + return true; + } + // Failed to find any combines. return false; } +// Attempt to constant fold all of the constant source ops. +// Returns true if the entire shuffle is folded to a constant. +// TODO: Extend this to merge multiple constant Ops and update the mask. +static bool combineX86ShufflesConstants(const SmallVectorImpl<SDValue> &Ops, + ArrayRef<int> Mask, SDValue Root, + bool HasVariableMask, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI, + const X86Subtarget &Subtarget) { + MVT VT = Root.getSimpleValueType(); + + unsigned SizeInBits = VT.getSizeInBits(); + unsigned NumMaskElts = Mask.size(); + unsigned MaskSizeInBits = SizeInBits / NumMaskElts; + unsigned NumOps = Ops.size(); + + // Extract constant bits from each source op. + bool OneUseConstantOp = false; + SmallVector<SmallBitVector, 4> UndefEltsOps(NumOps); + SmallVector<SmallVector<APInt, 8>, 4> RawBitsOps(NumOps); + for (unsigned i = 0; i != NumOps; ++i) { + SDValue SrcOp = Ops[i]; + OneUseConstantOp |= SrcOp.hasOneUse(); + if (!getTargetConstantBitsFromNode(SrcOp, MaskSizeInBits, UndefEltsOps[i], + RawBitsOps[i])) + return false; + } + + // Only fold if at least one of the constants is only used once or + // the combined shuffle has included a variable mask shuffle, this + // is to avoid constant pool bloat. + if (!OneUseConstantOp && !HasVariableMask) + return false; + + // Shuffle the constant bits according to the mask. + SmallBitVector UndefElts(NumMaskElts, false); + SmallBitVector ZeroElts(NumMaskElts, false); + SmallBitVector ConstantElts(NumMaskElts, false); + SmallVector<APInt, 8> ConstantBitData(NumMaskElts, + APInt::getNullValue(MaskSizeInBits)); + for (unsigned i = 0; i != NumMaskElts; ++i) { + int M = Mask[i]; + if (M == SM_SentinelUndef) { + UndefElts[i] = true; + continue; + } else if (M == SM_SentinelZero) { + ZeroElts[i] = true; + continue; + } + assert(0 <= M && M < (int)(NumMaskElts * NumOps)); + + unsigned SrcOpIdx = (unsigned)M / NumMaskElts; + unsigned SrcMaskIdx = (unsigned)M % NumMaskElts; + + auto &SrcUndefElts = UndefEltsOps[SrcOpIdx]; + if (SrcUndefElts[SrcMaskIdx]) { + UndefElts[i] = true; + continue; + } + + auto &SrcEltBits = RawBitsOps[SrcOpIdx]; + APInt &Bits = SrcEltBits[SrcMaskIdx]; + if (!Bits) { + ZeroElts[i] = true; + continue; + } + + ConstantElts[i] = true; + ConstantBitData[i] = Bits; + } + assert((UndefElts | ZeroElts | ConstantElts).count() == NumMaskElts); + + // Create the constant data. + MVT MaskSVT; + if (VT.isFloatingPoint() && (MaskSizeInBits == 32 || MaskSizeInBits == 64)) + MaskSVT = MVT::getFloatingPointVT(MaskSizeInBits); + else + MaskSVT = MVT::getIntegerVT(MaskSizeInBits); + + MVT MaskVT = MVT::getVectorVT(MaskSVT, NumMaskElts); + + SDLoc DL(Root); + SDValue CstOp = getConstVector(ConstantBitData, UndefElts, MaskVT, DAG, DL); + DCI.AddToWorklist(CstOp.getNode()); + DCI.CombineTo(Root.getNode(), DAG.getBitcast(VT, CstOp)); + return true; +} + /// \brief Fully generic combining of x86 shuffle instructions. /// /// This should be the last combine run over the x86 shuffle instructions. Once @@ -25350,7 +27314,7 @@ static bool combineX86ShuffleChain(SDValue Input, SDValue Root, /// instructions, and replace them with the slightly more expensive SSSE3 /// PSHUFB instruction if available. We do this as the last combining step /// to ensure we avoid using PSHUFB if we can implement the shuffle with -/// a suitable short sequence of other instructions. The PHUFB will either +/// a suitable short sequence of other instructions. The PSHUFB will either /// use a register or have to read from memory and so is slightly (but only /// slightly) more expensive than the other shuffle instructions. /// @@ -25363,7 +27327,8 @@ static bool combineX86ShuffleChain(SDValue Input, SDValue Root, /// would simplify under the threshold for PSHUFB formation because of /// combine-ordering. To fix this, we should do the redundant instruction /// combining in this recursive walk. -static bool combineX86ShufflesRecursively(SDValue Op, SDValue Root, +static bool combineX86ShufflesRecursively(ArrayRef<SDValue> SrcOps, + int SrcOpIndex, SDValue Root, ArrayRef<int> RootMask, int Depth, bool HasVariableMask, SelectionDAG &DAG, @@ -25375,8 +27340,8 @@ static bool combineX86ShufflesRecursively(SDValue Op, SDValue Root, return false; // Directly rip through bitcasts to find the underlying operand. - while (Op.getOpcode() == ISD::BITCAST && Op.getOperand(0).hasOneUse()) - Op = Op.getOperand(0); + SDValue Op = SrcOps[SrcOpIndex]; + Op = peekThroughOneUseBitcasts(Op); MVT VT = Op.getSimpleValueType(); if (!VT.isVector()) @@ -25393,8 +27358,27 @@ static bool combineX86ShufflesRecursively(SDValue Op, SDValue Root, if (!resolveTargetShuffleInputs(Op, Input0, Input1, OpMask)) return false; - assert(VT.getVectorNumElements() == OpMask.size() && - "Different mask size from vector size!"); + // Add the inputs to the Ops list, avoiding duplicates. + SmallVector<SDValue, 8> Ops(SrcOps.begin(), SrcOps.end()); + + int InputIdx0 = -1, InputIdx1 = -1; + for (int i = 0, e = Ops.size(); i < e; ++i) { + SDValue BC = peekThroughBitcasts(Ops[i]); + if (Input0 && BC == peekThroughBitcasts(Input0)) + InputIdx0 = i; + if (Input1 && BC == peekThroughBitcasts(Input1)) + InputIdx1 = i; + } + + if (Input0 && InputIdx0 < 0) { + InputIdx0 = SrcOpIndex; + Ops[SrcOpIndex] = Input0; + } + if (Input1 && InputIdx1 < 0) { + InputIdx1 = Ops.size(); + Ops.push_back(Input1); + } + assert(((RootMask.size() > OpMask.size() && RootMask.size() % OpMask.size() == 0) || (OpMask.size() > RootMask.size() && @@ -25424,6 +27408,17 @@ static bool combineX86ShufflesRecursively(SDValue Op, SDValue Root, } int RootMaskedIdx = RootMask[RootIdx] * RootRatio + i % RootRatio; + + // Just insert the scaled root mask value if it references an input other + // than the SrcOp we're currently inserting. + if ((RootMaskedIdx < (SrcOpIndex * MaskWidth)) || + (((SrcOpIndex + 1) * MaskWidth) <= RootMaskedIdx)) { + Mask.push_back(RootMaskedIdx); + continue; + } + + RootMaskedIdx %= MaskWidth; + int OpIdx = RootMaskedIdx / OpRatio; if (OpMask[OpIdx] < 0) { // The incoming lanes are zero or undef, it doesn't matter which ones we @@ -25432,17 +27427,27 @@ static bool combineX86ShufflesRecursively(SDValue Op, SDValue Root, continue; } - // Ok, we have non-zero lanes, map them through. - Mask.push_back(OpMask[OpIdx] * OpRatio + - RootMaskedIdx % OpRatio); + // Ok, we have non-zero lanes, map them through to one of the Op's inputs. + int OpMaskedIdx = OpMask[OpIdx] * OpRatio + RootMaskedIdx % OpRatio; + OpMaskedIdx %= MaskWidth; + + if (OpMask[OpIdx] < (int)OpMask.size()) { + assert(0 <= InputIdx0 && "Unknown target shuffle input"); + OpMaskedIdx += InputIdx0 * MaskWidth; + } else { + assert(0 <= InputIdx1 && "Unknown target shuffle input"); + OpMaskedIdx += InputIdx1 * MaskWidth; + } + + Mask.push_back(OpMaskedIdx); } // Handle the all undef/zero cases early. - if (llvm::all_of(Mask, [](int Idx) { return Idx == SM_SentinelUndef; })) { + if (all_of(Mask, [](int Idx) { return Idx == SM_SentinelUndef; })) { DCI.CombineTo(Root.getNode(), DAG.getUNDEF(Root.getValueType())); return true; } - if (llvm::all_of(Mask, [](int Idx) { return Idx < 0; })) { + if (all_of(Mask, [](int Idx) { return Idx < 0; })) { // TODO - should we handle the mixed zero/undef case as well? Just returning // a zero mask will lose information on undef elements possibly reducing // future combine possibilities. @@ -25451,30 +27456,40 @@ static bool combineX86ShufflesRecursively(SDValue Op, SDValue Root, return true; } - int MaskSize = Mask.size(); - bool UseInput0 = std::any_of(Mask.begin(), Mask.end(), - [MaskSize](int Idx) { return 0 <= Idx && Idx < MaskSize; }); - bool UseInput1 = std::any_of(Mask.begin(), Mask.end(), - [MaskSize](int Idx) { return MaskSize <= Idx; }); - - // At the moment we can only combine unary shuffle mask cases. - if (UseInput0 && UseInput1) - return false; - else if (UseInput1) { - std::swap(Input0, Input1); - ShuffleVectorSDNode::commuteMask(Mask); + // Remove unused shuffle source ops. + SmallVector<SDValue, 8> UsedOps; + for (int i = 0, e = Ops.size(); i < e; ++i) { + int lo = UsedOps.size() * MaskWidth; + int hi = lo + MaskWidth; + if (any_of(Mask, [lo, hi](int i) { return (lo <= i) && (i < hi); })) { + UsedOps.push_back(Ops[i]); + continue; + } + for (int &M : Mask) + if (lo <= M) + M -= MaskWidth; } - - assert(Input0 && "Shuffle with no inputs detected"); + assert(!UsedOps.empty() && "Shuffle with no inputs detected"); + Ops = UsedOps; HasVariableMask |= isTargetShuffleVariableMask(Op.getOpcode()); - // See if we can recurse into Input0 (if it's a target shuffle). - if (Op->isOnlyUserOf(Input0.getNode()) && - combineX86ShufflesRecursively(Input0, Root, Mask, Depth + 1, - HasVariableMask, DAG, DCI, Subtarget)) + // See if we can recurse into each shuffle source op (if it's a target shuffle). + for (int i = 0, e = Ops.size(); i < e; ++i) + if (Ops[i].getNode()->hasOneUse() || Op->isOnlyUserOf(Ops[i].getNode())) + if (combineX86ShufflesRecursively(Ops, i, Root, Mask, Depth + 1, + HasVariableMask, DAG, DCI, Subtarget)) + return true; + + // Attempt to constant fold all of the constant source ops. + if (combineX86ShufflesConstants(Ops, Mask, Root, HasVariableMask, DAG, DCI, + Subtarget)) return true; + // We can only combine unary and binary shuffle mask cases. + if (Ops.size() > 2) + return false; + // Minor canonicalization of the accumulated shuffle mask to make it easier // to match below. All this does is detect masks with sequential pairs of // elements, and shrink them to the half-width mask. It does this in a loop @@ -25485,7 +27500,14 @@ static bool combineX86ShufflesRecursively(SDValue Op, SDValue Root, Mask = std::move(WidenedMask); } - return combineX86ShuffleChain(Input0, Root, Mask, Depth, HasVariableMask, DAG, + // Canonicalization of binary shuffle masks to improve pattern matching by + // commuting the inputs. + if (Ops.size() == 2 && canonicalizeShuffleMaskWithCommute(Mask)) { + ShuffleVectorSDNode::commuteMask(Mask); + std::swap(Ops[0], Ops[1]); + } + + return combineX86ShuffleChain(Ops, Root, Mask, Depth, HasVariableMask, DAG, DCI, Subtarget); } @@ -25612,7 +27634,7 @@ combineRedundantDWordShuffle(SDValue N, MutableArrayRef<int> Mask, Chain.push_back(V); - // Fallthrough! + LLVM_FALLTHROUGH; case ISD::BITCAST: V = V.getOperand(0); continue; @@ -25742,7 +27764,8 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG, MVT VT = N.getSimpleValueType(); SmallVector<int, 4> Mask; - switch (N.getOpcode()) { + unsigned Opcode = N.getOpcode(); + switch (Opcode) { case X86ISD::PSHUFD: case X86ISD::PSHUFLW: case X86ISD::PSHUFHW: @@ -25750,6 +27773,17 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG, assert(Mask.size() == 4); break; case X86ISD::UNPCKL: { + auto Op0 = N.getOperand(0); + auto Op1 = N.getOperand(1); + unsigned Opcode0 = Op0.getOpcode(); + unsigned Opcode1 = Op1.getOpcode(); + + // Combine X86ISD::UNPCKL with 2 X86ISD::FHADD inputs into a single + // X86ISD::FHADD. This is generated by UINT_TO_FP v2f64 scalarization. + // TODO: Add other horizontal operations as required. + if (VT == MVT::v2f64 && Opcode0 == Opcode1 && Opcode0 == X86ISD::FHADD) + return DAG.getNode(Opcode0, DL, VT, Op0.getOperand(0), Op1.getOperand(0)); + // Combine X86ISD::UNPCKL and ISD::VECTOR_SHUFFLE into X86ISD::UNPCKH, in // which X86ISD::UNPCKL has a ISD::UNDEF operand, and ISD::VECTOR_SHUFFLE // moves upper half elements into the lower half part. For example: @@ -25767,9 +27801,7 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG, if (!VT.is128BitVector()) return SDValue(); - auto Op0 = N.getOperand(0); - auto Op1 = N.getOperand(1); - if (Op0.isUndef() && Op1.getNode()->getOpcode() == ISD::VECTOR_SHUFFLE) { + if (Op0.isUndef() && Opcode1 == ISD::VECTOR_SHUFFLE) { ArrayRef<int> Mask = cast<ShuffleVectorSDNode>(Op1.getNode())->getMask(); unsigned NumElts = VT.getVectorNumElements(); @@ -25806,44 +27838,31 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG, return DAG.getNode(X86ISD::BLENDI, DL, VT, V1, V0, NewMask); } - // Attempt to merge blend(insertps(x,y),zero). - if (V0.getOpcode() == X86ISD::INSERTPS || - V1.getOpcode() == X86ISD::INSERTPS) { - assert(VT == MVT::v4f32 && "INSERTPS ValueType must be MVT::v4f32"); - - // Determine which elements are known to be zero. - SmallVector<int, 8> TargetMask; - SmallVector<SDValue, 2> BlendOps; - if (!setTargetShuffleZeroElements(N, TargetMask, BlendOps)) - return SDValue(); - - // Helper function to take inner insertps node and attempt to - // merge the blend with zero into its zero mask. - auto MergeInsertPSAndBlend = [&](SDValue V, int Offset) { - if (V.getOpcode() != X86ISD::INSERTPS) - return SDValue(); - SDValue Op0 = V.getOperand(0); - SDValue Op1 = V.getOperand(1); - SDValue Op2 = V.getOperand(2); - unsigned InsertPSMask = cast<ConstantSDNode>(Op2)->getZExtValue(); - - // Check each element of the blend node's target mask - must either - // be zeroable (and update the zero mask) or selects the element from - // the inner insertps node. - for (int i = 0; i != 4; ++i) - if (TargetMask[i] < 0) - InsertPSMask |= (1u << i); - else if (TargetMask[i] != (i + Offset)) - return SDValue(); - return DAG.getNode(X86ISD::INSERTPS, DL, MVT::v4f32, Op0, Op1, - DAG.getConstant(InsertPSMask, DL, MVT::i8)); - }; - - if (SDValue V = MergeInsertPSAndBlend(V0, 0)) - return V; - if (SDValue V = MergeInsertPSAndBlend(V1, 4)) - return V; + return SDValue(); + } + case X86ISD::MOVSD: + case X86ISD::MOVSS: { + bool isFloat = VT.isFloatingPoint(); + SDValue V0 = peekThroughBitcasts(N->getOperand(0)); + SDValue V1 = peekThroughBitcasts(N->getOperand(1)); + bool isFloat0 = V0.getSimpleValueType().isFloatingPoint(); + bool isFloat1 = V1.getSimpleValueType().isFloatingPoint(); + bool isZero0 = ISD::isBuildVectorAllZeros(V0.getNode()); + bool isZero1 = ISD::isBuildVectorAllZeros(V1.getNode()); + assert(!(isZero0 && isZero1) && "Zeroable shuffle detected."); + + // We often lower to MOVSD/MOVSS from integer as well as native float + // types; remove unnecessary domain-crossing bitcasts if we can to make it + // easier to combine shuffles later on. We've already accounted for the + // domain switching cost when we decided to lower with it. + if ((isFloat != isFloat0 || isZero0) && (isFloat != isFloat1 || isZero1)) { + MVT NewVT = isFloat ? (X86ISD::MOVSD == Opcode ? MVT::v2i64 : MVT::v4i32) + : (X86ISD::MOVSD == Opcode ? MVT::v2f64 : MVT::v4f32); + V0 = DAG.getBitcast(NewVT, V0); + V1 = DAG.getBitcast(NewVT, V1); + return DAG.getBitcast(VT, DAG.getNode(Opcode, DL, NewVT, V0, V1)); } + return SDValue(); } case X86ISD::INSERTPS: { @@ -25976,9 +27995,7 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG, V.getOpcode() == X86ISD::PSHUFHW) && V.getOpcode() != N.getOpcode() && V.hasOneUse()) { - SDValue D = V.getOperand(0); - while (D.getOpcode() == ISD::BITCAST && D.hasOneUse()) - D = D.getOperand(0); + SDValue D = peekThroughOneUseBitcasts(V.getOperand(0)); if (D.getOpcode() == X86ISD::PSHUFD && D.hasOneUse()) { SmallVector<int, 4> VMask = getPSHUFShuffleMask(V); SmallVector<int, 4> DMask = getPSHUFShuffleMask(D); @@ -26017,31 +28034,32 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG, return SDValue(); } -/// \brief Try to combine a shuffle into a target-specific add-sub node. +/// Returns true iff the shuffle node \p N can be replaced with ADDSUB +/// operation. If true is returned then the operands of ADDSUB operation +/// are written to the parameters \p Opnd0 and \p Opnd1. /// -/// We combine this directly on the abstract vector shuffle nodes so it is -/// easier to generically match. We also insert dummy vector shuffle nodes for -/// the operands which explicitly discard the lanes which are unused by this -/// operation to try to flow through the rest of the combiner the fact that -/// they're unused. -static SDValue combineShuffleToAddSub(SDNode *N, const X86Subtarget &Subtarget, - SelectionDAG &DAG) { - SDLoc DL(N); +/// We combine shuffle to ADDSUB directly on the abstract vector shuffle nodes +/// so it is easier to generically match. We also insert dummy vector shuffle +/// nodes for the operands which explicitly discard the lanes which are unused +/// by this operation to try to flow through the rest of the combiner +/// the fact that they're unused. +static bool isAddSub(SDNode *N, const X86Subtarget &Subtarget, + SDValue &Opnd0, SDValue &Opnd1) { + EVT VT = N->getValueType(0); if ((!Subtarget.hasSSE3() || (VT != MVT::v4f32 && VT != MVT::v2f64)) && - (!Subtarget.hasAVX() || (VT != MVT::v8f32 && VT != MVT::v4f64))) - return SDValue(); + (!Subtarget.hasAVX() || (VT != MVT::v8f32 && VT != MVT::v4f64)) && + (!Subtarget.hasAVX512() || (VT != MVT::v16f32 && VT != MVT::v8f64))) + return false; // We only handle target-independent shuffles. // FIXME: It would be easy and harmless to use the target shuffle mask // extraction tool to support more. if (N->getOpcode() != ISD::VECTOR_SHUFFLE) - return SDValue(); + return false; - auto *SVN = cast<ShuffleVectorSDNode>(N); - SmallVector<int, 8> Mask; - for (int M : SVN->getMask()) - Mask.push_back(M); + ArrayRef<int> OrigMask = cast<ShuffleVectorSDNode>(N)->getMask(); + SmallVector<int, 16> Mask(OrigMask.begin(), OrigMask.end()); SDValue V1 = N->getOperand(0); SDValue V2 = N->getOperand(1); @@ -26052,27 +28070,102 @@ static SDValue combineShuffleToAddSub(SDNode *N, const X86Subtarget &Subtarget, ShuffleVectorSDNode::commuteMask(Mask); std::swap(V1, V2); } else if (V1.getOpcode() != ISD::FSUB || V2.getOpcode() != ISD::FADD) - return SDValue(); + return false; // If there are other uses of these operations we can't fold them. if (!V1->hasOneUse() || !V2->hasOneUse()) - return SDValue(); + return false; // Ensure that both operations have the same operands. Note that we can // commute the FADD operands. SDValue LHS = V1->getOperand(0), RHS = V1->getOperand(1); if ((V2->getOperand(0) != LHS || V2->getOperand(1) != RHS) && (V2->getOperand(0) != RHS || V2->getOperand(1) != LHS)) - return SDValue(); + return false; // We're looking for blends between FADD and FSUB nodes. We insist on these // nodes being lined up in a specific expected pattern. if (!(isShuffleEquivalent(V1, V2, Mask, {0, 3}) || isShuffleEquivalent(V1, V2, Mask, {0, 5, 2, 7}) || - isShuffleEquivalent(V1, V2, Mask, {0, 9, 2, 11, 4, 13, 6, 15}))) + isShuffleEquivalent(V1, V2, Mask, {0, 9, 2, 11, 4, 13, 6, 15}) || + isShuffleEquivalent(V1, V2, Mask, {0, 17, 2, 19, 4, 21, 6, 23, + 8, 25, 10, 27, 12, 29, 14, 31}))) + return false; + + Opnd0 = LHS; + Opnd1 = RHS; + return true; +} + +/// \brief Try to combine a shuffle into a target-specific add-sub or +/// mul-add-sub node. +static SDValue combineShuffleToAddSubOrFMAddSub(SDNode *N, + const X86Subtarget &Subtarget, + SelectionDAG &DAG) { + SDValue Opnd0, Opnd1; + if (!isAddSub(N, Subtarget, Opnd0, Opnd1)) return SDValue(); - return DAG.getNode(X86ISD::ADDSUB, DL, VT, LHS, RHS); + EVT VT = N->getValueType(0); + SDLoc DL(N); + + // Try to generate X86ISD::FMADDSUB node here. + SDValue Opnd2; + if (isFMAddSub(Subtarget, DAG, Opnd0, Opnd1, Opnd2)) + return DAG.getNode(X86ISD::FMADDSUB, DL, VT, Opnd0, Opnd1, Opnd2); + + // Do not generate X86ISD::ADDSUB node for 512-bit types even though + // the ADDSUB idiom has been successfully recognized. There are no known + // X86 targets with 512-bit ADDSUB instructions! + if (VT.is512BitVector()) + return SDValue(); + + return DAG.getNode(X86ISD::ADDSUB, DL, VT, Opnd0, Opnd1); +} + +// We are looking for a shuffle where both sources are concatenated with undef +// and have a width that is half of the output's width. AVX2 has VPERMD/Q, so +// if we can express this as a single-source shuffle, that's preferable. +static SDValue combineShuffleOfConcatUndef(SDNode *N, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + if (!Subtarget.hasAVX2() || !isa<ShuffleVectorSDNode>(N)) + return SDValue(); + + EVT VT = N->getValueType(0); + + // We only care about shuffles of 128/256-bit vectors of 32/64-bit values. + if (!VT.is128BitVector() && !VT.is256BitVector()) + return SDValue(); + + if (VT.getVectorElementType() != MVT::i32 && + VT.getVectorElementType() != MVT::i64 && + VT.getVectorElementType() != MVT::f32 && + VT.getVectorElementType() != MVT::f64) + return SDValue(); + + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + + // Check that both sources are concats with undef. + if (N0.getOpcode() != ISD::CONCAT_VECTORS || + N1.getOpcode() != ISD::CONCAT_VECTORS || N0.getNumOperands() != 2 || + N1.getNumOperands() != 2 || !N0.getOperand(1).isUndef() || + !N1.getOperand(1).isUndef()) + return SDValue(); + + // Construct the new shuffle mask. Elements from the first source retain their + // index, but elements from the second source no longer need to skip an undef. + SmallVector<int, 8> Mask; + int NumElts = VT.getVectorNumElements(); + + ShuffleVectorSDNode *SVOp = cast<ShuffleVectorSDNode>(N); + for (int Elt : SVOp->getMask()) + Mask.push_back(Elt < NumElts ? Elt : (Elt - NumElts / 2)); + + SDLoc DL(N); + SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, N0.getOperand(0), + N1.getOperand(0)); + return DAG.getVectorShuffle(VT, DL, Concat, DAG.getUNDEF(VT), Mask); } static SDValue combineShuffle(SDNode *N, SelectionDAG &DAG, @@ -26089,14 +28182,9 @@ static SDValue combineShuffle(SDNode *N, SelectionDAG &DAG, // If we have legalized the vector types, look for blends of FADD and FSUB // nodes that we can fuse into an ADDSUB node. if (TLI.isTypeLegal(VT)) - if (SDValue AddSub = combineShuffleToAddSub(N, Subtarget, DAG)) + if (SDValue AddSub = combineShuffleToAddSubOrFMAddSub(N, Subtarget, DAG)) return AddSub; - // Combine 256-bit vector shuffles. This is only profitable when in AVX mode - if (TLI.isTypeLegal(VT) && Subtarget.hasFp256() && VT.is256BitVector() && - N->getOpcode() == ISD::VECTOR_SHUFFLE) - return combineShuffle256(N, DAG, DCI, Subtarget); - // During Type Legalization, when promoting illegal vector types, // the backend might introduce new shuffle dag nodes and bitcasts. // @@ -26127,13 +28215,18 @@ static SDValue combineShuffle(SDNode *N, SelectionDAG &DAG, bool CanFold = false; switch (Opcode) { default : break; - case ISD::ADD : - case ISD::FADD : - case ISD::SUB : - case ISD::FSUB : - case ISD::MUL : - case ISD::FMUL : - CanFold = true; + case ISD::ADD: + case ISD::SUB: + case ISD::MUL: + // isOperationLegal lies for integer ops on floating point types. + CanFold = VT.isInteger(); + break; + case ISD::FADD: + case ISD::FSUB: + case ISD::FMUL: + // isOperationLegal lies for floating point ops on integer types. + CanFold = VT.isFloatingPoint(); + break; } unsigned SVTNumElts = SVT.getVectorNumElements(); @@ -26162,9 +28255,18 @@ static SDValue combineShuffle(SDNode *N, SelectionDAG &DAG, if (SDValue LD = EltsFromConsecutiveLoads(VT, Elts, dl, DAG, true)) return LD; + // For AVX2, we sometimes want to combine + // (vector_shuffle <mask> (concat_vectors t1, undef) + // (concat_vectors t2, undef)) + // Into: + // (vector_shuffle <mask> (concat_vectors t1, t2), undef) + // Since the latter can be efficiently lowered with VPERMD/VPERMQ + if (SDValue ShufConcat = combineShuffleOfConcatUndef(N, DAG, Subtarget)) + return ShufConcat; + if (isTargetShuffle(N->getOpcode())) { - if (SDValue Shuffle = - combineTargetShuffle(SDValue(N, 0), DAG, DCI, Subtarget)) + SDValue Op(N, 0); + if (SDValue Shuffle = combineTargetShuffle(Op, DAG, DCI, Subtarget)) return Shuffle; // Try recursively combining arbitrary sequences of x86 shuffle @@ -26174,8 +28276,8 @@ static SDValue combineShuffle(SDNode *N, SelectionDAG &DAG, // a particular chain. SmallVector<int, 1> NonceMask; // Just a placeholder. NonceMask.push_back(0); - if (combineX86ShufflesRecursively(SDValue(N, 0), SDValue(N, 0), NonceMask, - /*Depth*/ 1, /*HasPSHUFB*/ false, DAG, + if (combineX86ShufflesRecursively({Op}, 0, Op, NonceMask, + /*Depth*/ 1, /*HasVarMask*/ false, DAG, DCI, Subtarget)) return SDValue(); // This routine will use CombineTo to replace N. } @@ -26305,11 +28407,10 @@ static SDValue combineBitcast(SDNode *N, SelectionDAG &DAG, } // Convert a bitcasted integer logic operation that has one bitcasted - // floating-point operand and one constant operand into a floating-point - // logic operation. This may create a load of the constant, but that is - // cheaper than materializing the constant in an integer register and - // transferring it to an SSE register or transferring the SSE operand to - // integer register and back. + // floating-point operand into a floating-point logic operation. This may + // create a load of a constant, but that is cheaper than materializing the + // constant in an integer register and transferring it to an SSE register or + // transferring the SSE operand to integer register and back. unsigned FPOpcode; switch (N0.getOpcode()) { case ISD::AND: FPOpcode = X86ISD::FAND; break; @@ -26317,25 +28418,238 @@ static SDValue combineBitcast(SDNode *N, SelectionDAG &DAG, case ISD::XOR: FPOpcode = X86ISD::FXOR; break; default: return SDValue(); } - if (((Subtarget.hasSSE1() && VT == MVT::f32) || - (Subtarget.hasSSE2() && VT == MVT::f64)) && - isa<ConstantSDNode>(N0.getOperand(1)) && - N0.getOperand(0).getOpcode() == ISD::BITCAST && - N0.getOperand(0).getOperand(0).getValueType() == VT) { - SDValue N000 = N0.getOperand(0).getOperand(0); - SDValue FPConst = DAG.getBitcast(VT, N0.getOperand(1)); - return DAG.getNode(FPOpcode, SDLoc(N0), VT, N000, FPConst); + + if (!((Subtarget.hasSSE1() && VT == MVT::f32) || + (Subtarget.hasSSE2() && VT == MVT::f64))) + return SDValue(); + + SDValue LogicOp0 = N0.getOperand(0); + SDValue LogicOp1 = N0.getOperand(1); + SDLoc DL0(N0); + + // bitcast(logic(bitcast(X), Y)) --> logic'(X, bitcast(Y)) + if (N0.hasOneUse() && LogicOp0.getOpcode() == ISD::BITCAST && + LogicOp0.hasOneUse() && LogicOp0.getOperand(0).getValueType() == VT && + !isa<ConstantSDNode>(LogicOp0.getOperand(0))) { + SDValue CastedOp1 = DAG.getBitcast(VT, LogicOp1); + return DAG.getNode(FPOpcode, DL0, VT, LogicOp0.getOperand(0), CastedOp1); + } + // bitcast(logic(X, bitcast(Y))) --> logic'(bitcast(X), Y) + if (N0.hasOneUse() && LogicOp1.getOpcode() == ISD::BITCAST && + LogicOp1.hasOneUse() && LogicOp1.getOperand(0).getValueType() == VT && + !isa<ConstantSDNode>(LogicOp1.getOperand(0))) { + SDValue CastedOp0 = DAG.getBitcast(VT, LogicOp0); + return DAG.getNode(FPOpcode, DL0, VT, LogicOp1.getOperand(0), CastedOp0); } return SDValue(); } +// Match a binop + shuffle pyramid that represents a horizontal reduction over +// the elements of a vector. +// Returns the vector that is being reduced on, or SDValue() if a reduction +// was not matched. +static SDValue matchBinOpReduction(SDNode *Extract, ISD::NodeType BinOp) { + // The pattern must end in an extract from index 0. + if ((Extract->getOpcode() != ISD::EXTRACT_VECTOR_ELT) || + !isNullConstant(Extract->getOperand(1))) + return SDValue(); + + unsigned Stages = + Log2_32(Extract->getOperand(0).getValueType().getVectorNumElements()); + + SDValue Op = Extract->getOperand(0); + // At each stage, we're looking for something that looks like: + // %s = shufflevector <8 x i32> %op, <8 x i32> undef, + // <8 x i32> <i32 2, i32 3, i32 undef, i32 undef, + // i32 undef, i32 undef, i32 undef, i32 undef> + // %a = binop <8 x i32> %op, %s + // Where the mask changes according to the stage. E.g. for a 3-stage pyramid, + // we expect something like: + // <4,5,6,7,u,u,u,u> + // <2,3,u,u,u,u,u,u> + // <1,u,u,u,u,u,u,u> + for (unsigned i = 0; i < Stages; ++i) { + if (Op.getOpcode() != BinOp) + return SDValue(); + + ShuffleVectorSDNode *Shuffle = + dyn_cast<ShuffleVectorSDNode>(Op.getOperand(0).getNode()); + if (Shuffle) { + Op = Op.getOperand(1); + } else { + Shuffle = dyn_cast<ShuffleVectorSDNode>(Op.getOperand(1).getNode()); + Op = Op.getOperand(0); + } + + // The first operand of the shuffle should be the same as the other operand + // of the add. + if (!Shuffle || (Shuffle->getOperand(0) != Op)) + return SDValue(); + + // Verify the shuffle has the expected (at this stage of the pyramid) mask. + for (int Index = 0, MaskEnd = 1 << i; Index < MaskEnd; ++Index) + if (Shuffle->getMaskElt(Index) != MaskEnd + Index) + return SDValue(); + } + + return Op; +} + +// Given a select, detect the following pattern: +// 1: %2 = zext <N x i8> %0 to <N x i32> +// 2: %3 = zext <N x i8> %1 to <N x i32> +// 3: %4 = sub nsw <N x i32> %2, %3 +// 4: %5 = icmp sgt <N x i32> %4, [0 x N] or [-1 x N] +// 5: %6 = sub nsw <N x i32> zeroinitializer, %4 +// 6: %7 = select <N x i1> %5, <N x i32> %4, <N x i32> %6 +// This is useful as it is the input into a SAD pattern. +static bool detectZextAbsDiff(const SDValue &Select, SDValue &Op0, + SDValue &Op1) { + // Check the condition of the select instruction is greater-than. + SDValue SetCC = Select->getOperand(0); + if (SetCC.getOpcode() != ISD::SETCC) + return false; + ISD::CondCode CC = cast<CondCodeSDNode>(SetCC.getOperand(2))->get(); + if (CC != ISD::SETGT) + return false; + + SDValue SelectOp1 = Select->getOperand(1); + SDValue SelectOp2 = Select->getOperand(2); + + // The second operand of the select should be the negation of the first + // operand, which is implemented as 0 - SelectOp1. + if (!(SelectOp2.getOpcode() == ISD::SUB && + ISD::isBuildVectorAllZeros(SelectOp2.getOperand(0).getNode()) && + SelectOp2.getOperand(1) == SelectOp1)) + return false; + + // The first operand of SetCC is the first operand of the select, which is the + // difference between the two input vectors. + if (SetCC.getOperand(0) != SelectOp1) + return false; + + // The second operand of the comparison can be either -1 or 0. + if (!(ISD::isBuildVectorAllZeros(SetCC.getOperand(1).getNode()) || + ISD::isBuildVectorAllOnes(SetCC.getOperand(1).getNode()))) + return false; + + // The first operand of the select is the difference between the two input + // vectors. + if (SelectOp1.getOpcode() != ISD::SUB) + return false; + + Op0 = SelectOp1.getOperand(0); + Op1 = SelectOp1.getOperand(1); + + // Check if the operands of the sub are zero-extended from vectors of i8. + if (Op0.getOpcode() != ISD::ZERO_EXTEND || + Op0.getOperand(0).getValueType().getVectorElementType() != MVT::i8 || + Op1.getOpcode() != ISD::ZERO_EXTEND || + Op1.getOperand(0).getValueType().getVectorElementType() != MVT::i8) + return false; + + return true; +} + +// Given two zexts of <k x i8> to <k x i32>, create a PSADBW of the inputs +// to these zexts. +static SDValue createPSADBW(SelectionDAG &DAG, const SDValue &Zext0, + const SDValue &Zext1, const SDLoc &DL) { + + // Find the appropriate width for the PSADBW. + EVT InVT = Zext0.getOperand(0).getValueType(); + unsigned RegSize = std::max(128u, InVT.getSizeInBits()); + + // "Zero-extend" the i8 vectors. This is not a per-element zext, rather we + // fill in the missing vector elements with 0. + unsigned NumConcat = RegSize / InVT.getSizeInBits(); + SmallVector<SDValue, 16> Ops(NumConcat, DAG.getConstant(0, DL, InVT)); + Ops[0] = Zext0.getOperand(0); + MVT ExtendedVT = MVT::getVectorVT(MVT::i8, RegSize / 8); + SDValue SadOp0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops); + Ops[0] = Zext1.getOperand(0); + SDValue SadOp1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops); + + // Actually build the SAD + MVT SadVT = MVT::getVectorVT(MVT::i64, RegSize / 64); + return DAG.getNode(X86ISD::PSADBW, DL, SadVT, SadOp0, SadOp1); +} + +static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + // PSADBW is only supported on SSE2 and up. + if (!Subtarget.hasSSE2()) + return SDValue(); + + // Verify the type we're extracting from is appropriate + // TODO: There's nothing special about i32, any integer type above i16 should + // work just as well. + EVT VT = Extract->getOperand(0).getValueType(); + if (!VT.isSimple() || !(VT.getVectorElementType() == MVT::i32)) + return SDValue(); + + unsigned RegSize = 128; + if (Subtarget.hasBWI()) + RegSize = 512; + else if (Subtarget.hasAVX2()) + RegSize = 256; + + // We only handle v16i32 for SSE2 / v32i32 for AVX2 / v64i32 for AVX512. + // TODO: We should be able to handle larger vectors by splitting them before + // feeding them into several SADs, and then reducing over those. + if (VT.getSizeInBits() / 4 > RegSize) + return SDValue(); + + // Match shuffle + add pyramid. + SDValue Root = matchBinOpReduction(Extract, ISD::ADD); + + // If there was a match, we want Root to be a select that is the root of an + // abs-diff pattern. + if (!Root || (Root.getOpcode() != ISD::VSELECT)) + return SDValue(); + + // Check whether we have an abs-diff pattern feeding into the select. + SDValue Zext0, Zext1; + if (!detectZextAbsDiff(Root, Zext0, Zext1)) + return SDValue(); + + // Create the SAD instruction + SDLoc DL(Extract); + SDValue SAD = createPSADBW(DAG, Zext0, Zext1, DL); + + // If the original vector was wider than 8 elements, sum over the results + // in the SAD vector. + unsigned Stages = Log2_32(VT.getVectorNumElements()); + MVT SadVT = SAD.getSimpleValueType(); + if (Stages > 3) { + unsigned SadElems = SadVT.getVectorNumElements(); + + for(unsigned i = Stages - 3; i > 0; --i) { + SmallVector<int, 16> Mask(SadElems, -1); + for(unsigned j = 0, MaskEnd = 1 << (i - 1); j < MaskEnd; ++j) + Mask[j] = MaskEnd + j; + + SDValue Shuffle = + DAG.getVectorShuffle(SadVT, DL, SAD, DAG.getUNDEF(SadVT), Mask); + SAD = DAG.getNode(ISD::ADD, DL, SadVT, SAD, Shuffle); + } + } + + // Return the lowest i32. + MVT ResVT = MVT::getVectorVT(MVT::i32, SadVT.getSizeInBits() / 32); + SAD = DAG.getNode(ISD::BITCAST, DL, ResVT, SAD); + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i32, SAD, + Extract->getOperand(1)); +} + /// Detect vector gather/scatter index generation and convert it from being a /// bunch of shuffles and extracts into a somewhat faster sequence. /// For i686, the best sequence is apparently storing the value and loading /// scalars back, while for x64 we should use 64-bit extracts and shifts. static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG, - TargetLowering::DAGCombinerInfo &DCI) { + TargetLowering::DAGCombinerInfo &DCI, + const X86Subtarget &Subtarget) { if (SDValue NewOp = XFormVExtractWithShuffleIntoLoad(N, DAG, DCI)) return NewOp; @@ -26347,7 +28661,7 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG, InputVector.getValueType() == MVT::v2i32 && isa<ConstantSDNode>(N->getOperand(1)) && N->getConstantOperandVal(1) == 0) { - SDValue MMXSrc = InputVector.getNode()->getOperand(0); + SDValue MMXSrc = InputVector.getOperand(0); // The bitcast source is a direct mmx result. if (MMXSrc.getValueType() == MVT::x86mmx) @@ -26366,6 +28680,13 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG, uint64_t Res = (InputValue >> ExtractedElt) & 1; return DAG.getConstant(Res, dl, MVT::i1); } + + // Check whether this extract is the root of a sum of absolute differences + // pattern. This has to be done here because we really want it to happen + // pre-legalization, + if (SDValue SAD = combineBasicSADPattern(N, DAG, Subtarget)) + return SAD; + // Only operate on vectors of 4 elements, where the alternative shuffling // gets to be more expensive. if (InputVector.getValueType() != MVT::v4i32) @@ -26467,6 +28788,310 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG, return SDValue(); } +/// If a vector select has an operand that is -1 or 0, try to simplify the +/// select to a bitwise logic operation. +static SDValue +combineVSelectWithAllOnesOrZeros(SDNode *N, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI, + const X86Subtarget &Subtarget) { + SDValue Cond = N->getOperand(0); + SDValue LHS = N->getOperand(1); + SDValue RHS = N->getOperand(2); + EVT VT = LHS.getValueType(); + EVT CondVT = Cond.getValueType(); + SDLoc DL(N); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + + if (N->getOpcode() != ISD::VSELECT) + return SDValue(); + + assert(CondVT.isVector() && "Vector select expects a vector selector!"); + + bool FValIsAllZeros = ISD::isBuildVectorAllZeros(LHS.getNode()); + // Check if the first operand is all zeros and Cond type is vXi1. + // This situation only applies to avx512. + if (FValIsAllZeros && Subtarget.hasAVX512() && Cond.hasOneUse() && + CondVT.getVectorElementType() == MVT::i1) { + //Invert the cond to not(cond) : xor(op,allones)=not(op) + SDValue CondNew = DAG.getNode(ISD::XOR, DL, Cond.getValueType(), Cond, + DAG.getConstant(APInt::getAllOnesValue(CondVT.getScalarSizeInBits()), + DL, CondVT)); + //Vselect cond, op1, op2 = Vselect not(cond), op2, op1 + return DAG.getNode(ISD::VSELECT, DL, VT, CondNew, RHS, LHS); + } + + // To use the condition operand as a bitwise mask, it must have elements that + // are the same size as the select elements. Ie, the condition operand must + // have already been promoted from the IR select condition type <N x i1>. + // Don't check if the types themselves are equal because that excludes + // vector floating-point selects. + if (CondVT.getScalarSizeInBits() != VT.getScalarSizeInBits()) + return SDValue(); + + bool TValIsAllOnes = ISD::isBuildVectorAllOnes(LHS.getNode()); + FValIsAllZeros = ISD::isBuildVectorAllZeros(RHS.getNode()); + + // Try to invert the condition if true value is not all 1s and false value is + // not all 0s. + if (!TValIsAllOnes && !FValIsAllZeros && + // Check if the selector will be produced by CMPP*/PCMP*. + Cond.getOpcode() == ISD::SETCC && + // Check if SETCC has already been promoted. + TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT) == + CondVT) { + bool TValIsAllZeros = ISD::isBuildVectorAllZeros(LHS.getNode()); + bool FValIsAllOnes = ISD::isBuildVectorAllOnes(RHS.getNode()); + + if (TValIsAllZeros || FValIsAllOnes) { + SDValue CC = Cond.getOperand(2); + ISD::CondCode NewCC = + ISD::getSetCCInverse(cast<CondCodeSDNode>(CC)->get(), + Cond.getOperand(0).getValueType().isInteger()); + Cond = DAG.getSetCC(DL, CondVT, Cond.getOperand(0), Cond.getOperand(1), + NewCC); + std::swap(LHS, RHS); + TValIsAllOnes = FValIsAllOnes; + FValIsAllZeros = TValIsAllZeros; + } + } + + // vselect Cond, 111..., 000... -> Cond + if (TValIsAllOnes && FValIsAllZeros) + return DAG.getBitcast(VT, Cond); + + if (!DCI.isBeforeLegalize() && !TLI.isTypeLegal(CondVT)) + return SDValue(); + + // vselect Cond, 111..., X -> or Cond, X + if (TValIsAllOnes) { + SDValue CastRHS = DAG.getBitcast(CondVT, RHS); + SDValue Or = DAG.getNode(ISD::OR, DL, CondVT, Cond, CastRHS); + return DAG.getBitcast(VT, Or); + } + + // vselect Cond, X, 000... -> and Cond, X + if (FValIsAllZeros) { + SDValue CastLHS = DAG.getBitcast(CondVT, LHS); + SDValue And = DAG.getNode(ISD::AND, DL, CondVT, Cond, CastLHS); + return DAG.getBitcast(VT, And); + } + + return SDValue(); +} + +static SDValue combineSelectOfTwoConstants(SDNode *N, SelectionDAG &DAG) { + SDValue Cond = N->getOperand(0); + SDValue LHS = N->getOperand(1); + SDValue RHS = N->getOperand(2); + SDLoc DL(N); + + auto *TrueC = dyn_cast<ConstantSDNode>(LHS); + auto *FalseC = dyn_cast<ConstantSDNode>(RHS); + if (!TrueC || !FalseC) + return SDValue(); + + // Don't do this for crazy integer types. + if (!DAG.getTargetLoweringInfo().isTypeLegal(LHS.getValueType())) + return SDValue(); + + // If this is efficiently invertible, canonicalize the LHSC/RHSC values + // so that TrueC (the true value) is larger than FalseC. + bool NeedsCondInvert = false; + if (TrueC->getAPIntValue().ult(FalseC->getAPIntValue()) && + // Efficiently invertible. + (Cond.getOpcode() == ISD::SETCC || // setcc -> invertible. + (Cond.getOpcode() == ISD::XOR && // xor(X, C) -> invertible. + isa<ConstantSDNode>(Cond.getOperand(1))))) { + NeedsCondInvert = true; + std::swap(TrueC, FalseC); + } + + // Optimize C ? 8 : 0 -> zext(C) << 3. Likewise for any pow2/0. + if (FalseC->getAPIntValue() == 0 && TrueC->getAPIntValue().isPowerOf2()) { + if (NeedsCondInvert) // Invert the condition if needed. + Cond = DAG.getNode(ISD::XOR, DL, Cond.getValueType(), Cond, + DAG.getConstant(1, DL, Cond.getValueType())); + + // Zero extend the condition if needed. + Cond = DAG.getNode(ISD::ZERO_EXTEND, DL, LHS.getValueType(), Cond); + + unsigned ShAmt = TrueC->getAPIntValue().logBase2(); + return DAG.getNode(ISD::SHL, DL, LHS.getValueType(), Cond, + DAG.getConstant(ShAmt, DL, MVT::i8)); + } + + // Optimize Cond ? cst+1 : cst -> zext(setcc(C)+cst. + if (FalseC->getAPIntValue() + 1 == TrueC->getAPIntValue()) { + if (NeedsCondInvert) // Invert the condition if needed. + Cond = DAG.getNode(ISD::XOR, DL, Cond.getValueType(), Cond, + DAG.getConstant(1, DL, Cond.getValueType())); + + // Zero extend the condition if needed. + Cond = DAG.getNode(ISD::ZERO_EXTEND, DL, FalseC->getValueType(0), Cond); + return DAG.getNode(ISD::ADD, DL, Cond.getValueType(), Cond, + SDValue(FalseC, 0)); + } + + // Optimize cases that will turn into an LEA instruction. This requires + // an i32 or i64 and an efficient multiplier (1, 2, 3, 4, 5, 8, 9). + if (N->getValueType(0) == MVT::i32 || N->getValueType(0) == MVT::i64) { + uint64_t Diff = TrueC->getZExtValue() - FalseC->getZExtValue(); + if (N->getValueType(0) == MVT::i32) + Diff = (unsigned)Diff; + + bool isFastMultiplier = false; + if (Diff < 10) { + switch ((unsigned char)Diff) { + default: + break; + case 1: // result = add base, cond + case 2: // result = lea base( , cond*2) + case 3: // result = lea base(cond, cond*2) + case 4: // result = lea base( , cond*4) + case 5: // result = lea base(cond, cond*4) + case 8: // result = lea base( , cond*8) + case 9: // result = lea base(cond, cond*8) + isFastMultiplier = true; + break; + } + } + + if (isFastMultiplier) { + APInt Diff = TrueC->getAPIntValue() - FalseC->getAPIntValue(); + if (NeedsCondInvert) // Invert the condition if needed. + Cond = DAG.getNode(ISD::XOR, DL, Cond.getValueType(), Cond, + DAG.getConstant(1, DL, Cond.getValueType())); + + // Zero extend the condition if needed. + Cond = DAG.getNode(ISD::ZERO_EXTEND, DL, FalseC->getValueType(0), Cond); + // Scale the condition by the difference. + if (Diff != 1) + Cond = DAG.getNode(ISD::MUL, DL, Cond.getValueType(), Cond, + DAG.getConstant(Diff, DL, Cond.getValueType())); + + // Add the base if non-zero. + if (FalseC->getAPIntValue() != 0) + Cond = DAG.getNode(ISD::ADD, DL, Cond.getValueType(), Cond, + SDValue(FalseC, 0)); + return Cond; + } + } + + return SDValue(); +} + +// If this is a bitcasted op that can be represented as another type, push the +// the bitcast to the inputs. This allows more opportunities for pattern +// matching masked instructions. This is called when we know that the operation +// is used as one of the inputs of a vselect. +static bool combineBitcastForMaskedOp(SDValue OrigOp, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI) { + // Make sure we have a bitcast. + if (OrigOp.getOpcode() != ISD::BITCAST) + return false; + + SDValue Op = OrigOp.getOperand(0); + + // If the operation is used by anything other than the bitcast, we shouldn't + // do this combine as that would replicate the operation. + if (!Op.hasOneUse()) + return false; + + MVT VT = OrigOp.getSimpleValueType(); + MVT EltVT = VT.getVectorElementType(); + SDLoc DL(Op.getNode()); + + auto BitcastAndCombineShuffle = [&](unsigned Opcode, SDValue Op0, SDValue Op1, + SDValue Op2) { + Op0 = DAG.getBitcast(VT, Op0); + DCI.AddToWorklist(Op0.getNode()); + Op1 = DAG.getBitcast(VT, Op1); + DCI.AddToWorklist(Op1.getNode()); + DCI.CombineTo(OrigOp.getNode(), + DAG.getNode(Opcode, DL, VT, Op0, Op1, Op2)); + return true; + }; + + unsigned Opcode = Op.getOpcode(); + switch (Opcode) { + case X86ISD::PALIGNR: + // PALIGNR can be converted to VALIGND/Q for 128-bit vectors. + if (!VT.is128BitVector()) + return false; + Opcode = X86ISD::VALIGN; + LLVM_FALLTHROUGH; + case X86ISD::VALIGN: { + if (EltVT != MVT::i32 && EltVT != MVT::i64) + return false; + uint64_t Imm = cast<ConstantSDNode>(Op.getOperand(2))->getZExtValue(); + MVT OpEltVT = Op.getSimpleValueType().getVectorElementType(); + unsigned ShiftAmt = Imm * OpEltVT.getSizeInBits(); + unsigned EltSize = EltVT.getSizeInBits(); + // Make sure we can represent the same shift with the new VT. + if ((ShiftAmt % EltSize) != 0) + return false; + Imm = ShiftAmt / EltSize; + return BitcastAndCombineShuffle(Opcode, Op.getOperand(0), Op.getOperand(1), + DAG.getConstant(Imm, DL, MVT::i8)); + } + case X86ISD::SHUF128: { + if (EltVT.getSizeInBits() != 32 && EltVT.getSizeInBits() != 64) + return false; + // Only change element size, not type. + if (VT.isInteger() != Op.getSimpleValueType().isInteger()) + return false; + return BitcastAndCombineShuffle(Opcode, Op.getOperand(0), Op.getOperand(1), + Op.getOperand(2)); + } + case ISD::INSERT_SUBVECTOR: { + unsigned EltSize = EltVT.getSizeInBits(); + if (EltSize != 32 && EltSize != 64) + return false; + MVT OpEltVT = Op.getSimpleValueType().getVectorElementType(); + // Only change element size, not type. + if (VT.isInteger() != OpEltVT.isInteger()) + return false; + uint64_t Imm = cast<ConstantSDNode>(Op.getOperand(2))->getZExtValue(); + Imm = (Imm * OpEltVT.getSizeInBits()) / EltSize; + SDValue Op0 = DAG.getBitcast(VT, Op.getOperand(0)); + DCI.AddToWorklist(Op0.getNode()); + // Op1 needs to be bitcasted to a smaller vector with the same element type. + SDValue Op1 = Op.getOperand(1); + MVT Op1VT = MVT::getVectorVT(EltVT, + Op1.getSimpleValueType().getSizeInBits() / EltSize); + Op1 = DAG.getBitcast(Op1VT, Op1); + DCI.AddToWorklist(Op1.getNode()); + DCI.CombineTo(OrigOp.getNode(), + DAG.getNode(Opcode, DL, VT, Op0, Op1, + DAG.getConstant(Imm, DL, MVT::i8))); + return true; + } + case ISD::EXTRACT_SUBVECTOR: { + unsigned EltSize = EltVT.getSizeInBits(); + if (EltSize != 32 && EltSize != 64) + return false; + MVT OpEltVT = Op.getSimpleValueType().getVectorElementType(); + // Only change element size, not type. + if (VT.isInteger() != OpEltVT.isInteger()) + return false; + uint64_t Imm = cast<ConstantSDNode>(Op.getOperand(1))->getZExtValue(); + Imm = (Imm * OpEltVT.getSizeInBits()) / EltSize; + // Op0 needs to be bitcasted to a larger vector with the same element type. + SDValue Op0 = Op.getOperand(0); + MVT Op0VT = MVT::getVectorVT(EltVT, + Op0.getSimpleValueType().getSizeInBits() / EltSize); + Op0 = DAG.getBitcast(Op0VT, Op0); + DCI.AddToWorklist(Op0.getNode()); + DCI.CombineTo(OrigOp.getNode(), + DAG.getNode(Opcode, DL, VT, Op0, + DAG.getConstant(Imm, DL, MVT::i8))); + return true; + } + } + + return false; +} + /// Do target-specific dag combines on SELECT and VSELECT nodes. static SDValue combineSelect(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, @@ -26477,6 +29102,7 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG, SDValue LHS = N->getOperand(1); SDValue RHS = N->getOperand(2); EVT VT = LHS.getValueType(); + EVT CondVT = Cond.getValueType(); const TargetLowering &TLI = DAG.getTargetLoweringInfo(); // If we have SSE[12] support, try to form min/max nodes. SSE min/max @@ -26625,117 +29251,24 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG, return DAG.getNode(Opcode, DL, N->getValueType(0), LHS, RHS); } - EVT CondVT = Cond.getValueType(); - if (Subtarget.hasAVX512() && VT.isVector() && CondVT.isVector() && - CondVT.getVectorElementType() == MVT::i1) { - // v16i8 (select v16i1, v16i8, v16i8) does not have a proper - // lowering on KNL. In this case we convert it to - // v16i8 (select v16i8, v16i8, v16i8) and use AVX instruction. - // The same situation for all 128 and 256-bit vectors of i8 and i16. - // Since SKX these selects have a proper lowering. - EVT OpVT = LHS.getValueType(); - if ((OpVT.is128BitVector() || OpVT.is256BitVector()) && - (OpVT.getVectorElementType() == MVT::i8 || - OpVT.getVectorElementType() == MVT::i16) && - !(Subtarget.hasBWI() && Subtarget.hasVLX())) { - Cond = DAG.getNode(ISD::SIGN_EXTEND, DL, OpVT, Cond); - DCI.AddToWorklist(Cond.getNode()); - return DAG.getNode(N->getOpcode(), DL, OpVT, Cond, LHS, RHS); - } + // v16i8 (select v16i1, v16i8, v16i8) does not have a proper + // lowering on KNL. In this case we convert it to + // v16i8 (select v16i8, v16i8, v16i8) and use AVX instruction. + // The same situation for all 128 and 256-bit vectors of i8 and i16. + // Since SKX these selects have a proper lowering. + if (Subtarget.hasAVX512() && CondVT.isVector() && + CondVT.getVectorElementType() == MVT::i1 && + (VT.is128BitVector() || VT.is256BitVector()) && + (VT.getVectorElementType() == MVT::i8 || + VT.getVectorElementType() == MVT::i16) && + !(Subtarget.hasBWI() && Subtarget.hasVLX())) { + Cond = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Cond); + DCI.AddToWorklist(Cond.getNode()); + return DAG.getNode(N->getOpcode(), DL, VT, Cond, LHS, RHS); } - // If this is a select between two integer constants, try to do some - // optimizations. - if (ConstantSDNode *TrueC = dyn_cast<ConstantSDNode>(LHS)) { - if (ConstantSDNode *FalseC = dyn_cast<ConstantSDNode>(RHS)) - // Don't do this for crazy integer types. - if (DAG.getTargetLoweringInfo().isTypeLegal(LHS.getValueType())) { - // If this is efficiently invertible, canonicalize the LHSC/RHSC values - // so that TrueC (the true value) is larger than FalseC. - bool NeedsCondInvert = false; - - if (TrueC->getAPIntValue().ult(FalseC->getAPIntValue()) && - // Efficiently invertible. - (Cond.getOpcode() == ISD::SETCC || // setcc -> invertible. - (Cond.getOpcode() == ISD::XOR && // xor(X, C) -> invertible. - isa<ConstantSDNode>(Cond.getOperand(1))))) { - NeedsCondInvert = true; - std::swap(TrueC, FalseC); - } - - // Optimize C ? 8 : 0 -> zext(C) << 3. Likewise for any pow2/0. - if (FalseC->getAPIntValue() == 0 && - TrueC->getAPIntValue().isPowerOf2()) { - if (NeedsCondInvert) // Invert the condition if needed. - Cond = DAG.getNode(ISD::XOR, DL, Cond.getValueType(), Cond, - DAG.getConstant(1, DL, Cond.getValueType())); - - // Zero extend the condition if needed. - Cond = DAG.getNode(ISD::ZERO_EXTEND, DL, LHS.getValueType(), Cond); - - unsigned ShAmt = TrueC->getAPIntValue().logBase2(); - return DAG.getNode(ISD::SHL, DL, LHS.getValueType(), Cond, - DAG.getConstant(ShAmt, DL, MVT::i8)); - } - - // Optimize Cond ? cst+1 : cst -> zext(setcc(C)+cst. - if (FalseC->getAPIntValue()+1 == TrueC->getAPIntValue()) { - if (NeedsCondInvert) // Invert the condition if needed. - Cond = DAG.getNode(ISD::XOR, DL, Cond.getValueType(), Cond, - DAG.getConstant(1, DL, Cond.getValueType())); - - // Zero extend the condition if needed. - Cond = DAG.getNode(ISD::ZERO_EXTEND, DL, - FalseC->getValueType(0), Cond); - return DAG.getNode(ISD::ADD, DL, Cond.getValueType(), Cond, - SDValue(FalseC, 0)); - } - // Optimize cases that will turn into an LEA instruction. This requires - // an i32 or i64 and an efficient multiplier (1, 2, 3, 4, 5, 8, 9). - if (N->getValueType(0) == MVT::i32 || N->getValueType(0) == MVT::i64) { - uint64_t Diff = TrueC->getZExtValue()-FalseC->getZExtValue(); - if (N->getValueType(0) == MVT::i32) Diff = (unsigned)Diff; - - bool isFastMultiplier = false; - if (Diff < 10) { - switch ((unsigned char)Diff) { - default: break; - case 1: // result = add base, cond - case 2: // result = lea base( , cond*2) - case 3: // result = lea base(cond, cond*2) - case 4: // result = lea base( , cond*4) - case 5: // result = lea base(cond, cond*4) - case 8: // result = lea base( , cond*8) - case 9: // result = lea base(cond, cond*8) - isFastMultiplier = true; - break; - } - } - - if (isFastMultiplier) { - APInt Diff = TrueC->getAPIntValue()-FalseC->getAPIntValue(); - if (NeedsCondInvert) // Invert the condition if needed. - Cond = DAG.getNode(ISD::XOR, DL, Cond.getValueType(), Cond, - DAG.getConstant(1, DL, Cond.getValueType())); - - // Zero extend the condition if needed. - Cond = DAG.getNode(ISD::ZERO_EXTEND, DL, FalseC->getValueType(0), - Cond); - // Scale the condition by the difference. - if (Diff != 1) - Cond = DAG.getNode(ISD::MUL, DL, Cond.getValueType(), Cond, - DAG.getConstant(Diff, DL, - Cond.getValueType())); - - // Add the base if non-zero. - if (FalseC->getAPIntValue() != 0) - Cond = DAG.getNode(ISD::ADD, DL, Cond.getValueType(), Cond, - SDValue(FalseC, 0)); - return Cond; - } - } - } - } + if (SDValue V = combineSelectOfTwoConstants(N, DAG)) + return V; // Canonicalize max and min: // (x > y) ? x : y -> (x >= y) ? x : y @@ -26832,53 +29365,8 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG, } } - // Simplify vector selection if condition value type matches vselect - // operand type - if (N->getOpcode() == ISD::VSELECT && CondVT == VT) { - assert(Cond.getValueType().isVector() && - "vector select expects a vector selector!"); - - bool TValIsAllOnes = ISD::isBuildVectorAllOnes(LHS.getNode()); - bool FValIsAllZeros = ISD::isBuildVectorAllZeros(RHS.getNode()); - - // Try invert the condition if true value is not all 1s and false value - // is not all 0s. - if (!TValIsAllOnes && !FValIsAllZeros && - // Check if the selector will be produced by CMPP*/PCMP* - Cond.getOpcode() == ISD::SETCC && - // Check if SETCC has already been promoted - TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT) == - CondVT) { - bool TValIsAllZeros = ISD::isBuildVectorAllZeros(LHS.getNode()); - bool FValIsAllOnes = ISD::isBuildVectorAllOnes(RHS.getNode()); - - if (TValIsAllZeros || FValIsAllOnes) { - SDValue CC = Cond.getOperand(2); - ISD::CondCode NewCC = - ISD::getSetCCInverse(cast<CondCodeSDNode>(CC)->get(), - Cond.getOperand(0).getValueType().isInteger()); - Cond = DAG.getSetCC(DL, CondVT, Cond.getOperand(0), Cond.getOperand(1), NewCC); - std::swap(LHS, RHS); - TValIsAllOnes = FValIsAllOnes; - FValIsAllZeros = TValIsAllZeros; - } - } - - if (TValIsAllOnes || FValIsAllZeros) { - SDValue Ret; - - if (TValIsAllOnes && FValIsAllZeros) - Ret = Cond; - else if (TValIsAllOnes) - Ret = - DAG.getNode(ISD::OR, DL, CondVT, Cond, DAG.getBitcast(CondVT, RHS)); - else if (FValIsAllZeros) - Ret = DAG.getNode(ISD::AND, DL, CondVT, Cond, - DAG.getBitcast(CondVT, LHS)); - - return DAG.getBitcast(VT, Ret); - } - } + if (SDValue V = combineVSelectWithAllOnesOrZeros(N, DAG, DCI, Subtarget)) + return V; // If this is a *dynamic* select (non-constant condition) and we can match // this node with one of the variable blend instructions, restructure the @@ -26887,7 +29375,7 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG, if (N->getOpcode() == ISD::VSELECT && DCI.isBeforeLegalizeOps() && !DCI.isBeforeLegalize() && !ISD::isBuildVectorOfConstantSDNodes(Cond.getNode())) { - unsigned BitWidth = Cond.getValueType().getScalarSizeInBits(); + unsigned BitWidth = Cond.getScalarValueSizeInBits(); // Don't optimize vector selects that map to mask-registers. if (BitWidth == 1) @@ -26965,6 +29453,17 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG, } } + // Look for vselects with LHS/RHS being bitcasted from an operation that + // can be executed on another type. Push the bitcast to the inputs of + // the operation. This exposes opportunities for using masking instructions. + if (N->getOpcode() == ISD::VSELECT && !DCI.isBeforeLegalizeOps() && + CondVT.getVectorElementType() == MVT::i1) { + if (combineBitcastForMaskedOp(LHS, DAG, DCI)) + return SDValue(N, 0); + if (combineBitcastForMaskedOp(RHS, DAG, DCI)) + return SDValue(N, 0); + } + return SDValue(); } @@ -26981,6 +29480,12 @@ static SDValue combineSetCCAtomicArith(SDValue Cmp, X86::CondCode &CC, (Cmp.getOpcode() == X86ISD::SUB && !Cmp->hasAnyUseOfValue(0)))) return SDValue(); + // Can't replace the cmp if it has more uses than the one we're looking at. + // FIXME: We would like to be able to handle this, but would need to make sure + // all uses were updated. + if (!Cmp.hasOneUse()) + return SDValue(); + // This only applies to variations of the common case: // (icmp slt x, 0) -> (icmp sle (add x, 1), 0) // (icmp sge x, 0) -> (icmp sgt (add x, 1), 0) @@ -27088,7 +29593,6 @@ static SDValue checkBoolTestSetCCCombine(SDValue Cmp, X86::CondCode &CC) { // Skip (zext $x), (trunc $x), or (and $x, 1) node. while (SetCC.getOpcode() == ISD::ZERO_EXTEND || SetCC.getOpcode() == ISD::TRUNCATE || - SetCC.getOpcode() == ISD::AssertZext || SetCC.getOpcode() == ISD::AND) { if (SetCC.getOpcode() == ISD::AND) { int OpIdx = -1; @@ -27114,7 +29618,7 @@ static SDValue checkBoolTestSetCCCombine(SDValue Cmp, X86::CondCode &CC) { break; assert(X86::CondCode(SetCC.getConstantOperandVal(0)) == X86::COND_B && "Invalid use of SETCC_CARRY!"); - // FALL THROUGH + LLVM_FALLTHROUGH; case X86ISD::SETCC: // Set the condition code or opposite one if necessary. CC = X86::CondCode(SetCC.getConstantOperandVal(0)); @@ -27187,7 +29691,7 @@ static bool checkBoolTestAndOrSetCCCombine(SDValue Cond, X86::CondCode &CC0, case ISD::AND: case X86ISD::AND: isAnd = true; - // fallthru + LLVM_FALLTHROUGH; case ISD::OR: case X86ISD::OR: SetCC0 = Cond->getOperand(0); @@ -27270,8 +29774,7 @@ static SDValue combineCMov(SDNode *N, SelectionDAG &DAG, // This is efficient for any integer data type (including i8/i16) and // shift amount. if (FalseC->getAPIntValue() == 0 && TrueC->getAPIntValue().isPowerOf2()) { - Cond = DAG.getNode(X86ISD::SETCC, DL, MVT::i8, - DAG.getConstant(CC, DL, MVT::i8), Cond); + Cond = getSETCC(CC, Cond, DL, DAG); // Zero extend the condition if needed. Cond = DAG.getNode(ISD::ZERO_EXTEND, DL, TrueC->getValueType(0), Cond); @@ -27287,8 +29790,7 @@ static SDValue combineCMov(SDNode *N, SelectionDAG &DAG, // Optimize Cond ? cst+1 : cst -> zext(setcc(C)+cst. This is efficient // for any integer data type, including i8/i16. if (FalseC->getAPIntValue()+1 == TrueC->getAPIntValue()) { - Cond = DAG.getNode(X86ISD::SETCC, DL, MVT::i8, - DAG.getConstant(CC, DL, MVT::i8), Cond); + Cond = getSETCC(CC, Cond, DL, DAG); // Zero extend the condition if needed. Cond = DAG.getNode(ISD::ZERO_EXTEND, DL, @@ -27325,8 +29827,7 @@ static SDValue combineCMov(SDNode *N, SelectionDAG &DAG, if (isFastMultiplier) { APInt Diff = TrueC->getAPIntValue()-FalseC->getAPIntValue(); - Cond = DAG.getNode(X86ISD::SETCC, DL, MVT::i8, - DAG.getConstant(CC, DL, MVT::i8), Cond); + Cond = getSETCC(CC, Cond, DL ,DAG); // Zero extend the condition if needed. Cond = DAG.getNode(ISD::ZERO_EXTEND, DL, FalseC->getValueType(0), Cond); @@ -27525,10 +30026,17 @@ static bool canReduceVMulWidth(SDNode *N, SelectionDAG &DAG, ShrinkMode &Mode) { /// generate pmullw+pmulhuw for it (MULU16 mode). static SDValue reduceVMULWidth(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { - // pmulld is supported since SSE41. It is better to use pmulld - // instead of pmullw+pmulhw. + // Check for legality // pmullw/pmulhw are not supported by SSE. - if (Subtarget.hasSSE41() || !Subtarget.hasSSE2()) + if (!Subtarget.hasSSE2()) + return SDValue(); + + // Check for profitability + // pmulld is supported since SSE41. It is better to use pmulld + // instead of pmullw+pmulhw, except for subtargets where pmulld is slower than + // the expansion. + bool OptForMinSize = DAG.getMachineFunction().getFunction()->optForMinSize(); + if (Subtarget.hasSSE41() && (OptForMinSize || !Subtarget.isPMULLDSlow())) return SDValue(); ShrinkMode Mode; @@ -27591,7 +30099,12 @@ static SDValue reduceVMULWidth(SDNode *N, SelectionDAG &DAG, // <4 x i16> undef). // // Legalize the operands of mul. - SmallVector<SDValue, 16> Ops(RegSize / ReducedVT.getSizeInBits(), + // FIXME: We may be able to handle non-concatenated vectors by insertion. + unsigned ReducedSizeInBits = ReducedVT.getSizeInBits(); + if ((RegSize % ReducedSizeInBits) != 0) + return SDValue(); + + SmallVector<SDValue, 16> Ops(RegSize / ReducedSizeInBits, DAG.getUNDEF(ReducedVT)); Ops[0] = NewN0; NewN0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, OpsVT, Ops); @@ -27851,7 +30364,7 @@ static SDValue performShiftToAllZeros(SDNode *N, SelectionDAG &DAG, if (auto *AmtSplat = AmtBV->getConstantSplatNode()) { const APInt &ShiftAmt = AmtSplat->getAPIntValue(); unsigned MaxAmount = - VT.getSimpleVT().getVectorElementType().getSizeInBits(); + VT.getSimpleVT().getScalarSizeInBits(); // SSE2/AVX2 logical shifts always return a vector of 0s // if the shift amount is bigger than or equal to @@ -27883,6 +30396,45 @@ static SDValue combineShift(SDNode* N, SelectionDAG &DAG, return SDValue(); } +static SDValue combineVectorShift(SDNode *N, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI, + const X86Subtarget &Subtarget) { + assert((X86ISD::VSHLI == N->getOpcode() || X86ISD::VSRLI == N->getOpcode()) && + "Unexpected opcode"); + EVT VT = N->getValueType(0); + unsigned NumBitsPerElt = VT.getScalarSizeInBits(); + + // This fails for mask register (vXi1) shifts. + if ((NumBitsPerElt % 8) != 0) + return SDValue(); + + // Out of range logical bit shifts are guaranteed to be zero. + APInt ShiftVal = cast<ConstantSDNode>(N->getOperand(1))->getAPIntValue(); + if (ShiftVal.zextOrTrunc(8).uge(NumBitsPerElt)) + return getZeroVector(VT.getSimpleVT(), Subtarget, DAG, SDLoc(N)); + + // Shift N0 by zero -> N0. + if (!ShiftVal) + return N->getOperand(0); + + // Shift zero -> zero. + if (ISD::isBuildVectorAllZeros(N->getOperand(0).getNode())) + return getZeroVector(VT.getSimpleVT(), Subtarget, DAG, SDLoc(N)); + + // We can decode 'whole byte' logical bit shifts as shuffles. + if ((ShiftVal.getZExtValue() % 8) == 0) { + SDValue Op(N, 0); + SmallVector<int, 1> NonceMask; // Just a placeholder. + NonceMask.push_back(0); + if (combineX86ShufflesRecursively({Op}, 0, Op, NonceMask, + /*Depth*/ 1, /*HasVarMask*/ false, DAG, + DCI, Subtarget)) + return SDValue(); // This routine will use CombineTo to replace N. + } + + return SDValue(); +} + /// Recognize the distinctive (AND (setcc ...) (setcc ..)) where both setccs /// reference the same FP CMP, and rewrite for CMPEQSS and friends. Likewise for /// OR -> CMPNEQSS. @@ -27943,7 +30495,7 @@ static SDValue combineCompareEqual(SDNode *N, SelectionDAG &DAG, // See X86ATTInstPrinter.cpp:printSSECC(). unsigned x86cc = (cc0 == X86::COND_E) ? 0 : 4; if (Subtarget.hasAVX512()) { - SDValue FSetCC = DAG.getNode(X86ISD::FSETCC, DL, MVT::i1, CMP00, + SDValue FSetCC = DAG.getNode(X86ISD::FSETCCM, DL, MVT::i1, CMP00, CMP01, DAG.getConstant(x86cc, DL, MVT::i8)); if (N->getValueType(0) != MVT::i1) @@ -27995,9 +30547,7 @@ static SDValue combineANDXORWithAllOnesIntoANDNP(SDNode *N, SelectionDAG &DAG) { SDValue N1 = N->getOperand(1); SDLoc DL(N); - if (VT != MVT::v2i64 && VT != MVT::v4i64 && - VT != MVT::v8i64 && VT != MVT::v16i32 && - VT != MVT::v4i32 && VT != MVT::v8i32) // Legal with VLX + if (VT != MVT::v2i64 && VT != MVT::v4i64 && VT != MVT::v8i64) return SDValue(); // Canonicalize XOR to the left. @@ -28111,95 +30661,6 @@ static SDValue WidenMaskArithmetic(SDNode *N, SelectionDAG &DAG, } } -static SDValue combineVectorZext(SDNode *N, SelectionDAG &DAG, - TargetLowering::DAGCombinerInfo &DCI, - const X86Subtarget &Subtarget) { - SDValue N0 = N->getOperand(0); - SDValue N1 = N->getOperand(1); - SDLoc DL(N); - - // A vector zext_in_reg may be represented as a shuffle, - // feeding into a bitcast (this represents anyext) feeding into - // an and with a mask. - // We'd like to try to combine that into a shuffle with zero - // plus a bitcast, removing the and. - if (N0.getOpcode() != ISD::BITCAST || - N0.getOperand(0).getOpcode() != ISD::VECTOR_SHUFFLE) - return SDValue(); - - // The other side of the AND should be a splat of 2^C, where C - // is the number of bits in the source type. - N1 = peekThroughBitcasts(N1); - if (N1.getOpcode() != ISD::BUILD_VECTOR) - return SDValue(); - BuildVectorSDNode *Vector = cast<BuildVectorSDNode>(N1); - - ShuffleVectorSDNode *Shuffle = cast<ShuffleVectorSDNode>(N0.getOperand(0)); - EVT SrcType = Shuffle->getValueType(0); - - // We expect a single-source shuffle - if (!Shuffle->getOperand(1)->isUndef()) - return SDValue(); - - unsigned SrcSize = SrcType.getScalarSizeInBits(); - unsigned NumElems = SrcType.getVectorNumElements(); - - APInt SplatValue, SplatUndef; - unsigned SplatBitSize; - bool HasAnyUndefs; - if (!Vector->isConstantSplat(SplatValue, SplatUndef, - SplatBitSize, HasAnyUndefs)) - return SDValue(); - - unsigned ResSize = N1.getValueType().getScalarSizeInBits(); - // Make sure the splat matches the mask we expect - if (SplatBitSize > ResSize || - (SplatValue + 1).exactLogBase2() != (int)SrcSize) - return SDValue(); - - // Make sure the input and output size make sense - if (SrcSize >= ResSize || ResSize % SrcSize) - return SDValue(); - - // We expect a shuffle of the form <0, u, u, u, 1, u, u, u...> - // The number of u's between each two values depends on the ratio between - // the source and dest type. - unsigned ZextRatio = ResSize / SrcSize; - bool IsZext = true; - for (unsigned i = 0; i != NumElems; ++i) { - if (i % ZextRatio) { - if (Shuffle->getMaskElt(i) > 0) { - // Expected undef - IsZext = false; - break; - } - } else { - if (Shuffle->getMaskElt(i) != (int)(i / ZextRatio)) { - // Expected element number - IsZext = false; - break; - } - } - } - - if (!IsZext) - return SDValue(); - - // Ok, perform the transformation - replace the shuffle with - // a shuffle of the form <0, k, k, k, 1, k, k, k> with zero - // (instead of undef) where the k elements come from the zero vector. - SmallVector<int, 8> Mask; - for (unsigned i = 0; i != NumElems; ++i) - if (i % ZextRatio) - Mask.push_back(NumElems); - else - Mask.push_back(i / ZextRatio); - - SDValue NewShuffle = DAG.getVectorShuffle(Shuffle->getValueType(0), DL, - Shuffle->getOperand(0), DAG.getConstant(0, DL, SrcType), Mask); - return DAG.getBitcast(N0.getValueType(), NewShuffle); -} - /// If both input operands of a logic op are being cast from floating point /// types, try to convert this into a floating point logic node to avoid /// unnecessary moves from SSE to integer registers. @@ -28255,7 +30716,7 @@ static SDValue combinePCMPAnd1(SDNode *N, SelectionDAG &DAG) { // masked compare nodes, so they should not make it here. EVT VT0 = Op0.getValueType(); EVT VT1 = Op1.getValueType(); - unsigned EltBitWidth = VT0.getScalarType().getSizeInBits(); + unsigned EltBitWidth = VT0.getScalarSizeInBits(); if (VT0 != VT1 || EltBitWidth == 8) return SDValue(); @@ -28277,9 +30738,6 @@ static SDValue combineAnd(SDNode *N, SelectionDAG &DAG, if (DCI.isBeforeLegalizeOps()) return SDValue(); - if (SDValue Zext = combineVectorZext(N, DAG, DCI, Subtarget)) - return Zext; - if (SDValue R = combineCompareEqual(N, DAG, DCI, Subtarget)) return R; @@ -28297,6 +30755,17 @@ static SDValue combineAnd(SDNode *N, SelectionDAG &DAG, SDValue N1 = N->getOperand(1); SDLoc DL(N); + // Attempt to recursively combine a bitmask AND with shuffles. + if (VT.isVector() && (VT.getScalarSizeInBits() % 8) == 0) { + SDValue Op(N, 0); + SmallVector<int, 1> NonceMask; // Just a placeholder. + NonceMask.push_back(0); + if (combineX86ShufflesRecursively({Op}, 0, Op, NonceMask, + /*Depth*/ 1, /*HasVarMask*/ false, DAG, + DCI, Subtarget)) + return SDValue(); // This routine will use CombineTo to replace N. + } + // Create BEXTR instructions // BEXTR is ((X >> imm) & (2**size-1)) if (VT != MVT::i32 && VT != MVT::i64) @@ -28372,7 +30841,7 @@ static SDValue combineLogicBlendIntoPBLENDV(SDNode *N, SelectionDAG &DAG, // Validate that the Mask operand is a vector sra node. // FIXME: what to do for bytes, since there is a psignb/pblendvb, but // there is no psrai.b - unsigned EltBits = MaskVT.getVectorElementType().getSizeInBits(); + unsigned EltBits = MaskVT.getScalarSizeInBits(); unsigned SraAmt = ~0; if (Mask.getOpcode() == ISD::SRA) { if (auto *AmtBV = dyn_cast<BuildVectorSDNode>(Mask.getOperand(1))) @@ -28450,6 +30919,114 @@ static SDValue combineLogicBlendIntoPBLENDV(SDNode *N, SelectionDAG &DAG, return DAG.getBitcast(VT, Mask); } +// Helper function for combineOrCmpEqZeroToCtlzSrl +// Transforms: +// seteq(cmp x, 0) +// into: +// srl(ctlz x), log2(bitsize(x)) +// Input pattern is checked by caller. +static SDValue lowerX86CmpEqZeroToCtlzSrl(SDValue Op, EVT ExtTy, + SelectionDAG &DAG) { + SDValue Cmp = Op.getOperand(1); + EVT VT = Cmp.getOperand(0).getValueType(); + unsigned Log2b = Log2_32(VT.getSizeInBits()); + SDLoc dl(Op); + SDValue Clz = DAG.getNode(ISD::CTLZ, dl, VT, Cmp->getOperand(0)); + // The result of the shift is true or false, and on X86, the 32-bit + // encoding of shr and lzcnt is more desirable. + SDValue Trunc = DAG.getZExtOrTrunc(Clz, dl, MVT::i32); + SDValue Scc = DAG.getNode(ISD::SRL, dl, MVT::i32, Trunc, + DAG.getConstant(Log2b, dl, VT)); + return DAG.getZExtOrTrunc(Scc, dl, ExtTy); +} + +// Try to transform: +// zext(or(setcc(eq, (cmp x, 0)), setcc(eq, (cmp y, 0)))) +// into: +// srl(or(ctlz(x), ctlz(y)), log2(bitsize(x)) +// Will also attempt to match more generic cases, eg: +// zext(or(or(setcc(eq, cmp 0), setcc(eq, cmp 0)), setcc(eq, cmp 0))) +// Only applies if the target supports the FastLZCNT feature. +static SDValue combineOrCmpEqZeroToCtlzSrl(SDNode *N, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI, + const X86Subtarget &Subtarget) { + if (DCI.isBeforeLegalize() || !Subtarget.getTargetLowering()->isCtlzFast()) + return SDValue(); + + auto isORCandidate = [](SDValue N) { + return (N->getOpcode() == ISD::OR && N->hasOneUse()); + }; + + // Check the zero extend is extending to 32-bit or more. The code generated by + // srl(ctlz) for 16-bit or less variants of the pattern would require extra + // instructions to clear the upper bits. + if (!N->hasOneUse() || !N->getSimpleValueType(0).bitsGE(MVT::i32) || + !isORCandidate(N->getOperand(0))) + return SDValue(); + + // Check the node matches: setcc(eq, cmp 0) + auto isSetCCCandidate = [](SDValue N) { + return N->getOpcode() == X86ISD::SETCC && N->hasOneUse() && + X86::CondCode(N->getConstantOperandVal(0)) == X86::COND_E && + N->getOperand(1).getOpcode() == X86ISD::CMP && + N->getOperand(1).getConstantOperandVal(1) == 0 && + N->getOperand(1).getValueType().bitsGE(MVT::i32); + }; + + SDNode *OR = N->getOperand(0).getNode(); + SDValue LHS = OR->getOperand(0); + SDValue RHS = OR->getOperand(1); + + // Save nodes matching or(or, setcc(eq, cmp 0)). + SmallVector<SDNode *, 2> ORNodes; + while (((isORCandidate(LHS) && isSetCCCandidate(RHS)) || + (isORCandidate(RHS) && isSetCCCandidate(LHS)))) { + ORNodes.push_back(OR); + OR = (LHS->getOpcode() == ISD::OR) ? LHS.getNode() : RHS.getNode(); + LHS = OR->getOperand(0); + RHS = OR->getOperand(1); + } + + // The last OR node should match or(setcc(eq, cmp 0), setcc(eq, cmp 0)). + if (!(isSetCCCandidate(LHS) && isSetCCCandidate(RHS)) || + !isORCandidate(SDValue(OR, 0))) + return SDValue(); + + // We have a or(setcc(eq, cmp 0), setcc(eq, cmp 0)) pattern, try to lower it + // to + // or(srl(ctlz),srl(ctlz)). + // The dag combiner can then fold it into: + // srl(or(ctlz, ctlz)). + EVT VT = OR->getValueType(0); + SDValue NewLHS = lowerX86CmpEqZeroToCtlzSrl(LHS, VT, DAG); + SDValue Ret, NewRHS; + if (NewLHS && (NewRHS = lowerX86CmpEqZeroToCtlzSrl(RHS, VT, DAG))) + Ret = DAG.getNode(ISD::OR, SDLoc(OR), VT, NewLHS, NewRHS); + + if (!Ret) + return SDValue(); + + // Try to lower nodes matching the or(or, setcc(eq, cmp 0)) pattern. + while (ORNodes.size() > 0) { + OR = ORNodes.pop_back_val(); + LHS = OR->getOperand(0); + RHS = OR->getOperand(1); + // Swap rhs with lhs to match or(setcc(eq, cmp, 0), or). + if (RHS->getOpcode() == ISD::OR) + std::swap(LHS, RHS); + EVT VT = OR->getValueType(0); + SDValue NewRHS = lowerX86CmpEqZeroToCtlzSrl(RHS, VT, DAG); + if (!NewRHS) + return SDValue(); + Ret = DAG.getNode(ISD::OR, SDLoc(OR), VT, Ret, NewRHS); + } + + if (Ret) + Ret = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0), Ret); + + return Ret; +} + static SDValue combineOr(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { @@ -28505,18 +31082,23 @@ static SDValue combineOr(SDNode *N, SelectionDAG &DAG, unsigned Opc = X86ISD::SHLD; SDValue Op0 = N0.getOperand(0); SDValue Op1 = N1.getOperand(0); - if (ShAmt0.getOpcode() == ISD::SUB) { + if (ShAmt0.getOpcode() == ISD::SUB || + ShAmt0.getOpcode() == ISD::XOR) { Opc = X86ISD::SHRD; std::swap(Op0, Op1); std::swap(ShAmt0, ShAmt1); } + // OR( SHL( X, C ), SRL( Y, 32 - C ) ) -> SHLD( X, Y, C ) + // OR( SRL( X, C ), SHL( Y, 32 - C ) ) -> SHRD( X, Y, C ) + // OR( SHL( X, C ), SRL( SRL( Y, 1 ), XOR( C, 31 ) ) ) -> SHLD( X, Y, C ) + // OR( SRL( X, C ), SHL( SHL( Y, 1 ), XOR( C, 31 ) ) ) -> SHRD( X, Y, C ) unsigned Bits = VT.getSizeInBits(); if (ShAmt1.getOpcode() == ISD::SUB) { SDValue Sum = ShAmt1.getOperand(0); if (ConstantSDNode *SumC = dyn_cast<ConstantSDNode>(Sum)) { SDValue ShAmt1Op1 = ShAmt1.getOperand(1); - if (ShAmt1Op1.getNode()->getOpcode() == ISD::TRUNCATE) + if (ShAmt1Op1.getOpcode() == ISD::TRUNCATE) ShAmt1Op1 = ShAmt1Op1.getOperand(0); if (SumC->getSExtValue() == Bits && ShAmt1Op1 == ShAmt0) return DAG.getNode(Opc, DL, VT, @@ -28526,18 +31108,39 @@ static SDValue combineOr(SDNode *N, SelectionDAG &DAG, } } else if (ConstantSDNode *ShAmt1C = dyn_cast<ConstantSDNode>(ShAmt1)) { ConstantSDNode *ShAmt0C = dyn_cast<ConstantSDNode>(ShAmt0); - if (ShAmt0C && - ShAmt0C->getSExtValue() + ShAmt1C->getSExtValue() == Bits) + if (ShAmt0C && (ShAmt0C->getSExtValue() + ShAmt1C->getSExtValue()) == Bits) return DAG.getNode(Opc, DL, VT, N0.getOperand(0), N1.getOperand(0), DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, ShAmt0)); + } else if (ShAmt1.getOpcode() == ISD::XOR) { + SDValue Mask = ShAmt1.getOperand(1); + if (ConstantSDNode *MaskC = dyn_cast<ConstantSDNode>(Mask)) { + unsigned InnerShift = (X86ISD::SHLD == Opc ? ISD::SRL : ISD::SHL); + SDValue ShAmt1Op0 = ShAmt1.getOperand(0); + if (ShAmt1Op0.getOpcode() == ISD::TRUNCATE) + ShAmt1Op0 = ShAmt1Op0.getOperand(0); + if (MaskC->getSExtValue() == (Bits - 1) && ShAmt1Op0 == ShAmt0) { + if (Op1.getOpcode() == InnerShift && + isa<ConstantSDNode>(Op1.getOperand(1)) && + Op1.getConstantOperandVal(1) == 1) { + return DAG.getNode(Opc, DL, VT, Op0, Op1.getOperand(0), + DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, ShAmt0)); + } + // Test for ADD( Y, Y ) as an equivalent to SHL( Y, 1 ). + if (InnerShift == ISD::SHL && Op1.getOpcode() == ISD::ADD && + Op1.getOperand(0) == Op1.getOperand(1)) { + return DAG.getNode(Opc, DL, VT, Op0, Op1.getOperand(0), + DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, ShAmt0)); + } + } + } } return SDValue(); } -// Generate NEG and CMOV for integer abs. +/// Generate NEG and CMOV for integer abs. static SDValue combineIntegerAbs(SDNode *N, SelectionDAG &DAG) { EVT VT = N->getValueType(0); @@ -28553,21 +31156,19 @@ static SDValue combineIntegerAbs(SDNode *N, SelectionDAG &DAG) { // Check pattern of XOR(ADD(X,Y), Y) where Y is SRA(X, size(X)-1) // and change it to SUB and CMOV. if (VT.isInteger() && N->getOpcode() == ISD::XOR && - N0.getOpcode() == ISD::ADD && - N0.getOperand(1) == N1 && - N1.getOpcode() == ISD::SRA && - N1.getOperand(0) == N0.getOperand(0)) - if (ConstantSDNode *Y1C = dyn_cast<ConstantSDNode>(N1.getOperand(1))) - if (Y1C->getAPIntValue() == VT.getSizeInBits()-1) { - // Generate SUB & CMOV. - SDValue Neg = DAG.getNode(X86ISD::SUB, DL, DAG.getVTList(VT, MVT::i32), - DAG.getConstant(0, DL, VT), N0.getOperand(0)); - - SDValue Ops[] = { N0.getOperand(0), Neg, - DAG.getConstant(X86::COND_GE, DL, MVT::i8), - SDValue(Neg.getNode(), 1) }; - return DAG.getNode(X86ISD::CMOV, DL, DAG.getVTList(VT, MVT::Glue), Ops); - } + N0.getOpcode() == ISD::ADD && N0.getOperand(1) == N1 && + N1.getOpcode() == ISD::SRA && N1.getOperand(0) == N0.getOperand(0)) { + auto *Y1C = dyn_cast<ConstantSDNode>(N1.getOperand(1)); + if (Y1C && Y1C->getAPIntValue() == VT.getSizeInBits() - 1) { + // Generate SUB & CMOV. + SDValue Neg = DAG.getNode(X86ISD::SUB, DL, DAG.getVTList(VT, MVT::i32), + DAG.getConstant(0, DL, VT), N0.getOperand(0)); + SDValue Ops[] = {N0.getOperand(0), Neg, + DAG.getConstant(X86::COND_GE, DL, MVT::i8), + SDValue(Neg.getNode(), 1)}; + return DAG.getNode(X86ISD::CMOV, DL, DAG.getVTList(VT, MVT::Glue), Ops); + } + } return SDValue(); } @@ -28671,28 +31272,6 @@ static SDValue foldVectorXorShiftIntoCmp(SDNode *N, SelectionDAG &DAG, return DAG.getNode(X86ISD::PCMPGT, SDLoc(N), VT, Shift.getOperand(0), Ones); } -static SDValue combineXor(SDNode *N, SelectionDAG &DAG, - TargetLowering::DAGCombinerInfo &DCI, - const X86Subtarget &Subtarget) { - if (SDValue Cmp = foldVectorXorShiftIntoCmp(N, DAG, Subtarget)) - return Cmp; - - if (DCI.isBeforeLegalizeOps()) - return SDValue(); - - if (SDValue RV = foldXorTruncShiftIntoCmp(N, DAG)) - return RV; - - if (Subtarget.hasCMov()) - if (SDValue RV = combineIntegerAbs(N, DAG)) - return RV; - - if (SDValue FPLogic = convertIntLogicToFPLogic(N, DAG, Subtarget)) - return FPLogic; - - return SDValue(); -} - /// This function detects the AVG pattern between vectors of unsigned i8/i16, /// which is c = (a + b + 1) / 2, and replace this operation with the efficient /// X86ISD::AVG instruction. @@ -28717,7 +31296,7 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG, if (!Subtarget.hasSSE2()) return SDValue(); - if (Subtarget.hasAVX512()) { + if (Subtarget.hasBWI()) { if (VT.getSizeInBits() > 512) return SDValue(); } else if (Subtarget.hasAVX2()) { @@ -28999,6 +31578,11 @@ static SDValue combineMaskedLoad(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { MaskedLoadSDNode *Mld = cast<MaskedLoadSDNode>(N); + + // TODO: Expanding load with constant mask may be optimized as well. + if (Mld->isExpandingLoad()) + return SDValue(); + if (Mld->getExtensionType() == ISD::NON_EXTLOAD) { if (SDValue ScalarLoad = reduceMaskedLoadToScalarLoad(Mld, DAG, DCI)) return ScalarLoad; @@ -29018,8 +31602,8 @@ static SDValue combineMaskedLoad(SDNode *N, SelectionDAG &DAG, SDLoc dl(Mld); assert(LdVT != VT && "Cannot extend to the same type"); - unsigned ToSz = VT.getVectorElementType().getSizeInBits(); - unsigned FromSz = LdVT.getVectorElementType().getSizeInBits(); + unsigned ToSz = VT.getScalarSizeInBits(); + unsigned FromSz = LdVT.getScalarSizeInBits(); // From/To sizes and ElemCount must be pow of two. assert (isPowerOf2_32(NumElems * FromSz * ToSz) && "Unexpected size for extending masked load"); @@ -29114,6 +31698,10 @@ static SDValue reduceMaskedStoreToScalarStore(MaskedStoreSDNode *MS, static SDValue combineMaskedStore(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { MaskedStoreSDNode *Mst = cast<MaskedStoreSDNode>(N); + + if (Mst->isCompressingStore()) + return SDValue(); + if (!Mst->isTruncatingStore()) return reduceMaskedStoreToScalarStore(Mst, DAG); @@ -29124,8 +31712,8 @@ static SDValue combineMaskedStore(SDNode *N, SelectionDAG &DAG, SDLoc dl(Mst); assert(StVT != VT && "Cannot truncate to the same type"); - unsigned FromSz = VT.getVectorElementType().getSizeInBits(); - unsigned ToSz = StVT.getVectorElementType().getSizeInBits(); + unsigned FromSz = VT.getScalarSizeInBits(); + unsigned ToSz = StVT.getScalarSizeInBits(); const TargetLowering &TLI = DAG.getTargetLoweringInfo(); @@ -29253,8 +31841,8 @@ static SDValue combineStore(SDNode *N, SelectionDAG &DAG, const TargetLowering &TLI = DAG.getTargetLoweringInfo(); unsigned NumElems = VT.getVectorNumElements(); assert(StVT != VT && "Cannot truncate to the same type"); - unsigned FromSz = VT.getVectorElementType().getSizeInBits(); - unsigned ToSz = StVT.getVectorElementType().getSizeInBits(); + unsigned FromSz = VT.getScalarSizeInBits(); + unsigned ToSz = StVT.getScalarSizeInBits(); // The truncating store is legal in some cases. For example // vpmovqb, vpmovqw, vpmovqd, vpmovdb, vpmovdw @@ -29596,6 +32184,83 @@ static SDValue combineFaddFsub(SDNode *N, SelectionDAG &DAG, return SDValue(); } +/// Attempt to pre-truncate inputs to arithmetic ops if it will simplify +/// the codegen. +/// e.g. TRUNC( BINOP( X, Y ) ) --> BINOP( TRUNC( X ), TRUNC( Y ) ) +static SDValue combineTruncatedArithmetic(SDNode *N, SelectionDAG &DAG, + const X86Subtarget &Subtarget, + SDLoc &DL) { + assert(N->getOpcode() == ISD::TRUNCATE && "Wrong opcode"); + SDValue Src = N->getOperand(0); + unsigned Opcode = Src.getOpcode(); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + + EVT VT = N->getValueType(0); + EVT SrcVT = Src.getValueType(); + + auto IsRepeatedOpOrOneUseConstant = [](SDValue Op0, SDValue Op1) { + // TODO: Add extra cases where we can truncate both inputs for the + // cost of one (or none). + // e.g. TRUNC( BINOP( EXT( X ), EXT( Y ) ) ) --> BINOP( X, Y ) + if (Op0 == Op1) + return true; + + SDValue BC0 = peekThroughOneUseBitcasts(Op0); + SDValue BC1 = peekThroughOneUseBitcasts(Op1); + return ISD::isBuildVectorOfConstantSDNodes(BC0.getNode()) || + ISD::isBuildVectorOfConstantSDNodes(BC1.getNode()); + }; + + auto TruncateArithmetic = [&](SDValue N0, SDValue N1) { + SDValue Trunc0 = DAG.getNode(ISD::TRUNCATE, DL, VT, N0); + SDValue Trunc1 = DAG.getNode(ISD::TRUNCATE, DL, VT, N1); + return DAG.getNode(Opcode, DL, VT, Trunc0, Trunc1); + }; + + // Don't combine if the operation has other uses. + if (!N->isOnlyUserOf(Src.getNode())) + return SDValue(); + + // Only support vector truncation for now. + // TODO: i64 scalar math would benefit as well. + if (!VT.isVector()) + return SDValue(); + + // In most cases its only worth pre-truncating if we're only facing the cost + // of one truncation. + // i.e. if one of the inputs will constant fold or the input is repeated. + switch (Opcode) { + case ISD::AND: + case ISD::XOR: + case ISD::OR: { + SDValue Op0 = Src.getOperand(0); + SDValue Op1 = Src.getOperand(1); + if (TLI.isOperationLegalOrPromote(Opcode, VT) && + IsRepeatedOpOrOneUseConstant(Op0, Op1)) + return TruncateArithmetic(Op0, Op1); + break; + } + + case ISD::MUL: + // X86 is rubbish at scalar and vector i64 multiplies (until AVX512DQ) - its + // better to truncate if we have the chance. + if (SrcVT.getScalarType() == MVT::i64 && TLI.isOperationLegal(Opcode, VT) && + !TLI.isOperationLegal(Opcode, SrcVT)) + return TruncateArithmetic(Src.getOperand(0), Src.getOperand(1)); + LLVM_FALLTHROUGH; + case ISD::ADD: { + SDValue Op0 = Src.getOperand(0); + SDValue Op1 = Src.getOperand(1); + if (TLI.isOperationLegal(Opcode, VT) && + IsRepeatedOpOrOneUseConstant(Op0, Op1)) + return TruncateArithmetic(Op0, Op1); + break; + } + } + + return SDValue(); +} + /// Truncate a group of v4i32 into v16i8/v8i16 using X86ISD::PACKUS. static SDValue combineVectorTruncationWithPACKUS(SDNode *N, SelectionDAG &DAG, @@ -29653,7 +32318,8 @@ combineVectorTruncationWithPACKUS(SDNode *N, SelectionDAG &DAG, /// Truncate a group of v4i32 into v8i16 using X86ISD::PACKSS. static SDValue -combineVectorTruncationWithPACKSS(SDNode *N, SelectionDAG &DAG, +combineVectorTruncationWithPACKSS(SDNode *N, const X86Subtarget &Subtarget, + SelectionDAG &DAG, SmallVector<SDValue, 8> &Regs) { assert(Regs.size() > 0 && Regs[0].getValueType() == MVT::v4i32); EVT OutVT = N->getValueType(0); @@ -29662,8 +32328,10 @@ combineVectorTruncationWithPACKSS(SDNode *N, SelectionDAG &DAG, // Shift left by 16 bits, then arithmetic-shift right by 16 bits. SDValue ShAmt = DAG.getConstant(16, DL, MVT::i32); for (auto &Reg : Regs) { - Reg = getTargetVShiftNode(X86ISD::VSHLI, DL, MVT::v4i32, Reg, ShAmt, DAG); - Reg = getTargetVShiftNode(X86ISD::VSRAI, DL, MVT::v4i32, Reg, ShAmt, DAG); + Reg = getTargetVShiftNode(X86ISD::VSHLI, DL, MVT::v4i32, Reg, ShAmt, + Subtarget, DAG); + Reg = getTargetVShiftNode(X86ISD::VSRAI, DL, MVT::v4i32, Reg, ShAmt, + Subtarget, DAG); } for (unsigned i = 0, e = Regs.size() / 2; i < e; i++) @@ -29681,7 +32349,7 @@ combineVectorTruncationWithPACKSS(SDNode *N, SelectionDAG &DAG, /// X86ISD::PACKUS/X86ISD::PACKSS operations. We do it here because after type /// legalization the truncation will be translated into a BUILD_VECTOR with each /// element that is extracted from a vector and then truncated, and it is -/// diffcult to do this optimization based on them. +/// difficult to do this optimization based on them. static SDValue combineVectorTruncation(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { EVT OutVT = N->getValueType(0); @@ -29732,17 +32400,60 @@ static SDValue combineVectorTruncation(SDNode *N, SelectionDAG &DAG, if (Subtarget.hasSSE41() || OutSVT == MVT::i8) return combineVectorTruncationWithPACKUS(N, DAG, SubVec); else if (InSVT == MVT::i32) - return combineVectorTruncationWithPACKSS(N, DAG, SubVec); + return combineVectorTruncationWithPACKSS(N, Subtarget, DAG, SubVec); else return SDValue(); } +/// This function transforms vector truncation of 'all or none' bits values. +/// vXi16/vXi32/vXi64 to vXi8/vXi16/vXi32 into X86ISD::PACKSS operations. +static SDValue combineVectorSignBitsTruncation(SDNode *N, SDLoc &DL, + SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + // Requires SSE2 but AVX512 has fast truncate. + if (!Subtarget.hasSSE2() || Subtarget.hasAVX512()) + return SDValue(); + + if (!N->getValueType(0).isVector() || !N->getValueType(0).isSimple()) + return SDValue(); + + SDValue In = N->getOperand(0); + if (!In.getValueType().isSimple()) + return SDValue(); + + MVT VT = N->getValueType(0).getSimpleVT(); + MVT SVT = VT.getScalarType(); + + MVT InVT = In.getValueType().getSimpleVT(); + MVT InSVT = InVT.getScalarType(); + + // Use PACKSS if the input is a splatted sign bit. + // e.g. Comparison result, sext_in_reg, etc. + unsigned NumSignBits = DAG.ComputeNumSignBits(In); + if (NumSignBits != InSVT.getSizeInBits()) + return SDValue(); + + // Check we have a truncation suited for PACKSS. + if (!VT.is128BitVector() && !VT.is256BitVector()) + return SDValue(); + if (SVT != MVT::i8 && SVT != MVT::i16 && SVT != MVT::i32) + return SDValue(); + if (InSVT != MVT::i16 && InSVT != MVT::i32 && InSVT != MVT::i64) + return SDValue(); + + return truncateVectorCompareWithPACKSS(VT, In, DL, DAG, Subtarget); +} + static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { EVT VT = N->getValueType(0); SDValue Src = N->getOperand(0); SDLoc DL(N); + // Attempt to pre-truncate inputs to arithmetic ops instead. + if (SDValue V = combineTruncatedArithmetic(N, DAG, Subtarget, DL)) + return V; + // Try to detect AVG pattern first. if (SDValue Avg = detectAVGPattern(Src, VT, DAG, Subtarget, DL)) return Avg; @@ -29755,15 +32466,75 @@ static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG, return DAG.getNode(X86ISD::MMX_MOVD2W, DL, MVT::i32, BCSrc); } + // Try to truncate extended sign bits with PACKSS. + if (SDValue V = combineVectorSignBitsTruncation(N, DL, DAG, Subtarget)) + return V; + return combineVectorTruncation(N, DAG, Subtarget); } +/// Returns the negated value if the node \p N flips sign of FP value. +/// +/// FP-negation node may have different forms: FNEG(x) or FXOR (x, 0x80000000). +/// AVX512F does not have FXOR, so FNEG is lowered as +/// (bitcast (xor (bitcast x), (bitcast ConstantFP(0x80000000)))). +/// In this case we go though all bitcasts. +static SDValue isFNEG(SDNode *N) { + if (N->getOpcode() == ISD::FNEG) + return N->getOperand(0); + + SDValue Op = peekThroughBitcasts(SDValue(N, 0)); + if (Op.getOpcode() != X86ISD::FXOR && Op.getOpcode() != ISD::XOR) + return SDValue(); + + SDValue Op1 = peekThroughBitcasts(Op.getOperand(1)); + if (!Op1.getValueType().isFloatingPoint()) + return SDValue(); + + SDValue Op0 = peekThroughBitcasts(Op.getOperand(0)); + + unsigned EltBits = Op1.getScalarValueSizeInBits(); + auto isSignBitValue = [&](const ConstantFP *C) { + return C->getValueAPF().bitcastToAPInt() == APInt::getSignBit(EltBits); + }; + + // There is more than one way to represent the same constant on + // the different X86 targets. The type of the node may also depend on size. + // - load scalar value and broadcast + // - BUILD_VECTOR node + // - load from a constant pool. + // We check all variants here. + if (Op1.getOpcode() == X86ISD::VBROADCAST) { + if (auto *C = getTargetConstantFromNode(Op1.getOperand(0))) + if (isSignBitValue(cast<ConstantFP>(C))) + return Op0; + + } else if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(Op1)) { + if (ConstantFPSDNode *CN = BV->getConstantFPSplatNode()) + if (isSignBitValue(CN->getConstantFPValue())) + return Op0; + + } else if (auto *C = getTargetConstantFromNode(Op1)) { + if (C->getType()->isVectorTy()) { + if (auto *SplatV = C->getSplatValue()) + if (isSignBitValue(cast<ConstantFP>(SplatV))) + return Op0; + } else if (auto *FPConst = dyn_cast<ConstantFP>(C)) + if (isSignBitValue(FPConst)) + return Op0; + } + return SDValue(); +} + /// Do target-specific dag combines on floating point negations. static SDValue combineFneg(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { - EVT VT = N->getValueType(0); + EVT OrigVT = N->getValueType(0); + SDValue Arg = isFNEG(N); + assert(Arg.getNode() && "N is expected to be an FNEG node"); + + EVT VT = Arg.getValueType(); EVT SVT = VT.getScalarType(); - SDValue Arg = N->getOperand(0); SDLoc DL(N); // Let legalize expand this if it isn't a legal type yet. @@ -29776,70 +32547,182 @@ static SDValue combineFneg(SDNode *N, SelectionDAG &DAG, if (Arg.getOpcode() == ISD::FMUL && (SVT == MVT::f32 || SVT == MVT::f64) && Arg->getFlags()->hasNoSignedZeros() && Subtarget.hasAnyFMA()) { SDValue Zero = DAG.getConstantFP(0.0, DL, VT); - return DAG.getNode(X86ISD::FNMSUB, DL, VT, Arg.getOperand(0), - Arg.getOperand(1), Zero); + SDValue NewNode = DAG.getNode(X86ISD::FNMSUB, DL, VT, Arg.getOperand(0), + Arg.getOperand(1), Zero); + return DAG.getBitcast(OrigVT, NewNode); } - // If we're negating a FMA node, then we can adjust the + // If we're negating an FMA node, then we can adjust the // instruction to include the extra negation. + unsigned NewOpcode = 0; if (Arg.hasOneUse()) { switch (Arg.getOpcode()) { - case X86ISD::FMADD: - return DAG.getNode(X86ISD::FNMSUB, DL, VT, Arg.getOperand(0), - Arg.getOperand(1), Arg.getOperand(2)); - case X86ISD::FMSUB: - return DAG.getNode(X86ISD::FNMADD, DL, VT, Arg.getOperand(0), - Arg.getOperand(1), Arg.getOperand(2)); - case X86ISD::FNMADD: - return DAG.getNode(X86ISD::FMSUB, DL, VT, Arg.getOperand(0), - Arg.getOperand(1), Arg.getOperand(2)); - case X86ISD::FNMSUB: - return DAG.getNode(X86ISD::FMADD, DL, VT, Arg.getOperand(0), - Arg.getOperand(1), Arg.getOperand(2)); - } - } + case X86ISD::FMADD: NewOpcode = X86ISD::FNMSUB; break; + case X86ISD::FMSUB: NewOpcode = X86ISD::FNMADD; break; + case X86ISD::FNMADD: NewOpcode = X86ISD::FMSUB; break; + case X86ISD::FNMSUB: NewOpcode = X86ISD::FMADD; break; + case X86ISD::FMADD_RND: NewOpcode = X86ISD::FNMSUB_RND; break; + case X86ISD::FMSUB_RND: NewOpcode = X86ISD::FNMADD_RND; break; + case X86ISD::FNMADD_RND: NewOpcode = X86ISD::FMSUB_RND; break; + case X86ISD::FNMSUB_RND: NewOpcode = X86ISD::FMADD_RND; break; + // We can't handle scalar intrinsic node here because it would only + // invert one element and not the whole vector. But we could try to handle + // a negation of the lower element only. + } + } + if (NewOpcode) + return DAG.getBitcast(OrigVT, DAG.getNode(NewOpcode, DL, VT, + Arg.getNode()->ops())); + return SDValue(); } static SDValue lowerX86FPLogicOp(SDNode *N, SelectionDAG &DAG, - const X86Subtarget &Subtarget) { - EVT VT = N->getValueType(0); - if (VT.is512BitVector() && !Subtarget.hasDQI()) { - // VXORPS, VORPS, VANDPS, VANDNPS are supported only under DQ extention. - // These logic operations may be executed in the integer domain. + const X86Subtarget &Subtarget) { + MVT VT = N->getSimpleValueType(0); + // If we have integer vector types available, use the integer opcodes. + if (VT.isVector() && Subtarget.hasSSE2()) { SDLoc dl(N); - MVT IntScalar = MVT::getIntegerVT(VT.getScalarSizeInBits()); - MVT IntVT = MVT::getVectorVT(IntScalar, VT.getVectorNumElements()); + + MVT IntVT = MVT::getVectorVT(MVT::i64, VT.getSizeInBits() / 64); SDValue Op0 = DAG.getBitcast(IntVT, N->getOperand(0)); SDValue Op1 = DAG.getBitcast(IntVT, N->getOperand(1)); - unsigned IntOpcode = 0; + unsigned IntOpcode; switch (N->getOpcode()) { - default: llvm_unreachable("Unexpected FP logic op"); - case X86ISD::FOR: IntOpcode = ISD::OR; break; - case X86ISD::FXOR: IntOpcode = ISD::XOR; break; - case X86ISD::FAND: IntOpcode = ISD::AND; break; - case X86ISD::FANDN: IntOpcode = X86ISD::ANDNP; break; + default: llvm_unreachable("Unexpected FP logic op"); + case X86ISD::FOR: IntOpcode = ISD::OR; break; + case X86ISD::FXOR: IntOpcode = ISD::XOR; break; + case X86ISD::FAND: IntOpcode = ISD::AND; break; + case X86ISD::FANDN: IntOpcode = X86ISD::ANDNP; break; } SDValue IntOp = DAG.getNode(IntOpcode, dl, IntVT, Op0, Op1); return DAG.getBitcast(VT, IntOp); } return SDValue(); } + +static SDValue combineXor(SDNode *N, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI, + const X86Subtarget &Subtarget) { + if (SDValue Cmp = foldVectorXorShiftIntoCmp(N, DAG, Subtarget)) + return Cmp; + + if (DCI.isBeforeLegalizeOps()) + return SDValue(); + + if (SDValue RV = foldXorTruncShiftIntoCmp(N, DAG)) + return RV; + + if (Subtarget.hasCMov()) + if (SDValue RV = combineIntegerAbs(N, DAG)) + return RV; + + if (SDValue FPLogic = convertIntLogicToFPLogic(N, DAG, Subtarget)) + return FPLogic; + + if (isFNEG(N)) + return combineFneg(N, DAG, Subtarget); + return SDValue(); +} + + +static bool isNullFPScalarOrVectorConst(SDValue V) { + return isNullFPConstant(V) || ISD::isBuildVectorAllZeros(V.getNode()); +} + +/// If a value is a scalar FP zero or a vector FP zero (potentially including +/// undefined elements), return a zero constant that may be used to fold away +/// that value. In the case of a vector, the returned constant will not contain +/// undefined elements even if the input parameter does. This makes it suitable +/// to be used as a replacement operand with operations (eg, bitwise-and) where +/// an undef should not propagate. +static SDValue getNullFPConstForNullVal(SDValue V, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + if (!isNullFPScalarOrVectorConst(V)) + return SDValue(); + + if (V.getValueType().isVector()) + return getZeroVector(V.getSimpleValueType(), Subtarget, DAG, SDLoc(V)); + + return V; +} + +static SDValue combineFAndFNotToFAndn(SDNode *N, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + EVT VT = N->getValueType(0); + SDLoc DL(N); + + // Vector types are handled in combineANDXORWithAllOnesIntoANDNP(). + if (!((VT == MVT::f32 && Subtarget.hasSSE1()) || + (VT == MVT::f64 && Subtarget.hasSSE2()))) + return SDValue(); + + auto isAllOnesConstantFP = [](SDValue V) { + auto *C = dyn_cast<ConstantFPSDNode>(V); + return C && C->getConstantFPValue()->isAllOnesValue(); + }; + + // fand (fxor X, -1), Y --> fandn X, Y + if (N0.getOpcode() == X86ISD::FXOR && isAllOnesConstantFP(N0.getOperand(1))) + return DAG.getNode(X86ISD::FANDN, DL, VT, N0.getOperand(0), N1); + + // fand X, (fxor Y, -1) --> fandn Y, X + if (N1.getOpcode() == X86ISD::FXOR && isAllOnesConstantFP(N1.getOperand(1))) + return DAG.getNode(X86ISD::FANDN, DL, VT, N1.getOperand(0), N0); + + return SDValue(); +} + +/// Do target-specific dag combines on X86ISD::FAND nodes. +static SDValue combineFAnd(SDNode *N, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + // FAND(0.0, x) -> 0.0 + if (SDValue V = getNullFPConstForNullVal(N->getOperand(0), DAG, Subtarget)) + return V; + + // FAND(x, 0.0) -> 0.0 + if (SDValue V = getNullFPConstForNullVal(N->getOperand(1), DAG, Subtarget)) + return V; + + if (SDValue V = combineFAndFNotToFAndn(N, DAG, Subtarget)) + return V; + + return lowerX86FPLogicOp(N, DAG, Subtarget); +} + +/// Do target-specific dag combines on X86ISD::FANDN nodes. +static SDValue combineFAndn(SDNode *N, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + // FANDN(0.0, x) -> x + if (isNullFPScalarOrVectorConst(N->getOperand(0))) + return N->getOperand(1); + + // FANDN(x, 0.0) -> 0.0 + if (SDValue V = getNullFPConstForNullVal(N->getOperand(1), DAG, Subtarget)) + return V; + + return lowerX86FPLogicOp(N, DAG, Subtarget); +} + /// Do target-specific dag combines on X86ISD::FOR and X86ISD::FXOR nodes. static SDValue combineFOr(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { assert(N->getOpcode() == X86ISD::FOR || N->getOpcode() == X86ISD::FXOR); // F[X]OR(0.0, x) -> x - if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(N->getOperand(0))) - if (C->getValueAPF().isPosZero()) - return N->getOperand(1); + if (isNullFPScalarOrVectorConst(N->getOperand(0))) + return N->getOperand(1); // F[X]OR(x, 0.0) -> x - if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(N->getOperand(1))) - if (C->getValueAPF().isPosZero()) - return N->getOperand(0); + if (isNullFPScalarOrVectorConst(N->getOperand(1))) + return N->getOperand(0); + + if (isFNEG(N)) + if (SDValue NewVal = combineFneg(N, DAG, Subtarget)) + return NewVal; return lowerX86FPLogicOp(N, DAG, Subtarget); } @@ -29921,38 +32804,6 @@ static SDValue combineFMinNumFMaxNum(SDNode *N, SelectionDAG &DAG, return DAG.getNode(SelectOpcode, DL, VT, IsOp0Nan, Op1, MinOrMax); } -/// Do target-specific dag combines on X86ISD::FAND nodes. -static SDValue combineFAnd(SDNode *N, SelectionDAG &DAG, - const X86Subtarget &Subtarget) { - // FAND(0.0, x) -> 0.0 - if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(N->getOperand(0))) - if (C->getValueAPF().isPosZero()) - return N->getOperand(0); - - // FAND(x, 0.0) -> 0.0 - if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(N->getOperand(1))) - if (C->getValueAPF().isPosZero()) - return N->getOperand(1); - - return lowerX86FPLogicOp(N, DAG, Subtarget); -} - -/// Do target-specific dag combines on X86ISD::FANDN nodes -static SDValue combineFAndn(SDNode *N, SelectionDAG &DAG, - const X86Subtarget &Subtarget) { - // FANDN(0.0, x) -> x - if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(N->getOperand(0))) - if (C->getValueAPF().isPosZero()) - return N->getOperand(1); - - // FANDN(x, 0.0) -> 0.0 - if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(N->getOperand(1))) - if (C->getValueAPF().isPosZero()) - return N->getOperand(1); - - return lowerX86FPLogicOp(N, DAG, Subtarget); -} - static SDValue combineBT(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI) { // BT ignores high bits in the bit index operand. @@ -29971,17 +32822,6 @@ static SDValue combineBT(SDNode *N, SelectionDAG &DAG, return SDValue(); } -static SDValue combineVZextMovl(SDNode *N, SelectionDAG &DAG) { - SDValue Op = peekThroughBitcasts(N->getOperand(0)); - EVT VT = N->getValueType(0), OpVT = Op.getValueType(); - if (Op.getOpcode() == X86ISD::VZEXT_LOAD && - VT.getVectorElementType().getSizeInBits() == - OpVT.getVectorElementType().getSizeInBits()) { - return DAG.getBitcast(VT, Op); - } - return SDValue(); -} - static SDValue combineSignExtendInReg(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { EVT VT = N->getValueType(0); @@ -30018,19 +32858,32 @@ static SDValue combineSignExtendInReg(SDNode *N, SelectionDAG &DAG, } /// sext(add_nsw(x, C)) --> add(sext(x), C_sext) -/// Promoting a sign extension ahead of an 'add nsw' exposes opportunities -/// to combine math ops, use an LEA, or use a complex addressing mode. This can -/// eliminate extend, add, and shift instructions. -static SDValue promoteSextBeforeAddNSW(SDNode *Sext, SelectionDAG &DAG, - const X86Subtarget &Subtarget) { +/// zext(add_nuw(x, C)) --> add(zext(x), C_zext) +/// Promoting a sign/zero extension ahead of a no overflow 'add' exposes +/// opportunities to combine math ops, use an LEA, or use a complex addressing +/// mode. This can eliminate extend, add, and shift instructions. +static SDValue promoteExtBeforeAdd(SDNode *Ext, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + if (Ext->getOpcode() != ISD::SIGN_EXTEND && + Ext->getOpcode() != ISD::ZERO_EXTEND) + return SDValue(); + // TODO: This should be valid for other integer types. - EVT VT = Sext->getValueType(0); + EVT VT = Ext->getValueType(0); if (VT != MVT::i64) return SDValue(); - // We need an 'add nsw' feeding into the 'sext'. - SDValue Add = Sext->getOperand(0); - if (Add.getOpcode() != ISD::ADD || !Add->getFlags()->hasNoSignedWrap()) + SDValue Add = Ext->getOperand(0); + if (Add.getOpcode() != ISD::ADD) + return SDValue(); + + bool Sext = Ext->getOpcode() == ISD::SIGN_EXTEND; + bool NSW = Add->getFlags()->hasNoSignedWrap(); + bool NUW = Add->getFlags()->hasNoUnsignedWrap(); + + // We need an 'add nsw' feeding into the 'sext' or 'add nuw' feeding + // into the 'zext' + if ((Sext && !NSW) || (!Sext && !NUW)) return SDValue(); // Having a constant operand to the 'add' ensures that we are not increasing @@ -30046,7 +32899,7 @@ static SDValue promoteSextBeforeAddNSW(SDNode *Sext, SelectionDAG &DAG, // of single 'add' instructions, but the cost model for selecting an LEA // currently has a high threshold. bool HasLEAPotential = false; - for (auto *User : Sext->uses()) { + for (auto *User : Ext->uses()) { if (User->getOpcode() == ISD::ADD || User->getOpcode() == ISD::SHL) { HasLEAPotential = true; break; @@ -30055,17 +32908,18 @@ static SDValue promoteSextBeforeAddNSW(SDNode *Sext, SelectionDAG &DAG, if (!HasLEAPotential) return SDValue(); - // Everything looks good, so pull the 'sext' ahead of the 'add'. - int64_t AddConstant = AddOp1->getSExtValue(); + // Everything looks good, so pull the '{s|z}ext' ahead of the 'add'. + int64_t AddConstant = Sext ? AddOp1->getSExtValue() : AddOp1->getZExtValue(); SDValue AddOp0 = Add.getOperand(0); - SDValue NewSext = DAG.getNode(ISD::SIGN_EXTEND, SDLoc(Sext), VT, AddOp0); + SDValue NewExt = DAG.getNode(Ext->getOpcode(), SDLoc(Ext), VT, AddOp0); SDValue NewConstant = DAG.getConstant(AddConstant, SDLoc(Add), VT); // The wider add is guaranteed to not wrap because both operands are // sign-extended. SDNodeFlags Flags; - Flags.setNoSignedWrap(true); - return DAG.getNode(ISD::ADD, SDLoc(Add), VT, NewSext, NewConstant, &Flags); + Flags.setNoSignedWrap(NSW); + Flags.setNoUnsignedWrap(NUW); + return DAG.getNode(ISD::ADD, SDLoc(Add), VT, NewExt, NewConstant, &Flags); } /// (i8,i32 {s/z}ext ({s/u}divrem (i8 x, i8 y)) -> @@ -30157,18 +33011,17 @@ static SDValue combineToExtendVectorInReg(SDNode *N, SelectionDAG &DAG, // ISD::*_EXTEND_VECTOR_INREG which ensures lowering to X86ISD::V*EXT. // Also use this if we don't have SSE41 to allow the legalizer do its job. if (!Subtarget.hasSSE41() || VT.is128BitVector() || - (VT.is256BitVector() && Subtarget.hasInt256())) { + (VT.is256BitVector() && Subtarget.hasInt256()) || + (VT.is512BitVector() && Subtarget.hasAVX512())) { SDValue ExOp = ExtendVecSize(DL, N0, VT.getSizeInBits()); return Opcode == ISD::SIGN_EXTEND ? DAG.getSignExtendVectorInReg(ExOp, DL, VT) : DAG.getZeroExtendVectorInReg(ExOp, DL, VT); } - // On pre-AVX2 targets, split into 128-bit nodes of - // ISD::*_EXTEND_VECTOR_INREG. - if (!Subtarget.hasInt256() && !(VT.getSizeInBits() % 128)) { - unsigned NumVecs = VT.getSizeInBits() / 128; - unsigned NumSubElts = 128 / SVT.getSizeInBits(); + auto SplitAndExtendInReg = [&](unsigned SplitSize) { + unsigned NumVecs = VT.getSizeInBits() / SplitSize; + unsigned NumSubElts = SplitSize / SVT.getSizeInBits(); EVT SubVT = EVT::getVectorVT(*DAG.getContext(), SVT, NumSubElts); EVT InSubVT = EVT::getVectorVT(*DAG.getContext(), InSVT, NumSubElts); @@ -30176,14 +33029,24 @@ static SDValue combineToExtendVectorInReg(SDNode *N, SelectionDAG &DAG, for (unsigned i = 0, Offset = 0; i != NumVecs; ++i, Offset += NumSubElts) { SDValue SrcVec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, InSubVT, N0, DAG.getIntPtrConstant(Offset, DL)); - SrcVec = ExtendVecSize(DL, SrcVec, 128); + SrcVec = ExtendVecSize(DL, SrcVec, SplitSize); SrcVec = Opcode == ISD::SIGN_EXTEND ? DAG.getSignExtendVectorInReg(SrcVec, DL, SubVT) : DAG.getZeroExtendVectorInReg(SrcVec, DL, SubVT); Opnds.push_back(SrcVec); } return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Opnds); - } + }; + + // On pre-AVX2 targets, split into 128-bit nodes of + // ISD::*_EXTEND_VECTOR_INREG. + if (!Subtarget.hasInt256() && !(VT.getSizeInBits() % 128)) + return SplitAndExtendInReg(128); + + // On pre-AVX512 targets, split into 256-bit nodes of + // ISD::*_EXTEND_VECTOR_INREG. + if (!Subtarget.hasAVX512() && !(VT.getSizeInBits() % 256)) + return SplitAndExtendInReg(256); return SDValue(); } @@ -30216,7 +33079,7 @@ static SDValue combineSext(SDNode *N, SelectionDAG &DAG, if (SDValue R = WidenMaskArithmetic(N, DAG, DCI, Subtarget)) return R; - if (SDValue NewAdd = promoteSextBeforeAddNSW(N, DAG, Subtarget)) + if (SDValue NewAdd = promoteExtBeforeAdd(N, DAG, Subtarget)) return NewAdd; return SDValue(); @@ -30239,26 +33102,58 @@ static SDValue combineFMA(SDNode *N, SelectionDAG &DAG, SDValue B = N->getOperand(1); SDValue C = N->getOperand(2); - bool NegA = (A.getOpcode() == ISD::FNEG); - bool NegB = (B.getOpcode() == ISD::FNEG); - bool NegC = (C.getOpcode() == ISD::FNEG); + auto invertIfNegative = [](SDValue &V) { + if (SDValue NegVal = isFNEG(V.getNode())) { + V = NegVal; + return true; + } + return false; + }; + + // Do not convert the passthru input of scalar intrinsics. + // FIXME: We could allow negations of the lower element only. + bool NegA = N->getOpcode() != X86ISD::FMADDS1_RND && invertIfNegative(A); + bool NegB = invertIfNegative(B); + bool NegC = N->getOpcode() != X86ISD::FMADDS3_RND && invertIfNegative(C); // Negative multiplication when NegA xor NegB bool NegMul = (NegA != NegB); - if (NegA) - A = A.getOperand(0); - if (NegB) - B = B.getOperand(0); - if (NegC) - C = C.getOperand(0); - unsigned Opcode; + unsigned NewOpcode; if (!NegMul) - Opcode = (!NegC) ? X86ISD::FMADD : X86ISD::FMSUB; + NewOpcode = (!NegC) ? X86ISD::FMADD : X86ISD::FMSUB; else - Opcode = (!NegC) ? X86ISD::FNMADD : X86ISD::FNMSUB; + NewOpcode = (!NegC) ? X86ISD::FNMADD : X86ISD::FNMSUB; + + + if (N->getOpcode() == X86ISD::FMADD_RND) { + switch (NewOpcode) { + case X86ISD::FMADD: NewOpcode = X86ISD::FMADD_RND; break; + case X86ISD::FMSUB: NewOpcode = X86ISD::FMSUB_RND; break; + case X86ISD::FNMADD: NewOpcode = X86ISD::FNMADD_RND; break; + case X86ISD::FNMSUB: NewOpcode = X86ISD::FNMSUB_RND; break; + } + } else if (N->getOpcode() == X86ISD::FMADDS1_RND) { + switch (NewOpcode) { + case X86ISD::FMADD: NewOpcode = X86ISD::FMADDS1_RND; break; + case X86ISD::FMSUB: NewOpcode = X86ISD::FMSUBS1_RND; break; + case X86ISD::FNMADD: NewOpcode = X86ISD::FNMADDS1_RND; break; + case X86ISD::FNMSUB: NewOpcode = X86ISD::FNMSUBS1_RND; break; + } + } else if (N->getOpcode() == X86ISD::FMADDS3_RND) { + switch (NewOpcode) { + case X86ISD::FMADD: NewOpcode = X86ISD::FMADDS3_RND; break; + case X86ISD::FMSUB: NewOpcode = X86ISD::FMSUBS3_RND; break; + case X86ISD::FNMADD: NewOpcode = X86ISD::FNMADDS3_RND; break; + case X86ISD::FNMSUB: NewOpcode = X86ISD::FNMSUBS3_RND; break; + } + } else { + assert((N->getOpcode() == X86ISD::FMADD || N->getOpcode() == ISD::FMA) && + "Unexpected opcode!"); + return DAG.getNode(NewOpcode, dl, VT, A, B, C); + } - return DAG.getNode(Opcode, dl, VT, A, B, C); + return DAG.getNode(NewOpcode, dl, VT, A, B, C, N->getOperand(3)); } static SDValue combineZext(SDNode *N, SelectionDAG &DAG, @@ -30308,6 +33203,12 @@ static SDValue combineZext(SDNode *N, SelectionDAG &DAG, if (SDValue DivRem8 = getDivRem8(N, DAG)) return DivRem8; + if (SDValue NewAdd = promoteExtBeforeAdd(N, DAG, Subtarget)) + return NewAdd; + + if (SDValue R = combineOrCmpEqZeroToCtlzSrl(N, DAG, DCI, Subtarget)) + return R; + return SDValue(); } @@ -30443,10 +33344,8 @@ static SDValue combineX86SetCC(SDNode *N, SelectionDAG &DAG, return MaterializeSETB(DL, EFLAGS, DAG, N->getSimpleValueType(0)); // Try to simplify the EFLAGS and condition code operands. - if (SDValue Flags = combineSetCCEFLAGS(EFLAGS, CC, DAG)) { - SDValue Cond = DAG.getConstant(CC, DL, MVT::i8); - return DAG.getNode(X86ISD::SETCC, DL, N->getVTList(), Cond, Flags); - } + if (SDValue Flags = combineSetCCEFLAGS(EFLAGS, CC, DAG)) + return getSETCC(CC, Flags, DL, DAG); return SDValue(); } @@ -30539,6 +33438,12 @@ static SDValue combineUIntToFP(SDNode *N, SelectionDAG &DAG, return DAG.getNode(ISD::SINT_TO_FP, dl, VT, P); } + // Since UINT_TO_FP is legal (it's marked custom), dag combiner won't + // optimize it to a SINT_TO_FP when the sign bit is known zero. Perform + // the optimization here. + if (DAG.SignBitIsZero(Op0)) + return DAG.getNode(ISD::SINT_TO_FP, SDLoc(N), VT, Op0); + return SDValue(); } @@ -30555,9 +33460,12 @@ static SDValue combineSIntToFP(SDNode *N, SelectionDAG &DAG, EVT InVT = Op0.getValueType(); EVT InSVT = InVT.getScalarType(); + // SINT_TO_FP(vXi1) -> SINT_TO_FP(SEXT(vXi1 to vXi32)) // SINT_TO_FP(vXi8) -> SINT_TO_FP(SEXT(vXi8 to vXi32)) // SINT_TO_FP(vXi16) -> SINT_TO_FP(SEXT(vXi16 to vXi32)) - if (InVT.isVector() && (InSVT == MVT::i8 || InSVT == MVT::i16)) { + if (InVT.isVector() && + (InSVT == MVT::i8 || InSVT == MVT::i16 || + (InSVT == MVT::i1 && !DAG.getTargetLoweringInfo().isTypeLegal(InVT)))) { SDLoc dl(N); EVT DstVT = EVT::getVectorVT(*DAG.getContext(), MVT::i32, InVT.getVectorNumElements()); @@ -30565,6 +33473,23 @@ static SDValue combineSIntToFP(SDNode *N, SelectionDAG &DAG, return DAG.getNode(ISD::SINT_TO_FP, dl, VT, P); } + // Without AVX512DQ we only support i64 to float scalar conversion. For both + // vectors and scalars, see if we know that the upper bits are all the sign + // bit, in which case we can truncate the input to i32 and convert from that. + if (InVT.getScalarSizeInBits() > 32 && !Subtarget.hasDQI()) { + unsigned BitWidth = InVT.getScalarSizeInBits(); + unsigned NumSignBits = DAG.ComputeNumSignBits(Op0); + if (NumSignBits >= (BitWidth - 31)) { + EVT TruncVT = EVT::getIntegerVT(*DAG.getContext(), 32); + if (InVT.isVector()) + TruncVT = EVT::getVectorVT(*DAG.getContext(), TruncVT, + InVT.getVectorNumElements()); + SDLoc dl(N); + SDValue Trunc = DAG.getNode(ISD::TRUNCATE, dl, TruncVT, Op0); + return DAG.getNode(ISD::SINT_TO_FP, dl, VT, Trunc); + } + } + // Transform (SINT_TO_FP (i64 ...)) into an x87 operation if we have // a 32-bit target where SSE doesn't support i64->FP operations. if (!Subtarget.useSoftFloat() && Op0.getOpcode() == ISD::LOAD) { @@ -30654,13 +33579,15 @@ static SDValue OptimizeConditionalInDecrement(SDNode *N, SelectionDAG &DAG) { DAG.getConstant(0, DL, OtherVal.getValueType()), NewCmp); } -static SDValue detectSADPattern(SDNode *N, SelectionDAG &DAG, - const X86Subtarget &Subtarget) { +static SDValue combineLoopSADPattern(SDNode *N, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { SDLoc DL(N); EVT VT = N->getValueType(0); SDValue Op0 = N->getOperand(0); SDValue Op1 = N->getOperand(1); + // TODO: There's nothing special about i32, any integer type above i16 should + // work just as well. if (!VT.isVector() || !VT.isSimple() || !(VT.getVectorElementType() == MVT::i32)) return SDValue(); @@ -30672,24 +33599,13 @@ static SDValue detectSADPattern(SDNode *N, SelectionDAG &DAG, RegSize = 256; // We only handle v16i32 for SSE2 / v32i32 for AVX2 / v64i32 for AVX512. + // TODO: We should be able to handle larger vectors by splitting them before + // feeding them into several SADs, and then reducing over those. if (VT.getSizeInBits() / 4 > RegSize) return SDValue(); - // Detect the following pattern: - // - // 1: %2 = zext <N x i8> %0 to <N x i32> - // 2: %3 = zext <N x i8> %1 to <N x i32> - // 3: %4 = sub nsw <N x i32> %2, %3 - // 4: %5 = icmp sgt <N x i32> %4, [0 x N] or [-1 x N] - // 5: %6 = sub nsw <N x i32> zeroinitializer, %4 - // 6: %7 = select <N x i1> %5, <N x i32> %4, <N x i32> %6 - // 7: %8 = add nsw <N x i32> %7, %vec.phi - // - // The last instruction must be a reduction add. The instructions 3-6 forms an - // ABSDIFF pattern. - - // The two operands of reduction add are from PHI and a select-op as in line 7 - // above. + // We know N is a reduction add, which means one of its operands is a phi. + // To match SAD, we need the other operand to be a vector select. SDValue SelectOp, Phi; if (Op0.getOpcode() == ISD::VSELECT) { SelectOp = Op0; @@ -30700,77 +33616,22 @@ static SDValue detectSADPattern(SDNode *N, SelectionDAG &DAG, } else return SDValue(); - // Check the condition of the select instruction is greater-than. - SDValue SetCC = SelectOp->getOperand(0); - if (SetCC.getOpcode() != ISD::SETCC) - return SDValue(); - ISD::CondCode CC = cast<CondCodeSDNode>(SetCC.getOperand(2))->get(); - if (CC != ISD::SETGT) - return SDValue(); - - Op0 = SelectOp->getOperand(1); - Op1 = SelectOp->getOperand(2); - - // The second operand of SelectOp Op1 is the negation of the first operand - // Op0, which is implemented as 0 - Op0. - if (!(Op1.getOpcode() == ISD::SUB && - ISD::isBuildVectorAllZeros(Op1.getOperand(0).getNode()) && - Op1.getOperand(1) == Op0)) - return SDValue(); - - // The first operand of SetCC is the first operand of SelectOp, which is the - // difference between two input vectors. - if (SetCC.getOperand(0) != Op0) - return SDValue(); - - // The second operand of > comparison can be either -1 or 0. - if (!(ISD::isBuildVectorAllZeros(SetCC.getOperand(1).getNode()) || - ISD::isBuildVectorAllOnes(SetCC.getOperand(1).getNode()))) - return SDValue(); - - // The first operand of SelectOp is the difference between two input vectors. - if (Op0.getOpcode() != ISD::SUB) - return SDValue(); - - Op1 = Op0.getOperand(1); - Op0 = Op0.getOperand(0); - - // Check if the operands of the diff are zero-extended from vectors of i8. - if (Op0.getOpcode() != ISD::ZERO_EXTEND || - Op0.getOperand(0).getValueType().getVectorElementType() != MVT::i8 || - Op1.getOpcode() != ISD::ZERO_EXTEND || - Op1.getOperand(0).getValueType().getVectorElementType() != MVT::i8) + // Check whether we have an abs-diff pattern feeding into the select. + if(!detectZextAbsDiff(SelectOp, Op0, Op1)) return SDValue(); // SAD pattern detected. Now build a SAD instruction and an addition for - // reduction. Note that the number of elments of the result of SAD is less + // reduction. Note that the number of elements of the result of SAD is less // than the number of elements of its input. Therefore, we could only update // part of elements in the reduction vector. - - // Legalize the type of the inputs of PSADBW. - EVT InVT = Op0.getOperand(0).getValueType(); - if (InVT.getSizeInBits() <= 128) - RegSize = 128; - else if (InVT.getSizeInBits() <= 256) - RegSize = 256; - - unsigned NumConcat = RegSize / InVT.getSizeInBits(); - SmallVector<SDValue, 16> Ops(NumConcat, DAG.getConstant(0, DL, InVT)); - Ops[0] = Op0.getOperand(0); - MVT ExtendedVT = MVT::getVectorVT(MVT::i8, RegSize / 8); - Op0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops); - Ops[0] = Op1.getOperand(0); - Op1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops); + SDValue Sad = createPSADBW(DAG, Op0, Op1, DL); // The output of PSADBW is a vector of i64. - MVT SadVT = MVT::getVectorVT(MVT::i64, RegSize / 64); - SDValue Sad = DAG.getNode(X86ISD::PSADBW, DL, SadVT, Op0, Op1); - // We need to turn the vector of i64 into a vector of i32. // If the reduction vector is at least as wide as the psadbw result, just // bitcast. If it's narrower, truncate - the high i32 of each i64 is zero // anyway. - MVT ResVT = MVT::getVectorVT(MVT::i32, RegSize / 32); + MVT ResVT = MVT::getVectorVT(MVT::i32, Sad.getValueSizeInBits() / 32); if (VT.getSizeInBits() >= ResVT.getSizeInBits()) Sad = DAG.getNode(ISD::BITCAST, DL, ResVT, Sad); else @@ -30793,7 +33654,7 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { const SDNodeFlags *Flags = &cast<BinaryWithFlagsSDNode>(N)->Flags; if (Flags->hasVectorReduction()) { - if (SDValue Sad = detectSADPattern(N, DAG, Subtarget)) + if (SDValue Sad = combineLoopSADPattern(N, DAG, Subtarget)) return Sad; } EVT VT = N->getValueType(0); @@ -30832,20 +33693,21 @@ static SDValue combineSub(SDNode *N, SelectionDAG &DAG, } } - // Try to synthesize horizontal adds from adds of shuffles. + // Try to synthesize horizontal subs from subs of shuffles. EVT VT = N->getValueType(0); if (((Subtarget.hasSSSE3() && (VT == MVT::v8i16 || VT == MVT::v4i32)) || (Subtarget.hasInt256() && (VT == MVT::v16i16 || VT == MVT::v8i32))) && - isHorizontalBinOp(Op0, Op1, true)) + isHorizontalBinOp(Op0, Op1, false)) return DAG.getNode(X86ISD::HSUB, SDLoc(N), VT, Op0, Op1); return OptimizeConditionalInDecrement(N, DAG); } -static SDValue combineVZext(SDNode *N, SelectionDAG &DAG, - TargetLowering::DAGCombinerInfo &DCI, - const X86Subtarget &Subtarget) { +static SDValue combineVSZext(SDNode *N, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI, + const X86Subtarget &Subtarget) { SDLoc DL(N); + unsigned Opcode = N->getOpcode(); MVT VT = N->getSimpleValueType(0); MVT SVT = VT.getVectorElementType(); SDValue Op = N->getOperand(0); @@ -30854,25 +33716,28 @@ static SDValue combineVZext(SDNode *N, SelectionDAG &DAG, unsigned InputBits = OpEltVT.getSizeInBits() * VT.getVectorNumElements(); // Perform any constant folding. + // FIXME: Reduce constant pool usage and don't fold when OptSize is enabled. if (ISD::isBuildVectorOfConstantSDNodes(Op.getNode())) { - SmallVector<SDValue, 4> Vals; - for (int i = 0, e = VT.getVectorNumElements(); i != e; ++i) { + unsigned NumDstElts = VT.getVectorNumElements(); + SmallBitVector Undefs(NumDstElts, false); + SmallVector<APInt, 4> Vals(NumDstElts, APInt(SVT.getSizeInBits(), 0)); + for (unsigned i = 0; i != NumDstElts; ++i) { SDValue OpElt = Op.getOperand(i); if (OpElt.getOpcode() == ISD::UNDEF) { - Vals.push_back(DAG.getUNDEF(SVT)); + Undefs[i] = true; continue; } APInt Cst = cast<ConstantSDNode>(OpElt.getNode())->getAPIntValue(); - assert(Cst.getBitWidth() == OpEltVT.getSizeInBits()); - Cst = Cst.zextOrTrunc(SVT.getSizeInBits()); - Vals.push_back(DAG.getConstant(Cst, DL, SVT)); + Vals[i] = Opcode == X86ISD::VZEXT ? Cst.zextOrTrunc(SVT.getSizeInBits()) + : Cst.sextOrTrunc(SVT.getSizeInBits()); } - return DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Vals); + return getConstVector(Vals, Undefs, VT, DAG, DL); } // (vzext (bitcast (vzext (x)) -> (vzext x) + // TODO: (vsext (bitcast (vsext (x)) -> (vsext x) SDValue V = peekThroughBitcasts(Op); - if (V != Op && V.getOpcode() == X86ISD::VZEXT) { + if (Opcode == X86ISD::VZEXT && V != Op && V.getOpcode() == X86ISD::VZEXT) { MVT InnerVT = V.getSimpleValueType(); MVT InnerEltVT = InnerVT.getVectorElementType(); @@ -30897,7 +33762,9 @@ static SDValue combineVZext(SDNode *N, SelectionDAG &DAG, // Check if we can bypass extracting and re-inserting an element of an input // vector. Essentially: // (bitcast (sclr2vec (ext_vec_elt x))) -> (bitcast x) - if (V.getOpcode() == ISD::SCALAR_TO_VECTOR && + // TODO: Add X86ISD::VSEXT support + if (Opcode == X86ISD::VZEXT && + V.getOpcode() == ISD::SCALAR_TO_VECTOR && V.getOperand(0).getOpcode() == ISD::EXTRACT_VECTOR_ELT && V.getOperand(0).getSimpleValueType().getSizeInBits() == InputBits) { SDValue ExtractedV = V.getOperand(0); @@ -30976,7 +33843,8 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N, SelectionDAG &DAG = DCI.DAG; switch (N->getOpcode()) { default: break; - case ISD::EXTRACT_VECTOR_ELT: return combineExtractVectorElt(N, DAG, DCI); + case ISD::EXTRACT_VECTOR_ELT: + return combineExtractVectorElt(N, DAG, DCI, Subtarget); case ISD::VSELECT: case ISD::SELECT: case X86ISD::SHRUNKBLEND: return combineSelect(N, DAG, DCI, Subtarget); @@ -31002,16 +33870,15 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N, case ISD::FSUB: return combineFaddFsub(N, DAG, Subtarget); case ISD::FNEG: return combineFneg(N, DAG, Subtarget); case ISD::TRUNCATE: return combineTruncate(N, DAG, Subtarget); + case X86ISD::FAND: return combineFAnd(N, DAG, Subtarget); + case X86ISD::FANDN: return combineFAndn(N, DAG, Subtarget); case X86ISD::FXOR: case X86ISD::FOR: return combineFOr(N, DAG, Subtarget); case X86ISD::FMIN: case X86ISD::FMAX: return combineFMinFMax(N, DAG); case ISD::FMINNUM: case ISD::FMAXNUM: return combineFMinNumFMaxNum(N, DAG, Subtarget); - case X86ISD::FAND: return combineFAnd(N, DAG, Subtarget); - case X86ISD::FANDN: return combineFAndn(N, DAG, Subtarget); case X86ISD::BT: return combineBT(N, DAG, DCI); - case X86ISD::VZEXT_MOVL: return combineVZextMovl(N, DAG); case ISD::ANY_EXTEND: case ISD::ZERO_EXTEND: return combineZext(N, DAG, DCI, Subtarget); case ISD::SIGN_EXTEND: return combineSext(N, DAG, DCI, Subtarget); @@ -31019,7 +33886,10 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N, case ISD::SETCC: return combineSetCC(N, DAG, Subtarget); case X86ISD::SETCC: return combineX86SetCC(N, DAG, DCI, Subtarget); case X86ISD::BRCOND: return combineBrCond(N, DAG, DCI, Subtarget); - case X86ISD::VZEXT: return combineVZext(N, DAG, DCI, Subtarget); + case X86ISD::VSHLI: + case X86ISD::VSRLI: return combineVectorShift(N, DAG, DCI, Subtarget); + case X86ISD::VSEXT: + case X86ISD::VZEXT: return combineVSZext(N, DAG, DCI, Subtarget); case X86ISD::SHUFP: // Handle all target specific shuffles case X86ISD::INSERTPS: case X86ISD::PALIGNR: @@ -31043,11 +33913,17 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N, case X86ISD::VPERMI: case X86ISD::VPERMV: case X86ISD::VPERMV3: + case X86ISD::VPERMIV3: case X86ISD::VPERMIL2: case X86ISD::VPERMILPI: case X86ISD::VPERMILPV: case X86ISD::VPERM2X128: + case X86ISD::VZEXT_MOVL: case ISD::VECTOR_SHUFFLE: return combineShuffle(N, DAG, DCI,Subtarget); + case X86ISD::FMADD: + case X86ISD::FMADD_RND: + case X86ISD::FMADDS1_RND: + case X86ISD::FMADDS3_RND: case ISD::FMA: return combineFMA(N, DAG, Subtarget); case ISD::MGATHER: case ISD::MSCATTER: return combineGatherScatter(N, DAG); @@ -31133,7 +34009,7 @@ bool X86TargetLowering::IsDesirableToPromoteOp(SDValue Op, EVT &PVT) const { case ISD::OR: case ISD::XOR: Commute = true; - // fallthrough + LLVM_FALLTHROUGH; case ISD::SUB: { SDValue N0 = Op.getOperand(0); SDValue N1 = Op.getOperand(1); @@ -31280,9 +34156,11 @@ X86TargetLowering::getConstraintType(StringRef Constraint) const { case 'u': case 'y': case 'x': + case 'v': case 'Y': case 'l': return C_RegisterClass; + case 'k': // AVX512 masking registers. case 'a': case 'b': case 'c': @@ -31306,6 +34184,19 @@ X86TargetLowering::getConstraintType(StringRef Constraint) const { break; } } + else if (Constraint.size() == 2) { + switch (Constraint[0]) { + default: + break; + case 'Y': + switch (Constraint[1]) { + default: + break; + case 'k': + return C_Register; + } + } + } return TargetLowering::getConstraintType(Constraint); } @@ -31349,12 +34240,28 @@ TargetLowering::ConstraintWeight if (type->isX86_MMXTy() && Subtarget.hasMMX()) weight = CW_SpecificReg; break; - case 'x': case 'Y': + // Other "Y<x>" (e.g. "Yk") constraints should be implemented below. + if (constraint[1] == 'k') { + // Support for 'Yk' (similarly to the 'k' variant below). + weight = CW_SpecificReg; + break; + } + // Else fall through (handle "Y" constraint). + LLVM_FALLTHROUGH; + case 'v': + if ((type->getPrimitiveSizeInBits() == 512) && Subtarget.hasAVX512()) + weight = CW_Register; + LLVM_FALLTHROUGH; + case 'x': if (((type->getPrimitiveSizeInBits() == 128) && Subtarget.hasSSE1()) || ((type->getPrimitiveSizeInBits() == 256) && Subtarget.hasFp256())) weight = CW_Register; break; + case 'k': + // Enable conditional vector operations using %k<#> registers. + weight = CW_SpecificReg; + break; case 'I': if (ConstantInt *C = dyn_cast<ConstantInt>(info.CallOperandVal)) { if (C->getZExtValue() <= 31) @@ -31601,60 +34508,21 @@ void X86TargetLowering::LowerAsmOperandForConstraint(SDValue Op, /// Check if \p RC is a general purpose register class. /// I.e., GR* or one of their variant. static bool isGRClass(const TargetRegisterClass &RC) { - switch (RC.getID()) { - case X86::GR8RegClassID: - case X86::GR8_ABCD_LRegClassID: - case X86::GR8_ABCD_HRegClassID: - case X86::GR8_NOREXRegClassID: - case X86::GR16RegClassID: - case X86::GR16_ABCDRegClassID: - case X86::GR16_NOREXRegClassID: - case X86::GR32RegClassID: - case X86::GR32_ABCDRegClassID: - case X86::GR32_TCRegClassID: - case X86::GR32_NOREXRegClassID: - case X86::GR32_NOAXRegClassID: - case X86::GR32_NOSPRegClassID: - case X86::GR32_NOREX_NOSPRegClassID: - case X86::GR32_ADRegClassID: - case X86::GR64RegClassID: - case X86::GR64_ABCDRegClassID: - case X86::GR64_TCRegClassID: - case X86::GR64_TCW64RegClassID: - case X86::GR64_NOREXRegClassID: - case X86::GR64_NOSPRegClassID: - case X86::GR64_NOREX_NOSPRegClassID: - case X86::LOW32_ADDR_ACCESSRegClassID: - case X86::LOW32_ADDR_ACCESS_RBPRegClassID: - return true; - default: - return false; - } + return RC.hasSuperClassEq(&X86::GR8RegClass) || + RC.hasSuperClassEq(&X86::GR16RegClass) || + RC.hasSuperClassEq(&X86::GR32RegClass) || + RC.hasSuperClassEq(&X86::GR64RegClass) || + RC.hasSuperClassEq(&X86::LOW32_ADDR_ACCESS_RBPRegClass); } /// Check if \p RC is a vector register class. /// I.e., FR* / VR* or one of their variant. static bool isFRClass(const TargetRegisterClass &RC) { - switch (RC.getID()) { - case X86::FR32RegClassID: - case X86::FR32XRegClassID: - case X86::FR64RegClassID: - case X86::FR64XRegClassID: - case X86::FR128RegClassID: - case X86::VR64RegClassID: - case X86::VR128RegClassID: - case X86::VR128LRegClassID: - case X86::VR128HRegClassID: - case X86::VR128XRegClassID: - case X86::VR256RegClassID: - case X86::VR256LRegClassID: - case X86::VR256HRegClassID: - case X86::VR256XRegClassID: - case X86::VR512RegClassID: - return true; - default: - return false; - } + return RC.hasSuperClassEq(&X86::FR32XRegClass) || + RC.hasSuperClassEq(&X86::FR64XRegClass) || + RC.hasSuperClassEq(&X86::VR128XRegClass) || + RC.hasSuperClassEq(&X86::VR256XRegClass) || + RC.hasSuperClassEq(&X86::VR512RegClass); } std::pair<unsigned, const TargetRegisterClass *> @@ -31670,6 +34538,24 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, // TODO: Slight differences here in allocation order and leaving // RIP in the class. Do they matter any more here than they do // in the normal allocation? + case 'k': + if (Subtarget.hasAVX512()) { + // Only supported in AVX512 or later. + switch (VT.SimpleTy) { + default: break; + case MVT::i32: + return std::make_pair(0U, &X86::VK32RegClass); + case MVT::i16: + return std::make_pair(0U, &X86::VK16RegClass); + case MVT::i8: + return std::make_pair(0U, &X86::VK8RegClass); + case MVT::i1: + return std::make_pair(0U, &X86::VK1RegClass); + case MVT::i64: + return std::make_pair(0U, &X86::VK64RegClass); + } + } + break; case 'q': // GENERAL_REGS in 64-bit mode, Q_REGS in 32-bit mode. if (Subtarget.is64Bit()) { if (VT == MVT::i32 || VT == MVT::f32) @@ -31723,18 +34609,24 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, return std::make_pair(0U, &X86::VR64RegClass); case 'Y': // SSE_REGS if SSE2 allowed if (!Subtarget.hasSSE2()) break; - // FALL THROUGH. + LLVM_FALLTHROUGH; + case 'v': case 'x': // SSE_REGS if SSE1 allowed or AVX_REGS if AVX allowed if (!Subtarget.hasSSE1()) break; + bool VConstraint = (Constraint[0] == 'v'); switch (VT.SimpleTy) { default: break; // Scalar SSE types. case MVT::f32: case MVT::i32: + if (VConstraint && Subtarget.hasAVX512() && Subtarget.hasVLX()) + return std::make_pair(0U, &X86::FR32XRegClass); return std::make_pair(0U, &X86::FR32RegClass); case MVT::f64: case MVT::i64: + if (VConstraint && Subtarget.hasVLX()) + return std::make_pair(0U, &X86::FR64XRegClass); return std::make_pair(0U, &X86::FR64RegClass); // TODO: Handle f128 and i128 in FR128RegClass after it is tested well. // Vector types. @@ -31744,6 +34636,8 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, case MVT::v2i64: case MVT::v4f32: case MVT::v2f64: + if (VConstraint && Subtarget.hasVLX()) + return std::make_pair(0U, &X86::VR128XRegClass); return std::make_pair(0U, &X86::VR128RegClass); // AVX types. case MVT::v32i8: @@ -31752,6 +34646,8 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, case MVT::v4i64: case MVT::v8f32: case MVT::v4f64: + if (VConstraint && Subtarget.hasVLX()) + return std::make_pair(0U, &X86::VR256XRegClass); return std::make_pair(0U, &X86::VR256RegClass); case MVT::v8f64: case MVT::v16f32: @@ -31761,6 +34657,29 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, } break; } + } else if (Constraint.size() == 2 && Constraint[0] == 'Y') { + switch (Constraint[1]) { + default: + break; + case 'k': + // This register class doesn't allocate k0 for masked vector operation. + if (Subtarget.hasAVX512()) { // Only supported in AVX512. + switch (VT.SimpleTy) { + default: break; + case MVT::i32: + return std::make_pair(0U, &X86::VK32WMRegClass); + case MVT::i16: + return std::make_pair(0U, &X86::VK16WMRegClass); + case MVT::i8: + return std::make_pair(0U, &X86::VK8WMRegClass); + case MVT::i1: + return std::make_pair(0U, &X86::VK1WMRegClass); + case MVT::i64: + return std::make_pair(0U, &X86::VK64WMRegClass); + } + } + break; + } } // Use the default implementation in TargetLowering to convert the register @@ -31954,3 +34873,7 @@ void X86TargetLowering::insertCopiesSplitCSR( .addReg(NewVR); } } + +bool X86TargetLowering::supportSwiftError() const { + return Subtarget.is64Bit(); +} |