-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnew-forward.cuh
executable file
·301 lines (235 loc) · 10.6 KB
/
new-forward.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
#ifndef MXNET_OPERATOR_NEW_FORWARD_CUH_
#define MXNET_OPERATOR_NEW_FORWARD_CUH_
#include <mxnet/base.h>
#define CONSTANT_KERNEL
const constexpr static int K = 7;
const constexpr static int B = 10000;
namespace mxnet
{
namespace op
{
__constant__ float deviceKernel[12 * 7 * 7];
__global__ void base(float *y, const float *x, const float *k, const int B, const int M, const int C, const int H, const int W)
{
const int H_out = H - K + 1;
const int W_out = W - K + 1;
(void)H_out; // silence declared but never referenced warning. remove this line when you start working
(void)W_out; // silence declared but never referenced warning. remove this line when you start working
// An example use of these macros:
// float a = y4d(0,0,0,0)
// y4d(0,0,0,0) = a
#define y4d(i3, i2, i1, i0) y[(i3) * (M * H_out * W_out) + (i2) * (H_out * W_out) + (i1) * (H_out) + i0]
#define x4d(i3, i2, i1, i0) x[(i3) * (C * H * W) + (i2) * (H * W) + (i1) * (W) + i0]
#ifdef CONSTANT_KERNEL
#define k4d(i3, i2, i1, i0) deviceKernel[(i3) * (C * K * K) + (i2) * (K * K) + (i1) * (K) + i0]
#else
#define k4d(i3, i2, i1, i0) k[(i3) * (C * K * K) + (i2) * (K * K) + (i1) * (K) + i0]
#endif // CONSTANT_KERNEL
int b = blockDim.x * blockIdx.x + threadIdx.x;
if (b < B) // for each image in the batch
{
for (int m = 0; m < M; m++) // for each output feature maps
for (int h = 0; h < H_out; h++) // for each output element
for (int w = 0; w < W_out; w++)
{
float Pvalue = 0;
for (int c = 0; c < C; c++) // sum over all input feature maps
for (int p = 0; p < K; p++) // KxK filter
for (int q = 0; q < K; q++)
Pvalue += x4d(b, c, h + p, w + q) * k4d(m, c, p, q);
y4d(b, m, h, w) = Pvalue;
}
}
#undef y4d
#undef x4d
#undef k4d
}
template<const int BLOCK_SIZE, const int CHANNELS, const int D_o, const int M, const int H, const int W>
__global__ void parallel_output(float * __restrict__ y, const float * const __restrict__ x){
#define y4d(i3, i2, i1, i0) y[(i3) * (M * D_o * D_o) + (i2) * (D_o * D_o) + (i1) * (D_o) + i0]
#define x4d(i3, i2, i1, i0) x[(i3) * (CHANNELS * H * W) + (i2) * (H * W) + (i1) * (W) + i0]
#define k4d(i3, i2, i1, i0) deviceKernel[(i3) * (CHANNELS * K * K) + (i2) * (K * K) + (i1) * (K) + i0]
const constexpr int D_grid = (D_o / BLOCK_SIZE) + 1;
int bx = blockIdx.x, by = blockIdx.y, bz = blockIdx.z;
int tx = threadIdx.x, ty = threadIdx.y;
int tmp = bz / D_grid;
int row_start = tmp * BLOCK_SIZE + ty;
int col_start = (bz - tmp * D_grid) * BLOCK_SIZE + tx; // (bz % D_grid) => (bz - (bz / D_grid) * D_grid)
if (row_start < D_o && col_start < D_o) {
float Pvalue = 0;
#pragma unroll
for (int ch = 0; ch < CHANNELS; ch++) {
#pragma unroll
for (int r = 0; r < K; r++) {
#pragma unroll
for (int c = 0; c < K; c++) {
Pvalue += x4d(bx, ch, row_start + r, col_start + c) * k4d(by, ch, r, c);
}
}
}
y4d(bx, by, row_start, col_start) = Pvalue;
}
#undef k4d
#undef y4d
#undef x4d
}
template<const int TILE_SIZE, const int BLOCK_SIZE>
__global__ void shared_convolution(float * __restrict__ y, const float * __restrict__ x, const float * __restrict__ k, const int B, const int M, const int C, const int H, const int W, const int D_out)
{
#define y4d(i3, i2, i1, i0) y[(i3) * (M * D_out * D_out) + (i2) * (D_out * D_out) + (i1) * (D_out) + i0]
#define x4d(i3, i2, i1, i0) x[(i3) * (C * H * W) + (i2) * (H * W) + (i1) * (W) + i0]
#ifdef CONSTANT
#define k4d(i3, i2, i1, i0) deviceKernel[(i3) * (C * K * K) + (i2) * (K * K) + (i1) * (K) + i0]
#else
#define k4d(i3, i2, i1, i0) k[(i3) * (C * K * K) + (i2) * (K * K) + (i1) * (K) + i0]
#endif
__shared__ float Xds[BLOCK_SIZE][BLOCK_SIZE][12];
int D_grid = ceil(1.0*D_out / TILE_SIZE);
int bx = blockIdx.x, by = blockIdx.y, bz = blockIdx.z;
int tx = threadIdx.x, ty = threadIdx.y;
int row_start = (bz / D_grid) * TILE_SIZE + ty;
int col_start = (bz % D_grid) * TILE_SIZE + tx;
for (int ch = 0; ch < C; ch++) {
Xds[ch][ty][tx] = (row_start < H && col_start < W) ? x4d(bx, ch, row_start, col_start) : 0;
}
__syncthreads();
if (tx < TILE_SIZE && ty < TILE_SIZE && row_start < D_out && col_start < D_out) {
float Pvalue = 0;
for (int ch = 0; ch < C; ch++)
#pragma unroll
for (int r = 0; r < K; r++)
#pragma unroll
for (int c = 0; c < K; c++)
Pvalue += Xds[ch][ty+r][tx+c] * k4d(by, ch, r, c);
y4d(bx, by, row_start, col_start) = Pvalue;
}
#undef y4d
#undef x4d
#undef k4d
#undef x4ds
}
template<const int BLOCK_SIZE, const int C, const int D_out, const int M, const int H, const int W>
__global__ void fused_unroll_gemm(float * __restrict__ y, const float * __restrict__ x, const float * __restrict__ k){
#define x4d(i3, i2, i1, i0) x[(i3) * (C * H * W) + (i2) * (H * W) + (i1) * (W) + i0]
#define y4d(i3, i2, i1, i0) y[(i3) * (M * numBColumns) + (i2) * (numBColumns) + (i1) * (D_out) + i0]
__shared__ float Mds[BLOCK_SIZE][BLOCK_SIZE];
__shared__ float Nds[BLOCK_SIZE][BLOCK_SIZE];
constexpr const int numAColumns = C * K * K;
constexpr const int numBColumns = D_out * D_out;
constexpr const int k2 = K * K;
int bx = blockIdx.x, by = blockIdx.y, bz = blockIdx.z;
int tx = threadIdx.x, ty = threadIdx.y;
int out_idx = bx * BLOCK_SIZE + tx;
int mask_out = by * BLOCK_SIZE + ty;
int h_out = out_idx / D_out;
float Pvalue = 0;
#pragma unroll
for (int i = 0; i < numAColumns/BLOCK_SIZE + 1; ++i) {
int r = i * BLOCK_SIZE + ty;
int c = i * BLOCK_SIZE + tx;
int mask_in = r / k2;
int tmp = r - k2 * mask_in; // r % K2
int tk = tmp / K;
int w_unrolled = (out_idx - h_out * D_out) + (tmp - tk * K); // out_idx % D_out + (r % K2) % K
Nds[ty][tx] = (r < numAColumns && out_idx < numBColumns) ? x4d(bz, mask_in, h_out + tk, w_unrolled) : 0;
Mds[ty][tx] = (mask_out < M && c < numAColumns) ? k[mask_out * numAColumns + c] : 0;
__syncthreads();
#pragma unroll
for (int k = 0; k < BLOCK_SIZE; ++k)
Pvalue += Mds[ty][k] * Nds[k][tx];
__syncthreads();
}
if (mask_out < M && out_idx < numBColumns)
y4d(bz, mask_out, h_out, out_idx - h_out * D_out) = Pvalue; // out_idx, out_idx % D_out
#undef y4d
#undef x4d
}
/*
This function is called by new-inl.h
Any code you write should be executed by this function.
For ECE408, we only expect the float version of the operator to be called, so here we specialize with only floats.
*/
template <>
void forward<gpu, float>(mshadow::Tensor<gpu, 4, float> &y, const mshadow::Tensor<gpu, 4, float> &x, const mshadow::Tensor<gpu, 4, float> &w)
{
// // Use mxnet's CHECK_EQ to do assertions.
// // Remove this assertion when you do your implementation!
// CHECK_EQ(0, 1) << "Remove this line and replace with your implementation";
// const int B = x.shape_[0];
const int M = y.shape_[1]; // num_filter
// const int C = x.shape_[1];
// const int H = x.shape_[2];
// const int W = x.shape_[3];
// const float D_o = H - K + 1;
/*
dim3 gridDim((B + 511) / 512);
dim3 blockDim(512);
base<<<gridDim, blockDim>>>(y.dptr_, x.dptr_, w.dptr_, B, M, C, H, W);
*/
/*
cudaStream_t s = 0;
if(M == 12){
constexpr const int B = 10000;
constexpr const int C = 1;
constexpr const int H = 72;
constexpr const int W = 72;
cudaMemcpyToSymbolAsync(deviceKernel, w.dptr_, M*C*K*K*sizeof(float), 0, cudaMemcpyDeviceToDevice, s);
constexpr const int D_o = H - K + 1; // 66
constexpr const int BLOCK_SIZE = 24;
constexpr const int grid_z = ((D_o / BLOCK_SIZE) + 1) * ((D_o / BLOCK_SIZE) + 1);
dim3 gridDim(B, M, grid_z); // images x output_masks x (blocks per image)
dim3 blockDim(BLOCK_SIZE, BLOCK_SIZE, 1);
parallel_output<BLOCK_SIZE, 1, D_o, 12, H, W><<<gridDim, blockDim, 0, s>>>(y.dptr_, x.dptr_);
}else{
constexpr const int B = 10000;
constexpr const int C = 12;
constexpr const int H = 33;
constexpr const int W = 33;
cudaMemcpyToSymbolAsync(deviceKernel, w.dptr_, M*C*K*K*sizeof(float), 0, cudaMemcpyDeviceToDevice, s);
constexpr const int D_o = H - K + 1; // 27
constexpr const int BLOCK_SIZE = 32;
constexpr const int grid_z = ((D_o / BLOCK_SIZE) + 1) * ((D_o / BLOCK_SIZE) + 1);
dim3 gridDim(B, M, grid_z); // images x output_masks x (blocks per image)
dim3 blockDim(BLOCK_SIZE, BLOCK_SIZE, 1);
parallel_output<BLOCK_SIZE, 12, D_o, 24, H, W><<<gridDim, blockDim, 0, s>>>(y.dptr_, x.dptr_);
}
*/
/*
constexpr const int TILE_SIZE = 24;
constexpr const int BLOCK_SIZE = TILE_SIZE + K - 1;
constexpr const int grid_z = ((D_o / BLOCK_SIZE) + 1) * ((D_o / BLOCK_SIZE) + 1);
dim3 gridDim(B, M, grid_z);
dim3 blockDim(BLOCK_SIZE, BLOCK_SIZE, 1);
shared_convolution<TILE_SIZE, BLOCK_SIZE><<<gridDim, blockDim>>>(y.dptr_, x.dptr_, w.dptr_, B, M, C, H, W, D_o);
*/
if(M == 12){
constexpr const int BLOCK_SIZE = 24;
constexpr const int C = 1;
constexpr const int H = 72;
constexpr const int W = 72;
constexpr const int D_o = H - K + 1; // 66
constexpr const int grid_z = ((D_o / BLOCK_SIZE) + 1) * ((D_o / BLOCK_SIZE) + 1);
cudaMemcpyToSymbol(deviceKernel, w.dptr_, 12*49*sizeof(float), 0, cudaMemcpyDeviceToDevice);
dim3 gridDim(B, M, grid_z); // images x output_masks x (blocks per image)
dim3 blockDim(BLOCK_SIZE, BLOCK_SIZE, 1);
parallel_output<BLOCK_SIZE, C, D_o, 12, H, W><<<gridDim, blockDim>>>(y.dptr_, x.dptr_);
}
else{
constexpr const int BLOCK_SIZE = 24;
constexpr const int C = 12;
constexpr const int H = 33;
constexpr const int W = 33;
constexpr const int D_o = H - K + 1; // 27
dim3 gridDim((D_o * D_o)/BLOCK_SIZE + 1, 1, B);
dim3 blockDim(BLOCK_SIZE, BLOCK_SIZE, 1);
fused_unroll_gemm<BLOCK_SIZE, C, D_o, 24, H, W><<<gridDim, blockDim>>>(y.dptr_, x.dptr_, w.dptr_);
}
}
template <typename gpu, typename DType>
void forward(mshadow::Tensor<gpu, 4, DType> &y, const mshadow::Tensor<gpu, 4, DType> &x, const mshadow::Tensor<gpu, 4, DType> &w)
{
assert(0 && "No forward implementation for other datatypes");
}
}
}
#endif