diff --git a/src/coreneuron/mpi/lib/mpispike.cpp b/src/coreneuron/mpi/lib/mpispike.cpp index bbe81ac6c2..2e62ab578a 100644 --- a/src/coreneuron/mpi/lib/mpispike.cpp +++ b/src/coreneuron/mpi/lib/mpispike.cpp @@ -289,46 +289,31 @@ void nrnmpi_barrier_impl() { MPI_Barrier(nrnmpi_comm); } -double nrnmpi_dbl_allreduce_impl(double x, int type) { - double result; +static MPI_Op type2OP(int type) { MPI_Op tt; if (type == 1) { - tt = MPI_SUM; + return MPI_SUM; } else if (type == 2) { - tt = MPI_MAX; + return MPI_MAX; } else { - tt = MPI_MIN; + return MPI_MIN; } - MPI_Allreduce(&x, &result, 1, MPI_DOUBLE, tt, nrnmpi_comm); +} + +double nrnmpi_dbl_allreduce_impl(double x, int type) { + double result; + MPI_Allreduce(&x, &result, 1, MPI_DOUBLE, type2OP(type), nrnmpi_comm); return result; } void nrnmpi_dbl_allreduce_vec_impl(double* src, double* dest, int cnt, int type) { - MPI_Op tt; assert(src != dest); - if (type == 1) { - tt = MPI_SUM; - } else if (type == 2) { - tt = MPI_MAX; - } else { - tt = MPI_MIN; - } - MPI_Allreduce(src, dest, cnt, MPI_DOUBLE, tt, nrnmpi_comm); - return; + MPI_Allreduce(src, dest, cnt, MPI_DOUBLE, type2OP(type), nrnmpi_comm); } void nrnmpi_long_allreduce_vec_impl(long* src, long* dest, int cnt, int type) { - MPI_Op tt; assert(src != dest); - if (type == 1) { - tt = MPI_SUM; - } else if (type == 2) { - tt = MPI_MAX; - } else { - tt = MPI_MIN; - } - MPI_Allreduce(src, dest, cnt, MPI_LONG, tt, nrnmpi_comm); - return; + MPI_Allreduce(src, dest, cnt, MPI_LONG, type2OP(type), nrnmpi_comm); } #if NRN_MULTISEND diff --git a/src/nrnmpi/mpispike.cpp b/src/nrnmpi/mpispike.cpp index 53a86e6549..c84e2d0ae1 100644 --- a/src/nrnmpi/mpispike.cpp +++ b/src/nrnmpi/mpispike.cpp @@ -617,25 +617,26 @@ void nrnmpi_barrier() { MPI_Barrier(nrnmpi_comm); } -double nrnmpi_dbl_allreduce(double x, int type) { - double result; - MPI_Op t; - if (nrnmpi_numprocs < 2) { - return x; - } +static MPI_Op type2OP(int type) { if (type == 1) { - t = MPI_SUM; + return MPI_SUM; } else if (type == 2) { - t = MPI_MAX; + return MPI_MAX; } else { - t = MPI_MIN; + return MPI_MIN; + } +} + +double nrnmpi_dbl_allreduce(double x, int type) { + if (nrnmpi_numprocs < 2) { + return x; } - MPI_Allreduce(&x, &result, 1, MPI_DOUBLE, t, nrnmpi_comm); + double result; + MPI_Allreduce(&x, &result, 1, MPI_DOUBLE, type2OP(type), nrnmpi_comm); return result; } extern "C" void nrnmpi_dbl_allreduce_vec(double* src, double* dest, int cnt, int type) { - MPI_Op t; assert(src != dest); if (nrnmpi_numprocs < 2) { for (int i = 0; i < cnt; ++i) { @@ -643,56 +644,31 @@ extern "C" void nrnmpi_dbl_allreduce_vec(double* src, double* dest, int cnt, int } return; } - if (type == 1) { - t = MPI_SUM; - } else if (type == 2) { - t = MPI_MAX; - } else { - t = MPI_MIN; - } - MPI_Allreduce(src, dest, cnt, MPI_DOUBLE, t, nrnmpi_comm); + MPI_Allreduce(src, dest, cnt, MPI_DOUBLE, type2OP(type), nrnmpi_comm); return; } void nrnmpi_longdbl_allreduce_vec(longdbl* src, longdbl* dest, int cnt, int type) { - int i; - MPI_Op t; assert(src != dest); if (nrnmpi_numprocs < 2) { - for (i = 0; i < cnt; ++i) { + for (int i = 0; i < cnt; ++i) { dest[i] = src[i]; } return; } - if (type == 1) { - t = MPI_SUM; - } else if (type == 2) { - t = MPI_MAX; - } else { - t = MPI_MIN; - } - MPI_Allreduce(src, dest, cnt, MPI_LONG_DOUBLE, t, nrnmpi_comm); + MPI_Allreduce(src, dest, cnt, MPI_LONG_DOUBLE, type2OP(type), nrnmpi_comm); return; } void nrnmpi_long_allreduce_vec(long* src, long* dest, int cnt, int type) { - int i; - MPI_Op t; assert(src != dest); if (nrnmpi_numprocs < 2) { - for (i = 0; i < cnt; ++i) { + for (int i = 0; i < cnt; ++i) { dest[i] = src[i]; } return; } - if (type == 1) { - t = MPI_SUM; - } else if (type == 2) { - t = MPI_MAX; - } else { - t = MPI_MIN; - } - MPI_Allreduce(src, dest, cnt, MPI_LONG, t, nrnmpi_comm); + MPI_Allreduce(src, dest, cnt, MPI_LONG, type2OP(type), nrnmpi_comm); return; }