From 5de0a4b8b21d48cf9c34a20b9a8d4f0ab9b6c678 Mon Sep 17 00:00:00 2001 From: Benjamin Menetrier <30638301+benjaminmenetrier@users.noreply.github.com> Date: Mon, 16 Sep 2024 09:45:51 +0200 Subject: [PATCH 1/2] Add regional interpolation (#215) --- src/atlas/CMakeLists.txt | 2 + .../detail/StructuredColumns_setup.cc | 21 +- .../interpolation/method/MethodFactory.cc | 2 + .../method/structured/RegionalLinear2D.cc | 568 ++++++++++++++++++ .../method/structured/RegionalLinear2D.h | 77 +++ src/tests/interpolation/CMakeLists.txt | 8 + ...test_interpolation_k_nearest_neighbours.cc | 4 - ...est_interpolation_structured2D_regional.cc | 217 +++++++ 8 files changed, 888 insertions(+), 11 deletions(-) create mode 100644 src/atlas/interpolation/method/structured/RegionalLinear2D.cc create mode 100644 src/atlas/interpolation/method/structured/RegionalLinear2D.h create mode 100644 src/tests/interpolation/test_interpolation_structured2D_regional.cc diff --git a/src/atlas/CMakeLists.txt b/src/atlas/CMakeLists.txt index 5ad375d66..30f4a14b1 100644 --- a/src/atlas/CMakeLists.txt +++ b/src/atlas/CMakeLists.txt @@ -655,6 +655,8 @@ interpolation/method/structured/QuasiCubic2D.cc interpolation/method/structured/QuasiCubic2D.h interpolation/method/structured/QuasiCubic3D.cc interpolation/method/structured/QuasiCubic3D.h +interpolation/method/structured/RegionalLinear2D.cc +interpolation/method/structured/RegionalLinear2D.h interpolation/method/structured/StructuredInterpolation2D.h interpolation/method/structured/StructuredInterpolation2D.tcc interpolation/method/structured/StructuredInterpolation3D.h diff --git a/src/atlas/functionspace/detail/StructuredColumns_setup.cc b/src/atlas/functionspace/detail/StructuredColumns_setup.cc index 4bdc0a4e9..094b47849 100644 --- a/src/atlas/functionspace/detail/StructuredColumns_setup.cc +++ b/src/atlas/functionspace/detail/StructuredColumns_setup.cc @@ -588,13 +588,20 @@ void StructuredColumns::setup(const grid::Distribution& distribution, const ecki atlas_omp_parallel_for(idx_t n = 0; n < gridpoints.size(); ++n) { const GridPoint& gp = gridpoints[n]; - if (gp.j >= 0 && gp.j < grid_->ny()) { - xy(gp.r, XX) = grid_->x(gp.i, gp.j); - xy(gp.r, YY) = grid_->y(gp.j); - } - else { - xy(gp.r, XX) = compute_x(gp.i, gp.j); - xy(gp.r, YY) = compute_y(gp.j); + if (regional) { + std::array lonlatVec; + grid_->lonlat(gp.i, gp.j, lonlatVec.data()); + xy(gp.r, XX) = lonlatVec[0]; + xy(gp.r, YY) = lonlatVec[1]; + } else { + if (gp.j >= 0 && gp.j < grid_->ny()) { + xy(gp.r, XX) = grid_->x(gp.i, gp.j); + xy(gp.r, YY) = grid_->y(gp.j); + } + else { + xy(gp.r, XX) = compute_x(gp.i, gp.j); + xy(gp.r, YY) = compute_y(gp.j); + } } bool in_domain(false); diff --git a/src/atlas/interpolation/method/MethodFactory.cc b/src/atlas/interpolation/method/MethodFactory.cc index a34e29961..f7a25c72f 100644 --- a/src/atlas/interpolation/method/MethodFactory.cc +++ b/src/atlas/interpolation/method/MethodFactory.cc @@ -24,6 +24,7 @@ #include "structured/Linear3D.h" #include "structured/QuasiCubic2D.h" #include "structured/QuasiCubic3D.h" +#include "structured/RegionalLinear2D.h" #include "unstructured/FiniteElement.h" #include "unstructured/UnstructuredBilinearLonLat.h" @@ -46,6 +47,7 @@ void force_link() { MethodBuilder(); MethodBuilder(); MethodBuilder(); + MethodBuilder(); MethodBuilder(); MethodBuilder(); MethodBuilder(); diff --git a/src/atlas/interpolation/method/structured/RegionalLinear2D.cc b/src/atlas/interpolation/method/structured/RegionalLinear2D.cc new file mode 100644 index 000000000..d294171b7 --- /dev/null +++ b/src/atlas/interpolation/method/structured/RegionalLinear2D.cc @@ -0,0 +1,568 @@ +/* + * (C) Copyright 2024 Meteorologisk Institutt + * + * This software is licensed under the terms of the Apache Licence Version 2.0 + * which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. + */ + +#include "atlas/interpolation/method/structured/RegionalLinear2D.h" + +#include + +#include "atlas/array.h" +#include "atlas/interpolation/method/MethodFactory.h" +#include "atlas/util/KDTree.h" +#include "atlas/util/Point.h" + +namespace atlas { +namespace interpolation { +namespace method { + +namespace { +MethodBuilder __builder("regional-linear-2d"); +} + +void RegionalLinear2D::print(std::ostream&) const { ATLAS_NOTIMPLEMENTED; } + +void RegionalLinear2D::do_setup(const Grid& source, const Grid& target, + const Cache&) { + ATLAS_NOTIMPLEMENTED; +} + +void RegionalLinear2D::do_setup(const FunctionSpace& source, + const FunctionSpace& target) { + ATLAS_TRACE("interpolation::method::RegionalLinear2D::do_setup"); + source_ = source; + target_ = target; + + if (target_.size() == 0) { + return; + } + ASSERT(source_.type() == "StructuredColumns"); + + // Get grid parameters + const functionspace::StructuredColumns sourceFs(source_); + const RegularGrid sourceGrid(sourceFs.grid()); + const Projection & sourceProj = sourceGrid.projection(); + const size_t sourceNx = sourceGrid.nx(); + const size_t sourceNy = sourceGrid.ny(); + const double sourceDx = sourceGrid.dx(); + const double sourceDy = std::abs(sourceGrid.y(1)-sourceGrid.y(0)); + const bool reversedY = sourceGrid.y(1) < sourceGrid.y(0); + + // Check grid regularity in y direction + for (size_t sourceJ = 0; sourceJ < sourceNy-1; ++sourceJ) { + if (reversedY) { + ASSERT(std::abs(sourceGrid.y(sourceJ)-sourceGrid.y(sourceJ+1)-sourceDy) < 1.0e-12*sourceDy); + } else { + ASSERT(std::abs(sourceGrid.y(sourceJ+1)-sourceGrid.y(sourceJ)-sourceDy) < 1.0e-12*sourceDy); + } + } + + // Source grid indices + const Field sourceFieldIndexI = sourceFs.index_i(); + const Field sourceFieldIndexJ = sourceFs.index_j(); + const auto sourceIndexIView = array::make_view(sourceFieldIndexI); + const auto sourceIndexJView = array::make_view(sourceFieldIndexJ); + sourceSize_ = sourceFs.size(); + + // Destination grid size + targetSize_ = target_.size(); + + // Ghost points + const auto sourceGhostView = array::make_view(sourceFs.ghost()); + const auto targetGhostView = array::make_view(target_.ghost()); + + // Define reduced grid horizontal distribution + std::vector mpiTask(sourceNx*sourceNy, 0); + for (size_t sourceJnode = 0; sourceJnode < sourceSize_; ++sourceJnode) { + if (sourceGhostView(sourceJnode) == 0) { + mpiTask[(sourceIndexIView(sourceJnode)-1)*sourceNy+sourceIndexJView(sourceJnode)-1] = comm_.rank(); + } + } + comm_.allReduceInPlace(mpiTask.begin(), mpiTask.end(), eckit::mpi::sum()); + + // Define local tree on destination grid + std::vector targetPoints; + std::vector targetIndices; + const auto targetLonLatView = array::make_view(target_.lonlat()); + for (size_t targetJnode = 0; targetJnode < targetSize_; ++targetJnode) { + PointLonLat p({targetLonLatView(targetJnode, 0), targetLonLatView(targetJnode, 1)}); + sourceProj.lonlat2xy(p); + targetPoints.push_back(Point3(p[0], p[1], 0.0)); + targetIndices.push_back(targetJnode); + } + util::IndexKDTree targetTree; + if (targetSize_ > 0) { + targetTree.build(targetPoints, targetIndices); + } + const double radius = std::sqrt(sourceDx*sourceDx+sourceDy*sourceDy); + + // Delta for colocation + const double eps = 1.0e-8; + + // RecvCounts and received points list + targetRecvCounts_.resize(comm_.size()); + std::fill(targetRecvCounts_.begin(), targetRecvCounts_.end(), 0); + std::vector targetRecvPointsList; + for (size_t sourceJ = 0; sourceJ < sourceNy; ++sourceJ) { + double yMin, yMax; + if (reversedY) { + yMin = sourceJ < sourceNy-1 ? sourceGrid.y(sourceJ+1)-eps : -std::numeric_limits::max(); + yMax = sourceJ > 0 ? sourceGrid.y(sourceJ-1)+eps : std::numeric_limits::max(); + } else { + yMin = sourceJ > 0 ? sourceGrid.y(sourceJ-1)-eps : -std::numeric_limits::max(); + yMax = sourceJ < sourceNy-1 ? sourceGrid.y(sourceJ+1)+eps : std::numeric_limits::max(); + } + for (size_t sourceI = 0; sourceI < sourceNx; ++sourceI) { + const double xMin = sourceI > 0 ? sourceGrid.x(sourceI-1)-eps : -std::numeric_limits::max(); + const double xMax = sourceI < sourceNx-1 ? sourceGrid.x(sourceI+1)+eps : + std::numeric_limits::max(); + + bool pointsNeeded = false; + if (targetSize_ > 0) { + const Point3 p(sourceGrid.x(sourceI), sourceGrid.y(sourceJ), 0.0); + const auto list = targetTree.closestPointsWithinRadius(p, radius); + for (const auto & item : list) { + const PointXYZ targetPoint = item.point(); + const size_t targetJnode = item.payload(); + if (targetGhostView(targetJnode) == 0) { + const bool inX = (xMin <= targetPoint[0] && targetPoint[0] <= xMax); + const bool inY = (yMin <= targetPoint[1] && targetPoint[1] <= yMax); + if (inX && inY) { + pointsNeeded = true; + break; + } + } + } + } + if (pointsNeeded) { + ++targetRecvCounts_[mpiTask[sourceI*sourceNy+sourceJ]]; + targetRecvPointsList.push_back(sourceI*sourceNy+sourceJ); + } + } + } + + // Buffer size + targetRecvSize_ = targetRecvPointsList.size(); + + if (targetRecvSize_ > 0) { + // RecvDispls + targetRecvDispls_.push_back(0); + for (size_t jt = 0; jt < comm_.size()-1; ++jt) { + targetRecvDispls_.push_back(targetRecvDispls_[jt]+targetRecvCounts_[jt]); + } + + // Allgather RecvCounts + eckit::mpi::Buffer targetRecvCountsBuffer(comm_.size()); + comm_.allGatherv(targetRecvCounts_.begin(), targetRecvCounts_.end(), targetRecvCountsBuffer); + std::vector targetRecvCountsGlb_ = std::move(targetRecvCountsBuffer.buffer); + + // SendCounts + for (size_t jt = 0; jt < comm_.size(); ++jt) { + sourceSendCounts_.push_back(targetRecvCountsGlb_[jt*comm_.size()+comm_.rank()]); + } + + // Buffer size + sourceSendSize_ = 0; + for (const auto & n : sourceSendCounts_) sourceSendSize_ += n; + + // SendDispls + sourceSendDispls_.push_back(0); + for (size_t jt = 0; jt < comm_.size()-1; ++jt) { + sourceSendDispls_.push_back(sourceSendDispls_[jt]+sourceSendCounts_[jt]); + } + + // Ordered received points list + std::vector targetRecvOffset(comm_.size(), 0); + std::vector targetRecvPointsListOrdered(targetRecvSize_); + for (size_t jr = 0; jr < targetRecvSize_; ++jr) { + const size_t sourceI = targetRecvPointsList[jr]/sourceNy; + const size_t sourceJ = targetRecvPointsList[jr]-sourceI*sourceNy; + size_t jt = mpiTask[sourceI*sourceNy+sourceJ]; + size_t jro = targetRecvDispls_[jt]+targetRecvOffset[jt]; + targetRecvPointsListOrdered[jro] = targetRecvPointsList[jr]; + ++targetRecvOffset[jt]; + } + std::vector sourceSentPointsList(sourceSendSize_); + comm_.allToAllv(targetRecvPointsListOrdered.data(), targetRecvCounts_.data(), targetRecvDispls_.data(), + sourceSentPointsList.data(), sourceSendCounts_.data(), sourceSendDispls_.data()); + + // Sort indices + std::vector gij; + for (size_t sourceJnode = 0; sourceJnode < sourceSize_; ++sourceJnode) { + if (sourceGhostView(sourceJnode) == 0) { + gij.push_back((sourceIndexIView(sourceJnode)-1)*sourceNy+sourceIndexJView(sourceJnode)-1); + } else { + gij.push_back(-1); + } + } + std::vector gidx(sourceSize_); + std::iota(gidx.begin(), gidx.end(), 0); + std::stable_sort(gidx.begin(), gidx.end(), [&gij](size_t i1, size_t i2) + {return gij[i1] < gij[i2];}); + std::vector ridx(sourceSendSize_); + std::iota(ridx.begin(), ridx.end(), 0); + std::stable_sort(ridx.begin(), ridx.end(), [&sourceSentPointsList](size_t i1, size_t i2) + {return sourceSentPointsList[i1] < sourceSentPointsList[i2];}); + + // Mapping for sent points + sourceSendMapping_.resize(sourceSendSize_); + size_t sourceJnode = 0; + for (size_t js = 0; js < sourceSendSize_; ++js) { + while (gij[gidx[sourceJnode]] < sourceSentPointsList[ridx[js]]) { + ++sourceJnode; + ASSERT(sourceJnode < sourceSize_); + } + sourceSendMapping_[ridx[js]] = gidx[sourceJnode]; + } + + // Sort indices + std::vector idx(targetRecvPointsListOrdered.size()); + std::iota(idx.begin(), idx.end(), 0); + std::stable_sort(idx.begin(), idx.end(), [&targetRecvPointsListOrdered](size_t i1, size_t i2) + {return targetRecvPointsListOrdered[i1] < targetRecvPointsListOrdered[i2];}); + + // Compute horizontal interpolation + stencil_.resize(targetSize_); + weights_.resize(targetSize_); + stencilSize_.resize(targetSize_); + for (size_t targetJnode = 0; targetJnode < targetSize_; ++targetJnode) { + // Interpolation element default values + if (targetGhostView(targetJnode) == 0) { + // Destination grid indices + const double targetX = targetPoints[targetJnode][0]; + bool colocatedX = false; + int indexI = -1; + for (size_t sourceI = 0; sourceI < sourceNx-1; ++sourceI) { + if (std::abs(targetX-sourceGrid.x(sourceI)) < eps) { + indexI = sourceI; + colocatedX = true; + } + if (sourceGrid.x(sourceI)+eps < targetX && targetX < sourceGrid.x(sourceI+1)-eps) { + indexI = sourceI; + colocatedX = false; + } + } + if (std::abs(targetX-sourceGrid.x(sourceNx-1)) < eps) { + indexI = sourceNx-1; + colocatedX = true; + } + const double targetY = targetPoints[targetJnode][1]; + bool colocatedY = false; + int indexJ = -1; + for (size_t sourceJ = 0; sourceJ < sourceNy-1; ++sourceJ) { + if (std::abs(targetY-sourceGrid.y(sourceJ)) < eps) { + indexJ = sourceJ; + colocatedY = true; + } + if (reversedY) { + if (sourceGrid.y(sourceJ+1)+eps < targetY && targetY < sourceGrid.y(sourceJ)-eps) { + indexJ = sourceJ; + colocatedY = false; + } + } else { + if (sourceGrid.y(sourceJ)+eps < targetY && targetY < sourceGrid.y(sourceJ+1)-eps) { + indexJ = sourceJ; + colocatedY = false; + } + } + } + if (std::abs(targetY-sourceGrid.y(sourceNy-1)) < eps) { + indexJ = sourceNy-1; + colocatedY = true; + } + + if (indexI == -1 || indexJ == -1) { + // Point outside of the domain, using nearest neighbor + if (indexI > -1) { + if (!colocatedX && + (std::abs(targetX-sourceGrid.x(indexI+1)) < std::abs(targetX-sourceGrid.x(indexI)))) { + indexI += 1; + } + } else { + if (std::abs(targetX-sourceGrid.x(0)) < std::abs(targetX-sourceGrid.x(sourceNx-1))) { + indexI = 0; + } else { + indexI = sourceNx-1; + } + } + if (indexJ > -1) { + if (!colocatedY && + (std::abs(targetY-sourceGrid.y(indexJ+1)) < std::abs(targetY-sourceGrid.y(indexJ)))) { + indexJ += 1; + } + } else { + if (std::abs(targetY-sourceGrid.y(0)) < std::abs(targetY-sourceGrid.y(sourceNy-1))) { + indexJ = 0; + } else { + indexJ = sourceNy-1; + } + Log::info() << "WARNING: point outside of the domain" << std::endl; + } + + // Colocated point (actually nearest neighbor) + colocatedX = true; + colocatedY = true; + } + + // Bilinear interpolation factor + const double alphaX = 1.0-(sourceGrid.x(indexI)+sourceDx-targetX)/sourceDx; + const double alphaY = reversedY ? (sourceGrid.y(indexJ)-targetY)/sourceDy + : 1.0-(sourceGrid.y(indexJ)+sourceDy-targetY)/sourceDy; + + // Points to find + std::vector toFind = {true, !colocatedX, !colocatedY, !colocatedX && !colocatedY}; + std::vector valueToFind = {indexI*sourceNy+indexJ, (indexI+1)*sourceNy+indexJ, + indexI*sourceNy+(indexJ+1), (indexI+1)*sourceNy+(indexJ+1)}; + std::array foundIndex; + foundIndex.fill(-1); + + // Binary search for each point + for (size_t jj = 0; jj < 4; ++jj) { + if (toFind[jj]) { + size_t low = 0; + size_t high = targetRecvPointsListOrdered.size()-1; + while (low <= high) { + size_t mid = low+(high-low)/2; + if (valueToFind[jj] == static_cast(targetRecvPointsListOrdered[idx[mid]])) { + foundIndex[jj] = idx[mid]; + break; + } + if (valueToFind[jj] > static_cast(targetRecvPointsListOrdered[idx[mid]])) { + low = mid+1; + } + if (valueToFind[jj] < static_cast(targetRecvPointsListOrdered[idx[mid]])) { + high = mid-1; + } + } + ASSERT(foundIndex[jj] > -1); + ASSERT(static_cast(targetRecvPointsListOrdered[foundIndex[jj]]) == + valueToFind[jj]); + } + } + + // Create interpolation operations + if (colocatedX && colocatedY) { + // Colocated point + stencil_[targetJnode][0] = foundIndex[0]; + weights_[targetJnode][0] = 1.0; + stencilSize_[targetJnode] = 1; + } else if (colocatedY) { + // Linear interpolation along x + stencil_[targetJnode][0] = foundIndex[0]; + weights_[targetJnode][0] = 1.0-alphaX; + stencil_[targetJnode][1] = foundIndex[1]; + weights_[targetJnode][1] = alphaX; + stencilSize_[targetJnode] = 2; + } else if (colocatedX) { + // Linear interpolation along y + stencil_[targetJnode][0] = foundIndex[0]; + weights_[targetJnode][0] = 1.0-alphaY; + stencil_[targetJnode][1] = foundIndex[2]; + weights_[targetJnode][1] = alphaY; + stencilSize_[targetJnode] = 2; + } else { + // Bilinear interpolation + stencil_[targetJnode][0] = foundIndex[0]; + weights_[targetJnode][0] = (1.0-alphaX)*(1.0-alphaY); + stencil_[targetJnode][1] = foundIndex[1]; + weights_[targetJnode][1] = alphaX*(1.0-alphaY); + stencil_[targetJnode][2] = foundIndex[2]; + weights_[targetJnode][2] = (1.0-alphaX)*alphaY; + stencil_[targetJnode][3] = foundIndex[3]; + weights_[targetJnode][3] = alphaX*alphaY; + stencilSize_[targetJnode] = 4; + } + } else { + // Ghost point + stencilSize_[targetJnode] = 0; + } + } + } +} + +void RegionalLinear2D::do_execute(const FieldSet& sourceFieldSet, + FieldSet& targetFieldSet, + Metadata& metadata) const { + ATLAS_TRACE("atlas::interpolation::method::RegionalLinear2D::do_execute()"); + ATLAS_ASSERT(sourceFieldSet.size() == targetFieldSet.size()); + + for (auto i = 0; i < sourceFieldSet.size(); ++i) { + do_execute(sourceFieldSet[i], targetFieldSet[i], metadata); + } +} + +void RegionalLinear2D::do_execute(const Field& sourceField, Field& targetField, + Metadata&) const { + ATLAS_TRACE("atlas::interpolation::method::RegionalLinear2D::do_execute()"); + + if (targetField.size() == 0) { + return; + } + + // Check number of levels + ASSERT(sourceField.levels() == targetField.levels()); + const size_t nz = sourceField.levels() > 0 ? sourceField.levels() : 1; + const size_t ndim = sourceField.levels() > 0 ? 2 : 1; + + // Scale counts and displs for all levels + std::vector sourceSendCounts3D(comm_.size()); + std::vector sourceSendDispls3D(comm_.size()); + std::vector targetRecvCounts3D(comm_.size()); + std::vector targetRecvDispls3D(comm_.size()); + for (size_t jt = 0; jt < comm_.size(); ++jt) { + sourceSendCounts3D[jt] = sourceSendCounts_[jt]*nz; + sourceSendDispls3D[jt] = sourceSendDispls_[jt]*nz; + targetRecvCounts3D[jt] = targetRecvCounts_[jt]*nz; + targetRecvDispls3D[jt] = targetRecvDispls_[jt]*nz; + } + + // Halo exchange + haloExchange(sourceField); + + // Serialize + std::vector sourceSendVec(sourceSendSize_*nz); + if (ndim == 1) { + const auto sourceView = array::make_view(sourceField); + for (size_t js = 0; js < sourceSendSize_; ++js) { + size_t sourceJnode = sourceSendMapping_[js]; + sourceSendVec[js] = sourceView(sourceJnode); + } + } else if (ndim == 2) { + const auto sourceView = array::make_view(sourceField); + for (size_t js = 0; js < sourceSendSize_; ++js) { + for (size_t k = 0; k < nz; ++k) { + size_t sourceJnode = sourceSendMapping_[js]; + sourceSendVec[js*nz+k] = sourceView(sourceJnode, k); + } + } + } + + // Communication + std::vector targetRecvVec(targetRecvSize_*nz); + comm_.allToAllv(sourceSendVec.data(), sourceSendCounts3D.data(), sourceSendDispls3D.data(), + targetRecvVec.data(), targetRecvCounts3D.data(), targetRecvDispls3D.data()); + + // Interpolation + if (ndim == 1) { + auto targetView = array::make_view(targetField); + targetView.assign(0.0); + for (size_t targetJnode = 0; targetJnode < targetSize_; ++targetJnode) { + for (size_t jj = 0; jj < stencilSize_[targetJnode]; ++jj) { + targetView(targetJnode) += weights_[targetJnode][jj] + *targetRecvVec[stencil_[targetJnode][jj]]; + } + } + } else if (ndim == 2) { + auto targetView = array::make_view(targetField); + targetView.assign(0.0); + for (size_t targetJnode = 0; targetJnode < targetSize_; ++targetJnode) { + for (size_t jj = 0; jj < stencilSize_[targetJnode]; ++jj) { + for (size_t k = 0; k < nz; ++k) { + targetView(targetJnode, k) += weights_[targetJnode][jj] + *targetRecvVec[stencil_[targetJnode][jj]*nz+k]; + } + } + } + } + + // Set target field dirty + targetField.set_dirty(); +} + +void RegionalLinear2D::do_execute_adjoint(FieldSet& sourceFieldSet, + const FieldSet& targetFieldSet, + Metadata& metadata) const { + ATLAS_TRACE( + "atlas::interpolation::method::RegionalLinear2D::do_execute_adjoint()"); + ATLAS_ASSERT(sourceFieldSet.size() == targetFieldSet.size()); + + for (auto i = 0; i < sourceFieldSet.size(); ++i) { + do_execute_adjoint(sourceFieldSet[i], targetFieldSet[i], metadata); + } +} + +void RegionalLinear2D::do_execute_adjoint(Field& sourceField, + const Field& targetField, + Metadata& metadata) const { + ATLAS_TRACE( + "atlas::interpolation::method::RegionalLinear2D::do_execute_adjoint()"); + + if (targetField.size() == 0) { + return; + } + + // Check number of levels + ASSERT(sourceField.levels() == targetField.levels()); + const size_t nz = sourceField.levels() > 0 ? sourceField.levels() : 1; + const size_t ndim = sourceField.levels() > 0 ? 2 : 1; + + // Scale counts and displs for all levels + std::vector sourceSendCounts3D(comm_.size()); + std::vector sourceSendDispls3D(comm_.size()); + std::vector targetRecvCounts3D(comm_.size()); + std::vector targetRecvDispls3D(comm_.size()); + for (size_t jt = 0; jt < comm_.size(); ++jt) { + sourceSendCounts3D[jt] = sourceSendCounts_[jt]*nz; + sourceSendDispls3D[jt] = sourceSendDispls_[jt]*nz; + targetRecvCounts3D[jt] = targetRecvCounts_[jt]*nz; + targetRecvDispls3D[jt] = targetRecvDispls_[jt]*nz; + } + + // Copy destination field + Field targetTmpField = targetField.clone(); + + // Interpolation adjoint + std::vector targetRecvVec(targetRecvSize_*nz, 0.0); + if (ndim == 1) { + const auto targetView = array::make_view(targetTmpField); + for (size_t targetJnode = 0; targetJnode < targetSize_; ++targetJnode) { + for (size_t jj = 0; jj < stencilSize_[targetJnode]; ++jj) { + targetRecvVec[stencil_[targetJnode][jj]] += weights_[targetJnode][jj] + *targetView(targetJnode); + } + } + } else if (ndim == 2) { + const auto targetView = array::make_view(targetTmpField); + for (size_t targetJnode = 0; targetJnode < targetSize_; ++targetJnode) { + for (size_t jj = 0; jj < stencilSize_[targetJnode]; ++jj) { + for (size_t k = 0; k < nz; ++k) { + targetRecvVec[stencil_[targetJnode][jj]*nz+k] += weights_[targetJnode][jj] + *targetView(targetJnode, k); + } + } + } + } + + // Communication + std::vector sourceSendVec(sourceSendSize_*nz); + comm_.allToAllv(targetRecvVec.data(), targetRecvCounts3D.data(), targetRecvDispls3D.data(), + sourceSendVec.data(), sourceSendCounts3D.data(), sourceSendDispls3D.data()); + + // Deserialize + if (ndim == 1) { + auto sourceView = array::make_view(sourceField); + sourceView.assign(0.0); + for (size_t js = 0; js < sourceSendSize_; ++js) { + size_t sourceJnode = sourceSendMapping_[js]; + sourceView(sourceJnode) += sourceSendVec[js]; + } + } else if (ndim == 2) { + auto sourceView = array::make_view(sourceField); + sourceView.assign(0.0); + for (size_t js = 0; js < sourceSendSize_; ++js) { + size_t sourceJnode = sourceSendMapping_[js]; + for (size_t k = 0; k < nz; ++k) { + sourceView(sourceJnode, k) += sourceSendVec[js*nz+k]; + } + } + } + + // Adjoint halo exchange + adjointHaloExchange(sourceField); +} + +} // namespace method +} // namespace interpolation +} // namespace atlas diff --git a/src/atlas/interpolation/method/structured/RegionalLinear2D.h b/src/atlas/interpolation/method/structured/RegionalLinear2D.h new file mode 100644 index 000000000..462f88662 --- /dev/null +++ b/src/atlas/interpolation/method/structured/RegionalLinear2D.h @@ -0,0 +1,77 @@ +/* + * (C) Copyright 2024 Meteorologisk Institutt + * + * This software is licensed under the terms of the Apache Licence Version 2.0 + * which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. + */ + +#pragma once + +#include "atlas/interpolation/method/Method.h" + +#include + +#include "atlas/field.h" +#include "atlas/functionspace.h" +#include "atlas/grid/Grid.h" + +#include "eckit/config/Configuration.h" +#include "eckit/mpi/Comm.h" + +namespace atlas { +namespace interpolation { +namespace method { + + +class RegionalLinear2D : public Method { + public: + /// @brief Regional linear interpolation + /// + /// @details + /// + RegionalLinear2D(const Config& config) : Method(config), comm_(eckit::mpi::comm()) {} + ~RegionalLinear2D() override {} + + void print(std::ostream&) const override; + const FunctionSpace& source() const override { return source_; } + const FunctionSpace& target() const override { return target_; } + + void do_execute(const FieldSet& sourceFieldSet, FieldSet& targetFieldSet, + Metadata& metadata) const override; + void do_execute(const Field& sourceField, Field& targetField, + Metadata& metadata) const override; + + void do_execute_adjoint(FieldSet& sourceFieldSet, + const FieldSet& targetFieldSet, + Metadata& metadata) const override; + void do_execute_adjoint(Field& sourceField, const Field& targetField, + Metadata& metadata) const override; + + private: + using Method::do_setup; + void do_setup(const FunctionSpace& source, const FunctionSpace& target) override; + void do_setup(const Grid& source, const Grid& target, const Cache&) override; + + FunctionSpace source_{}; + FunctionSpace target_{}; + + const eckit::mpi::Comm & comm_; + size_t sourceSize_; + std::vector mpiTask_; + size_t targetSize_; + size_t sourceSendSize_; + size_t targetRecvSize_; + std::vector sourceSendCounts_; + std::vector sourceSendDispls_; + std::vector targetRecvCounts_; + std::vector targetRecvDispls_; + std::vector sourceSendMapping_; + std::vector> stencil_; + std::vector> weights_; + std::vector stencilSize_; +}; + + +} // namespace method +} // namespace interpolation +} // namespace atlas diff --git a/src/tests/interpolation/CMakeLists.txt b/src/tests/interpolation/CMakeLists.txt index 82005b229..a5ac0e86a 100644 --- a/src/tests/interpolation/CMakeLists.txt +++ b/src/tests/interpolation/CMakeLists.txt @@ -75,6 +75,14 @@ ecbuild_add_executable( TARGET atlas_test_interpolation_structured2D NOINSTALL ) +ecbuild_add_test( TARGET atlas_test_interpolation_structured2D_regional + SOURCES test_interpolation_structured2D_regional.cc + LIBS atlas + MPI 2 + CONDITION eckit_HAVE_MPI + ENVIRONMENT ${ATLAS_TEST_ENVIRONMENT} +) + ecbuild_add_test( TARGET atlas_test_interpolation_non_linear SOURCES test_interpolation_non_linear.cc LIBS atlas diff --git a/src/tests/interpolation/test_interpolation_k_nearest_neighbours.cc b/src/tests/interpolation/test_interpolation_k_nearest_neighbours.cc index 0ce3cd190..641005fb6 100644 --- a/src/tests/interpolation/test_interpolation_k_nearest_neighbours.cc +++ b/src/tests/interpolation/test_interpolation_k_nearest_neighbours.cc @@ -1,9 +1,5 @@ /* -<<<<<<< HEAD - * (C) Copyright 1996- ECMWF. -======= * (C) Copyright 2013 ECMWF. ->>>>>>> develop * * This software is licensed under the terms of the Apache Licence Version 2.0 * which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. diff --git a/src/tests/interpolation/test_interpolation_structured2D_regional.cc b/src/tests/interpolation/test_interpolation_structured2D_regional.cc new file mode 100644 index 000000000..8bba8692e --- /dev/null +++ b/src/tests/interpolation/test_interpolation_structured2D_regional.cc @@ -0,0 +1,217 @@ +/* + * (C) Copyright 2024 Meteorologisk Institutt + * + * This software is licensed under the terms of the Apache Licence Version 2.0 + * which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. + */ + +#include "atlas/array.h" +#include "atlas/field.h" +#include "atlas/functionspace/StructuredColumns.h" +#include "atlas/grid.h" +#include "atlas/interpolation.h" +#include "atlas/option.h" +#include "atlas/util/Config.h" + +#include "tests/AtlasTestEnvironment.h" + +using atlas::functionspace::StructuredColumns; +using atlas::util::Config; + +namespace atlas { +namespace test { + +//----------------------------------------------------------------------------- + +static Config gridConfigs() { + Config gridConfigs; + + Config projectionConfig; + projectionConfig.set("type", "lambert_conformal_conic"); + projectionConfig.set("latitude0", 56.3); + projectionConfig.set("longitude0", 0.0); + + Config gridConfig; + gridConfig.set("type", "regional"); + std::vector lonlat = {9.9, 56.3}; + gridConfig.set("lonlat(centre)", lonlat); + gridConfig.set("projection", projectionConfig); + + const size_t sourceNx = 21; + const size_t sourceNy = 31; + const double sourceDx = 8.0e3; + const double sourceDy = 9.0e3; + + const size_t xFactor = 4; + const size_t yFactor = 3; + + const size_t targetNx = (sourceNx-1)*xFactor+1; + const size_t targetNy = (sourceNy-1)*yFactor+1; + const double targetDx = static_cast(sourceNx-1)/static_cast(targetNx-1)*sourceDx; + const double targetDy = static_cast(sourceNy-1)/static_cast(targetNy-1)*sourceDy; + + gridConfig.set("nx", sourceNx); + gridConfig.set("ny", sourceNy); + gridConfig.set("dx", sourceDx); + gridConfig.set("dy", sourceDy); + gridConfigs.set("source", gridConfig); + gridConfigs.set("source normalization", (sourceNx-1)*(sourceNy-1)); + + gridConfig.set("nx", targetNx); + gridConfig.set("ny", targetNy); + gridConfig.set("dx", targetDx); + gridConfig.set("dy", targetDy); + gridConfigs.set("target", gridConfig); + gridConfigs.set("target normalization", (targetNx-1)*(targetNy-1)); + + return gridConfigs; +} + +// Dot product +double dotProd(const Field& field01, const Field& field02) { + double dprod{}; + + const size_t ndim = field01.levels() > 0 ? 2 : 1; + if (ndim == 1) { + const auto field01_view = array::make_view(field01); + const auto field02_view = array::make_view(field02); + for (idx_t i=0; i < field01_view.shape(0); ++i) { + dprod += field01_view(i) * field02_view(i); + } + } else if (ndim == 2) { + const auto field01_view = array::make_view(field01); + const auto field02_view = array::make_view(field02); + for (idx_t l=0; l < field01_view.shape(1); ++l) { + for (idx_t i=0; i < field01_view.shape(0); ++i) { + dprod += field01_view(i, l) * field02_view(i, l); + } + } + } + eckit::mpi::comm().allReduceInPlace(dprod, eckit::mpi::Operation::SUM); + + return dprod; +} + + +CASE("test_interpolation_structured2D_regional_1d") { + Grid sourceGrid(gridConfigs().getSubConfiguration("source")); + Grid targetGrid(gridConfigs().getSubConfiguration("target")); + + StructuredColumns sourceFs(sourceGrid, option::halo(1)); + StructuredColumns targetFs(targetGrid, option::halo(1)); + + Interpolation interpolation(Config("type", "regional-linear-2d"), sourceFs, targetFs); + + auto sourceField = sourceFs.createField(Config("name", "source")); + auto targetField = targetFs.createField(Config("name", "target")); + + // Accuracy test + const auto sourceIView = array::make_view(sourceFs.index_i()); + const auto sourceJView = array::make_view(sourceFs.index_j()); + auto sourceView = array::make_view(sourceField); + const auto sourceGhostView = atlas::array::make_view(sourceFs.ghost()); + sourceView.assign(0.0); + for (idx_t i = 0; i < sourceFs.size(); ++i) { + if (sourceGhostView(i) == 0) { + sourceView(i) = static_cast((sourceIView(i)-1)*(sourceJView(i)-1)) + /static_cast(gridConfigs().getInt("source normalization")); + } + } + + interpolation.execute(sourceField, targetField); + + const auto targetIView = array::make_view(targetFs.index_i()); + const auto targetJView = array::make_view(targetFs.index_j()); + const auto targetView = array::make_view(targetField); + const auto targetGhostView = atlas::array::make_view(targetFs.ghost()); + const double tolerance = 1.e-12; + for (idx_t i = 0; i < targetFs.size(); ++i) { + if (targetGhostView(i) == 0) { + const double targetTest = static_cast((targetIView(i)-1)*(targetJView(i)-1)) + /static_cast(gridConfigs().getInt("target normalization")); + EXPECT_APPROX_EQ(targetView(i), targetTest, tolerance); + } + } + + // Adjoint test + auto targetAdjoint = targetFs.createField(); + array::make_view(targetAdjoint).assign(array::make_view(targetField)); + targetAdjoint.adjointHaloExchange(); + + auto sourceAdjoint = sourceFs.createField(); + array::make_view(sourceAdjoint).assign(0.); + interpolation.execute_adjoint(sourceAdjoint, targetAdjoint); + + const auto yDotY = dotProd(targetField, targetField); + const auto xDotXAdj = dotProd(sourceField, sourceAdjoint); + + EXPECT_APPROX_EQ(yDotY / xDotXAdj, 1., 1e-14); +} + + +CASE("test_interpolation_structured2D_regional_2d") { + Grid sourceGrid(gridConfigs().getSubConfiguration("source")); + Grid targetGrid(gridConfigs().getSubConfiguration("target")); + + const idx_t nlevs = 2; + StructuredColumns sourceFs(sourceGrid, option::halo(1) | option::levels(nlevs)); + StructuredColumns targetFs(targetGrid, option::halo(1) | option::levels(nlevs)); + + Interpolation interpolation(Config("type", "regional-linear-2d"), sourceFs, targetFs); + + auto sourceField = sourceFs.createField(Config("name", "source")); + auto targetField = targetFs.createField(Config("name", "target")); + + // Accuracy test + const auto sourceIView = array::make_view(sourceFs.index_i()); + const auto sourceJView = array::make_view(sourceFs.index_j()); + auto sourceView = array::make_view(sourceField); + const auto sourceGhostView = atlas::array::make_view(sourceFs.ghost()); + sourceView.assign(0.0); + for (idx_t i = 0; i < sourceFs.size(); ++i) { + if (sourceGhostView(i) == 0) { + for (idx_t k = 0; k < nlevs; ++k) { + sourceView(i, k) = static_cast((sourceIView(i)-1)*(sourceJView(i)-1)) + /static_cast(gridConfigs().getInt("source normalization")); + } + } + } + + interpolation.execute(sourceField, targetField); + + const auto targetIView = array::make_view(targetFs.index_i()); + const auto targetJView = array::make_view(targetFs.index_j()); + const auto targetView = array::make_view(targetField); + const auto targetGhostView = atlas::array::make_view(targetFs.ghost()); + const double tolerance = 1.e-12; + for (idx_t i = 0; i < targetFs.size(); ++i) { + if (targetGhostView(i) == 0) { + const double targetTest = static_cast((targetIView(i)-1)*(targetJView(i)-1)) + /static_cast(gridConfigs().getInt("target normalization")); + for (idx_t k = 0; k < nlevs; ++k) { + EXPECT_APPROX_EQ(targetView(i, k), targetTest, tolerance); + } + } + } + + // Adjoint test + auto targetAdjoint = targetFs.createField(); + array::make_view(targetAdjoint).assign(array::make_view(targetField)); + targetAdjoint.adjointHaloExchange(); + + auto sourceAdjoint = sourceFs.createField(); + array::make_view(sourceAdjoint).assign(0.); + interpolation.execute_adjoint(sourceAdjoint, targetAdjoint); + + const auto yDotY = dotProd(targetField, targetField); + const auto xDotXAdj = dotProd(sourceField, sourceAdjoint); + + EXPECT_APPROX_EQ(yDotY / xDotXAdj, 1., 1e-14); +} + +} // namespace test +} // namespace atlas + +int main(int argc, char** argv) { + return atlas::test::run(argc, argv); +} From e2974add35bb906a0c4ecf48a7ddaa32cd2a5241 Mon Sep 17 00:00:00 2001 From: Oliver Lomax Date: Mon, 16 Sep 2024 12:59:48 +0100 Subject: [PATCH 2/2] Pack vector components into higher-rank vector fields. (#218) --- src/atlas/CMakeLists.txt | 2 + src/atlas/util/PackVectorFields.cc | 229 +++++++++++++++++++++ src/atlas/util/PackVectorFields.h | 40 ++++ src/tests/util/CMakeLists.txt | 9 + src/tests/util/test_pack_vector_fields.cc | 235 ++++++++++++++++++++++ 5 files changed, 515 insertions(+) create mode 100644 src/atlas/util/PackVectorFields.cc create mode 100644 src/atlas/util/PackVectorFields.h create mode 100644 src/tests/util/test_pack_vector_fields.cc diff --git a/src/atlas/CMakeLists.txt b/src/atlas/CMakeLists.txt index 30f4a14b1..c87a26971 100644 --- a/src/atlas/CMakeLists.txt +++ b/src/atlas/CMakeLists.txt @@ -826,6 +826,8 @@ util/PolygonXY.cc util/PolygonXY.h util/Metadata.cc util/Metadata.h +util/PackVectorFields.cc +util/PackVectorFields.h util/Point.cc util/Point.h util/Polygon.cc diff --git a/src/atlas/util/PackVectorFields.cc b/src/atlas/util/PackVectorFields.cc new file mode 100644 index 000000000..640a1d46b --- /dev/null +++ b/src/atlas/util/PackVectorFields.cc @@ -0,0 +1,229 @@ +/* + * (C) Crown Copyright 2024 Met Office + * + * This software is licensed under the terms of the Apache Licence Version 2.0 + * which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. + */ + +#include "atlas/util/PackVectorFields.h" + +#include +#include +#include +#include + +#include "atlas/array.h" +#include "atlas/array/helpers/ArrayForEach.h" +#include "atlas/functionspace.h" +#include "atlas/option.h" +#include "atlas/runtime/Exception.h" +#include "atlas/util/Config.h" +#include "eckit/config/LocalConfiguration.h" + +namespace atlas { +namespace util { + +namespace { + +using eckit::LocalConfiguration; + +using array::DataType; +using array::helpers::arrayForEachDim; + +void addOrReplaceField(FieldSet& fieldSet, const Field& field) { + const auto fieldName = field.name(); + if (fieldSet.has(fieldName)) { + fieldSet[fieldName] = field; + } else { + fieldSet.add(field); + } +} + +Field& getOrCreateField(FieldSet& fieldSet, const FunctionSpace& functionSpace, + const Config& config) { + const auto fieldName = config.getString("name"); + if (!fieldSet.has(fieldName)) { + fieldSet.add(functionSpace.createField(config)); + } + return fieldSet[fieldName]; +} + +void checkFieldCompatibility(const Field& componentField, + const Field& vectorField) { + ATLAS_ASSERT(componentField.functionspace().size() == + vectorField.functionspace().size()); + ATLAS_ASSERT(componentField.levels() == vectorField.levels()); + ATLAS_ASSERT(componentField.variables() == 0); + ATLAS_ASSERT(vectorField.variables() > 0); + + const auto checkStandardShape = [](const Field& field) { + // Check for "standard" Atlas field shape. + auto dim = 0; + const auto rank = field.rank(); + const auto shape = field.shape(); + if (field.functionspace().size() != shape[dim++]) { + return false; + } + if (const auto levels = field.levels(); + levels && (dim >= rank || levels != shape[dim++])) { + return false; + } + if (const auto variables = field.variables(); + variables && (dim >= rank || variables != shape[dim++])) { + return false; + } + if (dim != rank) { + return false; + } + return true; + }; + + ATLAS_ASSERT(checkStandardShape(componentField)); + ATLAS_ASSERT(checkStandardShape(vectorField)); +} + +template +void copyFieldData(ComponentField& componentField, VectorField& vectorField, + const Functor& copier) { + checkFieldCompatibility(componentField, vectorField); + + const auto copyArrayData = [&](auto value, auto rank) { + // Resolve value-type and rank from arguments. + using Value = decltype(value); + constexpr auto Rank = decltype(rank)::value; + + // Iterate over fields. + auto vectorView = array::make_view(vectorField); + auto componentView = array::make_view(componentField); + constexpr auto Dims = std::make_integer_sequence{}; + arrayForEachDim(Dims, execution::par, std::tie(componentView, vectorView), + copier); + }; + + const auto selectRank = [&](auto value) { + switch (vectorField.rank()) { + case 2: + return copyArrayData(value, std::integral_constant{}); + case 3: + return copyArrayData(value, std::integral_constant{}); + default: + ATLAS_THROW_EXCEPTION("Unsupported vector field rank: " + + std::to_string(vectorField.rank())); + } + }; + + const auto selectType = [&]() { + switch (vectorField.datatype().kind()) { + case DataType::kind(): + return selectRank(double{}); + case DataType::kind(): + return selectRank(float{}); + case DataType::kind(): + return selectRank(long{}); + case DataType::kind(): + return selectRank(int{}); + default: + ATLAS_THROW_EXCEPTION("Unknown datatype: " + + std::to_string(vectorField.datatype().kind())); + } + }; + + selectType(); +} + +} // namespace + +FieldSet pack_vector_fields(const FieldSet& fields, FieldSet packedFields) { + // Get the number of variables for each vector field. + auto vectorSizeMap = std::map{}; + for (const auto& field : fields) { + auto vectorFieldName = std::string{}; + if (field.metadata().get("vector_field_name", vectorFieldName)) { + ++vectorSizeMap[vectorFieldName]; + } + } + auto vectorIndexMap = std::map{}; + + // Pack vector fields. + for (const auto& field : fields) { + auto vectorFieldName = std::string{}; + if (!field.metadata().get("vector_field_name", vectorFieldName)) { + // Not a vector component field. + addOrReplaceField(packedFields, field); + continue; + } + + // Field is vector field component. + const auto& componentField = field; + + // Get or create vector field. + const auto vectorFieldConfig = + option::name(vectorFieldName) | + option::levels(componentField.levels()) | + option::vector(vectorSizeMap[vectorFieldName]) | + option::datatype(componentField.datatype()); + auto& vectorField = getOrCreateField( + packedFields, componentField.functionspace(), vectorFieldConfig); + + // Copy field data. + const auto vectorIndex = vectorIndexMap[vectorFieldName]++; + const auto copier = [&](auto&& componentElem, auto&& vectorElem) { + vectorElem(vectorIndex) = componentElem; + }; + copyFieldData(componentField, vectorField, copier); + + // Copy metadata. + const auto componentFieldMetadata = componentField.metadata(); + auto componentFieldMetadataVector = std::vector{}; + vectorField.metadata().get("component_field_metadata", + componentFieldMetadataVector); + componentFieldMetadataVector.push_back(componentFieldMetadata); + vectorField.metadata().set("component_field_metadata", + componentFieldMetadataVector); + } + return packedFields; +} + +FieldSet unpack_vector_fields(const FieldSet& fields, FieldSet unpackedFields) { + for (const auto& field : fields) { + auto componentFieldMetadataVector = std::vector{}; + if (!field.metadata().get("component_field_metadata", + componentFieldMetadataVector)) { + // Not a vector field. + addOrReplaceField(unpackedFields, field); + continue; + } + + // Field is vector. + const auto& vectorField = field; + + auto vectorIndex = 0; + for (const auto& componentFieldMetadata : componentFieldMetadataVector) { + + // Get or create field. + auto componentFieldName = std::string{}; + componentFieldMetadata.get("name", componentFieldName); + const auto componentFieldConfig = + option::name(componentFieldName) | + option::levels(vectorField.levels()) | + option::datatype(vectorField.datatype()); + auto& componentField = getOrCreateField( + unpackedFields, vectorField.functionspace(), componentFieldConfig); + + // Copy field data. + const auto copier = [&](auto&& componentElem, auto&& vectorElem) { + componentElem = vectorElem(vectorIndex); + }; + copyFieldData(componentField, vectorField, copier); + + // Copy metadata. + componentField.metadata() = componentFieldMetadata; + + ++vectorIndex; + } + } + return unpackedFields; +} + +} // namespace util +} // namespace atlas diff --git a/src/atlas/util/PackVectorFields.h b/src/atlas/util/PackVectorFields.h new file mode 100644 index 000000000..fc77da6c5 --- /dev/null +++ b/src/atlas/util/PackVectorFields.h @@ -0,0 +1,40 @@ +/* + * (C) Crown Copyright 2024 Met Office + * + * This software is licensed under the terms of the Apache Licence Version 2.0 + * which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. + */ + +#pragma once + +#include "atlas/field.h" + +namespace eckit { +class LocalConfiguration; +} + +namespace atlas { +namespace util { + +/// @brief Packs vector field components into vector fields +/// +/// @details Iterates through @param fields and creates vector fields from any +/// component field with the "vector name" string metadata. These, as +/// well as any present scalar fields, are added to the return-value +/// field set. +/// Note, a mutable @param packedFields field set can be supplied if +/// one needs to guarantee the order of the packed fields +FieldSet pack_vector_fields(const FieldSet& fields, + FieldSet packedFields = FieldSet{}); + +/// @brief Unpacks vector field into vector field components. +/// +/// @details Undoes "pack" operation when a set of packed fields are supplied +/// as @param fields. A mutable @param unpackedFields field set can be +/// supplied if one needs to guarantee the order of the unpacked +/// fields. +FieldSet unpack_vector_fields(const FieldSet& fields, + FieldSet unpackedFields = FieldSet{}); + +} // namespace util +} // namespace atlas diff --git a/src/tests/util/CMakeLists.txt b/src/tests/util/CMakeLists.txt index 47d8b6def..7c7262ccf 100644 --- a/src/tests/util/CMakeLists.txt +++ b/src/tests/util/CMakeLists.txt @@ -94,3 +94,12 @@ ecbuild_add_test( TARGET atlas_test_unitsphere ENVIRONMENT ${ATLAS_TEST_ENVIRONMENT} ) +ecbuild_add_test( TARGET atlas_test_pack_vector_fields + SOURCES test_pack_vector_fields.cc + LIBS atlas + ENVIRONMENT ${ATLAS_TEST_ENVIRONMENT} +) + + + + diff --git a/src/tests/util/test_pack_vector_fields.cc b/src/tests/util/test_pack_vector_fields.cc new file mode 100644 index 000000000..3c6ea89df --- /dev/null +++ b/src/tests/util/test_pack_vector_fields.cc @@ -0,0 +1,235 @@ +/* + * (C) Crown Copyright 2024 Met Office + * + * This software is licensed under the terms of the Apache Licence Version 2.0 + * which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. + */ + +#include +#include + +#include "atlas/array.h" +#include "atlas/field.h" +#include "atlas/functionspace.h" +#include "atlas/grid.h" +#include "atlas/option.h" +#include "atlas/util/Config.h" +#include "atlas/util/PackVectorFields.h" +#include "tests/AtlasTestEnvironment.h" + +namespace atlas { +namespace test { + +FieldSet setFields(const FunctionSpace& functionSpace, + const std::vector& fieldConfigs) { + auto fields = FieldSet{}; + + // Set unique values to all field elements. + auto value = 0; + for (const auto& fieldConfig : fieldConfigs) { + auto field = fields.add(functionSpace.createField(fieldConfig)); + for (auto arrayIdx = size_t{0}; arrayIdx < field.size(); arrayIdx++) { + field->data()[arrayIdx] = value++; + } + field.metadata().set("comment", "This field is made with love."); + auto vectorFieldName = std::string{}; + if (fieldConfig.get("vector_field_name", vectorFieldName)) { + field.metadata().set("vector_field_name", vectorFieldName); + } + } + return fields; +} + +FieldSet createOrderedTestFields() { + const auto grid = Grid("O16"); + const auto functionSpace = functionspace::StructuredColumns(grid); + + // Note: vector components 0 and 1 are contiguous in field set. + auto fieldConfigs = std::vector{}; + fieldConfigs.push_back(option::name("scalar") | option::levels(1) | + option::datatype(DataType::kind())); + fieldConfigs.push_back(option::name("vector_component_0") | + option::levels(1) | + option::datatype(DataType::kind()) | + util::Config{"vector_field_name", "vector"}); + fieldConfigs.push_back(option::name("vector_component_1") | + option::levels(1) | + option::datatype(DataType::kind()) | + util::Config{"vector_field_name", "vector"}); + + return setFields(functionSpace, fieldConfigs); +} + +FieldSet createUnorderedTestFields() { + auto fields = FieldSet{}; + + const auto grid = Grid("O16"); + const auto functionSpace = functionspace::StructuredColumns(grid); + + // Note: vector components 0 and 1 are not contiguous in field set. + auto fieldConfigs = std::vector{}; + fieldConfigs.push_back(option::name("vector_component_0") | + option::datatype(DataType::kind()) | + util::Config{"vector_field_name", "vector"}); + fieldConfigs.push_back(option::name("scalar") | + option::datatype(DataType::kind())); + fieldConfigs.push_back(option::name("vector_component_1") | + option::datatype(DataType::kind()) | + util::Config{"vector_field_name", "vector"}); + + return setFields(functionSpace, fieldConfigs); +} + +FieldSet createInconsistentRankFields() { + auto fields = FieldSet{}; + + const auto grid = Grid("O16"); + const auto functionSpace = functionspace::StructuredColumns(grid); + + // Note: vector components 0 and 1 have different ranks. + auto fieldConfigs = std::vector{}; + fieldConfigs.push_back(option::name("vector_component_0") | + option::levels(10) | + option::datatype(DataType::kind()) | + util::Config{"vector_field_name", "vector"}); + fieldConfigs.push_back(option::name("scalar") | + option::datatype(DataType::kind())); + fieldConfigs.push_back(option::name("vector_component_1") | + option::datatype(DataType::kind()) | + util::Config{"vector_field_name", "vector"}); + + return setFields(functionSpace, fieldConfigs); +} + +FieldSet createInconsistentDatatypeFields() { + auto fields = FieldSet{}; + + const auto grid = Grid("O16"); + const auto functionSpace = functionspace::StructuredColumns(grid); + + // Note: vector components 0 and 1 have different datatypes. + auto fieldConfigs = std::vector{}; + fieldConfigs.push_back(option::name("vector_component_0") | + option::datatype(DataType::kind()) | + util::Config{"vector_field_name", "vector"}); + fieldConfigs.push_back(option::name("scalar") | + option::datatype(DataType::kind())); + fieldConfigs.push_back(option::name("vector_component_1") | + option::datatype(DataType::kind()) | + util::Config{"vector_field_name", "vector"}); + + return setFields(functionSpace, fieldConfigs); +} + +FieldSet createInconsistentLevelsFields() { + auto fields = FieldSet{}; + + const auto grid = Grid("O16"); + const auto functionSpace = functionspace::StructuredColumns(grid); + + // Note: vector components 0 and 1 have different number of levels. + auto fieldConfigs = std::vector{}; + fieldConfigs.push_back(option::name("vector_component_0") | + option::levels(10) | + option::datatype(DataType::kind()) | + util::Config{"vector_field_name", "vector"}); + fieldConfigs.push_back(option::name("scalar") | + option::datatype(DataType::kind())); + fieldConfigs.push_back(option::name("vector_component_1") | + option::levels(20) | + option::datatype(DataType::kind()) | + util::Config{"vector_field_name", "vector"}); + + return setFields(functionSpace, fieldConfigs); +} + +FieldSet createInconsistentVariablesFields() { + auto fields = FieldSet{}; + + const auto grid = Grid("O16"); + const auto functionSpace = functionspace::StructuredColumns(grid); + + // Note: vector components 0 and 1 have different number of variables. + auto fieldConfigs = std::vector{}; + fieldConfigs.push_back(option::name("vector_component_0") | + option::datatype(DataType::kind()) | + option::variables(2) | + util::Config{"vector_field_name", "vector"}); + fieldConfigs.push_back(option::name("scalar") | + option::datatype(DataType::kind())); + fieldConfigs.push_back(option::name("vector_component_1") | + option::datatype(DataType::kind()) | + util::Config{"vector_field_name", "vector"}); + + return setFields(functionSpace, fieldConfigs); +} + +void checkTestFields(const FieldSet& fields) { + auto value = 0; + for (const auto& field : fields) { + for (auto arrayIdx = size_t{0}; arrayIdx < field.size(); arrayIdx++) { + EXPECT(field->data()[arrayIdx] == value++); + } + EXPECT(field.metadata().get("comment") == + "This field is made with love."); + } +} + +CASE("Basic pack and unpack") { + const auto fields = createOrderedTestFields(); + + const auto packedFields = util::pack_vector_fields(fields); + + EXPECT(!packedFields.has("vector_component_0")); + EXPECT(!packedFields.has("vector_component_1")); + EXPECT(packedFields.has("vector")); + EXPECT(packedFields.has("scalar")); + + const auto unpackedFields = util::unpack_vector_fields(packedFields); + + EXPECT(unpackedFields.has("vector_component_0")); + EXPECT(unpackedFields.has("vector_component_1")); + EXPECT(!unpackedFields.has("vector")); + EXPECT(unpackedFields.has("scalar")); + + checkTestFields(unpackedFields); +} + +CASE("unpack into existing field set") { + auto fields = createUnorderedTestFields(); + + const auto packedFields = util::pack_vector_fields(fields); + + EXPECT(!packedFields.has("vector_component_0")); + EXPECT(!packedFields.has("vector_component_1")); + EXPECT(packedFields.has("vector")); + EXPECT(packedFields.has("scalar")); + + // Need to unpack into existing field to guarantee field order is preserved. + array::make_view(fields["vector_component_0"]).assign(0.); + array::make_view(fields["vector_component_1"]).assign(0.); + util::unpack_vector_fields(packedFields, fields); + + EXPECT(fields.has("vector_component_0")); + EXPECT(fields.has("vector_component_1")); + EXPECT(!fields.has("vector")); + EXPECT(fields.has("scalar")); + + checkTestFields(fields); +} + +CASE("check that bad inputs throw") { + // Try to apply pack to inconsistent field sets. + EXPECT_THROWS(util::pack_vector_fields(createInconsistentRankFields())); + EXPECT_THROWS( + util::pack_vector_fields(createInconsistentDatatypeFields())); + EXPECT_THROWS( + util::pack_vector_fields(createInconsistentLevelsFields())); + EXPECT_THROWS( + util::pack_vector_fields(createInconsistentVariablesFields())); +} + +} // namespace test +} // namespace atlas + +int main(int argc, char** argv) { return atlas::test::run(argc, argv); }