Skip to content

Commit

Permalink
Interpolating spline (#141)
Browse files Browse the repository at this point in the history
  • Loading branch information
myeatman-bdai authored Oct 24, 2024
1 parent 0c57dc0 commit a6d6641
Show file tree
Hide file tree
Showing 5 changed files with 325 additions and 52 deletions.
27 changes: 21 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
repos:
# - repo: https://github.com/charliermarsh/ruff-pre-commit
# # Ruff version.
# rev: 'v0.1.0'
# hooks:
# - id: ruff
# args: ['--fix', '--config', 'pyproject.toml']
- repo: https://github.com/charliermarsh/ruff-pre-commit
# Ruff version.
rev: 'v0.1.0'
hooks:
- id: ruff
args: ['--fix', '--config', 'pyproject.toml']

- repo: https://github.com/psf/black
rev: 23.10.0
Expand All @@ -14,6 +14,21 @@ repos:
args: ['--config', 'pyproject.toml']
verbose: true

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: end-of-file-fixer
- id: debug-statements # Ensure we don't commit `import pdb; pdb.set_trace()`
exclude: |
(?x)^(
docker/ros/web/static/.*|
)$
- id: trailing-whitespace
exclude: |
(?x)^(
docker/ros/web/static/.*|
(.*/).*\.patch|
)$
# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: v1.6.1
# hooks:
Expand Down
4 changes: 3 additions & 1 deletion spatialmath/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from spatialmath.quaternion import Quaternion, UnitQuaternion
from spatialmath.DualQuaternion import DualQuaternion, UnitDualQuaternion
from spatialmath.spline import BSplineSE3
from spatialmath.spline import BSplineSE3, InterpSplineSE3, SplineFit

# from spatialmath.Plucker import *
# from spatialmath import base as smb
Expand Down Expand Up @@ -45,6 +45,8 @@
"Polygon2",
"Ellipse",
"BSplineSE3",
"InterpSplineSE3",
"SplineFit"
]

try:
Expand Down
2 changes: 1 addition & 1 deletion spatialmath/base/animate.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def update(frame, animation):
if isinstance(frame, float):
# passed a single transform, interpolate it
T = smb.trinterp(start=self.start, end=self.end, s=frame)
elif isinstance(frame, NDArray):
elif isinstance(frame, np.ndarray):
# type is SO3Array or SE3Array when Animate.trajectory is not None
T = frame
else:
Expand Down
273 changes: 233 additions & 40 deletions spatialmath/spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,242 @@
# MIT Licence, see details in top-level file: LICENCE

"""
Classes for parameterizing a trajectory in SE3 with B-splines.
Copies parts of the API from scipy's B-spline class.
Classes for parameterizing a trajectory in SE3 with splines.
"""

from typing import Any, Dict, List, Optional
from scipy.interpolate import BSpline
from spatialmath import SE3
import numpy as np
from abc import ABC, abstractmethod
from functools import cached_property
from typing import List, Optional, Tuple, Set

import matplotlib.pyplot as plt
from spatialmath.base.transforms3d import tranimate, trplot
import numpy as np
from scipy.interpolate import BSpline, CubicSpline
from scipy.spatial.transform import Rotation, RotationSpline

from spatialmath import SE3, SO3, Twist3
from spatialmath.base.transforms3d import tranimate


class SplineSE3(ABC):
def __init__(self) -> None:
self.control_poses: SE3

@abstractmethod
def __call__(self, t: float) -> SE3:
pass

def visualize(
self,
sample_times: List[float],
input_trajectory: Optional[List[SE3]] = None,
pose_marker_length: float = 0.2,
animate: bool = False,
repeat: bool = True,
ax: Optional[plt.Axes] = None,
) -> None:
"""Displays an animation of the trajectory with the control poses against an optional input trajectory.
Args:
sample_times: which times to sample the spline at and plot
"""
if ax is None:
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(projection="3d")

