diff options
Diffstat (limited to 'contrib/llvm/lib/Target/PTX/PTXAsmPrinter.cpp')
-rw-r--r-- | contrib/llvm/lib/Target/PTX/PTXAsmPrinter.cpp | 210 |
1 files changed, 170 insertions, 40 deletions
diff --git a/contrib/llvm/lib/Target/PTX/PTXAsmPrinter.cpp b/contrib/llvm/lib/Target/PTX/PTXAsmPrinter.cpp index a605997..29c4781 100644 --- a/contrib/llvm/lib/Target/PTX/PTXAsmPrinter.cpp +++ b/contrib/llvm/lib/Target/PTX/PTXAsmPrinter.cpp @@ -24,6 +24,7 @@ #include "llvm/ADT/Twine.h" #include "llvm/CodeGen/AsmPrinter.h" #include "llvm/CodeGen/MachineInstr.h" +#include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/MC/MCStreamer.h" #include "llvm/MC/MCSymbol.h" #include "llvm/Target/Mangler.h" @@ -37,13 +38,6 @@ using namespace llvm; -static cl::opt<std::string> -OptPTXVersion("ptx-version", cl::desc("Set PTX version"), cl::init("1.4")); - -static cl::opt<std::string> -OptPTXTarget("ptx-target", cl::desc("Set GPU target (comma-separated list)"), - cl::init("sm_10")); - namespace { class PTXAsmPrinter : public AsmPrinter { public: @@ -68,6 +62,7 @@ public: const char *Modifier = 0); void printParamOperand(const MachineInstr *MI, int opNum, raw_ostream &OS, const char *Modifier = 0); + void printPredicateOperand(const MachineInstr *MI, raw_ostream &O); // autogen'd. void printInstruction(const MachineInstr *MI, raw_ostream &OS); @@ -82,27 +77,20 @@ private: static const char PARAM_PREFIX[] = "__param_"; static const char *getRegisterTypeName(unsigned RegNo) { -#define TEST_REGCLS(cls, clsstr) \ +#define TEST_REGCLS(cls, clsstr) \ if (PTX::cls ## RegisterClass->contains(RegNo)) return # clsstr; - TEST_REGCLS(RRegs32, s32); TEST_REGCLS(Preds, pred); + TEST_REGCLS(RRegu16, u16); + TEST_REGCLS(RRegu32, u32); + TEST_REGCLS(RRegu64, u64); + TEST_REGCLS(RRegf32, f32); + TEST_REGCLS(RRegf64, f64); #undef TEST_REGCLS llvm_unreachable("Not in any register class!"); return NULL; } -static const char *getInstructionTypeName(const MachineInstr *MI) { - for (int i = 0, e = MI->getNumOperands(); i != e; ++i) { - const MachineOperand &MO = MI->getOperand(i); - if (MO.getType() == MachineOperand::MO_Register) - return getRegisterTypeName(MO.getReg()); - } - - llvm_unreachable("No reg operand found in instruction!"); - return NULL; -} - static const char *getStateSpaceName(unsigned addressSpace) { switch (addressSpace) { default: llvm_unreachable("Unknown state space"); @@ -115,6 +103,28 @@ static const char *getStateSpaceName(unsigned addressSpace) { return NULL; } +static const char *getTypeName(const Type* type) { + while (true) { + switch (type->getTypeID()) { + default: llvm_unreachable("Unknown type"); + case Type::FloatTyID: return ".f32"; + case Type::DoubleTyID: return ".f64"; + case Type::IntegerTyID: + switch (type->getPrimitiveSizeInBits()) { + default: llvm_unreachable("Unknown integer bit-width"); + case 16: return ".u16"; + case 32: return ".u32"; + case 64: return ".u64"; + } + case Type::ArrayTyID: + case Type::PointerTyID: + type = dyn_cast<const SequentialType>(type)->getElementType(); + break; + } + } + return NULL; +} + bool PTXAsmPrinter::doFinalization(Module &M) { // XXX Temproarily remove global variables so that doFinalization() will not // emit them again (global variables are emitted at beginning). @@ -146,8 +156,12 @@ bool PTXAsmPrinter::doFinalization(Module &M) { void PTXAsmPrinter::EmitStartOfAsmFile(Module &M) { - OutStreamer.EmitRawText(Twine("\t.version " + OptPTXVersion)); - OutStreamer.EmitRawText(Twine("\t.target " + OptPTXTarget)); + const PTXSubtarget& ST = TM.getSubtarget<PTXSubtarget>(); + + OutStreamer.EmitRawText(Twine("\t.version " + ST.getPTXVersionString())); + OutStreamer.EmitRawText(Twine("\t.target " + ST.getTargetString() + + (ST.supportsDouble() ? "" + : ", map_f64_to_f32"))); OutStreamer.AddBlankLine(); // declare global variables @@ -186,17 +200,16 @@ void PTXAsmPrinter::EmitInstruction(const MachineInstr *MI) { std::string str; str.reserve(64); - // Write instruction to str raw_string_ostream OS(str); + + // Emit predicate + printPredicateOperand(MI, OS); + + // Write instruction to str printInstruction(MI, OS); OS << ';'; OS.flush(); - // Replace "%type" if found - size_t pos; - if ((pos = str.find("%type")) != std::string::npos) - str.replace(pos, /*strlen("%type")==*/5, getInstructionTypeName(MI)); - StringRef strref = StringRef(str); OutStreamer.EmitRawText(strref); } @@ -213,11 +226,36 @@ void PTXAsmPrinter::printOperand(const MachineInstr *MI, int opNum, OS << *Mang->getSymbol(MO.getGlobal()); break; case MachineOperand::MO_Immediate: - OS << (int) MO.getImm(); + OS << (long) MO.getImm(); + break; + case MachineOperand::MO_MachineBasicBlock: + OS << *MO.getMBB()->getSymbol(); break; case MachineOperand::MO_Register: OS << getRegisterName(MO.getReg()); break; + case MachineOperand::MO_FPImmediate: + APInt constFP = MO.getFPImm()->getValueAPF().bitcastToAPInt(); + bool isFloat = MO.getFPImm()->getType()->getTypeID() == Type::FloatTyID; + // Emit 0F for 32-bit floats and 0D for 64-bit doubles. + if (isFloat) { + OS << "0F"; + } + else { + OS << "0D"; + } + // Emit the encoded floating-point value. + if (constFP.getZExtValue() > 0) { + OS << constFP.toString(16, false); + } + else { + OS << "00000000"; + // If We have a double-precision zero, pad to 8-bytes. + if (!isFloat) { + OS << "00000000"; + } + } + break; } } @@ -265,13 +303,77 @@ void PTXAsmPrinter::EmitVariableDeclaration(const GlobalVariable *gv) { decl += " "; } - // TODO: add types - decl += ".s32 "; - decl += gvsym->getName(); + if (PointerType::classof(gv->getType())) { + const PointerType* pointerTy = dyn_cast<const PointerType>(gv->getType()); + const Type* elementTy = pointerTy->getElementType(); + + decl += ".b8 "; + decl += gvsym->getName(); + decl += "["; + + if (elementTy->isArrayTy()) + { + assert(elementTy->isArrayTy() && "Only pointers to arrays are supported"); - if (ArrayType::classof(gv->getType()) || PointerType::classof(gv->getType())) - decl += "[]"; + const ArrayType* arrayTy = dyn_cast<const ArrayType>(elementTy); + elementTy = arrayTy->getElementType(); + + unsigned numElements = arrayTy->getNumElements(); + + while (elementTy->isArrayTy()) { + + arrayTy = dyn_cast<const ArrayType>(elementTy); + elementTy = arrayTy->getElementType(); + + numElements *= arrayTy->getNumElements(); + } + + // FIXME: isPrimitiveType() == false for i16? + assert(elementTy->isSingleValueType() && + "Non-primitive types are not handled"); + + // Compute the size of the array, in bytes. + uint64_t arraySize = (elementTy->getPrimitiveSizeInBits() >> 3) + * numElements; + + decl += utostr(arraySize); + } + + decl += "]"; + + // handle string constants (assume ConstantArray means string) + + if (gv->hasInitializer()) + { + Constant *C = gv->getInitializer(); + if (const ConstantArray *CA = dyn_cast<ConstantArray>(C)) + { + decl += " = {"; + + for (unsigned i = 0, e = C->getNumOperands(); i != e; ++i) + { + if (i > 0) decl += ","; + + decl += "0x" + utohexstr(cast<ConstantInt>(CA->getOperand(i))->getZExtValue()); + } + + decl += "}"; + } + } + } + else { + // Note: this is currently the fall-through case and most likely generates + // incorrect code. + decl += getTypeName(gv->getType()); + decl += " "; + + decl += gvsym->getName(); + + if (ArrayType::classof(gv->getType()) || + PointerType::classof(gv->getType())) + decl += "[]"; + } decl += ";"; @@ -313,16 +415,24 @@ void PTXAsmPrinter::EmitFunctionDeclaration() { if (!MFI->argRegEmpty()) { decl += " ("; if (isKernel) { - for (int i = 0, e = MFI->getNumArg(); i != e; ++i) { - if (i != 0) + unsigned cnt = 0; + for(PTXMachineFunctionInfo::reg_iterator + i = MFI->argRegBegin(), e = MFI->argRegEnd(), b = i; + i != e; ++i) { + reg = *i; + assert(reg != PTX::NoRegister && "Not a valid register!"); + if (i != b) decl += ", "; - decl += ".param .s32 "; // TODO: add types + decl += ".param ."; + decl += getRegisterTypeName(reg); + decl += " "; decl += PARAM_PREFIX; - decl += utostr(i + 1); + decl += utostr(++cnt); } } else { for (PTXMachineFunctionInfo::reg_iterator - i = MFI->argRegBegin(), e = MFI->argRegEnd(), b = i; i != e; ++i) { + i = MFI->argRegBegin(), e = MFI->argRegEnd(), b = i; + i != e; ++i) { reg = *i; assert(reg != PTX::NoRegister && "Not a valid register!"); if (i != b) @@ -339,9 +449,29 @@ void PTXAsmPrinter::EmitFunctionDeclaration() { OutStreamer.EmitRawText(Twine(decl)); } +void PTXAsmPrinter:: +printPredicateOperand(const MachineInstr *MI, raw_ostream &O) { + int i = MI->findFirstPredOperandIdx(); + if (i == -1) + llvm_unreachable("missing predicate operand"); + + unsigned reg = MI->getOperand(i).getReg(); + int predOp = MI->getOperand(i+1).getImm(); + + DEBUG(dbgs() << "predicate: (" << reg << ", " << predOp << ")\n"); + + if (reg != PTX::NoRegister) { + O << '@'; + if (predOp == PTX::PRED_NEGATE) + O << '!'; + O << getRegisterName(reg); + } +} + #include "PTXGenAsmWriter.inc" // Force static initialization. extern "C" void LLVMInitializePTXAsmPrinter() { - RegisterAsmPrinter<PTXAsmPrinter> X(ThePTXTarget); + RegisterAsmPrinter<PTXAsmPrinter> X(ThePTX32Target); + RegisterAsmPrinter<PTXAsmPrinter> Y(ThePTX64Target); } |