diff options
Diffstat (limited to 'lib/Target/PTX/PTXISelLowering.cpp')
-rw-r--r-- | lib/Target/PTX/PTXISelLowering.cpp | 168 |
1 files changed, 134 insertions, 34 deletions
diff --git a/lib/Target/PTX/PTXISelLowering.cpp b/lib/Target/PTX/PTXISelLowering.cpp index 3307d91..ef4455b 100644 --- a/lib/Target/PTX/PTXISelLowering.cpp +++ b/lib/Target/PTX/PTXISelLowering.cpp @@ -11,8 +11,8 @@ // //===----------------------------------------------------------------------===// -#include "PTX.h" #include "PTXISelLowering.h" +#include "PTX.h" #include "PTXMachineFunctionInfo.h" #include "PTXRegisterInfo.h" #include "PTXSubtarget.h" @@ -20,6 +20,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/CodeGen/CallingConvLower.h" #include "llvm/CodeGen/MachineFunction.h" +#include "llvm/CodeGen/MachineFrameInfo.h" #include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/CodeGen/SelectionDAG.h" #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h" @@ -46,6 +47,11 @@ PTXTargetLowering::PTXTargetLowering(TargetMachine &TM) setBooleanVectorContents(ZeroOrOneBooleanContent); // FIXME: Is this correct? setMinFunctionAlignment(2); + // Let LLVM use loads/stores for all mem* operations + maxStoresPerMemcpy = 4096; + maxStoresPerMemmove = 4096; + maxStoresPerMemset = 4096; + //////////////////////////////////// /////////// Expansion ////////////// //////////////////////////////////// @@ -91,7 +97,8 @@ PTXTargetLowering::PTXTargetLowering(TargetMachine &TM) // customise setcc to use bitwise logic if possible - setOperationAction(ISD::SETCC, MVT::i1, Custom); + //setOperationAction(ISD::SETCC, MVT::i1, Custom); + setOperationAction(ISD::SETCC, MVT::i1, Legal); // customize translation of memory addresses @@ -150,18 +157,27 @@ SDValue PTXTargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const { SDValue Op1 = Op.getOperand(1); SDValue Op2 = Op.getOperand(2); DebugLoc dl = Op.getDebugLoc(); - ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get(); + //ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get(); // Look for X == 0, X == 1, X != 0, or X != 1 // We can simplify these to bitwise logic - if (Op1.getOpcode() == ISD::Constant && - (cast<ConstantSDNode>(Op1)->getZExtValue() == 1 || - cast<ConstantSDNode>(Op1)->isNullValue()) && - (CC == ISD::SETEQ || CC == ISD::SETNE)) { + //if (Op1.getOpcode() == ISD::Constant && + // (cast<ConstantSDNode>(Op1)->getZExtValue() == 1 || + // cast<ConstantSDNode>(Op1)->isNullValue()) && + // (CC == ISD::SETEQ || CC == ISD::SETNE)) { + // + // return DAG.getNode(ISD::AND, dl, MVT::i1, Op0, Op1); + //} - return DAG.getNode(ISD::AND, dl, MVT::i1, Op0, Op1); - } + //ConstantSDNode* COp1 = cast<ConstantSDNode>(Op1); + //if(COp1 && COp1->getZExtValue() == 1) { + // if(CC == ISD::SETNE) { + // return DAG.getNode(PTX::XORripreds, dl, MVT::i1, Op0); + // } + //} + + llvm_unreachable("setcc was not matched by a pattern!"); return DAG.getNode(ISD::SETCC, dl, MVT::i1, Op0, Op1, Op2); } @@ -205,7 +221,6 @@ SDValue PTXTargetLowering:: switch (CallConv) { default: llvm_unreachable("Unsupported calling convention"); - break; case CallingConv::PTX_Kernel: MFI->setKernel(true); break; @@ -235,8 +250,25 @@ SDValue PTXTargetLowering:: } else { for (unsigned i = 0, e = Ins.size(); i != e; ++i) { - EVT RegVT = Ins[i].VT; - TargetRegisterClass* TRC = getRegClassFor(RegVT); + EVT RegVT = Ins[i].VT; + const TargetRegisterClass* TRC = getRegClassFor(RegVT); + unsigned RegType; + + // Determine which register class we need + if (RegVT == MVT::i1) + RegType = PTXRegisterType::Pred; + else if (RegVT == MVT::i16) + RegType = PTXRegisterType::B16; + else if (RegVT == MVT::i32) + RegType = PTXRegisterType::B32; + else if (RegVT == MVT::i64) + RegType = PTXRegisterType::B64; + else if (RegVT == MVT::f32) + RegType = PTXRegisterType::F32; + else if (RegVT == MVT::f64) + RegType = PTXRegisterType::F64; + else + llvm_unreachable("Unknown parameter type"); // Use a unique index in the instruction to prevent instruction folding. // Yes, this is a hack. @@ -247,7 +279,7 @@ SDValue PTXTargetLowering:: InVals.push_back(ArgValue); - MFI->addArgReg(Reg); + MFI->addRegister(Reg, RegType, PTXRegisterSpace::Argument); } } @@ -297,26 +329,33 @@ SDValue PTXTargetLowering:: } else { for (unsigned i = 0, e = Outs.size(); i != e; ++i) { EVT RegVT = Outs[i].VT; - TargetRegisterClass* TRC = 0; + const TargetRegisterClass* TRC; + unsigned RegType; // Determine which register class we need if (RegVT == MVT::i1) { TRC = PTX::RegPredRegisterClass; + RegType = PTXRegisterType::Pred; } else if (RegVT == MVT::i16) { TRC = PTX::RegI16RegisterClass; + RegType = PTXRegisterType::B16; } else if (RegVT == MVT::i32) { TRC = PTX::RegI32RegisterClass; + RegType = PTXRegisterType::B32; } else if (RegVT == MVT::i64) { TRC = PTX::RegI64RegisterClass; + RegType = PTXRegisterType::B64; } else if (RegVT == MVT::f32) { TRC = PTX::RegF32RegisterClass; + RegType = PTXRegisterType::F32; } else if (RegVT == MVT::f64) { TRC = PTX::RegF64RegisterClass; + RegType = PTXRegisterType::F64; } else { llvm_unreachable("Unknown parameter type"); @@ -329,7 +368,7 @@ SDValue PTXTargetLowering:: Chain = DAG.getNode(PTXISD::WRITE_PARAM, dl, MVT::Other, Copy, OutReg); - MFI->addRetReg(Reg); + MFI->addRegister(Reg, RegType, PTXRegisterSpace::Return); } } @@ -344,7 +383,7 @@ SDValue PTXTargetLowering:: SDValue PTXTargetLowering::LowerCall(SDValue Chain, SDValue Callee, CallingConv::ID CallConv, bool isVarArg, - bool &isTailCall, + bool doesNotRet, bool &isTailCall, const SmallVectorImpl<ISD::OutputArg> &Outs, const SmallVectorImpl<SDValue> &OutVals, const SmallVectorImpl<ISD::InputArg> &Ins, @@ -352,38 +391,99 @@ PTXTargetLowering::LowerCall(SDValue Chain, SDValue Callee, SmallVectorImpl<SDValue> &InVals) const { MachineFunction& MF = DAG.getMachineFunction(); - PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>(); - PTXParamManager &PM = MFI->getParamManager(); + PTXMachineFunctionInfo *PTXMFI = MF.getInfo<PTXMachineFunctionInfo>(); + PTXParamManager &PM = PTXMFI->getParamManager(); + MachineFrameInfo *MFI = MF.getFrameInfo(); assert(getTargetMachine().getSubtarget<PTXSubtarget>().callsAreHandled() && "Calls are not handled for the target device"); + // Identify the callee function + const GlobalValue *GV = cast<GlobalAddressSDNode>(Callee)->getGlobal(); + const Function *function = cast<Function>(GV); + + // allow non-device calls only for printf + bool isPrintf = function->getName() == "printf" || function->getName() == "puts"; + + assert((isPrintf || function->getCallingConv() == CallingConv::PTX_Device) && + "PTX function calls must be to PTX device functions"); + + unsigned outSize = isPrintf ? 2 : Outs.size(); + std::vector<SDValue> Ops; // The layout of the ops will be [Chain, #Ins, Ins, Callee, #Outs, Outs] - Ops.resize(Outs.size() + Ins.size() + 4); + Ops.resize(outSize + Ins.size() + 4); Ops[0] = Chain; // Identify the callee function - const GlobalValue *GV = cast<GlobalAddressSDNode>(Callee)->getGlobal(); - assert(cast<Function>(GV)->getCallingConv() == CallingConv::PTX_Device && - "PTX function calls must be to PTX device functions"); Callee = DAG.getTargetGlobalAddress(GV, dl, getPointerTy()); Ops[Ins.size()+2] = Callee; - // Generate STORE_PARAM nodes for each function argument. In PTX, function - // arguments are explicitly stored into .param variables and passed as - // arguments. There is no register/stack-based calling convention in PTX. - Ops[Ins.size()+3] = DAG.getTargetConstant(OutVals.size(), MVT::i32); - for (unsigned i = 0; i != OutVals.size(); ++i) { - unsigned Size = OutVals[i].getValueType().getSizeInBits(); - unsigned Param = PM.addLocalParam(Size); - const std::string &ParamName = PM.getParamName(Param); - SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(), - MVT::Other); + // #Outs + Ops[Ins.size()+3] = DAG.getTargetConstant(outSize, MVT::i32); + + if (isPrintf) { + // first argument is the address of the global string variable in memory + unsigned Param0 = PM.addLocalParam(getPointerTy().getSizeInBits()); + SDValue ParamValue0 = DAG.getTargetExternalSymbol(PM.getParamName(Param0).c_str(), + MVT::Other); Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain, - ParamValue, OutVals[i]); - Ops[i+Ins.size()+4] = ParamValue; + ParamValue0, OutVals[0]); + Ops[Ins.size()+4] = ParamValue0; + + // alignment is the maximum size of all the arguments + unsigned alignment = 0; + for (unsigned i = 1; i < OutVals.size(); ++i) { + alignment = std::max(alignment, + OutVals[i].getValueType().getSizeInBits()); + } + + // size is the alignment multiplied by the number of arguments + unsigned size = alignment * (OutVals.size() - 1); + + // second argument is the address of the stack object (unless no arguments) + unsigned Param1 = PM.addLocalParam(getPointerTy().getSizeInBits()); + SDValue ParamValue1 = DAG.getTargetExternalSymbol(PM.getParamName(Param1).c_str(), + MVT::Other); + Ops[Ins.size()+5] = ParamValue1; + + if (size > 0) + { + // create a local stack object to store the arguments + unsigned StackObject = MFI->CreateStackObject(size / 8, alignment / 8, false); + SDValue FrameIndex = DAG.getFrameIndex(StackObject, getPointerTy()); + + // store each of the arguments to the stack in turn + for (unsigned int i = 1; i != OutVals.size(); i++) { + SDValue FrameAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), FrameIndex, DAG.getTargetConstant((i - 1) * 8, getPointerTy())); + Chain = DAG.getStore(Chain, dl, OutVals[i], FrameAddr, + MachinePointerInfo(), + false, false, 0); + } + + // copy the address of the local frame index to get the address in non-local space + SDValue genericAddr = DAG.getNode(PTXISD::COPY_ADDRESS, dl, getPointerTy(), FrameIndex); + + // store this address in the second argument + Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain, ParamValue1, genericAddr); + } + } + else + { + // Generate STORE_PARAM nodes for each function argument. In PTX, function + // arguments are explicitly stored into .param variables and passed as + // arguments. There is no register/stack-based calling convention in PTX. + for (unsigned i = 0; i != OutVals.size(); ++i) { + unsigned Size = OutVals[i].getValueType().getSizeInBits(); + unsigned Param = PM.addLocalParam(Size); + const std::string &ParamName = PM.getParamName(Param); + SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(), + MVT::Other); + Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain, + ParamValue, OutVals[i]); + Ops[i+Ins.size()+4] = ParamValue; + } } std::vector<SDValue> InParams; |