diff options
Diffstat (limited to 'contrib/llvm/lib/Target/NVPTX')
36 files changed, 4578 insertions, 2606 deletions
diff --git a/contrib/llvm/lib/Target/NVPTX/InstPrinter/NVPTXInstPrinter.cpp b/contrib/llvm/lib/Target/NVPTX/InstPrinter/NVPTXInstPrinter.cpp index 10051c7..d5be0e4 100644 --- a/contrib/llvm/lib/Target/NVPTX/InstPrinter/NVPTXInstPrinter.cpp +++ b/contrib/llvm/lib/Target/NVPTX/InstPrinter/NVPTXInstPrinter.cpp @@ -1 +1,289 @@ -// Placeholder +//===-- NVPTXInstPrinter.cpp - PTX assembly instruction printing ----------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// Print MCInst instructions to .ptx format. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "asm-printer" +#include "InstPrinter/NVPTXInstPrinter.h" +#include "NVPTX.h" +#include "MCTargetDesc/NVPTXBaseInfo.h" +#include "llvm/MC/MCExpr.h" +#include "llvm/MC/MCInst.h" +#include "llvm/MC/MCInstrInfo.h" +#include "llvm/MC/MCSymbol.h" +#include "llvm/MC/MCSubtargetInfo.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/FormattedStream.h" +#include <cctype> +using namespace llvm; + +#include "NVPTXGenAsmWriter.inc" + + +NVPTXInstPrinter::NVPTXInstPrinter(const MCAsmInfo &MAI, const MCInstrInfo &MII, + const MCRegisterInfo &MRI, + const MCSubtargetInfo &STI) + : MCInstPrinter(MAI, MII, MRI) { + setAvailableFeatures(STI.getFeatureBits()); +} + +void NVPTXInstPrinter::printRegName(raw_ostream &OS, unsigned RegNo) const { + // Decode the virtual register + // Must be kept in sync with NVPTXAsmPrinter::encodeVirtualRegister + unsigned RCId = (RegNo >> 28); + switch (RCId) { + default: report_fatal_error("Bad virtual register encoding"); + case 0: + // This is actually a physical register, so defer to the autogenerated + // register printer + OS << getRegisterName(RegNo); + return; + case 1: + OS << "%p"; + break; + case 2: + OS << "%rs"; + break; + case 3: + OS << "%r"; + break; + case 4: + OS << "%rl"; + break; + case 5: + OS << "%f"; + break; + case 6: + OS << "%fl"; + break; + } + + unsigned VReg = RegNo & 0x0FFFFFFF; + OS << VReg; +} + +void NVPTXInstPrinter::printInst(const MCInst *MI, raw_ostream &OS, + StringRef Annot) { + printInstruction(MI, OS); + + // Next always print the annotation. + printAnnotation(OS, Annot); +} + +void NVPTXInstPrinter::printOperand(const MCInst *MI, unsigned OpNo, + raw_ostream &O) { + const MCOperand &Op = MI->getOperand(OpNo); + if (Op.isReg()) { + unsigned Reg = Op.getReg(); + printRegName(O, Reg); + } else if (Op.isImm()) { + O << markup("<imm:") << formatImm(Op.getImm()) << markup(">"); + } else { + assert(Op.isExpr() && "Unknown operand kind in printOperand"); + O << *Op.getExpr(); + } +} + +void NVPTXInstPrinter::printCvtMode(const MCInst *MI, int OpNum, raw_ostream &O, + const char *Modifier) { + const MCOperand &MO = MI->getOperand(OpNum); + int64_t Imm = MO.getImm(); + + if (strcmp(Modifier, "ftz") == 0) { + // FTZ flag + if (Imm & NVPTX::PTXCvtMode::FTZ_FLAG) + O << ".ftz"; + } else if (strcmp(Modifier, "sat") == 0) { + // SAT flag + if (Imm & NVPTX::PTXCvtMode::SAT_FLAG) + O << ".sat"; + } else if (strcmp(Modifier, "base") == 0) { + // Default operand + switch (Imm & NVPTX::PTXCvtMode::BASE_MASK) { + default: + return; + case NVPTX::PTXCvtMode::NONE: + break; + case NVPTX::PTXCvtMode::RNI: + O << ".rni"; + break; + case NVPTX::PTXCvtMode::RZI: + O << ".rzi"; + break; + case NVPTX::PTXCvtMode::RMI: + O << ".rmi"; + break; + case NVPTX::PTXCvtMode::RPI: + O << ".rpi"; + break; + case NVPTX::PTXCvtMode::RN: + O << ".rn"; + break; + case NVPTX::PTXCvtMode::RZ: + O << ".rz"; + break; + case NVPTX::PTXCvtMode::RM: + O << ".rm"; + break; + case NVPTX::PTXCvtMode::RP: + O << ".rp"; + break; + } + } else { + llvm_unreachable("Invalid conversion modifier"); + } +} + +void NVPTXInstPrinter::printCmpMode(const MCInst *MI, int OpNum, raw_ostream &O, + const char *Modifier) { + const MCOperand &MO = MI->getOperand(OpNum); + int64_t Imm = MO.getImm(); + + if (strcmp(Modifier, "ftz") == 0) { + // FTZ flag + if (Imm & NVPTX::PTXCmpMode::FTZ_FLAG) + O << ".ftz"; + } else if (strcmp(Modifier, "base") == 0) { + switch (Imm & NVPTX::PTXCmpMode::BASE_MASK) { + default: + return; + case NVPTX::PTXCmpMode::EQ: + O << ".eq"; + break; + case NVPTX::PTXCmpMode::NE: + O << ".ne"; + break; + case NVPTX::PTXCmpMode::LT: + O << ".lt"; + break; + case NVPTX::PTXCmpMode::LE: + O << ".le"; + break; + case NVPTX::PTXCmpMode::GT: + O << ".gt"; + break; + case NVPTX::PTXCmpMode::GE: + O << ".ge"; + break; + case NVPTX::PTXCmpMode::LO: + O << ".lo"; + break; + case NVPTX::PTXCmpMode::LS: + O << ".ls"; + break; + case NVPTX::PTXCmpMode::HI: + O << ".hi"; + break; + case NVPTX::PTXCmpMode::HS: + O << ".hs"; + break; + case NVPTX::PTXCmpMode::EQU: + O << ".equ"; + break; + case NVPTX::PTXCmpMode::NEU: + O << ".neu"; + break; + case NVPTX::PTXCmpMode::LTU: + O << ".ltu"; + break; + case NVPTX::PTXCmpMode::LEU: + O << ".leu"; + break; + case NVPTX::PTXCmpMode::GTU: + O << ".gtu"; + break; + case NVPTX::PTXCmpMode::GEU: + O << ".geu"; + break; + case NVPTX::PTXCmpMode::NUM: + O << ".num"; + break; + case NVPTX::PTXCmpMode::NotANumber: + O << ".nan"; + break; + } + } else { + llvm_unreachable("Empty Modifier"); + } +} + +void NVPTXInstPrinter::printLdStCode(const MCInst *MI, int OpNum, + raw_ostream &O, const char *Modifier) { + if (Modifier) { + const MCOperand &MO = MI->getOperand(OpNum); + int Imm = (int) MO.getImm(); + if (!strcmp(Modifier, "volatile")) { + if (Imm) + O << ".volatile"; + } else if (!strcmp(Modifier, "addsp")) { + switch (Imm) { + case NVPTX::PTXLdStInstCode::GLOBAL: + O << ".global"; + break; + case NVPTX::PTXLdStInstCode::SHARED: + O << ".shared"; + break; + case NVPTX::PTXLdStInstCode::LOCAL: + O << ".local"; + break; + case NVPTX::PTXLdStInstCode::PARAM: + O << ".param"; + break; + case NVPTX::PTXLdStInstCode::CONSTANT: + O << ".const"; + break; + case NVPTX::PTXLdStInstCode::GENERIC: + break; + default: + llvm_unreachable("Wrong Address Space"); + } + } else if (!strcmp(Modifier, "sign")) { + if (Imm == NVPTX::PTXLdStInstCode::Signed) + O << "s"; + else if (Imm == NVPTX::PTXLdStInstCode::Unsigned) + O << "u"; + else + O << "f"; + } else if (!strcmp(Modifier, "vec")) { + if (Imm == NVPTX::PTXLdStInstCode::V2) + O << ".v2"; + else if (Imm == NVPTX::PTXLdStInstCode::V4) + O << ".v4"; + } else + llvm_unreachable("Unknown Modifier"); + } else + llvm_unreachable("Empty Modifier"); +} + +void NVPTXInstPrinter::printMemOperand(const MCInst *MI, int OpNum, + raw_ostream &O, const char *Modifier) { + printOperand(MI, OpNum, O); + + if (Modifier && !strcmp(Modifier, "add")) { + O << ", "; + printOperand(MI, OpNum + 1, O); + } else { + if (MI->getOperand(OpNum + 1).isImm() && + MI->getOperand(OpNum + 1).getImm() == 0) + return; // don't print ',0' or '+0' + O << "+"; + printOperand(MI, OpNum + 1, O); + } +} + +void NVPTXInstPrinter::printProtoIdent(const MCInst *MI, int OpNum, + raw_ostream &O, const char *Modifier) { + const MCOperand &Op = MI->getOperand(OpNum); + assert(Op.isExpr() && "Call prototype is not an MCExpr?"); + const MCExpr *Expr = Op.getExpr(); + const MCSymbol &Sym = cast<MCSymbolRefExpr>(Expr)->getSymbol(); + O << Sym.getName(); +} diff --git a/contrib/llvm/lib/Target/NVPTX/InstPrinter/NVPTXInstPrinter.h b/contrib/llvm/lib/Target/NVPTX/InstPrinter/NVPTXInstPrinter.h new file mode 100644 index 0000000..93029ae --- /dev/null +++ b/contrib/llvm/lib/Target/NVPTX/InstPrinter/NVPTXInstPrinter.h @@ -0,0 +1,53 @@ +//= NVPTXInstPrinter.h - Convert NVPTX MCInst to assembly syntax --*- C++ -*-=// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This class prints an NVPTX MCInst to .ptx file syntax. +// +//===----------------------------------------------------------------------===// + +#ifndef NVPTX_INST_PRINTER_H +#define NVPTX_INST_PRINTER_H + +#include "llvm/MC/MCInstPrinter.h" +#include "llvm/Support/raw_ostream.h" + +namespace llvm { + +class MCOperand; +class MCSubtargetInfo; + +class NVPTXInstPrinter : public MCInstPrinter { +public: + NVPTXInstPrinter(const MCAsmInfo &MAI, const MCInstrInfo &MII, + const MCRegisterInfo &MRI, const MCSubtargetInfo &STI); + + virtual void printRegName(raw_ostream &OS, unsigned RegNo) const; + virtual void printInst(const MCInst *MI, raw_ostream &OS, StringRef Annot); + + // Autogenerated by tblgen. + void printInstruction(const MCInst *MI, raw_ostream &O); + static const char *getRegisterName(unsigned RegNo); + // End + + void printOperand(const MCInst *MI, unsigned OpNo, raw_ostream &O); + void printCvtMode(const MCInst *MI, int OpNum, raw_ostream &O, + const char *Modifier = 0); + void printCmpMode(const MCInst *MI, int OpNum, raw_ostream &O, + const char *Modifier = 0); + void printLdStCode(const MCInst *MI, int OpNum, + raw_ostream &O, const char *Modifier = 0); + void printMemOperand(const MCInst *MI, int OpNum, + raw_ostream &O, const char *Modifier = 0); + void printProtoIdent(const MCInst *MI, int OpNum, + raw_ostream &O, const char *Modifier = 0); +}; + +} + +#endif diff --git a/contrib/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXBaseInfo.h b/contrib/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXBaseInfo.h index b3e8b5d..edf4a80 100644 --- a/contrib/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXBaseInfo.h +++ b/contrib/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXBaseInfo.h @@ -22,7 +22,6 @@ namespace llvm { enum AddressSpace { ADDRESS_SPACE_GENERIC = 0, ADDRESS_SPACE_GLOBAL = 1, - ADDRESS_SPACE_CONST_NOT_GEN = 2, // Not part of generic space ADDRESS_SPACE_SHARED = 3, ADDRESS_SPACE_CONST = 4, ADDRESS_SPACE_LOCAL = 5, diff --git a/contrib/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXMCAsmInfo.cpp b/contrib/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXMCAsmInfo.cpp index 459cd96..f2784b8 100644 --- a/contrib/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXMCAsmInfo.cpp +++ b/contrib/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXMCAsmInfo.cpp @@ -17,17 +17,15 @@ using namespace llvm; -bool CompileForDebugging; - // -debug-compile - Command line option to inform opt and llc passes to // compile for debugging -static cl::opt<bool, true> -Debug("debug-compile", cl::desc("Compile for debugging"), cl::Hidden, - cl::location(CompileForDebugging), cl::init(false)); +static cl::opt<bool> CompileForDebugging("debug-compile", + cl::desc("Compile for debugging"), + cl::Hidden, cl::init(false)); void NVPTXMCAsmInfo::anchor() {} -NVPTXMCAsmInfo::NVPTXMCAsmInfo(const Target &T, const StringRef &TT) { +NVPTXMCAsmInfo::NVPTXMCAsmInfo(const StringRef &TT) { Triple TheTriple(TT); if (TheTriple.getArch() == Triple::nvptx64) { PointerSize = CalleeSaveStackSlotSize = 8; @@ -37,8 +35,6 @@ NVPTXMCAsmInfo::NVPTXMCAsmInfo(const Target &T, const StringRef &TT) { PrivateGlobalPrefix = "$L__"; - AllowPeriodsInName = false; - HasSetDirective = false; HasSingleParameterDotFile = false; diff --git a/contrib/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXMCAsmInfo.h b/contrib/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXMCAsmInfo.h index 82097da..7d1633f 100644 --- a/contrib/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXMCAsmInfo.h +++ b/contrib/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXMCAsmInfo.h @@ -23,7 +23,7 @@ class StringRef; class NVPTXMCAsmInfo : public MCAsmInfo { virtual void anchor(); public: - explicit NVPTXMCAsmInfo(const Target &T, const StringRef &TT); + explicit NVPTXMCAsmInfo(const StringRef &TT); }; } // namespace llvm diff --git a/contrib/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXMCTargetDesc.cpp b/contrib/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXMCTargetDesc.cpp index ccd2970..871bac9 100644 --- a/contrib/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXMCTargetDesc.cpp +++ b/contrib/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXMCTargetDesc.cpp @@ -13,6 +13,7 @@ #include "NVPTXMCTargetDesc.h" #include "NVPTXMCAsmInfo.h" +#include "InstPrinter/NVPTXInstPrinter.h" #include "llvm/MC/MCCodeGenInfo.h" #include "llvm/MC/MCInstrInfo.h" #include "llvm/MC/MCRegisterInfo.h" @@ -57,6 +58,17 @@ static MCCodeGenInfo *createNVPTXMCCodeGenInfo( return X; } +static MCInstPrinter *createNVPTXMCInstPrinter(const Target &T, + unsigned SyntaxVariant, + const MCAsmInfo &MAI, + const MCInstrInfo &MII, + const MCRegisterInfo &MRI, + const MCSubtargetInfo &STI) { + if (SyntaxVariant == 0) + return new NVPTXInstPrinter(MAI, MII, MRI, STI); + return 0; +} + // Force static initialization. extern "C" void LLVMInitializeNVPTXTargetMC() { // Register the MC asm info. @@ -85,4 +97,9 @@ extern "C" void LLVMInitializeNVPTXTargetMC() { TargetRegistry::RegisterMCSubtargetInfo(TheNVPTXTarget64, createNVPTXMCSubtargetInfo); + // Register the MCInstPrinter. + TargetRegistry::RegisterMCInstPrinter(TheNVPTXTarget32, + createNVPTXMCInstPrinter); + TargetRegistry::RegisterMCInstPrinter(TheNVPTXTarget64, + createNVPTXMCInstPrinter); } diff --git a/contrib/llvm/lib/Target/NVPTX/ManagedStringPool.h b/contrib/llvm/lib/Target/NVPTX/ManagedStringPool.h index d6c79b5..f9fb059 100644 --- a/contrib/llvm/lib/Target/NVPTX/ManagedStringPool.h +++ b/contrib/llvm/lib/Target/NVPTX/ManagedStringPool.h @@ -29,7 +29,7 @@ class ManagedStringPool { public: ManagedStringPool() {} ~ManagedStringPool() { - SmallVector<std::string *, 8>::iterator Current = Pool.begin(); + SmallVectorImpl<std::string *>::iterator Current = Pool.begin(); while (Current != Pool.end()) { delete *Current; Current++; diff --git a/contrib/llvm/lib/Target/NVPTX/NVPTX.h b/contrib/llvm/lib/Target/NVPTX/NVPTX.h index 072c65d..490b49d 100644 --- a/contrib/llvm/lib/Target/NVPTX/NVPTX.h +++ b/contrib/llvm/lib/Target/NVPTX/NVPTX.h @@ -27,6 +27,7 @@ namespace llvm { class NVPTXTargetMachine; class FunctionPass; +class MachineFunctionPass; class formatted_raw_ostream; namespace NVPTXCC { @@ -60,12 +61,10 @@ inline static const char *NVPTXCondCodeToString(NVPTXCC::CondCodes CC) { FunctionPass * createNVPTXISelDag(NVPTXTargetMachine &TM, llvm::CodeGenOpt::Level OptLevel); -FunctionPass *createLowerStructArgsPass(NVPTXTargetMachine &); -FunctionPass *createNVPTXReMatPass(NVPTXTargetMachine &); -FunctionPass *createNVPTXReMatBlockPass(NVPTXTargetMachine &); ModulePass *createGenericToNVVMPass(); ModulePass *createNVVMReflectPass(); ModulePass *createNVVMReflectPass(const StringMap<int>& Mapping); +MachineFunctionPass *createNVPTXPrologEpilogPass(); bool isImageOrSamplerVal(const Value *, const Module *); @@ -75,8 +74,7 @@ extern Target TheNVPTXTarget64; namespace NVPTX { enum DrvInterface { NVCL, - CUDA, - TEST + CUDA }; // A field inside TSFlags needs a shift and a mask. The usage is @@ -130,6 +128,53 @@ enum VecType { V4 = 4 }; } + +/// PTXCvtMode - Conversion code enumeration +namespace PTXCvtMode { +enum CvtMode { + NONE = 0, + RNI, + RZI, + RMI, + RPI, + RN, + RZ, + RM, + RP, + + BASE_MASK = 0x0F, + FTZ_FLAG = 0x10, + SAT_FLAG = 0x20 +}; +} + +/// PTXCmpMode - Comparison mode enumeration +namespace PTXCmpMode { +enum CmpMode { + EQ = 0, + NE, + LT, + LE, + GT, + GE, + LO, + LS, + HI, + HS, + EQU, + NEU, + LTU, + LEU, + GTU, + GEU, + NUM, + // NAN is a MACRO + NotANumber, + + BASE_MASK = 0xFF, + FTZ_FLAG = 0x100 +}; +} } } // end namespace llvm; diff --git a/contrib/llvm/lib/Target/NVPTX/NVPTX.td b/contrib/llvm/lib/Target/NVPTX/NVPTX.td index d78b4e8..6183a75 100644 --- a/contrib/llvm/lib/Target/NVPTX/NVPTX.td +++ b/contrib/llvm/lib/Target/NVPTX/NVPTX.td @@ -57,6 +57,12 @@ def : Proc<"sm_35", [SM35]>; def NVPTXInstrInfo : InstrInfo { } +def NVPTXAsmWriter : AsmWriter { + bit isMCAsmWriter = 1; + string AsmWriterClassName = "InstPrinter"; +} + def NVPTX : Target { let InstructionSet = NVPTXInstrInfo; + let AssemblyWriters = [NVPTXAsmWriter]; } diff --git a/contrib/llvm/lib/Target/NVPTX/NVPTXAllocaHoisting.cpp b/contrib/llvm/lib/Target/NVPTX/NVPTXAllocaHoisting.cpp index 0f792ec..1f37696 100644 --- a/contrib/llvm/lib/Target/NVPTX/NVPTXAllocaHoisting.cpp +++ b/contrib/llvm/lib/Target/NVPTX/NVPTXAllocaHoisting.cpp @@ -37,7 +37,7 @@ bool NVPTXAllocaHoisting::runOnFunction(Function &function) { } char NVPTXAllocaHoisting::ID = 1; -RegisterPass<NVPTXAllocaHoisting> +static RegisterPass<NVPTXAllocaHoisting> X("alloca-hoisting", "Hoisting alloca instructions in non-entry " "blocks to the entry block"); diff --git a/contrib/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/contrib/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp index 229e4e5..7552fe7 100644 --- a/contrib/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp +++ b/contrib/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp @@ -16,10 +16,11 @@ #include "MCTargetDesc/NVPTXMCAsmInfo.h" #include "NVPTX.h" #include "NVPTXInstrInfo.h" -#include "NVPTXNumRegisters.h" +#include "NVPTXMCExpr.h" #include "NVPTXRegisterInfo.h" #include "NVPTXTargetMachine.h" #include "NVPTXUtilities.h" +#include "InstPrinter/NVPTXInstPrinter.h" #include "cl_common_defines.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Analysis/ConstantFolding.h" @@ -47,23 +48,17 @@ #include <sstream> using namespace llvm; -#include "NVPTXGenAsmWriter.inc" - -bool RegAllocNilUsed = true; - #define DEPOTNAME "__local_depot" static cl::opt<bool> -EmitLineNumbers("nvptx-emit-line-numbers", +EmitLineNumbers("nvptx-emit-line-numbers", cl::Hidden, cl::desc("NVPTX Specific: Emit Line numbers even without -G"), cl::init(true)); -namespace llvm { bool InterleaveSrcInPtx = false; } - -static cl::opt<bool, true> -InterleaveSrc("nvptx-emit-src", cl::ZeroOrMore, +static cl::opt<bool> +InterleaveSrc("nvptx-emit-src", cl::ZeroOrMore, cl::Hidden, cl::desc("NVPTX Specific: Emit source line in ptx file"), - cl::location(llvm::InterleaveSrcInPtx)); + cl::init(false)); namespace { /// DiscoverDependentGlobals - Return a set of GlobalVariables on which \p V @@ -131,7 +126,7 @@ const MCExpr *nvptx::LowerConstant(const Constant *CV, AsmPrinter &AP) { return MCConstantExpr::Create(CI->getZExtValue(), Ctx); if (const GlobalValue *GV = dyn_cast<GlobalValue>(CV)) - return MCSymbolRefExpr::Create(AP.Mang->getSymbol(GV), Ctx); + return MCSymbolRefExpr::Create(AP.getSymbol(GV), Ctx); if (const BlockAddress *BA = dyn_cast<BlockAddress>(CV)) return MCSymbolRefExpr::Create(AP.GetBlockAddressSymbol(BA), Ctx); @@ -279,8 +274,10 @@ void NVPTXAsmPrinter::emitLineNumberAsDotLoc(const MachineInstr &MI) { const LLVMContext &ctx = MF->getFunction()->getContext(); DIScope Scope(curLoc.getScope(ctx)); - if (!Scope.Verify()) - return; + assert((!Scope || Scope.isScope()) && + "Scope of a DebugLoc should be null or a DIScope."); + if (!Scope) + return; StringRef fileName(Scope.getFilename()); StringRef dirName(Scope.getDirectory()); @@ -294,7 +291,7 @@ void NVPTXAsmPrinter::emitLineNumberAsDotLoc(const MachineInstr &MI) { return; // Emit the line from the source file. - if (llvm::InterleaveSrcInPtx) + if (InterleaveSrc) this->emitSrcInText(fileName.str(), curLoc.getLine()); std::stringstream temp; @@ -308,8 +305,115 @@ void NVPTXAsmPrinter::EmitInstruction(const MachineInstr *MI) { raw_svector_ostream OS(Str); if (nvptxSubtarget.getDrvInterface() == NVPTX::CUDA) emitLineNumberAsDotLoc(*MI); - printInstruction(MI, OS); - OutStreamer.EmitRawText(OS.str()); + + MCInst Inst; + lowerToMCInst(MI, Inst); + OutStreamer.EmitInstruction(Inst); +} + +void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) { + OutMI.setOpcode(MI->getOpcode()); + + // Special: Do not mangle symbol operand of CALL_PROTOTYPE + if (MI->getOpcode() == NVPTX::CALL_PROTOTYPE) { + const MachineOperand &MO = MI->getOperand(0); + OutMI.addOperand(GetSymbolRef(MO, + OutContext.GetOrCreateSymbol(Twine(MO.getSymbolName())))); + return; + } + + for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) { + const MachineOperand &MO = MI->getOperand(i); + + MCOperand MCOp; + if (lowerOperand(MO, MCOp)) + OutMI.addOperand(MCOp); + } +} + +bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO, + MCOperand &MCOp) { + switch (MO.getType()) { + default: llvm_unreachable("unknown operand type"); + case MachineOperand::MO_Register: + MCOp = MCOperand::CreateReg(encodeVirtualRegister(MO.getReg())); + break; + case MachineOperand::MO_Immediate: + MCOp = MCOperand::CreateImm(MO.getImm()); + break; + case MachineOperand::MO_MachineBasicBlock: + MCOp = MCOperand::CreateExpr(MCSymbolRefExpr::Create( + MO.getMBB()->getSymbol(), OutContext)); + break; + case MachineOperand::MO_ExternalSymbol: + MCOp = GetSymbolRef(MO, GetExternalSymbolSymbol(MO.getSymbolName())); + break; + case MachineOperand::MO_GlobalAddress: + MCOp = GetSymbolRef(MO, getSymbol(MO.getGlobal())); + break; + case MachineOperand::MO_FPImmediate: { + const ConstantFP *Cnt = MO.getFPImm(); + APFloat Val = Cnt->getValueAPF(); + + switch (Cnt->getType()->getTypeID()) { + default: report_fatal_error("Unsupported FP type"); break; + case Type::FloatTyID: + MCOp = MCOperand::CreateExpr( + NVPTXFloatMCExpr::CreateConstantFPSingle(Val, OutContext)); + break; + case Type::DoubleTyID: + MCOp = MCOperand::CreateExpr( + NVPTXFloatMCExpr::CreateConstantFPDouble(Val, OutContext)); + break; + } + break; + } + } + return true; +} + +unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) { + if (TargetRegisterInfo::isVirtualRegister(Reg)) { + const TargetRegisterClass *RC = MRI->getRegClass(Reg); + + DenseMap<unsigned, unsigned> &RegMap = VRegMapping[RC]; + unsigned RegNum = RegMap[Reg]; + + // Encode the register class in the upper 4 bits + // Must be kept in sync with NVPTXInstPrinter::printRegName + unsigned Ret = 0; + if (RC == &NVPTX::Int1RegsRegClass) { + Ret = (1 << 28); + } else if (RC == &NVPTX::Int16RegsRegClass) { + Ret = (2 << 28); + } else if (RC == &NVPTX::Int32RegsRegClass) { + Ret = (3 << 28); + } else if (RC == &NVPTX::Int64RegsRegClass) { + Ret = (4 << 28); + } else if (RC == &NVPTX::Float32RegsRegClass) { + Ret = (5 << 28); + } else if (RC == &NVPTX::Float64RegsRegClass) { + Ret = (6 << 28); + } else { + report_fatal_error("Bad register class"); + } + + // Insert the vreg number + Ret |= (RegNum & 0x0FFFFFFF); + return Ret; + } else { + // Some special-use registers are actually physical registers. + // Encode this as the register class ID of 0 and the real register ID. + return Reg & 0x0FFFFFFF; + } +} + +MCOperand NVPTXAsmPrinter::GetSymbolRef(const MachineOperand &MO, + const MCSymbol *Symbol) { + const MCExpr *Expr; + Expr = MCSymbolRefExpr::Create(Symbol, MCSymbolRefExpr::VK_None, + OutContext); + return MCOperand::CreateExpr(Expr); } void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) { @@ -436,9 +540,7 @@ void NVPTXAsmPrinter::EmitFunctionEntryLabel() { } void NVPTXAsmPrinter::EmitFunctionBodyStart() { - const TargetRegisterInfo &TRI = *TM.getRegisterInfo(); - unsigned numRegClasses = TRI.getNumRegClasses(); - VRidGlobal2LocalMap = new std::map<unsigned, unsigned>[numRegClasses + 1]; + VRegMapping.clear(); OutStreamer.EmitRawText(StringRef("{\n")); setAndEmitFunctionVirtualRegisters(*MF); @@ -450,7 +552,20 @@ void NVPTXAsmPrinter::EmitFunctionBodyStart() { void NVPTXAsmPrinter::EmitFunctionBodyEnd() { OutStreamer.EmitRawText(StringRef("}\n")); - delete[] VRidGlobal2LocalMap; + VRegMapping.clear(); +} + +void NVPTXAsmPrinter::emitImplicitDef(const MachineInstr *MI) const { + unsigned RegNo = MI->getOperand(0).getReg(); + const TargetRegisterInfo *TRI = TM.getRegisterInfo(); + if (TRI->isVirtualRegister(RegNo)) { + OutStreamer.AddComment(Twine("implicit-def: ") + + getVirtualRegisterName(RegNo)); + } else { + OutStreamer.AddComment(Twine("implicit-def: ") + + TM.getRegisterInfo()->getName(RegNo)); + } + OutStreamer.AddBlankLine(); } void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F, @@ -504,24 +619,30 @@ void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F, O << ".minnctapersm " << mincta << "\n"; } -void NVPTXAsmPrinter::getVirtualRegisterName(unsigned vr, bool isVec, - raw_ostream &O) { - const TargetRegisterClass *RC = MRI->getRegClass(vr); - unsigned id = RC->getID(); +std::string +NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const { + const TargetRegisterClass *RC = MRI->getRegClass(Reg); - std::map<unsigned, unsigned> ®map = VRidGlobal2LocalMap[id]; - unsigned mapped_vr = regmap[vr]; + std::string Name; + raw_string_ostream NameStr(Name); - if (!isVec) { - O << getNVPTXRegClassStr(RC) << mapped_vr; - return; - } - report_fatal_error("Bad register!"); + VRegRCMap::const_iterator I = VRegMapping.find(RC); + assert(I != VRegMapping.end() && "Bad register class"); + const DenseMap<unsigned, unsigned> &RegMap = I->second; + + VRegMap::const_iterator VI = RegMap.find(Reg); + assert(VI != RegMap.end() && "Bad virtual register"); + unsigned MappedVR = VI->second; + + NameStr << getNVPTXRegClassStr(RC) << MappedVR; + + NameStr.flush(); + return Name; } -void NVPTXAsmPrinter::emitVirtualRegister(unsigned int vr, bool isVec, +void NVPTXAsmPrinter::emitVirtualRegister(unsigned int vr, raw_ostream &O) { - getVirtualRegisterName(vr, isVec, O); + O << getVirtualRegisterName(vr); } void NVPTXAsmPrinter::printVecModifiedImmediate( @@ -554,145 +675,7 @@ void NVPTXAsmPrinter::printVecModifiedImmediate( llvm_unreachable("Unknown Modifier on immediate operand"); } -void NVPTXAsmPrinter::printOperand(const MachineInstr *MI, int opNum, - raw_ostream &O, const char *Modifier) { - const MachineOperand &MO = MI->getOperand(opNum); - switch (MO.getType()) { - case MachineOperand::MO_Register: - if (TargetRegisterInfo::isPhysicalRegister(MO.getReg())) { - if (MO.getReg() == NVPTX::VRDepot) - O << DEPOTNAME << getFunctionNumber(); - else - O << getRegisterName(MO.getReg()); - } else { - if (!Modifier) - emitVirtualRegister(MO.getReg(), false, O); - else { - if (strcmp(Modifier, "vecfull") == 0) - emitVirtualRegister(MO.getReg(), true, O); - else - llvm_unreachable( - "Don't know how to handle the modifier on virtual register."); - } - } - return; - case MachineOperand::MO_Immediate: - if (!Modifier) - O << MO.getImm(); - else if (strstr(Modifier, "vec") == Modifier) - printVecModifiedImmediate(MO, Modifier, O); - else - llvm_unreachable( - "Don't know how to handle modifier on immediate operand"); - return; - - case MachineOperand::MO_FPImmediate: - printFPConstant(MO.getFPImm(), O); - break; - - case MachineOperand::MO_GlobalAddress: - O << *Mang->getSymbol(MO.getGlobal()); - break; - - case MachineOperand::MO_ExternalSymbol: { - const char *symbname = MO.getSymbolName(); - if (strstr(symbname, ".PARAM") == symbname) { - unsigned index; - sscanf(symbname + 6, "%u[];", &index); - printParamName(index, O); - } else if (strstr(symbname, ".HLPPARAM") == symbname) { - unsigned index; - sscanf(symbname + 9, "%u[];", &index); - O << *CurrentFnSym << "_param_" << index << "_offset"; - } else - O << symbname; - break; - } - - case MachineOperand::MO_MachineBasicBlock: - O << *MO.getMBB()->getSymbol(); - return; - - default: - llvm_unreachable("Operand type not supported."); - } -} - -void NVPTXAsmPrinter::printImplicitDef(const MachineInstr *MI, - raw_ostream &O) const { -#ifndef __OPTIMIZE__ - O << "\t// Implicit def :"; - //printOperand(MI, 0); - O << "\n"; -#endif -} - -void NVPTXAsmPrinter::printMemOperand(const MachineInstr *MI, int opNum, - raw_ostream &O, const char *Modifier) { - printOperand(MI, opNum, O); - - if (Modifier && !strcmp(Modifier, "add")) { - O << ", "; - printOperand(MI, opNum + 1, O); - } else { - if (MI->getOperand(opNum + 1).isImm() && - MI->getOperand(opNum + 1).getImm() == 0) - return; // don't print ',0' or '+0' - O << "+"; - printOperand(MI, opNum + 1, O); - } -} - -void NVPTXAsmPrinter::printLdStCode(const MachineInstr *MI, int opNum, - raw_ostream &O, const char *Modifier) { - if (Modifier) { - const MachineOperand &MO = MI->getOperand(opNum); - int Imm = (int) MO.getImm(); - if (!strcmp(Modifier, "volatile")) { - if (Imm) - O << ".volatile"; - } else if (!strcmp(Modifier, "addsp")) { - switch (Imm) { - case NVPTX::PTXLdStInstCode::GLOBAL: - O << ".global"; - break; - case NVPTX::PTXLdStInstCode::SHARED: - O << ".shared"; - break; - case NVPTX::PTXLdStInstCode::LOCAL: - O << ".local"; - break; - case NVPTX::PTXLdStInstCode::PARAM: - O << ".param"; - break; - case NVPTX::PTXLdStInstCode::CONSTANT: - O << ".const"; - break; - case NVPTX::PTXLdStInstCode::GENERIC: - if (!nvptxSubtarget.hasGenericLdSt()) - O << ".global"; - break; - default: - llvm_unreachable("Wrong Address Space"); - } - } else if (!strcmp(Modifier, "sign")) { - if (Imm == NVPTX::PTXLdStInstCode::Signed) - O << "s"; - else if (Imm == NVPTX::PTXLdStInstCode::Unsigned) - O << "u"; - else - O << "f"; - } else if (!strcmp(Modifier, "vec")) { - if (Imm == NVPTX::PTXLdStInstCode::V2) - O << ".v2"; - else if (Imm == NVPTX::PTXLdStInstCode::V4) - O << ".v4"; - } else - llvm_unreachable("Unknown Modifier"); - } else - llvm_unreachable("Empty Modifier"); -} void NVPTXAsmPrinter::emitDeclaration(const Function *F, raw_ostream &O) { @@ -702,7 +685,7 @@ void NVPTXAsmPrinter::emitDeclaration(const Function *F, raw_ostream &O) { else O << ".func "; printReturnValStr(F, O); - O << *Mang->getSymbol(F) << "\n"; + O << *getSymbol(F) << "\n"; emitFunctionParamList(F, O); O << ";\n"; } @@ -912,7 +895,7 @@ bool NVPTXAsmPrinter::doInitialization(Module &M) { const_cast<TargetLoweringObjectFile &>(getObjFileLowering()) .Initialize(OutContext, TM); - Mang = new Mangler(OutContext, *TM.getDataLayout()); + Mang = new Mangler(&TM); // Emit header before any dwarf directives are emitted below. emitHeader(M, OS1); @@ -921,6 +904,16 @@ bool NVPTXAsmPrinter::doInitialization(Module &M) { // Already commented out //bool Result = AsmPrinter::doInitialization(M); + // Emit module-level inline asm if it exists. + if (!M.getModuleInlineAsm().empty()) { + OutStreamer.AddComment("Start of file scope inline assembly"); + OutStreamer.AddBlankLine(); + OutStreamer.EmitRawText(StringRef(M.getModuleInlineAsm())); + OutStreamer.AddBlankLine(); + OutStreamer.AddComment("End of file scope inline assembly"); + OutStreamer.AddBlankLine(); + } + if (nvptxSubtarget.getDrvInterface() == NVPTX::CUDA) recordAndEmitFilenames(M); @@ -1222,12 +1215,11 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar, else O << getPTXFundamentalTypeStr(ETy, false); O << " "; - O << *Mang->getSymbol(GVar); + O << *getSymbol(GVar); // Ptx allows variable initilization only for constant and global state // spaces. if (((PTy->getAddressSpace() == llvm::ADDRESS_SPACE_GLOBAL) || - (PTy->getAddressSpace() == llvm::ADDRESS_SPACE_CONST_NOT_GEN) || (PTy->getAddressSpace() == llvm::ADDRESS_SPACE_CONST)) && GVar->hasInitializer()) { const Constant *Initializer = GVar->getInitializer(); @@ -1251,7 +1243,6 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar, // Ptx allows variable initilization only for constant and // global state spaces. if (((PTy->getAddressSpace() == llvm::ADDRESS_SPACE_GLOBAL) || - (PTy->getAddressSpace() == llvm::ADDRESS_SPACE_CONST_NOT_GEN) || (PTy->getAddressSpace() == llvm::ADDRESS_SPACE_CONST)) && GVar->hasInitializer()) { const Constant *Initializer = GVar->getInitializer(); @@ -1260,15 +1251,15 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar, bufferAggregateConstant(Initializer, &aggBuffer); if (aggBuffer.numSymbols) { if (nvptxSubtarget.is64Bit()) { - O << " .u64 " << *Mang->getSymbol(GVar) << "["; + O << " .u64 " << *getSymbol(GVar) << "["; O << ElementSize / 8; } else { - O << " .u32 " << *Mang->getSymbol(GVar) << "["; + O << " .u32 " << *getSymbol(GVar) << "["; O << ElementSize / 4; } O << "]"; } else { - O << " .b8 " << *Mang->getSymbol(GVar) << "["; + O << " .b8 " << *getSymbol(GVar) << "["; O << ElementSize; O << "]"; } @@ -1276,7 +1267,7 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar, aggBuffer.print(); O << "}"; } else { - O << " .b8 " << *Mang->getSymbol(GVar); + O << " .b8 " << *getSymbol(GVar); if (ElementSize) { O << "["; O << ElementSize; @@ -1284,7 +1275,7 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar, } } } else { - O << " .b8 " << *Mang->getSymbol(GVar); + O << " .b8 " << *getSymbol(GVar); if (ElementSize) { O << "["; O << ElementSize; @@ -1322,14 +1313,6 @@ void NVPTXAsmPrinter::emitPTXAddressSpace(unsigned int AddressSpace, O << "global"; break; case llvm::ADDRESS_SPACE_CONST: - // This logic should be consistent with that in - // getCodeAddrSpace() (NVPTXISelDATToDAT.cpp) - if (nvptxSubtarget.hasGenericLdSt()) - O << "global"; - else - O << "const"; - break; - case llvm::ADDRESS_SPACE_CONST_NOT_GEN: O << "const"; break; case llvm::ADDRESS_SPACE_SHARED: @@ -1399,7 +1382,7 @@ void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar, O << " ."; O << getPTXFundamentalTypeStr(ETy); O << " "; - O << *Mang->getSymbol(GVar); + O << *getSymbol(GVar); return; } @@ -1414,7 +1397,7 @@ void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar, case Type::ArrayTyID: case Type::VectorTyID: ElementSize = TD->getTypeStoreSize(ETy); - O << " .b8 " << *Mang->getSymbol(GVar) << "["; + O << " .b8 " << *getSymbol(GVar) << "["; if (ElementSize) { O << itostr(ElementSize); } @@ -1469,7 +1452,7 @@ void NVPTXAsmPrinter::printParamName(Function::const_arg_iterator I, int paramIndex, raw_ostream &O) { if ((nvptxSubtarget.getDrvInterface() == NVPTX::NVCL) || (nvptxSubtarget.getDrvInterface() == NVPTX::CUDA)) - O << *Mang->getSymbol(I->getParent()) << "_param_" << paramIndex; + O << *getSymbol(I->getParent()) << "_param_" << paramIndex; else { std::string argName = I->getName(); const char *p = argName.c_str(); @@ -1528,13 +1511,13 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) { if (llvm::isImage(*I)) { std::string sname = I->getName(); if (llvm::isImageWriteOnly(*I)) - O << "\t.param .surfref " << *Mang->getSymbol(F) << "_param_" + O << "\t.param .surfref " << *getSymbol(F) << "_param_" << paramIndex; else // Default image is read_only - O << "\t.param .texref " << *Mang->getSymbol(F) << "_param_" + O << "\t.param .texref " << *getSymbol(F) << "_param_" << paramIndex; } else // Should be llvm::isSampler(*I) - O << "\t.param .samplerref " << *Mang->getSymbol(F) << "_param_" + O << "\t.param .samplerref " << *getSymbol(F) << "_param_" << paramIndex; continue; } @@ -1569,14 +1552,13 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) { default: O << ".ptr "; break; - case llvm::ADDRESS_SPACE_CONST_NOT_GEN: + case llvm::ADDRESS_SPACE_CONST: O << ".ptr .const "; break; case llvm::ADDRESS_SPACE_SHARED: O << ".ptr .shared "; break; case llvm::ADDRESS_SPACE_GLOBAL: - case llvm::ADDRESS_SPACE_CONST: O << ".ptr .global "; break; } @@ -1709,48 +1691,36 @@ void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters( for (unsigned i = 0; i < numVRs; i++) { unsigned int vr = TRI->index2VirtReg(i); const TargetRegisterClass *RC = MRI->getRegClass(vr); - std::map<unsigned, unsigned> ®map = VRidGlobal2LocalMap[RC->getID()]; + DenseMap<unsigned, unsigned> ®map = VRegMapping[RC]; int n = regmap.size(); regmap.insert(std::make_pair(vr, n + 1)); } // Emit register declarations // @TODO: Extract out the real register usage - O << "\t.reg .pred %p<" << NVPTXNumRegisters << ">;\n"; - O << "\t.reg .s16 %rc<" << NVPTXNumRegisters << ">;\n"; - O << "\t.reg .s16 %rs<" << NVPTXNumRegisters << ">;\n"; - O << "\t.reg .s32 %r<" << NVPTXNumRegisters << ">;\n"; - O << "\t.reg .s64 %rl<" << NVPTXNumRegisters << ">;\n"; - O << "\t.reg .f32 %f<" << NVPTXNumRegisters << ">;\n"; - O << "\t.reg .f64 %fl<" << NVPTXNumRegisters << ">;\n"; + // O << "\t.reg .pred %p<" << NVPTXNumRegisters << ">;\n"; + // O << "\t.reg .s16 %rc<" << NVPTXNumRegisters << ">;\n"; + // O << "\t.reg .s16 %rs<" << NVPTXNumRegisters << ">;\n"; + // O << "\t.reg .s32 %r<" << NVPTXNumRegisters << ">;\n"; + // O << "\t.reg .s64 %rl<" << NVPTXNumRegisters << ">;\n"; + // O << "\t.reg .f32 %f<" << NVPTXNumRegisters << ">;\n"; + // O << "\t.reg .f64 %fl<" << NVPTXNumRegisters << ">;\n"; // Emit declaration of the virtual registers or 'physical' registers for // each register class - //for (unsigned i=0; i< numRegClasses; i++) { - // std::map<unsigned, unsigned> ®map = VRidGlobal2LocalMap[i]; - // const TargetRegisterClass *RC = TRI->getRegClass(i); - // std::string rcname = getNVPTXRegClassName(RC); - // std::string rcStr = getNVPTXRegClassStr(RC); - // //int n = regmap.size(); - // if (!isNVPTXVectorRegClass(RC)) { - // O << "\t.reg " << rcname << " \t" << rcStr << "<" - // << NVPTXNumRegisters << ">;\n"; - // } - - // Only declare those registers that may be used. And do not emit vector - // registers as - // they are all elementized to scalar registers. - //if (n && !isNVPTXVectorRegClass(RC)) { - // if (RegAllocNilUsed) { - // O << "\t.reg " << rcname << " \t" << rcStr << "<" << (n+1) - // << ">;\n"; - // } - // else { - // O << "\t.reg " << rcname << " \t" << StrToUpper(rcStr) - // << "<" << 32 << ">;\n"; - // } - //} - //} + for (unsigned i=0; i< TRI->getNumRegClasses(); i++) { + const TargetRegisterClass *RC = TRI->getRegClass(i); + DenseMap<unsigned, unsigned> ®map = VRegMapping[RC]; + std::string rcname = getNVPTXRegClassName(RC); + std::string rcStr = getNVPTXRegClassStr(RC); + int n = regmap.size(); + + // Only declare those registers that may be used. + if (n) { + O << "\t.reg " << rcname << " \t" << rcStr << "<" << (n+1) + << ">;\n"; + } + } OutStreamer.EmitRawText(O.str()); } @@ -1794,13 +1764,13 @@ void NVPTXAsmPrinter::printScalarConstant(const Constant *CPV, raw_ostream &O) { return; } if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) { - O << *Mang->getSymbol(GVar); + O << *getSymbol(GVar); return; } if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) { const Value *v = Cexpr->stripPointerCasts(); if (const GlobalValue *GVar = dyn_cast<GlobalValue>(v)) { - O << *Mang->getSymbol(GVar); + O << *getSymbol(GVar); return; } else { O << *LowerConstant(CPV, *this); @@ -1918,7 +1888,7 @@ void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes, case Type::VectorTyID: case Type::StructTyID: { if (isa<ConstantArray>(CPV) || isa<ConstantVector>(CPV) || - isa<ConstantStruct>(CPV)) { + isa<ConstantStruct>(CPV) || isa<ConstantDataSequential>(CPV)) { int ElementSize = TD->getTypeAllocSize(CPV->getType()); bufferAggregateConstant(CPV, aggBuffer); if (Bytes > ElementSize) @@ -1991,41 +1961,6 @@ bool NVPTXAsmPrinter::isImageType(const Type *Ty) { return false; } -/// PrintAsmOperand - Print out an operand for an inline asm expression. -/// -bool NVPTXAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo, - unsigned AsmVariant, - const char *ExtraCode, raw_ostream &O) { - if (ExtraCode && ExtraCode[0]) { - if (ExtraCode[1] != 0) - return true; // Unknown modifier. - - switch (ExtraCode[0]) { - default: - // See if this is a generic print operand - return AsmPrinter::PrintAsmOperand(MI, OpNo, AsmVariant, ExtraCode, O); - case 'r': - break; - } - } - - printOperand(MI, OpNo, O); - - return false; -} - -bool NVPTXAsmPrinter::PrintAsmMemoryOperand( - const MachineInstr *MI, unsigned OpNo, unsigned AsmVariant, - const char *ExtraCode, raw_ostream &O) { - if (ExtraCode && ExtraCode[0]) - return true; // Unknown modifier - - O << '['; - printMemOperand(MI, OpNo, O); - O << ']'; - - return false; -} bool NVPTXAsmPrinter::ignoreLoc(const MachineInstr &MI) { switch (MI.getOpcode()) { @@ -2040,7 +1975,6 @@ bool NVPTXAsmPrinter::ignoreLoc(const MachineInstr &MI) { case NVPTX::CallArgI32: case NVPTX::CallArgI32imm: case NVPTX::CallArgI64: - case NVPTX::CallArgI8: case NVPTX::CallArgParam: case NVPTX::CallVoidInst: case NVPTX::CallVoidInstReg: @@ -2058,10 +1992,6 @@ bool NVPTXAsmPrinter::ignoreLoc(const MachineInstr &MI) { case NVPTX::StoreParamI32: case NVPTX::StoreParamI64: case NVPTX::StoreParamI8: - case NVPTX::StoreParamS32I8: - case NVPTX::StoreParamU32I8: - case NVPTX::StoreParamS32I16: - case NVPTX::StoreParamU32I16: case NVPTX::StoreRetvalF32: case NVPTX::StoreRetvalF64: case NVPTX::StoreRetvalI16: @@ -2074,7 +2004,6 @@ bool NVPTXAsmPrinter::ignoreLoc(const MachineInstr &MI) { case NVPTX::LastCallArgI32: case NVPTX::LastCallArgI32imm: case NVPTX::LastCallArgI64: - case NVPTX::LastCallArgI8: case NVPTX::LastCallArgParam: case NVPTX::LoadParamMemF32: case NVPTX::LoadParamMemF64: @@ -2082,12 +2011,6 @@ bool NVPTXAsmPrinter::ignoreLoc(const MachineInstr &MI) { case NVPTX::LoadParamMemI32: case NVPTX::LoadParamMemI64: case NVPTX::LoadParamMemI8: - case NVPTX::LoadParamRegF32: - case NVPTX::LoadParamRegF64: - case NVPTX::LoadParamRegI16: - case NVPTX::LoadParamRegI32: - case NVPTX::LoadParamRegI64: - case NVPTX::LoadParamRegI8: case NVPTX::PrototypeInst: case NVPTX::DBG_VALUE: return true; @@ -2095,6 +2018,116 @@ bool NVPTXAsmPrinter::ignoreLoc(const MachineInstr &MI) { return false; } +/// PrintAsmOperand - Print out an operand for an inline asm expression. +/// +bool NVPTXAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo, + unsigned AsmVariant, + const char *ExtraCode, raw_ostream &O) { + if (ExtraCode && ExtraCode[0]) { + if (ExtraCode[1] != 0) + return true; // Unknown modifier. + + switch (ExtraCode[0]) { + default: + // See if this is a generic print operand + return AsmPrinter::PrintAsmOperand(MI, OpNo, AsmVariant, ExtraCode, O); + case 'r': + break; + } + } + + printOperand(MI, OpNo, O); + + return false; +} + +bool NVPTXAsmPrinter::PrintAsmMemoryOperand( + const MachineInstr *MI, unsigned OpNo, unsigned AsmVariant, + const char *ExtraCode, raw_ostream &O) { + if (ExtraCode && ExtraCode[0]) + return true; // Unknown modifier + + O << '['; + printMemOperand(MI, OpNo, O); + O << ']'; + + return false; +} + +void NVPTXAsmPrinter::printOperand(const MachineInstr *MI, int opNum, + raw_ostream &O, const char *Modifier) { + const MachineOperand &MO = MI->getOperand(opNum); + switch (MO.getType()) { + case MachineOperand::MO_Register: + if (TargetRegisterInfo::isPhysicalRegister(MO.getReg())) { + if (MO.getReg() == NVPTX::VRDepot) + O << DEPOTNAME << getFunctionNumber(); + else + O << NVPTXInstPrinter::getRegisterName(MO.getReg()); + } else { + emitVirtualRegister(MO.getReg(), O); + } + return; + + case MachineOperand::MO_Immediate: + if (!Modifier) + O << MO.getImm(); + else if (strstr(Modifier, "vec") == Modifier) + printVecModifiedImmediate(MO, Modifier, O); + else + llvm_unreachable( + "Don't know how to handle modifier on immediate operand"); + return; + + case MachineOperand::MO_FPImmediate: + printFPConstant(MO.getFPImm(), O); + break; + + case MachineOperand::MO_GlobalAddress: + O << *getSymbol(MO.getGlobal()); + break; + + case MachineOperand::MO_ExternalSymbol: { + const char *symbname = MO.getSymbolName(); + if (strstr(symbname, ".PARAM") == symbname) { + unsigned index; + sscanf(symbname + 6, "%u[];", &index); + printParamName(index, O); + } else if (strstr(symbname, ".HLPPARAM") == symbname) { + unsigned index; + sscanf(symbname + 9, "%u[];", &index); + O << *CurrentFnSym << "_param_" << index << "_offset"; + } else + O << symbname; + break; + } + + case MachineOperand::MO_MachineBasicBlock: + O << *MO.getMBB()->getSymbol(); + return; + + default: + llvm_unreachable("Operand type not supported."); + } +} + +void NVPTXAsmPrinter::printMemOperand(const MachineInstr *MI, int opNum, + raw_ostream &O, const char *Modifier) { + printOperand(MI, opNum, O); + + if (Modifier && !strcmp(Modifier, "add")) { + O << ", "; + printOperand(MI, opNum + 1, O); + } else { + if (MI->getOperand(opNum + 1).isImm() && + MI->getOperand(opNum + 1).getImm() == 0) + return; // don't print ',0' or '+0' + O << "+"; + printOperand(MI, opNum + 1, O); + } +} + + // Force static initialization. extern "C" void LLVMInitializeNVPTXBackendAsmPrinter() { RegisterAsmPrinter<NVPTXAsmPrinter> X(TheNVPTXTarget32); diff --git a/contrib/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h b/contrib/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h index 7faa6b2..3abe5d1 100644 --- a/contrib/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h +++ b/contrib/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h @@ -155,7 +155,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter { if (pos == nextSymbolPos) { const Value *v = Symbols[nSym]; if (const GlobalValue *GVar = dyn_cast<GlobalValue>(v)) { - MCSymbol *Name = AP.Mang->getSymbol(GVar); + MCSymbol *Name = AP.getSymbol(GVar); O << *Name; } else if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(v)) { O << *nvptx::LowerConstant(Cexpr, AP); @@ -188,16 +188,17 @@ private: void EmitFunctionEntryLabel(); void EmitFunctionBodyStart(); void EmitFunctionBodyEnd(); + void emitImplicitDef(const MachineInstr *MI) const; void EmitInstruction(const MachineInstr *); + void lowerToMCInst(const MachineInstr *MI, MCInst &OutMI); + bool lowerOperand(const MachineOperand &MO, MCOperand &MCOp); + MCOperand GetSymbolRef(const MachineOperand &MO, const MCSymbol *Symbol); + unsigned encodeVirtualRegister(unsigned Reg); void EmitAlignment(unsigned NumBits, const GlobalValue *GV = 0) const {} void printGlobalVariable(const GlobalVariable *GVar); - void printOperand(const MachineInstr *MI, int opNum, raw_ostream &O, - const char *Modifier = 0); - void printLdStCode(const MachineInstr *MI, int opNum, raw_ostream &O, - const char *Modifier = 0); void printVecModifiedImmediate(const MachineOperand &MO, const char *Modifier, raw_ostream &O); void printMemOperand(const MachineInstr *MI, int opNum, raw_ostream &O, @@ -213,22 +214,23 @@ private: void emitGlobals(const Module &M); void emitHeader(Module &M, raw_ostream &O); void emitKernelFunctionDirectives(const Function &F, raw_ostream &O) const; - void emitVirtualRegister(unsigned int vr, bool isVec, raw_ostream &O); + void emitVirtualRegister(unsigned int vr, raw_ostream &); void emitFunctionExternParamList(const MachineFunction &MF); void emitFunctionParamList(const Function *, raw_ostream &O); void emitFunctionParamList(const MachineFunction &MF, raw_ostream &O); void setAndEmitFunctionVirtualRegisters(const MachineFunction &MF); void emitFunctionTempData(const MachineFunction &MF, unsigned &FrameSize); bool isImageType(const Type *Ty); + void printReturnValStr(const Function *, raw_ostream &O); + void printReturnValStr(const MachineFunction &MF, raw_ostream &O); bool PrintAsmOperand(const MachineInstr *MI, unsigned OpNo, unsigned AsmVariant, const char *ExtraCode, raw_ostream &); + void printOperand(const MachineInstr *MI, int opNum, raw_ostream &O, + const char *Modifier = 0); bool PrintAsmMemoryOperand(const MachineInstr *MI, unsigned OpNo, unsigned AsmVariant, const char *ExtraCode, raw_ostream &); - void printReturnValStr(const Function *, raw_ostream &O); - void printReturnValStr(const MachineFunction &MF, raw_ostream &O); - protected: bool doInitialization(Module &M); bool doFinalization(Module &M); @@ -243,7 +245,9 @@ private: // The contents are specific for each // MachineFunction. But the size of the // array is not. - std::map<unsigned, unsigned> *VRidGlobal2LocalMap; + typedef DenseMap<unsigned, unsigned> VRegMap; + typedef DenseMap<const TargetRegisterClass *, VRegMap> VRegRCMap; + VRegRCMap VRegMapping; // cache the subtarget here. const NVPTXSubtarget &nvptxSubtarget; // Build the map between type name and ID based on module's type @@ -281,7 +285,6 @@ public: : AsmPrinter(TM, Streamer), nvptxSubtarget(TM.getSubtarget<NVPTXSubtarget>()) { CurrentBankselLabelInBasicBlock = ""; - VRidGlobal2LocalMap = NULL; reader = NULL; } @@ -292,7 +295,7 @@ public: bool ignoreLoc(const MachineInstr &); - virtual void getVirtualRegisterName(unsigned, bool, raw_ostream &); + std::string getVirtualRegisterName(unsigned) const; DebugLoc prevDebugLoc; void emitLineNumberAsDotLoc(const MachineInstr &); diff --git a/contrib/llvm/lib/Target/NVPTX/NVPTXFrameLowering.cpp b/contrib/llvm/lib/Target/NVPTX/NVPTXFrameLowering.cpp index 6533da5..9030584f 100644 --- a/contrib/llvm/lib/Target/NVPTX/NVPTXFrameLowering.cpp +++ b/contrib/llvm/lib/Target/NVPTX/NVPTXFrameLowering.cpp @@ -20,6 +20,7 @@ #include "llvm/CodeGen/MachineFrameInfo.h" #include "llvm/CodeGen/MachineFunction.h" #include "llvm/CodeGen/MachineInstrBuilder.h" +#include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/MC/MachineLocation.h" #include "llvm/Target/TargetInstrInfo.h" @@ -36,30 +37,24 @@ void NVPTXFrameLowering::emitPrologue(MachineFunction &MF) const { // in the BB, so giving it no debug location. DebugLoc dl = DebugLoc(); - if (tm.getSubtargetImpl()->hasGenericLdSt()) { - // mov %SPL, %depot; - // cvta.local %SP, %SPL; - if (is64bit) { - MachineInstr *MI = BuildMI( - MBB, MBBI, dl, tm.getInstrInfo()->get(NVPTX::cvta_local_yes_64), - NVPTX::VRFrame).addReg(NVPTX::VRFrameLocal); - BuildMI(MBB, MI, dl, tm.getInstrInfo()->get(NVPTX::IMOV64rr), - NVPTX::VRFrameLocal).addReg(NVPTX::VRDepot); - } else { - MachineInstr *MI = BuildMI( - MBB, MBBI, dl, tm.getInstrInfo()->get(NVPTX::cvta_local_yes), - NVPTX::VRFrame).addReg(NVPTX::VRFrameLocal); - BuildMI(MBB, MI, dl, tm.getInstrInfo()->get(NVPTX::IMOV32rr), - NVPTX::VRFrameLocal).addReg(NVPTX::VRDepot); - } + MachineRegisterInfo &MRI = MF.getRegInfo(); + + // mov %SPL, %depot; + // cvta.local %SP, %SPL; + if (is64bit) { + unsigned LocalReg = MRI.createVirtualRegister(&NVPTX::Int64RegsRegClass); + MachineInstr *MI = BuildMI( + MBB, MBBI, dl, tm.getInstrInfo()->get(NVPTX::cvta_local_yes_64), + NVPTX::VRFrame).addReg(LocalReg); + BuildMI(MBB, MI, dl, tm.getInstrInfo()->get(NVPTX::MOV_DEPOT_ADDR_64), + LocalReg).addImm(MF.getFunctionNumber()); } else { - // mov %SP, %depot; - if (is64bit) - BuildMI(MBB, MBBI, dl, tm.getInstrInfo()->get(NVPTX::IMOV64rr), - NVPTX::VRFrame).addReg(NVPTX::VRDepot); - else - BuildMI(MBB, MBBI, dl, tm.getInstrInfo()->get(NVPTX::IMOV32rr), - NVPTX::VRFrame).addReg(NVPTX::VRDepot); + unsigned LocalReg = MRI.createVirtualRegister(&NVPTX::Int32RegsRegClass); + MachineInstr *MI = BuildMI( + MBB, MBBI, dl, tm.getInstrInfo()->get(NVPTX::cvta_local_yes), + NVPTX::VRFrame).addReg(LocalReg); + BuildMI(MBB, MI, dl, tm.getInstrInfo()->get(NVPTX::MOV_DEPOT_ADDR), + LocalReg).addImm(MF.getFunctionNumber()); } } } diff --git a/contrib/llvm/lib/Target/NVPTX/NVPTXGenericToNVVM.cpp b/contrib/llvm/lib/Target/NVPTX/NVPTXGenericToNVVM.cpp index 1077c46..9fb0dd8 100644 --- a/contrib/llvm/lib/Target/NVPTX/NVPTXGenericToNVVM.cpp +++ b/contrib/llvm/lib/Target/NVPTX/NVPTXGenericToNVVM.cpp @@ -142,7 +142,7 @@ bool GenericToNVVM::runOnModule(Module &M) { GlobalVariable *GV = I->first; GlobalVariable *NewGV = I->second; ++I; - Constant *BitCastNewGV = ConstantExpr::getBitCast(NewGV, GV->getType()); + Constant *BitCastNewGV = ConstantExpr::getPointerCast(NewGV, GV->getType()); // At this point, the remaining uses of GV should be found only in global // variable initializers, as other uses have been already been removed // while walking through the instructions in function definitions. @@ -384,7 +384,7 @@ void GenericToNVVM::remapNamedMDNode(Module *M, NamedMDNode *N) { // Replace the old operands with the new operands. N->dropAllReferences(); - for (SmallVector<MDNode *, 16>::iterator I = NewOperands.begin(), + for (SmallVectorImpl<MDNode *>::iterator I = NewOperands.begin(), E = NewOperands.end(); I != E; ++I) { N->addOperand(*I); diff --git a/contrib/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/contrib/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp index b0dfca3..4b8b306 100644 --- a/contrib/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++ b/contrib/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp @@ -25,28 +25,29 @@ using namespace llvm; -static cl::opt<bool> UseFMADInstruction( - "nvptx-mad-enable", cl::ZeroOrMore, - cl::desc("NVPTX Specific: Enable generating FMAD instructions"), - cl::init(false)); - static cl::opt<int> -FMAContractLevel("nvptx-fma-level", cl::ZeroOrMore, +FMAContractLevel("nvptx-fma-level", cl::ZeroOrMore, cl::Hidden, cl::desc("NVPTX Specific: FMA contraction (0: don't do it" " 1: do it 2: do it aggressively"), cl::init(2)); static cl::opt<int> UsePrecDivF32( - "nvptx-prec-divf32", cl::ZeroOrMore, + "nvptx-prec-divf32", cl::ZeroOrMore, cl::Hidden, cl::desc("NVPTX Specifies: 0 use div.approx, 1 use div.full, 2 use" " IEEE Compliant F32 div.rnd if avaiable."), cl::init(2)); static cl::opt<bool> -UsePrecSqrtF32("nvptx-prec-sqrtf32", +UsePrecSqrtF32("nvptx-prec-sqrtf32", cl::Hidden, cl::desc("NVPTX Specific: 0 use sqrt.approx, 1 use sqrt.rn."), cl::init(true)); +static cl::opt<bool> +FtzEnabled("nvptx-f32ftz", cl::ZeroOrMore, cl::Hidden, + cl::desc("NVPTX Specific: Flush f32 subnormals to sign-preserving zero."), + cl::init(false)); + + /// createNVPTXISelDag - This pass converts a legalized DAG into a /// NVPTX-specific DAG, ready for instruction scheduling. FunctionPass *llvm::createNVPTXISelDag(NVPTXTargetMachine &TM, @@ -58,12 +59,7 @@ NVPTXDAGToDAGISel::NVPTXDAGToDAGISel(NVPTXTargetMachine &tm, CodeGenOpt::Level OptLevel) : SelectionDAGISel(tm, OptLevel), Subtarget(tm.getSubtarget<NVPTXSubtarget>()) { - // Always do fma.f32 fpcontract if the target supports the instruction. - // Always do fma.f64 fpcontract if the target supports the instruction. - // Do mad.f32 is nvptx-mad-enable is specified and the target does not - // support fma.f32. - doFMADF32 = (OptLevel > 0) && UseFMADInstruction && !Subtarget.hasFMAF32(); doFMAF32 = (OptLevel > 0) && Subtarget.hasFMAF32() && (FMAContractLevel >= 1); doFMAF64 = (OptLevel > 0) && Subtarget.hasFMAF64() && (FMAContractLevel >= 1); doFMAF32AGG = @@ -71,20 +67,51 @@ NVPTXDAGToDAGISel::NVPTXDAGToDAGISel(NVPTXTargetMachine &tm, doFMAF64AGG = (OptLevel > 0) && Subtarget.hasFMAF64() && (FMAContractLevel == 2); - allowFMA = (FMAContractLevel >= 1) || UseFMADInstruction; - - UseF32FTZ = false; + allowFMA = (FMAContractLevel >= 1); doMulWide = (OptLevel > 0); +} - // Decide how to translate f32 div - do_DIVF32_PREC = UsePrecDivF32; - // Decide how to translate f32 sqrt - do_SQRTF32_PREC = UsePrecSqrtF32; - // sm less than sm_20 does not support div.rnd. Use div.full. - if (do_DIVF32_PREC == 2 && !Subtarget.reqPTX20()) - do_DIVF32_PREC = 1; +int NVPTXDAGToDAGISel::getDivF32Level() const { + if (UsePrecDivF32.getNumOccurrences() > 0) { + // If nvptx-prec-div32=N is used on the command-line, always honor it + return UsePrecDivF32; + } else { + // Otherwise, use div.approx if fast math is enabled + if (TM.Options.UnsafeFPMath) + return 0; + else + return 2; + } +} + +bool NVPTXDAGToDAGISel::usePrecSqrtF32() const { + if (UsePrecSqrtF32.getNumOccurrences() > 0) { + // If nvptx-prec-sqrtf32 is used on the command-line, always honor it + return UsePrecSqrtF32; + } else { + // Otherwise, use sqrt.approx if fast math is enabled + if (TM.Options.UnsafeFPMath) + return false; + else + return true; + } +} +bool NVPTXDAGToDAGISel::useF32FTZ() const { + if (FtzEnabled.getNumOccurrences() > 0) { + // If nvptx-f32ftz is used on the command-line, always honor it + return FtzEnabled; + } else { + const Function *F = MF->getFunction(); + // Otherwise, check for an nvptx-f32ftz attribute on the function + if (F->hasFnAttribute("nvptx-f32ftz")) + return (F->getAttributes().getAttribute(AttributeSet::FunctionIndex, + "nvptx-f32ftz") + .getValueAsString() == "true"); + else + return false; + } } /// Select - Select instructions not customized! Used for @@ -118,6 +145,23 @@ SDNode *NVPTXDAGToDAGISel::Select(SDNode *N) { case NVPTXISD::StoreV4: ResNode = SelectStoreVector(N); break; + case NVPTXISD::LoadParam: + case NVPTXISD::LoadParamV2: + case NVPTXISD::LoadParamV4: + ResNode = SelectLoadParam(N); + break; + case NVPTXISD::StoreRetval: + case NVPTXISD::StoreRetvalV2: + case NVPTXISD::StoreRetvalV4: + ResNode = SelectStoreRetval(N); + break; + case NVPTXISD::StoreParam: + case NVPTXISD::StoreParamV2: + case NVPTXISD::StoreParamV4: + case NVPTXISD::StoreParamS32: + case NVPTXISD::StoreParamU32: + ResNode = SelectStoreParam(N); + break; default: break; } @@ -129,42 +173,26 @@ SDNode *NVPTXDAGToDAGISel::Select(SDNode *N) { static unsigned int getCodeAddrSpace(MemSDNode *N, const NVPTXSubtarget &Subtarget) { const Value *Src = N->getSrcValue(); + if (!Src) - return NVPTX::PTXLdStInstCode::LOCAL; + return NVPTX::PTXLdStInstCode::GENERIC; if (const PointerType *PT = dyn_cast<PointerType>(Src->getType())) { switch (PT->getAddressSpace()) { - case llvm::ADDRESS_SPACE_LOCAL: - return NVPTX::PTXLdStInstCode::LOCAL; - case llvm::ADDRESS_SPACE_GLOBAL: - return NVPTX::PTXLdStInstCode::GLOBAL; - case llvm::ADDRESS_SPACE_SHARED: - return NVPTX::PTXLdStInstCode::SHARED; - case llvm::ADDRESS_SPACE_CONST_NOT_GEN: - return NVPTX::PTXLdStInstCode::CONSTANT; - case llvm::ADDRESS_SPACE_GENERIC: - return NVPTX::PTXLdStInstCode::GENERIC; - case llvm::ADDRESS_SPACE_PARAM: - return NVPTX::PTXLdStInstCode::PARAM; - case llvm::ADDRESS_SPACE_CONST: - // If the arch supports generic address space, translate it to GLOBAL - // for correctness. - // If the arch does not support generic address space, then the arch - // does not really support ADDRESS_SPACE_CONST, translate it to - // to CONSTANT for better performance. - if (Subtarget.hasGenericLdSt()) - return NVPTX::PTXLdStInstCode::GLOBAL; - else - return NVPTX::PTXLdStInstCode::CONSTANT; - default: - break; + case llvm::ADDRESS_SPACE_LOCAL: return NVPTX::PTXLdStInstCode::LOCAL; + case llvm::ADDRESS_SPACE_GLOBAL: return NVPTX::PTXLdStInstCode::GLOBAL; + case llvm::ADDRESS_SPACE_SHARED: return NVPTX::PTXLdStInstCode::SHARED; + case llvm::ADDRESS_SPACE_GENERIC: return NVPTX::PTXLdStInstCode::GENERIC; + case llvm::ADDRESS_SPACE_PARAM: return NVPTX::PTXLdStInstCode::PARAM; + case llvm::ADDRESS_SPACE_CONST: return NVPTX::PTXLdStInstCode::CONSTANT; + default: break; } } - return NVPTX::PTXLdStInstCode::LOCAL; + return NVPTX::PTXLdStInstCode::GENERIC; } SDNode *NVPTXDAGToDAGISel::SelectLoad(SDNode *N) { - DebugLoc dl = N->getDebugLoc(); + SDLoc dl(N); LoadSDNode *LD = cast<LoadSDNode>(N); EVT LoadedVT = LD->getMemoryVT(); SDNode *NVPTXLD = NULL; @@ -207,7 +235,8 @@ SDNode *NVPTXDAGToDAGISel::SelectLoad(SDNode *N) { // type is integer // Float : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float MVT ScalarVT = SimpleVT.getScalarType(); - unsigned fromTypeWidth = ScalarVT.getSizeInBits(); + // Read at least 8 bits (predicates are stored as 8-bit values) + unsigned fromTypeWidth = std::max(8U, ScalarVT.getSizeInBits()); unsigned int fromType; if ((LD->getExtensionType() == ISD::SEXTLOAD)) fromType = NVPTX::PTXLdStInstCode::Signed; @@ -222,7 +251,7 @@ SDNode *NVPTXDAGToDAGISel::SelectLoad(SDNode *N) { SDValue Addr; SDValue Offset, Base; unsigned Opcode; - MVT::SimpleValueType TargetVT = LD->getValueType(0).getSimpleVT().SimpleTy; + MVT::SimpleValueType TargetVT = LD->getSimpleValueType(0).SimpleTy; if (SelectDirectAddr(N1, Addr)) { switch (TargetVT) { @@ -403,7 +432,7 @@ SDNode *NVPTXDAGToDAGISel::SelectLoadVector(SDNode *N) { SDValue Op1 = N->getOperand(1); SDValue Addr, Offset, Base; unsigned Opcode; - DebugLoc DL = N->getDebugLoc(); + SDLoc DL(N); SDNode *LD; MemSDNode *MemSD = cast<MemSDNode>(N); EVT LoadedVT = MemSD->getMemoryVT(); @@ -432,7 +461,8 @@ SDNode *NVPTXDAGToDAGISel::SelectLoadVector(SDNode *N) { // type is integer // Float : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float MVT ScalarVT = SimpleVT.getScalarType(); - unsigned FromTypeWidth = ScalarVT.getSizeInBits(); + // Read at least 8 bits (predicates are stored as 8-bit values) + unsigned FromTypeWidth = std::max(8U, ScalarVT.getSizeInBits()); unsigned int FromType; // The last operand holds the original LoadSDNode::getExtensionType() value unsigned ExtensionType = cast<ConstantSDNode>( @@ -784,194 +814,478 @@ SDNode *NVPTXDAGToDAGISel::SelectLDGLDUVector(SDNode *N) { SDValue Chain = N->getOperand(0); SDValue Op1 = N->getOperand(1); unsigned Opcode; - DebugLoc DL = N->getDebugLoc(); + SDLoc DL(N); SDNode *LD; + MemSDNode *Mem = cast<MemSDNode>(N); + SDValue Base, Offset, Addr; - EVT RetVT = N->getValueType(0); + EVT EltVT = Mem->getMemoryVT().getVectorElementType(); - // Select opcode - if (Subtarget.is64Bit()) { + if (SelectDirectAddr(Op1, Addr)) { switch (N->getOpcode()) { default: return NULL; case NVPTXISD::LDGV2: - switch (RetVT.getSimpleVT().SimpleTy) { + switch (EltVT.getSimpleVT().SimpleTy) { default: return NULL; case MVT::i8: - Opcode = NVPTX::INT_PTX_LDG_G_v2i8_ELE_64; + Opcode = NVPTX::INT_PTX_LDG_G_v2i8_ELE_avar; break; case MVT::i16: - Opcode = NVPTX::INT_PTX_LDG_G_v2i16_ELE_64; + Opcode = NVPTX::INT_PTX_LDG_G_v2i16_ELE_avar; break; case MVT::i32: - Opcode = NVPTX::INT_PTX_LDG_G_v2i32_ELE_64; + Opcode = NVPTX::INT_PTX_LDG_G_v2i32_ELE_avar; break; case MVT::i64: - Opcode = NVPTX::INT_PTX_LDG_G_v2i64_ELE_64; + Opcode = NVPTX::INT_PTX_LDG_G_v2i64_ELE_avar; break; case MVT::f32: - Opcode = NVPTX::INT_PTX_LDG_G_v2f32_ELE_64; + Opcode = NVPTX::INT_PTX_LDG_G_v2f32_ELE_avar; break; case MVT::f64: - Opcode = NVPTX::INT_PTX_LDG_G_v2f64_ELE_64; + Opcode = NVPTX::INT_PTX_LDG_G_v2f64_ELE_avar; break; } break; - case NVPTXISD::LDGV4: - switch (RetVT.getSimpleVT().SimpleTy) { + case NVPTXISD::LDUV2: + switch (EltVT.getSimpleVT().SimpleTy) { default: return NULL; case MVT::i8: - Opcode = NVPTX::INT_PTX_LDG_G_v4i8_ELE_64; + Opcode = NVPTX::INT_PTX_LDU_G_v2i8_ELE_avar; break; case MVT::i16: - Opcode = NVPTX::INT_PTX_LDG_G_v4i16_ELE_64; + Opcode = NVPTX::INT_PTX_LDU_G_v2i16_ELE_avar; break; case MVT::i32: - Opcode = NVPTX::INT_PTX_LDG_G_v4i32_ELE_64; + Opcode = NVPTX::INT_PTX_LDU_G_v2i32_ELE_avar; + break; + case MVT::i64: + Opcode = NVPTX::INT_PTX_LDU_G_v2i64_ELE_avar; break; case MVT::f32: - Opcode = NVPTX::INT_PTX_LDG_G_v4f32_ELE_64; + Opcode = NVPTX::INT_PTX_LDU_G_v2f32_ELE_avar; + break; + case MVT::f64: + Opcode = NVPTX::INT_PTX_LDU_G_v2f64_ELE_avar; break; } break; - case NVPTXISD::LDUV2: - switch (RetVT.getSimpleVT().SimpleTy) { + case NVPTXISD::LDGV4: + switch (EltVT.getSimpleVT().SimpleTy) { default: return NULL; case MVT::i8: - Opcode = NVPTX::INT_PTX_LDU_G_v2i8_ELE_64; + Opcode = NVPTX::INT_PTX_LDG_G_v4i8_ELE_avar; break; case MVT::i16: - Opcode = NVPTX::INT_PTX_LDU_G_v2i16_ELE_64; + Opcode = NVPTX::INT_PTX_LDG_G_v4i16_ELE_avar; break; case MVT::i32: - Opcode = NVPTX::INT_PTX_LDU_G_v2i32_ELE_64; - break; - case MVT::i64: - Opcode = NVPTX::INT_PTX_LDU_G_v2i64_ELE_64; + Opcode = NVPTX::INT_PTX_LDG_G_v4i32_ELE_avar; break; case MVT::f32: - Opcode = NVPTX::INT_PTX_LDU_G_v2f32_ELE_64; - break; - case MVT::f64: - Opcode = NVPTX::INT_PTX_LDU_G_v2f64_ELE_64; + Opcode = NVPTX::INT_PTX_LDG_G_v4f32_ELE_avar; break; } break; case NVPTXISD::LDUV4: - switch (RetVT.getSimpleVT().SimpleTy) { + switch (EltVT.getSimpleVT().SimpleTy) { default: return NULL; case MVT::i8: - Opcode = NVPTX::INT_PTX_LDU_G_v4i8_ELE_64; + Opcode = NVPTX::INT_PTX_LDU_G_v4i8_ELE_avar; break; case MVT::i16: - Opcode = NVPTX::INT_PTX_LDU_G_v4i16_ELE_64; + Opcode = NVPTX::INT_PTX_LDU_G_v4i16_ELE_avar; break; case MVT::i32: - Opcode = NVPTX::INT_PTX_LDU_G_v4i32_ELE_64; + Opcode = NVPTX::INT_PTX_LDU_G_v4i32_ELE_avar; break; case MVT::f32: - Opcode = NVPTX::INT_PTX_LDU_G_v4f32_ELE_64; + Opcode = NVPTX::INT_PTX_LDU_G_v4f32_ELE_avar; break; } break; } - } else { - switch (N->getOpcode()) { - default: - return NULL; - case NVPTXISD::LDGV2: - switch (RetVT.getSimpleVT().SimpleTy) { + + SDValue Ops[] = { Addr, Chain }; + LD = CurDAG->getMachineNode(Opcode, DL, N->getVTList(), + ArrayRef<SDValue>(Ops, 2)); + } else if (Subtarget.is64Bit() + ? SelectADDRri64(Op1.getNode(), Op1, Base, Offset) + : SelectADDRri(Op1.getNode(), Op1, Base, Offset)) { + if (Subtarget.is64Bit()) { + switch (N->getOpcode()) { default: return NULL; - case MVT::i8: - Opcode = NVPTX::INT_PTX_LDG_G_v2i8_ELE_32; - break; - case MVT::i16: - Opcode = NVPTX::INT_PTX_LDG_G_v2i16_ELE_32; - break; - case MVT::i32: - Opcode = NVPTX::INT_PTX_LDG_G_v2i32_ELE_32; + case NVPTXISD::LDGV2: + switch (EltVT.getSimpleVT().SimpleTy) { + default: + return NULL; + case MVT::i8: + Opcode = NVPTX::INT_PTX_LDG_G_v2i8_ELE_ari64; + break; + case MVT::i16: + Opcode = NVPTX::INT_PTX_LDG_G_v2i16_ELE_ari64; + break; + case MVT::i32: + Opcode = NVPTX::INT_PTX_LDG_G_v2i32_ELE_ari64; + break; + case MVT::i64: + Opcode = NVPTX::INT_PTX_LDG_G_v2i64_ELE_ari64; + break; + case MVT::f32: + Opcode = NVPTX::INT_PTX_LDG_G_v2f32_ELE_ari64; + break; + case MVT::f64: + Opcode = NVPTX::INT_PTX_LDG_G_v2f64_ELE_ari64; + break; + } break; - case MVT::i64: - Opcode = NVPTX::INT_PTX_LDG_G_v2i64_ELE_32; + case NVPTXISD::LDUV2: + switch (EltVT.getSimpleVT().SimpleTy) { + default: + return NULL; + case MVT::i8: + Opcode = NVPTX::INT_PTX_LDU_G_v2i8_ELE_ari64; + break; + case MVT::i16: + Opcode = NVPTX::INT_PTX_LDU_G_v2i16_ELE_ari64; + break; + case MVT::i32: + Opcode = NVPTX::INT_PTX_LDU_G_v2i32_ELE_ari64; + break; + case MVT::i64: + Opcode = NVPTX::INT_PTX_LDU_G_v2i64_ELE_ari64; + break; + case MVT::f32: + Opcode = NVPTX::INT_PTX_LDU_G_v2f32_ELE_ari64; + break; + case MVT::f64: + Opcode = NVPTX::INT_PTX_LDU_G_v2f64_ELE_ari64; + break; + } break; - case MVT::f32: - Opcode = NVPTX::INT_PTX_LDG_G_v2f32_ELE_32; + case NVPTXISD::LDGV4: + switch (EltVT.getSimpleVT().SimpleTy) { + default: + return NULL; + case MVT::i8: + Opcode = NVPTX::INT_PTX_LDG_G_v4i8_ELE_ari64; + break; + case MVT::i16: + Opcode = NVPTX::INT_PTX_LDG_G_v4i16_ELE_ari64; + break; + case MVT::i32: + Opcode = NVPTX::INT_PTX_LDG_G_v4i32_ELE_ari64; + break; + case MVT::f32: + Opcode = NVPTX::INT_PTX_LDG_G_v4f32_ELE_ari64; + break; + } break; - case MVT::f64: - Opcode = NVPTX::INT_PTX_LDG_G_v2f64_ELE_32; + case NVPTXISD::LDUV4: + switch (EltVT.getSimpleVT().SimpleTy) { + default: + return NULL; + case MVT::i8: + Opcode = NVPTX::INT_PTX_LDU_G_v4i8_ELE_ari64; + break; + case MVT::i16: + Opcode = NVPTX::INT_PTX_LDU_G_v4i16_ELE_ari64; + break; + case MVT::i32: + Opcode = NVPTX::INT_PTX_LDU_G_v4i32_ELE_ari64; + break; + case MVT::f32: + Opcode = NVPTX::INT_PTX_LDU_G_v4f32_ELE_ari64; + break; + } break; } - break; - case NVPTXISD::LDGV4: - switch (RetVT.getSimpleVT().SimpleTy) { + } else { + switch (N->getOpcode()) { default: return NULL; - case MVT::i8: - Opcode = NVPTX::INT_PTX_LDG_G_v4i8_ELE_32; + case NVPTXISD::LDGV2: + switch (EltVT.getSimpleVT().SimpleTy) { + default: + return NULL; + case MVT::i8: + Opcode = NVPTX::INT_PTX_LDG_G_v2i8_ELE_ari32; + break; + case MVT::i16: + Opcode = NVPTX::INT_PTX_LDG_G_v2i16_ELE_ari32; + break; + case MVT::i32: + Opcode = NVPTX::INT_PTX_LDG_G_v2i32_ELE_ari32; + break; + case MVT::i64: + Opcode = NVPTX::INT_PTX_LDG_G_v2i64_ELE_ari32; + break; + case MVT::f32: + Opcode = NVPTX::INT_PTX_LDG_G_v2f32_ELE_ari32; + break; + case MVT::f64: + Opcode = NVPTX::INT_PTX_LDG_G_v2f64_ELE_ari32; + break; + } break; - case MVT::i16: - Opcode = NVPTX::INT_PTX_LDG_G_v4i16_ELE_32; + case NVPTXISD::LDUV2: + switch (EltVT.getSimpleVT().SimpleTy) { + default: + return NULL; + case MVT::i8: + Opcode = NVPTX::INT_PTX_LDU_G_v2i8_ELE_ari32; + break; + case MVT::i16: + Opcode = NVPTX::INT_PTX_LDU_G_v2i16_ELE_ari32; + break; + case MVT::i32: + Opcode = NVPTX::INT_PTX_LDU_G_v2i32_ELE_ari32; + break; + case MVT::i64: + Opcode = NVPTX::INT_PTX_LDU_G_v2i64_ELE_ari32; + break; + case MVT::f32: + Opcode = NVPTX::INT_PTX_LDU_G_v2f32_ELE_ari32; + break; + case MVT::f64: + Opcode = NVPTX::INT_PTX_LDU_G_v2f64_ELE_ari32; + break; + } break; - case MVT::i32: - Opcode = NVPTX::INT_PTX_LDG_G_v4i32_ELE_32; + case NVPTXISD::LDGV4: + switch (EltVT.getSimpleVT().SimpleTy) { + default: + return NULL; + case MVT::i8: + Opcode = NVPTX::INT_PTX_LDG_G_v4i8_ELE_ari32; + break; + case MVT::i16: + Opcode = NVPTX::INT_PTX_LDG_G_v4i16_ELE_ari32; + break; + case MVT::i32: + Opcode = NVPTX::INT_PTX_LDG_G_v4i32_ELE_ari32; + break; + case MVT::f32: + Opcode = NVPTX::INT_PTX_LDG_G_v4f32_ELE_ari32; + break; + } break; - case MVT::f32: - Opcode = NVPTX::INT_PTX_LDG_G_v4f32_ELE_32; + case NVPTXISD::LDUV4: + switch (EltVT.getSimpleVT().SimpleTy) { + default: + return NULL; + case MVT::i8: + Opcode = NVPTX::INT_PTX_LDU_G_v4i8_ELE_ari32; + break; + case MVT::i16: + Opcode = NVPTX::INT_PTX_LDU_G_v4i16_ELE_ari32; + break; + case MVT::i32: + Opcode = NVPTX::INT_PTX_LDU_G_v4i32_ELE_ari32; + break; + case MVT::f32: + Opcode = NVPTX::INT_PTX_LDU_G_v4f32_ELE_ari32; + break; + } break; } - break; - case NVPTXISD::LDUV2: - switch (RetVT.getSimpleVT().SimpleTy) { + } + + SDValue Ops[] = { Base, Offset, Chain }; + + LD = CurDAG->getMachineNode(Opcode, DL, N->getVTList(), + ArrayRef<SDValue>(Ops, 3)); + } else { + if (Subtarget.is64Bit()) { + switch (N->getOpcode()) { default: return NULL; - case MVT::i8: - Opcode = NVPTX::INT_PTX_LDU_G_v2i8_ELE_32; - break; - case MVT::i16: - Opcode = NVPTX::INT_PTX_LDU_G_v2i16_ELE_32; - break; - case MVT::i32: - Opcode = NVPTX::INT_PTX_LDU_G_v2i32_ELE_32; + case NVPTXISD::LDGV2: + switch (EltVT.getSimpleVT().SimpleTy) { + default: + return NULL; + case MVT::i8: + Opcode = NVPTX::INT_PTX_LDG_G_v2i8_ELE_areg64; + break; + case MVT::i16: + Opcode = NVPTX::INT_PTX_LDG_G_v2i16_ELE_areg64; + break; + case MVT::i32: + Opcode = NVPTX::INT_PTX_LDG_G_v2i32_ELE_areg64; + break; + case MVT::i64: + Opcode = NVPTX::INT_PTX_LDG_G_v2i64_ELE_areg64; + break; + case MVT::f32: + Opcode = NVPTX::INT_PTX_LDG_G_v2f32_ELE_areg64; + break; + case MVT::f64: + Opcode = NVPTX::INT_PTX_LDG_G_v2f64_ELE_areg64; + break; + } break; - case MVT::i64: - Opcode = NVPTX::INT_PTX_LDU_G_v2i64_ELE_32; + case NVPTXISD::LDUV2: + switch (EltVT.getSimpleVT().SimpleTy) { + default: + return NULL; + case MVT::i8: + Opcode = NVPTX::INT_PTX_LDU_G_v2i8_ELE_areg64; + break; + case MVT::i16: + Opcode = NVPTX::INT_PTX_LDU_G_v2i16_ELE_areg64; + break; + case MVT::i32: + Opcode = NVPTX::INT_PTX_LDU_G_v2i32_ELE_areg64; + break; + case MVT::i64: + Opcode = NVPTX::INT_PTX_LDU_G_v2i64_ELE_areg64; + break; + case MVT::f32: + Opcode = NVPTX::INT_PTX_LDU_G_v2f32_ELE_areg64; + break; + case MVT::f64: + Opcode = NVPTX::INT_PTX_LDU_G_v2f64_ELE_areg64; + break; + } break; - case MVT::f32: - Opcode = NVPTX::INT_PTX_LDU_G_v2f32_ELE_32; + case NVPTXISD::LDGV4: + switch (EltVT.getSimpleVT().SimpleTy) { + default: + return NULL; + case MVT::i8: + Opcode = NVPTX::INT_PTX_LDG_G_v4i8_ELE_areg64; + break; + case MVT::i16: + Opcode = NVPTX::INT_PTX_LDG_G_v4i16_ELE_areg64; + break; + case MVT::i32: + Opcode = NVPTX::INT_PTX_LDG_G_v4i32_ELE_areg64; + break; + case MVT::f32: + Opcode = NVPTX::INT_PTX_LDG_G_v4f32_ELE_areg64; + break; + } break; - case MVT::f64: - Opcode = NVPTX::INT_PTX_LDU_G_v2f64_ELE_32; + case NVPTXISD::LDUV4: + switch (EltVT.getSimpleVT().SimpleTy) { + default: + return NULL; + case MVT::i8: + Opcode = NVPTX::INT_PTX_LDU_G_v4i8_ELE_areg64; + break; + case MVT::i16: + Opcode = NVPTX::INT_PTX_LDU_G_v4i16_ELE_areg64; + break; + case MVT::i32: + Opcode = NVPTX::INT_PTX_LDU_G_v4i32_ELE_areg64; + break; + case MVT::f32: + Opcode = NVPTX::INT_PTX_LDU_G_v4f32_ELE_areg64; + break; + } break; } - break; - case NVPTXISD::LDUV4: - switch (RetVT.getSimpleVT().SimpleTy) { + } else { + switch (N->getOpcode()) { default: return NULL; - case MVT::i8: - Opcode = NVPTX::INT_PTX_LDU_G_v4i8_ELE_32; + case NVPTXISD::LDGV2: + switch (EltVT.getSimpleVT().SimpleTy) { + default: + return NULL; + case MVT::i8: + Opcode = NVPTX::INT_PTX_LDG_G_v2i8_ELE_areg32; + break; + case MVT::i16: + Opcode = NVPTX::INT_PTX_LDG_G_v2i16_ELE_areg32; + break; + case MVT::i32: + Opcode = NVPTX::INT_PTX_LDG_G_v2i32_ELE_areg32; + break; + case MVT::i64: + Opcode = NVPTX::INT_PTX_LDG_G_v2i64_ELE_areg32; + break; + case MVT::f32: + Opcode = NVPTX::INT_PTX_LDG_G_v2f32_ELE_areg32; + break; + case MVT::f64: + Opcode = NVPTX::INT_PTX_LDG_G_v2f64_ELE_areg32; + break; + } break; - case MVT::i16: - Opcode = NVPTX::INT_PTX_LDU_G_v4i16_ELE_32; + case NVPTXISD::LDUV2: + switch (EltVT.getSimpleVT().SimpleTy) { + default: + return NULL; + case MVT::i8: + Opcode = NVPTX::INT_PTX_LDU_G_v2i8_ELE_areg32; + break; + case MVT::i16: + Opcode = NVPTX::INT_PTX_LDU_G_v2i16_ELE_areg32; + break; + case MVT::i32: + Opcode = NVPTX::INT_PTX_LDU_G_v2i32_ELE_areg32; + break; + case MVT::i64: + Opcode = NVPTX::INT_PTX_LDU_G_v2i64_ELE_areg32; + break; + case MVT::f32: + Opcode = NVPTX::INT_PTX_LDU_G_v2f32_ELE_areg32; + break; + case MVT::f64: + Opcode = NVPTX::INT_PTX_LDU_G_v2f64_ELE_areg32; + break; + } break; - case MVT::i32: - Opcode = NVPTX::INT_PTX_LDU_G_v4i32_ELE_32; + case NVPTXISD::LDGV4: + switch (EltVT.getSimpleVT().SimpleTy) { + default: + return NULL; + case MVT::i8: + Opcode = NVPTX::INT_PTX_LDG_G_v4i8_ELE_areg32; + break; + case MVT::i16: + Opcode = NVPTX::INT_PTX_LDG_G_v4i16_ELE_areg32; + break; + case MVT::i32: + Opcode = NVPTX::INT_PTX_LDG_G_v4i32_ELE_areg32; + break; + case MVT::f32: + Opcode = NVPTX::INT_PTX_LDG_G_v4f32_ELE_areg32; + break; + } break; - case MVT::f32: - Opcode = NVPTX::INT_PTX_LDU_G_v4f32_ELE_32; + case NVPTXISD::LDUV4: + switch (EltVT.getSimpleVT().SimpleTy) { + default: + return NULL; + case MVT::i8: + Opcode = NVPTX::INT_PTX_LDU_G_v4i8_ELE_areg32; + break; + case MVT::i16: + Opcode = NVPTX::INT_PTX_LDU_G_v4i16_ELE_areg32; + break; + case MVT::i32: + Opcode = NVPTX::INT_PTX_LDU_G_v4i32_ELE_areg32; + break; + case MVT::f32: + Opcode = NVPTX::INT_PTX_LDU_G_v4f32_ELE_areg32; + break; + } break; } - break; } - } - SDValue Ops[] = { Op1, Chain }; - LD = CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops); + SDValue Ops[] = { Op1, Chain }; + LD = CurDAG->getMachineNode(Opcode, DL, N->getVTList(), + ArrayRef<SDValue>(Ops, 2)); + } MachineSDNode::mmo_iterator MemRefs0 = MF->allocateMemRefsArray(1); MemRefs0[0] = cast<MemSDNode>(N)->getMemOperand(); @@ -981,7 +1295,7 @@ SDNode *NVPTXDAGToDAGISel::SelectLDGLDUVector(SDNode *N) { } SDNode *NVPTXDAGToDAGISel::SelectStore(SDNode *N) { - DebugLoc dl = N->getDebugLoc(); + SDLoc dl(N); StoreSDNode *ST = cast<StoreSDNode>(N); EVT StoreVT = ST->getMemoryVT(); SDNode *NVPTXST = NULL; @@ -1035,8 +1349,7 @@ SDNode *NVPTXDAGToDAGISel::SelectStore(SDNode *N) { SDValue Addr; SDValue Offset, Base; unsigned Opcode; - MVT::SimpleValueType SourceVT = - N1.getNode()->getValueType(0).getSimpleVT().SimpleTy; + MVT::SimpleValueType SourceVT = N1.getNode()->getSimpleValueType(0).SimpleTy; if (SelectDirectAddr(N2, Addr)) { switch (SourceVT) { @@ -1216,7 +1529,7 @@ SDNode *NVPTXDAGToDAGISel::SelectStoreVector(SDNode *N) { SDValue Op1 = N->getOperand(1); SDValue Addr, Offset, Base; unsigned Opcode; - DebugLoc DL = N->getDebugLoc(); + SDLoc DL(N); SDNode *ST; EVT EltVT = Op1.getValueType(); MemSDNode *MemSD = cast<MemSDNode>(N); @@ -1587,6 +1900,414 @@ SDNode *NVPTXDAGToDAGISel::SelectStoreVector(SDNode *N) { return ST; } +SDNode *NVPTXDAGToDAGISel::SelectLoadParam(SDNode *Node) { + SDValue Chain = Node->getOperand(0); + SDValue Offset = Node->getOperand(2); + SDValue Flag = Node->getOperand(3); + SDLoc DL(Node); + MemSDNode *Mem = cast<MemSDNode>(Node); + + unsigned VecSize; + switch (Node->getOpcode()) { + default: + return NULL; + case NVPTXISD::LoadParam: + VecSize = 1; + break; + case NVPTXISD::LoadParamV2: + VecSize = 2; + break; + case NVPTXISD::LoadParamV4: + VecSize = 4; + break; + } + + EVT EltVT = Node->getValueType(0); + EVT MemVT = Mem->getMemoryVT(); + + unsigned Opc = 0; + + switch (VecSize) { + default: + return NULL; + case 1: + switch (MemVT.getSimpleVT().SimpleTy) { + default: + return NULL; + case MVT::i1: + Opc = NVPTX::LoadParamMemI8; + break; + case MVT::i8: + Opc = NVPTX::LoadParamMemI8; + break; + case MVT::i16: + Opc = NVPTX::LoadParamMemI16; + break; + case MVT::i32: + Opc = NVPTX::LoadParamMemI32; + break; + case MVT::i64: + Opc = NVPTX::LoadParamMemI64; + break; + case MVT::f32: + Opc = NVPTX::LoadParamMemF32; + break; + case MVT::f64: + Opc = NVPTX::LoadParamMemF64; + break; + } + break; + case 2: + switch (MemVT.getSimpleVT().SimpleTy) { + default: + return NULL; + case MVT::i1: + Opc = NVPTX::LoadParamMemV2I8; + break; + case MVT::i8: + Opc = NVPTX::LoadParamMemV2I8; + break; + case MVT::i16: + Opc = NVPTX::LoadParamMemV2I16; + break; + case MVT::i32: + Opc = NVPTX::LoadParamMemV2I32; + break; + case MVT::i64: + Opc = NVPTX::LoadParamMemV2I64; + break; + case MVT::f32: + Opc = NVPTX::LoadParamMemV2F32; + break; + case MVT::f64: + Opc = NVPTX::LoadParamMemV2F64; + break; + } + break; + case 4: + switch (MemVT.getSimpleVT().SimpleTy) { + default: + return NULL; + case MVT::i1: + Opc = NVPTX::LoadParamMemV4I8; + break; + case MVT::i8: + Opc = NVPTX::LoadParamMemV4I8; + break; + case MVT::i16: + Opc = NVPTX::LoadParamMemV4I16; + break; + case MVT::i32: + Opc = NVPTX::LoadParamMemV4I32; + break; + case MVT::f32: + Opc = NVPTX::LoadParamMemV4F32; + break; + } + break; + } + + SDVTList VTs; + if (VecSize == 1) { + VTs = CurDAG->getVTList(EltVT, MVT::Other, MVT::Glue); + } else if (VecSize == 2) { + VTs = CurDAG->getVTList(EltVT, EltVT, MVT::Other, MVT::Glue); + } else { + EVT EVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other, MVT::Glue }; + VTs = CurDAG->getVTList(&EVTs[0], 5); + } + + unsigned OffsetVal = cast<ConstantSDNode>(Offset)->getZExtValue(); + + SmallVector<SDValue, 2> Ops; + Ops.push_back(CurDAG->getTargetConstant(OffsetVal, MVT::i32)); + Ops.push_back(Chain); + Ops.push_back(Flag); + + SDNode *Ret = + CurDAG->getMachineNode(Opc, DL, VTs, Ops); + return Ret; +} + +SDNode *NVPTXDAGToDAGISel::SelectStoreRetval(SDNode *N) { + SDLoc DL(N); + SDValue Chain = N->getOperand(0); + SDValue Offset = N->getOperand(1); + unsigned OffsetVal = cast<ConstantSDNode>(Offset)->getZExtValue(); + MemSDNode *Mem = cast<MemSDNode>(N); + + // How many elements do we have? + unsigned NumElts = 1; + switch (N->getOpcode()) { + default: + return NULL; + case NVPTXISD::StoreRetval: + NumElts = 1; + break; + case NVPTXISD::StoreRetvalV2: + NumElts = 2; + break; + case NVPTXISD::StoreRetvalV4: + NumElts = 4; + break; + } + + // Build vector of operands + SmallVector<SDValue, 6> Ops; + for (unsigned i = 0; i < NumElts; ++i) + Ops.push_back(N->getOperand(i + 2)); + Ops.push_back(CurDAG->getTargetConstant(OffsetVal, MVT::i32)); + Ops.push_back(Chain); + + // Determine target opcode + // If we have an i1, use an 8-bit store. The lowering code in + // NVPTXISelLowering will have already emitted an upcast. + unsigned Opcode = 0; + switch (NumElts) { + default: + return NULL; + case 1: + switch (Mem->getMemoryVT().getSimpleVT().SimpleTy) { + default: + return NULL; + case MVT::i1: + Opcode = NVPTX::StoreRetvalI8; + break; + case MVT::i8: + Opcode = NVPTX::StoreRetvalI8; + break; + case MVT::i16: + Opcode = NVPTX::StoreRetvalI16; + break; + case MVT::i32: + Opcode = NVPTX::StoreRetvalI32; + break; + case MVT::i64: + Opcode = NVPTX::StoreRetvalI64; + break; + case MVT::f32: + Opcode = NVPTX::StoreRetvalF32; + break; + case MVT::f64: + Opcode = NVPTX::StoreRetvalF64; + break; + } + break; + case 2: + switch (Mem->getMemoryVT().getSimpleVT().SimpleTy) { + default: + return NULL; + case MVT::i1: + Opcode = NVPTX::StoreRetvalV2I8; + break; + case MVT::i8: + Opcode = NVPTX::StoreRetvalV2I8; + break; + case MVT::i16: + Opcode = NVPTX::StoreRetvalV2I16; + break; + case MVT::i32: + Opcode = NVPTX::StoreRetvalV2I32; + break; + case MVT::i64: + Opcode = NVPTX::StoreRetvalV2I64; + break; + case MVT::f32: + Opcode = NVPTX::StoreRetvalV2F32; + break; + case MVT::f64: + Opcode = NVPTX::StoreRetvalV2F64; + break; + } + break; + case 4: + switch (Mem->getMemoryVT().getSimpleVT().SimpleTy) { + default: + return NULL; + case MVT::i1: + Opcode = NVPTX::StoreRetvalV4I8; + break; + case MVT::i8: + Opcode = NVPTX::StoreRetvalV4I8; + break; + case MVT::i16: + Opcode = NVPTX::StoreRetvalV4I16; + break; + case MVT::i32: + Opcode = NVPTX::StoreRetvalV4I32; + break; + case MVT::f32: + Opcode = NVPTX::StoreRetvalV4F32; + break; + } + break; + } + + SDNode *Ret = + CurDAG->getMachineNode(Opcode, DL, MVT::Other, Ops); + MachineSDNode::mmo_iterator MemRefs0 = MF->allocateMemRefsArray(1); + MemRefs0[0] = cast<MemSDNode>(N)->getMemOperand(); + cast<MachineSDNode>(Ret)->setMemRefs(MemRefs0, MemRefs0 + 1); + + return Ret; +} + +SDNode *NVPTXDAGToDAGISel::SelectStoreParam(SDNode *N) { + SDLoc DL(N); + SDValue Chain = N->getOperand(0); + SDValue Param = N->getOperand(1); + unsigned ParamVal = cast<ConstantSDNode>(Param)->getZExtValue(); + SDValue Offset = N->getOperand(2); + unsigned OffsetVal = cast<ConstantSDNode>(Offset)->getZExtValue(); + MemSDNode *Mem = cast<MemSDNode>(N); + SDValue Flag = N->getOperand(N->getNumOperands() - 1); + + // How many elements do we have? + unsigned NumElts = 1; + switch (N->getOpcode()) { + default: + return NULL; + case NVPTXISD::StoreParamU32: + case NVPTXISD::StoreParamS32: + case NVPTXISD::StoreParam: + NumElts = 1; + break; + case NVPTXISD::StoreParamV2: + NumElts = 2; + break; + case NVPTXISD::StoreParamV4: + NumElts = 4; + break; + } + + // Build vector of operands + SmallVector<SDValue, 8> Ops; + for (unsigned i = 0; i < NumElts; ++i) + Ops.push_back(N->getOperand(i + 3)); + Ops.push_back(CurDAG->getTargetConstant(ParamVal, MVT::i32)); + Ops.push_back(CurDAG->getTargetConstant(OffsetVal, MVT::i32)); + Ops.push_back(Chain); + Ops.push_back(Flag); + + // Determine target opcode + // If we have an i1, use an 8-bit store. The lowering code in + // NVPTXISelLowering will have already emitted an upcast. + unsigned Opcode = 0; + switch (N->getOpcode()) { + default: + switch (NumElts) { + default: + return NULL; + case 1: + switch (Mem->getMemoryVT().getSimpleVT().SimpleTy) { + default: + return NULL; + case MVT::i1: + Opcode = NVPTX::StoreParamI8; + break; + case MVT::i8: + Opcode = NVPTX::StoreParamI8; + break; + case MVT::i16: + Opcode = NVPTX::StoreParamI16; + break; + case MVT::i32: + Opcode = NVPTX::StoreParamI32; + break; + case MVT::i64: + Opcode = NVPTX::StoreParamI64; + break; + case MVT::f32: + Opcode = NVPTX::StoreParamF32; + break; + case MVT::f64: + Opcode = NVPTX::StoreParamF64; + break; + } + break; + case 2: + switch (Mem->getMemoryVT().getSimpleVT().SimpleTy) { + default: + return NULL; + case MVT::i1: + Opcode = NVPTX::StoreParamV2I8; + break; + case MVT::i8: + Opcode = NVPTX::StoreParamV2I8; + break; + case MVT::i16: + Opcode = NVPTX::StoreParamV2I16; + break; + case MVT::i32: + Opcode = NVPTX::StoreParamV2I32; + break; + case MVT::i64: + Opcode = NVPTX::StoreParamV2I64; + break; + case MVT::f32: + Opcode = NVPTX::StoreParamV2F32; + break; + case MVT::f64: + Opcode = NVPTX::StoreParamV2F64; + break; + } + break; + case 4: + switch (Mem->getMemoryVT().getSimpleVT().SimpleTy) { + default: + return NULL; + case MVT::i1: + Opcode = NVPTX::StoreParamV4I8; + break; + case MVT::i8: + Opcode = NVPTX::StoreParamV4I8; + break; + case MVT::i16: + Opcode = NVPTX::StoreParamV4I16; + break; + case MVT::i32: + Opcode = NVPTX::StoreParamV4I32; + break; + case MVT::f32: + Opcode = NVPTX::StoreParamV4F32; + break; + } + break; + } + break; + // Special case: if we have a sign-extend/zero-extend node, insert the + // conversion instruction first, and use that as the value operand to + // the selected StoreParam node. + case NVPTXISD::StoreParamU32: { + Opcode = NVPTX::StoreParamI32; + SDValue CvtNone = CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, + MVT::i32); + SDNode *Cvt = CurDAG->getMachineNode(NVPTX::CVT_u32_u16, DL, + MVT::i32, Ops[0], CvtNone); + Ops[0] = SDValue(Cvt, 0); + break; + } + case NVPTXISD::StoreParamS32: { + Opcode = NVPTX::StoreParamI32; + SDValue CvtNone = CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, + MVT::i32); + SDNode *Cvt = CurDAG->getMachineNode(NVPTX::CVT_s32_s16, DL, + MVT::i32, Ops[0], CvtNone); + Ops[0] = SDValue(Cvt, 0); + break; + } + } + + SDVTList RetVTs = CurDAG->getVTList(MVT::Other, MVT::Glue); + SDNode *Ret = + CurDAG->getMachineNode(Opcode, DL, RetVTs, Ops); + MachineSDNode::mmo_iterator MemRefs0 = MF->allocateMemRefsArray(1); + MemRefs0[0] = cast<MemSDNode>(N)->getMemOperand(); + cast<MachineSDNode>(Ret)->setMemRefs(MemRefs0, MemRefs0 + 1); + + return Ret; +} + // SelectDirectAddr - Match a direct address for DAG. // A direct address could be a globaladdress or externalsymbol. bool NVPTXDAGToDAGISel::SelectDirectAddr(SDValue N, SDValue &Address) { diff --git a/contrib/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/contrib/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h index ed16d44..d961e50 100644 --- a/contrib/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h +++ b/contrib/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h @@ -28,38 +28,22 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel { // If true, generate corresponding FPCONTRACT. This is // language dependent (i.e. CUDA and OpenCL works differently). - bool doFMADF32; bool doFMAF64; bool doFMAF32; bool doFMAF64AGG; bool doFMAF32AGG; bool allowFMA; - // 0: use div.approx - // 1: use div.full - // 2: For sm_20 and later, ieee-compliant div.rnd.f32 can be generated; - // Otherwise, use div.full - int do_DIVF32_PREC; - - // If true, generate sqrt.rn, else generate sqrt.approx. If FTZ - // is true, then generate the corresponding FTZ version. - bool do_SQRTF32_PREC; - - // If true, add .ftz to f32 instructions. - // This is only meaningful for sm_20 and later, as the default - // is not ftz. - // For sm earlier than sm_20, f32 denorms are always ftz by the - // hardware. - // We always add the .ftz modifier regardless of the sm value - // when Use32FTZ is true. - bool UseF32FTZ; - // If true, generate mul.wide from sext and mul bool doMulWide; + int getDivF32Level() const; + bool usePrecSqrtF32() const; + bool useF32FTZ() const; + public: explicit NVPTXDAGToDAGISel(NVPTXTargetMachine &tm, - CodeGenOpt::Level OptLevel); + CodeGenOpt::Level OptLevel); // Pass Name virtual const char *getPassName() const { @@ -80,7 +64,10 @@ private: SDNode *SelectLDGLDUVector(SDNode *N); SDNode *SelectStore(SDNode *N); SDNode *SelectStoreVector(SDNode *N); - + SDNode *SelectLoadParam(SDNode *N); + SDNode *SelectStoreRetval(SDNode *N); + SDNode *SelectStoreParam(SDNode *N); + inline SDValue getI32Imm(unsigned Imm) { return CurDAG->getTargetConstant(Imm, MVT::i32); } diff --git a/contrib/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/contrib/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 6e01a5a..6a8be75 100644 --- a/contrib/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/contrib/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -51,6 +51,8 @@ static bool IsPTXVectorType(MVT VT) { switch (VT.SimpleTy) { default: return false; + case MVT::v2i1: + case MVT::v4i1: case MVT::v2i8: case MVT::v4i8: case MVT::v2i16: @@ -65,6 +67,37 @@ static bool IsPTXVectorType(MVT VT) { } } +/// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive +/// EVTs that compose it. Unlike ComputeValueVTs, this will break apart vectors +/// into their primitive components. +/// NOTE: This is a band-aid for code that expects ComputeValueVTs to return the +/// same number of types as the Ins/Outs arrays in LowerFormalArguments, +/// LowerCall, and LowerReturn. +static void ComputePTXValueVTs(const TargetLowering &TLI, Type *Ty, + SmallVectorImpl<EVT> &ValueVTs, + SmallVectorImpl<uint64_t> *Offsets = 0, + uint64_t StartingOffset = 0) { + SmallVector<EVT, 16> TempVTs; + SmallVector<uint64_t, 16> TempOffsets; + + ComputeValueVTs(TLI, Ty, TempVTs, &TempOffsets, StartingOffset); + for (unsigned i = 0, e = TempVTs.size(); i != e; ++i) { + EVT VT = TempVTs[i]; + uint64_t Off = TempOffsets[i]; + if (VT.isVector()) + for (unsigned j = 0, je = VT.getVectorNumElements(); j != je; ++j) { + ValueVTs.push_back(VT.getVectorElementType()); + if (Offsets) + Offsets->push_back(Off+j*VT.getVectorElementType().getStoreSize()); + } + else { + ValueVTs.push_back(VT); + if (Offsets) + Offsets->push_back(Off); + } + } +} + // NVPTXTargetLowering Constructor. NVPTXTargetLowering::NVPTXTargetLowering(NVPTXTargetMachine &TM) : TargetLowering(TM, new NVPTXTargetObjectFile()), nvTM(&TM), @@ -90,7 +123,6 @@ NVPTXTargetLowering::NVPTXTargetLowering(NVPTXTargetMachine &TM) setSchedulingPreference(Sched::Source); addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass); - addRegisterClass(MVT::i8, &NVPTX::Int8RegsRegClass); addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass); addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass); addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass); @@ -106,10 +138,12 @@ NVPTXTargetLowering::NVPTXTargetLowering(NVPTXTargetMachine &TM) setOperationAction(ISD::BR_CC, MVT::i16, Expand); setOperationAction(ISD::BR_CC, MVT::i32, Expand); setOperationAction(ISD::BR_CC, MVT::i64, Expand); - setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i64, Expand); - setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i32, Expand); - setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i16, Expand); - setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i8, Expand); + // Some SIGN_EXTEND_INREG can be done using cvt instruction. + // For others we will expand to a SHL/SRA pair. + setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i64, Legal); + setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i32, Legal); + setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i16, Legal); + setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i8 , Legal); setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i1, Expand); if (nvptxSubtarget.hasROT64()) { @@ -170,6 +204,9 @@ NVPTXTargetLowering::NVPTXTargetLowering(NVPTXTargetMachine &TM) // TRAP can be lowered to PTX trap setOperationAction(ISD::TRAP, MVT::Other, Legal); + setOperationAction(ISD::ADDC, MVT::i64, Expand); + setOperationAction(ISD::ADDE, MVT::i64, Expand); + // Register custom handling for vector loads/stores for (int i = MVT::FIRST_VECTOR_VALUETYPE; i <= MVT::LAST_VECTOR_VALUETYPE; ++i) { @@ -181,6 +218,25 @@ NVPTXTargetLowering::NVPTXTargetLowering(NVPTXTargetMachine &TM) } } + // Custom handling for i8 intrinsics + setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i8, Custom); + + setOperationAction(ISD::CTLZ, MVT::i16, Legal); + setOperationAction(ISD::CTLZ, MVT::i32, Legal); + setOperationAction(ISD::CTLZ, MVT::i64, Legal); + setOperationAction(ISD::CTLZ_ZERO_UNDEF, MVT::i16, Legal); + setOperationAction(ISD::CTLZ_ZERO_UNDEF, MVT::i32, Legal); + setOperationAction(ISD::CTLZ_ZERO_UNDEF, MVT::i64, Legal); + setOperationAction(ISD::CTTZ, MVT::i16, Expand); + setOperationAction(ISD::CTTZ, MVT::i32, Expand); + setOperationAction(ISD::CTTZ, MVT::i64, Expand); + setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::i16, Expand); + setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::i32, Expand); + setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::i64, Expand); + setOperationAction(ISD::CTPOP, MVT::i16, Legal); + setOperationAction(ISD::CTPOP, MVT::i32, Legal); + setOperationAction(ISD::CTPOP, MVT::i64, Legal); + // Now deduce the information based on the above mentioned // actions computeRegisterProperties(); @@ -196,8 +252,6 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const { return "NVPTXISD::RET_FLAG"; case NVPTXISD::Wrapper: return "NVPTXISD::Wrapper"; - case NVPTXISD::NVBuiltin: - return "NVPTXISD::NVBuiltin"; case NVPTXISD::DeclareParam: return "NVPTXISD::DeclareParam"; case NVPTXISD::DeclareScalarParam: @@ -210,14 +264,20 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const { return "NVPTXISD::PrintCall"; case NVPTXISD::LoadParam: return "NVPTXISD::LoadParam"; + case NVPTXISD::LoadParamV2: + return "NVPTXISD::LoadParamV2"; + case NVPTXISD::LoadParamV4: + return "NVPTXISD::LoadParamV4"; case NVPTXISD::StoreParam: return "NVPTXISD::StoreParam"; + case NVPTXISD::StoreParamV2: + return "NVPTXISD::StoreParamV2"; + case NVPTXISD::StoreParamV4: + return "NVPTXISD::StoreParamV4"; case NVPTXISD::StoreParamS32: return "NVPTXISD::StoreParamS32"; case NVPTXISD::StoreParamU32: return "NVPTXISD::StoreParamU32"; - case NVPTXISD::MoveToParam: - return "NVPTXISD::MoveToParam"; case NVPTXISD::CallArgBegin: return "NVPTXISD::CallArgBegin"; case NVPTXISD::CallArg: @@ -236,12 +296,12 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const { return "NVPTXISD::Prototype"; case NVPTXISD::MoveParam: return "NVPTXISD::MoveParam"; - case NVPTXISD::MoveRetval: - return "NVPTXISD::MoveRetval"; - case NVPTXISD::MoveToRetval: - return "NVPTXISD::MoveToRetval"; case NVPTXISD::StoreRetval: return "NVPTXISD::StoreRetval"; + case NVPTXISD::StoreRetvalV2: + return "NVPTXISD::StoreRetvalV2"; + case NVPTXISD::StoreRetvalV4: + return "NVPTXISD::StoreRetvalV4"; case NVPTXISD::PseudoUseParam: return "NVPTXISD::PseudoUseParam"; case NVPTXISD::RETURN: @@ -250,6 +310,8 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const { return "NVPTXISD::CallSeqBegin"; case NVPTXISD::CallSeqEnd: return "NVPTXISD::CallSeqEnd"; + case NVPTXISD::CallPrototype: + return "NVPTXISD::CallPrototype"; case NVPTXISD::LoadV2: return "NVPTXISD::LoadV2"; case NVPTXISD::LoadV4: @@ -275,89 +337,68 @@ bool NVPTXTargetLowering::shouldSplitVectorElementType(EVT VT) const { SDValue NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const { - DebugLoc dl = Op.getDebugLoc(); + SDLoc dl(Op); const GlobalValue *GV = cast<GlobalAddressSDNode>(Op)->getGlobal(); Op = DAG.getTargetGlobalAddress(GV, dl, getPointerTy()); return DAG.getNode(NVPTXISD::Wrapper, dl, getPointerTy(), Op); } -std::string NVPTXTargetLowering::getPrototype( - Type *retTy, const ArgListTy &Args, - const SmallVectorImpl<ISD::OutputArg> &Outs, unsigned retAlignment) const { +std::string +NVPTXTargetLowering::getPrototype(Type *retTy, const ArgListTy &Args, + const SmallVectorImpl<ISD::OutputArg> &Outs, + unsigned retAlignment, + const ImmutableCallSite *CS) const { bool isABI = (nvptxSubtarget.getSmVersion() >= 20); + assert(isABI && "Non-ABI compilation is not supported"); + if (!isABI) + return ""; std::stringstream O; O << "prototype_" << uniqueCallSite << " : .callprototype "; - if (retTy->getTypeID() == Type::VoidTyID) + if (retTy->getTypeID() == Type::VoidTyID) { O << "()"; - else { + } else { O << "("; - if (isABI) { - if (retTy->isPrimitiveType() || retTy->isIntegerTy()) { - unsigned size = 0; - if (const IntegerType *ITy = dyn_cast<IntegerType>(retTy)) { - size = ITy->getBitWidth(); - if (size < 32) - size = 32; - } else { - assert(retTy->isFloatingPointTy() && - "Floating point type expected here"); - size = retTy->getPrimitiveSizeInBits(); - } - - O << ".param .b" << size << " _"; - } else if (isa<PointerType>(retTy)) - O << ".param .b" << getPointerTy().getSizeInBits() << " _"; - else { - if ((retTy->getTypeID() == Type::StructTyID) || - isa<VectorType>(retTy)) { - SmallVector<EVT, 16> vtparts; - ComputeValueVTs(*this, retTy, vtparts); - unsigned totalsz = 0; - for (unsigned i = 0, e = vtparts.size(); i != e; ++i) { - unsigned elems = 1; - EVT elemtype = vtparts[i]; - if (vtparts[i].isVector()) { - elems = vtparts[i].getVectorNumElements(); - elemtype = vtparts[i].getVectorElementType(); - } - for (unsigned j = 0, je = elems; j != je; ++j) { - unsigned sz = elemtype.getSizeInBits(); - if (elemtype.isInteger() && (sz < 8)) - sz = 8; - totalsz += sz / 8; - } - } - O << ".param .align " << retAlignment << " .b8 _[" << totalsz << "]"; - } else { - assert(false && "Unknown return type"); - } + if (retTy->isPrimitiveType() || retTy->isIntegerTy()) { + unsigned size = 0; + if (const IntegerType *ITy = dyn_cast<IntegerType>(retTy)) { + size = ITy->getBitWidth(); + if (size < 32) + size = 32; + } else { + assert(retTy->isFloatingPointTy() && + "Floating point type expected here"); + size = retTy->getPrimitiveSizeInBits(); } - } else { - SmallVector<EVT, 16> vtparts; - ComputeValueVTs(*this, retTy, vtparts); - unsigned idx = 0; - for (unsigned i = 0, e = vtparts.size(); i != e; ++i) { - unsigned elems = 1; - EVT elemtype = vtparts[i]; - if (vtparts[i].isVector()) { - elems = vtparts[i].getVectorNumElements(); - elemtype = vtparts[i].getVectorElementType(); - } - for (unsigned j = 0, je = elems; j != je; ++j) { - unsigned sz = elemtype.getSizeInBits(); - if (elemtype.isInteger() && (sz < 32)) - sz = 32; - O << ".reg .b" << sz << " _"; - if (j < je - 1) - O << ", "; - ++idx; + O << ".param .b" << size << " _"; + } else if (isa<PointerType>(retTy)) { + O << ".param .b" << getPointerTy().getSizeInBits() << " _"; + } else { + if ((retTy->getTypeID() == Type::StructTyID) || isa<VectorType>(retTy)) { + SmallVector<EVT, 16> vtparts; + ComputeValueVTs(*this, retTy, vtparts); + unsigned totalsz = 0; + for (unsigned i = 0, e = vtparts.size(); i != e; ++i) { + unsigned elems = 1; + EVT elemtype = vtparts[i]; + if (vtparts[i].isVector()) { + elems = vtparts[i].getVectorNumElements(); + elemtype = vtparts[i].getVectorElementType(); + } + // TODO: no need to loop + for (unsigned j = 0, je = elems; j != je; ++j) { + unsigned sz = elemtype.getSizeInBits(); + if (elemtype.isInteger() && (sz < 8)) + sz = 8; + totalsz += sz / 8; + } } - if (i < e - 1) - O << ", "; + O << ".param .align " << retAlignment << " .b8 _[" << totalsz << "]"; + } else { + assert(false && "Unknown return type"); } } O << ") "; @@ -367,14 +408,38 @@ std::string NVPTXTargetLowering::getPrototype( bool first = true; MVT thePointerTy = getPointerTy(); - for (unsigned i = 0, e = Args.size(); i != e; ++i) { - const Type *Ty = Args[i].Ty; + unsigned OIdx = 0; + for (unsigned i = 0, e = Args.size(); i != e; ++i, ++OIdx) { + Type *Ty = Args[i].Ty; if (!first) { O << ", "; } first = false; - if (Outs[i].Flags.isByVal() == false) { + if (Outs[OIdx].Flags.isByVal() == false) { + if (Ty->isAggregateType() || Ty->isVectorTy()) { + unsigned align = 0; + const CallInst *CallI = cast<CallInst>(CS->getInstruction()); + const DataLayout *TD = getDataLayout(); + // +1 because index 0 is reserved for return type alignment + if (!llvm::getAlign(*CallI, i + 1, align)) + align = TD->getABITypeAlignment(Ty); + unsigned sz = TD->getTypeAllocSize(Ty); + O << ".param .align " << align << " .b8 "; + O << "_"; + O << "[" << sz << "]"; + // update the index for Outs + SmallVector<EVT, 16> vtparts; + ComputeValueVTs(*this, Ty, vtparts); + if (unsigned len = vtparts.size()) + OIdx += len - 1; + continue; + } + // i8 types in IR will be i16 types in SDAG + assert((getValueType(Ty) == Outs[OIdx].VT || + (getValueType(Ty) == MVT::i8 && Outs[OIdx].VT == MVT::i16)) && + "type mismatch between callee prototype and arguments"); + // scalar type unsigned sz = 0; if (isa<IntegerType>(Ty)) { sz = cast<IntegerType>(Ty)->getBitWidth(); @@ -384,10 +449,7 @@ std::string NVPTXTargetLowering::getPrototype( sz = thePointerTy.getSizeInBits(); else sz = Ty->getPrimitiveSizeInBits(); - if (isABI) - O << ".param .b" << sz << " "; - else - O << ".reg .b" << sz << " "; + O << ".param .b" << sz << " "; O << "_"; continue; } @@ -395,50 +457,72 @@ std::string NVPTXTargetLowering::getPrototype( assert(PTy && "Param with byval attribute should be a pointer type"); Type *ETy = PTy->getElementType(); - if (isABI) { - unsigned align = Outs[i].Flags.getByValAlign(); - unsigned sz = getDataLayout()->getTypeAllocSize(ETy); - O << ".param .align " << align << " .b8 "; - O << "_"; - O << "[" << sz << "]"; - continue; - } else { - SmallVector<EVT, 16> vtparts; - ComputeValueVTs(*this, ETy, vtparts); - for (unsigned i = 0, e = vtparts.size(); i != e; ++i) { - unsigned elems = 1; - EVT elemtype = vtparts[i]; - if (vtparts[i].isVector()) { - elems = vtparts[i].getVectorNumElements(); - elemtype = vtparts[i].getVectorElementType(); - } + unsigned align = Outs[OIdx].Flags.getByValAlign(); + unsigned sz = getDataLayout()->getTypeAllocSize(ETy); + O << ".param .align " << align << " .b8 "; + O << "_"; + O << "[" << sz << "]"; + } + O << ");"; + return O.str(); +} - for (unsigned j = 0, je = elems; j != je; ++j) { - unsigned sz = elemtype.getSizeInBits(); - if (elemtype.isInteger() && (sz < 32)) - sz = 32; - O << ".reg .b" << sz << " "; - O << "_"; - if (j < je - 1) - O << ", "; - } - if (i < e - 1) - O << ", "; +unsigned +NVPTXTargetLowering::getArgumentAlignment(SDValue Callee, + const ImmutableCallSite *CS, + Type *Ty, + unsigned Idx) const { + const DataLayout *TD = getDataLayout(); + unsigned Align = 0; + const Value *DirectCallee = CS->getCalledFunction(); + + if (!DirectCallee) { + // We don't have a direct function symbol, but that may be because of + // constant cast instructions in the call. + const Instruction *CalleeI = CS->getInstruction(); + assert(CalleeI && "Call target is not a function or derived value?"); + + // With bitcast'd call targets, the instruction will be the call + if (isa<CallInst>(CalleeI)) { + // Check if we have call alignment metadata + if (llvm::getAlign(*cast<CallInst>(CalleeI), Idx, Align)) + return Align; + + const Value *CalleeV = cast<CallInst>(CalleeI)->getCalledValue(); + // Ignore any bitcast instructions + while(isa<ConstantExpr>(CalleeV)) { + const ConstantExpr *CE = cast<ConstantExpr>(CalleeV); + if (!CE->isCast()) + break; + // Look through the bitcast + CalleeV = cast<ConstantExpr>(CalleeV)->getOperand(0); } - continue; + + // We have now looked past all of the bitcasts. Do we finally have a + // Function? + if (isa<Function>(CalleeV)) + DirectCallee = CalleeV; } } - O << ");"; - return O.str(); + + // Check for function alignment information if we found that the + // ultimate target is a Function + if (DirectCallee) + if (llvm::getAlign(*cast<Function>(DirectCallee), Idx, Align)) + return Align; + + // Call is indirect or alignment information is not available, fall back to + // the ABI type alignment + return TD->getABITypeAlignment(Ty); } SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, SmallVectorImpl<SDValue> &InVals) const { SelectionDAG &DAG = CLI.DAG; - DebugLoc &dl = CLI.DL; - SmallVector<ISD::OutputArg, 32> &Outs = CLI.Outs; - SmallVector<SDValue, 32> &OutVals = CLI.OutVals; - SmallVector<ISD::InputArg, 32> &Ins = CLI.Ins; + SDLoc dl = CLI.DL; + SmallVectorImpl<ISD::OutputArg> &Outs = CLI.Outs; + SmallVectorImpl<SDValue> &OutVals = CLI.OutVals; + SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins; SDValue Chain = CLI.Chain; SDValue Callee = CLI.Callee; bool &isTailCall = CLI.IsTailCall; @@ -447,53 +531,258 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, ImmutableCallSite *CS = CLI.CS; bool isABI = (nvptxSubtarget.getSmVersion() >= 20); + assert(isABI && "Non-ABI compilation is not supported"); + if (!isABI) + return Chain; + const DataLayout *TD = getDataLayout(); + MachineFunction &MF = DAG.getMachineFunction(); + const Function *F = MF.getFunction(); SDValue tempChain = Chain; Chain = - DAG.getCALLSEQ_START(Chain, DAG.getIntPtrConstant(uniqueCallSite, true)); + DAG.getCALLSEQ_START(Chain, DAG.getIntPtrConstant(uniqueCallSite, true), + dl); SDValue InFlag = Chain.getValue(1); - assert((Outs.size() == Args.size()) && - "Unexpected number of arguments to function call"); unsigned paramCount = 0; + // Args.size() and Outs.size() need not match. + // Outs.size() will be larger + // * if there is an aggregate argument with multiple fields (each field + // showing up separately in Outs) + // * if there is a vector argument with more than typical vector-length + // elements (generally if more than 4) where each vector element is + // individually present in Outs. + // So a different index should be used for indexing into Outs/OutVals. + // See similar issue in LowerFormalArguments. + unsigned OIdx = 0; // Declare the .params or .reg need to pass values // to the function - for (unsigned i = 0, e = Outs.size(); i != e; ++i) { - EVT VT = Outs[i].VT; + for (unsigned i = 0, e = Args.size(); i != e; ++i, ++OIdx) { + EVT VT = Outs[OIdx].VT; + Type *Ty = Args[i].Ty; + + if (Outs[OIdx].Flags.isByVal() == false) { + if (Ty->isAggregateType()) { + // aggregate + SmallVector<EVT, 16> vtparts; + ComputeValueVTs(*this, Ty, vtparts); + + unsigned align = getArgumentAlignment(Callee, CS, Ty, paramCount + 1); + // declare .param .align <align> .b8 .param<n>[<size>]; + unsigned sz = TD->getTypeAllocSize(Ty); + SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); + SDValue DeclareParamOps[] = { Chain, DAG.getConstant(align, MVT::i32), + DAG.getConstant(paramCount, MVT::i32), + DAG.getConstant(sz, MVT::i32), InFlag }; + Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs, + DeclareParamOps, 5); + InFlag = Chain.getValue(1); + unsigned curOffset = 0; + for (unsigned j = 0, je = vtparts.size(); j != je; ++j) { + unsigned elems = 1; + EVT elemtype = vtparts[j]; + if (vtparts[j].isVector()) { + elems = vtparts[j].getVectorNumElements(); + elemtype = vtparts[j].getVectorElementType(); + } + for (unsigned k = 0, ke = elems; k != ke; ++k) { + unsigned sz = elemtype.getSizeInBits(); + if (elemtype.isInteger() && (sz < 8)) + sz = 8; + SDValue StVal = OutVals[OIdx]; + if (elemtype.getSizeInBits() < 16) { + StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal); + } + SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); + SDValue CopyParamOps[] = { Chain, + DAG.getConstant(paramCount, MVT::i32), + DAG.getConstant(curOffset, MVT::i32), + StVal, InFlag }; + Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl, + CopyParamVTs, &CopyParamOps[0], 5, + elemtype, MachinePointerInfo()); + InFlag = Chain.getValue(1); + curOffset += sz / 8; + ++OIdx; + } + } + if (vtparts.size() > 0) + --OIdx; + ++paramCount; + continue; + } + if (Ty->isVectorTy()) { + EVT ObjectVT = getValueType(Ty); + unsigned align = getArgumentAlignment(Callee, CS, Ty, paramCount + 1); + // declare .param .align <align> .b8 .param<n>[<size>]; + unsigned sz = TD->getTypeAllocSize(Ty); + SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); + SDValue DeclareParamOps[] = { Chain, DAG.getConstant(align, MVT::i32), + DAG.getConstant(paramCount, MVT::i32), + DAG.getConstant(sz, MVT::i32), InFlag }; + Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs, + DeclareParamOps, 5); + InFlag = Chain.getValue(1); + unsigned NumElts = ObjectVT.getVectorNumElements(); + EVT EltVT = ObjectVT.getVectorElementType(); + EVT MemVT = EltVT; + bool NeedExtend = false; + if (EltVT.getSizeInBits() < 16) { + NeedExtend = true; + EltVT = MVT::i16; + } + + // V1 store + if (NumElts == 1) { + SDValue Elt = OutVals[OIdx++]; + if (NeedExtend) + Elt = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt); + + SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); + SDValue CopyParamOps[] = { Chain, + DAG.getConstant(paramCount, MVT::i32), + DAG.getConstant(0, MVT::i32), Elt, + InFlag }; + Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl, + CopyParamVTs, &CopyParamOps[0], 5, + MemVT, MachinePointerInfo()); + InFlag = Chain.getValue(1); + } else if (NumElts == 2) { + SDValue Elt0 = OutVals[OIdx++]; + SDValue Elt1 = OutVals[OIdx++]; + if (NeedExtend) { + Elt0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt0); + Elt1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt1); + } + + SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); + SDValue CopyParamOps[] = { Chain, + DAG.getConstant(paramCount, MVT::i32), + DAG.getConstant(0, MVT::i32), Elt0, Elt1, + InFlag }; + Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParamV2, dl, + CopyParamVTs, &CopyParamOps[0], 6, + MemVT, MachinePointerInfo()); + InFlag = Chain.getValue(1); + } else { + unsigned curOffset = 0; + // V4 stores + // We have at least 4 elements (<3 x Ty> expands to 4 elements) and + // the + // vector will be expanded to a power of 2 elements, so we know we can + // always round up to the next multiple of 4 when creating the vector + // stores. + // e.g. 4 elem => 1 st.v4 + // 6 elem => 2 st.v4 + // 8 elem => 2 st.v4 + // 11 elem => 3 st.v4 + unsigned VecSize = 4; + if (EltVT.getSizeInBits() == 64) + VecSize = 2; + + // This is potentially only part of a vector, so assume all elements + // are packed together. + unsigned PerStoreOffset = MemVT.getStoreSizeInBits() / 8 * VecSize; + + for (unsigned i = 0; i < NumElts; i += VecSize) { + // Get values + SDValue StoreVal; + SmallVector<SDValue, 8> Ops; + Ops.push_back(Chain); + Ops.push_back(DAG.getConstant(paramCount, MVT::i32)); + Ops.push_back(DAG.getConstant(curOffset, MVT::i32)); + + unsigned Opc = NVPTXISD::StoreParamV2; + + StoreVal = OutVals[OIdx++]; + if (NeedExtend) + StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal); + Ops.push_back(StoreVal); + + if (i + 1 < NumElts) { + StoreVal = OutVals[OIdx++]; + if (NeedExtend) + StoreVal = + DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal); + } else { + StoreVal = DAG.getUNDEF(EltVT); + } + Ops.push_back(StoreVal); + + if (VecSize == 4) { + Opc = NVPTXISD::StoreParamV4; + if (i + 2 < NumElts) { + StoreVal = OutVals[OIdx++]; + if (NeedExtend) + StoreVal = + DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal); + } else { + StoreVal = DAG.getUNDEF(EltVT); + } + Ops.push_back(StoreVal); + + if (i + 3 < NumElts) { + StoreVal = OutVals[OIdx++]; + if (NeedExtend) + StoreVal = + DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal); + } else { + StoreVal = DAG.getUNDEF(EltVT); + } + Ops.push_back(StoreVal); + } - if (Outs[i].Flags.isByVal() == false) { + Ops.push_back(InFlag); + + SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); + Chain = DAG.getMemIntrinsicNode(Opc, dl, CopyParamVTs, &Ops[0], + Ops.size(), MemVT, + MachinePointerInfo()); + InFlag = Chain.getValue(1); + curOffset += PerStoreOffset; + } + } + ++paramCount; + --OIdx; + continue; + } // Plain scalar // for ABI, declare .param .b<size> .param<n>; - // for nonABI, declare .reg .b<size> .param<n>; - unsigned isReg = 1; - if (isABI) - isReg = 0; unsigned sz = VT.getSizeInBits(); - if (VT.isInteger() && (sz < 32)) - sz = 32; + bool needExtend = false; + if (VT.isInteger()) { + if (sz < 16) + needExtend = true; + if (sz < 32) + sz = 32; + } SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); SDValue DeclareParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32), DAG.getConstant(sz, MVT::i32), - DAG.getConstant(isReg, MVT::i32), InFlag }; + DAG.getConstant(0, MVT::i32), InFlag }; Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs, DeclareParamOps, 5); InFlag = Chain.getValue(1); + SDValue OutV = OutVals[OIdx]; + if (needExtend) { + // zext/sext i1 to i16 + unsigned opc = ISD::ZERO_EXTEND; + if (Outs[OIdx].Flags.isSExt()) + opc = ISD::SIGN_EXTEND; + OutV = DAG.getNode(opc, dl, MVT::i16, OutV); + } SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32), - DAG.getConstant(0, MVT::i32), OutVals[i], - InFlag }; + DAG.getConstant(0, MVT::i32), OutV, InFlag }; unsigned opcode = NVPTXISD::StoreParam; - if (isReg) - opcode = NVPTXISD::MoveToParam; - else { - if (Outs[i].Flags.isZExt()) - opcode = NVPTXISD::StoreParamU32; - else if (Outs[i].Flags.isSExt()) - opcode = NVPTXISD::StoreParamS32; - } - Chain = DAG.getNode(opcode, dl, CopyParamVTs, CopyParamOps, 5); + if (Outs[OIdx].Flags.isZExt()) + opcode = NVPTXISD::StoreParamU32; + else if (Outs[OIdx].Flags.isSExt()) + opcode = NVPTXISD::StoreParamS32; + Chain = DAG.getMemIntrinsicNode(opcode, dl, CopyParamVTs, CopyParamOps, 5, + VT, MachinePointerInfo()); InFlag = Chain.getValue(1); ++paramCount; @@ -505,55 +794,20 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, assert(PTy && "Type of a byval parameter should be pointer"); ComputeValueVTs(*this, PTy->getElementType(), vtparts); - if (isABI) { - // declare .param .align 16 .b8 .param<n>[<size>]; - unsigned sz = Outs[i].Flags.getByValSize(); - SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); - // The ByValAlign in the Outs[i].Flags is alway set at this point, so we - // don't need to - // worry about natural alignment or not. See TargetLowering::LowerCallTo() - SDValue DeclareParamOps[] = { - Chain, DAG.getConstant(Outs[i].Flags.getByValAlign(), MVT::i32), - DAG.getConstant(paramCount, MVT::i32), DAG.getConstant(sz, MVT::i32), - InFlag - }; - Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs, - DeclareParamOps, 5); - InFlag = Chain.getValue(1); - unsigned curOffset = 0; - for (unsigned j = 0, je = vtparts.size(); j != je; ++j) { - unsigned elems = 1; - EVT elemtype = vtparts[j]; - if (vtparts[j].isVector()) { - elems = vtparts[j].getVectorNumElements(); - elemtype = vtparts[j].getVectorElementType(); - } - for (unsigned k = 0, ke = elems; k != ke; ++k) { - unsigned sz = elemtype.getSizeInBits(); - if (elemtype.isInteger() && (sz < 8)) - sz = 8; - SDValue srcAddr = - DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[i], - DAG.getConstant(curOffset, getPointerTy())); - SDValue theVal = - DAG.getLoad(elemtype, dl, tempChain, srcAddr, - MachinePointerInfo(), false, false, false, 0); - SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); - SDValue CopyParamOps[] = { Chain, - DAG.getConstant(paramCount, MVT::i32), - DAG.getConstant(curOffset, MVT::i32), - theVal, InFlag }; - Chain = DAG.getNode(NVPTXISD::StoreParam, dl, CopyParamVTs, - CopyParamOps, 5); - InFlag = Chain.getValue(1); - curOffset += sz / 8; - } - } - ++paramCount; - continue; - } - // Non-abi, struct or vector - // Declare a bunch or .reg .b<size> .param<n> + // declare .param .align <align> .b8 .param<n>[<size>]; + unsigned sz = Outs[OIdx].Flags.getByValSize(); + SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); + // The ByValAlign in the Outs[OIdx].Flags is alway set at this point, + // so we don't need to worry about natural alignment or not. + // See TargetLowering::LowerCallTo(). + SDValue DeclareParamOps[] = { + Chain, DAG.getConstant(Outs[OIdx].Flags.getByValAlign(), MVT::i32), + DAG.getConstant(paramCount, MVT::i32), DAG.getConstant(sz, MVT::i32), + InFlag + }; + Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs, + DeclareParamOps, 5); + InFlag = Chain.getValue(1); unsigned curOffset = 0; for (unsigned j = 0, je = vtparts.size(); j != je; ++j) { unsigned elems = 1; @@ -564,107 +818,66 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, } for (unsigned k = 0, ke = elems; k != ke; ++k) { unsigned sz = elemtype.getSizeInBits(); - if (elemtype.isInteger() && (sz < 32)) - sz = 32; - SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); - SDValue DeclareParamOps[] = { Chain, - DAG.getConstant(paramCount, MVT::i32), - DAG.getConstant(sz, MVT::i32), - DAG.getConstant(1, MVT::i32), InFlag }; - Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs, - DeclareParamOps, 5); - InFlag = Chain.getValue(1); + if (elemtype.isInteger() && (sz < 8)) + sz = 8; SDValue srcAddr = - DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[i], + DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[OIdx], DAG.getConstant(curOffset, getPointerTy())); - SDValue theVal = - DAG.getLoad(elemtype, dl, tempChain, srcAddr, MachinePointerInfo(), - false, false, false, 0); + SDValue theVal = DAG.getLoad(elemtype, dl, tempChain, srcAddr, + MachinePointerInfo(), false, false, false, + 0); + if (elemtype.getSizeInBits() < 16) { + theVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, theVal); + } SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32), - DAG.getConstant(0, MVT::i32), theVal, + DAG.getConstant(curOffset, MVT::i32), theVal, InFlag }; - Chain = DAG.getNode(NVPTXISD::MoveToParam, dl, CopyParamVTs, - CopyParamOps, 5); + Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl, CopyParamVTs, + CopyParamOps, 5, elemtype, + MachinePointerInfo()); + InFlag = Chain.getValue(1); - ++paramCount; + curOffset += sz / 8; } } + ++paramCount; } GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode()); unsigned retAlignment = 0; // Handle Result - unsigned retCount = 0; if (Ins.size() > 0) { SmallVector<EVT, 16> resvtparts; ComputeValueVTs(*this, retTy, resvtparts); - // Declare one .param .align 16 .b8 func_retval0[<size>] for ABI or - // individual .reg .b<size> func_retval<0..> for non ABI - unsigned resultsz = 0; - for (unsigned i = 0, e = resvtparts.size(); i != e; ++i) { - unsigned elems = 1; - EVT elemtype = resvtparts[i]; - if (resvtparts[i].isVector()) { - elems = resvtparts[i].getVectorNumElements(); - elemtype = resvtparts[i].getVectorElementType(); - } - for (unsigned j = 0, je = elems; j != je; ++j) { - unsigned sz = elemtype.getSizeInBits(); - if (isABI == false) { - if (elemtype.isInteger() && (sz < 32)) - sz = 32; - } else { - if (elemtype.isInteger() && (sz < 8)) - sz = 8; - } - if (isABI == false) { - SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue); - SDValue DeclareRetOps[] = { Chain, DAG.getConstant(2, MVT::i32), - DAG.getConstant(sz, MVT::i32), - DAG.getConstant(retCount, MVT::i32), - InFlag }; - Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, DeclareRetVTs, - DeclareRetOps, 5); - InFlag = Chain.getValue(1); - ++retCount; - } - resultsz += sz; - } - } - if (isABI) { - if (retTy->isPrimitiveType() || retTy->isIntegerTy() || - retTy->isPointerTy()) { - // Scalar needs to be at least 32bit wide - if (resultsz < 32) - resultsz = 32; - SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue); - SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, MVT::i32), - DAG.getConstant(resultsz, MVT::i32), - DAG.getConstant(0, MVT::i32), InFlag }; - Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, DeclareRetVTs, - DeclareRetOps, 5); - InFlag = Chain.getValue(1); - } else { - if (Func) { // direct call - if (!llvm::getAlign(*(CS->getCalledFunction()), 0, retAlignment)) - retAlignment = getDataLayout()->getABITypeAlignment(retTy); - } else { // indirect call - const CallInst *CallI = dyn_cast<CallInst>(CS->getInstruction()); - if (!llvm::getAlign(*CallI, 0, retAlignment)) - retAlignment = getDataLayout()->getABITypeAlignment(retTy); - } - SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue); - SDValue DeclareRetOps[] = { Chain, - DAG.getConstant(retAlignment, MVT::i32), - DAG.getConstant(resultsz / 8, MVT::i32), - DAG.getConstant(0, MVT::i32), InFlag }; - Chain = DAG.getNode(NVPTXISD::DeclareRetParam, dl, DeclareRetVTs, - DeclareRetOps, 5); - InFlag = Chain.getValue(1); - } + // Declare + // .param .align 16 .b8 retval0[<size-in-bytes>], or + // .param .b<size-in-bits> retval0 + unsigned resultsz = TD->getTypeAllocSizeInBits(retTy); + if (retTy->isPrimitiveType() || retTy->isIntegerTy() || + retTy->isPointerTy()) { + // Scalar needs to be at least 32bit wide + if (resultsz < 32) + resultsz = 32; + SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue); + SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, MVT::i32), + DAG.getConstant(resultsz, MVT::i32), + DAG.getConstant(0, MVT::i32), InFlag }; + Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, DeclareRetVTs, + DeclareRetOps, 5); + InFlag = Chain.getValue(1); + } else { + retAlignment = getArgumentAlignment(Callee, CS, retTy, 0); + SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue); + SDValue DeclareRetOps[] = { Chain, + DAG.getConstant(retAlignment, MVT::i32), + DAG.getConstant(resultsz / 8, MVT::i32), + DAG.getConstant(0, MVT::i32), InFlag }; + Chain = DAG.getNode(NVPTXISD::DeclareRetParam, dl, DeclareRetVTs, + DeclareRetOps, 5); + InFlag = Chain.getValue(1); } } @@ -674,25 +887,22 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, // proto_0 : .callprototype(.param .b32 _) _ (.param .b32 _); // to be emitted, and the label has to used as the last arg of call // instruction. - // The prototype is embedded in a string and put as the operand for an - // INLINEASM SDNode. - SDVTList InlineAsmVTs = DAG.getVTList(MVT::Other, MVT::Glue); - std::string proto_string = getPrototype(retTy, Args, Outs, retAlignment); - const char *asmstr = nvTM->getManagedStrPool() - ->getManagedString(proto_string.c_str())->c_str(); - SDValue InlineAsmOps[] = { - Chain, DAG.getTargetExternalSymbol(asmstr, getPointerTy()), - DAG.getMDNode(0), DAG.getTargetConstant(0, MVT::i32), InFlag + // The prototype is embedded in a string and put as the operand for a + // CallPrototype SDNode which will print out to the value of the string. + SDVTList ProtoVTs = DAG.getVTList(MVT::Other, MVT::Glue); + std::string Proto = getPrototype(retTy, Args, Outs, retAlignment, CS); + const char *ProtoStr = + nvTM->getManagedStrPool()->getManagedString(Proto.c_str())->c_str(); + SDValue ProtoOps[] = { + Chain, DAG.getTargetExternalSymbol(ProtoStr, MVT::i32), InFlag, }; - Chain = DAG.getNode(ISD::INLINEASM, dl, InlineAsmVTs, InlineAsmOps, 5); + Chain = DAG.getNode(NVPTXISD::CallPrototype, dl, ProtoVTs, &ProtoOps[0], 3); InFlag = Chain.getValue(1); } // Op to just print "call" SDVTList PrintCallVTs = DAG.getVTList(MVT::Other, MVT::Glue); SDValue PrintCallOps[] = { - Chain, - DAG.getConstant(isABI ? ((Ins.size() == 0) ? 0 : 1) : retCount, MVT::i32), - InFlag + Chain, DAG.getConstant((Ins.size() == 0) ? 0 : 1, MVT::i32), InFlag }; Chain = DAG.getNode(Func ? (NVPTXISD::PrintCallUni) : (NVPTXISD::PrintCall), dl, PrintCallVTs, PrintCallOps, 3); @@ -740,62 +950,183 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, // Generate loads from param memory/moves from registers for result if (Ins.size() > 0) { - if (isABI) { - unsigned resoffset = 0; - for (unsigned i = 0, e = Ins.size(); i != e; ++i) { - unsigned sz = Ins[i].VT.getSizeInBits(); - if (Ins[i].VT.isInteger() && (sz < 8)) - sz = 8; - EVT LoadRetVTs[] = { Ins[i].VT, MVT::Other, MVT::Glue }; - SDValue LoadRetOps[] = { Chain, DAG.getConstant(1, MVT::i32), - DAG.getConstant(resoffset, MVT::i32), InFlag }; - SDValue retval = DAG.getNode(NVPTXISD::LoadParam, dl, LoadRetVTs, - LoadRetOps, array_lengthof(LoadRetOps)); + unsigned resoffset = 0; + if (retTy && retTy->isVectorTy()) { + EVT ObjectVT = getValueType(retTy); + unsigned NumElts = ObjectVT.getVectorNumElements(); + EVT EltVT = ObjectVT.getVectorElementType(); + assert(nvTM->getTargetLowering()->getNumRegisters(F->getContext(), + ObjectVT) == NumElts && + "Vector was not scalarized"); + unsigned sz = EltVT.getSizeInBits(); + bool needTruncate = sz < 16 ? true : false; + + if (NumElts == 1) { + // Just a simple load + std::vector<EVT> LoadRetVTs; + if (needTruncate) { + // If loading i1 result, generate + // load i16 + // trunc i16 to i1 + LoadRetVTs.push_back(MVT::i16); + } else + LoadRetVTs.push_back(EltVT); + LoadRetVTs.push_back(MVT::Other); + LoadRetVTs.push_back(MVT::Glue); + std::vector<SDValue> LoadRetOps; + LoadRetOps.push_back(Chain); + LoadRetOps.push_back(DAG.getConstant(1, MVT::i32)); + LoadRetOps.push_back(DAG.getConstant(0, MVT::i32)); + LoadRetOps.push_back(InFlag); + SDValue retval = DAG.getMemIntrinsicNode( + NVPTXISD::LoadParam, dl, + DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()), &LoadRetOps[0], + LoadRetOps.size(), EltVT, MachinePointerInfo()); Chain = retval.getValue(1); InFlag = retval.getValue(2); - InVals.push_back(retval); - resoffset += sz / 8; + SDValue Ret0 = retval; + if (needTruncate) + Ret0 = DAG.getNode(ISD::TRUNCATE, dl, EltVT, Ret0); + InVals.push_back(Ret0); + } else if (NumElts == 2) { + // LoadV2 + std::vector<EVT> LoadRetVTs; + if (needTruncate) { + // If loading i1 result, generate + // load i16 + // trunc i16 to i1 + LoadRetVTs.push_back(MVT::i16); + LoadRetVTs.push_back(MVT::i16); + } else { + LoadRetVTs.push_back(EltVT); + LoadRetVTs.push_back(EltVT); + } + LoadRetVTs.push_back(MVT::Other); + LoadRetVTs.push_back(MVT::Glue); + std::vector<SDValue> LoadRetOps; + LoadRetOps.push_back(Chain); + LoadRetOps.push_back(DAG.getConstant(1, MVT::i32)); + LoadRetOps.push_back(DAG.getConstant(0, MVT::i32)); + LoadRetOps.push_back(InFlag); + SDValue retval = DAG.getMemIntrinsicNode( + NVPTXISD::LoadParamV2, dl, + DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()), &LoadRetOps[0], + LoadRetOps.size(), EltVT, MachinePointerInfo()); + Chain = retval.getValue(2); + InFlag = retval.getValue(3); + SDValue Ret0 = retval.getValue(0); + SDValue Ret1 = retval.getValue(1); + if (needTruncate) { + Ret0 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Ret0); + InVals.push_back(Ret0); + Ret1 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Ret1); + InVals.push_back(Ret1); + } else { + InVals.push_back(Ret0); + InVals.push_back(Ret1); + } + } else { + // Split into N LoadV4 + unsigned Ofst = 0; + unsigned VecSize = 4; + unsigned Opc = NVPTXISD::LoadParamV4; + if (EltVT.getSizeInBits() == 64) { + VecSize = 2; + Opc = NVPTXISD::LoadParamV2; + } + EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, VecSize); + for (unsigned i = 0; i < NumElts; i += VecSize) { + SmallVector<EVT, 8> LoadRetVTs; + if (needTruncate) { + // If loading i1 result, generate + // load i16 + // trunc i16 to i1 + for (unsigned j = 0; j < VecSize; ++j) + LoadRetVTs.push_back(MVT::i16); + } else { + for (unsigned j = 0; j < VecSize; ++j) + LoadRetVTs.push_back(EltVT); + } + LoadRetVTs.push_back(MVT::Other); + LoadRetVTs.push_back(MVT::Glue); + SmallVector<SDValue, 4> LoadRetOps; + LoadRetOps.push_back(Chain); + LoadRetOps.push_back(DAG.getConstant(1, MVT::i32)); + LoadRetOps.push_back(DAG.getConstant(Ofst, MVT::i32)); + LoadRetOps.push_back(InFlag); + SDValue retval = DAG.getMemIntrinsicNode( + Opc, dl, DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()), + &LoadRetOps[0], LoadRetOps.size(), EltVT, MachinePointerInfo()); + if (VecSize == 2) { + Chain = retval.getValue(2); + InFlag = retval.getValue(3); + } else { + Chain = retval.getValue(4); + InFlag = retval.getValue(5); + } + + for (unsigned j = 0; j < VecSize; ++j) { + if (i + j >= NumElts) + break; + SDValue Elt = retval.getValue(j); + if (needTruncate) + Elt = DAG.getNode(ISD::TRUNCATE, dl, EltVT, Elt); + InVals.push_back(Elt); + } + Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext())); + } } } else { - SmallVector<EVT, 16> resvtparts; - ComputeValueVTs(*this, retTy, resvtparts); - - assert(Ins.size() == resvtparts.size() && - "Unexpected number of return values in non-ABI case"); - unsigned paramNum = 0; + SmallVector<EVT, 16> VTs; + ComputePTXValueVTs(*this, retTy, VTs); + assert(VTs.size() == Ins.size() && "Bad value decomposition"); for (unsigned i = 0, e = Ins.size(); i != e; ++i) { - assert(EVT(Ins[i].VT) == resvtparts[i] && - "Unexpected EVT type in non-ABI case"); - unsigned numelems = 1; - EVT elemtype = Ins[i].VT; - if (Ins[i].VT.isVector()) { - numelems = Ins[i].VT.getVectorNumElements(); - elemtype = Ins[i].VT.getVectorElementType(); - } - std::vector<SDValue> tempRetVals; - for (unsigned j = 0; j < numelems; ++j) { - EVT MoveRetVTs[] = { elemtype, MVT::Other, MVT::Glue }; - SDValue MoveRetOps[] = { Chain, DAG.getConstant(0, MVT::i32), - DAG.getConstant(paramNum, MVT::i32), - InFlag }; - SDValue retval = DAG.getNode(NVPTXISD::LoadParam, dl, MoveRetVTs, - MoveRetOps, array_lengthof(MoveRetOps)); - Chain = retval.getValue(1); - InFlag = retval.getValue(2); - tempRetVals.push_back(retval); - ++paramNum; - } - if (Ins[i].VT.isVector()) - InVals.push_back(DAG.getNode(ISD::BUILD_VECTOR, dl, Ins[i].VT, - &tempRetVals[0], tempRetVals.size())); - else - InVals.push_back(tempRetVals[0]); + unsigned sz = VTs[i].getSizeInBits(); + bool needTruncate = sz < 8 ? true : false; + if (VTs[i].isInteger() && (sz < 8)) + sz = 8; + + SmallVector<EVT, 4> LoadRetVTs; + EVT TheLoadType = VTs[i]; + if (retTy->isIntegerTy() && + TD->getTypeAllocSizeInBits(retTy) < 32) { + // This is for integer types only, and specifically not for + // aggregates. + LoadRetVTs.push_back(MVT::i32); + TheLoadType = MVT::i32; + } else if (sz < 16) { + // If loading i1/i8 result, generate + // load i8 (-> i16) + // trunc i16 to i1/i8 + LoadRetVTs.push_back(MVT::i16); + } else + LoadRetVTs.push_back(Ins[i].VT); + LoadRetVTs.push_back(MVT::Other); + LoadRetVTs.push_back(MVT::Glue); + + SmallVector<SDValue, 4> LoadRetOps; + LoadRetOps.push_back(Chain); + LoadRetOps.push_back(DAG.getConstant(1, MVT::i32)); + LoadRetOps.push_back(DAG.getConstant(resoffset, MVT::i32)); + LoadRetOps.push_back(InFlag); + SDValue retval = DAG.getMemIntrinsicNode( + NVPTXISD::LoadParam, dl, + DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()), &LoadRetOps[0], + LoadRetOps.size(), TheLoadType, MachinePointerInfo()); + Chain = retval.getValue(1); + InFlag = retval.getValue(2); + SDValue Ret0 = retval.getValue(0); + if (needTruncate) + Ret0 = DAG.getNode(ISD::TRUNCATE, dl, Ins[i].VT, Ret0); + InVals.push_back(Ret0); + resoffset += sz / 8; } } } + Chain = DAG.getCALLSEQ_END(Chain, DAG.getIntPtrConstant(uniqueCallSite, true), DAG.getIntPtrConstant(uniqueCallSite + 1, true), - InFlag); + InFlag, dl); uniqueCallSite++; // set isTailCall to false for now, until we figure out how to express @@ -810,7 +1141,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, SDValue NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const { SDNode *Node = Op.getNode(); - DebugLoc dl = Node->getDebugLoc(); + SDLoc dl(Node); SmallVector<SDValue, 8> Ops; unsigned NumOperands = Node->getNumOperands(); for (unsigned i = 0; i < NumOperands; ++i) { @@ -861,17 +1192,17 @@ SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const { // v = ld i1* addr // => -// v1 = ld i8* addr -// v = trunc v1 to i1 +// v1 = ld i8* addr (-> i16) +// v = trunc i16 to i1 SDValue NVPTXTargetLowering::LowerLOADi1(SDValue Op, SelectionDAG &DAG) const { SDNode *Node = Op.getNode(); LoadSDNode *LD = cast<LoadSDNode>(Node); - DebugLoc dl = Node->getDebugLoc(); + SDLoc dl(Node); assert(LD->getExtensionType() == ISD::NON_EXTLOAD); assert(Node->getValueType(0) == MVT::i1 && "Custom lowering for i1 load only"); SDValue newLD = - DAG.getLoad(MVT::i8, dl, LD->getChain(), LD->getBasePtr(), + DAG.getLoad(MVT::i16, dl, LD->getChain(), LD->getBasePtr(), LD->getPointerInfo(), LD->isVolatile(), LD->isNonTemporal(), LD->isInvariant(), LD->getAlignment()); SDValue result = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, newLD); @@ -896,7 +1227,7 @@ SDValue NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const { SDNode *N = Op.getNode(); SDValue Val = N->getOperand(1); - DebugLoc DL = N->getDebugLoc(); + SDLoc DL(N); EVT ValVT = Val.getValueType(); if (ValVT.isVector()) { @@ -955,8 +1286,6 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const { SDValue ExtVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val, DAG.getIntPtrConstant(i)); if (NeedExt) - // ANY_EXTEND is correct here since the store will only look at the - // lower-order bits anyway. ExtVal = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i16, ExtVal); Ops.push_back(ExtVal); } @@ -981,11 +1310,11 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const { // st i1 v, addr // => -// v1 = zxt v to i8 -// st i8, addr +// v1 = zxt v to i16 +// st.u8 i16, addr SDValue NVPTXTargetLowering::LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const { SDNode *Node = Op.getNode(); - DebugLoc dl = Node->getDebugLoc(); + SDLoc dl(Node); StoreSDNode *ST = cast<StoreSDNode>(Node); SDValue Tmp1 = ST->getChain(); SDValue Tmp2 = ST->getBasePtr(); @@ -994,9 +1323,10 @@ SDValue NVPTXTargetLowering::LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const { unsigned Alignment = ST->getAlignment(); bool isVolatile = ST->isVolatile(); bool isNonTemporal = ST->isNonTemporal(); - Tmp3 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, Tmp3); - SDValue Result = DAG.getStore(Tmp1, dl, Tmp3, Tmp2, ST->getPointerInfo(), - isVolatile, isNonTemporal, Alignment); + Tmp3 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Tmp3); + SDValue Result = DAG.getTruncStore(Tmp1, dl, Tmp3, Tmp2, + ST->getPointerInfo(), MVT::i8, isNonTemporal, + isVolatile, Alignment); return Result; } @@ -1011,7 +1341,15 @@ SDValue NVPTXTargetLowering::getExtSymb(SelectionDAG &DAG, const char *inname, SDValue NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int idx, EVT v) const { - return getExtSymb(DAG, ".PARAM", idx, v); + std::string ParamSym; + raw_string_ostream ParamStr(ParamSym); + + ParamStr << DAG.getMachineFunction().getName() << "_param_" << idx; + ParamStr.flush(); + + std::string *SavedStr = + nvTM->getManagedStrPool()->getManagedString(ParamSym.c_str()); + return DAG.getTargetExternalSymbol(SavedStr->c_str(), v); } SDValue NVPTXTargetLowering::getParamHelpSymbol(SelectionDAG &DAG, int idx) { @@ -1046,19 +1384,23 @@ bool llvm::isImageOrSamplerVal(const Value *arg, const Module *context) { SDValue NVPTXTargetLowering::LowerFormalArguments( SDValue Chain, CallingConv::ID CallConv, bool isVarArg, - const SmallVectorImpl<ISD::InputArg> &Ins, DebugLoc dl, SelectionDAG &DAG, + const SmallVectorImpl<ISD::InputArg> &Ins, SDLoc dl, SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const { MachineFunction &MF = DAG.getMachineFunction(); const DataLayout *TD = getDataLayout(); const Function *F = MF.getFunction(); const AttributeSet &PAL = F->getAttributes(); + const TargetLowering *TLI = nvTM->getTargetLowering(); SDValue Root = DAG.getRoot(); std::vector<SDValue> OutChains; bool isKernel = llvm::isKernelFunction(*F); bool isABI = (nvptxSubtarget.getSmVersion() >= 20); + assert(isABI && "Non-ABI compilation is not supported"); + if (!isABI) + return Chain; std::vector<Type *> argTypes; std::vector<const Argument *> theArgs; @@ -1067,15 +1409,20 @@ SDValue NVPTXTargetLowering::LowerFormalArguments( theArgs.push_back(I); argTypes.push_back(I->getType()); } - //assert(argTypes.size() == Ins.size() && - // "Ins types and function types did not match"); + // argTypes.size() (or theArgs.size()) and Ins.size() need not match. + // Ins.size() will be larger + // * if there is an aggregate argument with multiple fields (each field + // showing up separately in Ins) + // * if there is a vector argument with more than typical vector-length + // elements (generally if more than 4) where each vector element is + // individually present in Ins. + // So a different index should be used for indexing into Ins. + // See similar issue in LowerCall. + unsigned InsIdx = 0; int idx = 0; - for (unsigned i = 0, e = argTypes.size(); i != e; ++i, ++idx) { + for (unsigned i = 0, e = theArgs.size(); i != e; ++i, ++idx, ++InsIdx) { Type *Ty = argTypes[i]; - EVT ObjectVT = getValueType(Ty); - //assert(ObjectVT == Ins[i].VT && - // "Ins type did not match function type"); // If the kernel argument is image*_t or sampler_t, convert it to // a i32 constant holding the parameter position. This can later @@ -1091,142 +1438,248 @@ SDValue NVPTXTargetLowering::LowerFormalArguments( if (theArgs[i]->use_empty()) { // argument is dead - if (ObjectVT.isVector()) { - EVT EltVT = ObjectVT.getVectorElementType(); - unsigned NumElts = ObjectVT.getVectorNumElements(); - for (unsigned vi = 0; vi < NumElts; ++vi) { - InVals.push_back(DAG.getNode(ISD::UNDEF, dl, EltVT)); + if (Ty->isAggregateType()) { + SmallVector<EVT, 16> vtparts; + + ComputePTXValueVTs(*this, Ty, vtparts); + assert(vtparts.size() > 0 && "empty aggregate type not expected"); + for (unsigned parti = 0, parte = vtparts.size(); parti != parte; + ++parti) { + EVT partVT = vtparts[parti]; + InVals.push_back(DAG.getNode(ISD::UNDEF, dl, partVT)); + ++InsIdx; } - } else { - InVals.push_back(DAG.getNode(ISD::UNDEF, dl, ObjectVT)); + if (vtparts.size() > 0) + --InsIdx; + continue; + } + if (Ty->isVectorTy()) { + EVT ObjectVT = getValueType(Ty); + unsigned NumRegs = TLI->getNumRegisters(F->getContext(), ObjectVT); + for (unsigned parti = 0; parti < NumRegs; ++parti) { + InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT)); + ++InsIdx; + } + if (NumRegs > 0) + --InsIdx; + continue; } + InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT)); continue; } // In the following cases, assign a node order of "idx+1" - // to newly created nodes. The SDNOdes for params have to + // to newly created nodes. The SDNodes for params have to // appear in the same order as their order of appearance // in the original function. "idx+1" holds that order. if (PAL.hasAttribute(i + 1, Attribute::ByVal) == false) { - if (ObjectVT.isVector()) { + if (Ty->isAggregateType()) { + SmallVector<EVT, 16> vtparts; + SmallVector<uint64_t, 16> offsets; + + // NOTE: Here, we lose the ability to issue vector loads for vectors + // that are a part of a struct. This should be investigated in the + // future. + ComputePTXValueVTs(*this, Ty, vtparts, &offsets, 0); + assert(vtparts.size() > 0 && "empty aggregate type not expected"); + bool aggregateIsPacked = false; + if (StructType *STy = llvm::dyn_cast<StructType>(Ty)) + aggregateIsPacked = STy->isPacked(); + + SDValue Arg = getParamSymbol(DAG, idx, getPointerTy()); + for (unsigned parti = 0, parte = vtparts.size(); parti != parte; + ++parti) { + EVT partVT = vtparts[parti]; + Value *srcValue = Constant::getNullValue( + PointerType::get(partVT.getTypeForEVT(F->getContext()), + llvm::ADDRESS_SPACE_PARAM)); + SDValue srcAddr = + DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg, + DAG.getConstant(offsets[parti], getPointerTy())); + unsigned partAlign = + aggregateIsPacked ? 1 + : TD->getABITypeAlignment( + partVT.getTypeForEVT(F->getContext())); + SDValue p; + if (Ins[InsIdx].VT.getSizeInBits() > partVT.getSizeInBits()) { + ISD::LoadExtType ExtOp = Ins[InsIdx].Flags.isSExt() ? + ISD::SEXTLOAD : ISD::ZEXTLOAD; + p = DAG.getExtLoad(ExtOp, dl, Ins[InsIdx].VT, Root, srcAddr, + MachinePointerInfo(srcValue), partVT, false, + false, partAlign); + } else { + p = DAG.getLoad(partVT, dl, Root, srcAddr, + MachinePointerInfo(srcValue), false, false, false, + partAlign); + } + if (p.getNode()) + p.getNode()->setIROrder(idx + 1); + InVals.push_back(p); + ++InsIdx; + } + if (vtparts.size() > 0) + --InsIdx; + continue; + } + if (Ty->isVectorTy()) { + EVT ObjectVT = getValueType(Ty); + SDValue Arg = getParamSymbol(DAG, idx, getPointerTy()); unsigned NumElts = ObjectVT.getVectorNumElements(); + assert(TLI->getNumRegisters(F->getContext(), ObjectVT) == NumElts && + "Vector was not scalarized"); + unsigned Ofst = 0; EVT EltVT = ObjectVT.getVectorElementType(); - unsigned Offset = 0; - for (unsigned vi = 0; vi < NumElts; ++vi) { - SDValue A = getParamSymbol(DAG, idx, getPointerTy()); - SDValue B = DAG.getIntPtrConstant(Offset); - SDValue Addr = DAG.getNode(ISD::ADD, dl, getPointerTy(), - //getParamSymbol(DAG, idx, EltVT), - //DAG.getConstant(Offset, getPointerTy())); - A, B); + + // V1 load + // f32 = load ... + if (NumElts == 1) { + // We only have one element, so just directly load it Value *SrcValue = Constant::getNullValue(PointerType::get( EltVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM)); - SDValue Ld = DAG.getLoad( - EltVT, dl, Root, Addr, MachinePointerInfo(SrcValue), false, false, - false, + SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg, + DAG.getConstant(Ofst, getPointerTy())); + SDValue P = DAG.getLoad( + EltVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false, + false, true, TD->getABITypeAlignment(EltVT.getTypeForEVT(F->getContext()))); - Offset += EltVT.getStoreSizeInBits() / 8; - InVals.push_back(Ld); + if (P.getNode()) + P.getNode()->setIROrder(idx + 1); + + if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits()) + P = DAG.getNode(ISD::ANY_EXTEND, dl, Ins[InsIdx].VT, P); + InVals.push_back(P); + Ofst += TD->getTypeAllocSize(EltVT.getTypeForEVT(F->getContext())); + ++InsIdx; + } else if (NumElts == 2) { + // V2 load + // f32,f32 = load ... + EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, 2); + Value *SrcValue = Constant::getNullValue(PointerType::get( + VecVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM)); + SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg, + DAG.getConstant(Ofst, getPointerTy())); + SDValue P = DAG.getLoad( + VecVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false, + false, true, + TD->getABITypeAlignment(VecVT.getTypeForEVT(F->getContext()))); + if (P.getNode()) + P.getNode()->setIROrder(idx + 1); + + SDValue Elt0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P, + DAG.getIntPtrConstant(0)); + SDValue Elt1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P, + DAG.getIntPtrConstant(1)); + + if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits()) { + Elt0 = DAG.getNode(ISD::ANY_EXTEND, dl, Ins[InsIdx].VT, Elt0); + Elt1 = DAG.getNode(ISD::ANY_EXTEND, dl, Ins[InsIdx].VT, Elt1); + } + + InVals.push_back(Elt0); + InVals.push_back(Elt1); + Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext())); + InsIdx += 2; + } else { + // V4 loads + // We have at least 4 elements (<3 x Ty> expands to 4 elements) and + // the + // vector will be expanded to a power of 2 elements, so we know we can + // always round up to the next multiple of 4 when creating the vector + // loads. + // e.g. 4 elem => 1 ld.v4 + // 6 elem => 2 ld.v4 + // 8 elem => 2 ld.v4 + // 11 elem => 3 ld.v4 + unsigned VecSize = 4; + if (EltVT.getSizeInBits() == 64) { + VecSize = 2; + } + EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, VecSize); + for (unsigned i = 0; i < NumElts; i += VecSize) { + Value *SrcValue = Constant::getNullValue( + PointerType::get(VecVT.getTypeForEVT(F->getContext()), + llvm::ADDRESS_SPACE_PARAM)); + SDValue SrcAddr = + DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg, + DAG.getConstant(Ofst, getPointerTy())); + SDValue P = DAG.getLoad( + VecVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false, + false, true, + TD->getABITypeAlignment(VecVT.getTypeForEVT(F->getContext()))); + if (P.getNode()) + P.getNode()->setIROrder(idx + 1); + + for (unsigned j = 0; j < VecSize; ++j) { + if (i + j >= NumElts) + break; + SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P, + DAG.getIntPtrConstant(j)); + if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits()) + Elt = DAG.getNode(ISD::ANY_EXTEND, dl, Ins[InsIdx].VT, Elt); + InVals.push_back(Elt); + } + Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext())); + } + InsIdx += NumElts; } + + if (NumElts > 0) + --InsIdx; continue; } - // A plain scalar. - if (isABI || isKernel) { - // If ABI, load from the param symbol - SDValue Arg = getParamSymbol(DAG, idx); - // Conjure up a value that we can get the address space from. - // FIXME: Using a constant here is a hack. - Value *srcValue = Constant::getNullValue( - PointerType::get(ObjectVT.getTypeForEVT(F->getContext()), - llvm::ADDRESS_SPACE_PARAM)); - SDValue p = DAG.getLoad( - ObjectVT, dl, Root, Arg, MachinePointerInfo(srcValue), false, false, - false, - TD->getABITypeAlignment(ObjectVT.getTypeForEVT(F->getContext()))); - if (p.getNode()) - DAG.AssignOrdering(p.getNode(), idx + 1); - InVals.push_back(p); + EVT ObjectVT = getValueType(Ty); + // If ABI, load from the param symbol + SDValue Arg = getParamSymbol(DAG, idx, getPointerTy()); + Value *srcValue = Constant::getNullValue(PointerType::get( + ObjectVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM)); + SDValue p; + if (ObjectVT.getSizeInBits() < Ins[InsIdx].VT.getSizeInBits()) { + ISD::LoadExtType ExtOp = Ins[InsIdx].Flags.isSExt() ? + ISD::SEXTLOAD : ISD::ZEXTLOAD; + p = DAG.getExtLoad(ExtOp, dl, Ins[InsIdx].VT, Root, Arg, + MachinePointerInfo(srcValue), ObjectVT, false, false, + TD->getABITypeAlignment(ObjectVT.getTypeForEVT(F->getContext()))); } else { - // If no ABI, just move the param symbol - SDValue Arg = getParamSymbol(DAG, idx, ObjectVT); - SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg); - if (p.getNode()) - DAG.AssignOrdering(p.getNode(), idx + 1); - InVals.push_back(p); + p = DAG.getLoad(Ins[InsIdx].VT, dl, Root, Arg, + MachinePointerInfo(srcValue), false, false, false, + TD->getABITypeAlignment(ObjectVT.getTypeForEVT(F->getContext()))); } + if (p.getNode()) + p.getNode()->setIROrder(idx + 1); + InVals.push_back(p); continue; } // Param has ByVal attribute - if (isABI || isKernel) { - // Return MoveParam(param symbol). - // Ideally, the param symbol can be returned directly, - // but when SDNode builder decides to use it in a CopyToReg(), - // machine instruction fails because TargetExternalSymbol - // (not lowered) is target dependent, and CopyToReg assumes - // the source is lowered. - SDValue Arg = getParamSymbol(DAG, idx, getPointerTy()); - SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg); - if (p.getNode()) - DAG.AssignOrdering(p.getNode(), idx + 1); - if (isKernel) - InVals.push_back(p); - else { - SDValue p2 = DAG.getNode( - ISD::INTRINSIC_WO_CHAIN, dl, ObjectVT, - DAG.getConstant(Intrinsic::nvvm_ptr_local_to_gen, MVT::i32), p); - InVals.push_back(p2); - } - } else { - // Have to move a set of param symbols to registers and - // store them locally and return the local pointer in InVals - const PointerType *elemPtrType = dyn_cast<PointerType>(argTypes[i]); - assert(elemPtrType && "Byval parameter should be a pointer type"); - Type *elemType = elemPtrType->getElementType(); - // Compute the constituent parts - SmallVector<EVT, 16> vtparts; - SmallVector<uint64_t, 16> offsets; - ComputeValueVTs(*this, elemType, vtparts, &offsets, 0); - unsigned totalsize = 0; - for (unsigned j = 0, je = vtparts.size(); j != je; ++j) - totalsize += vtparts[j].getStoreSizeInBits(); - SDValue localcopy = DAG.getFrameIndex( - MF.getFrameInfo()->CreateStackObject(totalsize / 8, 16, false), - getPointerTy()); - unsigned sizesofar = 0; - std::vector<SDValue> theChains; - for (unsigned j = 0, je = vtparts.size(); j != je; ++j) { - unsigned numElems = 1; - if (vtparts[j].isVector()) - numElems = vtparts[j].getVectorNumElements(); - for (unsigned k = 0, ke = numElems; k != ke; ++k) { - EVT tmpvt = vtparts[j]; - if (tmpvt.isVector()) - tmpvt = tmpvt.getVectorElementType(); - SDValue arg = DAG.getNode(NVPTXISD::MoveParam, dl, tmpvt, - getParamSymbol(DAG, idx, tmpvt)); - SDValue addr = - DAG.getNode(ISD::ADD, dl, getPointerTy(), localcopy, - DAG.getConstant(sizesofar, getPointerTy())); - theChains.push_back(DAG.getStore( - Chain, dl, arg, addr, MachinePointerInfo(), false, false, 0)); - sizesofar += tmpvt.getStoreSizeInBits() / 8; - ++idx; - } - } - --idx; - Chain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, &theChains[0], - theChains.size()); - InVals.push_back(localcopy); + // Return MoveParam(param symbol). + // Ideally, the param symbol can be returned directly, + // but when SDNode builder decides to use it in a CopyToReg(), + // machine instruction fails because TargetExternalSymbol + // (not lowered) is target dependent, and CopyToReg assumes + // the source is lowered. + EVT ObjectVT = getValueType(Ty); + assert(ObjectVT == Ins[InsIdx].VT && + "Ins type did not match function type"); + SDValue Arg = getParamSymbol(DAG, idx, getPointerTy()); + SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg); + if (p.getNode()) + p.getNode()->setIROrder(idx + 1); + if (isKernel) + InVals.push_back(p); + else { + SDValue p2 = DAG.getNode( + ISD::INTRINSIC_WO_CHAIN, dl, ObjectVT, + DAG.getConstant(Intrinsic::nvvm_ptr_local_to_gen, MVT::i32), p); + InVals.push_back(p2); } } // Clang will check explicit VarArg and issue error if any. However, Clang // will let code with - // implicit var arg like f() pass. + // implicit var arg like f() pass. See bug 617733. // We treat this case as if the arg list is empty. - //if (F.isVarArg()) { + // if (F.isVarArg()) { // assert(0 && "VarArg not supported yet!"); //} @@ -1237,43 +1690,185 @@ SDValue NVPTXTargetLowering::LowerFormalArguments( return Chain; } -SDValue NVPTXTargetLowering::LowerReturn( - SDValue Chain, CallingConv::ID CallConv, bool isVarArg, - const SmallVectorImpl<ISD::OutputArg> &Outs, - const SmallVectorImpl<SDValue> &OutVals, DebugLoc dl, - SelectionDAG &DAG) const { + +SDValue +NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv, + bool isVarArg, + const SmallVectorImpl<ISD::OutputArg> &Outs, + const SmallVectorImpl<SDValue> &OutVals, + SDLoc dl, SelectionDAG &DAG) const { + MachineFunction &MF = DAG.getMachineFunction(); + const Function *F = MF.getFunction(); + Type *RetTy = F->getReturnType(); + const DataLayout *TD = getDataLayout(); bool isABI = (nvptxSubtarget.getSmVersion() >= 20); + assert(isABI && "Non-ABI compilation is not supported"); + if (!isABI) + return Chain; + + if (VectorType *VTy = dyn_cast<VectorType>(RetTy)) { + // If we have a vector type, the OutVals array will be the scalarized + // components and we have combine them into 1 or more vector stores. + unsigned NumElts = VTy->getNumElements(); + assert(NumElts == Outs.size() && "Bad scalarization of return value"); + + // const_cast can be removed in later LLVM versions + EVT EltVT = getValueType(RetTy).getVectorElementType(); + bool NeedExtend = false; + if (EltVT.getSizeInBits() < 16) + NeedExtend = true; + + // V1 store + if (NumElts == 1) { + SDValue StoreVal = OutVals[0]; + // We only have one element, so just directly store it + if (NeedExtend) + StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal); + SDValue Ops[] = { Chain, DAG.getConstant(0, MVT::i32), StoreVal }; + Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetval, dl, + DAG.getVTList(MVT::Other), &Ops[0], 3, + EltVT, MachinePointerInfo()); + + } else if (NumElts == 2) { + // V2 store + SDValue StoreVal0 = OutVals[0]; + SDValue StoreVal1 = OutVals[1]; + + if (NeedExtend) { + StoreVal0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal0); + StoreVal1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal1); + } - unsigned sizesofar = 0; - unsigned idx = 0; - for (unsigned i = 0, e = Outs.size(); i != e; ++i) { - SDValue theVal = OutVals[i]; - EVT theValType = theVal.getValueType(); - unsigned numElems = 1; - if (theValType.isVector()) - numElems = theValType.getVectorNumElements(); - for (unsigned j = 0, je = numElems; j != je; ++j) { - SDValue tmpval = theVal; - if (theValType.isVector()) - tmpval = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, - theValType.getVectorElementType(), tmpval, - DAG.getIntPtrConstant(j)); - Chain = DAG.getNode( - isABI ? NVPTXISD::StoreRetval : NVPTXISD::MoveToRetval, dl, - MVT::Other, Chain, DAG.getConstant(isABI ? sizesofar : idx, MVT::i32), - tmpval); - if (theValType.isVector()) - sizesofar += theValType.getVectorElementType().getStoreSizeInBits() / 8; - else - sizesofar += theValType.getStoreSizeInBits() / 8; - ++idx; + SDValue Ops[] = { Chain, DAG.getConstant(0, MVT::i32), StoreVal0, + StoreVal1 }; + Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetvalV2, dl, + DAG.getVTList(MVT::Other), &Ops[0], 4, + EltVT, MachinePointerInfo()); + } else { + // V4 stores + // We have at least 4 elements (<3 x Ty> expands to 4 elements) and the + // vector will be expanded to a power of 2 elements, so we know we can + // always round up to the next multiple of 4 when creating the vector + // stores. + // e.g. 4 elem => 1 st.v4 + // 6 elem => 2 st.v4 + // 8 elem => 2 st.v4 + // 11 elem => 3 st.v4 + + unsigned VecSize = 4; + if (OutVals[0].getValueType().getSizeInBits() == 64) + VecSize = 2; + + unsigned Offset = 0; + + EVT VecVT = + EVT::getVectorVT(F->getContext(), OutVals[0].getValueType(), VecSize); + unsigned PerStoreOffset = + TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext())); + + for (unsigned i = 0; i < NumElts; i += VecSize) { + // Get values + SDValue StoreVal; + SmallVector<SDValue, 8> Ops; + Ops.push_back(Chain); + Ops.push_back(DAG.getConstant(Offset, MVT::i32)); + unsigned Opc = NVPTXISD::StoreRetvalV2; + EVT ExtendedVT = (NeedExtend) ? MVT::i16 : OutVals[0].getValueType(); + + StoreVal = OutVals[i]; + if (NeedExtend) + StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal); + Ops.push_back(StoreVal); + + if (i + 1 < NumElts) { + StoreVal = OutVals[i + 1]; + if (NeedExtend) + StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal); + } else { + StoreVal = DAG.getUNDEF(ExtendedVT); + } + Ops.push_back(StoreVal); + + if (VecSize == 4) { + Opc = NVPTXISD::StoreRetvalV4; + if (i + 2 < NumElts) { + StoreVal = OutVals[i + 2]; + if (NeedExtend) + StoreVal = + DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal); + } else { + StoreVal = DAG.getUNDEF(ExtendedVT); + } + Ops.push_back(StoreVal); + + if (i + 3 < NumElts) { + StoreVal = OutVals[i + 3]; + if (NeedExtend) + StoreVal = + DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal); + } else { + StoreVal = DAG.getUNDEF(ExtendedVT); + } + Ops.push_back(StoreVal); + } + + // Chain = DAG.getNode(Opc, dl, MVT::Other, &Ops[0], Ops.size()); + Chain = + DAG.getMemIntrinsicNode(Opc, dl, DAG.getVTList(MVT::Other), &Ops[0], + Ops.size(), EltVT, MachinePointerInfo()); + Offset += PerStoreOffset; + } + } + } else { + SmallVector<EVT, 16> ValVTs; + // const_cast is necessary since we are still using an LLVM version from + // before the type system re-write. + ComputePTXValueVTs(*this, RetTy, ValVTs); + assert(ValVTs.size() == OutVals.size() && "Bad return value decomposition"); + + unsigned SizeSoFar = 0; + for (unsigned i = 0, e = Outs.size(); i != e; ++i) { + SDValue theVal = OutVals[i]; + EVT TheValType = theVal.getValueType(); + unsigned numElems = 1; + if (TheValType.isVector()) + numElems = TheValType.getVectorNumElements(); + for (unsigned j = 0, je = numElems; j != je; ++j) { + SDValue TmpVal = theVal; + if (TheValType.isVector()) + TmpVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, + TheValType.getVectorElementType(), TmpVal, + DAG.getIntPtrConstant(j)); + EVT TheStoreType = ValVTs[i]; + if (RetTy->isIntegerTy() && + TD->getTypeAllocSizeInBits(RetTy) < 32) { + // The following zero-extension is for integer types only, and + // specifically not for aggregates. + TmpVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, TmpVal); + TheStoreType = MVT::i32; + } + else if (TmpVal.getValueType().getSizeInBits() < 16) + TmpVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, TmpVal); + + SDValue Ops[] = { Chain, DAG.getConstant(SizeSoFar, MVT::i32), TmpVal }; + Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetval, dl, + DAG.getVTList(MVT::Other), &Ops[0], + 3, TheStoreType, + MachinePointerInfo()); + if(TheValType.isVector()) + SizeSoFar += + TheStoreType.getVectorElementType().getStoreSizeInBits() / 8; + else + SizeSoFar += TheStoreType.getStoreSizeInBits()/8; + } } } return DAG.getNode(NVPTXISD::RET_FLAG, dl, MVT::Other, Chain); } + void NVPTXTargetLowering::LowerAsmOperandForConstraint( SDValue Op, std::string &Constraint, std::vector<SDValue> &Ops, SelectionDAG &DAG) const { @@ -1337,9 +1932,9 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic( Info.opc = ISD::INTRINSIC_W_CHAIN; if (Intrinsic == Intrinsic::nvvm_ldu_global_i) - Info.memVT = MVT::i32; + Info.memVT = getValueType(I.getType()); else if (Intrinsic == Intrinsic::nvvm_ldu_global_p) - Info.memVT = getPointerTy(); + Info.memVT = getValueType(I.getType()); else Info.memVT = MVT::f32; Info.ptrVal = I.getArgOperand(0); @@ -1420,11 +2015,11 @@ NVPTXTargetLowering::getConstraintType(const std::string &Constraint) const { std::pair<unsigned, const TargetRegisterClass *> NVPTXTargetLowering::getRegForInlineAsmConstraint(const std::string &Constraint, - EVT VT) const { + MVT VT) const { if (Constraint.size() == 1) { switch (Constraint[0]) { case 'c': - return std::make_pair(0U, &NVPTX::Int8RegsRegClass); + return std::make_pair(0U, &NVPTX::Int16RegsRegClass); case 'h': return std::make_pair(0U, &NVPTX::Int16RegsRegClass); case 'r': @@ -1450,7 +2045,7 @@ unsigned NVPTXTargetLowering::getFunctionAlignment(const Function *) const { static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG, SmallVectorImpl<SDValue> &Results) { EVT ResVT = N->getValueType(0); - DebugLoc DL = N->getDebugLoc(); + SDLoc DL(N); assert(ResVT.isVector() && "Vector load must have vector type"); @@ -1543,7 +2138,7 @@ static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG, SmallVectorImpl<SDValue> &Results) { SDValue Chain = N->getOperand(0); SDValue Intrin = N->getOperand(1); - DebugLoc DL = N->getDebugLoc(); + SDLoc DL(N); // Get the intrinsic ID unsigned IntrinNo = cast<ConstantSDNode>(Intrin.getNode())->getZExtValue(); @@ -1564,7 +2159,8 @@ static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG, unsigned NumElts = ResVT.getVectorNumElements(); EVT EltVT = ResVT.getVectorElementType(); - // Since LDU/LDG are target nodes, we cannot rely on DAG type legalization. + // Since LDU/LDG are target nodes, we cannot rely on DAG type + // legalization. // Therefore, we must ensure the type is legal. For i1 and i8, we set the // loaded type to i16 and propogate the "real" type as the memory type. bool NeedTrunc = false; @@ -1623,7 +2219,7 @@ static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG, OtherOps.push_back(Chain); // Chain // Skip operand 1 (intrinsic ID) - // Others + // Others for (unsigned i = 2, e = N->getNumOperands(); i != e; ++i) OtherOps.push_back(N->getOperand(i)); @@ -1671,7 +2267,8 @@ static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG, DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, LdResVTs, &Ops[0], Ops.size(), MVT::i8, MemSD->getMemOperand()); - Results.push_back(NewLD.getValue(0)); + Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, + NewLD.getValue(0))); Results.push_back(NewLD.getValue(1)); } } @@ -1691,3 +2288,29 @@ void NVPTXTargetLowering::ReplaceNodeResults( return; } } + +// Pin NVPTXSection's and NVPTXTargetObjectFile's vtables to this file. +void NVPTXSection::anchor() {} + +NVPTXTargetObjectFile::~NVPTXTargetObjectFile() { + delete TextSection; + delete DataSection; + delete BSSSection; + delete ReadOnlySection; + + delete StaticCtorSection; + delete StaticDtorSection; + delete LSDASection; + delete EHFrameSection; + delete DwarfAbbrevSection; + delete DwarfInfoSection; + delete DwarfLineSection; + delete DwarfFrameSection; + delete DwarfPubTypesSection; + delete DwarfDebugInlineSection; + delete DwarfStrSection; + delete DwarfLocSection; + delete DwarfARangesSection; + delete DwarfRangesSection; + delete DwarfMacroInfoSection; +} diff --git a/contrib/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/contrib/llvm/lib/Target/NVPTX/NVPTXISelLowering.h index 3cd49d3..66e708f 100644 --- a/contrib/llvm/lib/Target/NVPTX/NVPTXISelLowering.h +++ b/contrib/llvm/lib/Target/NVPTX/NVPTXISelLowering.h @@ -29,17 +29,11 @@ enum NodeType { CALL, RET_FLAG, LOAD_PARAM, - NVBuiltin, DeclareParam, DeclareScalarParam, DeclareRetParam, DeclareRet, DeclareScalarRet, - LoadParam, - StoreParam, - StoreParamS32, // to sext and store a <32bit value, not used currently - StoreParamU32, // to zext and store a <32bit value, not used currently - MoveToParam, PrintCall, PrintCallUni, CallArgBegin, @@ -51,13 +45,11 @@ enum NodeType { CallSymbol, Prototype, MoveParam, - MoveRetval, - MoveToRetval, - StoreRetval, PseudoUseParam, RETURN, CallSeqBegin, CallSeqEnd, + CallPrototype, Dummy, LoadV2 = ISD::FIRST_TARGET_MEMORY_OPCODE, @@ -67,7 +59,18 @@ enum NodeType { LDUV2, // LDU.v2 LDUV4, // LDU.v4 StoreV2, - StoreV4 + StoreV4, + LoadParam, + LoadParamV2, + LoadParamV4, + StoreParam, + StoreParamV2, + StoreParamV4, + StoreParamS32, // to sext and store a <32bit value, not used currently + StoreParamU32, // to zext and store a <32bit value, not used currently + StoreRetval, + StoreRetvalV2, + StoreRetvalV4 }; } @@ -100,7 +103,7 @@ public: /// getFunctionAlignment - Return the Log2 alignment of this function. virtual unsigned getFunctionAlignment(const Function *F) const; - virtual EVT getSetCCResultType(EVT VT) const { + virtual EVT getSetCCResultType(LLVMContext &, EVT VT) const { if (VT.isVector()) return MVT::getVectorVT(MVT::i1, VT.getVectorNumElements()); return MVT::i1; @@ -108,11 +111,11 @@ public: ConstraintType getConstraintType(const std::string &Constraint) const; std::pair<unsigned, const TargetRegisterClass *> - getRegForInlineAsmConstraint(const std::string &Constraint, EVT VT) const; + getRegForInlineAsmConstraint(const std::string &Constraint, MVT VT) const; virtual SDValue LowerFormalArguments( SDValue Chain, CallingConv::ID CallConv, bool isVarArg, - const SmallVectorImpl<ISD::InputArg> &Ins, DebugLoc dl, SelectionDAG &DAG, + const SmallVectorImpl<ISD::InputArg> &Ins, SDLoc dl, SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const; virtual SDValue @@ -120,12 +123,13 @@ public: std::string getPrototype(Type *, const ArgListTy &, const SmallVectorImpl<ISD::OutputArg> &, - unsigned retAlignment) const; + unsigned retAlignment, + const ImmutableCallSite *CS) const; virtual SDValue LowerReturn(SDValue Chain, CallingConv::ID CallConv, bool isVarArg, const SmallVectorImpl<ISD::OutputArg> &Outs, - const SmallVectorImpl<SDValue> &OutVals, DebugLoc dl, + const SmallVectorImpl<SDValue> &OutVals, SDLoc dl, SelectionDAG &DAG) const; virtual void LowerAsmOperandForConstraint(SDValue Op, std::string &Constraint, @@ -144,7 +148,7 @@ private: SDValue getExtSymb(SelectionDAG &DAG, const char *name, int idx, EVT = MVT::i32) const; - SDValue getParamSymbol(SelectionDAG &DAG, int idx, EVT = MVT::i32) const; + SDValue getParamSymbol(SelectionDAG &DAG, int idx, EVT) const; SDValue getParamHelpSymbol(SelectionDAG &DAG, int idx); SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const; @@ -158,6 +162,9 @@ private: virtual void ReplaceNodeResults(SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const; + + unsigned getArgumentAlignment(SDValue Callee, const ImmutableCallSite *CS, + Type *Ty, unsigned Idx) const; }; } // namespace llvm diff --git a/contrib/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp b/contrib/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp index 33a63c2..86ddd38 100644 --- a/contrib/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp +++ b/contrib/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp @@ -14,54 +14,53 @@ #include "NVPTX.h" #include "NVPTXInstrInfo.h" #include "NVPTXTargetMachine.h" -#define GET_INSTRINFO_CTOR +#define GET_INSTRINFO_CTOR_DTOR #include "NVPTXGenInstrInfo.inc" #include "llvm/IR/Function.h" #include "llvm/ADT/STLExtras.h" #include "llvm/CodeGen/MachineFunction.h" #include "llvm/CodeGen/MachineInstrBuilder.h" #include "llvm/CodeGen/MachineRegisterInfo.h" -#include <cstdio> using namespace llvm; +// Pin the vtable to this file. +void NVPTXInstrInfo::anchor() {} + // FIXME: Add the subtarget support on this constructor. NVPTXInstrInfo::NVPTXInstrInfo(NVPTXTargetMachine &tm) - : NVPTXGenInstrInfo(), TM(tm), RegInfo(*this, *TM.getSubtargetImpl()) {} + : NVPTXGenInstrInfo(), TM(tm), RegInfo(*TM.getSubtargetImpl()) {} void NVPTXInstrInfo::copyPhysReg( MachineBasicBlock &MBB, MachineBasicBlock::iterator I, DebugLoc DL, unsigned DestReg, unsigned SrcReg, bool KillSrc) const { - if (NVPTX::Int32RegsRegClass.contains(DestReg) && - NVPTX::Int32RegsRegClass.contains(SrcReg)) + const MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo(); + const TargetRegisterClass *DestRC = MRI.getRegClass(DestReg); + const TargetRegisterClass *SrcRC = MRI.getRegClass(SrcReg); + + if (DestRC != SrcRC) + report_fatal_error("Attempted to created cross-class register copy"); + + if (DestRC == &NVPTX::Int32RegsRegClass) BuildMI(MBB, I, DL, get(NVPTX::IMOV32rr), DestReg) - .addReg(SrcReg, getKillRegState(KillSrc)); - else if (NVPTX::Int8RegsRegClass.contains(DestReg) && - NVPTX::Int8RegsRegClass.contains(SrcReg)) - BuildMI(MBB, I, DL, get(NVPTX::IMOV8rr), DestReg) - .addReg(SrcReg, getKillRegState(KillSrc)); - else if (NVPTX::Int1RegsRegClass.contains(DestReg) && - NVPTX::Int1RegsRegClass.contains(SrcReg)) + .addReg(SrcReg, getKillRegState(KillSrc)); + else if (DestRC == &NVPTX::Int1RegsRegClass) BuildMI(MBB, I, DL, get(NVPTX::IMOV1rr), DestReg) - .addReg(SrcReg, getKillRegState(KillSrc)); - else if (NVPTX::Float32RegsRegClass.contains(DestReg) && - NVPTX::Float32RegsRegClass.contains(SrcReg)) + .addReg(SrcReg, getKillRegState(KillSrc)); + else if (DestRC == &NVPTX::Float32RegsRegClass) BuildMI(MBB, I, DL, get(NVPTX::FMOV32rr), DestReg) - .addReg(SrcReg, getKillRegState(KillSrc)); - else if (NVPTX::Int16RegsRegClass.contains(DestReg) && - NVPTX::Int16RegsRegClass.contains(SrcReg)) + .addReg(SrcReg, getKillRegState(KillSrc)); + else if (DestRC == &NVPTX::Int16RegsRegClass) BuildMI(MBB, I, DL, get(NVPTX::IMOV16rr), DestReg) - .addReg(SrcReg, getKillRegState(KillSrc)); - else if (NVPTX::Int64RegsRegClass.contains(DestReg) && - NVPTX::Int64RegsRegClass.contains(SrcReg)) + .addReg(SrcReg, getKillRegState(KillSrc)); + else if (DestRC == &NVPTX::Int64RegsRegClass) BuildMI(MBB, I, DL, get(NVPTX::IMOV64rr), DestReg) - .addReg(SrcReg, getKillRegState(KillSrc)); - else if (NVPTX::Float64RegsRegClass.contains(DestReg) && - NVPTX::Float64RegsRegClass.contains(SrcReg)) + .addReg(SrcReg, getKillRegState(KillSrc)); + else if (DestRC == &NVPTX::Float64RegsRegClass) BuildMI(MBB, I, DL, get(NVPTX::FMOV64rr), DestReg) - .addReg(SrcReg, getKillRegState(KillSrc)); + .addReg(SrcReg, getKillRegState(KillSrc)); else { - llvm_unreachable("Don't know how to copy a register"); + llvm_unreachable("Bad register copy"); } } diff --git a/contrib/llvm/lib/Target/NVPTX/NVPTXInstrInfo.h b/contrib/llvm/lib/Target/NVPTX/NVPTXInstrInfo.h index b1972e9..600fc5c 100644 --- a/contrib/llvm/lib/Target/NVPTX/NVPTXInstrInfo.h +++ b/contrib/llvm/lib/Target/NVPTX/NVPTXInstrInfo.h @@ -26,6 +26,7 @@ namespace llvm { class NVPTXInstrInfo : public NVPTXGenInstrInfo { NVPTXTargetMachine &TM; const NVPTXRegisterInfo RegInfo; + virtual void anchor(); public: explicit NVPTXInstrInfo(NVPTXTargetMachine &TM); diff --git a/contrib/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/contrib/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index da6dd39..b23f1e4 100644 --- a/contrib/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/contrib/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -32,6 +32,86 @@ def isVecOther : VecInstTypeEnum<15>; def brtarget : Operand<OtherVT>; +// CVT conversion modes +// These must match the enum in NVPTX.h +def CvtNONE : PatLeaf<(i32 0x0)>; +def CvtRNI : PatLeaf<(i32 0x1)>; +def CvtRZI : PatLeaf<(i32 0x2)>; +def CvtRMI : PatLeaf<(i32 0x3)>; +def CvtRPI : PatLeaf<(i32 0x4)>; +def CvtRN : PatLeaf<(i32 0x5)>; +def CvtRZ : PatLeaf<(i32 0x6)>; +def CvtRM : PatLeaf<(i32 0x7)>; +def CvtRP : PatLeaf<(i32 0x8)>; + +def CvtNONE_FTZ : PatLeaf<(i32 0x10)>; +def CvtRNI_FTZ : PatLeaf<(i32 0x11)>; +def CvtRZI_FTZ : PatLeaf<(i32 0x12)>; +def CvtRMI_FTZ : PatLeaf<(i32 0x13)>; +def CvtRPI_FTZ : PatLeaf<(i32 0x14)>; +def CvtRN_FTZ : PatLeaf<(i32 0x15)>; +def CvtRZ_FTZ : PatLeaf<(i32 0x16)>; +def CvtRM_FTZ : PatLeaf<(i32 0x17)>; +def CvtRP_FTZ : PatLeaf<(i32 0x18)>; + +def CvtSAT : PatLeaf<(i32 0x20)>; +def CvtSAT_FTZ : PatLeaf<(i32 0x30)>; + +def CvtMode : Operand<i32> { + let PrintMethod = "printCvtMode"; +} + +// Compare modes +// These must match the enum in NVPTX.h +def CmpEQ : PatLeaf<(i32 0)>; +def CmpNE : PatLeaf<(i32 1)>; +def CmpLT : PatLeaf<(i32 2)>; +def CmpLE : PatLeaf<(i32 3)>; +def CmpGT : PatLeaf<(i32 4)>; +def CmpGE : PatLeaf<(i32 5)>; +def CmpLO : PatLeaf<(i32 6)>; +def CmpLS : PatLeaf<(i32 7)>; +def CmpHI : PatLeaf<(i32 8)>; +def CmpHS : PatLeaf<(i32 9)>; +def CmpEQU : PatLeaf<(i32 10)>; +def CmpNEU : PatLeaf<(i32 11)>; +def CmpLTU : PatLeaf<(i32 12)>; +def CmpLEU : PatLeaf<(i32 13)>; +def CmpGTU : PatLeaf<(i32 14)>; +def CmpGEU : PatLeaf<(i32 15)>; +def CmpNUM : PatLeaf<(i32 16)>; +def CmpNAN : PatLeaf<(i32 17)>; + +def CmpEQ_FTZ : PatLeaf<(i32 0x100)>; +def CmpNE_FTZ : PatLeaf<(i32 0x101)>; +def CmpLT_FTZ : PatLeaf<(i32 0x102)>; +def CmpLE_FTZ : PatLeaf<(i32 0x103)>; +def CmpGT_FTZ : PatLeaf<(i32 0x104)>; +def CmpGE_FTZ : PatLeaf<(i32 0x105)>; +def CmpLO_FTZ : PatLeaf<(i32 0x106)>; +def CmpLS_FTZ : PatLeaf<(i32 0x107)>; +def CmpHI_FTZ : PatLeaf<(i32 0x108)>; +def CmpHS_FTZ : PatLeaf<(i32 0x109)>; +def CmpEQU_FTZ : PatLeaf<(i32 0x10A)>; +def CmpNEU_FTZ : PatLeaf<(i32 0x10B)>; +def CmpLTU_FTZ : PatLeaf<(i32 0x10C)>; +def CmpLEU_FTZ : PatLeaf<(i32 0x10D)>; +def CmpGTU_FTZ : PatLeaf<(i32 0x10E)>; +def CmpGEU_FTZ : PatLeaf<(i32 0x10F)>; +def CmpNUM_FTZ : PatLeaf<(i32 0x110)>; +def CmpNAN_FTZ : PatLeaf<(i32 0x111)>; + +def CmpMode : Operand<i32> { + let PrintMethod = "printCmpMode"; +} + +def F32ConstZero : Operand<f32>, PatLeaf<(f32 fpimm)>, SDNodeXForm<fpimm, [{ + return CurDAG->getTargetConstantFP(0.0, MVT::f32); + }]>; +def F32ConstOne : Operand<f32>, PatLeaf<(f32 fpimm)>, SDNodeXForm<fpimm, [{ + return CurDAG->getTargetConstantFP(1.0, MVT::f32); + }]>; + //===----------------------------------------------------------------------===// // NVPTX Instruction Predicate Definitions //===----------------------------------------------------------------------===// @@ -56,127 +136,31 @@ def hasLDG : Predicate<"Subtarget.hasLDG()">; def hasLDU : Predicate<"Subtarget.hasLDU()">; def hasGenericLdSt : Predicate<"Subtarget.hasGenericLdSt()">; -def doF32FTZ : Predicate<"UseF32FTZ">; +def doF32FTZ : Predicate<"useF32FTZ()">; +def doNoF32FTZ : Predicate<"!useF32FTZ()">; def doFMAF32 : Predicate<"doFMAF32">; -def doFMAF32_ftz : Predicate<"(doFMAF32 && UseF32FTZ)">; +def doFMAF32_ftz : Predicate<"(doFMAF32 && useF32FTZ())">; def doFMAF32AGG : Predicate<"doFMAF32AGG">; -def doFMAF32AGG_ftz : Predicate<"(doFMAF32AGG && UseF32FTZ)">; +def doFMAF32AGG_ftz : Predicate<"(doFMAF32AGG && useF32FTZ())">; def doFMAF64 : Predicate<"doFMAF64">; def doFMAF64AGG : Predicate<"doFMAF64AGG">; -def doFMADF32 : Predicate<"doFMADF32">; -def doFMADF32_ftz : Predicate<"(doFMADF32 && UseF32FTZ)">; def doMulWide : Predicate<"doMulWide">; def allowFMA : Predicate<"allowFMA">; -def allowFMA_ftz : Predicate<"(allowFMA && UseF32FTZ)">; +def allowFMA_ftz : Predicate<"(allowFMA && useF32FTZ())">; -def do_DIVF32_APPROX : Predicate<"do_DIVF32_PREC==0">; -def do_DIVF32_FULL : Predicate<"do_DIVF32_PREC==1">; +def do_DIVF32_APPROX : Predicate<"getDivF32Level()==0">; +def do_DIVF32_FULL : Predicate<"getDivF32Level()==1">; -def do_SQRTF32_APPROX : Predicate<"do_SQRTF32_PREC==0">; -def do_SQRTF32_RN : Predicate<"do_SQRTF32_PREC==1">; +def do_SQRTF32_APPROX : Predicate<"!usePrecSqrtF32()">; +def do_SQRTF32_RN : Predicate<"usePrecSqrtF32()">; def hasHWROT32 : Predicate<"Subtarget.hasHWROT32()">; def true : Predicate<"1">; -//===----------------------------------------------------------------------===// -// Special Handling for 8-bit Operands and Operations -// -// PTX supports 8-bit signed and unsigned types, but does not support 8-bit -// operations (like add, shift, etc) except for ld/st/cvt. SASS does not have -// 8-bit registers. -// -// PTX ld, st and cvt instructions permit source and destination data operands -// to be wider than the instruction-type size, so that narrow values may be -// loaded, stored, and converted using regular-width registers. -// -// So in PTX generation, we -// - always use 16-bit registers in place in 8-bit registers. -// (8-bit variables should stay as 8-bit as they represent memory layout.) -// - for the following 8-bit operations, we sign-ext/zero-ext the 8-bit values -// before operation -// . div -// . rem -// . neg (sign) -// . set, setp -// . shr -// -// We are patching the operations by inserting the cvt instructions in the -// asm strings of the affected instructions. -// -// Since vector operations, except for ld/st, are eventually elementized. We -// do not need to special-hand the vector 8-bit operations. -// -// -//===----------------------------------------------------------------------===// - -// Generate string block like -// { -// .reg .s16 %temp1; -// .reg .s16 %temp2; -// cvt.s16.s8 %temp1, %a; -// cvt.s16.s8 %temp2, %b; -// opc.s16 %dst, %temp1, %temp2; -// } -// when OpcStr=opc.s TypeStr=s16 CVTStr=cvt.s16.s8 -class Handle_i8rr<string OpcStr, string TypeStr, string CVTStr> { - string s = !strconcat("{{\n\t", - !strconcat(".reg .", !strconcat(TypeStr, - !strconcat(" \t%temp1;\n\t", - !strconcat(".reg .", !strconcat(TypeStr, - !strconcat(" \t%temp2;\n\t", - !strconcat(CVTStr, !strconcat(" \t%temp1, $a;\n\t", - !strconcat(CVTStr, !strconcat(" \t%temp2, $b;\n\t", - !strconcat(OpcStr, "16 \t$dst, %temp1, %temp2;\n\t}}")))))))))))); -} - -// Generate string block like -// { -// .reg .s16 %temp1; -// .reg .s16 %temp2; -// cvt.s16.s8 %temp1, %a; -// mov.b16 %temp2, %b; -// cvt.s16.s8 %temp2, %temp2; -// opc.s16 %dst, %temp1, %temp2; -// } -// when OpcStr=opc.s TypeStr=s16 CVTStr=cvt.s16.s8 -class Handle_i8ri<string OpcStr, string TypeStr, string CVTStr> { - string s = !strconcat("{{\n\t", - !strconcat(".reg .", !strconcat(TypeStr, - !strconcat(" \t%temp1;\n\t", - !strconcat(".reg .", - !strconcat(TypeStr, !strconcat(" \t%temp2;\n\t", - !strconcat(CVTStr, !strconcat(" \t%temp1, $a;\n\t", - !strconcat("mov.b16 \t%temp2, $b;\n\t", - !strconcat(CVTStr, !strconcat(" \t%temp2, %temp2;\n\t", - !strconcat(OpcStr, "16 \t$dst, %temp1, %temp2;\n\t}}"))))))))))))); -} - -// Generate string block like -// { -// .reg .s16 %temp1; -// .reg .s16 %temp2; -// mov.b16 %temp1, %b; -// cvt.s16.s8 %temp1, %temp1; -// cvt.s16.s8 %temp2, %a; -// opc.s16 %dst, %temp1, %temp2; -// } -// when OpcStr=opc.s TypeStr=s16 CVTStr=cvt.s16.s8 -class Handle_i8ir<string OpcStr, string TypeStr, string CVTStr> { - string s = !strconcat("{{\n\t", - !strconcat(".reg .", !strconcat(TypeStr, - !strconcat(" \t%temp1;\n\t", - !strconcat(".reg .", !strconcat(TypeStr, - !strconcat(" \t%temp2;\n\t", - !strconcat("mov.b16 \t%temp1, $a;\n\t", - !strconcat(CVTStr, !strconcat(" \t%temp1, %temp1;\n\t", - !strconcat(CVTStr, !strconcat(" \t%temp2, $b;\n\t", - !strconcat(OpcStr, "16 \t$dst, %temp1, %temp2;\n\t}}"))))))))))))); -} - //===----------------------------------------------------------------------===// // Some Common Instruction Class Templates @@ -204,66 +188,6 @@ multiclass I3<string OpcStr, SDNode OpNode> { def i16ri : NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, i16imm:$b), !strconcat(OpcStr, "16 \t$dst, $a, $b;"), [(set Int16Regs:$dst, (OpNode Int16Regs:$a, (imm):$b))]>; - def i8rr : NVPTXInst<(outs Int8Regs:$dst), (ins Int8Regs:$a, Int8Regs:$b), - !strconcat(OpcStr, "16 \t$dst, $a, $b;"), - [(set Int8Regs:$dst, (OpNode Int8Regs:$a, Int8Regs:$b))]>; - def i8ri : NVPTXInst<(outs Int8Regs:$dst), (ins Int8Regs:$a, i8imm:$b), - !strconcat(OpcStr, "16 \t$dst, $a, $b;"), - [(set Int8Regs:$dst, (OpNode Int8Regs:$a, (imm):$b))]>; -} - -multiclass I3_i8<string OpcStr, SDNode OpNode, string TypeStr, string CVTStr> { - def i64rr : NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$a, Int64Regs:$b), - !strconcat(OpcStr, "64 \t$dst, $a, $b;"), - [(set Int64Regs:$dst, (OpNode Int64Regs:$a, - Int64Regs:$b))]>; - def i64ri : NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$a, i64imm:$b), - !strconcat(OpcStr, "64 \t$dst, $a, $b;"), - [(set Int64Regs:$dst, (OpNode Int64Regs:$a, imm:$b))]>; - def i32rr : NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a, Int32Regs:$b), - !strconcat(OpcStr, "32 \t$dst, $a, $b;"), - [(set Int32Regs:$dst, (OpNode Int32Regs:$a, - Int32Regs:$b))]>; - def i32ri : NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a, i32imm:$b), - !strconcat(OpcStr, "32 \t$dst, $a, $b;"), - [(set Int32Regs:$dst, (OpNode Int32Regs:$a, imm:$b))]>; - def i16rr : NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b), - !strconcat(OpcStr, "16 \t$dst, $a, $b;"), - [(set Int16Regs:$dst, (OpNode Int16Regs:$a, - Int16Regs:$b))]>; - def i16ri : NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, i16imm:$b), - !strconcat(OpcStr, "16 \t$dst, $a, $b;"), - [(set Int16Regs:$dst, (OpNode Int16Regs:$a, (imm):$b))]>; - def i8rr : NVPTXInst<(outs Int8Regs:$dst), (ins Int8Regs:$a, Int8Regs:$b), - Handle_i8rr<OpcStr, TypeStr, CVTStr>.s, - [(set Int8Regs:$dst, (OpNode Int8Regs:$a, Int8Regs:$b))]>; - def i8ri : NVPTXInst<(outs Int8Regs:$dst), (ins Int8Regs:$a, i8imm:$b), - Handle_i8ri<OpcStr, TypeStr, CVTStr>.s, - [(set Int8Regs:$dst, (OpNode Int8Regs:$a, (imm):$b))]>; -} - -multiclass I3_noi8<string OpcStr, SDNode OpNode> { - def i64rr : NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$a, Int64Regs:$b), - !strconcat(OpcStr, "64 \t$dst, $a, $b;"), - [(set Int64Regs:$dst, (OpNode Int64Regs:$a, - Int64Regs:$b))]>; - def i64ri : NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$a, i64imm:$b), - !strconcat(OpcStr, "64 \t$dst, $a, $b;"), - [(set Int64Regs:$dst, (OpNode Int64Regs:$a, imm:$b))]>; - def i32rr : NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a, Int32Regs:$b), - !strconcat(OpcStr, "32 \t$dst, $a, $b;"), - [(set Int32Regs:$dst, (OpNode Int32Regs:$a, - Int32Regs:$b))]>; - def i32ri : NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a, i32imm:$b), - !strconcat(OpcStr, "32 \t$dst, $a, $b;"), - [(set Int32Regs:$dst, (OpNode Int32Regs:$a, imm:$b))]>; - def i16rr : NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b), - !strconcat(OpcStr, "16 \t$dst, $a, $b;"), - [(set Int16Regs:$dst, (OpNode Int16Regs:$a, - Int16Regs:$b))]>; - def i16ri : NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, i16imm:$b), - !strconcat(OpcStr, "16 \t$dst, $a, $b;"), - [(set Int16Regs:$dst, (OpNode Int16Regs:$a, (imm):$b))]>; } multiclass ADD_SUB_INT_32<string OpcStr, SDNode OpNode> { @@ -369,6 +293,90 @@ multiclass F2<string OpcStr, SDNode OpNode> { //===----------------------------------------------------------------------===// //----------------------------------- +// General Type Conversion +//----------------------------------- + +let neverHasSideEffects = 1 in { +// Generate a cvt to the given type from all possible types. +// Each instance takes a CvtMode immediate that defines the conversion mode to +// use. It can be CvtNONE to omit a conversion mode. +multiclass CVT_FROM_ALL<string FromName, RegisterClass RC> { + def _s16 : NVPTXInst<(outs RC:$dst), + (ins Int16Regs:$src, CvtMode:$mode), + !strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.", + FromName, ".s16\t$dst, $src;"), + []>; + def _u16 : NVPTXInst<(outs RC:$dst), + (ins Int16Regs:$src, CvtMode:$mode), + !strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.", + FromName, ".u16\t$dst, $src;"), + []>; + def _f16 : NVPTXInst<(outs RC:$dst), + (ins Int16Regs:$src, CvtMode:$mode), + !strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.", + FromName, ".f16\t$dst, $src;"), + []>; + def _s32 : NVPTXInst<(outs RC:$dst), + (ins Int32Regs:$src, CvtMode:$mode), + !strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.", + FromName, ".s32\t$dst, $src;"), + []>; + def _u32 : NVPTXInst<(outs RC:$dst), + (ins Int32Regs:$src, CvtMode:$mode), + !strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.", + FromName, ".u32\t$dst, $src;"), + []>; + def _s64 : NVPTXInst<(outs RC:$dst), + (ins Int64Regs:$src, CvtMode:$mode), + !strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.", + FromName, ".s64\t$dst, $src;"), + []>; + def _u64 : NVPTXInst<(outs RC:$dst), + (ins Int64Regs:$src, CvtMode:$mode), + !strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.", + FromName, ".u64\t$dst, $src;"), + []>; + def _f32 : NVPTXInst<(outs RC:$dst), + (ins Float32Regs:$src, CvtMode:$mode), + !strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.", + FromName, ".f32\t$dst, $src;"), + []>; + def _f64 : NVPTXInst<(outs RC:$dst), + (ins Float64Regs:$src, CvtMode:$mode), + !strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.", + FromName, ".f64\t$dst, $src;"), + []>; +} + +// Generate a cvt to all possible types. +defm CVT_s16 : CVT_FROM_ALL<"s16", Int16Regs>; +defm CVT_u16 : CVT_FROM_ALL<"u16", Int16Regs>; +defm CVT_f16 : CVT_FROM_ALL<"f16", Int16Regs>; +defm CVT_s32 : CVT_FROM_ALL<"s32", Int32Regs>; +defm CVT_u32 : CVT_FROM_ALL<"u32", Int32Regs>; +defm CVT_s64 : CVT_FROM_ALL<"s64", Int64Regs>; +defm CVT_u64 : CVT_FROM_ALL<"u64", Int64Regs>; +defm CVT_f32 : CVT_FROM_ALL<"f32", Float32Regs>; +defm CVT_f64 : CVT_FROM_ALL<"f64", Float64Regs>; + +// This set of cvt is different from the above. The type of the source +// and target are the same. +// +def CVT_INREG_s16_s8 : NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$src), + "cvt.s16.s8 \t$dst, $src;", []>; +def CVT_INREG_s32_s8 : NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src), + "cvt.s32.s8 \t$dst, $src;", []>; +def CVT_INREG_s32_s16 : NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src), + "cvt.s32.s16 \t$dst, $src;", []>; +def CVT_INREG_s64_s8 : NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src), + "cvt.s64.s8 \t$dst, $src;", []>; +def CVT_INREG_s64_s16 : NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src), + "cvt.s64.s16 \t$dst, $src;", []>; +def CVT_INREG_s64_s32 : NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src), + "cvt.s64.s32 \t$dst, $src;", []>; +} + +//----------------------------------- // Integer Arithmetic //----------------------------------- @@ -522,81 +530,17 @@ def : Pat<(mul (zext Int16Regs:$a), (i32 UInt16Const:$b)), defm MULT : I3<"mul.lo.s", mul>; -defm MULTHS : I3_noi8<"mul.hi.s", mulhs>; -defm MULTHU : I3_noi8<"mul.hi.u", mulhu>; -def MULTHSi8rr : NVPTXInst<(outs Int8Regs:$dst), (ins Int8Regs:$a, Int8Regs:$b), - !strconcat("{{ \n\t", - !strconcat(".reg \t.s16 temp1; \n\t", - !strconcat(".reg \t.s16 temp2; \n\t", - !strconcat("cvt.s16.s8 \ttemp1, $a; \n\t", - !strconcat("cvt.s16.s8 \ttemp2, $b; \n\t", - !strconcat("mul.lo.s16 \t$dst, temp1, temp2; \n\t", - !strconcat("shr.s16 \t$dst, $dst, 8; \n\t", - !strconcat("}}", "")))))))), - [(set Int8Regs:$dst, (mulhs Int8Regs:$a, Int8Regs:$b))]>; -def MULTHSi8ri : NVPTXInst<(outs Int8Regs:$dst), (ins Int8Regs:$a, i8imm:$b), - !strconcat("{{ \n\t", - !strconcat(".reg \t.s16 temp1; \n\t", - !strconcat(".reg \t.s16 temp2; \n\t", - !strconcat("cvt.s16.s8 \ttemp1, $a; \n\t", - !strconcat("mov.b16 \ttemp2, $b; \n\t", - !strconcat("cvt.s16.s8 \ttemp2, temp2; \n\t", - !strconcat("mul.lo.s16 \t$dst, temp1, temp2; \n\t", - !strconcat("shr.s16 \t$dst, $dst, 8; \n\t", - !strconcat("}}", ""))))))))), - [(set Int8Regs:$dst, (mulhs Int8Regs:$a, imm:$b))]>; -def MULTHUi8rr : NVPTXInst<(outs Int8Regs:$dst), (ins Int8Regs:$a, Int8Regs:$b), - !strconcat("{{ \n\t", - !strconcat(".reg \t.u16 temp1; \n\t", - !strconcat(".reg \t.u16 temp2; \n\t", - !strconcat("cvt.u16.u8 \ttemp1, $a; \n\t", - !strconcat("cvt.u16.u8 \ttemp2, $b; \n\t", - !strconcat("mul.lo.u16 \t$dst, temp1, temp2; \n\t", - !strconcat("shr.u16 \t$dst, $dst, 8; \n\t", - !strconcat("}}", "")))))))), - [(set Int8Regs:$dst, (mulhu Int8Regs:$a, Int8Regs:$b))]>; -def MULTHUi8ri : NVPTXInst<(outs Int8Regs:$dst), (ins Int8Regs:$a, i8imm:$b), - !strconcat("{{ \n\t", - !strconcat(".reg \t.u16 temp1; \n\t", - !strconcat(".reg \t.u16 temp2; \n\t", - !strconcat("cvt.u16.u8 \ttemp1, $a; \n\t", - !strconcat("mov.b16 \ttemp2, $b; \n\t", - !strconcat("cvt.u16.u8 \ttemp2, temp2; \n\t", - !strconcat("mul.lo.u16 \t$dst, temp1, temp2; \n\t", - !strconcat("shr.u16 \t$dst, $dst, 8; \n\t", - !strconcat("}}", ""))))))))), - [(set Int8Regs:$dst, (mulhu Int8Regs:$a, imm:$b))]>; - - -defm SDIV : I3_i8<"div.s", sdiv, "s16", "cvt.s16.s8">; -defm UDIV : I3_i8<"div.u", udiv, "u16", "cvt.u16.u8">; - -defm SREM : I3_i8<"rem.s", srem, "s16", "cvt.s16.s8">; +defm MULTHS : I3<"mul.hi.s", mulhs>; +defm MULTHU : I3<"mul.hi.u", mulhu>; + +defm SDIV : I3<"div.s", sdiv>; +defm UDIV : I3<"div.u", udiv>; + +defm SREM : I3<"rem.s", srem>; // The ri version will not be selected as DAGCombiner::visitSREM will lower it. -defm UREM : I3_i8<"rem.u", urem, "u16", "cvt.u16.u8">; +defm UREM : I3<"rem.u", urem>; // The ri version will not be selected as DAGCombiner::visitUREM will lower it. -def MAD8rrr : NVPTXInst<(outs Int8Regs:$dst), - (ins Int8Regs:$a, Int8Regs:$b, Int8Regs:$c), - "mad.lo.s16 \t$dst, $a, $b, $c;", - [(set Int8Regs:$dst, (add (mul Int8Regs:$a, Int8Regs:$b), - Int8Regs:$c))]>; -def MAD8rri : NVPTXInst<(outs Int8Regs:$dst), - (ins Int8Regs:$a, Int8Regs:$b, i8imm:$c), - "mad.lo.s16 \t$dst, $a, $b, $c;", - [(set Int8Regs:$dst, (add (mul Int8Regs:$a, Int8Regs:$b), - imm:$c))]>; -def MAD8rir : NVPTXInst<(outs Int8Regs:$dst), - (ins Int8Regs:$a, i8imm:$b, Int8Regs:$c), - "mad.lo.s16 \t$dst, $a, $b, $c;", - [(set Int8Regs:$dst, (add (mul Int8Regs:$a, imm:$b), - Int8Regs:$c))]>; -def MAD8rii : NVPTXInst<(outs Int8Regs:$dst), - (ins Int8Regs:$a, i8imm:$b, i8imm:$c), - "mad.lo.s16 \t$dst, $a, $b, $c;", - [(set Int8Regs:$dst, (add (mul Int8Regs:$a, imm:$b), - imm:$c))]>; - def MAD16rrr : NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), "mad.lo.s16 \t$dst, $a, $b, $c;", @@ -661,10 +605,6 @@ def MAD64rii : NVPTXInst<(outs Int64Regs:$dst), (mul Int64Regs:$a, imm:$b), imm:$c))]>; -def INEG8 : NVPTXInst<(outs Int8Regs:$dst), (ins Int8Regs:$src), - !strconcat("cvt.s16.s8 \t$dst, $src;\n\t", - "neg.s16 \t$dst, $dst;"), - [(set Int8Regs:$dst, (ineg Int8Regs:$src))]>; def INEG16 : NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$src), "neg.s16 \t$dst, $src;", [(set Int16Regs:$dst, (ineg Int16Regs:$src))]>; @@ -842,6 +782,16 @@ def FDIV32ri_prec : NVPTXInst<(outs Float32Regs:$dst), (fdiv Float32Regs:$a, fpimm:$b))]>, Requires<[reqPTX20]>; +// +// F32 rsqrt +// + +def RSQRTF32approx1r : NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$b), + "rsqrt.approx.f32 \t$dst, $b;", []>; + +def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_f Float32Regs:$b)), + (RSQRTF32approx1r Float32Regs:$b)>, + Requires<[do_DIVF32_FULL, do_SQRTF32_APPROX, doNoF32FTZ]>; multiclass FPCONTRACT32<string OpcStr, Predicate Pred> { def rrr : NVPTXInst<(outs Float32Regs:$dst), @@ -912,8 +862,6 @@ multiclass FPCONTRACT64<string OpcStr, Predicate Pred> { // If we reverse the order of the following two lines, then rrr2 rule will be // generated for FMA32, but not for rrr. // Therefore, we manually write the rrr2 rule in FPCONTRACT32. -defm FMAD32_ftz : FPCONTRACT32<"mad.ftz.f32", doFMADF32_ftz>; -defm FMAD32 : FPCONTRACT32<"mad.f32", doFMADF32>; defm FMA32_ftz : FPCONTRACT32<"fma.rn.ftz.f32", doFMAF32_ftz>; defm FMA32 : FPCONTRACT32<"fma.rn.f32", doFMAF32>; defm FMA64 : FPCONTRACT64<"fma.rn.f64", doFMAF64>; @@ -952,8 +900,6 @@ multiclass FPCONTRACT64_SUB_PAT<NVPTXInst Inst, Predicate Pred> { defm FMAF32ext_ftz : FPCONTRACT32_SUB_PAT<FMA32_ftzrrr, doFMAF32AGG_ftz>; defm FMAF32ext : FPCONTRACT32_SUB_PAT<FMA32rrr, doFMAF32AGG>; -defm FMADF32ext_ftz : FPCONTRACT32_SUB_PAT_MAD<FMAD32_ftzrrr, doFMADF32_ftz>; -defm FMADF32ext : FPCONTRACT32_SUB_PAT_MAD<FMAD32rrr, doFMADF32>; defm FMAF64ext : FPCONTRACT64_SUB_PAT<FMA64rrr, doFMAF64AGG>; def SINF: NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$src), @@ -963,6 +909,41 @@ def COSF: NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$src), "cos.approx.f32 \t$dst, $src;", [(set Float32Regs:$dst, (fcos Float32Regs:$src))]>; +// Lower (frem x, y) into (sub x, (mul (floor (div x, y)) y)) +// e.g. "poor man's fmod()" + +// frem - f32 FTZ +def : Pat<(frem Float32Regs:$x, Float32Regs:$y), + (FSUBf32rr_ftz Float32Regs:$x, (FMULf32rr_ftz (CVT_f32_f32 + (FDIV32rr_prec_ftz Float32Regs:$x, Float32Regs:$y), CvtRMI_FTZ), + Float32Regs:$y))>, + Requires<[doF32FTZ]>; +def : Pat<(frem Float32Regs:$x, fpimm:$y), + (FSUBf32rr_ftz Float32Regs:$x, (FMULf32ri_ftz (CVT_f32_f32 + (FDIV32ri_prec_ftz Float32Regs:$x, fpimm:$y), CvtRMI_FTZ), + fpimm:$y))>, + Requires<[doF32FTZ]>; + +// frem - f32 +def : Pat<(frem Float32Regs:$x, Float32Regs:$y), + (FSUBf32rr Float32Regs:$x, (FMULf32rr (CVT_f32_f32 + (FDIV32rr_prec Float32Regs:$x, Float32Regs:$y), CvtRMI), + Float32Regs:$y))>; +def : Pat<(frem Float32Regs:$x, fpimm:$y), + (FSUBf32rr Float32Regs:$x, (FMULf32ri (CVT_f32_f32 + (FDIV32ri_prec Float32Regs:$x, fpimm:$y), CvtRMI), + fpimm:$y))>; + +// frem - f64 +def : Pat<(frem Float64Regs:$x, Float64Regs:$y), + (FSUBf64rr Float64Regs:$x, (FMULf64rr (CVT_f64_f64 + (FDIV64rr Float64Regs:$x, Float64Regs:$y), CvtRMI), + Float64Regs:$y))>; +def : Pat<(frem Float64Regs:$x, fpimm:$y), + (FSUBf64rr Float64Regs:$x, (FMULf64ri (CVT_f64_f64 + (FDIV64ri Float64Regs:$x, fpimm:$y), CvtRMI), + fpimm:$y))>; + //----------------------------------- // Logical Arithmetic //----------------------------------- @@ -974,12 +955,6 @@ multiclass LOG_FORMAT<string OpcStr, SDNode OpNode> { def b1ri: NVPTXInst<(outs Int1Regs:$dst), (ins Int1Regs:$a, i1imm:$b), !strconcat(OpcStr, ".pred \t$dst, $a, $b;"), [(set Int1Regs:$dst, (OpNode Int1Regs:$a, imm:$b))]>; - def b8rr: NVPTXInst<(outs Int8Regs:$dst), (ins Int8Regs:$a, Int8Regs:$b), - !strconcat(OpcStr, ".b16 \t$dst, $a, $b;"), - [(set Int8Regs:$dst, (OpNode Int8Regs:$a, Int8Regs:$b))]>; - def b8ri: NVPTXInst<(outs Int8Regs:$dst), (ins Int8Regs:$a, i8imm:$b), - !strconcat(OpcStr, ".b16 \t$dst, $a, $b;"), - [(set Int8Regs:$dst, (OpNode Int8Regs:$a, imm:$b))]>; def b16rr: NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b), !strconcat(OpcStr, ".b16 \t$dst, $a, $b;"), [(set Int16Regs:$dst, (OpNode Int16Regs:$a, @@ -1010,9 +985,6 @@ defm XOR : LOG_FORMAT<"xor", xor>; def NOT1: NVPTXInst<(outs Int1Regs:$dst), (ins Int1Regs:$src), "not.pred \t$dst, $src;", [(set Int1Regs:$dst, (not Int1Regs:$src))]>; -def NOT8: NVPTXInst<(outs Int8Regs:$dst), (ins Int8Regs:$src), - "not.b16 \t$dst, $src;", - [(set Int8Regs:$dst, (not Int8Regs:$src))]>; def NOT16: NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$src), "not.b16 \t$dst, $src;", [(set Int16Regs:$dst, (not Int16Regs:$src))]>; @@ -1056,21 +1028,13 @@ multiclass LSHIFT_FORMAT<string OpcStr, SDNode OpNode> { !strconcat(OpcStr, "16 \t$dst, $a, $b;"), [(set Int16Regs:$dst, (OpNode Int16Regs:$a, (i32 imm:$b)))]>; - def i8rr : NVPTXInst<(outs Int8Regs:$dst), (ins Int8Regs:$a, Int32Regs:$b), - !strconcat(OpcStr, "16 \t$dst, $a, $b;"), - [(set Int8Regs:$dst, (OpNode Int8Regs:$a, - Int32Regs:$b))]>; - def i8ri : NVPTXInst<(outs Int8Regs:$dst), (ins Int8Regs:$a, i32imm:$b), - !strconcat(OpcStr, "16 \t$dst, $a, $b;"), - [(set Int8Regs:$dst, (OpNode Int8Regs:$a, - (i32 imm:$b)))]>; } defm SHL : LSHIFT_FORMAT<"shl.b", shl>; // For shifts, the second src operand must be 32-bit value // Need to add cvt for the 8-bits. -multiclass RSHIFT_FORMAT<string OpcStr, SDNode OpNode, string CVTStr> { +multiclass RSHIFT_FORMAT<string OpcStr, SDNode OpNode> { def i64rr : NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$a, Int32Regs:$b), !strconcat(OpcStr, "64 \t$dst, $a, $b;"), @@ -1102,20 +1066,10 @@ multiclass RSHIFT_FORMAT<string OpcStr, SDNode OpNode, string CVTStr> { !strconcat(OpcStr, "16 \t$dst, $a, $b;"), [(set Int16Regs:$dst, (OpNode Int16Regs:$a, (i32 imm:$b)))]>; - def i8rr : NVPTXInst<(outs Int8Regs:$dst), (ins Int8Regs:$a, Int32Regs:$b), - !strconcat(CVTStr, !strconcat(" \t$dst, $a;\n\t", - !strconcat(OpcStr, "16 \t$dst, $dst, $b;"))), - [(set Int8Regs:$dst, (OpNode Int8Regs:$a, - Int32Regs:$b))]>; - def i8ri : NVPTXInst<(outs Int8Regs:$dst), (ins Int8Regs:$a, i32imm:$b), - !strconcat(CVTStr, !strconcat(" \t$dst, $a;\n\t", - !strconcat(OpcStr, "16 \t$dst, $dst, $b;"))), - [(set Int8Regs:$dst, (OpNode Int8Regs:$a, - (i32 imm:$b)))]>; } -defm SRA : RSHIFT_FORMAT<"shr.s", sra, "cvt.s16.s8">; -defm SRL : RSHIFT_FORMAT<"shr.u", srl, "cvt.u16.u8">; +defm SRA : RSHIFT_FORMAT<"shr.s", sra>; +defm SRL : RSHIFT_FORMAT<"shr.u", srl>; // 32bit def ROT32imm_sw : NVPTXInst<(outs Int32Regs:$dst), @@ -1213,6 +1167,120 @@ def ROTR64reg_sw : NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src, //----------------------------------- +// General Comparison +//----------------------------------- + +// General setp instructions +multiclass SETP<string TypeStr, RegisterClass RC, Operand ImmCls> { + def rr : NVPTXInst<(outs Int1Regs:$dst), + (ins RC:$a, RC:$b, CmpMode:$cmp), + !strconcat("setp${cmp:base}${cmp:ftz}.", TypeStr, "\t$dst, $a, $b;"), + []>; + def ri : NVPTXInst<(outs Int1Regs:$dst), + (ins RC:$a, ImmCls:$b, CmpMode:$cmp), + !strconcat("setp${cmp:base}${cmp:ftz}.", TypeStr, "\t$dst, $a, $b;"), + []>; + def ir : NVPTXInst<(outs Int1Regs:$dst), + (ins ImmCls:$a, RC:$b, CmpMode:$cmp), + !strconcat("setp${cmp:base}${cmp:ftz}.", TypeStr, "\t$dst, $a, $b;"), + []>; +} + +defm SETP_b16 : SETP<"b16", Int16Regs, i16imm>; +defm SETP_s16 : SETP<"s16", Int16Regs, i16imm>; +defm SETP_u16 : SETP<"u16", Int16Regs, i16imm>; +defm SETP_b32 : SETP<"b32", Int32Regs, i32imm>; +defm SETP_s32 : SETP<"s32", Int32Regs, i32imm>; +defm SETP_u32 : SETP<"u32", Int32Regs, i32imm>; +defm SETP_b64 : SETP<"b64", Int64Regs, i64imm>; +defm SETP_s64 : SETP<"s64", Int64Regs, i64imm>; +defm SETP_u64 : SETP<"u64", Int64Regs, i64imm>; +defm SETP_f32 : SETP<"f32", Float32Regs, f32imm>; +defm SETP_f64 : SETP<"f64", Float64Regs, f64imm>; + +// General set instructions +multiclass SET<string TypeStr, RegisterClass RC, Operand ImmCls> { + def rr : NVPTXInst<(outs Int32Regs:$dst), + (ins RC:$a, RC:$b, CmpMode:$cmp), + !strconcat("set$cmp.", TypeStr, "\t$dst, $a, $b;"), []>; + def ri : NVPTXInst<(outs Int32Regs:$dst), + (ins RC:$a, ImmCls:$b, CmpMode:$cmp), + !strconcat("set$cmp.", TypeStr, "\t$dst, $a, $b;"), []>; + def ir : NVPTXInst<(outs Int32Regs:$dst), + (ins ImmCls:$a, RC:$b, CmpMode:$cmp), + !strconcat("set$cmp.", TypeStr, "\t$dst, $a, $b;"), []>; +} + +defm SET_b16 : SET<"b16", Int16Regs, i16imm>; +defm SET_s16 : SET<"s16", Int16Regs, i16imm>; +defm SET_u16 : SET<"u16", Int16Regs, i16imm>; +defm SET_b32 : SET<"b32", Int32Regs, i32imm>; +defm SET_s32 : SET<"s32", Int32Regs, i32imm>; +defm SET_u32 : SET<"u32", Int32Regs, i32imm>; +defm SET_b64 : SET<"b64", Int64Regs, i64imm>; +defm SET_s64 : SET<"s64", Int64Regs, i64imm>; +defm SET_u64 : SET<"u64", Int64Regs, i64imm>; +defm SET_f32 : SET<"f32", Float32Regs, f32imm>; +defm SET_f64 : SET<"f64", Float64Regs, f64imm>; + +//----------------------------------- +// General Selection +//----------------------------------- + +// General selp instructions +multiclass SELP<string TypeStr, RegisterClass RC, Operand ImmCls> { + def rr : NVPTXInst<(outs RC:$dst), + (ins RC:$a, RC:$b, Int1Regs:$p), + !strconcat("selp.", TypeStr, "\t$dst, $a, $b, $p;"), []>; + def ri : NVPTXInst<(outs RC:$dst), + (ins RC:$a, ImmCls:$b, Int1Regs:$p), + !strconcat("selp.", TypeStr, "\t$dst, $a, $b, $p;"), []>; + def ir : NVPTXInst<(outs RC:$dst), + (ins ImmCls:$a, RC:$b, Int1Regs:$p), + !strconcat("selp.", TypeStr, "\t$dst, $a, $b, $p;"), []>; + def ii : NVPTXInst<(outs RC:$dst), + (ins ImmCls:$a, ImmCls:$b, Int1Regs:$p), + !strconcat("selp.", TypeStr, "\t$dst, $a, $b, $p;"), []>; +} + +multiclass SELP_PATTERN<string TypeStr, RegisterClass RC, Operand ImmCls, + SDNode ImmNode> { + def rr : NVPTXInst<(outs RC:$dst), + (ins RC:$a, RC:$b, Int1Regs:$p), + !strconcat("selp.", TypeStr, "\t$dst, $a, $b, $p;"), + [(set RC:$dst, (select Int1Regs:$p, RC:$a, RC:$b))]>; + def ri : NVPTXInst<(outs RC:$dst), + (ins RC:$a, ImmCls:$b, Int1Regs:$p), + !strconcat("selp.", TypeStr, "\t$dst, $a, $b, $p;"), + [(set RC:$dst, (select Int1Regs:$p, RC:$a, ImmNode:$b))]>; + def ir : NVPTXInst<(outs RC:$dst), + (ins ImmCls:$a, RC:$b, Int1Regs:$p), + !strconcat("selp.", TypeStr, "\t$dst, $a, $b, $p;"), + [(set RC:$dst, (select Int1Regs:$p, ImmNode:$a, RC:$b))]>; + def ii : NVPTXInst<(outs RC:$dst), + (ins ImmCls:$a, ImmCls:$b, Int1Regs:$p), + !strconcat("selp.", TypeStr, "\t$dst, $a, $b, $p;"), + [(set RC:$dst, (select Int1Regs:$p, ImmNode:$a, ImmNode:$b))]>; +} + +defm SELP_b16 : SELP_PATTERN<"b16", Int16Regs, i16imm, imm>; +defm SELP_s16 : SELP<"s16", Int16Regs, i16imm>; +defm SELP_u16 : SELP<"u16", Int16Regs, i16imm>; +defm SELP_b32 : SELP_PATTERN<"b32", Int32Regs, i32imm, imm>; +defm SELP_s32 : SELP<"s32", Int32Regs, i32imm>; +defm SELP_u32 : SELP<"u32", Int32Regs, i32imm>; +defm SELP_b64 : SELP_PATTERN<"b64", Int64Regs, i64imm, imm>; +defm SELP_s64 : SELP<"s64", Int64Regs, i64imm>; +defm SELP_u64 : SELP<"u64", Int64Regs, i64imm>; +defm SELP_f32 : SELP_PATTERN<"f32", Float32Regs, f32imm, fpimm>; +defm SELP_f64 : SELP_PATTERN<"f64", Float64Regs, f64imm, fpimm>; + +// Special select for predicate operands +def : Pat<(i1 (select Int1Regs:$p, Int1Regs:$a, Int1Regs:$b)), + (ORb1rr (ANDb1rr Int1Regs:$p, Int1Regs:$a), + (ANDb1rr (NOT1 Int1Regs:$p), Int1Regs:$b))>; + +//----------------------------------- // Data Movement (Load / Store, Move) //----------------------------------- @@ -1253,12 +1321,19 @@ def MOV_ADDR64 : NVPTXInst<(outs Int64Regs:$dst), (ins imem:$a), "mov.u64 \t$dst, $a;", [(set Int64Regs:$dst, (Wrapper tglobaladdr:$a))]>; +// Get pointer to local stack +def MOV_DEPOT_ADDR + : NVPTXInst<(outs Int32Regs:$d), (ins i32imm:$num), + "mov.u32 \t$d, __local_depot$num;", []>; +def MOV_DEPOT_ADDR_64 + : NVPTXInst<(outs Int64Regs:$d), (ins i32imm:$num), + "mov.u64 \t$d, __local_depot$num;", []>; + + // copyPhysreg is hard-coded in NVPTXInstrInfo.cpp let IsSimpleMove=1 in { def IMOV1rr: NVPTXInst<(outs Int1Regs:$dst), (ins Int1Regs:$sss), "mov.pred \t$dst, $sss;", []>; -def IMOV8rr: NVPTXInst<(outs Int8Regs:$dst), (ins Int8Regs:$sss), - "mov.u16 \t$dst, $sss;", []>; def IMOV16rr: NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$sss), "mov.u16 \t$dst, $sss;", []>; def IMOV32rr: NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$sss), @@ -1274,9 +1349,6 @@ def FMOV64rr: NVPTXInst<(outs Float64Regs:$dst), (ins Float64Regs:$src), def IMOV1ri: NVPTXInst<(outs Int1Regs:$dst), (ins i1imm:$src), "mov.pred \t$dst, $src;", [(set Int1Regs:$dst, imm:$src)]>; -def IMOV8ri: NVPTXInst<(outs Int8Regs:$dst), (ins i8imm:$src), - "mov.u16 \t$dst, $src;", - [(set Int8Regs:$dst, imm:$src)]>; def IMOV16ri: NVPTXInst<(outs Int16Regs:$dst), (ins i16imm:$src), "mov.u16 \t$dst, $src;", [(set Int16Regs:$dst, imm:$src)]>; @@ -1308,440 +1380,194 @@ def LEA_ADDRi64 : NVPTXInst<(outs Int64Regs:$dst), (ins MEMri64:$addr), // Comparison and Selection //----------------------------------- -// Generate string block like -// { -// .reg .pred p; -// setp.gt.s16 p, %a, %b; -// selp.s16 %dst, -1, 0, p; -// } -// when OpcStr=setp.gt.s sz1=16 sz2=16 d=%dst a=%a b=%b -class Set_Str<string OpcStr, string sz1, string sz2, string d, string a, - string b> { - string t1 = "{{\n\t.reg .pred p;\n\t"; - string t2 = !strconcat(t1 , OpcStr); - string t3 = !strconcat(t2 , sz1); - string t4 = !strconcat(t3 , " \tp, "); - string t5 = !strconcat(t4 , a); - string t6 = !strconcat(t5 , ", "); - string t7 = !strconcat(t6 , b); - string t8 = !strconcat(t7 , ";\n\tselp.s"); - string t9 = !strconcat(t8 , sz2); - string t10 = !strconcat(t9, " \t"); - string t11 = !strconcat(t10, d); - string s = !strconcat(t11, ", -1, 0, p;\n\t}}"); +multiclass ISET_FORMAT<PatFrag OpNode, PatLeaf Mode, + Instruction setp_16rr, + Instruction setp_16ri, + Instruction setp_16ir, + Instruction setp_32rr, + Instruction setp_32ri, + Instruction setp_32ir, + Instruction setp_64rr, + Instruction setp_64ri, + Instruction setp_64ir, + Instruction set_16rr, + Instruction set_16ri, + Instruction set_16ir, + Instruction set_32rr, + Instruction set_32ri, + Instruction set_32ir, + Instruction set_64rr, + Instruction set_64ri, + Instruction set_64ir> { + // i16 -> pred + def : Pat<(i1 (OpNode Int16Regs:$a, Int16Regs:$b)), + (setp_16rr Int16Regs:$a, Int16Regs:$b, Mode)>; + def : Pat<(i1 (OpNode Int16Regs:$a, imm:$b)), + (setp_16ri Int16Regs:$a, imm:$b, Mode)>; + def : Pat<(i1 (OpNode imm:$a, Int16Regs:$b)), + (setp_16ir imm:$a, Int16Regs:$b, Mode)>; + // i32 -> pred + def : Pat<(i1 (OpNode Int32Regs:$a, Int32Regs:$b)), + (setp_32rr Int32Regs:$a, Int32Regs:$b, Mode)>; + def : Pat<(i1 (OpNode Int32Regs:$a, imm:$b)), + (setp_32ri Int32Regs:$a, imm:$b, Mode)>; + def : Pat<(i1 (OpNode imm:$a, Int32Regs:$b)), + (setp_32ir imm:$a, Int32Regs:$b, Mode)>; + // i64 -> pred + def : Pat<(i1 (OpNode Int64Regs:$a, Int64Regs:$b)), + (setp_64rr Int64Regs:$a, Int64Regs:$b, Mode)>; + def : Pat<(i1 (OpNode Int64Regs:$a, imm:$b)), + (setp_64ri Int64Regs:$a, imm:$b, Mode)>; + def : Pat<(i1 (OpNode imm:$a, Int64Regs:$b)), + (setp_64ir imm:$a, Int64Regs:$b, Mode)>; + + // i16 -> i32 + def : Pat<(i32 (OpNode Int16Regs:$a, Int16Regs:$b)), + (set_16rr Int16Regs:$a, Int16Regs:$b, Mode)>; + def : Pat<(i32 (OpNode Int16Regs:$a, imm:$b)), + (set_16ri Int16Regs:$a, imm:$b, Mode)>; + def : Pat<(i32 (OpNode imm:$a, Int16Regs:$b)), + (set_16ir imm:$a, Int16Regs:$b, Mode)>; + // i32 -> i32 + def : Pat<(i32 (OpNode Int32Regs:$a, Int32Regs:$b)), + (set_32rr Int32Regs:$a, Int32Regs:$b, Mode)>; + def : Pat<(i32 (OpNode Int32Regs:$a, imm:$b)), + (set_32ri Int32Regs:$a, imm:$b, Mode)>; + def : Pat<(i32 (OpNode imm:$a, Int32Regs:$b)), + (set_32ir imm:$a, Int32Regs:$b, Mode)>; + // i64 -> i32 + def : Pat<(i32 (OpNode Int64Regs:$a, Int64Regs:$b)), + (set_64rr Int64Regs:$a, Int64Regs:$b, Mode)>; + def : Pat<(i32 (OpNode Int64Regs:$a, imm:$b)), + (set_64ri Int64Regs:$a, imm:$b, Mode)>; + def : Pat<(i32 (OpNode imm:$a, Int64Regs:$b)), + (set_64ir imm:$a, Int64Regs:$b, Mode)>; } -// Generate string block like -// { -// .reg .pred p; -// .reg .s16 %temp1; -// .reg .s16 %temp2; -// cvt.s16.s8 %temp1, %a; -// cvt s16.s8 %temp1, %b; -// setp.gt.s16 p, %temp1, %temp2; -// selp.s16 %dst, -1, 0, p; -// } -// when OpcStr=setp.gt.s d=%dst a=%a b=%b type=s16 cvt=cvt.s16.s8 -class Set_Stri8<string OpcStr, string d, string a, string b, string type, - string cvt> { - string t1 = "{{\n\t.reg .pred p;\n\t"; - string t2 = !strconcat(t1, ".reg ."); - string t3 = !strconcat(t2, type); - string t4 = !strconcat(t3, " %temp1;\n\t"); - string t5 = !strconcat(t4, ".reg ."); - string t6 = !strconcat(t5, type); - string t7 = !strconcat(t6, " %temp2;\n\t"); - string t8 = !strconcat(t7, cvt); - string t9 = !strconcat(t8, " \t%temp1, "); - string t10 = !strconcat(t9, a); - string t11 = !strconcat(t10, ";\n\t"); - string t12 = !strconcat(t11, cvt); - string t13 = !strconcat(t12, " \t%temp2, "); - string t14 = !strconcat(t13, b); - string t15 = !strconcat(t14, ";\n\t"); - string t16 = !strconcat(t15, OpcStr); - string t17 = !strconcat(t16, "16"); - string t18 = !strconcat(t17, " \tp, %temp1, %temp2;\n\t"); - string t19 = !strconcat(t18, "selp.s16 \t"); - string t20 = !strconcat(t19, d); - string s = !strconcat(t20, ", -1, 0, p;\n\t}}"); +multiclass ISET_FORMAT_SIGNED<PatFrag OpNode, PatLeaf Mode> + : ISET_FORMAT<OpNode, Mode, + SETP_s16rr, SETP_s16ri, SETP_s16ir, + SETP_s32rr, SETP_s32ri, SETP_s32ir, + SETP_s64rr, SETP_s64ri, SETP_s64ir, + SET_s16rr, SET_s16ri, SET_s16ir, + SET_s32rr, SET_s32ri, SET_s32ir, + SET_s64rr, SET_s64ri, SET_s64ir> { + // TableGen doesn't like empty multiclasses + def : PatLeaf<(i32 0)>; } -multiclass ISET_FORMAT<string OpcStr, string OpcStr_u32, PatFrag OpNode, - string TypeStr, string CVTStr> { - def i8rr_toi8: NVPTXInst<(outs Int8Regs:$dst), (ins Int8Regs:$a, Int8Regs:$b), - Set_Stri8<OpcStr, "$dst", "$a", "$b", TypeStr, CVTStr>.s, - []>; - def i16rr_toi16: NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, - Int16Regs:$b), - Set_Str<OpcStr, "16", "16", "$dst", "$a", "$b">.s, - []>; - def i32rr_toi32: NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a, - Int32Regs:$b), - Set_Str<OpcStr, "32", "32", "$dst", "$a", "$b">.s, - []>; - def i64rr_toi64: NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$a, - Int64Regs:$b), - Set_Str<OpcStr, "64", "64", "$dst", "$a", "$b">.s, - []>; - - def i8rr_p: NVPTXInst<(outs Int1Regs:$dst), (ins Int8Regs:$a, Int8Regs:$b), - Handle_i8rr<OpcStr, TypeStr, CVTStr>.s, - [(set Int1Regs:$dst, (OpNode Int8Regs:$a, Int8Regs:$b))]>; - def i8ri_p: NVPTXInst<(outs Int1Regs:$dst), (ins Int8Regs:$a, i8imm:$b), - Handle_i8ri<OpcStr, TypeStr, CVTStr>.s, - [(set Int1Regs:$dst, (OpNode Int8Regs:$a, imm:$b))]>; - def i8ir_p: NVPTXInst<(outs Int1Regs:$dst), (ins i8imm:$a, Int8Regs:$b), - Handle_i8ir<OpcStr, TypeStr, CVTStr>.s, - [(set Int1Regs:$dst, (OpNode imm:$a, Int8Regs:$b))]>; - def i16rr_p: NVPTXInst<(outs Int1Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b), - !strconcat(OpcStr, "16 \t$dst, $a, $b;"), - [(set Int1Regs:$dst, (OpNode Int16Regs:$a, Int16Regs:$b))]>; - def i16ri_p: NVPTXInst<(outs Int1Regs:$dst), (ins Int16Regs:$a, i16imm:$b), - !strconcat(OpcStr, "16 \t$dst, $a, $b;"), - [(set Int1Regs:$dst, (OpNode Int16Regs:$a, imm:$b))]>; - def i16ir_p: NVPTXInst<(outs Int1Regs:$dst), (ins i16imm:$a, Int16Regs:$b), - !strconcat(OpcStr, "16 \t$dst, $a, $b;"), - [(set Int1Regs:$dst, (OpNode imm:$a, Int16Regs:$b))]>; - def i32rr_p: NVPTXInst<(outs Int1Regs:$dst), (ins Int32Regs:$a, Int32Regs:$b), - !strconcat(OpcStr, "32 \t$dst, $a, $b;"), - [(set Int1Regs:$dst, (OpNode Int32Regs:$a, Int32Regs:$b))]>; - def i32ri_p: NVPTXInst<(outs Int1Regs:$dst), (ins Int32Regs:$a, i32imm:$b), - !strconcat(OpcStr, "32 \t$dst, $a, $b;"), - [(set Int1Regs:$dst, (OpNode Int32Regs:$a, imm:$b))]>; - def i32ir_p: NVPTXInst<(outs Int1Regs:$dst), (ins i32imm:$a, Int32Regs:$b), - !strconcat(OpcStr, "32 \t$dst, $a, $b;"), - [(set Int1Regs:$dst, (OpNode imm:$a, Int32Regs:$b))]>; - def i64rr_p: NVPTXInst<(outs Int1Regs:$dst), (ins Int64Regs:$a, Int64Regs:$b), - !strconcat(OpcStr, "64 \t$dst, $a, $b;"), - [(set Int1Regs:$dst, (OpNode Int64Regs:$a, Int64Regs:$b))]>; - def i64ri_p: NVPTXInst<(outs Int1Regs:$dst), (ins Int64Regs:$a, i64imm:$b), - !strconcat(OpcStr, "64 \t$dst, $a, $b;"), - [(set Int1Regs:$dst, (OpNode Int64Regs:$a, imm:$b))]>; - def i64ir_p: NVPTXInst<(outs Int1Regs:$dst), (ins i64imm:$a, Int64Regs:$b), - !strconcat(OpcStr, "64 \t$dst, $a, $b;"), - [(set Int1Regs:$dst, (OpNode imm:$a, Int64Regs:$b))]>; - - def i8rr_u32: NVPTXInst<(outs Int32Regs:$dst), (ins Int8Regs:$a, Int8Regs:$b), - Handle_i8rr<OpcStr_u32, TypeStr, CVTStr>.s, - [(set Int32Regs:$dst, (OpNode Int8Regs:$a, Int8Regs:$b))]>; - def i8ri_u32: NVPTXInst<(outs Int32Regs:$dst), (ins Int8Regs:$a, i8imm:$b), - Handle_i8ri<OpcStr_u32, TypeStr, CVTStr>.s, - [(set Int32Regs:$dst, (OpNode Int8Regs:$a, imm:$b))]>; - def i8ir_u32: NVPTXInst<(outs Int32Regs:$dst), (ins i8imm:$a, Int8Regs:$b), - Handle_i8ir<OpcStr_u32, TypeStr, CVTStr>.s, - [(set Int32Regs:$dst, (OpNode imm:$a, Int8Regs:$b))]>; - def i16rr_u32: NVPTXInst<(outs Int32Regs:$dst), (ins Int16Regs:$a, - Int16Regs:$b), - !strconcat(OpcStr_u32, "16 \t$dst, $a, $b;"), - [(set Int32Regs:$dst, (OpNode Int16Regs:$a, Int16Regs:$b))]>; - def i16ri_u32: NVPTXInst<(outs Int32Regs:$dst), (ins Int16Regs:$a, i16imm:$b), - !strconcat(OpcStr_u32, "16 \t$dst, $a, $b;"), - [(set Int32Regs:$dst, (OpNode Int16Regs:$a, imm:$b))]>; - def i16ir_u32: NVPTXInst<(outs Int32Regs:$dst), (ins i16imm:$a, Int16Regs:$b), - !strconcat(OpcStr_u32, "16 \t$dst, $a, $b;"), - [(set Int32Regs:$dst, (OpNode imm:$a, Int16Regs:$b))]>; - def i32rr_u32: NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a, - Int32Regs:$b), - !strconcat(OpcStr_u32, "32 \t$dst, $a, $b;"), - [(set Int32Regs:$dst, (OpNode Int32Regs:$a, Int32Regs:$b))]>; - def i32ri_u32: NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a, i32imm:$b), - !strconcat(OpcStr_u32, "32 \t$dst, $a, $b;"), - [(set Int32Regs:$dst, (OpNode Int32Regs:$a, imm:$b))]>; - def i32ir_u32: NVPTXInst<(outs Int32Regs:$dst), (ins i32imm:$a, Int32Regs:$b), - !strconcat(OpcStr_u32, "32 \t$dst, $a, $b;"), - [(set Int32Regs:$dst, (OpNode imm:$a, Int32Regs:$b))]>; - def i64rr_u32: NVPTXInst<(outs Int32Regs:$dst), (ins Int64Regs:$a, - Int64Regs:$b), - !strconcat(OpcStr_u32, "64 \t$dst, $a, $b;"), - [(set Int32Regs:$dst, (OpNode Int64Regs:$a, Int64Regs:$b))]>; - def i64ri_u32: NVPTXInst<(outs Int32Regs:$dst), (ins Int64Regs:$a, i64imm:$b), - !strconcat(OpcStr_u32, "64 \t$dst, $a, $b;"), - [(set Int32Regs:$dst, (OpNode Int64Regs:$a, imm:$b))]>; - def i64ir_u32: NVPTXInst<(outs Int32Regs:$dst), (ins i64imm:$a, Int64Regs:$b), - !strconcat(OpcStr_u32, "64 \t$dst, $a, $b;"), - [(set Int32Regs:$dst, (OpNode imm:$a, Int64Regs:$b))]>; +multiclass ISET_FORMAT_UNSIGNED<PatFrag OpNode, PatLeaf Mode> + : ISET_FORMAT<OpNode, Mode, + SETP_u16rr, SETP_u16ri, SETP_u16ir, + SETP_u32rr, SETP_u32ri, SETP_u32ir, + SETP_u64rr, SETP_u64ri, SETP_u64ir, + SET_u16rr, SET_u16ri, SET_u16ir, + SET_u32rr, SET_u32ri, SET_u32ir, + SET_u64rr, SET_u64ri, SET_u64ir> { + // TableGen doesn't like empty multiclasses + def : PatLeaf<(i32 0)>; } -multiclass FSET_FORMAT<string OpcStr, string OpcStr_u32, PatFrag OpNode> { - def f32rr_toi32_ftz: NVPTXInst<(outs Int32Regs:$dst), (ins Float32Regs:$a, - Float32Regs:$b), - Set_Str<OpcStr, "ftz.f32", "32", "$dst", "$a", "$b">.s, - []>, Requires<[doF32FTZ]>; - def f32rr_toi32: NVPTXInst<(outs Int32Regs:$dst), (ins Float32Regs:$a, - Float32Regs:$b), - Set_Str<OpcStr, "f32", "32", "$dst", "$a", "$b">.s, - []>; - def f64rr_toi64: NVPTXInst<(outs Int64Regs:$dst), (ins Float64Regs:$a, - Float64Regs:$b), - Set_Str<OpcStr, "f64", "64", "$dst", "$a", "$b">.s, - []>; - def f64rr_toi32: NVPTXInst<(outs Int32Regs:$dst), (ins Float64Regs:$a, - Float64Regs:$b), - Set_Str<OpcStr, "f64", "32", "$dst", "$a", "$b">.s, - []>; - - def f32rr_p_ftz: NVPTXInst<(outs Int1Regs:$dst), (ins Float32Regs:$a - , Float32Regs:$b), - !strconcat(OpcStr, "ftz.f32 \t$dst, $a, $b;"), - [(set Int1Regs:$dst, (OpNode Float32Regs:$a, Float32Regs:$b))]> - , Requires<[doF32FTZ]>; - def f32rr_p: NVPTXInst<(outs Int1Regs:$dst), - (ins Float32Regs:$a, Float32Regs:$b), - !strconcat(OpcStr, "f32 \t$dst, $a, $b;"), - [(set Int1Regs:$dst, (OpNode Float32Regs:$a, Float32Regs:$b))]>; - def f32ri_p_ftz: NVPTXInst<(outs Int1Regs:$dst), - (ins Float32Regs:$a, f32imm:$b), - !strconcat(OpcStr, "ftz.f32 \t$dst, $a, $b;"), - [(set Int1Regs:$dst, (OpNode Float32Regs:$a, fpimm:$b))]>, - Requires<[doF32FTZ]>; - def f32ri_p: NVPTXInst<(outs Int1Regs:$dst), (ins Float32Regs:$a, f32imm:$b), - !strconcat(OpcStr, "f32 \t$dst, $a, $b;"), - [(set Int1Regs:$dst, (OpNode Float32Regs:$a, fpimm:$b))]>; - def f32ir_p_ftz: NVPTXInst<(outs Int1Regs:$dst), - (ins f32imm:$a, Float32Regs:$b), - !strconcat(OpcStr, "ftz.f32 \t$dst, $a, $b;"), - [(set Int1Regs:$dst, (OpNode fpimm:$a, Float32Regs:$b))]>, - Requires<[doF32FTZ]>; - def f32ir_p: NVPTXInst<(outs Int1Regs:$dst), (ins f32imm:$a, Float32Regs:$b), - !strconcat(OpcStr, "f32 \t$dst, $a, $b;"), - [(set Int1Regs:$dst, (OpNode fpimm:$a, Float32Regs:$b))]>; - def f64rr_p: NVPTXInst<(outs Int1Regs:$dst), - (ins Float64Regs:$a, Float64Regs:$b), - !strconcat(OpcStr, "f64 \t$dst, $a, $b;"), - [(set Int1Regs:$dst, (OpNode Float64Regs:$a, Float64Regs:$b))]>; - def f64ri_p: NVPTXInst<(outs Int1Regs:$dst), (ins Float64Regs:$a, f64imm:$b), - !strconcat(OpcStr, "f64 \t$dst, $a, $b;"), - [(set Int1Regs:$dst, (OpNode Float64Regs:$a, fpimm:$b))]>; - def f64ir_p: NVPTXInst<(outs Int1Regs:$dst), (ins f64imm:$a, Float64Regs:$b), - !strconcat(OpcStr, "f64 \t$dst, $a, $b;"), - [(set Int1Regs:$dst, (OpNode fpimm:$a, Float64Regs:$b))]>; - - def f32rr_u32_ftz: NVPTXInst<(outs Int32Regs:$dst), - (ins Float32Regs:$a, Float32Regs:$b), - !strconcat(OpcStr_u32, "ftz.f32 \t$dst, $a, $b;"), - [(set Int32Regs:$dst, (OpNode Float32Regs:$a, Float32Regs:$b))]>; - def f32rr_u32: NVPTXInst<(outs Int32Regs:$dst), - (ins Float32Regs:$a, Float32Regs:$b), - !strconcat(OpcStr_u32, "f32 \t$dst, $a, $b;"), - [(set Int32Regs:$dst, (OpNode Float32Regs:$a, Float32Regs:$b))]>; - def f32ri_u32_ftz: NVPTXInst<(outs Int32Regs:$dst), - (ins Float32Regs:$a, f32imm:$b), - !strconcat(OpcStr_u32, "ftz.f32 \t$dst, $a, $b;"), - [(set Int32Regs:$dst, (OpNode Float32Regs:$a, fpimm:$b))]>; - def f32ri_u32: NVPTXInst<(outs Int32Regs:$dst), - (ins Float32Regs:$a, f32imm:$b), - !strconcat(OpcStr_u32, "f32 \t$dst, $a, $b;"), - [(set Int32Regs:$dst, (OpNode Float32Regs:$a, fpimm:$b))]>; - def f32ir_u32_ftz: NVPTXInst<(outs Int32Regs:$dst), - (ins f32imm:$a, Float32Regs:$b), - !strconcat(OpcStr_u32, "ftz.f32 \t$dst, $a, $b;"), - [(set Int32Regs:$dst, (OpNode fpimm:$a, Float32Regs:$b))]>; - def f32ir_u32: NVPTXInst<(outs Int32Regs:$dst), - (ins f32imm:$a, Float32Regs:$b), - !strconcat(OpcStr_u32, "f32 \t$dst, $a, $b;"), - [(set Int32Regs:$dst, (OpNode fpimm:$a, Float32Regs:$b))]>; - def f64rr_u32: NVPTXInst<(outs Int32Regs:$dst), - (ins Float64Regs:$a, Float64Regs:$b), - !strconcat(OpcStr_u32, "f64 \t$dst, $a, $b;"), - [(set Int32Regs:$dst, (OpNode Float64Regs:$a, Float64Regs:$b))]>; - def f64ri_u32: NVPTXInst<(outs Int32Regs:$dst), - (ins Float64Regs:$a, f64imm:$b), - !strconcat(OpcStr_u32, "f64 \t$dst, $a, $b;"), - [(set Int32Regs:$dst, (OpNode Float64Regs:$a, fpimm:$b))]>; - def f64ir_u32: NVPTXInst<(outs Int32Regs:$dst), - (ins f64imm:$a, Float64Regs:$b), - !strconcat(OpcStr_u32, "f64 \t$dst, $a, $b;"), - [(set Int32Regs:$dst, (OpNode fpimm:$a, Float64Regs:$b))]>; +defm : ISET_FORMAT_SIGNED<setgt, CmpGT>; +defm : ISET_FORMAT_UNSIGNED<setugt, CmpGT>; +defm : ISET_FORMAT_SIGNED<setlt, CmpLT>; +defm : ISET_FORMAT_UNSIGNED<setult, CmpLT>; +defm : ISET_FORMAT_SIGNED<setge, CmpGE>; +defm : ISET_FORMAT_UNSIGNED<setuge, CmpGE>; +defm : ISET_FORMAT_SIGNED<setle, CmpLE>; +defm : ISET_FORMAT_UNSIGNED<setule, CmpLE>; +defm : ISET_FORMAT_SIGNED<seteq, CmpEQ>; +defm : ISET_FORMAT_UNSIGNED<setueq, CmpEQ>; +defm : ISET_FORMAT_SIGNED<setne, CmpNE>; +defm : ISET_FORMAT_UNSIGNED<setune, CmpNE>; + +// i1 compares +def : Pat<(setne Int1Regs:$a, Int1Regs:$b), + (XORb1rr Int1Regs:$a, Int1Regs:$b)>; +def : Pat<(setune Int1Regs:$a, Int1Regs:$b), + (XORb1rr Int1Regs:$a, Int1Regs:$b)>; + +def : Pat<(seteq Int1Regs:$a, Int1Regs:$b), + (NOT1 (XORb1rr Int1Regs:$a, Int1Regs:$b))>; +def : Pat<(setueq Int1Regs:$a, Int1Regs:$b), + (NOT1 (XORb1rr Int1Regs:$a, Int1Regs:$b))>; + +// i1 compare -> i32 +def : Pat<(i32 (setne Int1Regs:$a, Int1Regs:$b)), + (SELP_u32ii -1, 0, (XORb1rr Int1Regs:$a, Int1Regs:$b))>; +def : Pat<(i32 (setne Int1Regs:$a, Int1Regs:$b)), + (SELP_u32ii 0, -1, (XORb1rr Int1Regs:$a, Int1Regs:$b))>; + + + +multiclass FSET_FORMAT<PatFrag OpNode, PatLeaf Mode, PatLeaf ModeFTZ> { + // f32 -> pred + def : Pat<(i1 (OpNode Float32Regs:$a, Float32Regs:$b)), + (SETP_f32rr Float32Regs:$a, Float32Regs:$b, ModeFTZ)>, + Requires<[doF32FTZ]>; + def : Pat<(i1 (OpNode Float32Regs:$a, Float32Regs:$b)), + (SETP_f32rr Float32Regs:$a, Float32Regs:$b, Mode)>; + def : Pat<(i1 (OpNode Float32Regs:$a, fpimm:$b)), + (SETP_f32ri Float32Regs:$a, fpimm:$b, ModeFTZ)>, + Requires<[doF32FTZ]>; + def : Pat<(i1 (OpNode Float32Regs:$a, fpimm:$b)), + (SETP_f32ri Float32Regs:$a, fpimm:$b, Mode)>; + def : Pat<(i1 (OpNode fpimm:$a, Float32Regs:$b)), + (SETP_f32ir fpimm:$a, Float32Regs:$b, ModeFTZ)>, + Requires<[doF32FTZ]>; + def : Pat<(i1 (OpNode fpimm:$a, Float32Regs:$b)), + (SETP_f32ir fpimm:$a, Float32Regs:$b, Mode)>; + + // f64 -> pred + def : Pat<(i1 (OpNode Float64Regs:$a, Float64Regs:$b)), + (SETP_f64rr Float64Regs:$a, Float64Regs:$b, Mode)>; + def : Pat<(i1 (OpNode Float64Regs:$a, fpimm:$b)), + (SETP_f64ri Float64Regs:$a, fpimm:$b, Mode)>; + def : Pat<(i1 (OpNode fpimm:$a, Float64Regs:$b)), + (SETP_f64ir fpimm:$a, Float64Regs:$b, Mode)>; + + // f32 -> i32 + def : Pat<(i32 (OpNode Float32Regs:$a, Float32Regs:$b)), + (SET_f32rr Float32Regs:$a, Float32Regs:$b, ModeFTZ)>, + Requires<[doF32FTZ]>; + def : Pat<(i32 (OpNode Float32Regs:$a, Float32Regs:$b)), + (SET_f32rr Float32Regs:$a, Float32Regs:$b, Mode)>; + def : Pat<(i32 (OpNode Float32Regs:$a, fpimm:$b)), + (SET_f32ri Float32Regs:$a, fpimm:$b, ModeFTZ)>, + Requires<[doF32FTZ]>; + def : Pat<(i32 (OpNode Float32Regs:$a, fpimm:$b)), + (SET_f32ri Float32Regs:$a, fpimm:$b, Mode)>; + def : Pat<(i32 (OpNode fpimm:$a, Float32Regs:$b)), + (SET_f32ir fpimm:$a, Float32Regs:$b, ModeFTZ)>, + Requires<[doF32FTZ]>; + def : Pat<(i32 (OpNode fpimm:$a, Float32Regs:$b)), + (SET_f32ir fpimm:$a, Float32Regs:$b, Mode)>; + + // f64 -> i32 + def : Pat<(i32 (OpNode Float64Regs:$a, Float64Regs:$b)), + (SET_f64rr Float64Regs:$a, Float64Regs:$b, Mode)>; + def : Pat<(i32 (OpNode Float64Regs:$a, fpimm:$b)), + (SET_f64ri Float64Regs:$a, fpimm:$b, Mode)>; + def : Pat<(i32 (OpNode fpimm:$a, Float64Regs:$b)), + (SET_f64ir fpimm:$a, Float64Regs:$b, Mode)>; } -defm ISetSGT -: ISET_FORMAT<"setp.gt.s", "set.gt.u32.s", setgt, "s16", "cvt.s16.s8">; -defm ISetUGT -: ISET_FORMAT<"setp.gt.u", "set.gt.u32.u", setugt, "u16", "cvt.u16.u8">; -defm ISetSLT -: ISET_FORMAT<"setp.lt.s", "set.lt.u32.s", setlt, "s16", "cvt.s16.s8">; -defm ISetULT -: ISET_FORMAT<"setp.lt.u", "set.lt.u32.u", setult, "u16", "cvt.u16.u8">; -defm ISetSGE -: ISET_FORMAT<"setp.ge.s", "set.ge.u32.s", setge, "s16", "cvt.s16.s8">; -defm ISetUGE -: ISET_FORMAT<"setp.ge.u", "set.ge.u32.u", setuge, "u16", "cvt.u16.u8">; -defm ISetSLE -: ISET_FORMAT<"setp.le.s", "set.le.u32.s", setle, "s16", "cvt.s16.s8">; -defm ISetULE -: ISET_FORMAT<"setp.le.u", "set.le.u32.u", setule, "u16", "cvt.u16.u8">; -defm ISetSEQ -: ISET_FORMAT<"setp.eq.s", "set.eq.u32.s", seteq, "s16", "cvt.s16.s8">; -defm ISetUEQ -: ISET_FORMAT<"setp.eq.u", "set.eq.u32.u", setueq, "u16", "cvt.u16.u8">; -defm ISetSNE -: ISET_FORMAT<"setp.ne.s", "set.ne.u32.s", setne, "s16", "cvt.s16.s8">; -defm ISetUNE -: ISET_FORMAT<"setp.ne.u", "set.ne.u32.u", setune, "u16", "cvt.u16.u8">; - -def ISetSNEi1rr_p : NVPTXInst<(outs Int1Regs:$dst), - (ins Int1Regs:$a, Int1Regs:$b), - "xor.pred \t$dst, $a, $b;", - [(set Int1Regs:$dst, (setne Int1Regs:$a, Int1Regs:$b))]>; -def ISetUNEi1rr_p : NVPTXInst<(outs Int1Regs:$dst), - (ins Int1Regs:$a, Int1Regs:$b), - "xor.pred \t$dst, $a, $b;", - [(set Int1Regs:$dst, (setune Int1Regs:$a, Int1Regs:$b))]>; -def ISetSEQi1rr_p : NVPTXInst<(outs Int1Regs:$dst), - (ins Int1Regs:$a, Int1Regs:$b), - !strconcat("{{\n\t", - !strconcat(".reg .pred temp;\n\t", - !strconcat("xor.pred \ttemp, $a, $b;\n\t", - !strconcat("not.pred \t$dst, temp;\n\t}}","")))), - [(set Int1Regs:$dst, (seteq Int1Regs:$a, Int1Regs:$b))]>; -def ISetUEQi1rr_p : NVPTXInst<(outs Int1Regs:$dst), - (ins Int1Regs:$a, Int1Regs:$b), - !strconcat("{{\n\t", - !strconcat(".reg .pred temp;\n\t", - !strconcat("xor.pred \ttemp, $a, $b;\n\t", - !strconcat("not.pred \t$dst, temp;\n\t}}","")))), - [(set Int1Regs:$dst, (setueq Int1Regs:$a, Int1Regs:$b))]>; - -// Compare 2 i1's and produce a u32 -def ISETSNEi1rr_u32 : NVPTXInst<(outs Int32Regs:$dst), - (ins Int1Regs:$a, Int1Regs:$b), - !strconcat("{{\n\t", - !strconcat(".reg .pred temp;\n\t", - !strconcat("xor.pred \ttemp, $a, $b;\n\t", - !strconcat("selp.u32 \t$dst, -1, 0, temp;", "\n\t}}")))), - [(set Int32Regs:$dst, (setne Int1Regs:$a, Int1Regs:$b))]>; -def ISETSEQi1rr_u32 : NVPTXInst<(outs Int32Regs:$dst), - (ins Int1Regs:$a, Int1Regs:$b), - !strconcat("{{\n\t", - !strconcat(".reg .pred temp;\n\t", - !strconcat("xor.pred \ttemp, $a, $b;\n\t", - !strconcat("selp.u32 \t$dst, 0, -1, temp;", "\n\t}}")))), - [(set Int32Regs:$dst, (seteq Int1Regs:$a, Int1Regs:$b))]>; - -defm FSetGT : FSET_FORMAT<"setp.gt.", "set.gt.u32.", setogt>; -defm FSetLT : FSET_FORMAT<"setp.lt.", "set.lt.u32.", setolt>; -defm FSetGE : FSET_FORMAT<"setp.ge.", "set.ge.u32.", setoge>; -defm FSetLE : FSET_FORMAT<"setp.le.", "set.le.u32.", setole>; -defm FSetEQ : FSET_FORMAT<"setp.eq.", "set.eq.u32.", setoeq>; -defm FSetNE : FSET_FORMAT<"setp.ne.", "set.ne.u32.", setone>; - -defm FSetUGT : FSET_FORMAT<"setp.gtu.", "set.gtu.u32.", setugt>; -defm FSetULT : FSET_FORMAT<"setp.ltu.", "set.ltu.u32.",setult>; -defm FSetUGE : FSET_FORMAT<"setp.geu.", "set.geu.u32.",setuge>; -defm FSetULE : FSET_FORMAT<"setp.leu.", "set.leu.u32.",setule>; -defm FSetUEQ : FSET_FORMAT<"setp.equ.", "set.equ.u32.",setueq>; -defm FSetUNE : FSET_FORMAT<"setp.neu.", "set.neu.u32.",setune>; - -defm FSetNUM : FSET_FORMAT<"setp.num.", "set.num.u32.",seto>; -defm FSetNAN : FSET_FORMAT<"setp.nan.", "set.nan.u32.",setuo>; - -def SELECTi1rr : Pat<(i1 (select Int1Regs:$p, Int1Regs:$a, Int1Regs:$b)), - (ORb1rr (ANDb1rr Int1Regs:$p, Int1Regs:$a), - (ANDb1rr (NOT1 Int1Regs:$p), Int1Regs:$b))>; -def SELECTi8rr : NVPTXInst<(outs Int8Regs:$dst), - (ins Int8Regs:$a, Int8Regs:$b, Int1Regs:$p), - "selp.b16 \t$dst, $a, $b, $p;", - [(set Int8Regs:$dst, (select Int1Regs:$p, Int8Regs:$a, Int8Regs:$b))]>; -def SELECTi8ri : NVPTXInst<(outs Int8Regs:$dst), - (ins Int8Regs:$a, i8imm:$b, Int1Regs:$p), - "selp.b16 \t$dst, $a, $b, $p;", - [(set Int8Regs:$dst, (select Int1Regs:$p, Int8Regs:$a, imm:$b))]>; -def SELECTi8ir : NVPTXInst<(outs Int8Regs:$dst), - (ins i8imm:$a, Int8Regs:$b, Int1Regs:$p), - "selp.b16 \t$dst, $a, $b, $p;", - [(set Int8Regs:$dst, (select Int1Regs:$p, imm:$a, Int8Regs:$b))]>; -def SELECTi8ii : NVPTXInst<(outs Int8Regs:$dst), - (ins i8imm:$a, i8imm:$b, Int1Regs:$p), - "selp.b16 \t$dst, $a, $b, $p;", - [(set Int8Regs:$dst, (select Int1Regs:$p, imm:$a, imm:$b))]>; - -def SELECTi16rr : NVPTXInst<(outs Int16Regs:$dst), - (ins Int16Regs:$a, Int16Regs:$b, Int1Regs:$p), - "selp.b16 \t$dst, $a, $b, $p;", - [(set Int16Regs:$dst, (select Int1Regs:$p, Int16Regs:$a, Int16Regs:$b))]>; -def SELECTi16ri : NVPTXInst<(outs Int16Regs:$dst), - (ins Int16Regs:$a, i16imm:$b, Int1Regs:$p), - "selp.b16 \t$dst, $a, $b, $p;", - [(set Int16Regs:$dst, (select Int1Regs:$p, Int16Regs:$a, imm:$b))]>; -def SELECTi16ir : NVPTXInst<(outs Int16Regs:$dst), - (ins i16imm:$a, Int16Regs:$b, Int1Regs:$p), - "selp.b16 \t$dst, $a, $b, $p;", - [(set Int16Regs:$dst, (select Int1Regs:$p, imm:$a, Int16Regs:$b))]>; -def SELECTi16ii : NVPTXInst<(outs Int16Regs:$dst), - (ins i16imm:$a, i16imm:$b, Int1Regs:$p), - "selp.b16 \t$dst, $a, $b, $p;", - [(set Int16Regs:$dst, (select Int1Regs:$p, imm:$a, imm:$b))]>; - -def SELECTi32rr : NVPTXInst<(outs Int32Regs:$dst), - (ins Int32Regs:$a, Int32Regs:$b, Int1Regs:$p), - "selp.b32 \t$dst, $a, $b, $p;", - [(set Int32Regs:$dst, (select Int1Regs:$p, Int32Regs:$a, Int32Regs:$b))]>; -def SELECTi32ri : NVPTXInst<(outs Int32Regs:$dst), - (ins Int32Regs:$a, i32imm:$b, Int1Regs:$p), - "selp.b32 \t$dst, $a, $b, $p;", - [(set Int32Regs:$dst, (select Int1Regs:$p, Int32Regs:$a, imm:$b))]>; -def SELECTi32ir : NVPTXInst<(outs Int32Regs:$dst), - (ins i32imm:$a, Int32Regs:$b, Int1Regs:$p), - "selp.b32 \t$dst, $a, $b, $p;", - [(set Int32Regs:$dst, (select Int1Regs:$p, imm:$a, Int32Regs:$b))]>; -def SELECTi32ii : NVPTXInst<(outs Int32Regs:$dst), - (ins i32imm:$a, i32imm:$b, Int1Regs:$p), - "selp.b32 \t$dst, $a, $b, $p;", - [(set Int32Regs:$dst, (select Int1Regs:$p, imm:$a, imm:$b))]>; - -def SELECTi64rr : NVPTXInst<(outs Int64Regs:$dst), - (ins Int64Regs:$a, Int64Regs:$b, Int1Regs:$p), - "selp.b64 \t$dst, $a, $b, $p;", - [(set Int64Regs:$dst, (select Int1Regs:$p, Int64Regs:$a, Int64Regs:$b))]>; -def SELECTi64ri : NVPTXInst<(outs Int64Regs:$dst), - (ins Int64Regs:$a, i64imm:$b, Int1Regs:$p), - "selp.b64 \t$dst, $a, $b, $p;", - [(set Int64Regs:$dst, (select Int1Regs:$p, Int64Regs:$a, imm:$b))]>; -def SELECTi64ir : NVPTXInst<(outs Int64Regs:$dst), - (ins i64imm:$a, Int64Regs:$b, Int1Regs:$p), - "selp.b64 \t$dst, $a, $b, $p;", - [(set Int64Regs:$dst, (select Int1Regs:$p, imm:$a, Int64Regs:$b))]>; -def SELECTi64ii : NVPTXInst<(outs Int64Regs:$dst), - (ins i64imm:$a, i64imm:$b, Int1Regs:$p), - "selp.b64 \t$dst, $a, $b, $p;", - [(set Int64Regs:$dst, (select Int1Regs:$p, imm:$a, imm:$b))]>; - -def SELECTf32rr : NVPTXInst<(outs Float32Regs:$dst), - (ins Float32Regs:$a, Float32Regs:$b, Int1Regs:$p), - "selp.f32 \t$dst, $a, $b, $p;", - [(set Float32Regs:$dst, - (select Int1Regs:$p, Float32Regs:$a, Float32Regs:$b))]>; -def SELECTf32ri : NVPTXInst<(outs Float32Regs:$dst), - (ins Float32Regs:$a, f32imm:$b, Int1Regs:$p), - "selp.f32 \t$dst, $a, $b, $p;", - [(set Float32Regs:$dst, (select Int1Regs:$p, Float32Regs:$a, fpimm:$b))]>; -def SELECTf32ir : NVPTXInst<(outs Float32Regs:$dst), - (ins f32imm:$a, Float32Regs:$b, Int1Regs:$p), - "selp.f32 \t$dst, $a, $b, $p;", - [(set Float32Regs:$dst, (select Int1Regs:$p, fpimm:$a, Float32Regs:$b))]>; -def SELECTf32ii : NVPTXInst<(outs Float32Regs:$dst), - (ins f32imm:$a, f32imm:$b, Int1Regs:$p), - "selp.f32 \t$dst, $a, $b, $p;", - [(set Float32Regs:$dst, (select Int1Regs:$p, fpimm:$a, fpimm:$b))]>; - -def SELECTf64rr : NVPTXInst<(outs Float64Regs:$dst), - (ins Float64Regs:$a, Float64Regs:$b, Int1Regs:$p), - "selp.f64 \t$dst, $a, $b, $p;", - [(set Float64Regs:$dst, - (select Int1Regs:$p, Float64Regs:$a, Float64Regs:$b))]>; -def SELECTf64ri : NVPTXInst<(outs Float64Regs:$dst), - (ins Float64Regs:$a, f64imm:$b, Int1Regs:$p), - "selp.f64 \t$dst, $a, $b, $p;", - [(set Float64Regs:$dst, (select Int1Regs:$p, Float64Regs:$a, fpimm:$b))]>; -def SELECTf64ir : NVPTXInst<(outs Float64Regs:$dst), - (ins f64imm:$a, Float64Regs:$b, Int1Regs:$p), - "selp.f64 \t$dst, $a, $b, $p;", - [(set Float64Regs:$dst, (select Int1Regs:$p, fpimm:$a, Float64Regs:$b))]>; -def SELECTf64ii : NVPTXInst<(outs Float64Regs:$dst), - (ins f64imm:$a, f64imm:$b, Int1Regs:$p), - "selp.f64 \t $dst, $a, $b, $p;", - [(set Float64Regs:$dst, (select Int1Regs:$p, fpimm:$a, fpimm:$b))]>; +defm FSetGT : FSET_FORMAT<setogt, CmpGT, CmpGT_FTZ>; +defm FSetLT : FSET_FORMAT<setolt, CmpLT, CmpLT_FTZ>; +defm FSetGE : FSET_FORMAT<setoge, CmpGE, CmpGE_FTZ>; +defm FSetLE : FSET_FORMAT<setole, CmpLE, CmpLE_FTZ>; +defm FSetEQ : FSET_FORMAT<setoeq, CmpEQ, CmpEQ_FTZ>; +defm FSetNE : FSET_FORMAT<setone, CmpNE, CmpNE_FTZ>; + +defm FSetUGT : FSET_FORMAT<setugt, CmpGTU, CmpGTU_FTZ>; +defm FSetULT : FSET_FORMAT<setult, CmpLTU, CmpLTU_FTZ>; +defm FSetUGE : FSET_FORMAT<setuge, CmpGEU, CmpGEU_FTZ>; +defm FSetULE : FSET_FORMAT<setule, CmpLEU, CmpLEU_FTZ>; +defm FSetUEQ : FSET_FORMAT<setueq, CmpEQU, CmpEQU_FTZ>; +defm FSetUNE : FSET_FORMAT<setune, CmpNEU, CmpNEU_FTZ>; + +defm FSetNUM : FSET_FORMAT<seto, CmpNUM, CmpNUM_FTZ>; +defm FSetNAN : FSET_FORMAT<setuo, CmpNAN, CmpNAN_FTZ>; //def ld_param : SDNode<"NVPTXISD::LOAD_PARAM", SDTLoad, // [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>; @@ -1751,17 +1577,22 @@ def SDTDeclareParamProfile : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisInt<1>, def SDTDeclareScalarParamProfile : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisInt<1>, SDTCisInt<2>]>; def SDTLoadParamProfile : SDTypeProfile<1, 2, [SDTCisInt<1>, SDTCisInt<2>]>; +def SDTLoadParamV2Profile : SDTypeProfile<2, 2, [SDTCisSameAs<0, 1>, SDTCisInt<2>, SDTCisInt<3>]>; +def SDTLoadParamV4Profile : SDTypeProfile<4, 2, [SDTCisInt<4>, SDTCisInt<5>]>; def SDTPrintCallProfile : SDTypeProfile<0, 1, [SDTCisInt<0>]>; def SDTPrintCallUniProfile : SDTypeProfile<0, 1, [SDTCisInt<0>]>; def SDTStoreParamProfile : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisInt<1>]>; +def SDTStoreParamV2Profile : SDTypeProfile<0, 4, [SDTCisInt<0>, SDTCisInt<1>]>; +def SDTStoreParamV4Profile : SDTypeProfile<0, 6, [SDTCisInt<0>, SDTCisInt<1>]>; def SDTStoreParam32Profile : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisInt<1>]>; def SDTCallArgProfile : SDTypeProfile<0, 2, [SDTCisInt<0>]>; def SDTCallArgMarkProfile : SDTypeProfile<0, 0, []>; def SDTCallVoidProfile : SDTypeProfile<0, 1, []>; def SDTCallValProfile : SDTypeProfile<1, 0, []>; def SDTMoveParamProfile : SDTypeProfile<1, 1, []>; -def SDTMoveRetvalProfile : SDTypeProfile<0, 1, []>; def SDTStoreRetvalProfile : SDTypeProfile<0, 2, [SDTCisInt<0>]>; +def SDTStoreRetvalV2Profile : SDTypeProfile<0, 3, [SDTCisInt<0>]>; +def SDTStoreRetvalV4Profile : SDTypeProfile<0, 5, [SDTCisInt<0>]>; def SDTPseudoUseParamProfile : SDTypeProfile<0, 1, []>; def DeclareParam : SDNode<"NVPTXISD::DeclareParam", SDTDeclareParamProfile, @@ -1776,18 +1607,24 @@ def DeclareRet : SDNode<"NVPTXISD::DeclareRet", SDTDeclareScalarParamProfile, [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; def LoadParam : SDNode<"NVPTXISD::LoadParam", SDTLoadParamProfile, [SDNPHasChain, SDNPMayLoad, SDNPOutGlue, SDNPInGlue]>; +def LoadParamV2 : SDNode<"NVPTXISD::LoadParamV2", SDTLoadParamV2Profile, + [SDNPHasChain, SDNPMayLoad, SDNPOutGlue, SDNPInGlue]>; +def LoadParamV4 : SDNode<"NVPTXISD::LoadParamV4", SDTLoadParamV4Profile, + [SDNPHasChain, SDNPMayLoad, SDNPOutGlue, SDNPInGlue]>; def PrintCall : SDNode<"NVPTXISD::PrintCall", SDTPrintCallProfile, [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; def PrintCallUni : SDNode<"NVPTXISD::PrintCallUni", SDTPrintCallUniProfile, [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; def StoreParam : SDNode<"NVPTXISD::StoreParam", SDTStoreParamProfile, [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; +def StoreParamV2 : SDNode<"NVPTXISD::StoreParamV2", SDTStoreParamV2Profile, + [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; +def StoreParamV4 : SDNode<"NVPTXISD::StoreParamV4", SDTStoreParamV4Profile, + [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; def StoreParamU32 : SDNode<"NVPTXISD::StoreParamU32", SDTStoreParam32Profile, [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; def StoreParamS32 : SDNode<"NVPTXISD::StoreParamS32", SDTStoreParam32Profile, [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; -def MoveToParam : SDNode<"NVPTXISD::MoveToParam", SDTStoreParamProfile, - [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; def CallArgBegin : SDNode<"NVPTXISD::CallArgBegin", SDTCallArgMarkProfile, [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; def CallArg : SDNode<"NVPTXISD::CallArg", SDTCallArgProfile, @@ -1804,12 +1641,12 @@ def CallVal : SDNode<"NVPTXISD::CallVal", SDTCallValProfile, [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; def MoveParam : SDNode<"NVPTXISD::MoveParam", SDTMoveParamProfile, []>; -def MoveRetval : SDNode<"NVPTXISD::MoveRetval", SDTMoveRetvalProfile, - [SDNPHasChain, SDNPSideEffect]>; def StoreRetval : SDNode<"NVPTXISD::StoreRetval", SDTStoreRetvalProfile, [SDNPHasChain, SDNPSideEffect]>; -def MoveToRetval : SDNode<"NVPTXISD::MoveToRetval", SDTStoreRetvalProfile, - [SDNPHasChain, SDNPSideEffect]>; +def StoreRetvalV2 : SDNode<"NVPTXISD::StoreRetvalV2", SDTStoreRetvalV2Profile, + [SDNPHasChain, SDNPSideEffect]>; +def StoreRetvalV4 : SDNode<"NVPTXISD::StoreRetvalV4", SDTStoreRetvalV4Profile, + [SDNPHasChain, SDNPSideEffect]>; def PseudoUseParam : SDNode<"NVPTXISD::PseudoUseParam", SDTPseudoUseParamProfile, [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; @@ -1820,7 +1657,7 @@ class LoadParamMemInst<NVPTXRegClass regclass, string opstr> : NVPTXInst<(outs regclass:$dst), (ins i32imm:$b), !strconcat(!strconcat("ld.param", opstr), "\t$dst, [retval0+$b];"), - [(set regclass:$dst, (LoadParam (i32 1), (i32 imm:$b)))]>; + []>; class LoadParamRegInst<NVPTXRegClass regclass, string opstr> : NVPTXInst<(outs regclass:$dst), (ins i32imm:$b), @@ -1828,35 +1665,57 @@ class LoadParamRegInst<NVPTXRegClass regclass, string opstr> : "\t$dst, retval$b;"), [(set regclass:$dst, (LoadParam (i32 0), (i32 imm:$b)))]>; +class LoadParamV2MemInst<NVPTXRegClass regclass, string opstr> : + NVPTXInst<(outs regclass:$dst, regclass:$dst2), (ins i32imm:$b), + !strconcat(!strconcat("ld.param.v2", opstr), + "\t{{$dst, $dst2}}, [retval0+$b];"), []>; + +class LoadParamV4MemInst<NVPTXRegClass regclass, string opstr> : + NVPTXInst<(outs regclass:$dst, regclass:$dst2, regclass:$dst3, + regclass:$dst4), + (ins i32imm:$b), + !strconcat(!strconcat("ld.param.v4", opstr), + "\t{{$dst, $dst2, $dst3, $dst4}}, [retval0+$b];"), []>; + class StoreParamInst<NVPTXRegClass regclass, string opstr> : NVPTXInst<(outs), (ins regclass:$val, i32imm:$a, i32imm:$b), !strconcat(!strconcat("st.param", opstr), "\t[param$a+$b], $val;"), - [(StoreParam (i32 imm:$a), (i32 imm:$b), regclass:$val)]>; + []>; -class MoveToParamInst<NVPTXRegClass regclass, string opstr> : - NVPTXInst<(outs), (ins regclass:$val, i32imm:$a, i32imm:$b), - !strconcat(!strconcat("mov", opstr), - "\tparam$a, $val;"), - [(MoveToParam (i32 imm:$a), (i32 imm:$b), regclass:$val)]>; +class StoreParamV2Inst<NVPTXRegClass regclass, string opstr> : + NVPTXInst<(outs), (ins regclass:$val, regclass:$val2, + i32imm:$a, i32imm:$b), + !strconcat(!strconcat("st.param.v2", opstr), + "\t[param$a+$b], {{$val, $val2}};"), + []>; + +class StoreParamV4Inst<NVPTXRegClass regclass, string opstr> : + NVPTXInst<(outs), (ins regclass:$val, regclass:$val1, regclass:$val2, + regclass:$val3, i32imm:$a, i32imm:$b), + !strconcat(!strconcat("st.param.v4", opstr), + "\t[param$a+$b], {{$val, $val2, $val3, $val4}};"), + []>; class StoreRetvalInst<NVPTXRegClass regclass, string opstr> : NVPTXInst<(outs), (ins regclass:$val, i32imm:$a), !strconcat(!strconcat("st.param", opstr), "\t[func_retval0+$a], $val;"), - [(StoreRetval (i32 imm:$a), regclass:$val)]>; + []>; -class MoveToRetvalInst<NVPTXRegClass regclass, string opstr> : - NVPTXInst<(outs), (ins i32imm:$num, regclass:$val), - !strconcat(!strconcat("mov", opstr), - "\tfunc_retval$num, $val;"), - [(MoveToRetval (i32 imm:$num), regclass:$val)]>; +class StoreRetvalV2Inst<NVPTXRegClass regclass, string opstr> : + NVPTXInst<(outs), (ins regclass:$val, regclass:$val2, i32imm:$a), + !strconcat(!strconcat("st.param.v2", opstr), + "\t[func_retval0+$a], {{$val, $val2}};"), + []>; -class MoveRetvalInst<NVPTXRegClass regclass, string opstr> : - NVPTXInst<(outs), (ins regclass:$val), - !strconcat(!strconcat("mov", opstr), - "\tfunc_retval0, $val;"), - [(MoveRetval regclass:$val)]>; +class StoreRetvalV4Inst<NVPTXRegClass regclass, string opstr> : + NVPTXInst<(outs), + (ins regclass:$val, regclass:$val2, regclass:$val3, + regclass:$val4, i32imm:$a), + !strconcat(!strconcat("st.param.v4", opstr), + "\t[func_retval0+$a], {{$val, $val2, $val3, $val4}};"), + []>; def PrintCallRetInst1 : NVPTXInst<(outs), (ins), "call (retval0), ", @@ -1919,126 +1778,81 @@ def PrintCallUniNoRetInst : NVPTXInst<(outs), (ins), "call.uni ", def LoadParamMemI64 : LoadParamMemInst<Int64Regs, ".b64">; def LoadParamMemI32 : LoadParamMemInst<Int32Regs, ".b32">; def LoadParamMemI16 : LoadParamMemInst<Int16Regs, ".b16">; -def LoadParamMemI8 : LoadParamMemInst<Int8Regs, ".b8">; - -//def LoadParamMemI16 : NVPTXInst<(outs Int16Regs:$dst), (ins i32imm:$b), -// !strconcat("ld.param.b32\ttemp_param_reg, [retval0+$b];\n\t", -// "cvt.u16.u32\t$dst, temp_param_reg;"), -// [(set Int16Regs:$dst, (LoadParam (i32 1), (i32 imm:$b)))]>; -//def LoadParamMemI8 : NVPTXInst<(outs Int8Regs:$dst), (ins i32imm:$b), -// !strconcat("ld.param.b32\ttemp_param_reg, [retval0+$b];\n\t", -// "cvt.u16.u32\t$dst, temp_param_reg;"), -// [(set Int8Regs:$dst, (LoadParam (i32 1), (i32 imm:$b)))]>; - +def LoadParamMemI8 : LoadParamMemInst<Int16Regs, ".b8">; +def LoadParamMemV2I64 : LoadParamV2MemInst<Int64Regs, ".b64">; +def LoadParamMemV2I32 : LoadParamV2MemInst<Int32Regs, ".b32">; +def LoadParamMemV2I16 : LoadParamV2MemInst<Int16Regs, ".b16">; +def LoadParamMemV2I8 : LoadParamV2MemInst<Int16Regs, ".b8">; +def LoadParamMemV4I32 : LoadParamV4MemInst<Int32Regs, ".b32">; +def LoadParamMemV4I16 : LoadParamV4MemInst<Int16Regs, ".b16">; +def LoadParamMemV4I8 : LoadParamV4MemInst<Int16Regs, ".b8">; def LoadParamMemF32 : LoadParamMemInst<Float32Regs, ".f32">; def LoadParamMemF64 : LoadParamMemInst<Float64Regs, ".f64">; - -def LoadParamRegI64 : LoadParamRegInst<Int64Regs, ".b64">; -def LoadParamRegI32 : LoadParamRegInst<Int32Regs, ".b32">; -def LoadParamRegI16 : NVPTXInst<(outs Int16Regs:$dst), (ins i32imm:$b), - "cvt.u16.u32\t$dst, retval$b;", - [(set Int16Regs:$dst, - (LoadParam (i32 0), (i32 imm:$b)))]>; -def LoadParamRegI8 : NVPTXInst<(outs Int8Regs:$dst), (ins i32imm:$b), - "cvt.u16.u32\t$dst, retval$b;", - [(set Int8Regs:$dst, - (LoadParam (i32 0), (i32 imm:$b)))]>; - -def LoadParamRegF32 : LoadParamRegInst<Float32Regs, ".f32">; -def LoadParamRegF64 : LoadParamRegInst<Float64Regs, ".f64">; +def LoadParamMemV2F32 : LoadParamV2MemInst<Float32Regs, ".f32">; +def LoadParamMemV2F64 : LoadParamV2MemInst<Float64Regs, ".f64">; +def LoadParamMemV4F32 : LoadParamV4MemInst<Float32Regs, ".f32">; def StoreParamI64 : StoreParamInst<Int64Regs, ".b64">; def StoreParamI32 : StoreParamInst<Int32Regs, ".b32">; -def StoreParamI16 : NVPTXInst<(outs), - (ins Int16Regs:$val, i32imm:$a, i32imm:$b), - "st.param.b16\t[param$a+$b], $val;", - [(StoreParam (i32 imm:$a), (i32 imm:$b), Int16Regs:$val)]>; - -def StoreParamI8 : NVPTXInst<(outs), - (ins Int8Regs:$val, i32imm:$a, i32imm:$b), - "st.param.b8\t[param$a+$b], $val;", - [(StoreParam - (i32 imm:$a), (i32 imm:$b), Int8Regs:$val)]>; - -def StoreParamS32I16 : NVPTXInst<(outs), - (ins Int16Regs:$val, i32imm:$a, i32imm:$b), - !strconcat("cvt.s32.s16\ttemp_param_reg, $val;\n\t", - "st.param.b32\t[param$a+$b], temp_param_reg;"), - [(StoreParamS32 (i32 imm:$a), (i32 imm:$b), Int16Regs:$val)]>; -def StoreParamU32I16 : NVPTXInst<(outs), - (ins Int16Regs:$val, i32imm:$a, i32imm:$b), - !strconcat("cvt.u32.u16\ttemp_param_reg, $val;\n\t", - "st.param.b32\t[param$a+$b], temp_param_reg;"), - [(StoreParamU32 (i32 imm:$a), (i32 imm:$b), Int16Regs:$val)]>; - -def StoreParamU32I8 : NVPTXInst<(outs), - (ins Int8Regs:$val, i32imm:$a, i32imm:$b), - !strconcat("cvt.u32.u8\ttemp_param_reg, $val;\n\t", - "st.param.b32\t[param$a+$b], temp_param_reg;"), - [(StoreParamU32 (i32 imm:$a), (i32 imm:$b), Int8Regs:$val)]>; -def StoreParamS32I8 : NVPTXInst<(outs), - (ins Int8Regs:$val, i32imm:$a, i32imm:$b), - !strconcat("cvt.s32.s8\ttemp_param_reg, $val;\n\t", - "st.param.b32\t[param$a+$b], temp_param_reg;"), - [(StoreParamS32 (i32 imm:$a), (i32 imm:$b), Int8Regs:$val)]>; +def StoreParamI16 : StoreParamInst<Int16Regs, ".b16">; +def StoreParamI8 : StoreParamInst<Int16Regs, ".b8">; +def StoreParamV2I64 : StoreParamV2Inst<Int64Regs, ".b64">; +def StoreParamV2I32 : StoreParamV2Inst<Int32Regs, ".b32">; +def StoreParamV2I16 : StoreParamV2Inst<Int16Regs, ".b16">; +def StoreParamV2I8 : StoreParamV2Inst<Int16Regs, ".b8">; + +// FIXME: StoreParamV4Inst crashes llvm-tblgen :( +//def StoreParamV4I32 : StoreParamV4Inst<Int32Regs, ".b32">; +def StoreParamV4I32 : NVPTXInst<(outs), (ins Int32Regs:$val, Int32Regs:$val2, + Int32Regs:$val3, Int32Regs:$val4, + i32imm:$a, i32imm:$b), + "st.param.b32\t[param$a+$b], {{$val, $val2, $val3, $val4}};", + []>; + +def StoreParamV4I16 : NVPTXInst<(outs), (ins Int16Regs:$val, Int16Regs:$val2, + Int16Regs:$val3, Int16Regs:$val4, + i32imm:$a, i32imm:$b), + "st.param.v4.b16\t[param$a+$b], {{$val, $val2, $val3, $val4}};", + []>; + +def StoreParamV4I8 : NVPTXInst<(outs), (ins Int16Regs:$val, Int16Regs:$val2, + Int16Regs:$val3, Int16Regs:$val4, + i32imm:$a, i32imm:$b), + "st.param.v4.b8\t[param$a+$b], {{$val, $val2, $val3, $val4}};", + []>; def StoreParamF32 : StoreParamInst<Float32Regs, ".f32">; def StoreParamF64 : StoreParamInst<Float64Regs, ".f64">; +def StoreParamV2F32 : StoreParamV2Inst<Float32Regs, ".f32">; +def StoreParamV2F64 : StoreParamV2Inst<Float64Regs, ".f64">; +// FIXME: StoreParamV4Inst crashes llvm-tblgen :( +//def StoreParamV4F32 : StoreParamV4Inst<Float32Regs, ".f32">; +def StoreParamV4F32 : NVPTXInst<(outs), + (ins Float32Regs:$val, Float32Regs:$val2, + Float32Regs:$val3, Float32Regs:$val4, + i32imm:$a, i32imm:$b), + "st.param.v4.f32\t[param$a+$b], {{$val, $val2, $val3, $val4}};", + []>; -def MoveToParamI64 : MoveToParamInst<Int64Regs, ".b64">; -def MoveToParamI32 : MoveToParamInst<Int32Regs, ".b32">; -def MoveToParamF64 : MoveToParamInst<Float64Regs, ".f64">; -def MoveToParamF32 : MoveToParamInst<Float32Regs, ".f32">; -def MoveToParamI16 : NVPTXInst<(outs), - (ins Int16Regs:$val, i32imm:$a, i32imm:$b), - !strconcat("cvt.u32.u16\ttemp_param_reg, $val;\n\t", - "mov.b32\tparam$a, temp_param_reg;"), - [(MoveToParam (i32 imm:$a), (i32 imm:$b), Int16Regs:$val)]>; -def MoveToParamI8 : NVPTXInst<(outs), - (ins Int8Regs:$val, i32imm:$a, i32imm:$b), - !strconcat("cvt.u32.u16\ttemp_param_reg, $val;\n\t", - "mov.b32\tparam$a, temp_param_reg;"), - [(MoveToParam (i32 imm:$a), (i32 imm:$b), Int8Regs:$val)]>; def StoreRetvalI64 : StoreRetvalInst<Int64Regs, ".b64">; def StoreRetvalI32 : StoreRetvalInst<Int32Regs, ".b32">; def StoreRetvalI16 : StoreRetvalInst<Int16Regs, ".b16">; -def StoreRetvalI8 : StoreRetvalInst<Int8Regs, ".b8">; - -//def StoreRetvalI16 : NVPTXInst<(outs), (ins Int16Regs:$val, i32imm:$a), -// !strconcat("\{\n\t", -// !strconcat(".reg .b32 temp_retval_reg;\n\t", -// !strconcat("cvt.u32.u16\ttemp_retval_reg, $val;\n\t", -// "st.param.b32\t[func_retval0+$a], temp_retval_reg;\n\t\}"))), -// [(StoreRetval (i32 imm:$a), Int16Regs:$val)]>; -//def StoreRetvalI8 : NVPTXInst<(outs), (ins Int8Regs:$val, i32imm:$a), -// !strconcat("\{\n\t", -// !strconcat(".reg .b32 temp_retval_reg;\n\t", -// !strconcat("cvt.u32.u16\ttemp_retval_reg, $val;\n\t", -// "st.param.b32\t[func_retval0+$a], temp_retval_reg;\n\t\}"))), -// [(StoreRetval (i32 imm:$a), Int8Regs:$val)]>; +def StoreRetvalI8 : StoreRetvalInst<Int16Regs, ".b8">; +def StoreRetvalV2I64 : StoreRetvalV2Inst<Int64Regs, ".b64">; +def StoreRetvalV2I32 : StoreRetvalV2Inst<Int32Regs, ".b32">; +def StoreRetvalV2I16 : StoreRetvalV2Inst<Int16Regs, ".b16">; +def StoreRetvalV2I8 : StoreRetvalV2Inst<Int16Regs, ".b8">; +def StoreRetvalV4I32 : StoreRetvalV4Inst<Int32Regs, ".b32">; +def StoreRetvalV4I16 : StoreRetvalV4Inst<Int16Regs, ".b16">; +def StoreRetvalV4I8 : StoreRetvalV4Inst<Int16Regs, ".b8">; def StoreRetvalF64 : StoreRetvalInst<Float64Regs, ".f64">; def StoreRetvalF32 : StoreRetvalInst<Float32Regs, ".f32">; - -def MoveRetvalI64 : MoveRetvalInst<Int64Regs, ".b64">; -def MoveRetvalI32 : MoveRetvalInst<Int32Regs, ".b32">; -def MoveRetvalI16 : MoveRetvalInst<Int16Regs, ".b16">; -def MoveRetvalI8 : MoveRetvalInst<Int8Regs, ".b8">; -def MoveRetvalF64 : MoveRetvalInst<Float64Regs, ".f64">; -def MoveRetvalF32 : MoveRetvalInst<Float32Regs, ".f32">; - -def MoveToRetvalI64 : MoveToRetvalInst<Int64Regs, ".b64">; -def MoveToRetvalI32 : MoveToRetvalInst<Int32Regs, ".b32">; -def MoveToRetvalF64 : MoveToRetvalInst<Float64Regs, ".f64">; -def MoveToRetvalF32 : MoveToRetvalInst<Float32Regs, ".f32">; -def MoveToRetvalI16 : NVPTXInst<(outs), (ins i32imm:$num, Int16Regs:$val), - "cvt.u32.u16\tfunc_retval$num, $val;", - [(MoveToRetval (i32 imm:$num), Int16Regs:$val)]>; -def MoveToRetvalI8 : NVPTXInst<(outs), (ins i32imm:$num, Int8Regs:$val), - "cvt.u32.u16\tfunc_retval$num, $val;", - [(MoveToRetval (i32 imm:$num), Int8Regs:$val)]>; +def StoreRetvalV2F64 : StoreRetvalV2Inst<Float64Regs, ".f64">; +def StoreRetvalV2F32 : StoreRetvalV2Inst<Float32Regs, ".f32">; +def StoreRetvalV4F32 : StoreRetvalV4Inst<Float32Regs, ".f32">; def CallArgBeginInst : NVPTXInst<(outs), (ins), "(", [(CallArgBegin)]>; def CallArgEndInst1 : NVPTXInst<(outs), (ins), ");", [(CallArgEnd (i32 1))]>; @@ -2056,7 +1870,6 @@ class LastCallArgInst<NVPTXRegClass regclass> : def CallArgI64 : CallArgInst<Int64Regs>; def CallArgI32 : CallArgInst<Int32Regs>; def CallArgI16 : CallArgInst<Int16Regs>; -def CallArgI8 : CallArgInst<Int8Regs>; def CallArgF64 : CallArgInst<Float64Regs>; def CallArgF32 : CallArgInst<Float32Regs>; @@ -2064,7 +1877,6 @@ def CallArgF32 : CallArgInst<Float32Regs>; def LastCallArgI64 : LastCallArgInst<Int64Regs>; def LastCallArgI32 : LastCallArgInst<Int32Regs>; def LastCallArgI16 : LastCallArgInst<Int16Regs>; -def LastCallArgI8 : LastCallArgInst<Int8Regs>; def LastCallArgF64 : LastCallArgInst<Float64Regs>; def LastCallArgF32 : LastCallArgInst<Float32Regs>; @@ -2124,9 +1936,6 @@ def MoveParamI32 : MoveParamInst<Int32Regs, ".b32">; def MoveParamI16 : NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$src), "cvt.u16.u32\t$dst, $src;", [(set Int16Regs:$dst, (MoveParam Int16Regs:$src))]>; -def MoveParamI8 : NVPTXInst<(outs Int8Regs:$dst), (ins Int8Regs:$src), - "cvt.u16.u32\t$dst, $src;", - [(set Int8Regs:$dst, (MoveParam Int8Regs:$src))]>; def MoveParamF64 : MoveParamInst<Float64Regs, ".f64">; def MoveParamF32 : MoveParamInst<Float32Regs, ".f32">; @@ -2138,7 +1947,6 @@ class PseudoUseParamInst<NVPTXRegClass regclass> : def PseudoUseParamI64 : PseudoUseParamInst<Int64Regs>; def PseudoUseParamI32 : PseudoUseParamInst<Int32Regs>; def PseudoUseParamI16 : PseudoUseParamInst<Int16Regs>; -def PseudoUseParamI8 : PseudoUseParamInst<Int8Regs>; def PseudoUseParamF64 : PseudoUseParamInst<Float64Regs>; def PseudoUseParamF32 : PseudoUseParamInst<Float32Regs>; @@ -2180,7 +1988,7 @@ multiclass LD<NVPTXRegClass regclass> { } let mayLoad=1, neverHasSideEffects=1 in { -defm LD_i8 : LD<Int8Regs>; +defm LD_i8 : LD<Int16Regs>; defm LD_i16 : LD<Int16Regs>; defm LD_i32 : LD<Int32Regs>; defm LD_i64 : LD<Int64Regs>; @@ -2222,7 +2030,7 @@ multiclass ST<NVPTXRegClass regclass> { } let mayStore=1, neverHasSideEffects=1 in { -defm ST_i8 : ST<Int8Regs>; +defm ST_i8 : ST<Int16Regs>; defm ST_i16 : ST<Int16Regs>; defm ST_i32 : ST<Int32Regs>; defm ST_i64 : ST<Int64Regs>; @@ -2306,7 +2114,7 @@ multiclass LD_VEC<NVPTXRegClass regclass> { []>; } let mayLoad=1, neverHasSideEffects=1 in { -defm LDV_i8 : LD_VEC<Int8Regs>; +defm LDV_i8 : LD_VEC<Int16Regs>; defm LDV_i16 : LD_VEC<Int16Regs>; defm LDV_i32 : LD_VEC<Int32Regs>; defm LDV_i64 : LD_VEC<Int64Regs>; @@ -2389,7 +2197,7 @@ multiclass ST_VEC<NVPTXRegClass regclass> { []>; } let mayStore=1, neverHasSideEffects=1 in { -defm STV_i8 : ST_VEC<Int8Regs>; +defm STV_i8 : ST_VEC<Int16Regs>; defm STV_i16 : ST_VEC<Int16Regs>; defm STV_i32 : ST_VEC<Int32Regs>; defm STV_i64 : ST_VEC<Int64Regs>; @@ -2400,291 +2208,6 @@ defm STV_f64 : ST_VEC<Float64Regs>; //---- Conversion ---- -multiclass CVT_INT_TO_FP <string OpStr, SDNode OpNode> { -// FIXME: need to add f16 support -// def CVTf16i8 : -// NVPTXInst<(outs Float16Regs:$d), (ins Int8Regs:$a), -// !strconcat(!strconcat("cvt.rn.f16.", OpStr), "8 \t$d, $a;"), -// [(set Float16Regs:$d, (OpNode Int8Regs:$a))]>; -// def CVTf16i16 : -// NVPTXInst<(outs Float16Regs:$d), (ins Int16Regs:$a), -// !strconcat(!strconcat("cvt.rn.f16.", OpStr), "16 \t$d, $a;"), -// [(set Float16Regs:$d, (OpNode Int16Regs:$a))]>; -// def CVTf16i32 : -// NVPTXInst<(outs Float16Regs:$d), (ins Int32Regs:$a), -// !strconcat(!strconcat("cvt.rn.f16.", OpStr), "32 \t$d, $a;"), -// [(set Float16Regs:$d, (OpNode Int32Regs:$a))]>; -// def CVTf16i64: -// NVPTXInst<(outs Float16Regs:$d), (ins Int64Regs:$a), -// !strconcat(!strconcat("cvt.rn.f32.", OpStr), "64 \t$d, $a;"), -// [(set Float32Regs:$d, (OpNode Int64Regs:$a))]>; - - def CVTf32i1 : - NVPTXInst<(outs Float32Regs:$d), (ins Int1Regs:$a), - "selp.f32 \t$d, 1.0, 0.0, $a;", - [(set Float32Regs:$d, (OpNode Int1Regs:$a))]>; - def CVTf32i8 : - NVPTXInst<(outs Float32Regs:$d), (ins Int8Regs:$a), - !strconcat(!strconcat("cvt.rn.f32.", OpStr), "8 \t$d, $a;"), - [(set Float32Regs:$d, (OpNode Int8Regs:$a))]>; - def CVTf32i16 : - NVPTXInst<(outs Float32Regs:$d), (ins Int16Regs:$a), - !strconcat(!strconcat("cvt.rn.f32.", OpStr), "16 \t$d, $a;"), - [(set Float32Regs:$d, (OpNode Int16Regs:$a))]>; - def CVTf32i32 : - NVPTXInst<(outs Float32Regs:$d), (ins Int32Regs:$a), - !strconcat(!strconcat("cvt.rn.f32.", OpStr), "32 \t$d, $a;"), - [(set Float32Regs:$d, (OpNode Int32Regs:$a))]>; - def CVTf32i64: - NVPTXInst<(outs Float32Regs:$d), (ins Int64Regs:$a), - !strconcat(!strconcat("cvt.rn.f32.", OpStr), "64 \t$d, $a;"), - [(set Float32Regs:$d, (OpNode Int64Regs:$a))]>; - - def CVTf64i1 : - NVPTXInst<(outs Float64Regs:$d), (ins Int1Regs:$a), - "selp.f64 \t$d, 1.0, 0.0, $a;", - [(set Float64Regs:$d, (OpNode Int1Regs:$a))]>; - def CVTf64i8 : - NVPTXInst<(outs Float64Regs:$d), (ins Int8Regs:$a), - !strconcat(!strconcat("cvt.rn.f64.", OpStr), "8 \t$d, $a;"), - [(set Float64Regs:$d, (OpNode Int8Regs:$a))]>; - def CVTf64i16 : - NVPTXInst<(outs Float64Regs:$d), (ins Int16Regs:$a), - !strconcat(!strconcat("cvt.rn.f64.", OpStr), "16 \t$d, $a;"), - [(set Float64Regs:$d, (OpNode Int16Regs:$a))]>; - def CVTf64i32 : - NVPTXInst<(outs Float64Regs:$d), (ins Int32Regs:$a), - !strconcat(!strconcat("cvt.rn.f64.", OpStr), "32 \t$d, $a;"), - [(set Float64Regs:$d, (OpNode Int32Regs:$a))]>; - def CVTf64i64: - NVPTXInst<(outs Float64Regs:$d), (ins Int64Regs:$a), - !strconcat(!strconcat("cvt.rn.f64.", OpStr), "64 \t$d, $a;"), - [(set Float64Regs:$d, (OpNode Int64Regs:$a))]>; -} - -defm Sint_to_fp : CVT_INT_TO_FP <"s", sint_to_fp>; -defm Uint_to_fp : CVT_INT_TO_FP <"u", uint_to_fp>; - -multiclass CVT_FP_TO_INT <string OpStr, SDNode OpNode> { -// FIXME: need to add f16 support -// def CVTi8f16: -// NVPTXInst<(outs Int8Regs:$d), (ins Float16Regs:$a), -// !strconcat(!strconcat("cvt.rzi.", OpStr), "8.f16 $d, $a;"), -// [(set Int8Regs:$d, (OpNode Float16Regs:$a))]>; - def CVTi8f32_ftz: - NVPTXInst<(outs Int8Regs:$d), (ins Float32Regs:$a), - !strconcat(!strconcat("cvt.rzi.ftz.", OpStr), "16.f32 \t$d, $a;"), - [(set Int8Regs:$d, (OpNode Float32Regs:$a))]>, Requires<[doF32FTZ]>; - def CVTi8f32: - NVPTXInst<(outs Int8Regs:$d), (ins Float32Regs:$a), - !strconcat(!strconcat("cvt.rzi.", OpStr), "16.f32 \t$d, $a;"), - [(set Int8Regs:$d, (OpNode Float32Regs:$a))]>; - def CVTi8f64: - NVPTXInst<(outs Int8Regs:$d), (ins Float64Regs:$a), - !strconcat(!strconcat("cvt.rzi.", OpStr), "16.f64 \t$d, $a;"), - [(set Int8Regs:$d, (OpNode Float64Regs:$a))]>; - -// FIXME: need to add f16 support -// def CVTi16f16: -// NVPTXInst<(outs Int16Regs:$d), (ins Float16Regs:$a), -// !strconcat(!strconcat("cvt.rzi.", OpStr), "16.f16 \t$d, $a;"), -// [(set Int16Regs:$d, (OpNode Float16Regs:$a))]>; - def CVTi16f32_ftz: - NVPTXInst<(outs Int16Regs:$d), (ins Float32Regs:$a), - !strconcat(!strconcat("cvt.rzi.ftz.", OpStr), "16.f32 \t$d, $a;"), - [(set Int16Regs:$d, (OpNode Float32Regs:$a))]>, Requires<[doF32FTZ]>; - def CVTi16f32: - NVPTXInst<(outs Int16Regs:$d), (ins Float32Regs:$a), - !strconcat(!strconcat("cvt.rzi.", OpStr), "16.f32 \t$d, $a;"), - [(set Int16Regs:$d, (OpNode Float32Regs:$a))]>; - def CVTi16f64: - NVPTXInst<(outs Int16Regs:$d), (ins Float64Regs:$a), - !strconcat(!strconcat("cvt.rzi.", OpStr), "16.f64 \t$d, $a;"), - [(set Int16Regs:$d, (OpNode Float64Regs:$a))]>; - -// FIXME: need to add f16 support -// def CVTi32f16: def CVTi32f16: -// NVPTXInst<(outs Int32Regs:$d), (ins Float16Regs:$a), -// !strconcat(!strconcat("cvt.rzi.", OpStr), "32.f16 \t$d, $a;"), -// [(set Int32Regs:$d, (OpNode Float16Regs:$a))]>; - def CVTi32f32_ftz: - NVPTXInst<(outs Int32Regs:$d), (ins Float32Regs:$a), - !strconcat(!strconcat("cvt.rzi.ftz.", OpStr), "32.f32 \t$d, $a;"), - [(set Int32Regs:$d, (OpNode Float32Regs:$a))]>, Requires<[doF32FTZ]>; - def CVTi32f32: - NVPTXInst<(outs Int32Regs:$d), (ins Float32Regs:$a), - !strconcat(!strconcat("cvt.rzi.", OpStr), "32.f32 \t$d, $a;"), - [(set Int32Regs:$d, (OpNode Float32Regs:$a))]>; - def CVTi32f64: - NVPTXInst<(outs Int32Regs:$d), (ins Float64Regs:$a), - !strconcat(!strconcat("cvt.rzi.", OpStr), "32.f64 \t$d, $a;"), - [(set Int32Regs:$d, (OpNode Float64Regs:$a))]>; - -// FIXME: need to add f16 support -// def CVTi64f16: -// NVPTXInst<(outs Int64Regs:$d), (ins Float16Regs:$a), -// !strconcat(!strconcat("cvt.rzi.", OpStr), "64.f16 \t$d, $a;"), -// [(set Int64Regs:$d, (OpNode Float16Regs:$a))]>; - def CVTi64f32_ftz: - NVPTXInst<(outs Int64Regs:$d), (ins Float32Regs:$a), - !strconcat(!strconcat("cvt.rzi.ftz.", OpStr), "64.f32 \t$d, $a;"), - [(set Int64Regs:$d, (OpNode Float32Regs:$a))]>, Requires<[doF32FTZ]>; - def CVTi64f32: - NVPTXInst<(outs Int64Regs:$d), (ins Float32Regs:$a), - !strconcat(!strconcat("cvt.rzi.", OpStr), "64.f32 \t$d, $a;"), - [(set Int64Regs:$d, (OpNode Float32Regs:$a))]>; - def CVTi64f64: - NVPTXInst<(outs Int64Regs:$d), (ins Float64Regs:$a), - !strconcat(!strconcat("cvt.rzi.", OpStr), "64.f64 \t$d, $a;"), - [(set Int64Regs:$d, (OpNode Float64Regs:$a))]>; -} - -defm Fp_to_sint : CVT_FP_TO_INT <"s", fp_to_sint>; -defm Fp_to_uint : CVT_FP_TO_INT <"u", fp_to_uint>; - -multiclass INT_EXTEND_UNSIGNED_1 <SDNode OpNode> { - def ext1to8: - NVPTXInst<(outs Int8Regs:$d), (ins Int1Regs:$a), - "selp.u16 \t$d, 1, 0, $a;", - [(set Int8Regs:$d, (OpNode Int1Regs:$a))]>; - def ext1to16: - NVPTXInst<(outs Int16Regs:$d), (ins Int1Regs:$a), - "selp.u16 \t$d, 1, 0, $a;", - [(set Int16Regs:$d, (OpNode Int1Regs:$a))]>; - def ext1to32: - NVPTXInst<(outs Int32Regs:$d), (ins Int1Regs:$a), - "selp.u32 \t$d, 1, 0, $a;", - [(set Int32Regs:$d, (OpNode Int1Regs:$a))]>; - def ext1to64: - NVPTXInst<(outs Int64Regs:$d), (ins Int1Regs:$a), - "selp.u64 \t$d, 1, 0, $a;", - [(set Int64Regs:$d, (OpNode Int1Regs:$a))]>; -} - -multiclass INT_EXTEND_SIGNED_1 <SDNode OpNode> { - def ext1to8: - NVPTXInst<(outs Int8Regs:$d), (ins Int1Regs:$a), - "selp.s16 \t$d, -1, 0, $a;", - [(set Int8Regs:$d, (OpNode Int1Regs:$a))]>; - def ext1to16: - NVPTXInst<(outs Int16Regs:$d), (ins Int1Regs:$a), - "selp.s16 \t$d, -1, 0, $a;", - [(set Int16Regs:$d, (OpNode Int1Regs:$a))]>; - def ext1to32: - NVPTXInst<(outs Int32Regs:$d), (ins Int1Regs:$a), - "selp.s32 \t$d, -1, 0, $a;", - [(set Int32Regs:$d, (OpNode Int1Regs:$a))]>; - def ext1to64: - NVPTXInst<(outs Int64Regs:$d), (ins Int1Regs:$a), - "selp.s64 \t$d, -1, 0, $a;", - [(set Int64Regs:$d, (OpNode Int1Regs:$a))]>; -} - -multiclass INT_EXTEND <string OpStr, SDNode OpNode> { - // All Int8Regs are emiited as 16bit registers in ptx. - // And there is no selp.u8 in ptx. - def ext8to16: - NVPTXInst<(outs Int16Regs:$d), (ins Int8Regs:$a), - !strconcat("cvt.", !strconcat(OpStr, !strconcat("16.", - !strconcat(OpStr, "8 \t$d, $a;")))), - [(set Int16Regs:$d, (OpNode Int8Regs:$a))]>; - def ext8to32: - NVPTXInst<(outs Int32Regs:$d), (ins Int8Regs:$a), - !strconcat("cvt.", !strconcat(OpStr, !strconcat("32.", - !strconcat(OpStr, "8 \t$d, $a;")))), - [(set Int32Regs:$d, (OpNode Int8Regs:$a))]>; - def ext8to64: - NVPTXInst<(outs Int64Regs:$d), (ins Int8Regs:$a), - !strconcat("cvt.", !strconcat(OpStr, !strconcat("64.", - !strconcat(OpStr, "8 \t$d, $a;")))), - [(set Int64Regs:$d, (OpNode Int8Regs:$a))]>; - def ext16to32: - NVPTXInst<(outs Int32Regs:$d), (ins Int16Regs:$a), - !strconcat("cvt.", !strconcat(OpStr, !strconcat("32.", - !strconcat(OpStr, "16 \t$d, $a;")))), - [(set Int32Regs:$d, (OpNode Int16Regs:$a))]>; - def ext16to64: - NVPTXInst<(outs Int64Regs:$d), (ins Int16Regs:$a), - !strconcat("cvt.", !strconcat(OpStr, !strconcat("64.", - !strconcat(OpStr, "16 \t$d, $a;")))), - [(set Int64Regs:$d, (OpNode Int16Regs:$a))]>; - def ext32to64: - NVPTXInst<(outs Int64Regs:$d), (ins Int32Regs:$a), - !strconcat("cvt.", !strconcat(OpStr, !strconcat("64.", - !strconcat(OpStr, "32 \t$d, $a;")))), - [(set Int64Regs:$d, (OpNode Int32Regs:$a))]>; -} - -defm Sint_extend_1 : INT_EXTEND_SIGNED_1<sext>; -defm Zint_extend_1 : INT_EXTEND_UNSIGNED_1<zext>; -defm Aint_extend_1 : INT_EXTEND_UNSIGNED_1<anyext>; - -defm Sint_extend : INT_EXTEND <"s", sext>; -defm Zint_extend : INT_EXTEND <"u", zext>; -defm Aint_extend : INT_EXTEND <"u", anyext>; - -class TRUNC_to1_asm<string sz> { - string s = !strconcat("{{\n\t", - !strconcat(".reg ", - !strconcat(sz, - !strconcat(" temp;\n\t", - !strconcat("and", - !strconcat(sz, - !strconcat("\t temp, $a, 1;\n\t", - !strconcat("setp", - !strconcat(sz, ".eq \t $d, temp, 1;\n\t}}"))))))))); -} - -def TRUNC_64to32 : NVPTXInst<(outs Int32Regs:$d), (ins Int64Regs:$a), - "cvt.u32.u64 \t$d, $a;", - [(set Int32Regs:$d, (trunc Int64Regs:$a))]>; -def TRUNC_64to16 : NVPTXInst<(outs Int16Regs:$d), (ins Int64Regs:$a), - "cvt.u16.u64 \t$d, $a;", - [(set Int16Regs:$d, (trunc Int64Regs:$a))]>; -def TRUNC_64to8 : NVPTXInst<(outs Int8Regs:$d), (ins Int64Regs:$a), - "cvt.u8.u64 \t$d, $a;", - [(set Int8Regs:$d, (trunc Int64Regs:$a))]>; -def TRUNC_32to16 : NVPTXInst<(outs Int16Regs:$d), (ins Int32Regs:$a), - "cvt.u16.u32 \t$d, $a;", - [(set Int16Regs:$d, (trunc Int32Regs:$a))]>; -def TRUNC_32to8 : NVPTXInst<(outs Int8Regs:$d), (ins Int32Regs:$a), - "cvt.u8.u32 \t$d, $a;", - [(set Int8Regs:$d, (trunc Int32Regs:$a))]>; -def TRUNC_16to8 : NVPTXInst<(outs Int8Regs:$d), (ins Int16Regs:$a), - "cvt.u8.u16 \t$d, $a;", - [(set Int8Regs:$d, (trunc Int16Regs:$a))]>; -def TRUNC_64to1 : NVPTXInst<(outs Int1Regs:$d), (ins Int64Regs:$a), - TRUNC_to1_asm<".b64">.s, - [(set Int1Regs:$d, (trunc Int64Regs:$a))]>; -def TRUNC_32to1 : NVPTXInst<(outs Int1Regs:$d), (ins Int32Regs:$a), - TRUNC_to1_asm<".b32">.s, - [(set Int1Regs:$d, (trunc Int32Regs:$a))]>; -def TRUNC_16to1 : NVPTXInst<(outs Int1Regs:$d), (ins Int16Regs:$a), - TRUNC_to1_asm<".b16">.s, - [(set Int1Regs:$d, (trunc Int16Regs:$a))]>; -def TRUNC_8to1 : NVPTXInst<(outs Int1Regs:$d), (ins Int8Regs:$a), - TRUNC_to1_asm<".b16">.s, - [(set Int1Regs:$d, (trunc Int8Regs:$a))]>; - -// Select instructions -def : Pat<(select Int32Regs:$pred, Int8Regs:$a, Int8Regs:$b), - (SELECTi8rr Int8Regs:$a, Int8Regs:$b, (TRUNC_32to1 Int32Regs:$pred))>; -def : Pat<(select Int32Regs:$pred, Int16Regs:$a, Int16Regs:$b), - (SELECTi16rr Int16Regs:$a, Int16Regs:$b, - (TRUNC_32to1 Int32Regs:$pred))>; -def : Pat<(select Int32Regs:$pred, Int32Regs:$a, Int32Regs:$b), - (SELECTi32rr Int32Regs:$a, Int32Regs:$b, - (TRUNC_32to1 Int32Regs:$pred))>; -def : Pat<(select Int32Regs:$pred, Int64Regs:$a, Int64Regs:$b), - (SELECTi64rr Int64Regs:$a, Int64Regs:$b, - (TRUNC_32to1 Int32Regs:$pred))>; -def : Pat<(select Int32Regs:$pred, Float32Regs:$a, Float32Regs:$b), - (SELECTf32rr Float32Regs:$a, Float32Regs:$b, - (TRUNC_32to1 Int32Regs:$pred))>; -def : Pat<(select Int32Regs:$pred, Float64Regs:$a, Float64Regs:$b), - (SELECTf64rr Float64Regs:$a, Float64Regs:$b, - (TRUNC_32to1 Int32Regs:$pred))>; - class F_BITCONVERT<string SzStr, NVPTXRegClass regclassIn, NVPTXRegClass regclassOut> : NVPTXInst<(outs regclassOut:$d), (ins regclassIn:$a), @@ -2696,29 +2219,209 @@ def BITCONVERT_32_F2I : F_BITCONVERT<"32", Float32Regs, Int32Regs>; def BITCONVERT_64_I2F : F_BITCONVERT<"64", Int64Regs, Float64Regs>; def BITCONVERT_64_F2I : F_BITCONVERT<"64", Float64Regs, Int64Regs>; +// NOTE: pred->fp are currently sub-optimal due to an issue in TableGen where +// we cannot specify floating-point literals in isel patterns. Therefore, we +// use an integer selp to select either 1 or 0 and then cvt to floating-point. + +// sint -> f32 +def : Pat<(f32 (sint_to_fp Int1Regs:$a)), + (CVT_f32_s32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>; +def : Pat<(f32 (sint_to_fp Int16Regs:$a)), + (CVT_f32_s16 Int16Regs:$a, CvtRN)>; +def : Pat<(f32 (sint_to_fp Int32Regs:$a)), + (CVT_f32_s32 Int32Regs:$a, CvtRN)>; +def : Pat<(f32 (sint_to_fp Int64Regs:$a)), + (CVT_f32_s64 Int64Regs:$a, CvtRN)>; + +// uint -> f32 +def : Pat<(f32 (uint_to_fp Int1Regs:$a)), + (CVT_f32_u32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>; +def : Pat<(f32 (uint_to_fp Int16Regs:$a)), + (CVT_f32_u16 Int16Regs:$a, CvtRN)>; +def : Pat<(f32 (uint_to_fp Int32Regs:$a)), + (CVT_f32_u32 Int32Regs:$a, CvtRN)>; +def : Pat<(f32 (uint_to_fp Int64Regs:$a)), + (CVT_f32_u64 Int64Regs:$a, CvtRN)>; + +// sint -> f64 +def : Pat<(f64 (sint_to_fp Int1Regs:$a)), + (CVT_f64_s32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>; +def : Pat<(f64 (sint_to_fp Int16Regs:$a)), + (CVT_f64_s16 Int16Regs:$a, CvtRN)>; +def : Pat<(f64 (sint_to_fp Int32Regs:$a)), + (CVT_f64_s32 Int32Regs:$a, CvtRN)>; +def : Pat<(f64 (sint_to_fp Int64Regs:$a)), + (CVT_f64_s64 Int64Regs:$a, CvtRN)>; + +// uint -> f64 +def : Pat<(f64 (uint_to_fp Int1Regs:$a)), + (CVT_f64_u32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>; +def : Pat<(f64 (uint_to_fp Int16Regs:$a)), + (CVT_f64_u16 Int16Regs:$a, CvtRN)>; +def : Pat<(f64 (uint_to_fp Int32Regs:$a)), + (CVT_f64_u32 Int32Regs:$a, CvtRN)>; +def : Pat<(f64 (uint_to_fp Int64Regs:$a)), + (CVT_f64_u64 Int64Regs:$a, CvtRN)>; + + +// f32 -> sint +def : Pat<(i1 (fp_to_sint Float32Regs:$a)), + (SETP_b32ri (BITCONVERT_32_F2I Float32Regs:$a), 0, CmpEQ)>; +def : Pat<(i16 (fp_to_sint Float32Regs:$a)), + (CVT_s16_f32 Float32Regs:$a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(i16 (fp_to_sint Float32Regs:$a)), + (CVT_s16_f32 Float32Regs:$a, CvtRZI)>; +def : Pat<(i32 (fp_to_sint Float32Regs:$a)), + (CVT_s32_f32 Float32Regs:$a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(i32 (fp_to_sint Float32Regs:$a)), + (CVT_s32_f32 Float32Regs:$a, CvtRZI)>; +def : Pat<(i64 (fp_to_sint Float32Regs:$a)), + (CVT_s64_f32 Float32Regs:$a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(i64 (fp_to_sint Float32Regs:$a)), + (CVT_s64_f32 Float32Regs:$a, CvtRZI)>; + +// f32 -> uint +def : Pat<(i1 (fp_to_uint Float32Regs:$a)), + (SETP_b32ri (BITCONVERT_32_F2I Float32Regs:$a), 0, CmpEQ)>; +def : Pat<(i16 (fp_to_uint Float32Regs:$a)), + (CVT_u16_f32 Float32Regs:$a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(i16 (fp_to_uint Float32Regs:$a)), + (CVT_u16_f32 Float32Regs:$a, CvtRZI)>; +def : Pat<(i32 (fp_to_uint Float32Regs:$a)), + (CVT_u32_f32 Float32Regs:$a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(i32 (fp_to_uint Float32Regs:$a)), + (CVT_u32_f32 Float32Regs:$a, CvtRZI)>; +def : Pat<(i64 (fp_to_uint Float32Regs:$a)), + (CVT_u64_f32 Float32Regs:$a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(i64 (fp_to_uint Float32Regs:$a)), + (CVT_u64_f32 Float32Regs:$a, CvtRZI)>; + +// f64 -> sint +def : Pat<(i1 (fp_to_sint Float64Regs:$a)), + (SETP_b64ri (BITCONVERT_64_F2I Float64Regs:$a), 0, CmpEQ)>; +def : Pat<(i16 (fp_to_sint Float64Regs:$a)), + (CVT_s16_f64 Float64Regs:$a, CvtRZI)>; +def : Pat<(i32 (fp_to_sint Float64Regs:$a)), + (CVT_s32_f64 Float64Regs:$a, CvtRZI)>; +def : Pat<(i64 (fp_to_sint Float64Regs:$a)), + (CVT_s64_f64 Float64Regs:$a, CvtRZI)>; + +// f64 -> uint +def : Pat<(i1 (fp_to_uint Float64Regs:$a)), + (SETP_b64ri (BITCONVERT_64_F2I Float64Regs:$a), 0, CmpEQ)>; +def : Pat<(i16 (fp_to_uint Float64Regs:$a)), + (CVT_u16_f64 Float64Regs:$a, CvtRZI)>; +def : Pat<(i32 (fp_to_uint Float64Regs:$a)), + (CVT_u32_f64 Float64Regs:$a, CvtRZI)>; +def : Pat<(i64 (fp_to_uint Float64Regs:$a)), + (CVT_u64_f64 Float64Regs:$a, CvtRZI)>; + +// sext i1 +def : Pat<(i16 (sext Int1Regs:$a)), + (SELP_s16ii -1, 0, Int1Regs:$a)>; +def : Pat<(i32 (sext Int1Regs:$a)), + (SELP_s32ii -1, 0, Int1Regs:$a)>; +def : Pat<(i64 (sext Int1Regs:$a)), + (SELP_s64ii -1, 0, Int1Regs:$a)>; + +// zext i1 +def : Pat<(i16 (zext Int1Regs:$a)), + (SELP_u16ii 1, 0, Int1Regs:$a)>; +def : Pat<(i32 (zext Int1Regs:$a)), + (SELP_u32ii 1, 0, Int1Regs:$a)>; +def : Pat<(i64 (zext Int1Regs:$a)), + (SELP_u64ii 1, 0, Int1Regs:$a)>; + +// anyext i1 +def : Pat<(i16 (anyext Int1Regs:$a)), + (SELP_u16ii -1, 0, Int1Regs:$a)>; +def : Pat<(i32 (anyext Int1Regs:$a)), + (SELP_u32ii -1, 0, Int1Regs:$a)>; +def : Pat<(i64 (anyext Int1Regs:$a)), + (SELP_u64ii -1, 0, Int1Regs:$a)>; + +// sext i16 +def : Pat<(i32 (sext Int16Regs:$a)), + (CVT_s32_s16 Int16Regs:$a, CvtNONE)>; +def : Pat<(i64 (sext Int16Regs:$a)), + (CVT_s64_s16 Int16Regs:$a, CvtNONE)>; + +// zext i16 +def : Pat<(i32 (zext Int16Regs:$a)), + (CVT_u32_u16 Int16Regs:$a, CvtNONE)>; +def : Pat<(i64 (zext Int16Regs:$a)), + (CVT_u64_u16 Int16Regs:$a, CvtNONE)>; + +// anyext i16 +def : Pat<(i32 (anyext Int16Regs:$a)), + (CVT_u32_u16 Int16Regs:$a, CvtNONE)>; +def : Pat<(i64 (anyext Int16Regs:$a)), + (CVT_u64_u16 Int16Regs:$a, CvtNONE)>; + +// sext i32 +def : Pat<(i64 (sext Int32Regs:$a)), + (CVT_s64_s32 Int32Regs:$a, CvtNONE)>; + +// zext i32 +def : Pat<(i64 (zext Int32Regs:$a)), + (CVT_u64_u32 Int32Regs:$a, CvtNONE)>; + +// anyext i32 +def : Pat<(i64 (anyext Int32Regs:$a)), + (CVT_u64_u32 Int32Regs:$a, CvtNONE)>; + + +// truncate i64 +def : Pat<(i32 (trunc Int64Regs:$a)), + (CVT_u32_u64 Int64Regs:$a, CvtNONE)>; +def : Pat<(i16 (trunc Int64Regs:$a)), + (CVT_u16_u64 Int64Regs:$a, CvtNONE)>; +def : Pat<(i1 (trunc Int64Regs:$a)), + (SETP_b64ri (ANDb64ri Int64Regs:$a, 1), 1, CmpEQ)>; + +// truncate i32 +def : Pat<(i16 (trunc Int32Regs:$a)), + (CVT_u16_u32 Int32Regs:$a, CvtNONE)>; +def : Pat<(i1 (trunc Int32Regs:$a)), + (SETP_b32ri (ANDb32ri Int32Regs:$a, 1), 1, CmpEQ)>; + +// truncate i16 +def : Pat<(i1 (trunc Int16Regs:$a)), + (SETP_b16ri (ANDb16ri Int16Regs:$a, 1), 1, CmpEQ)>; + +// sext_inreg +def : Pat<(sext_inreg Int16Regs:$a, i8), (CVT_INREG_s16_s8 Int16Regs:$a)>; +def : Pat<(sext_inreg Int32Regs:$a, i8), (CVT_INREG_s32_s8 Int32Regs:$a)>; +def : Pat<(sext_inreg Int32Regs:$a, i16), (CVT_INREG_s32_s16 Int32Regs:$a)>; +def : Pat<(sext_inreg Int64Regs:$a, i8), (CVT_INREG_s64_s8 Int64Regs:$a)>; +def : Pat<(sext_inreg Int64Regs:$a, i16), (CVT_INREG_s64_s16 Int64Regs:$a)>; +def : Pat<(sext_inreg Int64Regs:$a, i32), (CVT_INREG_s64_s32 Int64Regs:$a)>; + + +// Select instructions with 32-bit predicates +def : Pat<(select Int32Regs:$pred, Int16Regs:$a, Int16Regs:$b), + (SELP_b16rr Int16Regs:$a, Int16Regs:$b, + (SETP_b32ri (ANDb32ri Int32Regs:$pred, 1), 1, CmpEQ))>; +def : Pat<(select Int32Regs:$pred, Int32Regs:$a, Int32Regs:$b), + (SELP_b32rr Int32Regs:$a, Int32Regs:$b, + (SETP_b32ri (ANDb32ri Int32Regs:$pred, 1), 1, CmpEQ))>; +def : Pat<(select Int32Regs:$pred, Int64Regs:$a, Int64Regs:$b), + (SELP_b64rr Int64Regs:$a, Int64Regs:$b, + (SETP_b32ri (ANDb32ri Int32Regs:$pred, 1), 1, CmpEQ))>; +def : Pat<(select Int32Regs:$pred, Float32Regs:$a, Float32Regs:$b), + (SELP_f32rr Float32Regs:$a, Float32Regs:$b, + (SETP_b32ri (ANDb32ri Int32Regs:$pred, 1), 1, CmpEQ))>; +def : Pat<(select Int32Regs:$pred, Float64Regs:$a, Float64Regs:$b), + (SELP_f64rr Float64Regs:$a, Float64Regs:$b, + (SETP_b32ri (ANDb32ri Int32Regs:$pred, 1), 1, CmpEQ))>; + + // pack a set of smaller int registers to a larger int register -def V4I8toI32 : NVPTXInst<(outs Int32Regs:$d), - (ins Int8Regs:$s1, Int8Regs:$s2, - Int8Regs:$s3, Int8Regs:$s4), - !strconcat("{{\n\t.reg .b8\t%t<4>;", - !strconcat("\n\tcvt.u8.u8\t%t0, $s1;", - !strconcat("\n\tcvt.u8.u8\t%t1, $s2;", - !strconcat("\n\tcvt.u8.u8\t%t2, $s3;", - !strconcat("\n\tcvt.u8.u8\t%t3, $s4;", - "\n\tmov.b32\t$d, {%t0, %t1, %t2, %t3};\n\t}}"))))), - []>; def V4I16toI64 : NVPTXInst<(outs Int64Regs:$d), (ins Int16Regs:$s1, Int16Regs:$s2, Int16Regs:$s3, Int16Regs:$s4), "mov.b64\t$d, {{$s1, $s2, $s3, $s4}};", []>; -def V2I8toI16 : NVPTXInst<(outs Int16Regs:$d), - (ins Int8Regs:$s1, Int8Regs:$s2), - !strconcat("{{\n\t.reg .b8\t%t<2>;", - !strconcat("\n\tcvt.u8.u8\t%t0, $s1;", - !strconcat("\n\tcvt.u8.u8\t%t1, $s2;", - "\n\tmov.b16\t$d, {%t0, %t1};\n\t}}"))), - []>; def V2I16toI32 : NVPTXInst<(outs Int32Regs:$d), (ins Int16Regs:$s1, Int16Regs:$s2), "mov.b32\t$d, {{$s1, $s2}};", @@ -2733,28 +2436,11 @@ def V2F32toF64 : NVPTXInst<(outs Float64Regs:$d), []>; // unpack a larger int register to a set of smaller int registers -def I32toV4I8 : NVPTXInst<(outs Int8Regs:$d1, Int8Regs:$d2, - Int8Regs:$d3, Int8Regs:$d4), - (ins Int32Regs:$s), - !strconcat("{{\n\t.reg .b8\t%t<4>;", - !strconcat("\n\tmov.b32\t{%t0, %t1, %t2, %t3}, $s;", - !strconcat("\n\tcvt.u8.u8\t$d1, %t0;", - !strconcat("\n\tcvt.u8.u8\t$d2, %t1;", - !strconcat("\n\tcvt.u8.u8\t$d3, %t2;", - "\n\tcvt.u8.u8\t$d4, %t3;\n\t}}"))))), - []>; def I64toV4I16 : NVPTXInst<(outs Int16Regs:$d1, Int16Regs:$d2, Int16Regs:$d3, Int16Regs:$d4), (ins Int64Regs:$s), "mov.b64\t{{$d1, $d2, $d3, $d4}}, $s;", []>; -def I16toV2I8 : NVPTXInst<(outs Int8Regs:$d1, Int8Regs:$d2), - (ins Int16Regs:$s), - !strconcat("{{\n\t.reg .b8\t%t<2>;", - !strconcat("\n\tmov.b16\t{%t0, %t1}, $s;", - !strconcat("\n\tcvt.u8.u8\t$d1, %t0;", - "\n\tcvt.u8.u8\t$d2, %t1;\n\t}}"))), - []>; def I32toV2I16 : NVPTXInst<(outs Int16Regs:$d1, Int16Regs:$d2), (ins Int32Regs:$s), "mov.b32\t{{$d1, $d2}}, $s;", @@ -2768,21 +2454,75 @@ def F64toV2F32 : NVPTXInst<(outs Float32Regs:$d1, Float32Regs:$d2), "mov.b64\t{{$d1, $d2}}, $s;", []>; -def FPRound_ftz : NVPTXInst<(outs Float32Regs:$d), (ins Float64Regs:$a), - "cvt.rn.ftz.f32.f64 \t$d, $a;", - [(set Float32Regs:$d, (fround Float64Regs:$a))]>, Requires<[doF32FTZ]>; - -def FPRound : NVPTXInst<(outs Float32Regs:$d), (ins Float64Regs:$a), - "cvt.rn.f32.f64 \t$d, $a;", - [(set Float32Regs:$d, (fround Float64Regs:$a))]>; - -def FPExtend_ftz : NVPTXInst<(outs Float64Regs:$d), (ins Float32Regs:$a), - "cvt.ftz.f64.f32 \t$d, $a;", - [(set Float64Regs:$d, (fextend Float32Regs:$a))]>, Requires<[doF32FTZ]>; - -def FPExtend : NVPTXInst<(outs Float64Regs:$d), (ins Float32Regs:$a), - "cvt.f64.f32 \t$d, $a;", - [(set Float64Regs:$d, (fextend Float32Regs:$a))]>; +// Count leading zeros +def CLZr32 : NVPTXInst<(outs Int32Regs:$d), (ins Int32Regs:$a), + "clz.b32\t$d, $a;", + []>; +def CLZr64 : NVPTXInst<(outs Int32Regs:$d), (ins Int64Regs:$a), + "clz.b64\t$d, $a;", + []>; + +// 32-bit has a direct PTX instruction +def : Pat<(ctlz Int32Regs:$a), + (CLZr32 Int32Regs:$a)>; +def : Pat<(ctlz_zero_undef Int32Regs:$a), + (CLZr32 Int32Regs:$a)>; + +// For 64-bit, the result in PTX is actually 32-bit so we zero-extend +// to 64-bit to match the LLVM semantics +def : Pat<(ctlz Int64Regs:$a), + (CVT_u64_u32 (CLZr64 Int64Regs:$a), CvtNONE)>; +def : Pat<(ctlz_zero_undef Int64Regs:$a), + (CVT_u64_u32 (CLZr64 Int64Regs:$a), CvtNONE)>; + +// For 16-bit, we zero-extend to 32-bit, then trunc the result back +// to 16-bits (ctlz of a 16-bit value is guaranteed to require less +// than 16 bits to store). We also need to subtract 16 because the +// high-order 16 zeros were counted. +def : Pat<(ctlz Int16Regs:$a), + (SUBi16ri (CVT_u16_u32 (CLZr32 + (CVT_u32_u16 Int16Regs:$a, CvtNONE)), + CvtNONE), 16)>; +def : Pat<(ctlz_zero_undef Int16Regs:$a), + (SUBi16ri (CVT_u16_u32 (CLZr32 + (CVT_u32_u16 Int16Regs:$a, CvtNONE)), + CvtNONE), 16)>; + +// Population count +def POPCr32 : NVPTXInst<(outs Int32Regs:$d), (ins Int32Regs:$a), + "popc.b32\t$d, $a;", + []>; +def POPCr64 : NVPTXInst<(outs Int32Regs:$d), (ins Int64Regs:$a), + "popc.b64\t$d, $a;", + []>; + +// 32-bit has a direct PTX instruction +def : Pat<(ctpop Int32Regs:$a), + (POPCr32 Int32Regs:$a)>; + +// For 64-bit, the result in PTX is actually 32-bit so we zero-extend +// to 64-bit to match the LLVM semantics +def : Pat<(ctpop Int64Regs:$a), + (CVT_u64_u32 (POPCr64 Int64Regs:$a), CvtNONE)>; + +// For 16-bit, we zero-extend to 32-bit, then trunc the result back +// to 16-bits (ctpop of a 16-bit value is guaranteed to require less +// than 16 bits to store) +def : Pat<(ctpop Int16Regs:$a), + (CVT_u16_u32 (POPCr32 (CVT_u32_u16 Int16Regs:$a, CvtNONE)), + CvtNONE)>; + +// fround f64 -> f32 +def : Pat<(f32 (fround Float64Regs:$a)), + (CVT_f32_f64 Float64Regs:$a, CvtRN_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(f32 (fround Float64Regs:$a)), + (CVT_f32_f64 Float64Regs:$a, CvtRN)>; + +// fextend f32 -> f64 +def : Pat<(f64 (fextend Float32Regs:$a)), + (CVT_f64_f32 Float32Regs:$a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(f64 (fextend Float32Regs:$a)), + (CVT_f64_f32 Float32Regs:$a, CvtNONE)>; def retflag : SDNode<"NVPTXISD::RET_FLAG", SDTNone, [SDNPHasChain, SDNPOptInGlue]>; @@ -2810,8 +2550,8 @@ let isTerminator=1 in { [(br bb:$target)]>; } -def : Pat<(brcond Int32Regs:$a, bb:$target), (CBranch - (ISetUNEi32ri_p Int32Regs:$a, 0), bb:$target)>; +def : Pat<(brcond Int32Regs:$a, bb:$target), + (CBranch (SETP_u32ri Int32Regs:$a, 0, CmpNE), bb:$target)>; // SelectionDAGBuilder::visitSWitchCase() will invert the condition of a // conditional branch if @@ -2867,6 +2607,20 @@ def trapinst : NVPTXInst<(outs), (ins), "trap;", [(trap)]>; +// Call prototype wrapper +def SDTCallPrototype : SDTypeProfile<0, 1, [SDTCisInt<0>]>; +def CallPrototype + : SDNode<"NVPTXISD::CallPrototype", SDTCallPrototype, + [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; +def ProtoIdent : Operand<i32> { + let PrintMethod = "printProtoIdent"; +} +def CALL_PROTOTYPE + : NVPTXInst<(outs), (ins ProtoIdent:$ident), + "$ident", [(CallPrototype (i32 texternalsym:$ident))]>; + + + include "NVPTXIntrinsics.td" diff --git a/contrib/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/contrib/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index 24037ca..14049b1 100644 --- a/contrib/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/contrib/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -82,49 +82,36 @@ def INT_MEMBAR_SYS : MEMBAR<"membar.sys;", int_nvvm_membar_sys>; //----------------------------------- // Map min(1.0, max(0.0, x)) to sat(x) -multiclass SAT<NVPTXRegClass regclass, Operand fimm, Intrinsic IntMinOp, - Intrinsic IntMaxOp, PatLeaf f0, PatLeaf f1, string OpStr> { - - // fmin(1.0, fmax(0.0, x)) => sat(x) - def SAT11 : NVPTXInst<(outs regclass:$dst), - (ins fimm:$srcf0, fimm:$srcf1, regclass:$src), - OpStr, - [(set regclass:$dst, (IntMinOp f1:$srcf0 , - (IntMaxOp f0:$srcf1, regclass:$src)))]>; - - // fmin(1.0, fmax(x, 0.0)) => sat(x) - def SAT12 : NVPTXInst<(outs regclass:$dst), - (ins fimm:$srcf0, fimm:$srcf1, regclass:$src), - OpStr, - [(set regclass:$dst, (IntMinOp f1:$srcf0 , - (IntMaxOp regclass:$src, f0:$srcf1)))]>; - - // fmin(fmax(0.0, x), 1.0) => sat(x) - def SAT13 : NVPTXInst<(outs regclass:$dst), - (ins fimm:$srcf0, fimm:$srcf1, regclass:$src), - OpStr, - [(set regclass:$dst, (IntMinOp - (IntMaxOp f0:$srcf0, regclass:$src), f1:$srcf1))]>; - - // fmin(fmax(x, 0.0), 1.0) => sat(x) - def SAT14 : NVPTXInst<(outs regclass:$dst), - (ins fimm:$srcf0, fimm:$srcf1, regclass:$src), - OpStr, - [(set regclass:$dst, (IntMinOp - (IntMaxOp regclass:$src, f0:$srcf0), f1:$srcf1))]>; - -} -// Note that max(0.0, min(x, 1.0)) cannot be mapped to sat(x) because when x -// is NaN +// Note that max(0.0, min(x, 1.0)) cannot be mapped to sat(x) because when x is +// NaN // max(0.0, min(x, 1.0)) is 1.0 while sat(x) is 0. // Same story for fmax, fmin. -defm SAT_fmin_fmax_f : SAT<Float32Regs, f32imm, int_nvvm_fmin_f, - int_nvvm_fmax_f, immFloat0, immFloat1, - "cvt.sat.f32.f32 \t$dst, $src; \n">; -defm SAT_fmin_fmax_d : SAT<Float64Regs, f64imm, int_nvvm_fmin_d, - int_nvvm_fmax_d, immDouble0, immDouble1, - "cvt.sat.f64.f64 \t$dst, $src; \n">; +def : Pat<(int_nvvm_fmin_f immFloat1, + (int_nvvm_fmax_f immFloat0, Float32Regs:$a)), + (CVT_f32_f32 Float32Regs:$a, CvtSAT)>; +def : Pat<(int_nvvm_fmin_f immFloat1, + (int_nvvm_fmax_f Float32Regs:$a, immFloat0)), + (CVT_f32_f32 Float32Regs:$a, CvtSAT)>; +def : Pat<(int_nvvm_fmin_f + (int_nvvm_fmax_f immFloat0, Float32Regs:$a), immFloat1), + (CVT_f32_f32 Float32Regs:$a, CvtSAT)>; +def : Pat<(int_nvvm_fmin_f + (int_nvvm_fmax_f Float32Regs:$a, immFloat0), immFloat1), + (CVT_f32_f32 Float32Regs:$a, CvtSAT)>; + +def : Pat<(int_nvvm_fmin_d immDouble1, + (int_nvvm_fmax_d immDouble0, Float64Regs:$a)), + (CVT_f64_f64 Float64Regs:$a, CvtSAT)>; +def : Pat<(int_nvvm_fmin_d immDouble1, + (int_nvvm_fmax_d Float64Regs:$a, immDouble0)), + (CVT_f64_f64 Float64Regs:$a, CvtSAT)>; +def : Pat<(int_nvvm_fmin_d + (int_nvvm_fmax_d immDouble0, Float64Regs:$a), immDouble1), + (CVT_f64_f64 Float64Regs:$a, CvtSAT)>; +def : Pat<(int_nvvm_fmin_d + (int_nvvm_fmax_d Float64Regs:$a, immDouble0), immDouble1), + (CVT_f64_f64 Float64Regs:$a, CvtSAT)>; // We need a full string for OpcStr here because we need to deal with case like @@ -312,19 +299,19 @@ def INT_NVVM_SAD_UI : F_MATH_3<"sad.u32 \t$dst, $src0, $src1, $src2;", // Floor Ceil // -def INT_NVVM_FLOOR_FTZ_F : F_MATH_1<"cvt.rmi.ftz.f32.f32 \t$dst, $src0;", - Float32Regs, Float32Regs, int_nvvm_floor_ftz_f>; -def INT_NVVM_FLOOR_F : F_MATH_1<"cvt.rmi.f32.f32 \t$dst, $src0;", - Float32Regs, Float32Regs, int_nvvm_floor_f>; -def INT_NVVM_FLOOR_D : F_MATH_1<"cvt.rmi.f64.f64 \t$dst, $src0;", - Float64Regs, Float64Regs, int_nvvm_floor_d>; +def : Pat<(int_nvvm_floor_ftz_f Float32Regs:$a), + (CVT_f32_f32 Float32Regs:$a, CvtRMI_FTZ)>; +def : Pat<(int_nvvm_floor_f Float32Regs:$a), + (CVT_f32_f32 Float32Regs:$a, CvtRMI)>; +def : Pat<(int_nvvm_floor_d Float64Regs:$a), + (CVT_f64_f64 Float64Regs:$a, CvtRMI)>; -def INT_NVVM_CEIL_FTZ_F : F_MATH_1<"cvt.rpi.ftz.f32.f32 \t$dst, $src0;", - Float32Regs, Float32Regs, int_nvvm_ceil_ftz_f>; -def INT_NVVM_CEIL_F : F_MATH_1<"cvt.rpi.f32.f32 \t$dst, $src0;", - Float32Regs, Float32Regs, int_nvvm_ceil_f>; -def INT_NVVM_CEIL_D : F_MATH_1<"cvt.rpi.f64.f64 \t$dst, $src0;", - Float64Regs, Float64Regs, int_nvvm_ceil_d>; +def : Pat<(int_nvvm_ceil_ftz_f Float32Regs:$a), + (CVT_f32_f32 Float32Regs:$a, CvtRPI_FTZ)>; +def : Pat<(int_nvvm_ceil_f Float32Regs:$a), + (CVT_f32_f32 Float32Regs:$a, CvtRPI)>; +def : Pat<(int_nvvm_ceil_d Float64Regs:$a), + (CVT_f64_f64 Float64Regs:$a, CvtRPI)>; // // Abs @@ -347,37 +334,34 @@ def INT_NVVM_FABS_D : F_MATH_1<"abs.f64 \t$dst, $src0;", Float64Regs, // Round // -def INT_NVVM_ROUND_FTZ_F : F_MATH_1<"cvt.rni.ftz.f32.f32 \t$dst, $src0;", - Float32Regs, Float32Regs, int_nvvm_round_ftz_f>; -def INT_NVVM_ROUND_F : F_MATH_1<"cvt.rni.f32.f32 \t$dst, $src0;", Float32Regs, - Float32Regs, int_nvvm_round_f>; - -def INT_NVVM_ROUND_D : F_MATH_1<"cvt.rni.f64.f64 \t$dst, $src0;", Float64Regs, - Float64Regs, int_nvvm_round_d>; +def : Pat<(int_nvvm_round_ftz_f Float32Regs:$a), + (CVT_f32_f32 Float32Regs:$a, CvtRNI_FTZ)>; +def : Pat<(int_nvvm_round_f Float32Regs:$a), + (CVT_f32_f32 Float32Regs:$a, CvtRNI)>; +def : Pat<(int_nvvm_round_d Float64Regs:$a), + (CVT_f64_f64 Float64Regs:$a, CvtRNI)>; // // Trunc // -def INT_NVVM_TRUNC_FTZ_F : F_MATH_1<"cvt.rzi.ftz.f32.f32 \t$dst, $src0;", - Float32Regs, Float32Regs, int_nvvm_trunc_ftz_f>; -def INT_NVVM_TRUNC_F : F_MATH_1<"cvt.rzi.f32.f32 \t$dst, $src0;", Float32Regs, - Float32Regs, int_nvvm_trunc_f>; - -def INT_NVVM_TRUNC_D : F_MATH_1<"cvt.rzi.f64.f64 \t$dst, $src0;", Float64Regs, - Float64Regs, int_nvvm_trunc_d>; +def : Pat<(int_nvvm_trunc_ftz_f Float32Regs:$a), + (CVT_f32_f32 Float32Regs:$a, CvtRZI_FTZ)>; +def : Pat<(int_nvvm_trunc_f Float32Regs:$a), + (CVT_f32_f32 Float32Regs:$a, CvtRZI)>; +def : Pat<(int_nvvm_trunc_d Float64Regs:$a), + (CVT_f64_f64 Float64Regs:$a, CvtRZI)>; // // Saturate // -def INT_NVVM_SATURATE_FTZ_F : F_MATH_1<"cvt.sat.ftz.f32.f32 \t$dst, $src0;", - Float32Regs, Float32Regs, int_nvvm_saturate_ftz_f>; -def INT_NVVM_SATURATE_F : F_MATH_1<"cvt.sat.f32.f32 \t$dst, $src0;", - Float32Regs, Float32Regs, int_nvvm_saturate_f>; - -def INT_NVVM_SATURATE_D : F_MATH_1<"cvt.sat.f64.f64 \t$dst, $src0;", - Float64Regs, Float64Regs, int_nvvm_saturate_d>; +def : Pat<(int_nvvm_saturate_ftz_f Float32Regs:$a), + (CVT_f32_f32 Float32Regs:$a, CvtSAT_FTZ)>; +def : Pat<(int_nvvm_saturate_f Float32Regs:$a), + (CVT_f32_f32 Float32Regs:$a, CvtSAT)>; +def : Pat<(int_nvvm_saturate_d Float64Regs:$a), + (CVT_f64_f64 Float64Regs:$a, CvtSAT)>; // // Exp2 Log2 @@ -568,110 +552,110 @@ def INT_NVVM_ADD_RP_D : F_MATH_2<"add.rp.f64 \t$dst, $src0, $src1;", // Convert // -def INT_NVVM_D2F_RN_FTZ : F_MATH_1<"cvt.rn.ftz.f32.f64 \t$dst, $src0;", - Float32Regs, Float64Regs, int_nvvm_d2f_rn_ftz>; -def INT_NVVM_D2F_RN : F_MATH_1<"cvt.rn.f32.f64 \t$dst, $src0;", - Float32Regs, Float64Regs, int_nvvm_d2f_rn>; -def INT_NVVM_D2F_RZ_FTZ : F_MATH_1<"cvt.rz.ftz.f32.f64 \t$dst, $src0;", - Float32Regs, Float64Regs, int_nvvm_d2f_rz_ftz>; -def INT_NVVM_D2F_RZ : F_MATH_1<"cvt.rz.f32.f64 \t$dst, $src0;", - Float32Regs, Float64Regs, int_nvvm_d2f_rz>; -def INT_NVVM_D2F_RM_FTZ : F_MATH_1<"cvt.rm.ftz.f32.f64 \t$dst, $src0;", - Float32Regs, Float64Regs, int_nvvm_d2f_rm_ftz>; -def INT_NVVM_D2F_RM : F_MATH_1<"cvt.rm.f32.f64 \t$dst, $src0;", - Float32Regs, Float64Regs, int_nvvm_d2f_rm>; -def INT_NVVM_D2F_RP_FTZ : F_MATH_1<"cvt.rp.ftz.f32.f64 \t$dst, $src0;", - Float32Regs, Float64Regs, int_nvvm_d2f_rp_ftz>; -def INT_NVVM_D2F_RP : F_MATH_1<"cvt.rp.f32.f64 \t$dst, $src0;", - Float32Regs, Float64Regs, int_nvvm_d2f_rp>; - -def INT_NVVM_D2I_RN : F_MATH_1<"cvt.rni.s32.f64 \t$dst, $src0;", - Int32Regs, Float64Regs, int_nvvm_d2i_rn>; -def INT_NVVM_D2I_RZ : F_MATH_1<"cvt.rzi.s32.f64 \t$dst, $src0;", - Int32Regs, Float64Regs, int_nvvm_d2i_rz>; -def INT_NVVM_D2I_RM : F_MATH_1<"cvt.rmi.s32.f64 \t$dst, $src0;", - Int32Regs, Float64Regs, int_nvvm_d2i_rm>; -def INT_NVVM_D2I_RP : F_MATH_1<"cvt.rpi.s32.f64 \t$dst, $src0;", - Int32Regs, Float64Regs, int_nvvm_d2i_rp>; - -def INT_NVVM_D2UI_RN : F_MATH_1<"cvt.rni.u32.f64 \t$dst, $src0;", - Int32Regs, Float64Regs, int_nvvm_d2ui_rn>; -def INT_NVVM_D2UI_RZ : F_MATH_1<"cvt.rzi.u32.f64 \t$dst, $src0;", - Int32Regs, Float64Regs, int_nvvm_d2ui_rz>; -def INT_NVVM_D2UI_RM : F_MATH_1<"cvt.rmi.u32.f64 \t$dst, $src0;", - Int32Regs, Float64Regs, int_nvvm_d2ui_rm>; -def INT_NVVM_D2UI_RP : F_MATH_1<"cvt.rpi.u32.f64 \t$dst, $src0;", - Int32Regs, Float64Regs, int_nvvm_d2ui_rp>; - -def INT_NVVM_I2D_RN : F_MATH_1<"cvt.rn.f64.s32 \t$dst, $src0;", - Float64Regs, Int32Regs, int_nvvm_i2d_rn>; -def INT_NVVM_I2D_RZ : F_MATH_1<"cvt.rz.f64.s32 \t$dst, $src0;", - Float64Regs, Int32Regs, int_nvvm_i2d_rz>; -def INT_NVVM_I2D_RM : F_MATH_1<"cvt.rm.f64.s32 \t$dst, $src0;", - Float64Regs, Int32Regs, int_nvvm_i2d_rm>; -def INT_NVVM_I2D_RP : F_MATH_1<"cvt.rp.f64.s32 \t$dst, $src0;", - Float64Regs, Int32Regs, int_nvvm_i2d_rp>; - -def INT_NVVM_UI2D_RN : F_MATH_1<"cvt.rn.f64.u32 \t$dst, $src0;", - Float64Regs, Int32Regs, int_nvvm_ui2d_rn>; -def INT_NVVM_UI2D_RZ : F_MATH_1<"cvt.rz.f64.u32 \t$dst, $src0;", - Float64Regs, Int32Regs, int_nvvm_ui2d_rz>; -def INT_NVVM_UI2D_RM : F_MATH_1<"cvt.rm.f64.u32 \t$dst, $src0;", - Float64Regs, Int32Regs, int_nvvm_ui2d_rm>; -def INT_NVVM_UI2D_RP : F_MATH_1<"cvt.rp.f64.u32 \t$dst, $src0;", - Float64Regs, Int32Regs, int_nvvm_ui2d_rp>; - -def INT_NVVM_F2I_RN_FTZ : F_MATH_1<"cvt.rni.ftz.s32.f32 \t$dst, $src0;", - Int32Regs, Float32Regs, int_nvvm_f2i_rn_ftz>; -def INT_NVVM_F2I_RN : F_MATH_1<"cvt.rni.s32.f32 \t$dst, $src0;", Int32Regs, - Float32Regs, int_nvvm_f2i_rn>; -def INT_NVVM_F2I_RZ_FTZ : F_MATH_1<"cvt.rzi.ftz.s32.f32 \t$dst, $src0;", - Int32Regs, Float32Regs, int_nvvm_f2i_rz_ftz>; -def INT_NVVM_F2I_RZ : F_MATH_1<"cvt.rzi.s32.f32 \t$dst, $src0;", Int32Regs, - Float32Regs, int_nvvm_f2i_rz>; -def INT_NVVM_F2I_RM_FTZ : F_MATH_1<"cvt.rmi.ftz.s32.f32 \t$dst, $src0;", - Int32Regs, Float32Regs, int_nvvm_f2i_rm_ftz>; -def INT_NVVM_F2I_RM : F_MATH_1<"cvt.rmi.s32.f32 \t$dst, $src0;", Int32Regs, - Float32Regs, int_nvvm_f2i_rm>; -def INT_NVVM_F2I_RP_FTZ : F_MATH_1<"cvt.rpi.ftz.s32.f32 \t$dst, $src0;", - Int32Regs, Float32Regs, int_nvvm_f2i_rp_ftz>; -def INT_NVVM_F2I_RP : F_MATH_1<"cvt.rpi.s32.f32 \t$dst, $src0;", Int32Regs, - Float32Regs, int_nvvm_f2i_rp>; - -def INT_NVVM_F2UI_RN_FTZ : F_MATH_1<"cvt.rni.ftz.u32.f32 \t$dst, $src0;", - Int32Regs, Float32Regs, int_nvvm_f2ui_rn_ftz>; -def INT_NVVM_F2UI_RN : F_MATH_1<"cvt.rni.u32.f32 \t$dst, $src0;", Int32Regs, - Float32Regs, int_nvvm_f2ui_rn>; -def INT_NVVM_F2UI_RZ_FTZ : F_MATH_1<"cvt.rzi.ftz.u32.f32 \t$dst, $src0;", - Int32Regs, Float32Regs, int_nvvm_f2ui_rz_ftz>; -def INT_NVVM_F2UI_RZ : F_MATH_1<"cvt.rzi.u32.f32 \t$dst, $src0;", Int32Regs, - Float32Regs, int_nvvm_f2ui_rz>; -def INT_NVVM_F2UI_RM_FTZ : F_MATH_1<"cvt.rmi.ftz.u32.f32 \t$dst, $src0;", - Int32Regs, Float32Regs, int_nvvm_f2ui_rm_ftz>; -def INT_NVVM_F2UI_RM : F_MATH_1<"cvt.rmi.u32.f32 \t$dst, $src0;", Int32Regs, - Float32Regs, int_nvvm_f2ui_rm>; -def INT_NVVM_F2UI_RP_FTZ : F_MATH_1<"cvt.rpi.ftz.u32.f32 \t$dst, $src0;", - Int32Regs, Float32Regs, int_nvvm_f2ui_rp_ftz>; -def INT_NVVM_F2UI_RP : F_MATH_1<"cvt.rpi.u32.f32 \t$dst, $src0;", Int32Regs, - Float32Regs, int_nvvm_f2ui_rp>; - -def INT_NVVM_I2F_RN : F_MATH_1<"cvt.rn.f32.s32 \t$dst, $src0;", Float32Regs, - Int32Regs, int_nvvm_i2f_rn>; -def INT_NVVM_I2F_RZ : F_MATH_1<"cvt.rz.f32.s32 \t$dst, $src0;", Float32Regs, - Int32Regs, int_nvvm_i2f_rz>; -def INT_NVVM_I2F_RM : F_MATH_1<"cvt.rm.f32.s32 \t$dst, $src0;", Float32Regs, - Int32Regs, int_nvvm_i2f_rm>; -def INT_NVVM_I2F_RP : F_MATH_1<"cvt.rp.f32.s32 \t$dst, $src0;", Float32Regs, - Int32Regs, int_nvvm_i2f_rp>; - -def INT_NVVM_UI2F_RN : F_MATH_1<"cvt.rn.f32.u32 \t$dst, $src0;", Float32Regs, - Int32Regs, int_nvvm_ui2f_rn>; -def INT_NVVM_UI2F_RZ : F_MATH_1<"cvt.rz.f32.u32 \t$dst, $src0;", Float32Regs, - Int32Regs, int_nvvm_ui2f_rz>; -def INT_NVVM_UI2F_RM : F_MATH_1<"cvt.rm.f32.u32 \t$dst, $src0;", Float32Regs, - Int32Regs, int_nvvm_ui2f_rm>; -def INT_NVVM_UI2F_RP : F_MATH_1<"cvt.rp.f32.u32 \t$dst, $src0;", Float32Regs, - Int32Regs, int_nvvm_ui2f_rp>; +def : Pat<(int_nvvm_d2f_rn_ftz Float64Regs:$a), + (CVT_f32_f64 Float64Regs:$a, CvtRN_FTZ)>; +def : Pat<(int_nvvm_d2f_rn Float64Regs:$a), + (CVT_f32_f64 Float64Regs:$a, CvtRN)>; +def : Pat<(int_nvvm_d2f_rz_ftz Float64Regs:$a), + (CVT_f32_f64 Float64Regs:$a, CvtRZ_FTZ)>; +def : Pat<(int_nvvm_d2f_rz Float64Regs:$a), + (CVT_f32_f64 Float64Regs:$a, CvtRZ)>; +def : Pat<(int_nvvm_d2f_rm_ftz Float64Regs:$a), + (CVT_f32_f64 Float64Regs:$a, CvtRM_FTZ)>; +def : Pat<(int_nvvm_d2f_rm Float64Regs:$a), + (CVT_f32_f64 Float64Regs:$a, CvtRM)>; +def : Pat<(int_nvvm_d2f_rp_ftz Float64Regs:$a), + (CVT_f32_f64 Float64Regs:$a, CvtRP_FTZ)>; +def : Pat<(int_nvvm_d2f_rp Float64Regs:$a), + (CVT_f32_f64 Float64Regs:$a, CvtRP)>; + +def : Pat<(int_nvvm_d2i_rn Float64Regs:$a), + (CVT_s32_f64 Float64Regs:$a, CvtRNI)>; +def : Pat<(int_nvvm_d2i_rz Float64Regs:$a), + (CVT_s32_f64 Float64Regs:$a, CvtRZI)>; +def : Pat<(int_nvvm_d2i_rm Float64Regs:$a), + (CVT_s32_f64 Float64Regs:$a, CvtRMI)>; +def : Pat<(int_nvvm_d2i_rp Float64Regs:$a), + (CVT_s32_f64 Float64Regs:$a, CvtRPI)>; + +def : Pat<(int_nvvm_d2ui_rn Float64Regs:$a), + (CVT_u32_f64 Float64Regs:$a, CvtRNI)>; +def : Pat<(int_nvvm_d2ui_rz Float64Regs:$a), + (CVT_u32_f64 Float64Regs:$a, CvtRZI)>; +def : Pat<(int_nvvm_d2ui_rm Float64Regs:$a), + (CVT_u32_f64 Float64Regs:$a, CvtRMI)>; +def : Pat<(int_nvvm_d2ui_rp Float64Regs:$a), + (CVT_u32_f64 Float64Regs:$a, CvtRPI)>; + +def : Pat<(int_nvvm_i2d_rn Int32Regs:$a), + (CVT_f64_s32 Int32Regs:$a, CvtRN)>; +def : Pat<(int_nvvm_i2d_rz Int32Regs:$a), + (CVT_f64_s32 Int32Regs:$a, CvtRZ)>; +def : Pat<(int_nvvm_i2d_rm Int32Regs:$a), + (CVT_f64_s32 Int32Regs:$a, CvtRM)>; +def : Pat<(int_nvvm_i2d_rp Int32Regs:$a), + (CVT_f64_s32 Int32Regs:$a, CvtRP)>; + +def : Pat<(int_nvvm_ui2d_rn Int32Regs:$a), + (CVT_f64_u32 Int32Regs:$a, CvtRN)>; +def : Pat<(int_nvvm_ui2d_rz Int32Regs:$a), + (CVT_f64_u32 Int32Regs:$a, CvtRZ)>; +def : Pat<(int_nvvm_ui2d_rm Int32Regs:$a), + (CVT_f64_u32 Int32Regs:$a, CvtRM)>; +def : Pat<(int_nvvm_ui2d_rp Int32Regs:$a), + (CVT_f64_u32 Int32Regs:$a, CvtRP)>; + +def : Pat<(int_nvvm_f2i_rn_ftz Float32Regs:$a), + (CVT_s32_f32 Float32Regs:$a, CvtRNI_FTZ)>; +def : Pat<(int_nvvm_f2i_rn Float32Regs:$a), + (CVT_s32_f32 Float32Regs:$a, CvtRNI)>; +def : Pat<(int_nvvm_f2i_rz_ftz Float32Regs:$a), + (CVT_s32_f32 Float32Regs:$a, CvtRZI_FTZ)>; +def : Pat<(int_nvvm_f2i_rz Float32Regs:$a), + (CVT_s32_f32 Float32Regs:$a, CvtRZI)>; +def : Pat<(int_nvvm_f2i_rm_ftz Float32Regs:$a), + (CVT_s32_f32 Float32Regs:$a, CvtRMI_FTZ)>; +def : Pat<(int_nvvm_f2i_rm Float32Regs:$a), + (CVT_s32_f32 Float32Regs:$a, CvtRMI)>; +def : Pat<(int_nvvm_f2i_rp_ftz Float32Regs:$a), + (CVT_s32_f32 Float32Regs:$a, CvtRPI_FTZ)>; +def : Pat<(int_nvvm_f2i_rp Float32Regs:$a), + (CVT_s32_f32 Float32Regs:$a, CvtRPI)>; + +def : Pat<(int_nvvm_f2ui_rn_ftz Float32Regs:$a), + (CVT_u32_f32 Float32Regs:$a, CvtRNI_FTZ)>; +def : Pat<(int_nvvm_f2ui_rn Float32Regs:$a), + (CVT_u32_f32 Float32Regs:$a, CvtRNI)>; +def : Pat<(int_nvvm_f2ui_rz_ftz Float32Regs:$a), + (CVT_u32_f32 Float32Regs:$a, CvtRZI_FTZ)>; +def : Pat<(int_nvvm_f2ui_rz Float32Regs:$a), + (CVT_u32_f32 Float32Regs:$a, CvtRZI)>; +def : Pat<(int_nvvm_f2ui_rm_ftz Float32Regs:$a), + (CVT_u32_f32 Float32Regs:$a, CvtRMI_FTZ)>; +def : Pat<(int_nvvm_f2ui_rm Float32Regs:$a), + (CVT_u32_f32 Float32Regs:$a, CvtRMI)>; +def : Pat<(int_nvvm_f2ui_rp_ftz Float32Regs:$a), + (CVT_u32_f32 Float32Regs:$a, CvtRPI_FTZ)>; +def : Pat<(int_nvvm_f2ui_rp Float32Regs:$a), + (CVT_u32_f32 Float32Regs:$a, CvtRPI)>; + +def : Pat<(int_nvvm_i2f_rn Int32Regs:$a), + (CVT_f32_s32 Int32Regs:$a, CvtRN)>; +def : Pat<(int_nvvm_i2f_rz Int32Regs:$a), + (CVT_f32_s32 Int32Regs:$a, CvtRZ)>; +def : Pat<(int_nvvm_i2f_rm Int32Regs:$a), + (CVT_f32_s32 Int32Regs:$a, CvtRM)>; +def : Pat<(int_nvvm_i2f_rp Int32Regs:$a), + (CVT_f32_s32 Int32Regs:$a, CvtRP)>; + +def : Pat<(int_nvvm_ui2f_rn Int32Regs:$a), + (CVT_f32_u32 Int32Regs:$a, CvtRN)>; +def : Pat<(int_nvvm_ui2f_rz Int32Regs:$a), + (CVT_f32_u32 Int32Regs:$a, CvtRZ)>; +def : Pat<(int_nvvm_ui2f_rm Int32Regs:$a), + (CVT_f32_u32 Int32Regs:$a, CvtRM)>; +def : Pat<(int_nvvm_ui2f_rp Int32Regs:$a), + (CVT_f32_u32 Int32Regs:$a, CvtRP)>; def INT_NVVM_LOHI_I2D : F_MATH_2<"mov.b64 \t$dst, {{$src0, $src1}};", Float64Regs, Int32Regs, Int32Regs, int_nvvm_lohi_i2d>; @@ -687,91 +671,106 @@ def INT_NVVM_D2I_HI : F_MATH_1<!strconcat("{{\n\t", "}}"))), Int32Regs, Float64Regs, int_nvvm_d2i_hi>; -def INT_NVVM_F2LL_RN_FTZ : F_MATH_1<"cvt.rni.ftz.s64.f32 \t$dst, $src0;", - Int64Regs, Float32Regs, int_nvvm_f2ll_rn_ftz>; -def INT_NVVM_F2LL_RN : F_MATH_1<"cvt.rni.s64.f32 \t$dst, $src0;", Int64Regs, - Float32Regs, int_nvvm_f2ll_rn>; -def INT_NVVM_F2LL_RZ_FTZ : F_MATH_1<"cvt.rzi.ftz.s64.f32 \t$dst, $src0;", - Int64Regs, Float32Regs, int_nvvm_f2ll_rz_ftz>; -def INT_NVVM_F2LL_RZ : F_MATH_1<"cvt.rzi.s64.f32 \t$dst, $src0;", Int64Regs, - Float32Regs, int_nvvm_f2ll_rz>; -def INT_NVVM_F2LL_RM_FTZ : F_MATH_1<"cvt.rmi.ftz.s64.f32 \t$dst, $src0;", - Int64Regs, Float32Regs, int_nvvm_f2ll_rm_ftz>; -def INT_NVVM_F2LL_RM : F_MATH_1<"cvt.rmi.s64.f32 \t$dst, $src0;", Int64Regs, - Float32Regs, int_nvvm_f2ll_rm>; -def INT_NVVM_F2LL_RP_FTZ : F_MATH_1<"cvt.rpi.ftz.s64.f32 \t$dst, $src0;", - Int64Regs, Float32Regs, int_nvvm_f2ll_rp_ftz>; -def INT_NVVM_F2LL_RP : F_MATH_1<"cvt.rpi.s64.f32 \t$dst, $src0;", Int64Regs, - Float32Regs, int_nvvm_f2ll_rp>; - -def INT_NVVM_F2ULL_RN_FTZ : F_MATH_1<"cvt.rni.ftz.u64.f32 \t$dst, $src0;", - Int64Regs, Float32Regs, int_nvvm_f2ull_rn_ftz>; -def INT_NVVM_F2ULL_RN : F_MATH_1<"cvt.rni.u64.f32 \t$dst, $src0;", Int64Regs, - Float32Regs, int_nvvm_f2ull_rn>; -def INT_NVVM_F2ULL_RZ_FTZ : F_MATH_1<"cvt.rzi.ftz.u64.f32 \t$dst, $src0;", - Int64Regs, Float32Regs, int_nvvm_f2ull_rz_ftz>; -def INT_NVVM_F2ULL_RZ : F_MATH_1<"cvt.rzi.u64.f32 \t$dst, $src0;", Int64Regs, - Float32Regs, int_nvvm_f2ull_rz>; -def INT_NVVM_F2ULL_RM_FTZ : F_MATH_1<"cvt.rmi.ftz.u64.f32 \t$dst, $src0;", - Int64Regs, Float32Regs, int_nvvm_f2ull_rm_ftz>; -def INT_NVVM_F2ULL_RM : F_MATH_1<"cvt.rmi.u64.f32 \t$dst, $src0;", Int64Regs, - Float32Regs, int_nvvm_f2ull_rm>; -def INT_NVVM_F2ULL_RP_FTZ : F_MATH_1<"cvt.rpi.ftz.u64.f32 \t$dst, $src0;", - Int64Regs, Float32Regs, int_nvvm_f2ull_rp_ftz>; -def INT_NVVM_F2ULL_RP : F_MATH_1<"cvt.rpi.u64.f32 \t$dst, $src0;", Int64Regs, - Float32Regs, int_nvvm_f2ull_rp>; - -def INT_NVVM_D2LL_RN : F_MATH_1<"cvt.rni.s64.f64 \t$dst, $src0;", Int64Regs, - Float64Regs, int_nvvm_d2ll_rn>; -def INT_NVVM_D2LL_RZ : F_MATH_1<"cvt.rzi.s64.f64 \t$dst, $src0;", Int64Regs, - Float64Regs, int_nvvm_d2ll_rz>; -def INT_NVVM_D2LL_RM : F_MATH_1<"cvt.rmi.s64.f64 \t$dst, $src0;", Int64Regs, - Float64Regs, int_nvvm_d2ll_rm>; -def INT_NVVM_D2LL_RP : F_MATH_1<"cvt.rpi.s64.f64 \t$dst, $src0;", Int64Regs, - Float64Regs, int_nvvm_d2ll_rp>; - -def INT_NVVM_D2ULL_RN : F_MATH_1<"cvt.rni.u64.f64 \t$dst, $src0;", Int64Regs, - Float64Regs, int_nvvm_d2ull_rn>; -def INT_NVVM_D2ULL_RZ : F_MATH_1<"cvt.rzi.u64.f64 \t$dst, $src0;", Int64Regs, - Float64Regs, int_nvvm_d2ull_rz>; -def INT_NVVM_D2ULL_RM : F_MATH_1<"cvt.rmi.u64.f64 \t$dst, $src0;", Int64Regs, - Float64Regs, int_nvvm_d2ull_rm>; -def INT_NVVM_D2ULL_RP : F_MATH_1<"cvt.rpi.u64.f64 \t$dst, $src0;", Int64Regs, - Float64Regs, int_nvvm_d2ull_rp>; - -def INT_NVVM_LL2F_RN : F_MATH_1<"cvt.rn.f32.s64 \t$dst, $src0;", Float32Regs, - Int64Regs, int_nvvm_ll2f_rn>; -def INT_NVVM_LL2F_RZ : F_MATH_1<"cvt.rz.f32.s64 \t$dst, $src0;", Float32Regs, - Int64Regs, int_nvvm_ll2f_rz>; -def INT_NVVM_LL2F_RM : F_MATH_1<"cvt.rm.f32.s64 \t$dst, $src0;", Float32Regs, - Int64Regs, int_nvvm_ll2f_rm>; -def INT_NVVM_LL2F_RP : F_MATH_1<"cvt.rp.f32.s64 \t$dst, $src0;", Float32Regs, - Int64Regs, int_nvvm_ll2f_rp>; -def INT_NVVM_ULL2F_RN : F_MATH_1<"cvt.rn.f32.u64 \t$dst, $src0;", Float32Regs, - Int64Regs, int_nvvm_ull2f_rn>; -def INT_NVVM_ULL2F_RZ : F_MATH_1<"cvt.rz.f32.u64 \t$dst, $src0;", Float32Regs, - Int64Regs, int_nvvm_ull2f_rz>; -def INT_NVVM_ULL2F_RM : F_MATH_1<"cvt.rm.f32.u64 \t$dst, $src0;", Float32Regs, - Int64Regs, int_nvvm_ull2f_rm>; -def INT_NVVM_ULL2F_RP : F_MATH_1<"cvt.rp.f32.u64 \t$dst, $src0;", Float32Regs, - Int64Regs, int_nvvm_ull2f_rp>; - -def INT_NVVM_LL2D_RN : F_MATH_1<"cvt.rn.f64.s64 \t$dst, $src0;", Float64Regs, - Int64Regs, int_nvvm_ll2d_rn>; -def INT_NVVM_LL2D_RZ : F_MATH_1<"cvt.rz.f64.s64 \t$dst, $src0;", Float64Regs, - Int64Regs, int_nvvm_ll2d_rz>; -def INT_NVVM_LL2D_RM : F_MATH_1<"cvt.rm.f64.s64 \t$dst, $src0;", Float64Regs, - Int64Regs, int_nvvm_ll2d_rm>; -def INT_NVVM_LL2D_RP : F_MATH_1<"cvt.rp.f64.s64 \t$dst, $src0;", Float64Regs, - Int64Regs, int_nvvm_ll2d_rp>; -def INT_NVVM_ULL2D_RN : F_MATH_1<"cvt.rn.f64.u64 \t$dst, $src0;", Float64Regs, - Int64Regs, int_nvvm_ull2d_rn>; -def INT_NVVM_ULL2D_RZ : F_MATH_1<"cvt.rz.f64.u64 \t$dst, $src0;", Float64Regs, - Int64Regs, int_nvvm_ull2d_rz>; -def INT_NVVM_ULL2D_RM : F_MATH_1<"cvt.rm.f64.u64 \t$dst, $src0;", Float64Regs, - Int64Regs, int_nvvm_ull2d_rm>; -def INT_NVVM_ULL2D_RP : F_MATH_1<"cvt.rp.f64.u64 \t$dst, $src0;", Float64Regs, - Int64Regs, int_nvvm_ull2d_rp>; +def : Pat<(int_nvvm_f2ll_rn_ftz Float32Regs:$a), + (CVT_s64_f32 Float32Regs:$a, CvtRNI_FTZ)>; +def : Pat<(int_nvvm_f2ll_rn Float32Regs:$a), + (CVT_s64_f32 Float32Regs:$a, CvtRNI)>; +def : Pat<(int_nvvm_f2ll_rz_ftz Float32Regs:$a), + (CVT_s64_f32 Float32Regs:$a, CvtRZI_FTZ)>; +def : Pat<(int_nvvm_f2ll_rz Float32Regs:$a), + (CVT_s64_f32 Float32Regs:$a, CvtRZI)>; +def : Pat<(int_nvvm_f2ll_rm_ftz Float32Regs:$a), + (CVT_s64_f32 Float32Regs:$a, CvtRMI_FTZ)>; +def : Pat<(int_nvvm_f2ll_rm Float32Regs:$a), + (CVT_s64_f32 Float32Regs:$a, CvtRMI)>; +def : Pat<(int_nvvm_f2ll_rp_ftz Float32Regs:$a), + (CVT_s64_f32 Float32Regs:$a, CvtRPI_FTZ)>; +def : Pat<(int_nvvm_f2ll_rp Float32Regs:$a), + (CVT_s64_f32 Float32Regs:$a, CvtRPI)>; + +def : Pat<(int_nvvm_f2ull_rn_ftz Float32Regs:$a), + (CVT_u64_f32 Float32Regs:$a, CvtRNI_FTZ)>; +def : Pat<(int_nvvm_f2ull_rn Float32Regs:$a), + (CVT_u64_f32 Float32Regs:$a, CvtRNI)>; +def : Pat<(int_nvvm_f2ull_rz_ftz Float32Regs:$a), + (CVT_u64_f32 Float32Regs:$a, CvtRZI_FTZ)>; +def : Pat<(int_nvvm_f2ull_rz Float32Regs:$a), + (CVT_u64_f32 Float32Regs:$a, CvtRZI)>; +def : Pat<(int_nvvm_f2ull_rm_ftz Float32Regs:$a), + (CVT_u64_f32 Float32Regs:$a, CvtRMI_FTZ)>; +def : Pat<(int_nvvm_f2ull_rm Float32Regs:$a), + (CVT_u64_f32 Float32Regs:$a, CvtRMI)>; +def : Pat<(int_nvvm_f2ull_rp_ftz Float32Regs:$a), + (CVT_u64_f32 Float32Regs:$a, CvtRPI_FTZ)>; +def : Pat<(int_nvvm_f2ull_rp Float32Regs:$a), + (CVT_u64_f32 Float32Regs:$a, CvtRPI)>; + +def : Pat<(int_nvvm_d2ll_rn Float64Regs:$a), + (CVT_s64_f64 Float64Regs:$a, CvtRNI)>; +def : Pat<(int_nvvm_d2ll_rz Float64Regs:$a), + (CVT_s64_f64 Float64Regs:$a, CvtRZI)>; +def : Pat<(int_nvvm_d2ll_rm Float64Regs:$a), + (CVT_s64_f64 Float64Regs:$a, CvtRMI)>; +def : Pat<(int_nvvm_d2ll_rp Float64Regs:$a), + (CVT_s64_f64 Float64Regs:$a, CvtRPI)>; + +def : Pat<(int_nvvm_d2ull_rn Float64Regs:$a), + (CVT_u64_f64 Float64Regs:$a, CvtRNI)>; +def : Pat<(int_nvvm_d2ull_rz Float64Regs:$a), + (CVT_u64_f64 Float64Regs:$a, CvtRZI)>; +def : Pat<(int_nvvm_d2ull_rm Float64Regs:$a), + (CVT_u64_f64 Float64Regs:$a, CvtRMI)>; +def : Pat<(int_nvvm_d2ull_rp Float64Regs:$a), + (CVT_u64_f64 Float64Regs:$a, CvtRPI)>; + +def : Pat<(int_nvvm_ll2f_rn Int64Regs:$a), + (CVT_f32_s64 Int64Regs:$a, CvtRN)>; +def : Pat<(int_nvvm_ll2f_rz Int64Regs:$a), + (CVT_f32_s64 Int64Regs:$a, CvtRZ)>; +def : Pat<(int_nvvm_ll2f_rm Int64Regs:$a), + (CVT_f32_s64 Int64Regs:$a, CvtRM)>; +def : Pat<(int_nvvm_ll2f_rp Int64Regs:$a), + (CVT_f32_s64 Int64Regs:$a, CvtRP)>; + +def : Pat<(int_nvvm_ull2f_rn Int64Regs:$a), + (CVT_f32_u64 Int64Regs:$a, CvtRN)>; +def : Pat<(int_nvvm_ull2f_rz Int64Regs:$a), + (CVT_f32_u64 Int64Regs:$a, CvtRZ)>; +def : Pat<(int_nvvm_ull2f_rm Int64Regs:$a), + (CVT_f32_u64 Int64Regs:$a, CvtRM)>; +def : Pat<(int_nvvm_ull2f_rp Int64Regs:$a), + (CVT_f32_u64 Int64Regs:$a, CvtRP)>; + +def : Pat<(int_nvvm_ll2d_rn Int64Regs:$a), + (CVT_f64_s64 Int64Regs:$a, CvtRN)>; +def : Pat<(int_nvvm_ll2d_rz Int64Regs:$a), + (CVT_f64_s64 Int64Regs:$a, CvtRZ)>; +def : Pat<(int_nvvm_ll2d_rm Int64Regs:$a), + (CVT_f64_s64 Int64Regs:$a, CvtRM)>; +def : Pat<(int_nvvm_ll2d_rp Int64Regs:$a), + (CVT_f64_s64 Int64Regs:$a, CvtRP)>; + +def : Pat<(int_nvvm_ull2d_rn Int64Regs:$a), + (CVT_f64_u64 Int64Regs:$a, CvtRN)>; +def : Pat<(int_nvvm_ull2d_rz Int64Regs:$a), + (CVT_f64_u64 Int64Regs:$a, CvtRZ)>; +def : Pat<(int_nvvm_ull2d_rm Int64Regs:$a), + (CVT_f64_u64 Int64Regs:$a, CvtRM)>; +def : Pat<(int_nvvm_ull2d_rp Int64Regs:$a), + (CVT_f64_u64 Int64Regs:$a, CvtRP)>; + + +// FIXME: Ideally, we could use these patterns instead of the scope-creating +// patterns, but ptxas does not like these since .s16 is not compatible with +// .f16. The solution is to use .bXX for all integer register types, but we +// are not there yet. +//def : Pat<(int_nvvm_f2h_rn_ftz Float32Regs:$a), +// (CVT_f16_f32 Float32Regs:$a, CvtRN_FTZ)>; +//def : Pat<(int_nvvm_f2h_rn Float32Regs:$a), +// (CVT_f16_f32 Float32Regs:$a, CvtRN)>; +// +//def : Pat<(int_nvvm_h2f Int16Regs:$a), +// (CVT_f32_f16 Int16Regs:$a, CvtNONE)>; def INT_NVVM_F2H_RN_FTZ : F_MATH_1<!strconcat("{{\n\t", !strconcat(".reg .b16 %temp;\n\t", @@ -793,6 +792,13 @@ def INT_NVVM_H2F : F_MATH_1<!strconcat("{{\n\t", "}}")))), Float32Regs, Int16Regs, int_nvvm_h2f>; +def : Pat<(f32 (f16_to_f32 Int16Regs:$a)), + (CVT_f32_f16 Int16Regs:$a, CvtNONE)>; +def : Pat<(i16 (f32_to_f16 Float32Regs:$a)), + (CVT_f16_f32 Float32Regs:$a, CvtRN_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(i16 (f32_to_f16 Float32Regs:$a)), + (CVT_f16_f32 Float32Regs:$a, CvtRN)>; + // // Bitcast // @@ -1270,6 +1276,11 @@ def INT_PTX_SREG_WARPSIZE : F_SREG<"mov.u32 \t$dst, WARP_SZ;", Int32Regs, // Support for ldu on sm_20 or later //----------------------------------- +def ldu_i8 : PatFrag<(ops node:$ptr), (int_nvvm_ldu_global_i node:$ptr), [{ + MemIntrinsicSDNode *M = cast<MemIntrinsicSDNode>(N); + return M->getMemoryVT() == MVT::i8; +}]>; + // Scalar // @TODO: Revisit this, Changed imemAny to imem multiclass LDU_G<string TyStr, NVPTXRegClass regclass, Intrinsic IntOp> { @@ -1291,8 +1302,27 @@ multiclass LDU_G<string TyStr, NVPTXRegClass regclass, Intrinsic IntOp> { [(set regclass:$result, (IntOp ADDRri64:$src))]>, Requires<[hasLDU]>; } -defm INT_PTX_LDU_GLOBAL_i8 : LDU_G<"u8 \t$result, [$src];", Int8Regs, -int_nvvm_ldu_global_i>; +multiclass LDU_G_NOINTRIN<string TyStr, NVPTXRegClass regclass, PatFrag IntOp> { + def areg: NVPTXInst<(outs regclass:$result), (ins Int32Regs:$src), + !strconcat("ldu.global.", TyStr), + [(set regclass:$result, (IntOp Int32Regs:$src))]>, Requires<[hasLDU]>; + def areg64: NVPTXInst<(outs regclass:$result), (ins Int64Regs:$src), + !strconcat("ldu.global.", TyStr), + [(set regclass:$result, (IntOp Int64Regs:$src))]>, Requires<[hasLDU]>; + def avar: NVPTXInst<(outs regclass:$result), (ins imem:$src), + !strconcat("ldu.global.", TyStr), + [(set regclass:$result, (IntOp (Wrapper tglobaladdr:$src)))]>, + Requires<[hasLDU]>; + def ari : NVPTXInst<(outs regclass:$result), (ins MEMri:$src), + !strconcat("ldu.global.", TyStr), + [(set regclass:$result, (IntOp ADDRri:$src))]>, Requires<[hasLDU]>; + def ari64 : NVPTXInst<(outs regclass:$result), (ins MEMri64:$src), + !strconcat("ldu.global.", TyStr), + [(set regclass:$result, (IntOp ADDRri64:$src))]>, Requires<[hasLDU]>; +} + +defm INT_PTX_LDU_GLOBAL_i8 : LDU_G_NOINTRIN<"u8 \t$result, [$src];", Int16Regs, + ldu_i8>; defm INT_PTX_LDU_GLOBAL_i16 : LDU_G<"u16 \t$result, [$src];", Int16Regs, int_nvvm_ldu_global_i>; defm INT_PTX_LDU_GLOBAL_i32 : LDU_G<"u32 \t$result, [$src];", Int32Regs, @@ -1312,25 +1342,43 @@ int_nvvm_ldu_global_p>; // Elementized vector ldu multiclass VLDU_G_ELE_V2<string TyStr, NVPTXRegClass regclass> { - def _32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2), - (ins Int32Regs:$src), + def _areg32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2), + (ins Int32Regs:$src), + !strconcat("ldu.global.", TyStr), []>; + def _areg64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2), + (ins Int64Regs:$src), + !strconcat("ldu.global.", TyStr), []>; + def _ari32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2), + (ins MEMri:$src), !strconcat("ldu.global.", TyStr), []>; - def _64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2), - (ins Int64Regs:$src), + def _ari64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2), + (ins MEMri64:$src), + !strconcat("ldu.global.", TyStr), []>; + def _avar: NVPTXInst<(outs regclass:$dst1, regclass:$dst2), + (ins imemAny:$src), !strconcat("ldu.global.", TyStr), []>; } -multiclass VLDU_G_ELE_V4<string TyStr, NVPTXRegClass regclass> { - def _32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, - regclass:$dst4), (ins Int32Regs:$src), +multiclass VLDU_G_ELE_V4<string TyStr, NVPTXRegClass regclass> { + def _areg32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, + regclass:$dst4), (ins Int32Regs:$src), + !strconcat("ldu.global.", TyStr), []>; + def _areg64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, + regclass:$dst4), (ins Int64Regs:$src), + !strconcat("ldu.global.", TyStr), []>; + def _ari32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, + regclass:$dst4), (ins MEMri:$src), !strconcat("ldu.global.", TyStr), []>; - def _64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, - regclass:$dst4), (ins Int64Regs:$src), + def _ari64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, + regclass:$dst4), (ins MEMri64:$src), + !strconcat("ldu.global.", TyStr), []>; + def _avar: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, + regclass:$dst4), (ins imemAny:$src), !strconcat("ldu.global.", TyStr), []>; } defm INT_PTX_LDU_G_v2i8_ELE - : VLDU_G_ELE_V2<"v2.u8 \t{{$dst1, $dst2}}, [$src];", Int8Regs>; + : VLDU_G_ELE_V2<"v2.u8 \t{{$dst1, $dst2}}, [$src];", Int16Regs>; defm INT_PTX_LDU_G_v2i16_ELE : VLDU_G_ELE_V2<"v2.u16 \t{{$dst1, $dst2}}, [$src];", Int16Regs>; defm INT_PTX_LDU_G_v2i32_ELE @@ -1342,7 +1390,7 @@ defm INT_PTX_LDU_G_v2i64_ELE defm INT_PTX_LDU_G_v2f64_ELE : VLDU_G_ELE_V2<"v2.f64 \t{{$dst1, $dst2}}, [$src];", Float64Regs>; defm INT_PTX_LDU_G_v4i8_ELE - : VLDU_G_ELE_V4<"v4.u8 \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];", Int8Regs>; + : VLDU_G_ELE_V4<"v4.u8 \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];", Int16Regs>; defm INT_PTX_LDU_G_v4i16_ELE : VLDU_G_ELE_V4<"v4.u16 \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];", Int16Regs>; @@ -1422,20 +1470,38 @@ defm INT_PTX_LDG_GLOBAL_p64 // Elementized vector ldg multiclass VLDG_G_ELE_V2<string TyStr, NVPTXRegClass regclass> { - def _32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2), + def _areg32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2), (ins Int32Regs:$src), !strconcat("ld.global.nc.", TyStr), []>; - def _64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2), + def _areg64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2), (ins Int64Regs:$src), !strconcat("ld.global.nc.", TyStr), []>; + def _ari32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2), + (ins MEMri:$src), + !strconcat("ld.global.nc.", TyStr), []>; + def _ari64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2), + (ins MEMri64:$src), + !strconcat("ld.global.nc.", TyStr), []>; + def _avar: NVPTXInst<(outs regclass:$dst1, regclass:$dst2), + (ins imemAny:$src), + !strconcat("ld.global.nc.", TyStr), []>; } multiclass VLDG_G_ELE_V4<string TyStr, NVPTXRegClass regclass> { - def _32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, - regclass:$dst3, regclass:$dst4), (ins Int32Regs:$src), + def _areg32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, + regclass:$dst4), (ins Int32Regs:$src), + !strconcat("ld.global.nc.", TyStr), []>; + def _areg64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, + regclass:$dst4), (ins Int64Regs:$src), + !strconcat("ld.global.nc.", TyStr), []>; + def _ari32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, + regclass:$dst4), (ins MEMri:$src), !strconcat("ld.global.nc.", TyStr), []>; - def _64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, - regclass:$dst3, regclass:$dst4), (ins Int64Regs:$src), + def _ari64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, + regclass:$dst4), (ins MEMri64:$src), + !strconcat("ld.global.nc.", TyStr), []>; + def _avar: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, + regclass:$dst4), (ins imemAny:$src), !strconcat("ld.global.nc.", TyStr), []>; } @@ -1542,10 +1608,6 @@ def nvvm_ptr_gen_to_param_64 : NVPTXInst<(outs Int64Regs:$result), // nvvm.move intrinsicc -def nvvm_move_i8 : NVPTXInst<(outs Int8Regs:$r), (ins Int8Regs:$s), - "mov.b16 \t$r, $s;", - [(set Int8Regs:$r, - (int_nvvm_move_i8 Int8Regs:$s))]>; def nvvm_move_i16 : NVPTXInst<(outs Int16Regs:$r), (ins Int16Regs:$s), "mov.b16 \t$r, $s;", [(set Int16Regs:$r, diff --git a/contrib/llvm/lib/Target/NVPTX/NVPTXMCExpr.cpp b/contrib/llvm/lib/Target/NVPTX/NVPTXMCExpr.cpp new file mode 100644 index 0000000..ca24764 --- /dev/null +++ b/contrib/llvm/lib/Target/NVPTX/NVPTXMCExpr.cpp @@ -0,0 +1,46 @@ +//===-- NVPTXMCExpr.cpp - NVPTX specific MC expression classes ------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "nvptx-mcexpr" +#include "NVPTXMCExpr.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/MC/MCAssembler.h" +#include "llvm/MC/MCContext.h" +using namespace llvm; + +const NVPTXFloatMCExpr* +NVPTXFloatMCExpr::Create(VariantKind Kind, APFloat Flt, MCContext &Ctx) { + return new (Ctx) NVPTXFloatMCExpr(Kind, Flt); +} + +void NVPTXFloatMCExpr::PrintImpl(raw_ostream &OS) const { + bool Ignored; + unsigned NumHex; + APFloat APF = getAPFloat(); + + switch (Kind) { + default: llvm_unreachable("Invalid kind!"); + case VK_NVPTX_SINGLE_PREC_FLOAT: + OS << "0f"; + NumHex = 8; + APF.convert(APFloat::IEEEsingle, APFloat::rmNearestTiesToEven, &Ignored); + break; + case VK_NVPTX_DOUBLE_PREC_FLOAT: + OS << "0d"; + NumHex = 16; + APF.convert(APFloat::IEEEdouble, APFloat::rmNearestTiesToEven, &Ignored); + break; + } + + APInt API = APF.bitcastToAPInt(); + std::string HexStr(utohexstr(API.getZExtValue())); + if (HexStr.length() < NumHex) + OS << std::string(NumHex - HexStr.length(), '0'); + OS << utohexstr(API.getZExtValue()); +} diff --git a/contrib/llvm/lib/Target/NVPTX/NVPTXMCExpr.h b/contrib/llvm/lib/Target/NVPTX/NVPTXMCExpr.h new file mode 100644 index 0000000..0efb231 --- /dev/null +++ b/contrib/llvm/lib/Target/NVPTX/NVPTXMCExpr.h @@ -0,0 +1,83 @@ +//===-- NVPTXMCExpr.h - NVPTX specific MC expression classes ----*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +// Modeled after ARMMCExpr + +#ifndef NVPTXMCEXPR_H +#define NVPTXMCEXPR_H + +#include "llvm/ADT/APFloat.h" +#include "llvm/MC/MCExpr.h" + +namespace llvm { + +class NVPTXFloatMCExpr : public MCTargetExpr { +public: + enum VariantKind { + VK_NVPTX_None, + VK_NVPTX_SINGLE_PREC_FLOAT, // FP constant in single-precision + VK_NVPTX_DOUBLE_PREC_FLOAT // FP constant in double-precision + }; + +private: + const VariantKind Kind; + const APFloat Flt; + + explicit NVPTXFloatMCExpr(VariantKind _Kind, APFloat _Flt) + : Kind(_Kind), Flt(_Flt) {} + +public: + /// @name Construction + /// @{ + + static const NVPTXFloatMCExpr *Create(VariantKind Kind, APFloat Flt, + MCContext &Ctx); + + static const NVPTXFloatMCExpr *CreateConstantFPSingle(APFloat Flt, + MCContext &Ctx) { + return Create(VK_NVPTX_SINGLE_PREC_FLOAT, Flt, Ctx); + } + + static const NVPTXFloatMCExpr *CreateConstantFPDouble(APFloat Flt, + MCContext &Ctx) { + return Create(VK_NVPTX_DOUBLE_PREC_FLOAT, Flt, Ctx); + } + + /// @} + /// @name Accessors + /// @{ + + /// getOpcode - Get the kind of this expression. + VariantKind getKind() const { return Kind; } + + /// getSubExpr - Get the child of this expression. + APFloat getAPFloat() const { return Flt; } + +/// @} + + void PrintImpl(raw_ostream &OS) const; + bool EvaluateAsRelocatableImpl(MCValue &Res, + const MCAsmLayout *Layout) const { + return false; + } + void AddValueSymbols(MCAssembler *) const {}; + const MCSection *FindAssociatedSection() const { + return NULL; + } + + // There are no TLS NVPTXMCExprs at the moment. + void fixELFSymbolsInTLSFixups(MCAssembler &Asm) const {} + + static bool classof(const MCExpr *E) { + return E->getKind() == MCExpr::Target; + } +}; +} // end namespace llvm + +#endif diff --git a/contrib/llvm/lib/Target/NVPTX/NVPTXNumRegisters.h b/contrib/llvm/lib/Target/NVPTX/NVPTXNumRegisters.h deleted file mode 100644 index a95c16b..0000000 --- a/contrib/llvm/lib/Target/NVPTX/NVPTXNumRegisters.h +++ /dev/null @@ -1,16 +0,0 @@ - -//===-- NVPTXNumRegisters.h - PTX Register Info ---------------------------===// -// -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. -// -//===----------------------------------------------------------------------===// - -#ifndef NVPTX_NUM_REGISTERS_H -#define NVPTX_NUM_REGISTERS_H - -namespace llvm { const unsigned NVPTXNumRegisters = 396; } - -#endif diff --git a/contrib/llvm/lib/Target/NVPTX/NVPTXPrologEpilogPass.cpp b/contrib/llvm/lib/Target/NVPTX/NVPTXPrologEpilogPass.cpp new file mode 100644 index 0000000..843ebed --- /dev/null +++ b/contrib/llvm/lib/Target/NVPTX/NVPTXPrologEpilogPass.cpp @@ -0,0 +1,225 @@ +//===-- NVPTXPrologEpilogPass.cpp - NVPTX prolog/epilog inserter ----------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file is a copy of the generic LLVM PrologEpilogInserter pass, modified +// to remove unneeded functionality and to handle virtual registers. Most code +// here is a copy of PrologEpilogInserter.cpp. +// +//===----------------------------------------------------------------------===// + +#include "NVPTX.h" +#include "llvm/Pass.h" +#include "llvm/CodeGen/MachineFrameInfo.h" +#include "llvm/CodeGen/MachineFunction.h" +#include "llvm/CodeGen/MachineFunctionPass.h" +#include "llvm/Target/TargetFrameLowering.h" +#include "llvm/Target/TargetRegisterInfo.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; + +namespace { +class NVPTXPrologEpilogPass : public MachineFunctionPass { +public: + static char ID; + NVPTXPrologEpilogPass() : MachineFunctionPass(ID) {} + + virtual bool runOnMachineFunction(MachineFunction &MF); + +private: + void calculateFrameObjectOffsets(MachineFunction &Fn); +}; +} + +MachineFunctionPass *llvm::createNVPTXPrologEpilogPass() { + return new NVPTXPrologEpilogPass(); +} + +char NVPTXPrologEpilogPass::ID = 0; + +bool NVPTXPrologEpilogPass::runOnMachineFunction(MachineFunction &MF) { + const TargetMachine &TM = MF.getTarget(); + const TargetFrameLowering &TFI = *TM.getFrameLowering(); + const TargetRegisterInfo &TRI = *TM.getRegisterInfo(); + bool Modified = false; + + calculateFrameObjectOffsets(MF); + + for (MachineFunction::iterator BB = MF.begin(), E = MF.end(); BB != E; ++BB) { + for (MachineBasicBlock::iterator I = BB->begin(); I != BB->end(); ++I) { + MachineInstr *MI = I; + for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) { + if (!MI->getOperand(i).isFI()) + continue; + TRI.eliminateFrameIndex(MI, 0, i, NULL); + Modified = true; + } + } + } + + // Add function prolog/epilog + TFI.emitPrologue(MF); + + for (MachineFunction::iterator I = MF.begin(), E = MF.end(); I != E; ++I) { + // If last instruction is a return instruction, add an epilogue + if (!I->empty() && I->back().isReturn()) + TFI.emitEpilogue(MF, *I); + } + + return Modified; +} + +/// AdjustStackOffset - Helper function used to adjust the stack frame offset. +static inline void +AdjustStackOffset(MachineFrameInfo *MFI, int FrameIdx, + bool StackGrowsDown, int64_t &Offset, + unsigned &MaxAlign) { + // If the stack grows down, add the object size to find the lowest address. + if (StackGrowsDown) + Offset += MFI->getObjectSize(FrameIdx); + + unsigned Align = MFI->getObjectAlignment(FrameIdx); + + // If the alignment of this object is greater than that of the stack, then + // increase the stack alignment to match. + MaxAlign = std::max(MaxAlign, Align); + + // Adjust to alignment boundary. + Offset = (Offset + Align - 1) / Align * Align; + + if (StackGrowsDown) { + DEBUG(dbgs() << "alloc FI(" << FrameIdx << ") at SP[" << -Offset << "]\n"); + MFI->setObjectOffset(FrameIdx, -Offset); // Set the computed offset + } else { + DEBUG(dbgs() << "alloc FI(" << FrameIdx << ") at SP[" << Offset << "]\n"); + MFI->setObjectOffset(FrameIdx, Offset); + Offset += MFI->getObjectSize(FrameIdx); + } +} + +void +NVPTXPrologEpilogPass::calculateFrameObjectOffsets(MachineFunction &Fn) { + const TargetFrameLowering &TFI = *Fn.getTarget().getFrameLowering(); + const TargetRegisterInfo *RegInfo = Fn.getTarget().getRegisterInfo(); + + bool StackGrowsDown = + TFI.getStackGrowthDirection() == TargetFrameLowering::StackGrowsDown; + + // Loop over all of the stack objects, assigning sequential addresses... + MachineFrameInfo *MFI = Fn.getFrameInfo(); + + // Start at the beginning of the local area. + // The Offset is the distance from the stack top in the direction + // of stack growth -- so it's always nonnegative. + int LocalAreaOffset = TFI.getOffsetOfLocalArea(); + if (StackGrowsDown) + LocalAreaOffset = -LocalAreaOffset; + assert(LocalAreaOffset >= 0 + && "Local area offset should be in direction of stack growth"); + int64_t Offset = LocalAreaOffset; + + // If there are fixed sized objects that are preallocated in the local area, + // non-fixed objects can't be allocated right at the start of local area. + // We currently don't support filling in holes in between fixed sized + // objects, so we adjust 'Offset' to point to the end of last fixed sized + // preallocated object. + for (int i = MFI->getObjectIndexBegin(); i != 0; ++i) { + int64_t FixedOff; + if (StackGrowsDown) { + // The maximum distance from the stack pointer is at lower address of + // the object -- which is given by offset. For down growing stack + // the offset is negative, so we negate the offset to get the distance. + FixedOff = -MFI->getObjectOffset(i); + } else { + // The maximum distance from the start pointer is at the upper + // address of the object. + FixedOff = MFI->getObjectOffset(i) + MFI->getObjectSize(i); + } + if (FixedOff > Offset) Offset = FixedOff; + } + + // NOTE: We do not have a call stack + + unsigned MaxAlign = MFI->getMaxAlignment(); + + // No scavenger + + // FIXME: Once this is working, then enable flag will change to a target + // check for whether the frame is large enough to want to use virtual + // frame index registers. Functions which don't want/need this optimization + // will continue to use the existing code path. + if (MFI->getUseLocalStackAllocationBlock()) { + unsigned Align = MFI->getLocalFrameMaxAlign(); + + // Adjust to alignment boundary. + Offset = (Offset + Align - 1) / Align * Align; + + DEBUG(dbgs() << "Local frame base offset: " << Offset << "\n"); + + // Resolve offsets for objects in the local block. + for (unsigned i = 0, e = MFI->getLocalFrameObjectCount(); i != e; ++i) { + std::pair<int, int64_t> Entry = MFI->getLocalFrameObjectMap(i); + int64_t FIOffset = (StackGrowsDown ? -Offset : Offset) + Entry.second; + DEBUG(dbgs() << "alloc FI(" << Entry.first << ") at SP[" << + FIOffset << "]\n"); + MFI->setObjectOffset(Entry.first, FIOffset); + } + // Allocate the local block + Offset += MFI->getLocalFrameSize(); + + MaxAlign = std::max(Align, MaxAlign); + } + + // No stack protector + + // Then assign frame offsets to stack objects that are not used to spill + // callee saved registers. + for (unsigned i = 0, e = MFI->getObjectIndexEnd(); i != e; ++i) { + if (MFI->isObjectPreAllocated(i) && + MFI->getUseLocalStackAllocationBlock()) + continue; + if (MFI->isDeadObjectIndex(i)) + continue; + + AdjustStackOffset(MFI, i, StackGrowsDown, Offset, MaxAlign); + } + + // No scavenger + + if (!TFI.targetHandlesStackFrameRounding()) { + // If we have reserved argument space for call sites in the function + // immediately on entry to the current function, count it as part of the + // overall stack size. + if (MFI->adjustsStack() && TFI.hasReservedCallFrame(Fn)) + Offset += MFI->getMaxCallFrameSize(); + + // Round up the size to a multiple of the alignment. If the function has + // any calls or alloca's, align to the target's StackAlignment value to + // ensure that the callee's frame or the alloca data is suitably aligned; + // otherwise, for leaf functions, align to the TransientStackAlignment + // value. + unsigned StackAlign; + if (MFI->adjustsStack() || MFI->hasVarSizedObjects() || + (RegInfo->needsStackRealignment(Fn) && MFI->getObjectIndexEnd() != 0)) + StackAlign = TFI.getStackAlignment(); + else + StackAlign = TFI.getTransientStackAlignment(); + + // If the frame pointer is eliminated, all frame offsets will be relative to + // SP not FP. Align to MaxAlign so this works. + StackAlign = std::max(StackAlign, MaxAlign); + unsigned AlignMask = StackAlign - 1; + Offset = (Offset + AlignMask) & ~uint64_t(AlignMask); + } + + // Update frame info to pretend that this is part of the stack... + int64_t StackSize = Offset - LocalAreaOffset; + MFI->setStackSize(StackSize); +} diff --git a/contrib/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp b/contrib/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp index 2824653..4d3a1d9 100644 --- a/contrib/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp +++ b/contrib/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp @@ -38,10 +38,6 @@ std::string getNVPTXRegClassName(TargetRegisterClass const *RC) { return ".s32"; } else if (RC == &NVPTX::Int16RegsRegClass) { return ".s16"; - } - // Int8Regs become 16-bit registers in PTX - else if (RC == &NVPTX::Int8RegsRegClass) { - return ".s16"; } else if (RC == &NVPTX::Int1RegsRegClass) { return ".pred"; } else if (RC == &NVPTX::SpecialRegsRegClass) { @@ -57,15 +53,13 @@ std::string getNVPTXRegClassStr(TargetRegisterClass const *RC) { return "%f"; } if (RC == &NVPTX::Float64RegsRegClass) { - return "%fd"; + return "%fl"; } else if (RC == &NVPTX::Int64RegsRegClass) { - return "%rd"; + return "%rl"; } else if (RC == &NVPTX::Int32RegsRegClass) { return "%r"; } else if (RC == &NVPTX::Int16RegsRegClass) { return "%rs"; - } else if (RC == &NVPTX::Int8RegsRegClass) { - return "%rc"; } else if (RC == &NVPTX::Int1RegsRegClass) { return "%p"; } else if (RC == &NVPTX::SpecialRegsRegClass) { @@ -77,8 +71,7 @@ std::string getNVPTXRegClassStr(TargetRegisterClass const *RC) { } } -NVPTXRegisterInfo::NVPTXRegisterInfo(const TargetInstrInfo &tii, - const NVPTXSubtarget &st) +NVPTXRegisterInfo::NVPTXRegisterInfo(const NVPTXSubtarget &st) : NVPTXGenRegisterInfo(0), Is64Bit(st.is64Bit()) {} #define GET_REGINFO_TARGET_DESC diff --git a/contrib/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.h b/contrib/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.h index d406820..0a20f29 100644 --- a/contrib/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.h +++ b/contrib/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.h @@ -35,7 +35,7 @@ private: ManagedStringPool ManagedStrPool; public: - NVPTXRegisterInfo(const TargetInstrInfo &tii, const NVPTXSubtarget &st); + NVPTXRegisterInfo(const NVPTXSubtarget &st); //------------------------------------------------------ // Pure virtual functions from TargetRegisterInfo diff --git a/contrib/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td b/contrib/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td index 8d100d6..7a38a66 100644 --- a/contrib/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td +++ b/contrib/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td @@ -29,9 +29,10 @@ def VRFrameLocal : NVPTXReg<"%SPL">; // Special Registers used as the stack def VRDepot : NVPTXReg<"%Depot">; -foreach i = 0-395 in { +// We use virtual registers, but define a few physical registers here to keep +// SDAG and the MachineInstr layers happy. +foreach i = 0-4 in { def P#i : NVPTXReg<"%p"#i>; // Predicate - def RC#i : NVPTXReg<"%rc"#i>; // 8-bit def RS#i : NVPTXReg<"%rs"#i>; // 16-bit def R#i : NVPTXReg<"%r"#i>; // 32-bit def RL#i : NVPTXReg<"%rl"#i>; // 64-bit @@ -48,17 +49,16 @@ foreach i = 0-395 in { //===----------------------------------------------------------------------===// // Register classes //===----------------------------------------------------------------------===// -def Int1Regs : NVPTXRegClass<[i1], 8, (add (sequence "P%u", 0, 395))>; -def Int8Regs : NVPTXRegClass<[i8], 8, (add (sequence "RC%u", 0, 395))>; -def Int16Regs : NVPTXRegClass<[i16], 16, (add (sequence "RS%u", 0, 395))>; -def Int32Regs : NVPTXRegClass<[i32], 32, (add (sequence "R%u", 0, 395))>; -def Int64Regs : NVPTXRegClass<[i64], 64, (add (sequence "RL%u", 0, 395))>; -def Float32Regs : NVPTXRegClass<[f32], 32, (add (sequence "F%u", 0, 395))>; -def Float64Regs : NVPTXRegClass<[f64], 64, (add (sequence "FL%u", 0, 395))>; -def Int32ArgRegs : NVPTXRegClass<[i32], 32, (add (sequence "ia%u", 0, 395))>; -def Int64ArgRegs : NVPTXRegClass<[i64], 64, (add (sequence "la%u", 0, 395))>; -def Float32ArgRegs : NVPTXRegClass<[f32], 32, (add (sequence "fa%u", 0, 395))>; -def Float64ArgRegs : NVPTXRegClass<[f64], 64, (add (sequence "da%u", 0, 395))>; +def Int1Regs : NVPTXRegClass<[i1], 8, (add (sequence "P%u", 0, 4))>; +def Int16Regs : NVPTXRegClass<[i16], 16, (add (sequence "RS%u", 0, 4))>; +def Int32Regs : NVPTXRegClass<[i32], 32, (add (sequence "R%u", 0, 4))>; +def Int64Regs : NVPTXRegClass<[i64], 64, (add (sequence "RL%u", 0, 4))>; +def Float32Regs : NVPTXRegClass<[f32], 32, (add (sequence "F%u", 0, 4))>; +def Float64Regs : NVPTXRegClass<[f64], 64, (add (sequence "FL%u", 0, 4))>; +def Int32ArgRegs : NVPTXRegClass<[i32], 32, (add (sequence "ia%u", 0, 4))>; +def Int64ArgRegs : NVPTXRegClass<[i64], 64, (add (sequence "la%u", 0, 4))>; +def Float32ArgRegs : NVPTXRegClass<[f32], 32, (add (sequence "fa%u", 0, 4))>; +def Float64ArgRegs : NVPTXRegClass<[f64], 64, (add (sequence "da%u", 0, 4))>; // Read NVPTXRegisterInfo.cpp to see how VRFrame and VRDepot are used. def SpecialRegs : NVPTXRegClass<[i32], 32, (add VRFrame, VRDepot)>; diff --git a/contrib/llvm/lib/Target/NVPTX/NVPTXSection.h b/contrib/llvm/lib/Target/NVPTX/NVPTXSection.h index e57ace9..f8a692e 100644 --- a/contrib/llvm/lib/Target/NVPTX/NVPTXSection.h +++ b/contrib/llvm/lib/Target/NVPTX/NVPTXSection.h @@ -24,10 +24,10 @@ namespace llvm { /// the ASMPrint interface. /// class NVPTXSection : public MCSection { - + virtual void anchor(); public: NVPTXSection(SectionVariant V, SectionKind K) : MCSection(V, K) {} - ~NVPTXSection() {} + virtual ~NVPTXSection() {} /// Override this as NVPTX has its own way of printing switching /// to a section. diff --git a/contrib/llvm/lib/Target/NVPTX/NVPTXSplitBBatBar.cpp b/contrib/llvm/lib/Target/NVPTX/NVPTXSplitBBatBar.cpp index 83dfe12..b64c308 100644 --- a/contrib/llvm/lib/Target/NVPTX/NVPTXSplitBBatBar.cpp +++ b/contrib/llvm/lib/Target/NVPTX/NVPTXSplitBBatBar.cpp @@ -36,7 +36,7 @@ bool NVPTXSplitBBatBar::runOnFunction(Function &F) { BasicBlock::iterator II = IB; BasicBlock::iterator IE = BI->end(); - // Skit the first intruction. No splitting is needed at this + // Skit the first instruction. No splitting is needed at this // point even if this is a bar. while (II != IE) { if (IntrinsicInst *inst = dyn_cast<IntrinsicInst>(II)) { diff --git a/contrib/llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp b/contrib/llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp index 2dcd73d..9771a17 100644 --- a/contrib/llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp +++ b/contrib/llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp @@ -19,23 +19,21 @@ using namespace llvm; -// Select Driver Interface -#include "llvm/Support/CommandLine.h" -namespace { -cl::opt<NVPTX::DrvInterface> DriverInterface( - cl::desc("Choose driver interface:"), - cl::values(clEnumValN(NVPTX::NVCL, "drvnvcl", "Nvidia OpenCL driver"), - clEnumValN(NVPTX::CUDA, "drvcuda", "Nvidia CUDA driver"), - clEnumValN(NVPTX::TEST, "drvtest", "Plain Test"), clEnumValEnd), - cl::init(NVPTX::NVCL)); -} + +// Pin the vtable to this file. +void NVPTXSubtarget::anchor() {} NVPTXSubtarget::NVPTXSubtarget(const std::string &TT, const std::string &CPU, const std::string &FS, bool is64Bit) : NVPTXGenSubtargetInfo(TT, CPU, FS), Is64Bit(is64Bit), PTXVersion(0), SmVersion(20) { - drvInterface = DriverInterface; + Triple T(TT); + + if (T.getOS() == Triple::NVCL) + drvInterface = NVPTX::NVCL; + else + drvInterface = NVPTX::CUDA; // Provide the default CPU if none std::string defCPU = "sm_20"; diff --git a/contrib/llvm/lib/Target/NVPTX/NVPTXSubtarget.h b/contrib/llvm/lib/Target/NVPTX/NVPTXSubtarget.h index 670077d..004be11 100644 --- a/contrib/llvm/lib/Target/NVPTX/NVPTXSubtarget.h +++ b/contrib/llvm/lib/Target/NVPTX/NVPTXSubtarget.h @@ -25,7 +25,7 @@ namespace llvm { class NVPTXSubtarget : public NVPTXGenSubtargetInfo { - + virtual void anchor(); std::string TargetName; NVPTX::DrvInterface drvInterface; bool Is64Bit; diff --git a/contrib/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp b/contrib/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp index 1ae2a7c..46edd6d 100644 --- a/contrib/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp +++ b/contrib/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp @@ -57,9 +57,6 @@ extern "C" void LLVMInitializeNVPTXTarget() { RegisterTargetMachine<NVPTXTargetMachine32> X(TheNVPTXTarget32); RegisterTargetMachine<NVPTXTargetMachine64> Y(TheNVPTXTarget64); - RegisterMCAsmInfo<NVPTXMCAsmInfo> A(TheNVPTXTarget32); - RegisterMCAsmInfo<NVPTXMCAsmInfo> B(TheNVPTXTarget64); - // FIXME: This pass is really intended to be invoked during IR optimization, // but it's very NVPTX-specific. initializeNVVMReflectPass(*PassRegistry::getPassRegistry()); @@ -74,7 +71,9 @@ NVPTXTargetMachine::NVPTXTargetMachine( Subtarget(TT, CPU, FS, is64bit), DL(Subtarget.getDataLayout()), InstrInfo(*this), TLInfo(*this), TSInfo(*this), FrameLowering( - *this, is64bit) /*FrameInfo(TargetFrameInfo::StackGrowsUp, 8, 0)*/ {} + *this, is64bit) /*FrameInfo(TargetFrameInfo::StackGrowsUp, 8, 0)*/ { + initAsmInfo(); +} void NVPTXTargetMachine32::anchor() {} @@ -92,7 +91,7 @@ NVPTXTargetMachine64::NVPTXTargetMachine64( CodeGenOpt::Level OL) : NVPTXTargetMachine(T, TT, CPU, FS, Options, RM, CM, OL, true) {} -namespace llvm { +namespace { class NVPTXPassConfig : public TargetPassConfig { public: NVPTXPassConfig(NVPTXTargetMachine *TM, PassManagerBase &PM) @@ -105,8 +104,13 @@ public: virtual void addIRPasses(); virtual bool addInstSelector(); virtual bool addPreRegAlloc(); + virtual bool addPostRegAlloc(); + + virtual FunctionPass *createTargetRegisterAllocator(bool) LLVM_OVERRIDE; + virtual void addFastRegAlloc(FunctionPass *RegAllocPass); + virtual void addOptimizedRegAlloc(FunctionPass *RegAllocPass); }; -} +} // end anonymous namespace TargetPassConfig *NVPTXTargetMachine::createPassConfig(PassManagerBase &PM) { NVPTXPassConfig *PassConfig = new NVPTXPassConfig(this, PM); @@ -114,6 +118,16 @@ TargetPassConfig *NVPTXTargetMachine::createPassConfig(PassManagerBase &PM) { } void NVPTXPassConfig::addIRPasses() { + // The following passes are known to not play well with virtual regs hanging + // around after register allocation (which in our case, is *all* registers). + // We explicitly disable them here. We do, however, need some functionality + // of the PrologEpilogCodeInserter pass, so we emulate that behavior in the + // NVPTXPrologEpilog pass (see NVPTXPrologEpilogPass.cpp). + disablePass(&PrologEpilogCodeInserterID); + disablePass(&MachineCopyPropagationID); + disablePass(&BranchFolderPassID); + disablePass(&TailDuplicateID); + TargetPassConfig::addIRPasses(); addPass(createGenericToNVVMPass()); } @@ -127,3 +141,41 @@ bool NVPTXPassConfig::addInstSelector() { } bool NVPTXPassConfig::addPreRegAlloc() { return false; } +bool NVPTXPassConfig::addPostRegAlloc() { + addPass(createNVPTXPrologEpilogPass()); + return false; +} + +FunctionPass *NVPTXPassConfig::createTargetRegisterAllocator(bool) { + return 0; // No reg alloc +} + +void NVPTXPassConfig::addFastRegAlloc(FunctionPass *RegAllocPass) { + assert(!RegAllocPass && "NVPTX uses no regalloc!"); + addPass(&PHIEliminationID); + addPass(&TwoAddressInstructionPassID); +} + +void NVPTXPassConfig::addOptimizedRegAlloc(FunctionPass *RegAllocPass) { + assert(!RegAllocPass && "NVPTX uses no regalloc!"); + + addPass(&ProcessImplicitDefsID); + addPass(&LiveVariablesID); + addPass(&MachineLoopInfoID); + addPass(&PHIEliminationID); + + addPass(&TwoAddressInstructionPassID); + addPass(&RegisterCoalescerID); + + // PreRA instruction scheduling. + if (addPass(&MachineSchedulerID)) + printAndVerify("After Machine Scheduling"); + + + addPass(&StackSlotColoringID); + + // FIXME: Needs physical registers + //addPass(&PostRAMachineLICMID); + + printAndVerify("After StackSlotColoring"); +} diff --git a/contrib/llvm/lib/Target/NVPTX/NVPTXTargetObjectFile.h b/contrib/llvm/lib/Target/NVPTX/NVPTXTargetObjectFile.h index 6ab0e08..2a7394b 100644 --- a/contrib/llvm/lib/Target/NVPTX/NVPTXTargetObjectFile.h +++ b/contrib/llvm/lib/Target/NVPTX/NVPTXTargetObjectFile.h @@ -21,31 +21,33 @@ class Module; class NVPTXTargetObjectFile : public TargetLoweringObjectFile { public: - NVPTXTargetObjectFile() {} - ~NVPTXTargetObjectFile() { - delete TextSection; - delete DataSection; - delete BSSSection; - delete ReadOnlySection; + NVPTXTargetObjectFile() { + TextSection = 0; + DataSection = 0; + BSSSection = 0; + ReadOnlySection = 0; - delete StaticCtorSection; - delete StaticDtorSection; - delete LSDASection; - delete EHFrameSection; - delete DwarfAbbrevSection; - delete DwarfInfoSection; - delete DwarfLineSection; - delete DwarfFrameSection; - delete DwarfPubTypesSection; - delete DwarfDebugInlineSection; - delete DwarfStrSection; - delete DwarfLocSection; - delete DwarfARangesSection; - delete DwarfRangesSection; - delete DwarfMacroInfoSection; + StaticCtorSection = 0; + StaticDtorSection = 0; + LSDASection = 0; + EHFrameSection = 0; + DwarfAbbrevSection = 0; + DwarfInfoSection = 0; + DwarfLineSection = 0; + DwarfFrameSection = 0; + DwarfPubTypesSection = 0; + DwarfDebugInlineSection = 0; + DwarfStrSection = 0; + DwarfLocSection = 0; + DwarfARangesSection = 0; + DwarfRangesSection = 0; + DwarfMacroInfoSection = 0; } + virtual ~NVPTXTargetObjectFile(); + virtual void Initialize(MCContext &ctx, const TargetMachine &TM) { + TargetLoweringObjectFile::Initialize(ctx, TM); TextSection = new NVPTXSection(MCSection::SV_ELF, SectionKind::getText()); DataSection = new NVPTXSection(MCSection::SV_ELF, SectionKind::getDataRel()); diff --git a/contrib/llvm/lib/Target/NVPTX/NVVMReflect.cpp b/contrib/llvm/lib/Target/NVPTX/NVVMReflect.cpp index 3cc324b..7406207 100644 --- a/contrib/llvm/lib/Target/NVPTX/NVVMReflect.cpp +++ b/contrib/llvm/lib/Target/NVPTX/NVVMReflect.cpp @@ -79,7 +79,7 @@ ModulePass *llvm::createNVVMReflectPass(const StringMap<int>& Mapping) { } static cl::opt<bool> -NVVMReflectEnabled("nvvm-reflect-enable", cl::init(true), +NVVMReflectEnabled("nvvm-reflect-enable", cl::init(true), cl::Hidden, cl::desc("NVVM reflection, enabled by default")); char NVVMReflect::ID = 0; @@ -88,7 +88,7 @@ INITIALIZE_PASS(NVVMReflect, "nvvm-reflect", false) static cl::list<std::string> -ReflectList("nvvm-reflect-list", cl::value_desc("name=<int>"), +ReflectList("nvvm-reflect-list", cl::value_desc("name=<int>"), cl::Hidden, cl::desc("A list of string=num assignments"), cl::ValueRequired); |