From c79a3e81f178eec01d1c35861ba426802f1346bd Mon Sep 17 00:00:00 2001
From: "Y.Hisaki" <yhisaki31@gmail.com>
Date: Tue, 17 Sep 2024 18:11:24 +0900
Subject: [PATCH] fix spline bug

Signed-off-by: Y.Hisaki <yhisaki31@gmail.com>
---
 .../interpolation/spline_interpolation.hpp    |  32 +--
 .../spline_interpolation_points_2d.hpp        |   3 +
 common/interpolation/package.xml              |   1 +
 .../src/spline_interpolation.cpp              | 210 ++++++++----------
 4 files changed, 101 insertions(+), 145 deletions(-)

diff --git a/common/interpolation/include/interpolation/spline_interpolation.hpp b/common/interpolation/include/interpolation/spline_interpolation.hpp
index 09a01d03727eb..578b08a1fa225 100644
--- a/common/interpolation/include/interpolation/spline_interpolation.hpp
+++ b/common/interpolation/include/interpolation/spline_interpolation.hpp
@@ -15,35 +15,13 @@
 #ifndef INTERPOLATION__SPLINE_INTERPOLATION_HPP_
 #define INTERPOLATION__SPLINE_INTERPOLATION_HPP_
 
-#include "interpolation/interpolation_utils.hpp"
-#include "tier4_autoware_utils/geometry/geometry.hpp"
+#include <Eigen/Dense>
 
-#include <algorithm>
 #include <cmath>
