Skip to content

Commit

Permalink
beignet.func.space
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00b1 committed May 10, 2024
1 parent 1c08d8e commit a001652
Show file tree
Hide file tree
Showing 7 changed files with 200 additions and 159 deletions.
6 changes: 5 additions & 1 deletion src/beignet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)
from ._apply_rotation_matrix import apply_rotation_matrix
from ._apply_rotation_vector import apply_rotation_vector
from ._apply_transform import apply_transform
from ._compose_euler_angle import compose_euler_angle
from ._compose_quaternion import compose_quaternion
from ._compose_rotation_matrix import compose_rotation_matrix
Expand All @@ -28,6 +29,7 @@
from ._invert_quaternion import invert_quaternion
from ._invert_rotation_matrix import invert_rotation_matrix
from ._invert_rotation_vector import invert_rotation_vector
from ._invert_transform import invert_transform
from ._quaternion_identity import quaternion_identity
from ._quaternion_magnitude import quaternion_magnitude
from ._quaternion_mean import quaternion_mean
Expand Down Expand Up @@ -72,6 +74,7 @@
"apply_quaternion",
"apply_rotation_matrix",
"apply_rotation_vector",
"apply_transform",
"compose_euler_angle",
"compose_quaternion",
"compose_rotation_matrix",
Expand All @@ -86,9 +89,11 @@
"invert_quaternion",
"invert_rotation_matrix",
"invert_rotation_vector",
"invert_transform",
"quaternion_identity",
"quaternion_magnitude",
"quaternion_mean",
"quaternion_slerp",
"quaternion_to_euler_angle",
"quaternion_to_rotation_matrix",
"quaternion_to_rotation_vector",
Expand All @@ -108,6 +113,5 @@
"rotation_vector_to_euler_angle",
"rotation_vector_to_quaternion",
"rotation_vector_to_rotation_matrix",
"quaternion_slerp",
"translation_identity",
]
114 changes: 114 additions & 0 deletions src/beignet/_apply_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import torch
from torch import Tensor
from torch.autograd import Function


def _apply_transform(input: Tensor, transform: Tensor) -> Tensor:
"""
Applies an affine transformation to the position vector.
Parameters
----------
input : Tensor
Position, must have the shape `(..., dimension)`.
transform : Tensor
The affine transformation matrix, must be a scalar, a vector, or a
matrix with the shape `(dimension, dimension)`.
Returns
-------
Tensor
Affine transformed position vector, has the same shape as the
position vector.
"""
if transform.ndim == 0:
return input * transform

indices = [chr(ord("a") + index) for index in range(input.ndim - 1)]

indices = "".join(indices)

if transform.ndim == 1:
return torch.einsum(
f"i,{indices}i->{indices}i",
transform,
input,
)

if transform.ndim == 2:
return torch.einsum(
f"ij,{indices}j->{indices}i",
transform,
input,
)

raise ValueError


class _ApplyTransform(Function):
generate_vmap_rule = True

@staticmethod
def forward(transform: Tensor, position: Tensor) -> Tensor:
"""
Return affine transformed position.
Parameters
----------
transform : Tensor
Affine transformation matrix, must have shape
`(dimension, dimension)`.
position : Tensor
Position, must have shape `(..., dimension)`.
Returns
-------
Tensor
Affine transformed position of shape `(..., dimension)`.
"""
return _apply_transform(position, transform)

@staticmethod
def setup_context(ctx, inputs, output):
transformation, position = inputs

ctx.save_for_backward(transformation, position, output)

@staticmethod
def jvp(ctx, grad_transform: Tensor, grad_position: Tensor) -> (Tensor, Tensor):
transformation, position, _ = ctx.saved_tensors

output = _apply_transform(position, transformation)

grad_output = grad_position + _apply_transform(position, grad_transform)

return output, grad_output

@staticmethod
def backward(ctx, grad_output: Tensor) -> (Tensor, Tensor):
_, _, output = ctx.saved_tensors

return output, grad_output


def apply_transform(input: Tensor, transform: Tensor) -> Tensor:
"""
Return affine transformed position.
Parameters
----------
input : Tensor
Position, must have shape `(..., dimension)`.
transform : Tensor
Affine transformation matrix, must have shape
`(dimension, dimension)`.
Returns
-------
Tensor
Affine transformed position of shape `(..., dimension)`.
"""
return _ApplyTransform.apply(transform, input)
25 changes: 25 additions & 0 deletions src/beignet/_invert_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import torch
from torch import Tensor


def invert_transform(transform: Tensor) -> Tensor:
"""
Calculates the inverse of an affine transformation matrix.
Parameters
----------
transform : Tensor
The affine transformation matrix to be inverted.
Returns
-------
Tensor
The inverse of the given affine transformation matrix.
"""
if transform.ndim in {0, 1}:
return 1.0 / transform

if transform.ndim == 2:
return torch.linalg.inv(transform)

raise ValueError
Loading

0 comments on commit a001652

Please sign in to comment.