Skip to content

Commit

Permalink
update the test case
Browse files Browse the repository at this point in the history
  • Loading branch information
Sadr Mohsen committed Sep 17, 2024
1 parent 45f794b commit ed0633c
Showing 1 changed file with 32 additions and 22 deletions.
54 changes: 32 additions & 22 deletions test/random/TestInverseTransformSamplingSpecificRange.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Testing the update functionality of the inverse transform sampling method for Normal Distribution
// Testing the customized range policy of the inverse transform sampling method for Normal Distribution
// Example:
// srun ./TestInverseTransformSamplingUpdateBounds --overallocate 2.0 --info 10
// srun ./TestInverseTransformSamplingSpecificRange --overallocate 2.0 --info 10

#include <Kokkos_MathematicalConstants.hpp>
#include <Kokkos_MathematicalFunctions.hpp>
Expand All @@ -20,8 +20,6 @@ const int Dim = 2;

using view_type = typename ippl::detail::ViewType<ippl::Vector<double, Dim>, 1>::view_type;

using Mesh_t = ippl::UniformCartesian<double, Dim>;

using size_type = ippl::detail::size_type;

using GeneratorPool = typename Kokkos::Random_XorShift64_Pool<>;
Expand Down Expand Up @@ -51,15 +49,28 @@ KOKKOS_FUNCTION void get_norm_dist_cent_moms(double stdev, const int P, double *

void get_moments_from_samples(view_type position, int d, int start, int end, const int P, double *moms_p){
int d_ = d;
double temp = 0.0;
double temp = 0.0, mean = 0.0;
double locmoms[P];
int gNpart=0, locNpart = end-start;

for (int p = 0; p < P; p++) {
moms_p[p] = 0.0;
}

Kokkos::parallel_reduce("moments", Kokkos::RangePolicy<>(start, end),
KOKKOS_LAMBDA(const int i, double& valL) {
double myVal = position(i)[d_];
valL += myVal;
}, Kokkos::Sum<double>(temp));
Kokkos::fence();

double mean = temp / (end-start);
moms_p[0] = mean;
locmoms[0] = temp;

MPI_Allreduce(&temp, &mean, 1, MPI_DOUBLE, MPI_SUM, ippl::Comm->getCommunicator());
MPI_Allreduce(&locNpart, &gNpart, 1, MPI_INT, MPI_SUM, ippl::Comm->getCommunicator());
ippl::Comm->barrier();

mean = mean / gNpart;

for (int p = 1; p < P; p++) {
temp = 0.0;
Expand All @@ -68,8 +79,18 @@ void get_moments_from_samples(view_type position, int d, int start, int end, con
double myVal = pow(position(i)[d_] - mean, p + 1);
valL += myVal;
}, Kokkos::Sum<double>(temp));
moms_p[p] = temp / (end-start);
Kokkos::fence();

locmoms[p] = temp;
}

MPI_Allreduce(locmoms, moms_p, P, MPI_DOUBLE, MPI_SUM, ippl::Comm->getCommunicator());
ippl::Comm->barrier();

for (int p = 0; p < P; p++) {
moms_p[p] = moms_p[p] / gNpart;
}

}

void write_error_in_moments(double *moms_p, double *moms_ref_p, int P){
Expand All @@ -89,7 +110,7 @@ int main(int argc, char* argv[]) {
Inform m("test ITS normal");
// set up ITS as other examples to sample position
ippl::Vector<int, 2> nr = {100, 100};
size_type ntotal = 100000;
size_type ntotal = 1000000;

ippl::NDIndex<2> domain;
for (unsigned i = 0; i < Dim; i++) {
Expand All @@ -101,17 +122,6 @@ int main(int argc, char* argv[]) {

ippl::Vector<double, Dim> rmin = -3.;
ippl::Vector<double, Dim> rmax = 3.;
ippl::Vector<double, Dim> length = rmax - rmin;
ippl::Vector<double, Dim> hr = length / nr;
ippl::Vector<double, Dim> origin = rmin;

const bool isAllPeriodic = true;

Mesh_t mesh(domain, hr, origin);

ippl::FieldLayout<Dim> fl(MPI_COMM_WORLD, domain, isParallel, isAllPeriodic);

ippl::detail::RegionLayout<double, Dim, Mesh_t> rlayout(fl, mesh);

int seed = 42;

Expand All @@ -127,7 +137,7 @@ int main(int argc, char* argv[]) {
using sampling_t = ippl::random::InverseTransformSampling<double, Dim, Kokkos::DefaultExecutionSpace, Dist_t>;

Dist_t dist(par);
sampling_t sampling(dist, rmax, rmin, rlayout, ntotal);
sampling_t sampling(dist, rmax, rmin, rmax, rmin, ntotal);
size_type nlocal = sampling.getLocalSamplesNum();
view_type position("position", nlocal);

Expand Down Expand Up @@ -159,7 +169,7 @@ int main(int argc, char* argv[]) {

const double par2[4] = {mu1, sd1, mu2, sd2};
Dist_t dist2(par2);
sampling_t sampling2(dist2, rmax, rmin, rlayout, ntotal);
sampling_t sampling2(dist2, rmax, rmin, rmax, rmin, ntotal);

startIndex = floor(nlocal/2);
endIndex = nlocal;
Expand Down

0 comments on commit ed0633c

Please sign in to comment.