samples = [self(t) for t in sample_times]
if not animate:
pos = np.array([pose.t for pose in samples])
ax.plot(
pos[:, 0], pos[:, 1], pos[:, 2], "c", linewidth=1.0
) # plot spline fit

pos = np.array([pose.t for pose in self.control_poses])
ax.plot(pos[:, 0], pos[:, 1], pos[:, 2], "r*") # plot control_poses

if input_trajectory is not None:
pos = np.array([pose.t for pose in input_trajectory])
ax.plot(
pos[:, 0], pos[:, 1], pos[:, 2], "go", fillstyle="none"
) # plot compare to input poses

if animate:
tranimate(
samples, length=pose_marker_length, wait=True, repeat=repeat
) # animate pose along trajectory
else:
plt.show()


class InterpSplineSE3(SplineSE3):
"""Class for an interpolated trajectory in SE3, as a function of time, through control_poses with a cubic spline.
A combination of scipy.interpolate.CubicSpline and scipy.spatial.transform.RotationSpline (itself also cubic)
under the hood.
"""

_e = 1e-12

def __init__(
self,
timepoints: List[float],
control_poses: List[SE3],
*,
normalize_time: bool = False,
bc_type: str = "not-a-knot", # not-a-knot is scipy default; None is invalid
) -> None:
"""Construct a InterpSplineSE3 object
Extends the scipy CubicSpline object
https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.CubicSpline.html#cubicspline
Args :
timepoints : list of times corresponding to provided poses
control_poses : list of SE3 objects that govern the shape of the spline.
normalize_time : flag to map times into the range [0, 1]
bc_type : boundary condition provided to scipy CubicSpline backend.
string options: ["not-a-knot" (default), "clamped", "natural", "periodic"].
For tuple options and details see the scipy docs link above.
"""
super().__init__()
self.control_poses = control_poses
self.timepoints = np.array(timepoints)

if self.timepoints[-1] < self._e:
raise ValueError(
"Difference between start and end timepoints is less than {self._e}"
)

if len(self.control_poses) != len(self.timepoints):
raise ValueError("Length of control_poses and timepoints must be equal.")

if len(self.timepoints) < 2:
raise ValueError("Need at least 2 data points to make a trajectory.")

if normalize_time:
self.timepoints = self.timepoints - self.timepoints[0]
self.timepoints = self.timepoints / self.timepoints[-1]

self.spline_xyz = CubicSpline(
self.timepoints,
np.array([pose.t for pose in self.control_poses]),
bc_type=bc_type,
)
self.spline_so3 = RotationSpline(
self.timepoints,
Rotation.from_matrix(np.array([(pose.R) for pose in self.control_poses])),
)

def __call__(self, t: float) -> SE3:
"""Compute function value at t.
Return:
pose: SE3
"""
return SE3.Rt(t=self.spline_xyz(t), R=self.spline_so3(t).as_matrix())

def derivative(self, t: float) -> Twist3:
linear_vel = self.spline_xyz.derivative()(t)
angular_vel = self.spline_so3(
t, 1
) # 1 is angular rate, 2 is angular acceleration
return Twist3(linear_vel, angular_vel)


class SplineFit:
"""A general class to fit various SE3 splines to data."""

def __init__(
self,
time_data: List[float],
pose_data: List[SE3],
) -> None:
self.time_data = time_data
self.pose_data = pose_data
self.spline: Optional[SplineSE3] = None

def stochastic_downsample_interpolation(
self,
epsilon_xyz: float = 1e-3,
epsilon_angle: float = 1e-1,
normalize_time: bool = True,
bc_type: str = "not-a-knot",
check_type: str = "local"
) -> Tuple[InterpSplineSE3, List[int]]:
"""
Uses a random dropout to downsample a trajectory with an interpolated spline. Keeps the start and
end points of the trajectory. Takes a random order of the remaining indices, and then checks the error bound
of just that point if check_type=="local", checks the error of the whole trajectory is check_type=="global".
Local is **much** faster.
Return:
downsampled interpolating spline,
list of removed indices from input data
"""

