From 525391cb091cf303d599b4bc057c593a65afa4af Mon Sep 17 00:00:00 2001 From: Nick Date: Thu, 6 Jun 2024 17:40:26 +0300 Subject: [PATCH] feat: add forward kinematics & euler angle conversions monads --- .../geometry/rotation/euler_to_rotmat.yaml | 4 + .../geometry/rotation/rotmat_to_euler.yaml | 4 + .../monads/human/pose/forward_kinematics.yaml | 4 + moai/monads/geometry/rotations/euler.py | 167 ++++++++++++++++++ moai/monads/human/body/__init__.py | 9 - moai/monads/human/body/kinematics.py | 123 ------------- moai/monads/human/pose/forward_kinematics.py | 57 ++++++ 7 files changed, 236 insertions(+), 132 deletions(-) create mode 100644 moai/conf/model/monads/geometry/rotation/euler_to_rotmat.yaml create mode 100644 moai/conf/model/monads/geometry/rotation/rotmat_to_euler.yaml create mode 100644 moai/conf/model/monads/human/pose/forward_kinematics.yaml create mode 100644 moai/monads/geometry/rotations/euler.py delete mode 100644 moai/monads/human/body/kinematics.py create mode 100644 moai/monads/human/pose/forward_kinematics.py diff --git a/moai/conf/model/monads/geometry/rotation/euler_to_rotmat.yaml b/moai/conf/model/monads/geometry/rotation/euler_to_rotmat.yaml new file mode 100644 index 00000000..332d6145 --- /dev/null +++ b/moai/conf/model/monads/geometry/rotation/euler_to_rotmat.yaml @@ -0,0 +1,4 @@ +# @package model.monads.euler_to_rotmat + +_target_: moai.monads.geometry.rotations.euler.EulerToRotationMatrix +order: XYZ \ No newline at end of file diff --git a/moai/conf/model/monads/geometry/rotation/rotmat_to_euler.yaml b/moai/conf/model/monads/geometry/rotation/rotmat_to_euler.yaml new file mode 100644 index 00000000..2b8e69f5 --- /dev/null +++ b/moai/conf/model/monads/geometry/rotation/rotmat_to_euler.yaml @@ -0,0 +1,4 @@ +# @package model.monads.rotmat_to_euler + +_target_: moai.monads.geometry.rotations.euler.RotationMatrixToEuler +order: XYZ \ No newline at end of file diff --git a/moai/conf/model/monads/human/pose/forward_kinematics.yaml b/moai/conf/model/monads/human/pose/forward_kinematics.yaml new file mode 100644 index 00000000..02f0c3e0 --- /dev/null +++ b/moai/conf/model/monads/human/pose/forward_kinematics.yaml @@ -0,0 +1,4 @@ +# @package model.monads.forward_kinematics + +_target_: moai.monads.human.pose.forward_kinematics.ForwardKinematics +parents: ??? # int[] \ No newline at end of file diff --git a/moai/monads/geometry/rotations/euler.py b/moai/monads/geometry/rotations/euler.py new file mode 100644 index 00000000..25355997 --- /dev/null +++ b/moai/monads/geometry/rotations/euler.py @@ -0,0 +1,167 @@ +# NOTE: adapted from PyTorch3D + +import torch + +__all__ = ["euler_angles_to_matrix", "EulerToRotationMatrix"] + + +def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor: + """ + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = torch.cos(angle) + sin = torch.sin(angle) + one = torch.ones_like(angle) + zero = torch.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + elif axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + elif axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + else: + raise ValueError("letter must be either X, Y or Z.") + + return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + + +def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor: + """ + Convert rotations given as Euler angles in radians to rotation matrices. + + Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3). + convention: Convention string of three uppercase letters from + {"X", "Y", and "Z"}. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = [ + _axis_angle_rotation(c, e) + for c, e in zip(convention, torch.unbind(euler_angles, -1)) + ] + # return functools.reduce(torch.matmul, matrices) + return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2]) + + +def _index_from_letter(letter: str) -> int: + if letter == "X": + return 0 + if letter == "Y": + return 1 + if letter == "Z": + return 2 + raise ValueError("letter must be either X, Y or Z.") + + +def _angle_from_tan( + axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool +) -> torch.Tensor: + """ + Extract the first or third Euler angle from the two members of + the matrix which are positive constant times its sine and cosine. + + Args: + axis: Axis label "X" or "Y or "Z" for the angle we are finding. + other_axis: Axis label "X" or "Y or "Z" for the middle axis in the + convention. + data: Rotation matrices as tensor of shape (..., 3, 3). + horizontal: Whether we are looking for the angle for the third axis, + which means the relevant entries are in the same row of the + rotation matrix. If not, they are in the same column. + tait_bryan: Whether the first and third axes in the convention differ. + + Returns: + Euler Angles in radians for each matrix in data as a tensor + of shape (...). + """ + + i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] + if horizontal: + i2, i1 = i1, i2 + even = (axis + other_axis) in ["XY", "YZ", "ZX"] + if horizontal == even: + return torch.atan2(data[..., i1], data[..., i2]) + if tait_bryan: + return torch.atan2(-data[..., i2], data[..., i1]) + return torch.atan2(data[..., i2], -data[..., i1]) + + +def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to Euler angles in radians. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + convention: Convention string of three uppercase letters. + + Returns: + Euler angles in radians as tensor of shape (..., 3). + """ + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + i0 = _index_from_letter(convention[0]) + i2 = _index_from_letter(convention[2]) + tait_bryan = i0 != i2 + if tait_bryan: + central_angle = torch.asin( + matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0) + ) + else: + central_angle = torch.acos(matrix[..., i0, i0]) + + o = ( + _angle_from_tan( + convention[0], convention[1], matrix[..., i2], False, tait_bryan + ), + central_angle, + _angle_from_tan( + convention[2], convention[1], matrix[..., i0, :], True, tait_bryan + ), + ) + return torch.stack(o, -1) + + +class EulerToRotationMatrix(torch.nn.Module): + def __init__(self, order: str = "XYZ") -> None: + super().__init__() + self.order = order + + def forward(self, euler: torch.Tensor) -> torch.Tensor: + return euler_angles_to_matrix(euler, self.order) + + +class RotationMatrixToEuler(torch.nn.Module): + def __init__(self, order: str = "XYZ") -> None: + super().__init__() + self.order = order + + def forward(self, rotation: torch.Tensor) -> torch.Tensor: + return matrix_to_euler_angles(rotation, self.order) diff --git a/moai/monads/human/body/__init__.py b/moai/monads/human/body/__init__.py index adf2d1a5..e69de29b 100644 --- a/moai/monads/human/body/__init__.py +++ b/moai/monads/human/body/__init__.py @@ -1,9 +0,0 @@ -# from moai.monads.human.body.init_translation import InitTranslation -# from moai.monads.human.body.joint_regressor import JointRegressor -# from moai.monads.human.body.transfer import BodyTransfer - -# __all__ = [ -# 'InitTranslation', -# 'JointRegressor', -# 'BodyTransfer', -# ] diff --git a/moai/monads/human/body/kinematics.py b/moai/monads/human/body/kinematics.py deleted file mode 100644 index 1b4aed0d..00000000 --- a/moai/monads/human/body/kinematics.py +++ /dev/null @@ -1,123 +0,0 @@ -import typing - -import torch - - -class FK(torch.nn.Module): - def __init__( - self, - parents: torch.Tensor, - ): - """ - Forward Kinematics - returns the global poses of the joints - local_oris: (batch_size, n_joints, 3, 3) - parents: (n_joints, ) - output: (batch_size, n_joints, 3, 3) - """ - super(FK, self).__init__() - self.parents = parents - - def forward(self, local_poses: torch.Tensor) -> torch.Tensor: - n_joints = len(self.parents) - global_oris = torch.zeros_like(local_poses) - for j in range(n_joints): - if self.parents[j] < 0: # root rotation - global_oris[..., j, :, :] = local_poses[..., j, :, :] - else: - parent_rot = global_oris[..., self.parents[j], :, :] - local_rot = local_poses[..., j, :, :] - global_oris[..., j, :, :] = torch.matmul(parent_rot, local_rot) - return global_oris - - -class ForwardKinematics(torch.nn.Module): - def __init__( - self, - parents: torch.Tensor, - ) -> None: - super().__init__() - """ - Forward kinematic process. - In forward kinematics, - you start with the joint rotations - and propagate them from parent joints to child joints to - determine the position and orientation of the - end effector or joints in global space. - - Parameters - ---------- - rot_mats : torch.tensor BxNx3x3 - Tensor of rotation matrices - joints : torch.tensor BxNx3 - Locations of joints - parents : torch.tensor BxN - The kinematic tree of each object - - Returns - ------- - posed_joints : torch.tensor BxNx3 - The locations of the joints after applying the pose rotations - rel_transforms : torch.tensor BxNx4x4 - The relative (with respect to the root joint) rigid transformations - for all the joints - """ - self.parents = parents - - @staticmethod - def transform_mat(R: torch.Tensor, t: torch.Tensor) -> torch.Tensor: - """Creates a batch of transformation matrices - Args: - - R: Bx3x3 array of a batch of rotation matrices - - t: Bx3x1 array of a batch of translation vectors - Returns: - - T: Bx4x4 Transformation matrix - """ - # No padding left or right, only add an extra row - return torch.cat( - [ - torch.nn.functional.pad(R, [0, 0, 0, 1]), - torch.nn.functional.pad(t, [0, 0, 0, 1], value=1), - ], - dim=2, - ) - - def forward( - self, rotation_matrices: torch.Tensor, joints: torch.Tensor # BxNx3x3 # BxNx3 - ) -> typing.Dict[str, torch.Tensor]: - joints = torch.unsqueeze(joints, dim=-1) - - rel_joints = joints.clone() - rel_joints[:, 1:] -= joints[:, self.parents[1:]] - - transforms_mat = self.transform_mat( - rotation_matrices.reshape(-1, 3, 3), - rel_joints.reshape(-1, 3, 1).repeat( - int(rotation_matrices.shape[0] / rel_joints.shape[0]), 1, 1 - ), - ).reshape(-1, joints.shape[1], 4, 4) - - transform_chain = [transforms_mat[:, 0]] - for i in range(1, len(self.parents)): - # Subtract the joint location at the rest pose - # No need for rotation, since it's identity when at rest - curr_res = torch.matmul( - transform_chain[self.parents[i]], transforms_mat[:, i] - ) - transform_chain.append(curr_res) - - transforms = torch.stack(transform_chain, dim=1) - - # The last column of the transformations contains the posed joints - posed_joints = transforms[:, :, :3, 3] - - joints_homogen = torch.nn.functional.pad(joints, [0, 0, 0, 1]) - - rel_transforms = transforms - torch.nn.functional.pad( - torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0] - ) - - return { - "posed_joints": posed_joints, - "rel_transforms": rel_transforms, - } diff --git a/moai/monads/human/pose/forward_kinematics.py b/moai/monads/human/pose/forward_kinematics.py new file mode 100644 index 00000000..46e178cc --- /dev/null +++ b/moai/monads/human/pose/forward_kinematics.py @@ -0,0 +1,57 @@ +import typing + +import numpy as np +import torch + +__all__ = ["ForwardKinematics"] + + +class ForwardKinematics(torch.nn.Module): + def __init__( + self, parents + ): # TODO: add a col/row major param to adjust offset slicing + super().__init__() + self.parents = parents # TODO: register buffer from list? + + def forward( + self, # TODO: add parents tensor input? + rotation: torch.Tensor, # [B, (T), J, 3, 3] + position: torch.Tensor, # [B, (T), 3] + offset: torch.Tensor, # [B, (T), J, 3] + parents: typing.Optional[torch.Tensor] = None, # [B, J] + ) -> typing.Dict[str, torch.Tensor]: # { [B, (T), J, 3], [B, (T), J, 3, 3] } + joints = torch.empty(rotation.shape[:-1], device=rotation.device) + joints[..., 0, :] = position.clone() # first joint according to global position + offset = offset[ + :, np.newaxis, ..., np.newaxis + ] # NOTE: careful, col vs row major order + # offset = offset[np.newaxis, :, np.newaxis, :] #NOTE: careful, col vs row major order + global_rotation = rotation.clone() + # global_rotation = torch.empty(rotation.shape, device=rotation.device) + # global_rotation[..., 0, :3, :3] = rotation[..., 0, :3, :3].clone() + # NOTE: currently the op does not support per batch item parents + parent_indices = ( + parents[0].detach().cpu() + if parents is not None + else self.parents[0].detach().cpu() + ) + if ( + parent_indices.shape[-1] == offset.shape[-3] + ): # NOTE: to support using the same parents everywhere + parent_indices = parent_indices[1:] + for current_idx, parent_idx in enumerate( + parent_indices, start=1 + ): # NOTE: assumes parents exclude root + joints[..., current_idx, :] = torch.matmul( + global_rotation[..., parent_idx, :, :], offset[..., current_idx, :, :] + ).squeeze(-1) + global_rotation[..., current_idx, :, :] = torch.matmul( + global_rotation[..., parent_idx, :, :].clone(), + rotation[..., current_idx, :, :].clone(), + ) + joints[..., current_idx, :] += joints[..., parent_idx, :] + + return { + "positions": joints, + "rotations": global_rotation, + }