From ad7ecce0212f26133f3babaccaff46153e6f223b Mon Sep 17 00:00:00 2001 From: brentyi Date: Sun, 15 Sep 2024 14:05:47 -0700 Subject: [PATCH 1/8] Start transforms cleanup --- src/viser/transforms/_base.py | 21 +++++++++++++++++---- src/viser/transforms/_se2.py | 18 ++++++++++-------- src/viser/transforms/_se3.py | 22 +++++++++++++--------- src/viser/transforms/_so2.py | 21 ++++++++++++--------- src/viser/transforms/_so3.py | 21 +++++++++++++-------- src/viser/transforms/utils/__init__.py | 5 ++--- src/viser/transforms/utils/_utils.py | 25 +------------------------ 7 files changed, 68 insertions(+), 65 deletions(-) diff --git a/src/viser/transforms/_base.py b/src/viser/transforms/_base.py index f3660a095..bf20e0e1b 100644 --- a/src/viser/transforms/_base.py +++ b/src/viser/transforms/_base.py @@ -9,9 +9,6 @@ class MatrixLieGroup(abc.ABC): """Interface definition for matrix Lie groups.""" - # Class properties. - # > These will be set in `_utils.register_lie_group()`. - matrix_dim: ClassVar[int] """Dimension of square matrix output from `.as_matrix()`.""" @@ -36,6 +33,19 @@ def __init__( """Construct a group object from its underlying parameters.""" raise NotImplementedError() + def __init_subclass__( + cls, + matrix_dim: int = 0, + parameters_dim: int = 0, + tangent_dim: int = 0, + space_dim: int = 0, + ) -> None: + """Set class properties for subclasses. We default to dummy values.""" + cls.matrix_dim = matrix_dim + cls.parameters_dim = parameters_dim + cls.tangent_dim = tangent_dim + cls.space_dim = space_dim + # Shared implementations. @overload @@ -66,11 +76,14 @@ def __matmul__( @classmethod @abc.abstractmethod - def identity(cls, batch_axes: Tuple[int, ...] = ()) -> Self: + def identity( + cls, batch_axes: Tuple[int, ...] = (), dtype: onpt.DTypeLike = onp.float64 + ) -> Self: """Returns identity element. Args: batch_axes: Any leading batch axes for the output transform. + dtype: Datatype for the output. Returns: Identity element. diff --git a/src/viser/transforms/_se2.py b/src/viser/transforms/_se2.py index 8952684f9..2a1e83694 100644 --- a/src/viser/transforms/_se2.py +++ b/src/viser/transforms/_se2.py @@ -7,17 +7,17 @@ from . import _base, hints from ._so2 import SO2 -from .utils import broadcast_leading_axes, get_epsilon, register_lie_group +from .utils import broadcast_leading_axes, get_epsilon -@register_lie_group( +@dataclasses.dataclass(frozen=True) +class SE2( + _base.SEBase[SO2], matrix_dim=3, parameters_dim=4, tangent_dim=3, - space_dim=2, -) -@dataclasses.dataclass(frozen=True) -class SE2(_base.SEBase[SO2]): + space_dim=3, +): """Special Euclidean group for proper rigid transforms in 2D. Broadcasting rules are the same as for numpy. @@ -77,10 +77,12 @@ def translation(self) -> onpt.NDArray[onp.floating]: @classmethod @override - def identity(cls, batch_axes: Tuple[int, ...] = ()) -> "SE2": + def identity( + cls, batch_axes: Tuple[int, ...] = (), dtype: onpt.DTypeLike = onp.float64 + ) -> "SE2": return SE2( unit_complex_xy=onp.broadcast_to( - onp.array([1.0, 0.0, 0.0, 0.0]), (*batch_axes, 4) + onp.array([1.0, 0.0, 0.0, 0.0], dtype=dtype), (*batch_axes, 4) ) ) diff --git a/src/viser/transforms/_se3.py b/src/viser/transforms/_se3.py index 5ce5a3ea3..c15a51dde 100644 --- a/src/viser/transforms/_se3.py +++ b/src/viser/transforms/_se3.py @@ -9,7 +9,7 @@ from . import _base from ._so3 import SO3 -from .utils import broadcast_leading_axes, get_epsilon, register_lie_group +from .utils import broadcast_leading_axes, get_epsilon def _skew(omega: onpt.NDArray[onp.floating]) -> onpt.NDArray[onp.floating]: @@ -23,14 +23,14 @@ def _skew(omega: onpt.NDArray[onp.floating]) -> onpt.NDArray[onp.floating]: ).reshape((*omega.shape[:-1], 3, 3)) -@register_lie_group( +@dataclasses.dataclass(frozen=True) +class SE3( + _base.SEBase[SO3], matrix_dim=4, parameters_dim=7, tangent_dim=6, space_dim=3, -) -@dataclasses.dataclass(frozen=True) -class SE3(_base.SEBase[SO3]): +): """Special Euclidean group for proper rigid transforms in 3D. Broadcasting rules are the same as for numpy. @@ -76,10 +76,13 @@ def translation(self) -> onpt.NDArray[onp.floating]: @classmethod @override - def identity(cls, batch_axes: Tuple[int, ...] = ()) -> SE3: + def identity( + cls, batch_axes: Tuple[int, ...] = (), dtype: onpt.DTypeLike = onp.float64 + ) -> SE3: return SE3( wxyz_xyz=onp.broadcast_to( - onp.array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), (*batch_axes, 7) + onp.array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=dtype), + (*batch_axes, 7), ) ) @@ -97,7 +100,7 @@ def from_matrix(cls, matrix: onpt.NDArray[onp.floating]) -> SE3: @override def as_matrix(self) -> onpt.NDArray[onp.floating]: - out = onp.zeros((*self.get_batch_axes(), 4, 4)) + out = onp.zeros((*self.get_batch_axes(), 4, 4), dtype=self.wxyz_xyz.dtype) out[..., :3, :3] = self.rotation().as_matrix() out[..., :3, 3] = self.translation() out[..., 3, 3] = 1.0 @@ -212,7 +215,8 @@ def adjoint(self) -> onpt.NDArray[onp.floating]: axis=-1, ), onp.concatenate( - [onp.zeros((*self.get_batch_axes(), 3, 3)), R], axis=-1 + [onp.zeros((*self.get_batch_axes(), 3, 3), dtype=R.dtype), R], + axis=-1, ), ], axis=-2, diff --git a/src/viser/transforms/_so2.py b/src/viser/transforms/_so2.py index db4421245..28322b758 100644 --- a/src/viser/transforms/_so2.py +++ b/src/viser/transforms/_so2.py @@ -8,17 +8,17 @@ from typing_extensions import override from . import _base, hints -from .utils import broadcast_leading_axes, register_lie_group +from .utils import broadcast_leading_axes -@register_lie_group( +@dataclasses.dataclass(frozen=True) +class SO2( + _base.SOBase, matrix_dim=2, parameters_dim=2, tangent_dim=1, space_dim=2, -) -@dataclasses.dataclass(frozen=True) -class SO2(_base.SOBase): +): """Special orthogonal group for 2D rotations. Broadcasting rules are the same as for numpy. @@ -53,10 +53,13 @@ def as_radians(self) -> onpt.NDArray[onp.floating]: @classmethod @override - def identity(cls, batch_axes: Tuple[int, ...] = ()) -> SO2: + def identity( + cls, batch_axes: Tuple[int, ...] = (), dtype: onpt.DTypeLike = onp.float64 + ) -> SO2: return SO2( unit_complex=onp.stack( - [onp.ones(batch_axes), onp.zeros(batch_axes)], axis=-1 + [onp.ones(batch_axes, dtype=dtype), onp.zeros(batch_axes, dtype=dtype)], + axis=-1, ) ) @@ -64,7 +67,7 @@ def identity(cls, batch_axes: Tuple[int, ...] = ()) -> SO2: @override def from_matrix(cls, matrix: onpt.NDArray[onp.floating]) -> SO2: assert matrix.shape[-2:] == (2, 2) - return SO2(unit_complex=onp.asarray(matrix[..., :, 0])) + return SO2(unit_complex=onp.array(matrix[..., :, 0])) # Accessors. @@ -119,7 +122,7 @@ def log(self) -> onpt.NDArray[onp.floating]: @override def adjoint(self) -> onpt.NDArray[onp.floating]: - return onp.ones((*self.get_batch_axes(), 1, 1)) + return onp.ones((*self.get_batch_axes(), 1, 1), dtype=self.unit_complex.dtype) @override def inverse(self) -> SO2: diff --git a/src/viser/transforms/_so3.py b/src/viser/transforms/_so3.py index c1436604f..1a74ebbe7 100644 --- a/src/viser/transforms/_so3.py +++ b/src/viser/transforms/_so3.py @@ -8,7 +8,7 @@ from typing_extensions import override from . import _base, hints -from .utils import broadcast_leading_axes, get_epsilon, register_lie_group +from .utils import broadcast_leading_axes, get_epsilon @dataclasses.dataclass(frozen=True) @@ -20,14 +20,14 @@ class RollPitchYaw: yaw: onpt.NDArray[onp.floating] -@register_lie_group( +@dataclasses.dataclass(frozen=True) +class SO3( + _base.SOBase, matrix_dim=3, parameters_dim=4, tangent_dim=3, space_dim=3, -) -@dataclasses.dataclass(frozen=True) -class SO3(_base.SOBase): +): """Special orthogonal group for 3D rotations. Broadcasting rules are the same as for numpy. @@ -173,9 +173,13 @@ def compute_yaw_radians(self) -> onpt.NDArray[onp.floating]: @classmethod @override - def identity(cls, batch_axes: Tuple[int, ...] = ()) -> SO3: + def identity( + cls, batch_axes: Tuple[int, ...] = (), dtype: onpt.DTypeLike = onp.float64 + ) -> SO3: return SO3( - wxyz=onp.broadcast_to(onp.array([1.0, 0.0, 0.0, 0.0]), (*batch_axes, 4)) + wxyz=onp.broadcast_to( + onp.array([1.0, 0.0, 0.0, 0.0], dtype=dtype), (*batch_axes, 4) + ) ) @classmethod @@ -316,7 +320,8 @@ def apply(self, target: onpt.NDArray[onp.floating]) -> onpt.NDArray[onp.floating # Compute using quaternion multiplys. padded_target = onp.concatenate( - [onp.zeros((*self.get_batch_axes(), 1)), target], axis=-1 + [onp.zeros((*self.get_batch_axes(), 1), dtype=target.dtype), target], + axis=-1, ) return (self @ SO3(wxyz=padded_target) @ self.inverse()).wxyz[..., 1:] diff --git a/src/viser/transforms/utils/__init__.py b/src/viser/transforms/utils/__init__.py index 11980074b..3ecb41d21 100644 --- a/src/viser/transforms/utils/__init__.py +++ b/src/viser/transforms/utils/__init__.py @@ -1,3 +1,2 @@ -from ._utils import broadcast_leading_axes, get_epsilon, register_lie_group - -__all__ = ["get_epsilon", "register_lie_group", "broadcast_leading_axes"] +from ._utils import broadcast_leading_axes as broadcast_leading_axes +from ._utils import get_epsilon as get_epsilon diff --git a/src/viser/transforms/utils/_utils.py b/src/viser/transforms/utils/_utils.py index a25e0ac10..8ec773729 100644 --- a/src/viser/transforms/utils/_utils.py +++ b/src/viser/transforms/utils/_utils.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Callable, Tuple, Type, TypeVar, Union, cast +from typing import TYPE_CHECKING, Tuple, TypeVar, Union, cast import numpy as onp @@ -26,29 +26,6 @@ def get_epsilon(dtype: onp.dtype) -> float: assert False -def register_lie_group( - *, - matrix_dim: int, - parameters_dim: int, - tangent_dim: int, - space_dim: int, -) -> Callable[[Type[T]], Type[T]]: - """Decorator for registering Lie group dataclasses. - - Sets dimensionality class variables, and marks all methods for JIT compilation. - """ - - def _wrap(cls: Type[T]) -> Type[T]: - # Register dimensions as class attributes. - cls.matrix_dim = matrix_dim - cls.parameters_dim = parameters_dim - cls.tangent_dim = tangent_dim - cls.space_dim = space_dim - return cls - - return _wrap - - TupleOfBroadcastable = TypeVar( "TupleOfBroadcastable", bound="Tuple[Union[MatrixLieGroup, onp.ndarray], ...]", From be294be414ad4904632204ac63d16fc461a7ca73 Mon Sep 17 00:00:00 2001 From: brentyi Date: Sun, 15 Sep 2024 14:16:08 -0700 Subject: [PATCH 2/8] Add back `sample_uniform()` --- src/viser/transforms/_base.py | 33 ++++++++++-------- src/viser/transforms/_se2.py | 44 +++++++++++------------- src/viser/transforms/_se3.py | 26 +++++++------- src/viser/transforms/_so2.py | 24 +++++++------ src/viser/transforms/_so3.py | 64 ++++++++++++++++++----------------- 5 files changed, 99 insertions(+), 92 deletions(-) diff --git a/src/viser/transforms/_base.py b/src/viser/transforms/_base.py index bf20e0e1b..523a926cb 100644 --- a/src/viser/transforms/_base.py +++ b/src/viser/transforms/_base.py @@ -185,20 +185,25 @@ def normalize(self) -> Self: Normalized group member. """ - # @classmethod - # @abc.abstractmethod - # def sample_uniform(cls, key: onp.ndarray, batch_axes: Tuple[int, ...] = ()) -> Self: - # """Draw a uniform sample from the group. Translations (if applicable) are in the - # range [-1, 1]. - # - # Args: - # key: PRNG key, as returned by `jax.random.PRNGKey()`. - # batch_axes: Any leading batch axes for the output transforms. Each - # sampled transform will be different. - # - # Returns: - # Sampled group member. - # """ + @classmethod + @abc.abstractmethod + def sample_uniform( + cls, + rng: onp.random.Generator, + batch_axes: Tuple[int, ...] = (), + dtype: onpt.DTypeLike = onp.float64, + ) -> Self: + """Draw a uniform sample from the group. Translations (if applicable) are in the + range [-1, 1]. + + Args: + rng: numpy generator object. + batch_axes: Any leading batch axes for the output transforms. Each + sampled transform will be different. + + Returns: + Sampled group member. + """ @final def get_batch_axes(self) -> Tuple[int, ...]: diff --git a/src/viser/transforms/_se2.py b/src/viser/transforms/_se2.py index 2a1e83694..27dc60c26 100644 --- a/src/viser/transforms/_se2.py +++ b/src/viser/transforms/_se2.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dataclasses from typing import Tuple, cast @@ -39,7 +41,7 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}(unit_complex={unit_complex}, xy={xy})" @staticmethod - def from_xy_theta(x: hints.Scalar, y: hints.Scalar, theta: hints.Scalar) -> "SE2": + def from_xy_theta(x: hints.Scalar, y: hints.Scalar, theta: hints.Scalar) -> SE2: """Construct a transformation from standard 2D pose parameters. This is not the same as integrating over a length-3 twist. @@ -56,7 +58,7 @@ def from_rotation_and_translation( cls, rotation: SO2, translation: onpt.NDArray[onp.floating], - ) -> "SE2": + ) -> SE2: assert translation.shape[-1:] == (2,) rotation, translation = broadcast_leading_axes((rotation, translation)) return SE2( @@ -79,7 +81,7 @@ def translation(self) -> onpt.NDArray[onp.floating]: @override def identity( cls, batch_axes: Tuple[int, ...] = (), dtype: onpt.DTypeLike = onp.float64 - ) -> "SE2": + ) -> SE2: return SE2( unit_complex_xy=onp.broadcast_to( onp.array([1.0, 0.0, 0.0, 0.0], dtype=dtype), (*batch_axes, 4) @@ -88,7 +90,7 @@ def identity( @classmethod @override - def from_matrix(cls, matrix: onpt.NDArray[onp.floating]) -> "SE2": + def from_matrix(cls, matrix: onpt.NDArray[onp.floating]) -> SE2: assert matrix.shape[-2:] == (3, 3) or matrix.shape[-2:] == (2, 3) # Currently assumes bottom row is [0, 0, 1]. return SE2.from_rotation_and_translation( @@ -125,7 +127,7 @@ def as_matrix(self) -> onpt.NDArray[onp.floating]: @classmethod @override - def exp(cls, tangent: onpt.NDArray[onp.floating]) -> "SE2": + def exp(cls, tangent: onpt.NDArray[onp.floating]) -> SE2: # Reference: # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se2.hpp#L558 # Also see: @@ -229,7 +231,7 @@ def log(self) -> onpt.NDArray[onp.floating]: return tangent @override - def adjoint(self: "SE2") -> onpt.NDArray[onp.floating]: + def adjoint(self: SE2) -> onpt.NDArray[onp.floating]: cos, sin, x, y = onp.moveaxis(self.unit_complex_xy, -1, 0) return onp.stack( [ @@ -246,21 +248,15 @@ def adjoint(self: "SE2") -> onpt.NDArray[onp.floating]: axis=-1, ).reshape((*self.get_batch_axes(), 3, 3)) - # @classmethod - # @override - # def sample_uniform( - # cls, key: onp.ndarray, batch_axes: jdc.Static[Tuple[int, ...]] = () - # ) -> "SE2": - # key0, key1 = jax.random.split(key) - # return SE2.from_rotation_and_translation( - # rotation=SO2.sample_uniform(key0, batch_axes=batch_axes), - # translation=jax.random.uniform( - # key=key1, - # shape=( - # *batch_axes, - # 2, - # ), - # minval=-1.0, - # maxval=1.0, - # ), - # ) + @classmethod + @override + def sample_uniform( + cls, + rng: onp.random.Generator, + batch_axes: Tuple[int, ...] = (), + dtype: onpt.DTypeLike = onp.float64, + ) -> SE2: + return SE2.from_rotation_and_translation( + SO2.sample_uniform(rng, batch_axes=batch_axes, dtype=type), + rng.uniform(low=-1.0, high=1.0, size=(*batch_axes, 2)).astype(dtype), + ) diff --git a/src/viser/transforms/_se3.py b/src/viser/transforms/_se3.py index c15a51dde..e426d5ca3 100644 --- a/src/viser/transforms/_se3.py +++ b/src/viser/transforms/_se3.py @@ -222,15 +222,17 @@ def adjoint(self) -> onpt.NDArray[onp.floating]: axis=-2, ) - # @classmethod - # @override - # def sample_uniform( - # cls, key: onp.ndarray, batch_axes: jdc.Static[Tuple[int, ...]] = () - # ) -> SE3: - # key0, key1 = jax.random.split(key) - # return SE3.from_rotation_and_translation( - # rotation=SO3.sample_uniform(key0, batch_axes=batch_axes), - # translation=jax.random.uniform( - # key=key1, shape=(*batch_axes, 3), minval=-1.0, maxval=1.0 - # ), - # ) + @classmethod + @override + def sample_uniform( + cls, + rng: onp.random.Generator, + batch_axes: Tuple[int, ...] = (), + dtype: onpt.DTypeLike = onp.float64, + ) -> SE3: + return SE3.from_rotation_and_translation( + rotation=SO3.sample_uniform(rng, batch_axes=batch_axes, dtype=dtype), + translation=rng.uniform(low=-1.0, high=1.0, size=(*batch_axes, 3)).astype( + dtype=dtype + ), + ) diff --git a/src/viser/transforms/_so2.py b/src/viser/transforms/_so2.py index 28322b758..e375b33d0 100644 --- a/src/viser/transforms/_so2.py +++ b/src/viser/transforms/_so2.py @@ -135,14 +135,16 @@ def normalize(self) -> SO2: / onp.linalg.norm(self.unit_complex, axis=-1, keepdims=True) ) - # @classmethod - # @override - # def sample_uniform( - # cls, key: onp.ndarray, batch_axes: jdc.Static[Tuple[int, ...]] = () - # ) -> SO2: - # out = SO2.from_radians( - # jax.random.uniform( - # key=key, shape=batch_axes, minval=0.0, maxval=2.0 * onp.pi) - # ) - # assert out.get_batch_axes() == batch_axes - # return out + @classmethod + @override + def sample_uniform( + cls, + rng: onp.random.Generator, + batch_axes: Tuple[int, ...] = (), + dtype: onpt.DTypeLike = onp.float64, + ) -> SO2: + out = SO2.from_radians( + rng.uniform(0.0, 2.0 * onp.pi, size=batch_axes).astype(dtype=dtype) + ) + assert out.get_batch_axes() == batch_axes + return out diff --git a/src/viser/transforms/_so3.py b/src/viser/transforms/_so3.py index 1a74ebbe7..e403a924a 100644 --- a/src/viser/transforms/_so3.py +++ b/src/viser/transforms/_so3.py @@ -434,34 +434,36 @@ def inverse(self) -> SO3: def normalize(self) -> SO3: return SO3(wxyz=self.wxyz / onp.linalg.norm(self.wxyz, axis=-1, keepdims=True)) - # @classmethod - # @override - # def sample_uniform( - # cls, key: onp.ndarray, batch_axes: jdc.Static[Tuple[int, ...]] = () - # ) -> SO3: - # # Uniformly sample over S^3. - # # > Reference: http://planning.cs.uiuc.edu/node198.html - # u1, u2, u3 = onp.moveaxis( - # jax.random.uniform( - # key=key, - # shape=(*batch_axes, 3), - # minval=onp.zeros(3), - # maxval=onp.array([1.0, 2.0 * onp.pi, 2.0 * onp.pi]), - # ), - # -1, - # 0, - # ) - # a = onp.sqrt(1.0 - u1) - # b = onp.sqrt(u1) - # - # return SO3( - # wxyz=onp.stack( - # [ - # a * onp.sin(u2), - # a * onp.cos(u2), - # b * onp.sin(u3), - # b * onp.cos(u3), - # ], - # axis=-1, - # ) - # ) + @classmethod + @override + def sample_uniform( + cls, + rng: onp.random.Generator, + batch_axes: Tuple[int, ...] = (), + dtype: onpt.DTypeLike = onp.float64, + ) -> SO3: + # Uniformly sample over S^3. + # > Reference: http://planning.cs.uiuc.edu/node198.html + u1, u2, u3 = onp.moveaxis( + rng.uniform( + low=onp.zeros(3), + high=onp.array([1.0, 2.0 * onp.pi, 2.0 * onp.pi]), + size=(*batch_axes, 3), + ).astype(dtype=dtype), + -1, + 0, + ) + a = onp.sqrt(1.0 - u1) + b = onp.sqrt(u1) + + return SO3( + wxyz=onp.stack( + [ + a * onp.sin(u2), + a * onp.cos(u2), + b * onp.sin(u3), + b * onp.cos(u3), + ], + axis=-1, + ) + ) From eacc8cc234213a1b3ea30de5485a14a4e53ba6a7 Mon Sep 17 00:00:00 2001 From: brentyi Date: Sun, 15 Sep 2024 15:47:30 -0700 Subject: [PATCH 3/8] Add basic transform tests --- src/viser/transforms/_se2.py | 4 +- tests/test_transforms_bijective.py | 149 +++++++++++++++++++++++++++++ tests/utils.py | 111 +++++++++++++++++++++ 3 files changed, 262 insertions(+), 2 deletions(-) create mode 100644 tests/test_transforms_bijective.py create mode 100644 tests/utils.py diff --git a/src/viser/transforms/_se2.py b/src/viser/transforms/_se2.py index 27dc60c26..8b9aa3739 100644 --- a/src/viser/transforms/_se2.py +++ b/src/viser/transforms/_se2.py @@ -18,7 +18,7 @@ class SE2( matrix_dim=3, parameters_dim=4, tangent_dim=3, - space_dim=3, + space_dim=2, ): """Special Euclidean group for proper rigid transforms in 2D. Broadcasting rules are the same as for numpy. @@ -257,6 +257,6 @@ def sample_uniform( dtype: onpt.DTypeLike = onp.float64, ) -> SE2: return SE2.from_rotation_and_translation( - SO2.sample_uniform(rng, batch_axes=batch_axes, dtype=type), + SO2.sample_uniform(rng, batch_axes=batch_axes, dtype=dtype), rng.uniform(low=-1.0, high=1.0, size=(*batch_axes, 2)).astype(dtype), ) diff --git a/tests/test_transforms_bijective.py b/tests/test_transforms_bijective.py new file mode 100644 index 000000000..5da720192 --- /dev/null +++ b/tests/test_transforms_bijective.py @@ -0,0 +1,149 @@ +"""Tests for general operation definitions.""" + +from typing import Tuple, Type + +import numpy as onp +import viser.transforms as vtf +from hypothesis import given, settings +from hypothesis import strategies as st + +from utils import ( + assert_arrays_close, + assert_transforms_close, + general_group_test, + sample_transform, +) + + +@general_group_test +def test_sample_uniform_valid( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...] +): + """Check that sample_uniform() returns valid group members.""" + T = sample_transform(Group, batch_axes) # Calls sample_uniform under the hood. + assert_transforms_close(T, T.normalize()) + + +@settings(deadline=None) +@given(_random_module=st.random_module()) +def test_so2_from_to_radians_bijective(_random_module): + """Check that we can convert from and to radians.""" + radians = onp.random.uniform(low=-onp.pi, high=onp.pi) + assert_arrays_close(vtf.SO2.from_radians(radians).as_radians(), radians) + + +@settings(deadline=None) +@given(_random_module=st.random_module()) +def test_so3_xyzw_bijective(_random_module): + """Check that we can convert between xyzw and wxyz quaternions.""" + T = sample_transform(vtf.SO3) + assert_transforms_close(T, vtf.SO3.from_quaternion_xyzw(T.as_quaternion_xyzw())) + + +@settings(deadline=None) +@given(_random_module=st.random_module()) +def test_so3_rpy_bijective(_random_module): + """Check that we can convert between quaternions and Euler angles.""" + T = sample_transform(vtf.SO3) + rpy = T.as_rpy_radians() + assert_transforms_close(T, vtf.SO3.from_rpy_radians(rpy.roll, rpy.pitch, rpy.yaw)) + + +@general_group_test +def test_log_exp_bijective( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...] +): + """Check 1-to-1 mapping for log <=> exp operations.""" + transform = sample_transform(Group, batch_axes) + + tangent = transform.log() + assert tangent.shape == (*batch_axes, Group.tangent_dim) + + exp_transform = Group.exp(tangent) + assert_transforms_close(transform, exp_transform) + assert_arrays_close(tangent, exp_transform.log()) + + +@general_group_test +def test_inverse_bijective( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...] +): + """Check inverse of inverse.""" + transform = sample_transform(Group, batch_axes) + assert_transforms_close(transform, transform.inverse().inverse()) + + +@general_group_test +def test_matrix_bijective(Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...]): + """Check that we can convert to and from matrices.""" + transform = sample_transform(Group, batch_axes) + assert_transforms_close(transform, Group.from_matrix(transform.as_matrix())) + + +@general_group_test +def test_adjoint(Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...]): + """Check adjoint definition.""" + transform = sample_transform(Group, batch_axes) + omega = onp.random.randn(*batch_axes, Group.tangent_dim) + assert_transforms_close( + transform @ Group.exp(omega), + Group.exp(onp.einsum("...ij,...j->...i", transform.adjoint(), omega)) + @ transform, + ) + + +@general_group_test +def test_repr(Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...]): + """Smoke test for __repr__ implementations.""" + transform = sample_transform(Group, batch_axes) + print(transform) + + +@general_group_test +def test_apply(Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...]): + """Check group action interfaces.""" + T_w_b = sample_transform(Group, batch_axes) + p_b = onp.random.randn(*batch_axes, Group.space_dim) + + if Group.matrix_dim == Group.space_dim: + assert_arrays_close( + T_w_b @ p_b, + T_w_b.apply(p_b), + onp.einsum("...ij,...j->...i", T_w_b.as_matrix(), p_b), + ) + else: + # Homogeneous coordinates. + assert Group.matrix_dim == Group.space_dim + 1 + assert_arrays_close( + T_w_b @ p_b, + T_w_b.apply(p_b), + onp.einsum( + "...ij,...j->...i", + T_w_b.as_matrix(), + onp.concatenate([p_b, onp.ones_like(p_b[..., :1])], axis=-1), + )[..., :-1], + ) + + +@general_group_test +def test_multiply(Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...]): + """Check multiply interfaces.""" + T_w_b = sample_transform(Group, batch_axes) + T_b_a = sample_transform(Group, batch_axes) + assert_arrays_close( + onp.einsum( + "...ij,...jk->...ik", T_w_b.as_matrix(), T_w_b.inverse().as_matrix() + ), + onp.broadcast_to( + onp.eye(Group.matrix_dim), (*batch_axes, Group.matrix_dim, Group.matrix_dim) + ), + ) + assert_arrays_close( + onp.einsum( + "...ij,...jk->...ik", T_w_b.as_matrix(), onp.linalg.inv(T_w_b.as_matrix()) + ), + onp.broadcast_to( + onp.eye(Group.matrix_dim), (*batch_axes, Group.matrix_dim, Group.matrix_dim) + ), + ) + assert_transforms_close(T_w_b @ T_b_a, Group.multiply(T_w_b, T_b_a)) diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 000000000..ce8a93a60 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,111 @@ +import functools +import random +from typing import Any, Callable, Tuple, Type, TypeVar, Union, cast + +import numpy as onp +import numpy.typing as onpt +import pytest +import viser.transforms as vtf +from hypothesis import given, settings +from hypothesis import strategies as st + +T = TypeVar("T", bound=vtf.MatrixLieGroup) + + +def sample_transform(Group: Type[T], batch_axes: Tuple[int, ...] = ()) -> T: + """Sample a random transform from a group.""" + seed = random.getrandbits(32) + strategy = random.randint(0, 2) + + if strategy == 0: + # Uniform sampling. + return cast( + T, + Group.sample_uniform(onp.random.default_rng(seed), batch_axes=batch_axes), + ) + elif strategy == 1: + # Sample from normally-sampled tangent vector. + return cast(T, Group.exp(onp.random.randn(*batch_axes, Group.tangent_dim))) + elif strategy == 2: + # Sample near identity. + return cast( + T, Group.exp(onp.random.randn(*batch_axes, Group.tangent_dim) * 1e-7) + ) + else: + assert False + + +def general_group_test( + f: Callable[[Type[vtf.MatrixLieGroup], Tuple[int, ...]], None], + max_examples: int = 30, +) -> Callable[[Type[vtf.MatrixLieGroup], Tuple[int, ...], Any], None]: + """Decorator for defining tests that run on all group types.""" + + # Disregard unused argument. + def f_wrapped( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], _random_module + ) -> None: + f(Group, batch_axes) + + # Disable timing check (first run requires JIT tracing and will be slower). + f_wrapped = settings(deadline=None, max_examples=max_examples)(f_wrapped) + + # Add _random_module parameter. + f_wrapped = given(_random_module=st.random_module())(f_wrapped) + + # Parametrize tests with each group type. + f_wrapped = pytest.mark.parametrize( + "Group", + [ + vtf.SO2, + vtf.SE2, + vtf.SO3, + vtf.SE3, + ], + )(f_wrapped) + + # Parametrize tests with each group type. + f_wrapped = pytest.mark.parametrize( + "batch_axes", + [ + (), + (1,), + (3, 1, 2, 1), + ], + )(f_wrapped) + return f_wrapped + + +general_group_test_faster = functools.partial(general_group_test, max_examples=5) + + +def assert_transforms_close(a: vtf.MatrixLieGroup, b: vtf.MatrixLieGroup): + """Make sure two transforms are equivalent.""" + # Check matrix representation. + assert_arrays_close(a.as_matrix(), b.as_matrix()) + + # Flip signs for quaternions. + # We use `jnp.asarray` here in case inputs are onp arrays and don't support `.at()`. + p1 = a.parameters().copy() + p2 = b.parameters().copy() + if isinstance(a, vtf.SO3): + p1 = p1 * onp.sign(onp.sum(p1, axis=-1, keepdims=True)) + p2 = p2 * onp.sign(onp.sum(p2, axis=-1, keepdims=True)) + elif isinstance(a, vtf.SE3): + p1[..., :4] *= onp.sign(onp.sum(p1[..., :4], axis=-1, keepdims=True)) + p2[..., :4] *= onp.sign(onp.sum(p2[..., :4], axis=-1, keepdims=True)) + + # Make sure parameters are equal. + assert_arrays_close(p1, p2) + + +def assert_arrays_close( + *arrays: Union[onpt.NDArray[onp.float64], float], + rtol: float = 1e-8, + atol: float = 1e-8, +): + """Make sure two arrays are close. (and not NaN)""" + for array1, array2 in zip(arrays[:-1], arrays[1:]): + onp.testing.assert_allclose(array1, array2, rtol=rtol, atol=atol) + assert not onp.any(onp.isnan(array1)) + assert not onp.any(onp.isnan(array2)) From 12fdb4f3e98ae087cfca57d9ef17f7b0e39d8af3 Mon Sep 17 00:00:00 2001 From: brentyi Date: Sun, 15 Sep 2024 18:13:20 -0700 Subject: [PATCH 4/8] Ensure dtype consistency --- src/viser/transforms/_se2.py | 29 ++++++------ src/viser/transforms/_se3.py | 6 ++- src/viser/transforms/_so2.py | 6 ++- src/viser/transforms/_so3.py | 74 +++++++++++++----------------- tests/test_transforms_bijective.py | 63 +++++++++++++++---------- tests/utils.py | 48 ++++++++++++++----- 6 files changed, 128 insertions(+), 98 deletions(-) diff --git a/src/viser/transforms/_se2.py b/src/viser/transforms/_se2.py index 8b9aa3739..e68b4208e 100644 --- a/src/viser/transforms/_se2.py +++ b/src/viser/transforms/_se2.py @@ -150,21 +150,15 @@ def exp(cls, tangent: onpt.NDArray[onp.floating]) -> SE2: ) theta_sq = theta**2 - sin_over_theta = cast( - onp.ndarray, - onp.where( - use_taylor, - 1.0 - theta_sq / 6.0, - onp.sin(safe_theta) / safe_theta, - ), + sin_over_theta = onp.where( + use_taylor, + 1.0 - theta_sq / 6.0, + onp.sin(safe_theta) / safe_theta, ) - one_minus_cos_over_theta = cast( - onp.ndarray, - onp.where( - use_taylor, - 0.5 * theta - theta * theta_sq / 24.0, - (1.0 - onp.cos(safe_theta)) / safe_theta, - ), + one_minus_cos_over_theta = onp.where( + use_taylor, + 0.5 * theta - theta * theta_sq / 24.0, + (1.0 - onp.cos(safe_theta)) / safe_theta, ) V = onp.stack( @@ -176,9 +170,12 @@ def exp(cls, tangent: onpt.NDArray[onp.floating]) -> SE2: ], axis=-1, ).reshape((*tangent.shape[:-1], 2, 2)) + return SE2.from_rotation_and_translation( rotation=SO2.from_radians(theta), - translation=onp.einsum("...ij,...j->...i", V, tangent[..., :2]), + translation=onp.einsum("...ij,...j->...i", V, tangent[..., :2]).astype( + tangent.dtype + ), ) @override @@ -228,7 +225,7 @@ def log(self) -> onpt.NDArray[onp.floating]: ], axis=-1, ) - return tangent + return tangent.astype(self.unit_complex_xy.dtype) @override def adjoint(self: SE2) -> onpt.NDArray[onp.floating]: diff --git a/src/viser/transforms/_se3.py b/src/viser/transforms/_se3.py index e426d5ca3..2bc0b187b 100644 --- a/src/viser/transforms/_se3.py +++ b/src/viser/transforms/_se3.py @@ -157,7 +157,9 @@ def exp(cls, tangent: onpt.NDArray[onp.floating]) -> SE3: return SE3.from_rotation_and_translation( rotation=rotation, - translation=onp.einsum("...ij,...j->...i", V, tangent[..., :3]), + translation=onp.einsum("...ij,...j->...i", V, tangent[..., :3]).astype( + tangent.dtype + ), ) @override @@ -203,7 +205,7 @@ def log(self) -> onpt.NDArray[onp.floating]: ) return onp.concatenate( [onp.einsum("...ij,...j->...i", V_inv, self.translation()), omega], axis=-1 - ) + ).astype(self.wxyz_xyz.dtype) @override def adjoint(self) -> onpt.NDArray[onp.floating]: diff --git a/src/viser/transforms/_so2.py b/src/viser/transforms/_so2.py index e375b33d0..a6b9c5161 100644 --- a/src/viser/transforms/_so2.py +++ b/src/viser/transforms/_so2.py @@ -77,7 +77,7 @@ def as_matrix(self) -> onpt.NDArray[onp.floating]: out = onp.stack( [ # [cos, -sin], - cos_sin * onp.array([1, -1]), + cos_sin * onp.array([1, -1], dtype=cos_sin.dtype), # [sin, cos], cos_sin[..., ::-1], ], @@ -126,7 +126,9 @@ def adjoint(self) -> onpt.NDArray[onp.floating]: @override def inverse(self) -> SO2: - return SO2(unit_complex=self.unit_complex * onp.array([1, -1])) + unit_complex = self.unit_complex.copy() + unit_complex[..., 1] *= -1 + return SO2(unit_complex) @override def normalize(self) -> SO2: diff --git a/src/viser/transforms/_so3.py b/src/viser/transforms/_so3.py index e403a924a..c8a7deb63 100644 --- a/src/viser/transforms/_so3.py +++ b/src/viser/transforms/_so3.py @@ -264,26 +264,7 @@ def case3(m): onp.where(cond1[..., None], case0_q, case1_q), onp.where(cond2[..., None], case2_q, case3_q), ) - - # We can also choose to branch, but this is slower. - # t, q = jax.lax.cond( - # matrix[2, 2] < 0, - # true_fun=lambda matrix: jax.lax.cond( - # matrix[0, 0] > matrix[1, 1], - # true_fun=case0, - # false_fun=case1, - # operand=matrix, - # ), - # false_fun=lambda matrix: jax.lax.cond( - # matrix[0, 0] < -matrix[1, 1], - # true_fun=case2, - # false_fun=case3, - # operand=matrix, - # ), - # operand=matrix, - # ) - - return SO3(wxyz=q * 0.5 / onp.sqrt(t[..., None])) + return SO3(wxyz=(q * 0.5 / onp.sqrt(t[..., None])).astype(matrix.dtype)) # Accessors. @@ -292,20 +273,24 @@ def as_matrix(self) -> onpt.NDArray[onp.floating]: norm_sq = onp.sum(onp.square(self.wxyz), axis=-1, keepdims=True) q = self.wxyz * onp.sqrt(2.0 / norm_sq) # (*, 4) q_outer = onp.einsum("...i,...j->...ij", q, q) # (*, 4, 4) - return onp.stack( - [ - 1.0 - q_outer[..., 2, 2] - q_outer[..., 3, 3], - q_outer[..., 1, 2] - q_outer[..., 3, 0], - q_outer[..., 1, 3] + q_outer[..., 2, 0], - q_outer[..., 1, 2] + q_outer[..., 3, 0], - 1.0 - q_outer[..., 1, 1] - q_outer[..., 3, 3], - q_outer[..., 2, 3] - q_outer[..., 1, 0], - q_outer[..., 1, 3] - q_outer[..., 2, 0], - q_outer[..., 2, 3] + q_outer[..., 1, 0], - 1.0 - q_outer[..., 1, 1] - q_outer[..., 2, 2], - ], - axis=-1, - ).reshape(*q.shape[:-1], 3, 3) + return ( + onp.stack( + [ + 1.0 - q_outer[..., 2, 2] - q_outer[..., 3, 3], + q_outer[..., 1, 2] - q_outer[..., 3, 0], + q_outer[..., 1, 3] + q_outer[..., 2, 0], + q_outer[..., 1, 2] + q_outer[..., 3, 0], + 1.0 - q_outer[..., 1, 1] - q_outer[..., 3, 3], + q_outer[..., 2, 3] - q_outer[..., 1, 0], + q_outer[..., 1, 3] - q_outer[..., 2, 0], + q_outer[..., 2, 3] + q_outer[..., 1, 0], + 1.0 - q_outer[..., 1, 1] - q_outer[..., 2, 2], + ], + axis=-1, + ) + .reshape(*q.shape[:-1], 3, 3) + .astype(self.wxyz.dtype) + ) @override def parameters(self) -> onpt.NDArray[onp.floating]: @@ -362,14 +347,16 @@ def exp(cls, tangent: onpt.NDArray[onp.floating]) -> SO3: theta_squared, ) ) - safe_half_theta = 0.5 * safe_theta + # Fun fact: when safe_theta is a `float32` _scalar_, this + # multiplication will promote `safe_half_theta` to `float64`. We'll + # cast at the end to make sure our input/output dtypes match. + safe_half_theta = 0.5 * safe_theta real_factor = onp.where( use_taylor, 1.0 - theta_squared / 8.0 + theta_pow_4 / 384.0, onp.cos(safe_half_theta), ) - imaginary_factor = onp.where( use_taylor, 0.5 - theta_squared / 48.0 + theta_pow_4 / 3840.0, @@ -383,7 +370,7 @@ def exp(cls, tangent: onpt.NDArray[onp.floating]) -> SO3: imaginary_factor[..., None] * tangent, ], axis=-1, - ) + ).astype(tangent.dtype) ) @override @@ -414,12 +401,11 @@ def log(self) -> onpt.NDArray[onp.floating]: 2.0 / w_safe - 2.0 / 3.0 * norm_sq / w_safe**3, onp.where( onp.abs(w) < get_epsilon(w.dtype), - onp.where(w > 0, 1.0, -1.0) * onp.pi / norm_safe, + onp.where(w > 0, 1.0, -1.0).astype(dtype=w.dtype) * onp.pi / norm_safe, 2.0 * atan_n_over_w / norm_safe, ), ) - - return atan_factor[..., None] * self.wxyz[..., 1:] # type: ignore + return (atan_factor[..., None] * self.wxyz[..., 1:]).astype(self.wxyz.dtype) @override def adjoint(self) -> onpt.NDArray[onp.floating]: @@ -428,7 +414,9 @@ def adjoint(self) -> onpt.NDArray[onp.floating]: @override def inverse(self) -> SO3: # Negate complex terms. - return SO3(wxyz=self.wxyz * onp.array([1, -1, -1, -1])) + wxyz = self.wxyz.copy() + wxyz[..., 1:] *= -1 + return SO3(wxyz) @override def normalize(self) -> SO3: @@ -449,7 +437,7 @@ def sample_uniform( low=onp.zeros(3), high=onp.array([1.0, 2.0 * onp.pi, 2.0 * onp.pi]), size=(*batch_axes, 3), - ).astype(dtype=dtype), + ), -1, 0, ) @@ -465,5 +453,5 @@ def sample_uniform( b * onp.cos(u3), ], axis=-1, - ) + ).astype(dtype=dtype) ) diff --git a/tests/test_transforms_bijective.py b/tests/test_transforms_bijective.py index 5da720192..5982f8431 100644 --- a/tests/test_transforms_bijective.py +++ b/tests/test_transforms_bijective.py @@ -3,6 +3,7 @@ from typing import Tuple, Type import numpy as onp +import numpy.typing as onpt import viser.transforms as vtf from hypothesis import given, settings from hypothesis import strategies as st @@ -17,10 +18,12 @@ @general_group_test def test_sample_uniform_valid( - Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...] + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike ): """Check that sample_uniform() returns valid group members.""" - T = sample_transform(Group, batch_axes) # Calls sample_uniform under the hood. + T = sample_transform( + Group, batch_axes, dtype + ) # Calls sample_uniform under the hood. assert_transforms_close(T, T.normalize()) @@ -36,7 +39,7 @@ def test_so2_from_to_radians_bijective(_random_module): @given(_random_module=st.random_module()) def test_so3_xyzw_bijective(_random_module): """Check that we can convert between xyzw and wxyz quaternions.""" - T = sample_transform(vtf.SO3) + T = sample_transform(vtf.SO3, (), dtype=onp.float64) assert_transforms_close(T, vtf.SO3.from_quaternion_xyzw(T.as_quaternion_xyzw())) @@ -44,19 +47,21 @@ def test_so3_xyzw_bijective(_random_module): @given(_random_module=st.random_module()) def test_so3_rpy_bijective(_random_module): """Check that we can convert between quaternions and Euler angles.""" - T = sample_transform(vtf.SO3) + T = sample_transform(vtf.SO3, (), dtype=onp.float64) rpy = T.as_rpy_radians() assert_transforms_close(T, vtf.SO3.from_rpy_radians(rpy.roll, rpy.pitch, rpy.yaw)) @general_group_test def test_log_exp_bijective( - Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...] + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike ): """Check 1-to-1 mapping for log <=> exp operations.""" - transform = sample_transform(Group, batch_axes) + transform = sample_transform(Group, batch_axes, dtype) + assert transform.parameters().dtype == dtype tangent = transform.log() + assert tangent.dtype == dtype assert tangent.shape == (*batch_axes, Group.tangent_dim) exp_transform = Group.exp(tangent) @@ -66,25 +71,29 @@ def test_log_exp_bijective( @general_group_test def test_inverse_bijective( - Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...] + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike ): """Check inverse of inverse.""" - transform = sample_transform(Group, batch_axes) + transform = sample_transform(Group, batch_axes, dtype) assert_transforms_close(transform, transform.inverse().inverse()) @general_group_test -def test_matrix_bijective(Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...]): +def test_matrix_bijective( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike +): """Check that we can convert to and from matrices.""" - transform = sample_transform(Group, batch_axes) + transform = sample_transform(Group, batch_axes, dtype) assert_transforms_close(transform, Group.from_matrix(transform.as_matrix())) @general_group_test -def test_adjoint(Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...]): +def test_adjoint( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike +): """Check adjoint definition.""" - transform = sample_transform(Group, batch_axes) - omega = onp.random.randn(*batch_axes, Group.tangent_dim) + transform = sample_transform(Group, batch_axes, dtype) + omega = onp.random.randn(*batch_axes, Group.tangent_dim).astype(dtype) assert_transforms_close( transform @ Group.exp(omega), Group.exp(onp.einsum("...ij,...j->...i", transform.adjoint(), omega)) @@ -93,17 +102,21 @@ def test_adjoint(Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...]): @general_group_test -def test_repr(Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...]): +def test_repr( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike +): """Smoke test for __repr__ implementations.""" - transform = sample_transform(Group, batch_axes) + transform = sample_transform(Group, batch_axes, dtype) print(transform) @general_group_test -def test_apply(Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...]): +def test_apply( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike +): """Check group action interfaces.""" - T_w_b = sample_transform(Group, batch_axes) - p_b = onp.random.randn(*batch_axes, Group.space_dim) + T_w_b = sample_transform(Group, batch_axes, dtype) + p_b = onp.random.randn(*batch_axes, Group.space_dim).astype(dtype) if Group.matrix_dim == Group.space_dim: assert_arrays_close( @@ -126,16 +139,19 @@ def test_apply(Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...]): @general_group_test -def test_multiply(Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...]): +def test_multiply( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike +): """Check multiply interfaces.""" - T_w_b = sample_transform(Group, batch_axes) - T_b_a = sample_transform(Group, batch_axes) + T_w_b = sample_transform(Group, batch_axes, dtype) + T_b_a = sample_transform(Group, batch_axes, dtype) assert_arrays_close( onp.einsum( "...ij,...jk->...ik", T_w_b.as_matrix(), T_w_b.inverse().as_matrix() ), onp.broadcast_to( - onp.eye(Group.matrix_dim), (*batch_axes, Group.matrix_dim, Group.matrix_dim) + onp.eye(Group.matrix_dim, dtype=dtype), + (*batch_axes, Group.matrix_dim, Group.matrix_dim), ), ) assert_arrays_close( @@ -143,7 +159,8 @@ def test_multiply(Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...]): "...ij,...jk->...ik", T_w_b.as_matrix(), onp.linalg.inv(T_w_b.as_matrix()) ), onp.broadcast_to( - onp.eye(Group.matrix_dim), (*batch_axes, Group.matrix_dim, Group.matrix_dim) + onp.eye(Group.matrix_dim, dtype=dtype), + (*batch_axes, Group.matrix_dim, Group.matrix_dim), ), ) assert_transforms_close(T_w_b @ T_b_a, Group.multiply(T_w_b, T_b_a)) diff --git a/tests/utils.py b/tests/utils.py index ce8a93a60..0711a838e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,14 +5,17 @@ import numpy as onp import numpy.typing as onpt import pytest -import viser.transforms as vtf from hypothesis import given, settings from hypothesis import strategies as st +import viser.transforms as vtf + T = TypeVar("T", bound=vtf.MatrixLieGroup) -def sample_transform(Group: Type[T], batch_axes: Tuple[int, ...] = ()) -> T: +def sample_transform( + Group: Type[T], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike +) -> T: """Sample a random transform from a group.""" seed = random.getrandbits(32) strategy = random.randint(0, 2) @@ -21,31 +24,45 @@ def sample_transform(Group: Type[T], batch_axes: Tuple[int, ...] = ()) -> T: # Uniform sampling. return cast( T, - Group.sample_uniform(onp.random.default_rng(seed), batch_axes=batch_axes), + Group.sample_uniform( + onp.random.default_rng(seed), batch_axes=batch_axes, dtype=dtype + ), ) elif strategy == 1: # Sample from normally-sampled tangent vector. - return cast(T, Group.exp(onp.random.randn(*batch_axes, Group.tangent_dim))) + return cast( + T, + Group.exp( + onp.random.randn(*batch_axes, Group.tangent_dim).astype(dtype=dtype) + ), + ) elif strategy == 2: # Sample near identity. return cast( - T, Group.exp(onp.random.randn(*batch_axes, Group.tangent_dim) * 1e-7) + T, + Group.exp( + onp.random.randn(*batch_axes, Group.tangent_dim).astype(dtype=dtype) + * 1e-7 + ), ) else: assert False def general_group_test( - f: Callable[[Type[vtf.MatrixLieGroup], Tuple[int, ...]], None], - max_examples: int = 30, -) -> Callable[[Type[vtf.MatrixLieGroup], Tuple[int, ...], Any], None]: + f: Callable[[Type[vtf.MatrixLieGroup], Tuple[int, ...], onpt.DTypeLike], None], + max_examples: int = 3, +) -> Callable[[Type[vtf.MatrixLieGroup], Tuple[int, ...], onpt.DTypeLike, Any], None]: """Decorator for defining tests that run on all group types.""" # Disregard unused argument. def f_wrapped( - Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], _random_module + Group: Type[vtf.MatrixLieGroup], + batch_axes: Tuple[int, ...], + dtype: onpt.DTypeLike, + _random_module, ) -> None: - f(Group, batch_axes) + f(Group, batch_axes, dtype) # Disable timing check (first run requires JIT tracing and will be slower). f_wrapped = settings(deadline=None, max_examples=max_examples)(f_wrapped) @@ -73,6 +90,12 @@ def f_wrapped( (3, 1, 2, 1), ], )(f_wrapped) + + # Parametrize tests with each group type. + f_wrapped = pytest.mark.parametrize( + "dtype", + [onp.float32, onp.float64], + )(f_wrapped) return f_wrapped @@ -101,11 +124,12 @@ def assert_transforms_close(a: vtf.MatrixLieGroup, b: vtf.MatrixLieGroup): def assert_arrays_close( *arrays: Union[onpt.NDArray[onp.float64], float], - rtol: float = 1e-8, - atol: float = 1e-8, + rtol: float = 1e-3, + atol: float = 1e-4, ): """Make sure two arrays are close. (and not NaN)""" for array1, array2 in zip(arrays[:-1], arrays[1:]): + assert onp.asarray(array1).dtype == onp.asarray(array2).dtype onp.testing.assert_allclose(array1, array2, rtol=rtol, atol=atol) assert not onp.any(onp.isnan(array1)) assert not onp.any(onp.isnan(array2)) From 77dbd948ed8d84a34ba2dc94c34db2fb32bd0d98 Mon Sep 17 00:00:00 2001 From: brentyi Date: Sun, 15 Sep 2024 18:25:32 -0700 Subject: [PATCH 5/8] More tests, GH action --- .github/workflows/pytest.yml | 28 +++++++ tests/test_transforms_axioms.py | 103 +++++++++++++++++++++++ tests/test_transforms_bijective.py | 4 +- tests/test_transforms_ops.py | 129 +++++++++++++++++++++++++++++ tests/utils.py | 2 +- 5 files changed, 263 insertions(+), 3 deletions(-) create mode 100644 .github/workflows/pytest.yml create mode 100644 tests/test_transforms_axioms.py create mode 100644 tests/test_transforms_ops.py diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml new file mode 100644 index 000000000..d34aafd47 --- /dev/null +++ b/.github/workflows/pytest.yml @@ -0,0 +1,28 @@ +name: pytest + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.8", "3.8", "3.9", "3.10", "3.11", "3.12"] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + pip install uv + uv pip install --system -e ".[dev,examples]" + - name: Test with pytest + run: | + pytest diff --git a/tests/test_transforms_axioms.py b/tests/test_transforms_axioms.py new file mode 100644 index 000000000..8ec6d3398 --- /dev/null +++ b/tests/test_transforms_axioms.py @@ -0,0 +1,103 @@ +"""Tests for group axioms. + +https://proofwiki.org/wiki/Definition:Group_Axioms +""" + +from typing import Tuple, Type + +import numpy as onp +import numpy.typing as onpt +from utils import ( + assert_arrays_close, + assert_transforms_close, + general_group_test, + sample_transform, +) + +import viser.transforms as vtf + + +@general_group_test +def test_closure( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike +): + """Check closure property.""" + transform_a = sample_transform(Group, batch_axes, dtype) + transform_b = sample_transform(Group, batch_axes, dtype) + + composed = transform_a @ transform_b + assert_transforms_close(composed, composed.normalize()) + composed = transform_b @ transform_a + assert_transforms_close(composed, composed.normalize()) + composed = Group.multiply(transform_a, transform_b) + assert_transforms_close(composed, composed.normalize()) + composed = Group.multiply(transform_b, transform_a) + assert_transforms_close(composed, composed.normalize()) + + +@general_group_test +def test_identity( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike +): + """Check identity property.""" + transform = sample_transform(Group, batch_axes, dtype) + identity = Group.identity(batch_axes, dtype=dtype) + assert_transforms_close(transform, identity @ transform) + assert_transforms_close(transform, transform @ identity) + assert_arrays_close( + transform.as_matrix(), + onp.einsum("...ij,...jk->...ik", identity.as_matrix(), transform.as_matrix()), + ) + assert_arrays_close( + transform.as_matrix(), + onp.einsum("...ij,...jk->...ik", transform.as_matrix(), identity.as_matrix()), + ) + + +@general_group_test +def test_inverse( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike +): + """Check inverse property.""" + transform = sample_transform(Group, batch_axes, dtype) + identity = Group.identity(batch_axes, dtype=dtype) + assert_transforms_close(identity, transform @ transform.inverse()) + assert_transforms_close(identity, transform.inverse() @ transform) + assert_transforms_close(identity, Group.multiply(transform, transform.inverse())) + assert_transforms_close(identity, Group.multiply(transform.inverse(), transform)) + assert_arrays_close( + onp.broadcast_to( + onp.eye(Group.matrix_dim, dtype=dtype), + (*batch_axes, Group.matrix_dim, Group.matrix_dim), + ), + onp.einsum( + "...ij,...jk->...ik", + transform.as_matrix(), + transform.inverse().as_matrix(), + ), + ) + assert_arrays_close( + onp.broadcast_to( + onp.eye(Group.matrix_dim, dtype=dtype), + (*batch_axes, Group.matrix_dim, Group.matrix_dim), + ), + onp.einsum( + "...ij,...jk->...ik", + transform.inverse().as_matrix(), + transform.as_matrix(), + ), + ) + + +@general_group_test +def test_associative( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike +): + """Check associative property.""" + transform_a = sample_transform(Group, batch_axes, dtype) + transform_b = sample_transform(Group, batch_axes, dtype) + transform_c = sample_transform(Group, batch_axes, dtype) + assert_transforms_close( + (transform_a @ transform_b) @ transform_c, + transform_a @ (transform_b @ transform_c), + ) diff --git a/tests/test_transforms_bijective.py b/tests/test_transforms_bijective.py index 5982f8431..85173a3d7 100644 --- a/tests/test_transforms_bijective.py +++ b/tests/test_transforms_bijective.py @@ -4,10 +4,8 @@ import numpy as onp import numpy.typing as onpt -import viser.transforms as vtf from hypothesis import given, settings from hypothesis import strategies as st - from utils import ( assert_arrays_close, assert_transforms_close, @@ -15,6 +13,8 @@ sample_transform, ) +import viser.transforms as vtf + @general_group_test def test_sample_uniform_valid( diff --git a/tests/test_transforms_ops.py b/tests/test_transforms_ops.py new file mode 100644 index 000000000..583db9c0b --- /dev/null +++ b/tests/test_transforms_ops.py @@ -0,0 +1,129 @@ +"""Tests for general operation definitions.""" + +from typing import Tuple, Type + +import numpy as onp +import numpy.typing as onpt +from jax import numpy as jnp +from utils import ( + assert_arrays_close, + assert_transforms_close, + general_group_test, + sample_transform, +) + +import viser.transforms as vtf + + +@general_group_test +def test_sample_uniform_valid( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike +): + """Check that sample_uniform() returns valid group members.""" + T = sample_transform( + Group, batch_axes, dtype + ) # Calls sample_uniform under the hood. + assert_transforms_close(T, T.normalize()) + + +@general_group_test +def test_log_exp_bijective( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike +): + """Check 1-to-1 mapping for log <=> exp operations.""" + transform = sample_transform(Group, batch_axes, dtype) + + tangent = transform.log() + assert tangent.shape == (*batch_axes, Group.tangent_dim) + + exp_transform = Group.exp(tangent) + assert_transforms_close(transform, exp_transform) + assert_arrays_close(tangent, exp_transform.log()) + + +@general_group_test +def test_inverse_bijective( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike +): + """Check inverse of inverse.""" + transform = sample_transform(Group, batch_axes, dtype) + assert_transforms_close(transform, transform.inverse().inverse()) + + +@general_group_test +def test_matrix_bijective( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike +): + """Check that we can convert to and from matrices.""" + transform = sample_transform(Group, batch_axes, dtype) + assert_transforms_close(transform, Group.from_matrix(transform.as_matrix())) + + +@general_group_test +def test_adjoint( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike +): + """Check adjoint definition.""" + transform = sample_transform(Group, batch_axes, dtype) + omega = onp.random.randn(*batch_axes, Group.tangent_dim) + assert_transforms_close( + transform @ Group.exp(omega), + Group.exp(onp.einsum("...ij,...j->...i", transform.adjoint(), omega)) + @ transform, + ) + + +@general_group_test +def test_repr( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike +): + """Smoke test for __repr__ implementations.""" + transform = sample_transform(Group, batch_axes, dtype) + print(transform) + + +@general_group_test +def test_apply( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike +): + """Check group action interfaces.""" + T_w_b = sample_transform(Group, batch_axes, dtype) + p_b = onp.random.randn(*batch_axes, Group.space_dim).astype(dtype) + + if Group.matrix_dim == Group.space_dim: + assert_arrays_close( + T_w_b @ p_b, + T_w_b.apply(p_b), + onp.einsum("...ij,...j->...i", T_w_b.as_matrix(), p_b), + ) + else: + # Homogeneous coordinates. + assert Group.matrix_dim == Group.space_dim + 1 + assert_arrays_close( + T_w_b @ p_b, + T_w_b.apply(p_b), + onp.einsum( + "...ij,...j->...i", + T_w_b.as_matrix(), + onp.concatenate([p_b, onp.ones_like(p_b[..., :1])], axis=-1), + )[..., :-1], + ) + + +@general_group_test +def test_multiply( + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike +): + """Check multiply interfaces.""" + T_w_b = sample_transform(Group, batch_axes, dtype) + T_b_a = sample_transform(Group, batch_axes, dtype) + assert_arrays_close( + onp.einsum( + "...ij,...jk->...ik", T_w_b.as_matrix(), jnp.linalg.inv(T_w_b.as_matrix()) + ), + onp.broadcast_to( + onp.eye(Group.matrix_dim, dtype=dtype), + (*batch_axes, Group.matrix_dim, Group.matrix_dim), + ), + ) + assert_transforms_close(T_w_b @ T_b_a, Group.multiply(T_w_b, T_b_a)) diff --git a/tests/utils.py b/tests/utils.py index 0711a838e..1d023bdb3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -51,7 +51,7 @@ def sample_transform( def general_group_test( f: Callable[[Type[vtf.MatrixLieGroup], Tuple[int, ...], onpt.DTypeLike], None], - max_examples: int = 3, + max_examples: int = 15, ) -> Callable[[Type[vtf.MatrixLieGroup], Tuple[int, ...], onpt.DTypeLike, Any], None]: """Decorator for defining tests that run on all group types.""" From 9b7c19d6ee2fb2037ab273f5ad34fc8e1ec287ec Mon Sep 17 00:00:00 2001 From: brentyi Date: Sun, 15 Sep 2024 19:24:30 -0700 Subject: [PATCH 6/8] Add pytest dependency --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index d82a2478a..ac805c890 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ dev = [ "pyright>=1.1.308", "ruff==0.6.2", "pre-commit==3.3.2", + "pytest", ] examples = [ "torch>=1.13.1", From 2e7765d8a9f4c5441b6aac7262b0dcae2a5924f0 Mon Sep 17 00:00:00 2001 From: brentyi Date: Sun, 15 Sep 2024 19:26:49 -0700 Subject: [PATCH 7/8] Add hypothesis --- pyproject.toml | 1 + tests/test_transforms_ops.py | 7 +++---- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ac805c890..7dd76da91 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ dev = [ "ruff==0.6.2", "pre-commit==3.3.2", "pytest", + "hypothesis[numpy]", ] examples = [ "torch>=1.13.1", diff --git a/tests/test_transforms_ops.py b/tests/test_transforms_ops.py index 583db9c0b..019626d3d 100644 --- a/tests/test_transforms_ops.py +++ b/tests/test_transforms_ops.py @@ -4,7 +4,8 @@ import numpy as onp import numpy.typing as onpt -from jax import numpy as jnp +import viser.transforms as vtf + from utils import ( assert_arrays_close, assert_transforms_close, @@ -12,8 +13,6 @@ sample_transform, ) -import viser.transforms as vtf - @general_group_test def test_sample_uniform_valid( @@ -119,7 +118,7 @@ def test_multiply( T_b_a = sample_transform(Group, batch_axes, dtype) assert_arrays_close( onp.einsum( - "...ij,...jk->...ik", T_w_b.as_matrix(), jnp.linalg.inv(T_w_b.as_matrix()) + "...ij,...jk->...ik", T_w_b.as_matrix(), onp.linalg.inv(T_w_b.as_matrix()) ), onp.broadcast_to( onp.eye(Group.matrix_dim, dtype=dtype), From 6df992778b7150f804ad6c7c4d9a2a2383a942f3 Mon Sep 17 00:00:00 2001 From: brentyi Date: Sun, 15 Sep 2024 19:39:35 -0700 Subject: [PATCH 8/8] Minor fixes --- tests/test_transforms_bijective.py | 1 + tests/test_transforms_ops.py | 6 +++--- tests/utils.py | 14 +++++++++----- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/tests/test_transforms_bijective.py b/tests/test_transforms_bijective.py index 85173a3d7..a9f0c835c 100644 --- a/tests/test_transforms_bijective.py +++ b/tests/test_transforms_bijective.py @@ -94,6 +94,7 @@ def test_adjoint( """Check adjoint definition.""" transform = sample_transform(Group, batch_axes, dtype) omega = onp.random.randn(*batch_axes, Group.tangent_dim).astype(dtype) + assert (transform @ Group.exp(omega)).parameters().dtype == dtype assert_transforms_close( transform @ Group.exp(omega), Group.exp(onp.einsum("...ij,...j->...i", transform.adjoint(), omega)) diff --git a/tests/test_transforms_ops.py b/tests/test_transforms_ops.py index 019626d3d..6321b8896 100644 --- a/tests/test_transforms_ops.py +++ b/tests/test_transforms_ops.py @@ -4,8 +4,6 @@ import numpy as onp import numpy.typing as onpt -import viser.transforms as vtf - from utils import ( assert_arrays_close, assert_transforms_close, @@ -13,6 +11,8 @@ sample_transform, ) +import viser.transforms as vtf + @general_group_test def test_sample_uniform_valid( @@ -64,7 +64,7 @@ def test_adjoint( ): """Check adjoint definition.""" transform = sample_transform(Group, batch_axes, dtype) - omega = onp.random.randn(*batch_axes, Group.tangent_dim) + omega = onp.random.randn(*batch_axes, Group.tangent_dim).astype(dtype=dtype) assert_transforms_close( transform @ Group.exp(omega), Group.exp(onp.einsum("...ij,...j->...i", transform.adjoint(), omega)) diff --git a/tests/utils.py b/tests/utils.py index 1d023bdb3..ed7346e1e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -122,14 +122,18 @@ def assert_transforms_close(a: vtf.MatrixLieGroup, b: vtf.MatrixLieGroup): assert_arrays_close(p1, p2) -def assert_arrays_close( - *arrays: Union[onpt.NDArray[onp.float64], float], - rtol: float = 1e-3, - atol: float = 1e-4, -): +def assert_arrays_close(*arrays: Union[onpt.NDArray[onp.float64], float]): """Make sure two arrays are close. (and not NaN)""" for array1, array2 in zip(arrays[:-1], arrays[1:]): assert onp.asarray(array1).dtype == onp.asarray(array2).dtype + + if isinstance(array1, (float, int)) or array1.dtype == onp.float64: + rtol = 1e-7 + atol = 1e-7 + else: + rtol = 1e-3 + atol = 1e-3 + onp.testing.assert_allclose(array1, array2, rtol=rtol, atol=atol) assert not onp.any(onp.isnan(array1)) assert not onp.any(onp.isnan(array2))