summaryrefslogtreecommitdiffstats
path: root/contrib/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp')
-rw-r--r--contrib/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp50
1 files changed, 41 insertions, 9 deletions
diff --git a/contrib/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/contrib/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index 3c2594c..0139646 100644
--- a/contrib/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/contrib/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -12,11 +12,11 @@
//
//===----------------------------------------------------------------------===//
+#include "NVPTXAsmPrinter.h"
#include "InstPrinter/NVPTXInstPrinter.h"
#include "MCTargetDesc/NVPTXBaseInfo.h"
#include "MCTargetDesc/NVPTXMCAsmInfo.h"
#include "NVPTX.h"
-#include "NVPTXAsmPrinter.h"
#include "NVPTXMCExpr.h"
#include "NVPTXMachineFunctionInfo.h"
#include "NVPTXRegisterInfo.h"
@@ -73,8 +73,8 @@
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/Path.h"
-#include "llvm/Support/raw_ostream.h"
#include "llvm/Support/TargetRegistry.h"
+#include "llvm/Support/raw_ostream.h"
#include "llvm/Target/TargetLowering.h"
#include "llvm/Target/TargetLoweringObjectFile.h"
#include "llvm/Target/TargetMachine.h"
@@ -320,6 +320,10 @@ bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO,
switch (Cnt->getType()->getTypeID()) {
default: report_fatal_error("Unsupported FP type"); break;
+ case Type::HalfTyID:
+ MCOp = MCOperand::createExpr(
+ NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext));
+ break;
case Type::FloatTyID:
MCOp = MCOperand::createExpr(
NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext));
@@ -357,6 +361,10 @@ unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {
Ret = (5 << 28);
} else if (RC == &NVPTX::Float64RegsRegClass) {
Ret = (6 << 28);
+ } else if (RC == &NVPTX::Float16RegsRegClass) {
+ Ret = (7 << 28);
+ } else if (RC == &NVPTX::Float16x2RegsRegClass) {
+ Ret = (8 << 28);
} else {
report_fatal_error("Bad register class");
}
@@ -396,12 +404,15 @@ void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
unsigned size = 0;
if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
size = ITy->getBitWidth();
- if (size < 32)
- size = 32;
} else {
assert(Ty->isFloatingPointTy() && "Floating point type expected here");
size = Ty->getPrimitiveSizeInBits();
}
+ // PTX ABI requires all scalar return values to be at least 32
+ // bits in size. fp16 normally uses .b16 as its storage type in
+ // PTX, so its size must be adjusted here, too.
+ if (size < 32)
+ size = 32;
O << ".param .b" << size << " func_retval0";
} else if (isa<PointerType>(Ty)) {
@@ -1221,7 +1232,8 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
else
O << " .align " << GVar->getAlignment();
- if (ETy->isFloatingPointTy() || ETy->isIntegerTy() || ETy->isPointerTy()) {
+ if (ETy->isFloatingPointTy() || ETy->isPointerTy() ||
+ (ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) {
O << " .";
// Special case: ABI requires that we use .u8 for predicates
if (ETy->isIntegerTy(1))
@@ -1262,6 +1274,7 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
// targets that support these high level field accesses. Structs, arrays
// and vectors are lowered into arrays of bytes.
switch (ETy->getTypeID()) {
+ case Type::IntegerTyID: // Integers larger than 64 bits
case Type::StructTyID:
case Type::ArrayTyID:
case Type::VectorTyID:
@@ -1376,6 +1389,9 @@ NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const {
}
break;
}
+ case Type::HalfTyID:
+ // fp16 is stored as .b16 for compatibility with pre-sm_53 PTX assembly.
+ return "b16";
case Type::FloatTyID:
return "f32";
case Type::DoubleTyID:
@@ -1477,7 +1493,7 @@ void NVPTXAsmPrinter::printParamName(Function::const_arg_iterator I,
void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
const DataLayout &DL = getDataLayout();
- const AttributeSet &PAL = F->getAttributes();
+ const AttributeList &PAL = F->getAttributes();
const TargetLowering *TLI = nvptxSubtarget->getTargetLowering();
Function::const_arg_iterator I, E;
unsigned paramIndex = 0;
@@ -1534,12 +1550,12 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
}
}
- if (!PAL.hasAttribute(paramIndex + 1, Attribute::ByVal)) {
+ if (!PAL.hasParamAttribute(paramIndex, Attribute::ByVal)) {
if (Ty->isAggregateType() || Ty->isVectorTy()) {
// Just print .param .align <a> .b8 .param[size];
// <a> = PAL.getparamalignment
// size = typeallocsize of element type
- unsigned align = PAL.getParamAlignment(paramIndex + 1);
+ unsigned align = PAL.getParamAlignment(paramIndex);
if (align == 0)
align = DL.getABITypeAlignment(Ty);
@@ -1601,6 +1617,11 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
sz = 32;
} else if (isa<PointerType>(Ty))
sz = thePointerTy.getSizeInBits();
+ else if (Ty->isHalfTy())
+ // PTX ABI requires all scalar parameters to be at least 32
+ // bits in size. fp16 normally uses .b16 as its storage type
+ // in PTX, so its size must be adjusted here, too.
+ sz = 32;
else
sz = Ty->getPrimitiveSizeInBits();
if (isABI)
@@ -1620,7 +1641,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
// Just print .param .align <a> .b8 .param[size];
// <a> = PAL.getparamalignment
// size = typeallocsize of element type
- unsigned align = PAL.getParamAlignment(paramIndex + 1);
+ unsigned align = PAL.getParamAlignment(paramIndex);
if (align == 0)
align = DL.getABITypeAlignment(ETy);
// Work around a bug in ptxas. When PTX code takes address of
@@ -1977,6 +1998,17 @@ void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
const DataLayout &DL = getDataLayout();
int Bytes;
+ // Integers of arbitrary width
+ if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
+ APInt Val = CI->getValue();
+ for (unsigned I = 0, E = DL.getTypeAllocSize(CPV->getType()); I < E; ++I) {
+ uint8_t Byte = Val.getLoBits(8).getZExtValue();
+ aggBuffer->addBytes(&Byte, 1, 1);
+ Val.lshrInPlace(8);
+ }
+ return;
+ }
+
// Old constants
if (isa<ConstantArray>(CPV) || isa<ConstantVector>(CPV)) {
if (CPV->getNumOperands())
OpenPOWER on IntegriCloud