Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transform helpers cleanup #285

Merged
merged 8 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
name: pytest

on:
push:
branches: [main]
pull_request:
branches: [main]

jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.8", "3.9", "3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install uv
uv pip install --system -e ".[dev,examples]"
- name: Test with pytest
run: |
pytest
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ dev = [
"pyright>=1.1.308",
"ruff==0.6.2",
"pre-commit==3.3.2",
"pytest",
"hypothesis[numpy]",
]
examples = [
"torch>=1.13.1",
Expand Down
54 changes: 36 additions & 18 deletions src/viser/transforms/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@
class MatrixLieGroup(abc.ABC):
"""Interface definition for matrix Lie groups."""

# Class properties.
# > These will be set in `_utils.register_lie_group()`.

matrix_dim: ClassVar[int]
"""Dimension of square matrix output from `.as_matrix()`."""

Expand All @@ -36,6 +33,19 @@ def __init__(
"""Construct a group object from its underlying parameters."""
raise NotImplementedError()

def __init_subclass__(
cls,
matrix_dim: int = 0,
parameters_dim: int = 0,
tangent_dim: int = 0,
space_dim: int = 0,
) -> None:
"""Set class properties for subclasses. We default to dummy values."""
cls.matrix_dim = matrix_dim
cls.parameters_dim = parameters_dim
cls.tangent_dim = tangent_dim
cls.space_dim = space_dim

# Shared implementations.

@overload
Expand Down Expand Up @@ -66,11 +76,14 @@ def __matmul__(

@classmethod
@abc.abstractmethod
def identity(cls, batch_axes: Tuple[int, ...] = ()) -> Self:
def identity(
cls, batch_axes: Tuple[int, ...] = (), dtype: onpt.DTypeLike = onp.float64
) -> Self:
"""Returns identity element.

Args:
batch_axes: Any leading batch axes for the output transform.
dtype: Datatype for the output.

Returns:
Identity element.
Expand Down Expand Up @@ -172,20 +185,25 @@ def normalize(self) -> Self:
Normalized group member.
"""

# @classmethod
# @abc.abstractmethod
# def sample_uniform(cls, key: onp.ndarray, batch_axes: Tuple[int, ...] = ()) -> Self:
# """Draw a uniform sample from the group. Translations (if applicable) are in the
# range [-1, 1].
#
# Args:
# key: PRNG key, as returned by `jax.random.PRNGKey()`.
# batch_axes: Any leading batch axes for the output transforms. Each
# sampled transform will be different.
#
# Returns:
# Sampled group member.
# """
@classmethod
@abc.abstractmethod
def sample_uniform(
cls,
rng: onp.random.Generator,
batch_axes: Tuple[int, ...] = (),
dtype: onpt.DTypeLike = onp.float64,
) -> Self:
"""Draw a uniform sample from the group. Translations (if applicable) are in the
range [-1, 1].

Args:
rng: numpy generator object.
batch_axes: Any leading batch axes for the output transforms. Each
sampled transform will be different.

Returns:
Sampled group member.
"""

@final
def get_batch_axes(self) -> Tuple[int, ...]:
Expand Down
87 changes: 41 additions & 46 deletions src/viser/transforms/_se2.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import dataclasses
from typing import Tuple, cast

Expand All @@ -7,17 +9,17 @@

from . import _base, hints
from ._so2 import SO2
from .utils import broadcast_leading_axes, get_epsilon, register_lie_group
from .utils import broadcast_leading_axes, get_epsilon


@register_lie_group(
@dataclasses.dataclass(frozen=True)
class SE2(
_base.SEBase[SO2],
matrix_dim=3,
parameters_dim=4,
tangent_dim=3,
space_dim=2,
)
@dataclasses.dataclass(frozen=True)
class SE2(_base.SEBase[SO2]):
):
"""Special Euclidean group for proper rigid transforms in 2D. Broadcasting
rules are the same as for numpy.

Expand All @@ -39,7 +41,7 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}(unit_complex={unit_complex}, xy={xy})"

@staticmethod
def from_xy_theta(x: hints.Scalar, y: hints.Scalar, theta: hints.Scalar) -> "SE2":
def from_xy_theta(x: hints.Scalar, y: hints.Scalar, theta: hints.Scalar) -> SE2:
"""Construct a transformation from standard 2D pose parameters.

This is not the same as integrating over a length-3 twist.
Expand All @@ -56,7 +58,7 @@ def from_rotation_and_translation(
cls,
rotation: SO2,
translation: onpt.NDArray[onp.floating],
) -> "SE2":
) -> SE2:
assert translation.shape[-1:] == (2,)
rotation, translation = broadcast_leading_axes((rotation, translation))
return SE2(
Expand All @@ -77,16 +79,18 @@ def translation(self) -> onpt.NDArray[onp.floating]:

@classmethod
@override
def identity(cls, batch_axes: Tuple[int, ...] = ()) -> "SE2":
def identity(
cls, batch_axes: Tuple[int, ...] = (), dtype: onpt.DTypeLike = onp.float64
) -> SE2:
return SE2(
unit_complex_xy=onp.broadcast_to(
onp.array([1.0, 0.0, 0.0, 0.0]), (*batch_axes, 4)
onp.array([1.0, 0.0, 0.0, 0.0], dtype=dtype), (*batch_axes, 4)
)
)

@classmethod
@override
def from_matrix(cls, matrix: onpt.NDArray[onp.floating]) -> "SE2":
def from_matrix(cls, matrix: onpt.NDArray[onp.floating]) -> SE2:
assert matrix.shape[-2:] == (3, 3) or matrix.shape[-2:] == (2, 3)
# Currently assumes bottom row is [0, 0, 1].
return SE2.from_rotation_and_translation(
Expand Down Expand Up @@ -123,7 +127,7 @@ def as_matrix(self) -> onpt.NDArray[onp.floating]:

@classmethod
@override
def exp(cls, tangent: onpt.NDArray[onp.floating]) -> "SE2":
def exp(cls, tangent: onpt.NDArray[onp.floating]) -> SE2:
# Reference:
# > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se2.hpp#L558
# Also see:
Expand All @@ -146,21 +150,15 @@ def exp(cls, tangent: onpt.NDArray[onp.floating]) -> "SE2":
)

theta_sq = theta**2
sin_over_theta = cast(
onp.ndarray,
onp.where(
use_taylor,
1.0 - theta_sq / 6.0,
onp.sin(safe_theta) / safe_theta,
),
sin_over_theta = onp.where(
use_taylor,
1.0 - theta_sq / 6.0,
onp.sin(safe_theta) / safe_theta,
)
one_minus_cos_over_theta = cast(
onp.ndarray,
onp.where(
use_taylor,
0.5 * theta - theta * theta_sq / 24.0,
(1.0 - onp.cos(safe_theta)) / safe_theta,
),
one_minus_cos_over_theta = onp.where(
use_taylor,
0.5 * theta - theta * theta_sq / 24.0,
(1.0 - onp.cos(safe_theta)) / safe_theta,
)

V = onp.stack(
Expand All @@ -172,9 +170,12 @@ def exp(cls, tangent: onpt.NDArray[onp.floating]) -> "SE2":
],
axis=-1,
).reshape((*tangent.shape[:-1], 2, 2))

return SE2.from_rotation_and_translation(
rotation=SO2.from_radians(theta),
translation=onp.einsum("...ij,...j->...i", V, tangent[..., :2]),
translation=onp.einsum("...ij,...j->...i", V, tangent[..., :2]).astype(
tangent.dtype
),
)

@override
Expand Down Expand Up @@ -224,10 +225,10 @@ def log(self) -> onpt.NDArray[onp.floating]:
],
axis=-1,
)
return tangent
return tangent.astype(self.unit_complex_xy.dtype)

@override
def adjoint(self: "SE2") -> onpt.NDArray[onp.floating]:
def adjoint(self: SE2) -> onpt.NDArray[onp.floating]:
cos, sin, x, y = onp.moveaxis(self.unit_complex_xy, -1, 0)
return onp.stack(
[
Expand All @@ -244,21 +245,15 @@ def adjoint(self: "SE2") -> onpt.NDArray[onp.floating]:
axis=-1,
).reshape((*self.get_batch_axes(), 3, 3))

# @classmethod
# @override
# def sample_uniform(
# cls, key: onp.ndarray, batch_axes: jdc.Static[Tuple[int, ...]] = ()
# ) -> "SE2":
# key0, key1 = jax.random.split(key)
# return SE2.from_rotation_and_translation(
# rotation=SO2.sample_uniform(key0, batch_axes=batch_axes),
# translation=jax.random.uniform(
# key=key1,
# shape=(
# *batch_axes,
# 2,
# ),
# minval=-1.0,
# maxval=1.0,
# ),
# )
@classmethod
@override
def sample_uniform(
cls,
rng: onp.random.Generator,
batch_axes: Tuple[int, ...] = (),
dtype: onpt.DTypeLike = onp.float64,
) -> SE2:
return SE2.from_rotation_and_translation(
SO2.sample_uniform(rng, batch_axes=batch_axes, dtype=dtype),
rng.uniform(low=-1.0, high=1.0, size=(*batch_axes, 2)).astype(dtype),
)
Loading
Loading