From a001652787da012d4fc98c508ce8ee73fa59666c Mon Sep 17 00:00:00 2001 From: Allen Goodman Date: Fri, 10 May 2024 11:11:18 -0400 Subject: [PATCH] beignet.func.space --- src/beignet/__init__.py | 6 +- src/beignet/_apply_transform.py | 114 +++++++++++++++ src/beignet/_invert_transform.py | 25 ++++ src/beignet/func/_space.py | 177 +++--------------------- tests/beignet/func/test__space.py | 0 tests/beignet/test__apply_transform.py | 21 +++ tests/beignet/test__invert_transform.py | 16 +++ 7 files changed, 200 insertions(+), 159 deletions(-) create mode 100644 src/beignet/_apply_transform.py create mode 100644 src/beignet/_invert_transform.py create mode 100644 tests/beignet/func/test__space.py create mode 100644 tests/beignet/test__apply_transform.py create mode 100644 tests/beignet/test__invert_transform.py diff --git a/src/beignet/__init__.py b/src/beignet/__init__.py index f91d4adc64..2b6404f460 100644 --- a/src/beignet/__init__.py +++ b/src/beignet/__init__.py @@ -12,6 +12,7 @@ ) from ._apply_rotation_matrix import apply_rotation_matrix from ._apply_rotation_vector import apply_rotation_vector +from ._apply_transform import apply_transform from ._compose_euler_angle import compose_euler_angle from ._compose_quaternion import compose_quaternion from ._compose_rotation_matrix import compose_rotation_matrix @@ -28,6 +29,7 @@ from ._invert_quaternion import invert_quaternion from ._invert_rotation_matrix import invert_rotation_matrix from ._invert_rotation_vector import invert_rotation_vector +from ._invert_transform import invert_transform from ._quaternion_identity import quaternion_identity from ._quaternion_magnitude import quaternion_magnitude from ._quaternion_mean import quaternion_mean @@ -72,6 +74,7 @@ "apply_quaternion", "apply_rotation_matrix", "apply_rotation_vector", + "apply_transform", "compose_euler_angle", "compose_quaternion", "compose_rotation_matrix", @@ -86,9 +89,11 @@ "invert_quaternion", "invert_rotation_matrix", "invert_rotation_vector", + "invert_transform", "quaternion_identity", "quaternion_magnitude", "quaternion_mean", + "quaternion_slerp", "quaternion_to_euler_angle", "quaternion_to_rotation_matrix", "quaternion_to_rotation_vector", @@ -108,6 +113,5 @@ "rotation_vector_to_euler_angle", "rotation_vector_to_quaternion", "rotation_vector_to_rotation_matrix", - "quaternion_slerp", "translation_identity", ] diff --git a/src/beignet/_apply_transform.py b/src/beignet/_apply_transform.py new file mode 100644 index 0000000000..6b0defbc40 --- /dev/null +++ b/src/beignet/_apply_transform.py @@ -0,0 +1,114 @@ +import torch +from torch import Tensor +from torch.autograd import Function + + +def _apply_transform(input: Tensor, transform: Tensor) -> Tensor: + """ + Applies an affine transformation to the position vector. + + Parameters + ---------- + input : Tensor + Position, must have the shape `(..., dimension)`. + + transform : Tensor + The affine transformation matrix, must be a scalar, a vector, or a + matrix with the shape `(dimension, dimension)`. + + Returns + ------- + Tensor + Affine transformed position vector, has the same shape as the + position vector. + """ + if transform.ndim == 0: + return input * transform + + indices = [chr(ord("a") + index) for index in range(input.ndim - 1)] + + indices = "".join(indices) + + if transform.ndim == 1: + return torch.einsum( + f"i,{indices}i->{indices}i", + transform, + input, + ) + + if transform.ndim == 2: + return torch.einsum( + f"ij,{indices}j->{indices}i", + transform, + input, + ) + + raise ValueError + + +class _ApplyTransform(Function): + generate_vmap_rule = True + + @staticmethod + def forward(transform: Tensor, position: Tensor) -> Tensor: + """ + Return affine transformed position. + + Parameters + ---------- + transform : Tensor + Affine transformation matrix, must have shape + `(dimension, dimension)`. + + position : Tensor + Position, must have shape `(..., dimension)`. + + Returns + ------- + Tensor + Affine transformed position of shape `(..., dimension)`. + """ + return _apply_transform(position, transform) + + @staticmethod + def setup_context(ctx, inputs, output): + transformation, position = inputs + + ctx.save_for_backward(transformation, position, output) + + @staticmethod + def jvp(ctx, grad_transform: Tensor, grad_position: Tensor) -> (Tensor, Tensor): + transformation, position, _ = ctx.saved_tensors + + output = _apply_transform(position, transformation) + + grad_output = grad_position + _apply_transform(position, grad_transform) + + return output, grad_output + + @staticmethod + def backward(ctx, grad_output: Tensor) -> (Tensor, Tensor): + _, _, output = ctx.saved_tensors + + return output, grad_output + + +def apply_transform(input: Tensor, transform: Tensor) -> Tensor: + """ + Return affine transformed position. + + Parameters + ---------- + input : Tensor + Position, must have shape `(..., dimension)`. + + transform : Tensor + Affine transformation matrix, must have shape + `(dimension, dimension)`. + + Returns + ------- + Tensor + Affine transformed position of shape `(..., dimension)`. + """ + return _ApplyTransform.apply(transform, input) diff --git a/src/beignet/_invert_transform.py b/src/beignet/_invert_transform.py new file mode 100644 index 0000000000..f1348fcb66 --- /dev/null +++ b/src/beignet/_invert_transform.py @@ -0,0 +1,25 @@ +import torch +from torch import Tensor + + +def invert_transform(transform: Tensor) -> Tensor: + """ + Calculates the inverse of an affine transformation matrix. + + Parameters + ---------- + transform : Tensor + The affine transformation matrix to be inverted. + + Returns + ------- + Tensor + The inverse of the given affine transformation matrix. + """ + if transform.ndim in {0, 1}: + return 1.0 / transform + + if transform.ndim == 2: + return torch.linalg.inv(transform) + + raise ValueError diff --git a/src/beignet/func/_space.py b/src/beignet/func/_space.py index 8e5bffe755..2739d685fa 100644 --- a/src/beignet/func/_space.py +++ b/src/beignet/func/_space.py @@ -2,150 +2,11 @@ import torch from torch import Tensor -from torch.autograd import Function -T = TypeVar("T") - - -def _inverse_transform(transformation: Tensor) -> Tensor: - """ - Calculates the inverse of an affine transformation matrix. - - Parameters - ---------- - transformation : Tensor - The affine transformation matrix to be inverted. - - Returns - ------- - Tensor - The inverse of the given affine transformation matrix. - """ - if transformation.ndim in {0, 1}: - return 1.0 / transformation - - if transformation.ndim == 2: - return torch.linalg.inv(transformation) - - raise ValueError("Unsupported transformation dimensions.") - - -def _apply_transform(transformation: Tensor, position: Tensor) -> Tensor: - """ - Applies an affine transformation to the position vector. - - Parameters - ---------- - position : Tensor - Position, must have the shape `(..., dimension)`. - - transformation : Tensor - The affine transformation matrix, must be a scalar, a vector, or a - matrix with the shape `(dimension, dimension)`. - - Returns - ------- - Tensor - Affine transformed position vector, has the same shape as the - position vector. - """ - if transformation.ndim == 0: - return position * transformation - - indices = [chr(ord("a") + index) for index in range(position.ndim - 1)] - - indices = "".join(indices) - - if transformation.ndim == 1: - return torch.einsum( - f"i,{indices}i->{indices}i", - transformation, - position, - ) - - if transformation.ndim == 2: - return torch.einsum( - f"ij,{indices}j->{indices}i", - transformation, - position, - ) - - raise ValueError("Unsupported transformation dimensions.") - - -def apply_transform(transformation: Tensor, position: Tensor) -> Tensor: - """ - Return affine transformed position. - - Parameters - ---------- - transformation : Tensor - Affine transformation matrix, must have shape - `(dimension, dimension)`. - - position : Tensor - Position, must have shape `(..., dimension)`. - - Returns - ------- - Tensor - Affine transformed position of shape `(..., dimension)`. - """ - - class _Transform(Function): - generate_vmap_rule = True - - @staticmethod - def forward(transformation: Tensor, position: Tensor) -> Tensor: - """ - Return affine transformed position. - - Parameters - ---------- - transformation : Tensor - Affine transformation matrix, must have shape - `(dimension, dimension)`. - - position : Tensor - Position, must have shape `(..., dimension)`. - - Returns - ------- - Tensor - Affine transformed position of shape `(..., dimension)`. - """ - return _apply_transform(transformation, position) - - @staticmethod - def setup_context(ctx, inputs, output): - transformation, position = inputs - - ctx.save_for_backward(transformation, position, output) - - @staticmethod - def jvp( - ctx, - grad_transformation: Tensor, - grad_position: Tensor, - ) -> (Tensor, Tensor): - transformation, position, _ = ctx.saved_tensors - - output = _apply_transform(transformation, position) - - grad_output = grad_position + _apply_transform( - grad_transformation, - position, - ) - - return output, grad_output - - @staticmethod - def backward(ctx, grad_output: Tensor) -> (Tensor, Tensor): - _, _, output = ctx.saved_tensors - - return output, grad_output +from beignet._apply_transform import _apply_transform, apply_transform +from beignet._invert_transform import invert_transform - return _Transform.apply(transformation, position) +T = TypeVar("T") def space( @@ -265,7 +126,7 @@ def displacement_fn( raise ValueError if perturbation is not None: - return _apply_transform(input - other, perturbation) + return _apply_transform(perturbation, input - other) return input - other @@ -275,7 +136,7 @@ def shift_fn(input: Tensor, other: Tensor, **_) -> Tensor: return displacement_fn, shift_fn if parallelepiped: - inverse_transformation = _inverse_transform(dimensions) + inverse_transformation = invert_transform(dimensions) if normalized: @@ -303,12 +164,12 @@ def displacement_fn( raise ValueError displacement = apply_transform( - _transformation, torch.remainder(input - other + 1.0 * 0.5, 1.0) - 1.0 * 0.5, + _transformation, ) if perturbation is not None: - return _apply_transform(displacement, perturbation) + return _apply_transform(perturbation, displacement) return displacement @@ -325,12 +186,12 @@ def shift_fn(input: Tensor, other: Tensor, **kwargs) -> Tensor: if "transformation" in kwargs: _transformation = kwargs["transformation"] - _inverse_transformation = _inverse_transform(_transformation) + _inverse_transformation = invert_transform(_transformation) if "updated_transformation" in kwargs: _transformation = kwargs["updated_transformation"] - return u(input, apply_transform(_inverse_transformation, other)) + return u(input, apply_transform(other, _inverse_transformation)) return displacement_fn, shift_fn @@ -342,14 +203,14 @@ def shift_fn(input: Tensor, other: Tensor, **kwargs) -> Tensor: if "transformation" in kwargs: _transformation = kwargs["transformation"] - _inverse_transformation = _inverse_transform( + _inverse_transformation = invert_transform( _transformation, ) if "updated_transformation" in kwargs: _transformation = kwargs["updated_transformation"] - return input + apply_transform(_inverse_transformation, other) + return input + apply_transform(other, _inverse_transformation) return displacement_fn, shift_fn @@ -367,13 +228,13 @@ def displacement_fn( if "transformation" in kwargs: _transformation = kwargs["transformation"] - _inverse_transformation = _inverse_transform(_transformation) + _inverse_transformation = invert_transform(_transformation) if "updated_transformation" in kwargs: _transformation = kwargs["updated_transformation"] - input = apply_transform(_inverse_transformation, input) - other = apply_transform(_inverse_transformation, other) + input = apply_transform(input, _inverse_transformation) + other = apply_transform(other, _inverse_transformation) if len(input.shape) != 1: raise ValueError @@ -382,12 +243,12 @@ def displacement_fn( raise ValueError displacement = apply_transform( - _transformation, torch.remainder(input - other + 1.0 * 0.5, 1.0) - 1.0 * 0.5, + _transformation, ) if perturbation is not None: - return _apply_transform(displacement, perturbation) + return _apply_transform(perturbation, displacement) return displacement @@ -404,7 +265,7 @@ def shift_fn(input: Tensor, other: Tensor, **kwargs) -> Tensor: if "transformation" in kwargs: _transformation = kwargs["transformation"] - _inverse_transformation = _inverse_transform( + _inverse_transformation = invert_transform( _transformation, ) @@ -412,11 +273,11 @@ def shift_fn(input: Tensor, other: Tensor, **kwargs) -> Tensor: _transformation = kwargs["updated_transformation"] return apply_transform( - _transformation, u( apply_transform(_inverse_transformation, input), apply_transform(_inverse_transformation, other), ), + _transformation, ) return displacement_fn, shift_fn @@ -442,7 +303,7 @@ def displacement_fn( displacement = torch.remainder(input - other + dimensions * 0.5, dimensions) if perturbation is not None: - return _apply_transform(displacement - dimensions * 0.5, perturbation) + return _apply_transform(perturbation, displacement - dimensions * 0.5) return displacement - dimensions * 0.5 diff --git a/tests/beignet/func/test__space.py b/tests/beignet/func/test__space.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/beignet/test__apply_transform.py b/tests/beignet/test__apply_transform.py new file mode 100644 index 0000000000..7811aebdba --- /dev/null +++ b/tests/beignet/test__apply_transform.py @@ -0,0 +1,21 @@ +import torch +import torch.func +from beignet import apply_transform +from torch import Tensor + + +def test_apply_transform(): + input = torch.randn([32, 3]) + + transform = torch.randn([3, 3]) + + def f(r: Tensor) -> Tensor: + return torch.sum(r**2) + + def g(r: Tensor, t: Tensor) -> Tensor: + return torch.sum(apply_transform(r, t) ** 2) + + torch.testing.assert_close( + torch.func.grad(f)(apply_transform(input, transform)), + torch.func.grad(g, 0)(input, transform), + ) diff --git a/tests/beignet/test__invert_transform.py b/tests/beignet/test__invert_transform.py new file mode 100644 index 0000000000..adbd8a9e88 --- /dev/null +++ b/tests/beignet/test__invert_transform.py @@ -0,0 +1,16 @@ +import beignet +import torch.testing + + +def test_invert_transform(): + input = torch.randn([32, 3]) + + transform = torch.randn([3, 3]) + + torch.testing.assert_close( + input, + beignet.apply_transform( + beignet.apply_transform(input, transform), + beignet.invert_transform(transform), + ), + )