Skip to content

Commit

Permalink
Create an helper function to get the MPI_Op from type (#2560)
Browse files Browse the repository at this point in the history
Internally we have some constants that mean:
1 - MPI_SUM
2 - MPI_MAX
3 - MPI_MIN
  • Loading branch information
Nicolas Cornu authored Sep 29, 2023
1 parent 6c48766 commit 7ad9347
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 67 deletions.
37 changes: 11 additions & 26 deletions src/coreneuron/mpi/lib/mpispike.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,46 +289,31 @@ void nrnmpi_barrier_impl() {
MPI_Barrier(nrnmpi_comm);
}

double nrnmpi_dbl_allreduce_impl(double x, int type) {
double result;
static MPI_Op type2OP(int type) {
MPI_Op tt;
if (type == 1) {
tt = MPI_SUM;
return MPI_SUM;
} else if (type == 2) {
tt = MPI_MAX;
return MPI_MAX;
} else {
tt = MPI_MIN;
return MPI_MIN;
}
MPI_Allreduce(&x, &result, 1, MPI_DOUBLE, tt, nrnmpi_comm);
}

double nrnmpi_dbl_allreduce_impl(double x, int type) {
double result;
MPI_Allreduce(&x, &result, 1, MPI_DOUBLE, type2OP(type), nrnmpi_comm);
return result;
}

void nrnmpi_dbl_allreduce_vec_impl(double* src, double* dest, int cnt, int type) {
MPI_Op tt;
assert(src != dest);
if (type == 1) {
tt = MPI_SUM;
} else if (type == 2) {
tt = MPI_MAX;
} else {
tt = MPI_MIN;
}
MPI_Allreduce(src, dest, cnt, MPI_DOUBLE, tt, nrnmpi_comm);
return;
MPI_Allreduce(src, dest, cnt, MPI_DOUBLE, type2OP(type), nrnmpi_comm);
}

void nrnmpi_long_allreduce_vec_impl(long* src, long* dest, int cnt, int type) {
MPI_Op tt;
assert(src != dest);
if (type == 1) {
tt = MPI_SUM;
} else if (type == 2) {
tt = MPI_MAX;
} else {
tt = MPI_MIN;
}
MPI_Allreduce(src, dest, cnt, MPI_LONG, tt, nrnmpi_comm);
return;
MPI_Allreduce(src, dest, cnt, MPI_LONG, type2OP(type), nrnmpi_comm);
}

#if NRN_MULTISEND
Expand Down
58 changes: 17 additions & 41 deletions src/nrnmpi/mpispike.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -617,82 +617,58 @@ void nrnmpi_barrier() {
MPI_Barrier(nrnmpi_comm);
}

double nrnmpi_dbl_allreduce(double x, int type) {
double result;
MPI_Op t;
if (nrnmpi_numprocs < 2) {
return x;
}
static MPI_Op type2OP(int type) {
if (type == 1) {
t = MPI_SUM;
return MPI_SUM;
} else if (type == 2) {
t = MPI_MAX;
return MPI_MAX;
} else {
t = MPI_MIN;
return MPI_MIN;
}
}

double nrnmpi_dbl_allreduce(double x, int type) {
if (nrnmpi_numprocs < 2) {
return x;
}
MPI_Allreduce(&x, &result, 1, MPI_DOUBLE, t, nrnmpi_comm);
double result;
MPI_Allreduce(&x, &result, 1, MPI_DOUBLE, type2OP(type), nrnmpi_comm);
return result;
}

extern "C" void nrnmpi_dbl_allreduce_vec(double* src, double* dest, int cnt, int type) {
MPI_Op t;
assert(src != dest);
if (nrnmpi_numprocs < 2) {
for (int i = 0; i < cnt; ++i) {
dest[i] = src[i];
}
return;
}
if (type == 1) {
t = MPI_SUM;
} else if (type == 2) {
t = MPI_MAX;
} else {
t = MPI_MIN;
}
MPI_Allreduce(src, dest, cnt, MPI_DOUBLE, t, nrnmpi_comm);
MPI_Allreduce(src, dest, cnt, MPI_DOUBLE, type2OP(type), nrnmpi_comm);
return;
}

void nrnmpi_longdbl_allreduce_vec(longdbl* src, longdbl* dest, int cnt, int type) {
int i;
MPI_Op t;
assert(src != dest);
if (nrnmpi_numprocs < 2) {
for (i = 0; i < cnt; ++i) {
for (int i = 0; i < cnt; ++i) {
dest[i] = src[i];
}
return;
}
if (type == 1) {
t = MPI_SUM;
} else if (type == 2) {
t = MPI_MAX;
} else {
t = MPI_MIN;
}
MPI_Allreduce(src, dest, cnt, MPI_LONG_DOUBLE, t, nrnmpi_comm);
MPI_Allreduce(src, dest, cnt, MPI_LONG_DOUBLE, type2OP(type), nrnmpi_comm);
return;
}

void nrnmpi_long_allreduce_vec(long* src, long* dest, int cnt, int type) {
int i;
MPI_Op t;
assert(src != dest);
if (nrnmpi_numprocs < 2) {
for (i = 0; i < cnt; ++i) {
for (int i = 0; i < cnt; ++i) {
dest[i] = src[i];
}
return;
}
if (type == 1) {
t = MPI_SUM;
} else if (type == 2) {
t = MPI_MAX;
} else {
t = MPI_MIN;
}
MPI_Allreduce(src, dest, cnt, MPI_LONG, t, nrnmpi_comm);
MPI_Allreduce(src, dest, cnt, MPI_LONG, type2OP(type), nrnmpi_comm);
return;
}

Expand Down

0 comments on commit 7ad9347

Please sign in to comment.