Skip to content

Commit

Permalink
Added omp pragmas in the SWIG layer
Browse files Browse the repository at this point in the history
This closes #89
  • Loading branch information
tnipen committed Jan 30, 2022
1 parent a15800b commit a41be89
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion swig/vector.i
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ namespace std {
int s0 = array_size($input, 0);
int s1 = array_size($input, 1);
std::vector<std::vector<DTYPE> > temp = std::vector<std::vector<DTYPE> >(s0);
#pragma omp parallel for
for(int i = 0; i < s0; i++) {
temp[i] = std::vector<DTYPE>(arg + i*s1, arg + i*s1 + s1);
}
Expand Down Expand Up @@ -250,6 +251,7 @@ namespace std {
int s0 = array_size($input, 0);
int s1 = array_size($input, 1);
temp = std::vector<std::vector<DTYPE> >(s0);
#pragma omp parallel for
for(int i = 0; i < s0; i++) {
temp[i] = std::vector<DTYPE>(arg + i*s1, arg + i*s1 + s1);
}
Expand Down Expand Up @@ -288,6 +290,7 @@ namespace std {
int s0 = array_size($input, 0);
int s1 = array_size($input, 1);
temp = std::vector<std::vector<DTYPE> >(s0);
#pragma omp parallel for
for(int i = 0; i < s0; i++) {
temp[i] = std::vector<DTYPE>(arg + i*s1, arg + i*s1 + s1);
}
Expand Down Expand Up @@ -320,6 +323,7 @@ namespace std {
s1 = temp[0].size();
npy_intp dims[2] = {s0, s1};
py_obj = PyArray_ZEROS(2, dims, NPY_DTYPE, 0);
#pragma omp parallel for collapse(2)
for(long i = 0; i < s0; i++) {
for(long j = 0; j < s1; j++) {
DTYPE* ref = (DTYPE*) PyArray_GETPTR2((PyArrayObject*) py_obj, i, j);
Expand All @@ -338,6 +342,7 @@ namespace std {
s1 = temp[0].size();
npy_intp dims[2] = {s0, s1};
$result = PyArray_ZEROS(2, dims, NPY_DTYPE, 0);
#pragma omp parallel for collapse(2)
for(long i = 0; i < s0; i++) {
for(long j = 0; j < s1; j++) {
DTYPE* ref = (DTYPE*) PyArray_GETPTR2((PyArrayObject*) $result, i, j);
Expand Down Expand Up @@ -383,6 +388,9 @@ namespace std {
std::vector<std::vector<std::vector<DTYPE> > > temp = std::vector<std::vector<std::vector<DTYPE> > >(s0);
for(int i = 0; i < s0; i++) {
temp[i].resize(s1);
}
#pragma omp parallel for collapse(2)
for(int i = 0; i < s0; i++) {
for(int j = 0; j < s1; j++) {
temp[i][j] = std::vector<DTYPE>(arg + i * s1 * s2 + j * s2, arg + i * s1 * s2 + (j + 1) * s2);
}
Expand Down Expand Up @@ -424,6 +432,9 @@ namespace std {
temp = std::vector<std::vector<std::vector<DTYPE> > >(s0);
for(int i = 0; i < s0; i++) {
temp[i].resize(s1);
}
#pragma omp parallel for collapse(2)
for(int i = 0; i < s0; i++) {
for(int j = 0; j < s1; j++) {
temp[i][j] = std::vector<DTYPE>(arg + i * s1 * s2 + j * s2, arg + i * s1 * s2 + (j + 1) * s2);
}
Expand All @@ -442,7 +453,7 @@ namespace std {

/* Same as the const version above */
%typemap(in) std::vector<std::vector<std::vector<DTYPE> > > & (std::vector<std::vector<std::vector<DTYPE> > >*ptr=NULL, std::vector<std::vector<std::vector<DTYPE> > > temp, PyArrayObject* py_array=NULL, PyObject* py_obj=NULL, PyObject* py_obj0=NULL){
PRINT_DEBUG("Typemap(in) const std::vector<std::vector<std::vector<DTYPE> > > &");
PRINT_DEBUG("Typemap(in) std::vector<std::vector<std::vector<DTYPE> > > &");
if(is_array($input)) {
int num_dims = array_numdims($input);
if(num_dims != 3)
Expand All @@ -466,6 +477,9 @@ namespace std {
temp = std::vector<std::vector<std::vector<DTYPE> > >(s0);
for(int i = 0; i < s0; i++) {
temp[i].resize(s1);
}
#pragma omp parallel for collapse(2)
for(int i = 0; i < s0; i++) {
for(int j = 0; j < s1; j++) {
temp[i][j] = std::vector<DTYPE>(arg + i * s1 * s2 + j * s2, arg + i * s1 * s2 + (j + 1) * s2);
}
Expand Down Expand Up @@ -501,6 +515,7 @@ namespace std {
s2 = temp[0][0].size();
npy_intp dims[3] = {s0, s1, s2};
py_obj = PyArray_ZEROS(3, dims, NPY_DTYPE, 0);
#pragma omp parallel for collapse(3)
for(long i = 0; i < s0; i++) {
for(long j = 0; j < s1; j++) {
for(long k = 0; k < s2; k++) {
Expand All @@ -524,6 +539,7 @@ namespace std {
s2 = temp[0][0].size();
npy_intp dims[3] = {s0, s1, s2};
$result = PyArray_ZEROS(3, dims, NPY_DTYPE, 0);
#pragma omp parallel for collapse(3)
for(long i = 0; i < s0; i++) {
for(long j = 0; j < s1; j++) {
for(long k = 0; k < s2; k++) {
Expand Down

0 comments on commit a41be89

Please sign in to comment.