interpolation_indices = list(range(len(self.pose_data)))

# randomly attempt to remove poses from the trajectory
# always keep the start and end
removal_choices = interpolation_indices.copy()
removal_choices.remove(0)
removal_choices.remove(len(self.pose_data) - 1)
np.random.shuffle(removal_choices)
for candidate_removal_index in removal_choices:
interpolation_indices.remove(candidate_removal_index)

self.spline = InterpSplineSE3(
[self.time_data[i] for i in interpolation_indices],
[self.pose_data[i] for i in interpolation_indices],
normalize_time=normalize_time,
bc_type=bc_type,
)

sample_time = self.time_data[candidate_removal_index]
if check_type is "local":
angular_error = SO3(self.pose_data[candidate_removal_index]).angdist(
SO3(self.spline.spline_so3(sample_time).as_matrix())
)
euclidean_error = np.linalg.norm(
self.pose_data[candidate_removal_index].t - self.spline.spline_xyz(sample_time)
)
elif check_type is "global":
angular_error = self.max_angular_error()
euclidean_error = self.max_euclidean_error()
else:
raise ValueError(f"check_type must be 'local' of 'global', is {check_type}.")

if (angular_error > epsilon_angle) or (euclidean_error > epsilon_xyz):
interpolation_indices.append(candidate_removal_index)
interpolation_indices.sort()

self.spline = InterpSplineSE3(
[self.time_data[i] for i in interpolation_indices],
[self.pose_data[i] for i in interpolation_indices],
normalize_time=normalize_time,
bc_type=bc_type,
)

return self.spline, interpolation_indices

def max_angular_error(self) -> float:
return np.max(self.angular_errors())

def angular_errors(self) -> List[float]:
return [
pose.angdist(self.spline(t))
for pose, t in zip(self.pose_data, self.time_data)
]

def max_euclidean_error(self) -> float:
return np.max(self.euclidean_errors())

class BSplineSE3:
def euclidean_errors(self) -> List[float]:
return [
np.linalg.norm(pose.t - self.spline(t).t)
for pose, t in zip(self.pose_data, self.time_data)
]


class BSplineSE3(SplineSE3):
"""A class to parameterize a trajectory in SE3 with a 6-dimensional B-spline.
The SE3 control poses are converted to se3 twists (the lie algebra) and a B-spline
Expand All @@ -39,9 +261,9 @@ def __init__(
- degree: int that controls degree of the polynomial that governs any given point on the spline.
- knots: list of floats that govern which control points are active during evaluating the spline
at a given t input. If none, they are automatically, uniformly generated based on number of control poses and
degree of spline.
degree of spline on the range [0,1].
"""

super().__init__()
self.control_poses = control_poses

# a matrix where each row is a control pose as a twist
Expand Down Expand Up @@ -74,32 +296,3 @@ def __call__(self, t: float) -> SE3:
"""
twist = np.hstack([spline(t) for spline in self.splines])
return SE3.Exp(twist)

def visualize(
self,
num_samples: int,
length: float = 1.0,
repeat: bool = False,
ax: Optional[plt.Axes] = None,
kwargs_trplot: Dict[str, Any] = {"color": "green"},
kwargs_tranimate: Dict[str, Any] = {"wait": True},
kwargs_plot: Dict[str, Any] = {},
) -> None:
"""Displays an animation of the trajectory with the control poses."""
out_poses = [self(t) for t in np.linspace(0, 1, num_samples)]
x = [pose.x for pose in out_poses]
y = [pose.y for pose in out_poses]
z = [pose.z for pose in out_poses]

if ax is None:
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(projection="3d")

trplot(
[np.array(self.control_poses)], ax=ax, length=length, **kwargs_trplot
) # plot control points
ax.plot(x, y, z, **kwargs_plot) # plot x,y,z trajectory

tranimate(
out_poses, repeat=repeat, length=length, **kwargs_tranimate
) # animate pose along trajectory
Loading

0 comments on commit a6d6641

Please sign in to comment.