diff options
Diffstat (limited to 'contrib/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp')
-rw-r--r-- | contrib/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp | 50 |
1 files changed, 41 insertions, 9 deletions
diff --git a/contrib/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/contrib/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp index 3c2594c..0139646 100644 --- a/contrib/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp +++ b/contrib/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp @@ -12,11 +12,11 @@ // //===----------------------------------------------------------------------===// +#include "NVPTXAsmPrinter.h" #include "InstPrinter/NVPTXInstPrinter.h" #include "MCTargetDesc/NVPTXBaseInfo.h" #include "MCTargetDesc/NVPTXMCAsmInfo.h" #include "NVPTX.h" -#include "NVPTXAsmPrinter.h" #include "NVPTXMCExpr.h" #include "NVPTXMachineFunctionInfo.h" #include "NVPTXRegisterInfo.h" @@ -73,8 +73,8 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/Path.h" -#include "llvm/Support/raw_ostream.h" #include "llvm/Support/TargetRegistry.h" +#include "llvm/Support/raw_ostream.h" #include "llvm/Target/TargetLowering.h" #include "llvm/Target/TargetLoweringObjectFile.h" #include "llvm/Target/TargetMachine.h" @@ -320,6 +320,10 @@ bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO, switch (Cnt->getType()->getTypeID()) { default: report_fatal_error("Unsupported FP type"); break; + case Type::HalfTyID: + MCOp = MCOperand::createExpr( + NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext)); + break; case Type::FloatTyID: MCOp = MCOperand::createExpr( NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext)); @@ -357,6 +361,10 @@ unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) { Ret = (5 << 28); } else if (RC == &NVPTX::Float64RegsRegClass) { Ret = (6 << 28); + } else if (RC == &NVPTX::Float16RegsRegClass) { + Ret = (7 << 28); + } else if (RC == &NVPTX::Float16x2RegsRegClass) { + Ret = (8 << 28); } else { report_fatal_error("Bad register class"); } @@ -396,12 +404,15 @@ void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) { unsigned size = 0; if (auto *ITy = dyn_cast<IntegerType>(Ty)) { size = ITy->getBitWidth(); - if (size < 32) - size = 32; } else { assert(Ty->isFloatingPointTy() && "Floating point type expected here"); size = Ty->getPrimitiveSizeInBits(); } + // PTX ABI requires all scalar return values to be at least 32 + // bits in size. fp16 normally uses .b16 as its storage type in + // PTX, so its size must be adjusted here, too. + if (size < 32) + size = 32; O << ".param .b" << size << " func_retval0"; } else if (isa<PointerType>(Ty)) { @@ -1221,7 +1232,8 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar, else O << " .align " << GVar->getAlignment(); - if (ETy->isFloatingPointTy() || ETy->isIntegerTy() || ETy->isPointerTy()) { + if (ETy->isFloatingPointTy() || ETy->isPointerTy() || + (ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) { O << " ."; // Special case: ABI requires that we use .u8 for predicates if (ETy->isIntegerTy(1)) @@ -1262,6 +1274,7 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar, // targets that support these high level field accesses. Structs, arrays // and vectors are lowered into arrays of bytes. switch (ETy->getTypeID()) { + case Type::IntegerTyID: // Integers larger than 64 bits case Type::StructTyID: case Type::ArrayTyID: case Type::VectorTyID: @@ -1376,6 +1389,9 @@ NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const { } break; } + case Type::HalfTyID: + // fp16 is stored as .b16 for compatibility with pre-sm_53 PTX assembly. + return "b16"; case Type::FloatTyID: return "f32"; case Type::DoubleTyID: @@ -1477,7 +1493,7 @@ void NVPTXAsmPrinter::printParamName(Function::const_arg_iterator I, void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) { const DataLayout &DL = getDataLayout(); - const AttributeSet &PAL = F->getAttributes(); + const AttributeList &PAL = F->getAttributes(); const TargetLowering *TLI = nvptxSubtarget->getTargetLowering(); Function::const_arg_iterator I, E; unsigned paramIndex = 0; @@ -1534,12 +1550,12 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) { } } - if (!PAL.hasAttribute(paramIndex + 1, Attribute::ByVal)) { + if (!PAL.hasParamAttribute(paramIndex, Attribute::ByVal)) { if (Ty->isAggregateType() || Ty->isVectorTy()) { // Just print .param .align <a> .b8 .param[size]; // <a> = PAL.getparamalignment // size = typeallocsize of element type - unsigned align = PAL.getParamAlignment(paramIndex + 1); + unsigned align = PAL.getParamAlignment(paramIndex); if (align == 0) align = DL.getABITypeAlignment(Ty); @@ -1601,6 +1617,11 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) { sz = 32; } else if (isa<PointerType>(Ty)) sz = thePointerTy.getSizeInBits(); + else if (Ty->isHalfTy()) + // PTX ABI requires all scalar parameters to be at least 32 + // bits in size. fp16 normally uses .b16 as its storage type + // in PTX, so its size must be adjusted here, too. + sz = 32; else sz = Ty->getPrimitiveSizeInBits(); if (isABI) @@ -1620,7 +1641,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) { // Just print .param .align <a> .b8 .param[size]; // <a> = PAL.getparamalignment // size = typeallocsize of element type - unsigned align = PAL.getParamAlignment(paramIndex + 1); + unsigned align = PAL.getParamAlignment(paramIndex); if (align == 0) align = DL.getABITypeAlignment(ETy); // Work around a bug in ptxas. When PTX code takes address of @@ -1977,6 +1998,17 @@ void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV, const DataLayout &DL = getDataLayout(); int Bytes; + // Integers of arbitrary width + if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) { + APInt Val = CI->getValue(); + for (unsigned I = 0, E = DL.getTypeAllocSize(CPV->getType()); I < E; ++I) { + uint8_t Byte = Val.getLoBits(8).getZExtValue(); + aggBuffer->addBytes(&Byte, 1, 1); + Val.lshrInPlace(8); + } + return; + } + // Old constants if (isa<ConstantArray>(CPV) || isa<ConstantVector>(CPV)) { if (CPV->getNumOperands()) |