-
Notifications
You must be signed in to change notification settings - Fork 1
/
connect-omp-gpu.cpp
87 lines (75 loc) · 2.83 KB
/
connect-omp-gpu.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <sys/timeb.h>
#include <float.h>
#include "connect.h"
void connect(int batch, int K, int N, float *input, float *output, float *weights, int dev_id, int num_dev) {
int i,j,k;
int HWC_in = batch*K;
int HWC_out = batch*N;
int HWC_weight = N*K;
#pragma omp target teams distribute parallel for private(j,k) collapse(2) map(to:input[0:HWC_in], weights[0:HWC_weight]) map(tofrom:output[0:HWC_out]) device(dev_id)
{
for (i = 0; i < batch; i++) {
for (j = 0; j < N; j++) {
float sum = 0.0;
for (k = 0; k < K; k++) sum += input[i*K+k]*weights[j*K+k];
output[i*N+j] += sum;
}
}
} // target region 1
}
void connect_backward(int batch, int N, int M, float *delta_in, float *input, float *weight_updates, float *weights, float *delta_out, int dev_id, int num_dev) {
int i,j,k;
int HWC_in = batch*N;
int HWC_delta_in = batch*M;
int HWC_delta_out = batch*N;
int HWC_weight = M*N;
int HWC_weight_updates = M*N;
// gemm
#pragma omp target teams distribute private(j,k) collapse(2) map(to:input[0:HWC_in], delta_in[0:HWC_delta_in]) map(tofrom:weight_updates[0:HWC_weight_updates]) device(dev_id)
{
for (i = 0; i < M; i++) {
//for (k = 0; k < batch; k++) {
//float a_part = delta_in[k*M+i];
for (j = 0; j < N; j++) {
float sum = 0.0;
//#pragma omp parallel for reduction(+:sum)
for (k = 0; k < batch; k+=10) {
sum += delta_in[k*M+i]*input[k*N+j];
}
weight_updates[i*N+j] = sum;
}
}
} // target region 1
// gemm2
#pragma omp target teams distribute parallel for private(j,k) collapse(2) map(to:delta_in[0:HWC_delta_in], weights[0:HWC_weight]) map(tofrom:delta_out[0:HWC_delta_out]) device(dev_id)
{
for (i = 0; i < batch; i++) {
//for (k = 0; k < M; k++) {
//float a_part = delta_in[i*M+k];
for (j = 0; j < N; j++) {
float sum = 0.0;
//#pragma omp parallel for reduction(+:sum)
for (k = 0; k < M; k+=10) {
sum += delta_in[i*M+k]*weights[k*N+j];
}
delta_out[i*N+j] = sum;
}
}
} // target region 2
}
void connect_update(int nbias, float *biases, float *bias_updates, int nweights, float *weights, float *weight_updates, float p1, float p2, float p3) {
// axpy
for (int i = 0; i < nbias; i++) {biases[i] += p1*bias_updates[i];}
// scale
for (int i = 0; i < nbias; i++) {bias_updates[i] *= p3;}
// axpy
for (int i = 0; i < nweights; i++) {weight_updates[i] += p2*weights[i];}
// axpy2
for (int i = 0; i < nweights; i++) {weights[i] += p1*weight_updates[i];}
// scale
for (int i = 0; i < nweights; i++) {weight_updates[i] *= p3;}
}