Skip to content

Commit

Permalink
Add function to get neighboring indices at target.
Browse files Browse the repository at this point in the history
  • Loading branch information
nealkruis committed Mar 5, 2024
1 parent 525f3b0 commit 66d5f8d
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 0 deletions.
4 changes: 4 additions & 0 deletions include/btwxt/regular-grid-interpolator.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ class RegularGridInterpolator {

std::vector<double> operator()() { return get_values_at_target(); }

[[nodiscard]] std::vector<std::size_t> get_neighboring_indices_at_target() const;

std::vector<std::size_t> get_neighboring_indices_at_target(const std::vector<double>& target);

const std::vector<double>& get_target();

[[nodiscard]] const std::vector<TargetBoundsStatus>& get_target_bounds_status() const;
Expand Down
33 changes: 33 additions & 0 deletions src/regular-grid-interpolator-implementation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,39 @@ double RegularGridInterpolatorImplementation::get_grid_point_weighting_factor(
return weighting_factor;
}

std::vector<std::size_t> RegularGridInterpolatorImplementation::get_neighboring_indices_at_target()
{
if (!target_is_set) {
send_error("Cannot retrieve neighboring indices. No target has been set.");
}
std::vector<std::vector<std::size_t>> axes_neighbor_indices(
number_of_grid_axes,
std::vector<std::size_t>()); // For each axis, what are the neighboring indices?
for (std::size_t axis_index = 0; axis_index < number_of_grid_axes; ++axis_index) {
auto floor_index = floor_grid_point_coordinates[axis_index];
if (floor_to_ceiling_fractions[axis_index] < 1.0) {
axes_neighbor_indices[axis_index].push_back(floor_index);
}
if (grid_axis_lengths[axis_index] > 1 && floor_to_ceiling_fractions[axis_index] > 0.0) {
axes_neighbor_indices[axis_index].push_back(floor_index + 1);
}
}
std::vector<std::vector<std::size_t>> axes_neighbor_coordinates =
cartesian_product(axes_neighbor_indices);
std::vector<std::size_t> neighbor_indices;
neighbor_indices.reserve(axes_neighbor_coordinates.size());
for (const auto& coordinates : axes_neighbor_coordinates) {
neighbor_indices.push_back(get_grid_point_index(coordinates));
}
return neighbor_indices;
}

std::vector<std::size_t> RegularGridInterpolatorImplementation::get_neighboring_indices_at_target(
const std::vector<double>& target_in)
{
set_target(target_in);
return get_neighboring_indices_at_target();
}
// private methods

void RegularGridInterpolatorImplementation::setup()
Expand Down
4 changes: 4 additions & 0 deletions src/regular-grid-interpolator-implementation.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ class RegularGridInterpolatorImplementation : public Courier::Sender {
return floor_grid_point_coordinates;
};

std::vector<std::size_t> get_neighboring_indices_at_target();

std::vector<std::size_t> get_neighboring_indices_at_target(const std::vector<double>& target);

[[nodiscard]] inline const std::vector<std::vector<double>>&
get_interpolation_coefficients() const
{
Expand Down
11 changes: 11 additions & 0 deletions src/regular-grid-interpolator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,17 @@ std::vector<double> RegularGridInterpolator::get_values_at_target()
return implementation->get_results();
}

std::vector<std::size_t> RegularGridInterpolator::get_neighboring_indices_at_target() const
{
return implementation->get_neighboring_indices_at_target();
}

std::vector<std::size_t>
RegularGridInterpolator::get_neighboring_indices_at_target(const std::vector<double>& target_in)
{
return implementation->get_neighboring_indices_at_target(target_in);
}

const std::vector<double>& RegularGridInterpolator::get_target()
{
return implementation->get_target();
Expand Down
66 changes: 66 additions & 0 deletions test/btwxt-tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,72 @@ TEST_F(GridFixture, two_point_cubic_1d_interpolate)
EXPECT_NEAR(result, 5.25, 0.0001);
}

TEST_F(GridFixture, get_neighboring_indices)
{
grid = {{0, 1, 2}, {0, 1, 2}};
// clang-format off
data_sets = {{
// 0 1 2 < dim 2
0, 1, 2, // 0 dim 1
3, 4, 5, // 1 "
6, 7, 8 // 2 "
}};
// clang-format on
setup();

// Outside grid points
EXPECT_THAT(interpolator.get_neighboring_indices_at_target({-1, -1}), testing::ElementsAre(0));

EXPECT_THAT(interpolator.get_neighboring_indices_at_target({-1, 0.5}),
testing::ElementsAre(0, 1));

EXPECT_THAT(interpolator.get_neighboring_indices_at_target({-1, 3}), testing::ElementsAre(2));

EXPECT_THAT(interpolator.get_neighboring_indices_at_target({3, 3}), testing::ElementsAre(8));

// On outside boundaries
EXPECT_THAT(interpolator.get_neighboring_indices_at_target({0, 0.5}),
testing::ElementsAre(0, 1));

EXPECT_THAT(interpolator.get_neighboring_indices_at_target({0.5, 0}),
testing::ElementsAre(0, 3));

EXPECT_THAT(interpolator.get_neighboring_indices_at_target({2, 1.5}),
testing::ElementsAre(7, 8));

EXPECT_THAT(interpolator.get_neighboring_indices_at_target({0.5, 2}),
testing::ElementsAre(2, 5));

// On inside boundaries
EXPECT_THAT(interpolator.get_neighboring_indices_at_target({1, 0.5}),
testing::ElementsAre(3, 4));

EXPECT_THAT(interpolator.get_neighboring_indices_at_target({0.5, 1}),
testing::ElementsAre(1, 4));

// Inside cells
EXPECT_THAT(interpolator.get_neighboring_indices_at_target({0.5, 0.5}),
testing::ElementsAre(0, 1, 3, 4));

EXPECT_THAT(interpolator.get_neighboring_indices_at_target({0.5, 1.5}),
testing::ElementsAre(1, 2, 4, 5));

EXPECT_THAT(interpolator.get_neighboring_indices_at_target({1.5, 0.5}),
testing::ElementsAre(3, 4, 6, 7));

EXPECT_THAT(interpolator.get_neighboring_indices_at_target({1.5, 1.5}),
testing::ElementsAre(4, 5, 7, 8));

// On grid points
for (auto g0 : grid[0]) {
for (auto g1 : grid[1]) {
interpolator.set_target({g0, g1});
EXPECT_THAT(interpolator.get_neighboring_indices_at_target(),
testing::ElementsAre(interpolator.get_values_at_target()[0]));
}
}
}

TEST_F(Grid2DFixture, target_undefined)
{
std::vector<double> returned_target;
Expand Down

0 comments on commit 66d5f8d

Please sign in to comment.