Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

revert: "fix(interpolation): fix spline bug" #1648

Merged
merged 1 commit into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,35 @@
#ifndef INTERPOLATION__SPLINE_INTERPOLATION_HPP_
#define INTERPOLATION__SPLINE_INTERPOLATION_HPP_

#include <Eigen/Dense>
#include "interpolation/interpolation_utils.hpp"
#include "tier4_autoware_utils/geometry/geometry.hpp"

#include <algorithm>
#include <cmath>
#include <iostream>
#include <numeric>
#include <vector>

namespace interpolation
{
// NOTE: X(s) = a_i (s - s_i)^3 + b_i (s - s_i)^2 + c_i (s - s_i) + d_i : (i = 0, 1, ... N-1)
struct MultiSplineCoef
{
MultiSplineCoef() = default;

explicit MultiSplineCoef(const size_t num_spline)
{
a.resize(num_spline);
b.resize(num_spline);
c.resize(num_spline);
d.resize(num_spline);
}

std::vector<double> a;
std::vector<double> b;
std::vector<double> c;
std::vector<double> d;
};

// static spline interpolation functions
std::vector<double> slerp(
Expand Down Expand Up @@ -62,14 +84,8 @@ class SplineInterpolation
std::vector<double> getSplineInterpolatedDiffValues(const std::vector<double> & query_keys) const;

private:
Eigen::VectorXd a_;
Eigen::VectorXd b_;
Eigen::VectorXd c_;
Eigen::VectorXd d_;

std::vector<double> base_keys_;

Eigen::Index get_index(double key) const;
interpolation::MultiSplineCoef multi_spline_coef_;
};

#endif // INTERPOLATION__SPLINE_INTERPOLATION_HPP_
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@
#define INTERPOLATION__SPLINE_INTERPOLATION_POINTS_2D_HPP_

#include "interpolation/spline_interpolation.hpp"
#include "tier4_autoware_utils/geometry/geometry.hpp"

#include <geometry_msgs/msg/point.hpp>

#include <vector>

Expand Down
1 change: 0 additions & 1 deletion common/interpolation/package.xml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
<license>Apache License 2.0</license>
<buildtool_depend>ament_cmake_auto</buildtool_depend>

<depend>eigen</depend>
<depend>tier4_autoware_utils</depend>

<test_depend>ament_lint_auto</test_depend>
Expand Down
210 changes: 121 additions & 89 deletions common/interpolation/src/spline_interpolation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,43 +14,70 @@

#include "interpolation/spline_interpolation.hpp"

#include "interpolation/interpolation_utils.hpp"
#include <vector>

#include <algorithm>

Eigen::VectorXd solve_tridiagonal_matrix_algorithm(
const Eigen::Ref<const Eigen::VectorXd> & a, const Eigen::Ref<const Eigen::VectorXd> & b,
const Eigen::Ref<const Eigen::VectorXd> & c, const Eigen::Ref<const Eigen::VectorXd> & d)
namespace
{
auto n = d.size();

if (n == 1) {
return d.array() / b.array();
}

Eigen::VectorXd c_prime = Eigen::VectorXd::Zero(n);
Eigen::VectorXd d_prime = Eigen::VectorXd::Zero(n);
Eigen::VectorXd x = Eigen::VectorXd::Zero(n);

// Forward sweep
c_prime(0) = c(0) / b(0);
d_prime(0) = d(0) / b(0);

for (auto i = 1; i < n; i++) {
double m = 1.0 / (b(i) - a(i - 1) * c_prime(i - 1));
c_prime(i) = i < n - 1 ? c(i) * m : 0;
d_prime(i) = (d(i) - a(i - 1) * d_prime(i - 1)) * m;
// solve Ax = d
// where A is tridiagonal matrix
// [b_0 c_0 ... ]
// [a_0 b_1 c_1 ... O ]
// A = [ ... ]
// [ O ... a_N-3 b_N-2 c_N-2]
// [ ... a_N-2 b_N-1]
struct TDMACoef
{
explicit TDMACoef(const size_t num_row)
{
a.resize(num_row - 1);
b.resize(num_row);
c.resize(num_row - 1);
d.resize(num_row);
}

// Back substitution
x(n - 1) = d_prime(n - 1);
std::vector<double> a;
std::vector<double> b;
std::vector<double> c;
std::vector<double> d;
};

for (auto i = n - 2; i >= 0; i--) {
x(i) = d_prime(i) - c_prime(i) * x(i + 1);
inline std::vector<double> solveTridiagonalMatrixAlgorithm(const TDMACoef & tdma_coef)
{
const auto & a = tdma_coef.a;
const auto & b = tdma_coef.b;
const auto & c = tdma_coef.c;
const auto & d = tdma_coef.d;

const size_t num_row = b.size();

std::vector<double> x(num_row);
if (num_row != 1) {
// calculate p and q
std::vector<double> p;
std::vector<double> q;
p.push_back(-c[0] / b[0]);
q.push_back(d[0] / b[0]);

for (size_t i = 1; i < num_row; ++i) {
const double den = b[i] + a[i - 1] * p[i - 1];
p.push_back(-c[i - 1] / den);
q.push_back((d[i] - a[i - 1] * q[i - 1]) / den);
}

// calculate solution
x[num_row - 1] = q[num_row - 1];

for (size_t i = 1; i < num_row; ++i) {
const size_t j = num_row - 1 - i;
x[j] = p[j] * x[j + 1] + q[j];
}
} else {
x.push_back(d[0] / b[0]);
}

return x;
}
} // namespace

namespace interpolation
{
Expand All @@ -74,55 +101,48 @@ void SplineInterpolation::calcSplineCoefficients(
// throw exceptions for invalid arguments
interpolation_utils::validateKeysAndValues(base_keys, base_values);

Eigen::VectorXd x = Eigen::Map<const Eigen::VectorXd>(
base_keys.data(), static_cast<Eigen::Index>(base_keys.size()));
Eigen::VectorXd y = Eigen::Map<const Eigen::VectorXd>(
base_values.data(), static_cast<Eigen::Index>(base_values.size()));

const auto n = x.size();

if (n == 2) {
a_ = Eigen::VectorXd::Zero(1);
b_ = Eigen::VectorXd::Zero(1);
c_ = Eigen::VectorXd::Zero(1);
d_ = Eigen::VectorXd::Zero(1);
c_[0] = (y[1] - y[0]) / (x[1] - x[0]);
d_[0] = y[0];
base_keys_ = base_keys;
return;
const size_t num_base = base_keys.size(); // N+1

std::vector<double> diff_keys; // N
std::vector<double> diff_values; // N
for (size_t i = 0; i < num_base - 1; ++i) {
diff_keys.push_back(base_keys.at(i + 1) - base_keys.at(i));
diff_values.push_back(base_values.at(i + 1) - base_values.at(i));
}

// Create Tridiagonal matrix
Eigen::VectorXd v(n);
Eigen::VectorXd h = x.segment(1, n - 1) - x.segment(0, n - 1);
Eigen::VectorXd a = h.segment(1, n - 3);
Eigen::VectorXd b = 2 * (h.segment(0, n - 2) + h.segment(1, n - 2));
Eigen::VectorXd c = h.segment(1, n - 3);
Eigen::VectorXd y_diff = y.segment(1, n - 1) - y.segment(0, n - 1);
Eigen::VectorXd d = 6 * (y_diff.segment(1, n - 2).array() / h.tail(n - 2).array() -
y_diff.segment(0, n - 2).array() / h.head(n - 2).array());

// Solve tridiagonal matrix
v.segment(1, n - 2) = solve_tridiagonal_matrix_algorithm(a, b, c, d);
v[0] = 0;
v[n - 1] = 0;

// Calculate spline coefficients
a_ = (v.tail(n - 1) - v.head(n - 1)).array() / 6.0 / (x.tail(n - 1) - x.head(n - 1)).array();
b_ = v.segment(0, n - 1) / 2.0;
c_ = (y.tail(n - 1) - y.head(n - 1)).array() / (x.tail(n - 1) - x.head(n - 1)).array() -
(x.tail(n - 1) - x.head(n - 1)).array() *
(2 * v.segment(0, n - 1).array() + v.segment(1, n - 1).array()) / 6.0;
d_ = y.head(n - 1);
base_keys_ = base_keys;
}
std::vector<double> v = {0.0};
if (num_base > 2) {
// solve tridiagonal matrix algorithm
TDMACoef tdma_coef(num_base - 2); // N-1

for (size_t i = 0; i < num_base - 2; ++i) {
tdma_coef.b[i] = 2 * (diff_keys[i] + diff_keys[i + 1]);
if (i != num_base - 3) {
tdma_coef.a[i] = diff_keys[i + 1];
tdma_coef.c[i] = diff_keys[i + 1];
}
tdma_coef.d[i] =
6.0 * (diff_values[i + 1] / diff_keys[i + 1] - diff_values[i] / diff_keys[i]);
}

const std::vector<double> tdma_res = solveTridiagonalMatrixAlgorithm(tdma_coef);

// calculate v
v.insert(v.end(), tdma_res.begin(), tdma_res.end());
}
v.push_back(0.0);

// calculate a, b, c, d of spline coefficients
multi_spline_coef_ = interpolation::MultiSplineCoef{num_base - 1}; // N
for (size_t i = 0; i < num_base - 1; ++i) {
multi_spline_coef_.a[i] = (v[i + 1] - v[i]) / 6.0 / diff_keys[i];
multi_spline_coef_.b[i] = v[i] / 2.0;
multi_spline_coef_.c[i] =
diff_values[i] / diff_keys[i] - diff_keys[i] * (2 * v[i] + v[i + 1]) / 6.0;
multi_spline_coef_.d[i] = base_values[i];
}

Eigen::Index SplineInterpolation::get_index(double key) const
{
auto it = std::lower_bound(base_keys_.begin(), base_keys_.end(), key);
return std::clamp(
static_cast<int>(std::distance(base_keys_.begin(), it)) - 1, 0,
static_cast<int>(base_keys_.size()) - 2);
base_keys_ = base_keys;
}

std::vector<double> SplineInterpolation::getSplineInterpolatedValues(
Expand All @@ -131,17 +151,23 @@ std::vector<double> SplineInterpolation::getSplineInterpolatedValues(
// throw exceptions for invalid arguments
interpolation_utils::validateKeys(base_keys_, query_keys);

std::vector<double> interpolated_values;
interpolated_values.reserve(query_keys.size());
const auto & a = multi_spline_coef_.a;
const auto & b = multi_spline_coef_.b;
const auto & c = multi_spline_coef_.c;
const auto & d = multi_spline_coef_.d;

for (const auto & key : query_keys) {
const auto idx = get_index(key);
const auto dx = key - base_keys_[idx];
interpolated_values.emplace_back(
a_[idx] * dx * dx * dx + b_[idx] * dx * dx + c_[idx] * dx + d_[idx]);
std::vector<double> res;
size_t j = 0;
for (const auto & query_key : query_keys) {
while (base_keys_.at(j + 1) < query_key) {
++j;
}

const double ds = query_key - base_keys_.at(j);
res.push_back(d.at(j) + (c.at(j) + (b.at(j) + a.at(j) * ds) * ds) * ds);
}

return interpolated_values;
return res;
}

std::vector<double> SplineInterpolation::getSplineInterpolatedDiffValues(
Expand All @@ -150,14 +176,20 @@ std::vector<double> SplineInterpolation::getSplineInterpolatedDiffValues(
// throw exceptions for invalid arguments
interpolation_utils::validateKeys(base_keys_, query_keys);

std::vector<double> interpolated_diff_values;
interpolated_diff_values.reserve(query_keys.size());
const auto & a = multi_spline_coef_.a;
const auto & b = multi_spline_coef_.b;
const auto & c = multi_spline_coef_.c;

std::vector<double> res;
size_t j = 0;
for (const auto & query_key : query_keys) {
while (base_keys_.at(j + 1) < query_key) {
++j;
}

for (const auto & key : query_keys) {
const auto idx = get_index(key);
const auto dx = key - base_keys_[idx];
interpolated_diff_values.emplace_back(3 * a_[idx] * dx * dx + 2 * b_[idx] * dx + c_[idx]);
const double ds = query_key - base_keys_.at(j);
res.push_back(c.at(j) + (2.0 * b.at(j) + 3.0 * a.at(j) * ds) * ds);
}

return interpolated_diff_values;
return res;
}
Loading