-#include <iostream>
-#include <numeric>
 #include <vector>
 
 namespace interpolation
 {
-// NOTE: X(s) = a_i (s - s_i)^3 + b_i (s - s_i)^2 + c_i (s - s_i) + d_i : (i = 0, 1, ... N-1)
-struct MultiSplineCoef
-{
-  MultiSplineCoef() = default;
-
-  explicit MultiSplineCoef(const size_t num_spline)
-  {
-    a.resize(num_spline);
-    b.resize(num_spline);
-    c.resize(num_spline);
-    d.resize(num_spline);
-  }
-
-  std::vector<double> a;
-  std::vector<double> b;
-  std::vector<double> c;
-  std::vector<double> d;
-};
 
 // static spline interpolation functions
 std::vector<double> slerp(
@@ -84,8 +62,14 @@ class SplineInterpolation
   std::vector<double> getSplineInterpolatedDiffValues(const std::vector<double> & query_keys) const;
 
 private:
+  Eigen::VectorXd a_;
+  Eigen::VectorXd b_;
+  Eigen::VectorXd c_;
+  Eigen::VectorXd d_;
+
   std::vector<double> base_keys_;
-  interpolation::MultiSplineCoef multi_spline_coef_;
+
+  Eigen::Index get_index(double key) const;
 };
 
 #endif  // INTERPOLATION__SPLINE_INTERPOLATION_HPP_
diff --git a/common/interpolation/include/interpolation/spline_interpolation_points_2d.hpp b/common/interpolation/include/interpolation/spline_interpolation_points_2d.hpp
index c1f08a6d937ae..f46b64bba4d6a 100644
--- a/common/interpolation/include/interpolation/spline_interpolation_points_2d.hpp
+++ b/common/interpolation/include/interpolation/spline_interpolation_points_2d.hpp
@@ -16,6 +16,9 @@
 #define INTERPOLATION__SPLINE_INTERPOLATION_POINTS_2D_HPP_
 
 #include "interpolation/spline_interpolation.hpp"
+#include "tier4_autoware_utils/geometry/geometry.hpp"
+
+#include <geometry_msgs/msg/point.hpp>
 
 #include <vector>
 
diff --git a/common/interpolation/package.xml b/common/interpolation/package.xml
index 72844e0702978..9f94f9221100e 100644
--- a/common/interpolation/package.xml
+++ b/common/interpolation/package.xml
@@ -9,6 +9,7 @@
   <license>Apache License 2.0</license>
   <buildtool_depend>ament_cmake_auto</buildtool_depend>
 
+  <depend>eigen</depend>
   <depend>tier4_autoware_utils</depend>
 
   <test_depend>ament_lint_auto</test_depend>
diff --git a/common/interpolation/src/spline_interpolation.cpp b/common/interpolation/src/spline_interpolation.cpp
index f8d14ff7bba37..bbcaafdef14e6 100644
--- a/common/interpolation/src/spline_interpolation.cpp
+++ b/common/interpolation/src/spline_interpolation.cpp
@@ -14,70 +14,43 @@
 
 #include "interpolation/spline_interpolation.hpp"
 
-#include <vector>
+#include "interpolation/interpolation_utils.hpp"
 
-namespace
-{
-// solve Ax = d
-// where A is tridiagonal matrix
-//     [b_0 c_0 ...                       ]
-//     [a_0 b_1 c_1 ...               O   ]
-// A = [            ...                   ]
-//     [   O         ... a_N-3 b_N-2 c_N-2]
-//     [                   ... a_N-2 b_N-1]
-struct TDMACoef
+#include <algorithm>
+
+Eigen::VectorXd solve_tridiagonal_matrix_algorithm(
+  const Eigen::Ref<const Eigen::VectorXd> & a, const Eigen::Ref<const Eigen::VectorXd> & b,
+  const Eigen::Ref<const Eigen::VectorXd> & c, const Eigen::Ref<const Eigen::VectorXd> & d)
 {
-  explicit TDMACoef(const size_t num_row)
-  {
-    a.resize(num_row - 1);
-    b.resize(num_row);
-    c.resize(num_row - 1);
-    d.resize(num_row);
+  auto n = d.size();
+
+  if (n == 1) {
+    return d.array() / b.array();
   }
 
-  std::vector<double> a;
-  std::vector<double> b;
-  std::vector<double> c;
-  std::vector<double> d;
-};
+  Eigen::VectorXd c_prime = Eigen::VectorXd::Zero(n);
+  Eigen::VectorXd d_prime = Eigen::VectorXd::Zero(n);
+  Eigen::VectorXd x = Eigen::VectorXd::Zero(n);
 
-inline std::vector<double> solveTridiagonalMatrixAlgorithm(const TDMACoef & tdma_coef)
-{
-  const auto & a = tdma_coef.a;
-  const auto & b = tdma_coef.b;
-  const auto & c = tdma_coef.c;
-  const auto & d = tdma_coef.d;
-
-  const size_t num_row = b.size();
-
-  std::vector<double> x(num_row);
-  if (num_row != 1) {
-    // calculate p and q
-    std::vector<double> p;
-    std::vector<double> q;
-    p.push_back(-c[0] / b[0]);
-    q.push_back(d[0] / b[0]);
-
-    for (size_t i = 1; i < num_row; ++i) {
-      const double den = b[i] + a[i - 1] * p[i - 1];
-      p.push_back(-c[i - 1] / den);
-      q.push_back((d[i] - a[i - 1] * q[i - 1]) / den);
-    }
-
-    // calculate solution
-    x[num_row - 1] = q[num_row - 1];
-
-    for (size_t i = 1; i < num_row; ++i) {
-      const size_t j = num_row - 1 - i;
-      x[j] = p[j] * x[j + 1] + q[j];
-    }
-  } else {
-    x.push_back(d[0] / b[0]);
+  // Forward sweep
+  c_prime(0) = c(0) / b(0);
+  d_prime(0) = d(0) / b(0);
+
+  for (auto i = 1; i < n; i++) {
+    double m = 1.0 / (b(i) - a(i - 1) * c_prime(i - 1));
+    c_prime(i) = i < n - 1 ? c(i) * m : 0;
+    d_prime(i) = (d(i) - a(i - 1) * d_prime(i - 1)) * m;
+  }
+
+  // Back substitution
+  x(n - 1) = d_prime(n - 1);
+
+  for (auto i = n - 2; i >= 0; i--) {
+    x(i) = d_prime(i) - c_prime(i) * x(i + 1);
   }
 
   return x;
 }
-}  // namespace
 
 namespace interpolation
 {
@@ -101,73 +74,74 @@ void SplineInterpolation::calcSplineCoefficients(
   // throw exceptions for invalid arguments
   interpolation_utils::validateKeysAndValues(base_keys, base_values);
 
-  const size_t num_base = base_keys.size();  // N+1
-
-  std::vector<double> diff_keys;    // N
-  std::vector<double> diff_values;  // N
-  for (size_t i = 0; i < num_base - 1; ++i) {
-    diff_keys.push_back(base_keys.at(i + 1) - base_keys.at(i));
-    diff_values.push_back(base_values.at(i + 1) - base_values.at(i));
-  }
-
-  std::vector<double> v = {0.0};
-  if (num_base > 2) {
-    // solve tridiagonal matrix algorithm
-    TDMACoef tdma_coef(num_base - 2);  // N-1
-
-    for (size_t i = 0; i < num_base - 2; ++i) {
-      tdma_coef.b[i] = 2 * (diff_keys[i] + diff_keys[i + 1]);
-      if (i != num_base - 3) {
-        tdma_coef.a[i] = diff_keys[i + 1];
-        tdma_coef.c[i] = diff_keys[i + 1];
-      }
-      tdma_coef.d[i] =
-        6.0 * (diff_values[i + 1] / diff_keys[i + 1] - diff_values[i] / diff_keys[i]);
-    }
-
-    const std::vector<double> tdma_res = solveTridiagonalMatrixAlgorithm(tdma_coef);
-
-    // calculate v
-    v.insert(v.end(), tdma_res.begin(), tdma_res.end());
-  }
-  v.push_back(0.0);
-
-  // calculate a, b, c, d of spline coefficients
-  multi_spline_coef_ = interpolation::MultiSplineCoef{num_base - 1};  // N
-  for (size_t i = 0; i < num_base - 1; ++i) {
-    multi_spline_coef_.a[i] = (v[i + 1] - v[i]) / 6.0 / diff_keys[i];
-    multi_spline_coef_.b[i] = v[i] / 2.0;
-    multi_spline_coef_.c[i] =
-      diff_values[i] / diff_keys[i] - diff_keys[i] * (2 * v[i] + v[i + 1]) / 6.0;
-    multi_spline_coef_.d[i] = base_values[i];
+  Eigen::VectorXd x = Eigen::Map<const Eigen::VectorXd>(
+    base_keys.data(), static_cast<Eigen::Index>(base_keys.size()));
+  Eigen::VectorXd y = Eigen::Map<const Eigen::VectorXd>(
+    base_values.data(), static_cast<Eigen::Index>(base_values.size()));
+
+  const auto n = x.size();
+
+  if (n == 2) {
+    a_ = Eigen::VectorXd::Zero(1);
+    b_ = Eigen::VectorXd::Zero(1);
+    c_ = Eigen::VectorXd::Zero(1);
+    d_ = Eigen::VectorXd::Zero(1);
+    c_[0] = (y[1] - y[0]) / (x[1] - x[0]);
+    d_[0] = y[0];
+    base_keys_ = base_keys;
+    return;
   }
 
+  // Create Tridiagonal matrix
+  Eigen::VectorXd v(n);
+  Eigen::VectorXd h = x.segment(1, n - 1) - x.segment(0, n - 1);
+  Eigen::VectorXd a = h.segment(1, n - 3);
+  Eigen::VectorXd b = 2 * (h.segment(0, n - 2) + h.segment(1, n - 2));
+  Eigen::VectorXd c = h.segment(1, n - 3);
+  Eigen::VectorXd y_diff = y.segment(1, n - 1) - y.segment(0, n - 1);
+  Eigen::VectorXd d = 6 * (y_diff.segment(1, n - 2).array() / h.tail(n - 2).array() -
+                           y_diff.segment(0, n - 2).array() / h.head(n - 2).array());
+
+  // Solve tridiagonal matrix
+  v.segment(1, n - 2) = solve_tridiagonal_matrix_algorithm(a, b, c, d);
+  v[0] = 0;
+  v[n - 1] = 0;
+
+  // Calculate spline coefficients
+  a_ = (v.tail(n - 1) - v.head(n - 1)).array() / 6.0 / (x.tail(n - 1) - x.head(n - 1)).array();
+  b_ = v.segment(0, n - 1) / 2.0;
+  c_ = (y.tail(n - 1) - y.head(n - 1)).array() / (x.tail(n - 1) - x.head(n - 1)).array() -
+       (x.tail(n - 1) - x.head(n - 1)).array() *
+         (2 * v.segment(0, n - 1).array() + v.segment(1, n - 1).array()) / 6.0;
+  d_ = y.head(n - 1);
   base_keys_ = base_keys;
 }
 
+Eigen::Index SplineInterpolation::get_index(double key) const
+{
+  auto it = std::lower_bound(base_keys_.begin(), base_keys_.end(), key);
+  return std::clamp(
+    static_cast<int>(std::distance(base_keys_.begin(), it)) - 1, 0,
+    static_cast<int>(base_keys_.size()) - 2);
+}
+
 std::vector<double> SplineInterpolation::getSplineInterpolatedValues(
   const std::vector<double> & query_keys) const
 {
   // throw exceptions for invalid arguments
   interpolation_utils::validateKeys(base_keys_, query_keys);
 
-  const auto & a = multi_spline_coef_.a;
-  const auto & b = multi_spline_coef_.b;
-  const auto & c = multi_spline_coef_.c;
-  const auto & d = multi_spline_coef_.d;
+  std::vector<double> interpolated_values;
+  interpolated_values.reserve(query_keys.size());
 
-  std::vector<double> res;
-  size_t j = 0;
-  for (const auto & query_key : query_keys) {
-    while (base_keys_.at(j + 1) < query_key) {
-      ++j;
-    }
-
-    const double ds = query_key - base_keys_.at(j);
-    res.push_back(d.at(j) + (c.at(j) + (b.at(j) + a.at(j) * ds) * ds) * ds);
+  for (const auto & key : query_keys) {
+    const auto idx = get_index(key);
+    const auto dx = key - base_keys_[idx];
+    interpolated_values.emplace_back(
+      a_[idx] * dx * dx * dx + b_[idx] * dx * dx + c_[idx] * dx + d_[idx]);
   }
 
-  return res;
+  return interpolated_values;
 }
 
 std::vector<double> SplineInterpolation::getSplineInterpolatedDiffValues(
@@ -176,20 +150,14 @@ std::vector<double> SplineInterpolation::getSplineInterpolatedDiffValues(
   // throw exceptions for invalid arguments
   interpolation_utils::validateKeys(base_keys_, query_keys);
 
-  const auto & a = multi_spline_coef_.a;
-  const auto & b = multi_spline_coef_.b;
-  const auto & c = multi_spline_coef_.c;
-
-  std::vector<double> res;
-  size_t j = 0;
-  for (const auto & query_key : query_keys) {
-    while (base_keys_.at(j + 1) < query_key) {
-      ++j;
-    }
+  std::vector<double> interpolated_diff_values;
+  interpolated_diff_values.reserve(query_keys.size());
 
-    const double ds = query_key - base_keys_.at(j);
-    res.push_back(c.at(j) + (2.0 * b.at(j) + 3.0 * a.at(j) * ds) * ds);
+  for (const auto & key : query_keys) {
+    const auto idx = get_index(key);
+    const auto dx = key - base_keys_[idx];
+    interpolated_diff_values.emplace_back(3 * a_[idx] * dx * dx + 2 * b_[idx] * dx + c_[idx]);
   }
 
-  return res;
+  return interpolated_diff_values;
 }