summaryrefslogtreecommitdiffstats
path: root/contrib/llvm/lib/Target/NVPTX/NVPTXLowerKernelArgs.cpp
blob: 24dcb122b94eb3678bc8b37d6a19246efe7bc51a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
//===-- NVPTXLowerKernelArgs.cpp - Lower kernel arguments -----------------===//
//
//                     The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
//
// Pointer arguments to kernel functions need to be lowered specially.
//
// 1. Copy byval struct args to local memory. This is a preparation for handling
//    cases like
//
//    kernel void foo(struct A arg, ...)
//    {
//      struct A *p = &arg;
//      ...
//      ... = p->filed1 ...  (this is no generic address for .param)
//      p->filed2 = ...      (this is no write access to .param)
//    }
//
// 2. Convert non-byval pointer arguments of CUDA kernels to pointers in the
//    global address space. This allows later optimizations to emit
//    ld.global.*/st.global.* for accessing these pointer arguments. For
//    example,
//
//    define void @foo(float* %input) {
//      %v = load float, float* %input, align 4
//      ...
//    }
//
//    becomes
//
//    define void @foo(float* %input) {
//      %input2 = addrspacecast float* %input to float addrspace(1)*
//      %input3 = addrspacecast float addrspace(1)* %input2 to float*
//      %v = load float, float* %input3, align 4
//      ...
//    }
//
//    Later, NVPTXFavorNonGenericAddrSpaces will optimize it to
//
//    define void @foo(float* %input) {
//      %input2 = addrspacecast float* %input to float addrspace(1)*
//      %v = load float, float addrspace(1)* %input2, align 4
//      ...
//    }
//
// TODO: merge this pass with NVPTXFavorNonGenericAddrSpace so that other passes
// don't cancel the addrspacecast pair this pass emits.
//===----------------------------------------------------------------------===//

#include "NVPTX.h"
#include "NVPTXUtilities.h"
#include "NVPTXTargetMachine.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/Pass.h"

using namespace llvm;

namespace llvm {
void initializeNVPTXLowerKernelArgsPass(PassRegistry &);
}

namespace {
class NVPTXLowerKernelArgs : public FunctionPass {
  bool runOnFunction(Function &F) override;

  // handle byval parameters
  void handleByValParam(Argument *);
  // handle non-byval pointer parameters
  void handlePointerParam(Argument *);

public:
  static char ID; // Pass identification, replacement for typeid
  NVPTXLowerKernelArgs(const NVPTXTargetMachine *TM = nullptr)
      : FunctionPass(ID), TM(TM) {}
  const char *getPassName() const override {
    return "Lower pointer arguments of CUDA kernels";
  }

private:
  const NVPTXTargetMachine *TM;
};
} // namespace

char NVPTXLowerKernelArgs::ID = 1;

INITIALIZE_PASS(NVPTXLowerKernelArgs, "nvptx-lower-kernel-args",
                "Lower kernel arguments (NVPTX)", false, false)

// =============================================================================
// If the function had a byval struct ptr arg, say foo(%struct.x *byval %d),
// then add the following instructions to the first basic block:
//
// %temp = alloca %struct.x, align 8
// %tempd = addrspacecast %struct.x* %d to %struct.x addrspace(101)*
// %tv = load %struct.x addrspace(101)* %tempd
// store %struct.x %tv, %struct.x* %temp, align 8
//
// The above code allocates some space in the stack and copies the incoming
// struct from param space to local space.
// Then replace all occurences of %d by %temp.
// =============================================================================
void NVPTXLowerKernelArgs::handleByValParam(Argument *Arg) {
  Function *Func = Arg->getParent();
  Instruction *FirstInst = &(Func->getEntryBlock().front());
  PointerType *PType = dyn_cast<PointerType>(Arg->getType());

  assert(PType && "Expecting pointer type in handleByValParam");

  Type *StructType = PType->getElementType();
  AllocaInst *AllocA = new AllocaInst(StructType, Arg->getName(), FirstInst);
  // Set the alignment to alignment of the byval parameter. This is because,
  // later load/stores assume that alignment, and we are going to replace
  // the use of the byval parameter with this alloca instruction.
  AllocA->setAlignment(Func->getParamAlignment(Arg->getArgNo() + 1));
  Arg->replaceAllUsesWith(AllocA);

  Value *ArgInParam = new AddrSpaceCastInst(
      Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(),
      FirstInst);
  LoadInst *LI = new LoadInst(ArgInParam, Arg->getName(), FirstInst);
  new StoreInst(LI, AllocA, FirstInst);
}

void NVPTXLowerKernelArgs::handlePointerParam(Argument *Arg) {
  assert(!Arg->hasByValAttr() &&
         "byval params should be handled by handleByValParam");

  Instruction *FirstInst = Arg->getParent()->getEntryBlock().begin();
  Instruction *ArgInGlobal = new AddrSpaceCastInst(
      Arg, PointerType::get(Arg->getType()->getPointerElementType(),
                            ADDRESS_SPACE_GLOBAL),
      Arg->getName(), FirstInst);
  Value *ArgInGeneric = new AddrSpaceCastInst(ArgInGlobal, Arg->getType(),
                                              Arg->getName(), FirstInst);
  // Replace with ArgInGeneric all uses of Args except ArgInGlobal.
  Arg->replaceAllUsesWith(ArgInGeneric);
  ArgInGlobal->setOperand(0, Arg);
}


// =============================================================================
// Main function for this pass.
// =============================================================================
bool NVPTXLowerKernelArgs::runOnFunction(Function &F) {
  // Skip non-kernels. See the comments at the top of this file.
  if (!isKernelFunction(F))
    return false;

  for (Argument &Arg : F.args()) {
    if (Arg.getType()->isPointerTy()) {
      if (Arg.hasByValAttr())
        handleByValParam(&Arg);
      else if (TM && TM->getDrvInterface() == NVPTX::CUDA)
        handlePointerParam(&Arg);
    }
  }
  return true;
}

FunctionPass *
llvm::createNVPTXLowerKernelArgsPass(const NVPTXTargetMachine *TM) {
  return new NVPTXLowerKernelArgs(TM);
}
OpenPOWER on IntegriCloud