Skip to content

Commit

Permalink
Cover all test cases | Fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Kotochleb committed Jul 24, 2024
1 parent aa004fe commit ee3f145
Show file tree
Hide file tree
Showing 2 changed files with 264 additions and 14 deletions.
11 changes: 7 additions & 4 deletions happypose_msgs/happypose_msgs_py/symmetries.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,18 @@ def _discretize_continuous(
sym: ContinuousSymmetry, idx: int
) -> npt.NDArray[np.float64]:
axis = np.array([sym.axis.x, sym.axis.y, sym.axis.z])
if not np.isclose(axis.sum(), 1.0):
if not np.isclose(np.linalg.norm(axis), 1.0):
raise ValueError(
f"Continuous symmetry at index {idx} has non unitary rotation axis!"
)
symmetries = np.zeros((n_symmetries_continuous, 4, 4))

# Precompute steps of rotations
rot_base = 2.0 * axis * np.pi / n_symmetries_continuous
rot_base = 2.0 * np.pi / n_symmetries_continuous
for i in range(n_symmetries_continuous):
symmetries[i, :3, :3] = transforms3d.euler.euler2mat(*(rot_base * i))
symmetries[i, :3, :3] = transforms3d.axangles.axangle2mat(
axis, rot_base * i
)

symmetries[:, -1, -1] = 1.0
symmetries[:, :3, -1] = np.array([sym.offset.x, sym.offset.y, sym.offset.z])
Expand All @@ -61,8 +63,9 @@ def _discretize_continuous(

def _transform_msg_to_mat(sym: Transform) -> npt.NDArray[np.float64]:
M = np.eye(4)

M[:3, :3] = transforms3d.quaternions.quat2mat(
(sym.rotation.w, sym.rotation.x, sym.rotation.y, sym.rotation.z)
[sym.rotation.w, sym.rotation.x, sym.rotation.y, sym.rotation.z]
)
M[0, -1] = sym.translation.x
M[1, -1] = sym.translation.y
Expand Down
267 changes: 257 additions & 10 deletions happypose_msgs/test/test_discretize_symmetries.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,32 @@
#!/usr/bin/env python

import numpy as np
import numpy.typing as npt
import pinocchio as pin
from typing import List

from geometry_msgs.msg import Transform, Vector3, Quaternion

from happypose_msgs_py.symmetries import discretize_symmetries

from happypose_msgs.msg import ObjectSymmetries
from happypose_msgs.msg import ContinuousSymmetry, ObjectSymmetries


def are_transforms_close(t1: Transform, t2: Transform) -> bool:
def are_transforms_close(t1: pin.SE3, t2: pin.SE3) -> bool:
"""Checks if two SE3 transformations are close.
:param t1: First transform to compare.
:type t1: pinocchio.SE3
:param t2: Second transform to compare.
:type t2: pinocchio.SE3
:return: If those transformations are close.
:rtype: bool
"""
diff = t1.inverse() * t2
return np.linalg.norm(pin.log6(diff).vector) < 5e-3


def are_transform_msgs_close(t1: Transform, t2: Transform) -> bool:
"""Checks if two ROS transform messages are close.
:param t1: First transform to compare.
Expand All @@ -35,11 +50,10 @@ def are_transforms_close(t1: Transform, t2: Transform) -> bool:
)
for t in (t1, t2)
]
diff = T1.inverse() * T2
return np.linalg.norm(pin.log6(diff).vector) < 1e-6
return are_transforms_close(T1, T2)


def is_transform_in_list(t1: Transform, t_list: List[Transform]) -> bool:
def is_transform_msg_in_list(t1: Transform, t_list: List[Transform]) -> bool:
"""Checks if a transform is in the list of transformations.
:param t1: Transform to check if in the list.
Expand All @@ -49,7 +63,39 @@ def is_transform_in_list(t1: Transform, t_list: List[Transform]) -> bool:
:return: If the transform in the list.
:rtype: bool
"""
return any(are_transforms_close(t1, t2) for t2 in t_list)
return any(are_transform_msgs_close(t1, t2) for t2 in t_list)


def is_transform_in_se3_list(
t1: npt.NDArray[np.float64], t_list: List[pin.SE3]
) -> bool:
"""Checks if a transform is in the tensor of transformations.
:param t1: Transform to check if in the tensor.
:type t1: npt.NDArray[np.float64]
:param t_list: List of SE3 objects used to find ``t1``.
:type t_list: List[pinocchio.SE3]
:return: If the transform in the list.
:rtype: bool
"""
T1 = pin.SE3(t1)
return any(are_transforms_close(T1, t2) for t2 in t_list)


def pin_to_msg(transform: pin.SE3) -> Transform:
"""Converts 4x4 transformation matrix to ROS Transform message.
:param transform: Input transformation.
:type transform: pinocchio.SE3
:return: Converted SE3 transformation into ROS Transform
message format.
:rtype: geometry_msgs.msg.Transform
"""
pose_vec = pin.SE3ToXYZQUAT(transform)
return Transform(
translation=Vector3(**dict(zip("xyz", pose_vec[:3]))),
rotation=Quaternion(**dict(zip("xyzw", pose_vec[3:]))),
)


def test_empty_message_np() -> None:
Expand All @@ -74,6 +120,28 @@ def test_empty_message_ros() -> None:
assert len(res) == 0, "Results list is not empty!"


def test_only_discrete_np() -> None:
t1 = pin.SE3(
pin.Quaternion(np.array([0.0, 0.0, 0.0, 1.0])), np.array([0.0, 0.0, 0.0])
)
t2 = pin.SE3(
pin.Quaternion(np.array([0.707, 0.0, 0.0, 0.707])), np.array([0.1, 0.1, 0.1])
)
msg = ObjectSymmetries(
symmetries_discrete=[pin_to_msg(t1), pin_to_msg(t2)],
symmetries_continuous=[],
)

res = discretize_symmetries(msg)

assert res.shape == (2, 4, 4), "Result shape is incorrect!"

for i, t in enumerate(res):
assert is_transform_in_se3_list(
t, [t1, t2]
), f"Discrete symmetry at index {i} did not match any of the initial ones!"


def test_only_discrete_ros() -> None:
msg = ObjectSymmetries(
symmetries_discrete=[
Expand All @@ -83,7 +151,7 @@ def test_only_discrete_ros() -> None:
),
Transform(
translation=Vector3(x=0.1, y=0.1, z=0.1),
rotation=Quaternion(x=0.0, y=0.0, z=0.0, w=1.0),
rotation=Quaternion(x=0.707, y=0.0, z=0.0, w=0.707),
),
],
symmetries_continuous=[],
Expand All @@ -99,6 +167,185 @@ def test_only_discrete_ros() -> None:
isinstance(r, Transform) for r in res
), "Returned type of elements in the list is not geometry_msgs.msg.Transform!"

assert all(
is_transform_in_list(t, msg.symmetries_discrete) for t in res
), "Resulted discrete symmetries are not close to the initial ones int the message!"
for i, t in enumerate(res):
assert is_transform_msg_in_list(
t, msg.symmetries_discrete
), f"Discrete symmetry at index {i} did not match any initial ones in the message!"


def test_only_continuous_np() -> None:
n_symmetries = 2
axis = np.array([1.0, 0.0, -1.0])
axis = axis / np.linalg.norm(axis)
offset = np.array([-0.1, 0.0, 0.1])

t_list = [
pin.SE3(
pin.Quaternion(pin.AngleAxis(2.0 * np.pi / n_symmetries * i, axis)),
offset,
)
for i in range(n_symmetries)
]
msg = ObjectSymmetries(
symmetries_discrete=[],
symmetries_continuous=[
ContinuousSymmetry(
axis=Vector3(**dict(zip("xyz", axis))),
offset=Vector3(**dict(zip("xyz", offset))),
)
],
)

res = discretize_symmetries(msg, n_symmetries_continuous=n_symmetries)
assert res.shape == (n_symmetries, 4, 4), "Result shape is incorrect!"
for i, t in enumerate(res):
assert is_transform_in_se3_list(
t, t_list
), f"Discrete symmetry at index {i} did not match any symmetry from generated list!"


def test_only_continuous_ros() -> None:
n_symmetries = 9
axis = np.array([1.0, 0.2, 1.0])
axis = axis / np.linalg.norm(axis)
offset = np.array([0.0, 0.0, 0.1])

t_list = [
pin.SE3(
pin.Quaternion(pin.AngleAxis(2.0 * np.pi / n_symmetries * i, axis)),
offset,
)
for i in range(n_symmetries)
]
msg = ObjectSymmetries(
symmetries_discrete=[],
symmetries_continuous=[
ContinuousSymmetry(
axis=Vector3(**dict(zip("xyz", axis))),
offset=Vector3(**dict(zip("xyz", offset))),
)
],
)

res = discretize_symmetries(
msg, n_symmetries_continuous=n_symmetries, return_ros_msg=True
)

assert len(res) == len(
t_list
), "Results list does not have all discrete symmetries from message received!"

t_msgs = [pin_to_msg(t) for t in t_list]

for i, t in enumerate(res):
assert is_transform_msg_in_list(
t, t_msgs
), f"Discrete symmetry at index {i} did not match any symmetry from generated list!"


def test_mixed_np() -> None:
n_symmetries = 32
axis = np.array([-0.1, 0.5, 0.5])
axis = axis / np.linalg.norm(axis)
offset = np.array([-0.1, 0.6, 0.1])

t_c_list = [
pin.SE3(
pin.Quaternion(pin.AngleAxis(2.0 * np.pi / n_symmetries * i, axis)),
offset,
)
for i in range(n_symmetries)
]

t_d_list = [
pin.SE3(
pin.Quaternion(np.array([0.0, 0.0, 0.0, 1.0])), np.array([0.0, 0.0, 0.0])
),
pin.SE3(
pin.Quaternion(np.array([0.707, 0.0, 0.0, 0.707])),
np.array([0.1, 0.1, 0.1]),
),
]
msg = ObjectSymmetries(
symmetries_discrete=[*[pin_to_msg(t) for t in t_d_list]],
symmetries_continuous=[
ContinuousSymmetry(
axis=Vector3(**dict(zip("xyz", axis))),
offset=Vector3(**dict(zip("xyz", offset))),
)
],
)

res = discretize_symmetries(msg, n_symmetries_continuous=n_symmetries)

assert res.shape == (
len(t_c_list) + len(t_d_list) + len(t_d_list) * len(t_c_list),
4,
4,
), "Result shape is incorrect!"

t_test = [
*t_c_list,
*t_d_list,
*[t_d * t_c for t_c in t_c_list for t_d in t_d_list],
]

print(res, flush=True)

for i, t in enumerate(res):
assert is_transform_in_se3_list(
t, t_test
), f"Discrete symmetry at index {i} did not match any symmetry from generated list!"


def test_mixed_ros() -> None:
n_symmetries = 31
axis = np.array([-0.9, 0.2, -0.5])
axis = axis / np.linalg.norm(axis)
offset = np.array([-1.1, 0.6, 0.1])

t_c_list = [
pin.SE3(
pin.Quaternion(pin.AngleAxis(2.0 * np.pi / n_symmetries * i, axis)),
offset,
)
for i in range(n_symmetries)
]

t_d_list = [
pin.SE3(
pin.Quaternion(np.array([0.0, 0.0, 0.0, 1.0])), np.array([0.0, 0.0, 0.0])
),
pin.SE3(
pin.Quaternion(np.array([0.707, 0.0, 0.0, 0.707])),
np.array([0.1, 0.1, 0.1]),
),
]
msg = ObjectSymmetries(
symmetries_discrete=[*[pin_to_msg(t) for t in t_d_list]],
symmetries_continuous=[
ContinuousSymmetry(
axis=Vector3(**dict(zip("xyz", axis))),
offset=Vector3(**dict(zip("xyz", offset))),
)
],
)

res = discretize_symmetries(
msg, n_symmetries_continuous=n_symmetries, return_ros_msg=True
)

assert len(res) == (
len(t_c_list) + len(t_d_list) + len(t_d_list) * len(t_c_list)
), "Size od the results is incorrect!"

t_test = [
*[pin_to_msg(t) for t in t_c_list],
*[pin_to_msg(t) for t in t_d_list],
*[pin_to_msg(t_d * t_c) for t_c in t_c_list for t_d in t_d_list],
]

for i, t in enumerate(res):
assert is_transform_msg_in_list(
t, t_test
), f"Discrete symmetry at index {i} did not match any symmetry from generated list!"

0 comments on commit ee3f145

Please sign in to comment.