From 883b3e55fac5cac50563b6b2a52bd7260b728d4b Mon Sep 17 00:00:00 2001 From: Henry Isaacson Date: Fri, 19 Apr 2024 14:48:27 -0400 Subject: [PATCH 01/15] beignet.func.space --- src/beignet/func/__init__.py | 1 + src/beignet/func/_space.py | 459 +++++++++++++++++++++++++++++++++++ 2 files changed, 460 insertions(+) create mode 100644 src/beignet/func/__init__.py create mode 100644 src/beignet/func/_space.py diff --git a/src/beignet/func/__init__.py b/src/beignet/func/__init__.py new file mode 100644 index 0000000000..cc1c09cfca --- /dev/null +++ b/src/beignet/func/__init__.py @@ -0,0 +1 @@ +from ._space import space diff --git a/src/beignet/func/_space.py b/src/beignet/func/_space.py new file mode 100644 index 0000000000..72e6ed1fde --- /dev/null +++ b/src/beignet/func/_space.py @@ -0,0 +1,459 @@ +from typing import ( + Callable, + Optional, + Tuple, + TypeVar, +) + +import torch +from torch import Tensor +from torch.autograd import Function + +T = TypeVar("T") + + +def _inverse_transform(transformation: Tensor) -> Tensor: + """ + Calculates the inverse of an affine transformation matrix. + + Parameters + ---------- + transformation : Tensor + The affine transformation matrix to be inverted. + + Returns + ------- + Tensor + The inverse of the given affine transformation matrix. + """ + if transformation.ndim in {0, 1}: + return 1.0 / transformation + + if transformation.ndim == 2: + return torch.linalg.inv(transformation) + + raise ValueError("Unsupported transformation dimensions.") + + +def _transform(transformation: Tensor, position: Tensor) -> Tensor: + """ + Applies an affine transformation to the position vector. + + Parameters + ---------- + position : Tensor + Position, must have the shape `(..., dimension)`. + + transformation : Tensor + The affine transformation matrix, must be a scalar, a vector, or a + matrix with the shape `(dimension, dimension)`. + + Returns + ------- + Tensor + Affine transformed position vector, has the same shape as the + position vector. + """ + if transformation.ndim == 0: + return position * transformation + + indices = [chr(ord("a") + index) for index in range(position.ndim - 1)] + + indices = "".join(indices) + + if transformation.ndim == 1: + return torch.einsum( + f"i,{indices}i->{indices}i", + transformation, + position, + ) + + if transformation.ndim == 2: + return torch.einsum( + f"ij,{indices}j->{indices}i", + transformation, + position, + ) + + raise ValueError("Unsupported transformation dimensions.") + + +def transform(transformation: Tensor, position: Tensor) -> Tensor: + """ + Return affine transformed position. + + Parameters + ---------- + transformation : Tensor + Affine transformation matrix, must have shape + `(dimension, dimension)`. + + position : Tensor + Position, must have shape `(..., dimension)`. + + Returns + ------- + Tensor + Affine transformed position of shape `(..., dimension)`. + """ + + class _Transform(Function): + generate_vmap_rule = True + + @staticmethod + def forward(transformation: Tensor, position: Tensor) -> Tensor: + """ + Return affine transformed position. + + Parameters + ---------- + transformation : Tensor + Affine transformation matrix, must have shape + `(dimension, dimension)`. + + position : Tensor + Position, must have shape `(..., dimension)`. + + Returns + ------- + Tensor + Affine transformed position of shape `(..., dimension)`. + """ + return _transform(transformation, position) + + @staticmethod + def setup_context(ctx, inputs, output): + transformation, position = inputs + + ctx.save_for_backward(transformation, position, output) + + @staticmethod + def jvp( + ctx, + grad_transformation: Tensor, + grad_position: Tensor, + ) -> Tuple[Tensor, Tensor]: + transformation, position, _ = ctx.saved_tensors + + output = _transform(transformation, position) + + grad_output = grad_position + _transform( + grad_transformation, + position, + ) + + return output, grad_output + + @staticmethod + def backward(ctx, grad_output: Tensor) -> Tuple[Tensor, Tensor]: + _, _, output = ctx.saved_tensors + + return output, grad_output + + return _Transform.apply(transformation, position) + + +def space( + dimensions: Optional[Tensor] = None, + *, + normalized: bool = True, + parallelepiped: bool = True, + remapped: bool = True, +) -> Tuple[Callable, Callable]: + r"""Define a simulation space. + + This function is fundamental in constructing simulation spaces derived from + subsets of $\mathbb{R}^{D}$ (where $D = 1$, $2$, or $3$) and is + instrumental in setting up simulation environments with specific + characteristics (e.g., periodic boundary conditions). The function returns + a a displacement function and a shift function to compute particle + interactions and movements in space. + + This function supports deformation of the simulation cell, crucial for + certain types of simulations, such as those involving finite deformations + or the computation of elastic constants. + + Parameters + ---------- + dimensions : Optional[Tensor], default=None + Dimensions of the simulation space. Interpretation varies based on the + value of `parallelepiped`. If `parallelepiped` is `True`, must be an + affine transformation, $T$, specified in one of three ways: a cube, + $L$; an orthorhombic unit cell, $[L_{x}, L_{y}, L_{z}]$; or a triclinic + cell, upper triangular matrix. If `parallelepiped` is `False`, must be + the edge lengths. If `None`, the simulation space has free boundary + conditions. + + normalized : bool, default=True + If `True`, positions are stored in the unit cube. Displacements and + shifts are computed in a normalized simulation space and can be + transformed back to real simulation space using the provided affine + transformation matrix. If `False`, positions are expressed and + computations performed directly in the real simulation space. + + parallelepiped : bool, default=True + If `True`, the simulation space is defined as a ${1, 2, 3}$-dimensional + parallelepiped with periodic boundary conditions. If `False`, the space + is defined on a ${1, 2, 3}$-dimensional hypercube. + + remapped : bool, default=True + If `True`, positions and displacements are remapped to stay in the + bounds of the defined simulation space. A rempapped simulation space is + topologically equivalent to a torus, ensuring that particles exiting + one boundary re-enter from the opposite side. This is particularly + relevant for simulation spaces with periodic boundary conditions. + + Returns + ------- + Tuple[Callable[[Tensor, Tensor], Tensor], Callable[[Tensor, Tensor], Tensor]] + A tuple containing two functions: + + 1. The displacement function, $\overrightarrow{d}$, measures the + difference between two points in the simulation space, factoring in + the geometry and boundary conditions. This function is used to + calculate particle interactions and dynamics. + 2. The shift function, $u$, applies a displacement vector to a point + in the space, effectively moving it. This function is used to + update simulated particle positions. + + Examples + -------- + transformation = torch.tensor([10.0]) + + displacement_fn, shift_fn = space( + transformation, + normalized=False, + ) + + normalized_displacement_fn, normalized_shift_fn = space( + transformation, + normalized=True, + ) + + normalized_position = torch.rand([4, 3]) + + position = transformation * normalized_position + + displacement = torch.randn([4, 3]) + + torch.testing.assert_close( + displacement_fn(position[0], position[1]), + normalized_displacement_fn( + normalized_position[0], + normalized_position[1], + ), + ) + """ + if isinstance(dimensions, (int, float)): + dimensions = torch.tensor([dimensions]) + + if dimensions is None: + + def displacement_fn( + input: Tensor, + other: Tensor, + *, + perturbation: Optional[Tensor] = None, + **_, + ) -> Tensor: + if len(input.shape) != 1: + raise ValueError + + if input.shape != other.shape: + raise ValueError + + if perturbation is not None: + return _transform(input - other, perturbation) + + return input - other + + def shift_fn(input: Tensor, other: Tensor, **_) -> Tensor: + return input + other + + return displacement_fn, shift_fn + + if parallelepiped: + inverse_transformation = _inverse_transform(dimensions) + + if normalized: + + def displacement_fn( + input: Tensor, + other: Tensor, + *, + perturbation: Optional[Tensor] = None, + **kwargs, + ) -> Tensor: + _transformation = dimensions + + _inverse_transformation = inverse_transformation + + if "transformation" in kwargs: + _transformation = kwargs["transformation"] + + if "updated_transformation" in kwargs: + _transformation = kwargs["updated_transformation"] + + if len(input.shape) != 1: + raise ValueError + + if input.shape != other.shape: + raise ValueError + + displacement = transform( + _transformation, + torch.remainder(input - other + 1.0 * 0.5, 1.0) - 1.0 * 0.5, + ) + + if perturbation is not None: + return _transform(displacement, perturbation) + + return displacement + + if remapped: + + def u(input: Tensor, other: Tensor) -> Tensor: + return torch.remainder(input + other, 1.0) + + def shift_fn(input: Tensor, other: Tensor, **kwargs) -> Tensor: + _transformation = dimensions + + _inverse_transformation = inverse_transformation + + if "transformation" in kwargs: + _transformation = kwargs["transformation"] + + _inverse_transformation = _inverse_transform(_transformation) + + if "updated_transformation" in kwargs: + _transformation = kwargs["updated_transformation"] + + return u(input, transform(_inverse_transformation, other)) + + return displacement_fn, shift_fn + + def shift_fn(input: Tensor, other: Tensor, **kwargs) -> Tensor: + _transformation = dimensions + + _inverse_transformation = inverse_transformation + + if "transformation" in kwargs: + _transformation = kwargs["transformation"] + + _inverse_transformation = _inverse_transform( + _transformation, + ) + + if "updated_transformation" in kwargs: + _transformation = kwargs["updated_transformation"] + + return input + transform(_inverse_transformation, other) + + return displacement_fn, shift_fn + + def displacement_fn( + input: Tensor, + other: Tensor, + *, + perturbation: Optional[Tensor] = None, + **kwargs, + ) -> Tensor: + _transformation = dimensions + + _inverse_transformation = inverse_transformation + + if "transformation" in kwargs: + _transformation = kwargs["transformation"] + + _inverse_transformation = _inverse_transform(_transformation) + + if "updated_transformation" in kwargs: + _transformation = kwargs["updated_transformation"] + + input = transform(_inverse_transformation, input) + other = transform(_inverse_transformation, other) + + if len(input.shape) != 1: + raise ValueError + + if input.shape != other.shape: + raise ValueError + + displacement = transform( + _transformation, + torch.remainder(input - other + 1.0 * 0.5, 1.0) - 1.0 * 0.5, + ) + + if perturbation is not None: + return _transform(displacement, perturbation) + + return displacement + + if remapped: + + def u(a: Tensor, b: Tensor) -> Tensor: + return torch.remainder(a + b, 1.0) + + def shift_fn(a: Tensor, b: Tensor, **kwargs) -> Tensor: + _transformation = dimensions + + _inverse_transformation = inverse_transformation + + if "transformation" in kwargs: + _transformation = kwargs["transformation"] + + _inverse_transformation = _inverse_transform( + _transformation, + ) + + if "updated_transformation" in kwargs: + _transformation = kwargs["updated_transformation"] + + return transform( + _transformation, + u( + transform(_inverse_transformation, a), + transform(_inverse_transformation, b), + ), + ) + + return displacement_fn, shift_fn + + def shift_fn(input: Tensor, other: Tensor, **_) -> Tensor: + return input + other + + return displacement_fn, shift_fn + + def displacement_fn( + input: Tensor, + other: Tensor, + *, + perturbation: Tensor | None = None, + **_, + ) -> Tensor: + if len(input.shape) != 1: + raise ValueError + + if input.shape != other.shape: + raise ValueError + + displacement = ( + torch.remainder(input - other + dimensions * 0.5, dimensions) + - dimensions * 0.5 + ) + + if perturbation is not None: + return _transform(displacement, perturbation) + + return displacement + + if remapped: + + def shift_fn(input: Tensor, other: Tensor, **_) -> Tensor: + return torch.remainder(input + other, dimensions) + else: + + def shift_fn(input: Tensor, other: Tensor, **_) -> Tensor: + return input + other + + return displacement_fn, shift_fn From 49c060497f8a2b96050a97be1b5d954c4e4e91e0 Mon Sep 17 00:00:00 2001 From: Allen Goodman Date: Thu, 9 May 2024 14:51:05 -0400 Subject: [PATCH 02/15] beignet.func.space --- src/beignet/func/_space.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/src/beignet/func/_space.py b/src/beignet/func/_space.py index 72e6ed1fde..3280fa235f 100644 --- a/src/beignet/func/_space.py +++ b/src/beignet/func/_space.py @@ -391,10 +391,10 @@ def displacement_fn( if remapped: - def u(a: Tensor, b: Tensor) -> Tensor: - return torch.remainder(a + b, 1.0) + def u(input: Tensor, other: Tensor) -> Tensor: + return torch.remainder(input + other, 1.0) - def shift_fn(a: Tensor, b: Tensor, **kwargs) -> Tensor: + def shift_fn(input: Tensor, other: Tensor, **kwargs) -> Tensor: _transformation = dimensions _inverse_transformation = inverse_transformation @@ -412,8 +412,8 @@ def shift_fn(a: Tensor, b: Tensor, **kwargs) -> Tensor: return transform( _transformation, u( - transform(_inverse_transformation, a), - transform(_inverse_transformation, b), + transform(_inverse_transformation, input), + transform(_inverse_transformation, other), ), ) @@ -437,15 +437,12 @@ def displacement_fn( if input.shape != other.shape: raise ValueError - displacement = ( - torch.remainder(input - other + dimensions * 0.5, dimensions) - - dimensions * 0.5 - ) + displacement = torch.remainder(input - other + dimensions * 0.5, dimensions) if perturbation is not None: - return _transform(displacement, perturbation) + return _transform(displacement - dimensions * 0.5, perturbation) - return displacement + return displacement - dimensions * 0.5 if remapped: From 9139a4e3a0d4afbdbde791570469e287f6979ca1 Mon Sep 17 00:00:00 2001 From: Allen Goodman Date: Thu, 9 May 2024 15:09:44 -0400 Subject: [PATCH 03/15] beignet.func.space --- docs/beignet.func.md | 3 ++ src/beignet/func/_space.py | 96 +++++++++++++++++++------------------- 2 files changed, 52 insertions(+), 47 deletions(-) create mode 100644 docs/beignet.func.md diff --git a/docs/beignet.func.md b/docs/beignet.func.md new file mode 100644 index 0000000000..acc54935f7 --- /dev/null +++ b/docs/beignet.func.md @@ -0,0 +1,3 @@ +# beignet.func + +::: beignet.func.space \ No newline at end of file diff --git a/src/beignet/func/_space.py b/src/beignet/func/_space.py index 3280fa235f..8e5bffe755 100644 --- a/src/beignet/func/_space.py +++ b/src/beignet/func/_space.py @@ -1,9 +1,4 @@ -from typing import ( - Callable, - Optional, - Tuple, - TypeVar, -) +from typing import Callable, TypeVar import torch from torch import Tensor @@ -35,7 +30,7 @@ def _inverse_transform(transformation: Tensor) -> Tensor: raise ValueError("Unsupported transformation dimensions.") -def _transform(transformation: Tensor, position: Tensor) -> Tensor: +def _apply_transform(transformation: Tensor, position: Tensor) -> Tensor: """ Applies an affine transformation to the position vector. @@ -78,7 +73,7 @@ def _transform(transformation: Tensor, position: Tensor) -> Tensor: raise ValueError("Unsupported transformation dimensions.") -def transform(transformation: Tensor, position: Tensor) -> Tensor: +def apply_transform(transformation: Tensor, position: Tensor) -> Tensor: """ Return affine transformed position. @@ -119,7 +114,7 @@ def forward(transformation: Tensor, position: Tensor) -> Tensor: Tensor Affine transformed position of shape `(..., dimension)`. """ - return _transform(transformation, position) + return _apply_transform(transformation, position) @staticmethod def setup_context(ctx, inputs, output): @@ -132,12 +127,12 @@ def jvp( ctx, grad_transformation: Tensor, grad_position: Tensor, - ) -> Tuple[Tensor, Tensor]: + ) -> (Tensor, Tensor): transformation, position, _ = ctx.saved_tensors - output = _transform(transformation, position) + output = _apply_transform(transformation, position) - grad_output = grad_position + _transform( + grad_output = grad_position + _apply_transform( grad_transformation, position, ) @@ -145,7 +140,7 @@ def jvp( return output, grad_output @staticmethod - def backward(ctx, grad_output: Tensor) -> Tuple[Tensor, Tensor]: + def backward(ctx, grad_output: Tensor) -> (Tensor, Tensor): _, _, output = ctx.saved_tensors return output, grad_output @@ -154,12 +149,12 @@ def backward(ctx, grad_output: Tensor) -> Tuple[Tensor, Tensor]: def space( - dimensions: Optional[Tensor] = None, + dimensions: Tensor | None = None, *, normalized: bool = True, parallelepiped: bool = True, remapped: bool = True, -) -> Tuple[Callable, Callable]: +) -> (Callable, Callable): r"""Define a simulation space. This function is fundamental in constructing simulation spaces derived from @@ -175,20 +170,27 @@ def space( Parameters ---------- - dimensions : Optional[Tensor], default=None - Dimensions of the simulation space. Interpretation varies based on the - value of `parallelepiped`. If `parallelepiped` is `True`, must be an - affine transformation, $T$, specified in one of three ways: a cube, - $L$; an orthorhombic unit cell, $[L_{x}, L_{y}, L_{z}]$; or a triclinic - cell, upper triangular matrix. If `parallelepiped` is `False`, must be - the edge lengths. If `None`, the simulation space has free boundary + dimensions : Tensor | None, default=None + Dimensions of the simulation space. + + Interpretation varies based on the value of `parallelepiped`. If + `parallelepiped` is `True`, must be an affine transformation, $T$, + specified in one of three ways: + + 1. a cube, $L$; + 2. an orthorhombic unit cell, $[L_{x}, L_{y}, L_{z}]$; or + 3. a triclinic cell, upper triangular matrix. + + If `parallelepiped` is `False`, must be the edge lengths. + + If `dimensions` is `None`, the simulation space has free boundary conditions. normalized : bool, default=True - If `True`, positions are stored in the unit cube. Displacements and - shifts are computed in a normalized simulation space and can be - transformed back to real simulation space using the provided affine - transformation matrix. If `False`, positions are expressed and + If `normalized` is `True`, positions are stored in the unit cube. + Displacements and shifts are computed in a normalized simulation space + and can be transformed back to real simulation space using the affine + transformation. If `normalized` is `False`, positions are expressed and computations performed directly in the real simulation space. parallelepiped : bool, default=True @@ -200,18 +202,18 @@ def space( If `True`, positions and displacements are remapped to stay in the bounds of the defined simulation space. A rempapped simulation space is topologically equivalent to a torus, ensuring that particles exiting - one boundary re-enter from the opposite side. This is particularly - relevant for simulation spaces with periodic boundary conditions. + one boundary re-enter from the opposite side. Returns ------- - Tuple[Callable[[Tensor, Tensor], Tensor], Callable[[Tensor, Tensor], Tensor]] - A tuple containing two functions: + (Callable[[Tensor, Tensor], Tensor], Callable[[Tensor, Tensor], Tensor]) + A pair of functions: - 1. The displacement function, $\overrightarrow{d}$, measures the + 1. The displacement function, $\vec{d}$, measures the difference between two points in the simulation space, factoring in the geometry and boundary conditions. This function is used to calculate particle interactions and dynamics. + 2. The shift function, $u$, applies a displacement vector to a point in the space, effectively moving it. This function is used to update simulated particle positions. @@ -253,7 +255,7 @@ def displacement_fn( input: Tensor, other: Tensor, *, - perturbation: Optional[Tensor] = None, + perturbation: Tensor | None = None, **_, ) -> Tensor: if len(input.shape) != 1: @@ -263,7 +265,7 @@ def displacement_fn( raise ValueError if perturbation is not None: - return _transform(input - other, perturbation) + return _apply_transform(input - other, perturbation) return input - other @@ -281,7 +283,7 @@ def displacement_fn( input: Tensor, other: Tensor, *, - perturbation: Optional[Tensor] = None, + perturbation: Tensor | None = None, **kwargs, ) -> Tensor: _transformation = dimensions @@ -300,13 +302,13 @@ def displacement_fn( if input.shape != other.shape: raise ValueError - displacement = transform( + displacement = apply_transform( _transformation, torch.remainder(input - other + 1.0 * 0.5, 1.0) - 1.0 * 0.5, ) if perturbation is not None: - return _transform(displacement, perturbation) + return _apply_transform(displacement, perturbation) return displacement @@ -328,7 +330,7 @@ def shift_fn(input: Tensor, other: Tensor, **kwargs) -> Tensor: if "updated_transformation" in kwargs: _transformation = kwargs["updated_transformation"] - return u(input, transform(_inverse_transformation, other)) + return u(input, apply_transform(_inverse_transformation, other)) return displacement_fn, shift_fn @@ -347,7 +349,7 @@ def shift_fn(input: Tensor, other: Tensor, **kwargs) -> Tensor: if "updated_transformation" in kwargs: _transformation = kwargs["updated_transformation"] - return input + transform(_inverse_transformation, other) + return input + apply_transform(_inverse_transformation, other) return displacement_fn, shift_fn @@ -355,7 +357,7 @@ def displacement_fn( input: Tensor, other: Tensor, *, - perturbation: Optional[Tensor] = None, + perturbation: Tensor | None = None, **kwargs, ) -> Tensor: _transformation = dimensions @@ -370,8 +372,8 @@ def displacement_fn( if "updated_transformation" in kwargs: _transformation = kwargs["updated_transformation"] - input = transform(_inverse_transformation, input) - other = transform(_inverse_transformation, other) + input = apply_transform(_inverse_transformation, input) + other = apply_transform(_inverse_transformation, other) if len(input.shape) != 1: raise ValueError @@ -379,13 +381,13 @@ def displacement_fn( if input.shape != other.shape: raise ValueError - displacement = transform( + displacement = apply_transform( _transformation, torch.remainder(input - other + 1.0 * 0.5, 1.0) - 1.0 * 0.5, ) if perturbation is not None: - return _transform(displacement, perturbation) + return _apply_transform(displacement, perturbation) return displacement @@ -409,11 +411,11 @@ def shift_fn(input: Tensor, other: Tensor, **kwargs) -> Tensor: if "updated_transformation" in kwargs: _transformation = kwargs["updated_transformation"] - return transform( + return apply_transform( _transformation, u( - transform(_inverse_transformation, input), - transform(_inverse_transformation, other), + apply_transform(_inverse_transformation, input), + apply_transform(_inverse_transformation, other), ), ) @@ -440,7 +442,7 @@ def displacement_fn( displacement = torch.remainder(input - other + dimensions * 0.5, dimensions) if perturbation is not None: - return _transform(displacement - dimensions * 0.5, perturbation) + return _apply_transform(displacement - dimensions * 0.5, perturbation) return displacement - dimensions * 0.5 From d6ef8977924a79b5293e90400fdbd6aea9e4a777 Mon Sep 17 00:00:00 2001 From: Allen Goodman Date: Fri, 10 May 2024 11:11:18 -0400 Subject: [PATCH 04/15] beignet.func.space --- src/beignet/__init__.py | 6 +- src/beignet/_apply_transform.py | 114 +++++++++++++++ src/beignet/_invert_transform.py | 25 ++++ src/beignet/func/_space.py | 177 +++--------------------- tests/beignet/func/test__space.py | 0 tests/beignet/test__apply_transform.py | 21 +++ tests/beignet/test__invert_transform.py | 16 +++ 7 files changed, 200 insertions(+), 159 deletions(-) create mode 100644 src/beignet/_apply_transform.py create mode 100644 src/beignet/_invert_transform.py create mode 100644 tests/beignet/func/test__space.py create mode 100644 tests/beignet/test__apply_transform.py create mode 100644 tests/beignet/test__invert_transform.py diff --git a/src/beignet/__init__.py b/src/beignet/__init__.py index f91d4adc64..2b6404f460 100644 --- a/src/beignet/__init__.py +++ b/src/beignet/__init__.py @@ -12,6 +12,7 @@ ) from ._apply_rotation_matrix import apply_rotation_matrix from ._apply_rotation_vector import apply_rotation_vector +from ._apply_transform import apply_transform from ._compose_euler_angle import compose_euler_angle from ._compose_quaternion import compose_quaternion from ._compose_rotation_matrix import compose_rotation_matrix @@ -28,6 +29,7 @@ from ._invert_quaternion import invert_quaternion from ._invert_rotation_matrix import invert_rotation_matrix from ._invert_rotation_vector import invert_rotation_vector +from ._invert_transform import invert_transform from ._quaternion_identity import quaternion_identity from ._quaternion_magnitude import quaternion_magnitude from ._quaternion_mean import quaternion_mean @@ -72,6 +74,7 @@ "apply_quaternion", "apply_rotation_matrix", "apply_rotation_vector", + "apply_transform", "compose_euler_angle", "compose_quaternion", "compose_rotation_matrix", @@ -86,9 +89,11 @@ "invert_quaternion", "invert_rotation_matrix", "invert_rotation_vector", + "invert_transform", "quaternion_identity", "quaternion_magnitude", "quaternion_mean", + "quaternion_slerp", "quaternion_to_euler_angle", "quaternion_to_rotation_matrix", "quaternion_to_rotation_vector", @@ -108,6 +113,5 @@ "rotation_vector_to_euler_angle", "rotation_vector_to_quaternion", "rotation_vector_to_rotation_matrix", - "quaternion_slerp", "translation_identity", ] diff --git a/src/beignet/_apply_transform.py b/src/beignet/_apply_transform.py new file mode 100644 index 0000000000..6b0defbc40 --- /dev/null +++ b/src/beignet/_apply_transform.py @@ -0,0 +1,114 @@ +import torch +from torch import Tensor +from torch.autograd import Function + + +def _apply_transform(input: Tensor, transform: Tensor) -> Tensor: + """ + Applies an affine transformation to the position vector. + + Parameters + ---------- + input : Tensor + Position, must have the shape `(..., dimension)`. + + transform : Tensor + The affine transformation matrix, must be a scalar, a vector, or a + matrix with the shape `(dimension, dimension)`. + + Returns + ------- + Tensor + Affine transformed position vector, has the same shape as the + position vector. + """ + if transform.ndim == 0: + return input * transform + + indices = [chr(ord("a") + index) for index in range(input.ndim - 1)] + + indices = "".join(indices) + + if transform.ndim == 1: + return torch.einsum( + f"i,{indices}i->{indices}i", + transform, + input, + ) + + if transform.ndim == 2: + return torch.einsum( + f"ij,{indices}j->{indices}i", + transform, + input, + ) + + raise ValueError + + +class _ApplyTransform(Function): + generate_vmap_rule = True + + @staticmethod + def forward(transform: Tensor, position: Tensor) -> Tensor: + """ + Return affine transformed position. + + Parameters + ---------- + transform : Tensor + Affine transformation matrix, must have shape + `(dimension, dimension)`. + + position : Tensor + Position, must have shape `(..., dimension)`. + + Returns + ------- + Tensor + Affine transformed position of shape `(..., dimension)`. + """ + return _apply_transform(position, transform) + + @staticmethod + def setup_context(ctx, inputs, output): + transformation, position = inputs + + ctx.save_for_backward(transformation, position, output) + + @staticmethod + def jvp(ctx, grad_transform: Tensor, grad_position: Tensor) -> (Tensor, Tensor): + transformation, position, _ = ctx.saved_tensors + + output = _apply_transform(position, transformation) + + grad_output = grad_position + _apply_transform(position, grad_transform) + + return output, grad_output + + @staticmethod + def backward(ctx, grad_output: Tensor) -> (Tensor, Tensor): + _, _, output = ctx.saved_tensors + + return output, grad_output + + +def apply_transform(input: Tensor, transform: Tensor) -> Tensor: + """ + Return affine transformed position. + + Parameters + ---------- + input : Tensor + Position, must have shape `(..., dimension)`. + + transform : Tensor + Affine transformation matrix, must have shape + `(dimension, dimension)`. + + Returns + ------- + Tensor + Affine transformed position of shape `(..., dimension)`. + """ + return _ApplyTransform.apply(transform, input) diff --git a/src/beignet/_invert_transform.py b/src/beignet/_invert_transform.py new file mode 100644 index 0000000000..f1348fcb66 --- /dev/null +++ b/src/beignet/_invert_transform.py @@ -0,0 +1,25 @@ +import torch +from torch import Tensor + + +def invert_transform(transform: Tensor) -> Tensor: + """ + Calculates the inverse of an affine transformation matrix. + + Parameters + ---------- + transform : Tensor + The affine transformation matrix to be inverted. + + Returns + ------- + Tensor + The inverse of the given affine transformation matrix. + """ + if transform.ndim in {0, 1}: + return 1.0 / transform + + if transform.ndim == 2: + return torch.linalg.inv(transform) + + raise ValueError diff --git a/src/beignet/func/_space.py b/src/beignet/func/_space.py index 8e5bffe755..2739d685fa 100644 --- a/src/beignet/func/_space.py +++ b/src/beignet/func/_space.py @@ -2,150 +2,11 @@ import torch from torch import Tensor -from torch.autograd import Function -T = TypeVar("T") - - -def _inverse_transform(transformation: Tensor) -> Tensor: - """ - Calculates the inverse of an affine transformation matrix. - - Parameters - ---------- - transformation : Tensor - The affine transformation matrix to be inverted. - - Returns - ------- - Tensor - The inverse of the given affine transformation matrix. - """ - if transformation.ndim in {0, 1}: - return 1.0 / transformation - - if transformation.ndim == 2: - return torch.linalg.inv(transformation) - - raise ValueError("Unsupported transformation dimensions.") - - -def _apply_transform(transformation: Tensor, position: Tensor) -> Tensor: - """ - Applies an affine transformation to the position vector. - - Parameters - ---------- - position : Tensor - Position, must have the shape `(..., dimension)`. - - transformation : Tensor - The affine transformation matrix, must be a scalar, a vector, or a - matrix with the shape `(dimension, dimension)`. - - Returns - ------- - Tensor - Affine transformed position vector, has the same shape as the - position vector. - """ - if transformation.ndim == 0: - return position * transformation - - indices = [chr(ord("a") + index) for index in range(position.ndim - 1)] - - indices = "".join(indices) - - if transformation.ndim == 1: - return torch.einsum( - f"i,{indices}i->{indices}i", - transformation, - position, - ) - - if transformation.ndim == 2: - return torch.einsum( - f"ij,{indices}j->{indices}i", - transformation, - position, - ) - - raise ValueError("Unsupported transformation dimensions.") - - -def apply_transform(transformation: Tensor, position: Tensor) -> Tensor: - """ - Return affine transformed position. - - Parameters - ---------- - transformation : Tensor - Affine transformation matrix, must have shape - `(dimension, dimension)`. - - position : Tensor - Position, must have shape `(..., dimension)`. - - Returns - ------- - Tensor - Affine transformed position of shape `(..., dimension)`. - """ - - class _Transform(Function): - generate_vmap_rule = True - - @staticmethod - def forward(transformation: Tensor, position: Tensor) -> Tensor: - """ - Return affine transformed position. - - Parameters - ---------- - transformation : Tensor - Affine transformation matrix, must have shape - `(dimension, dimension)`. - - position : Tensor - Position, must have shape `(..., dimension)`. - - Returns - ------- - Tensor - Affine transformed position of shape `(..., dimension)`. - """ - return _apply_transform(transformation, position) - - @staticmethod - def setup_context(ctx, inputs, output): - transformation, position = inputs - - ctx.save_for_backward(transformation, position, output) - - @staticmethod - def jvp( - ctx, - grad_transformation: Tensor, - grad_position: Tensor, - ) -> (Tensor, Tensor): - transformation, position, _ = ctx.saved_tensors - - output = _apply_transform(transformation, position) - - grad_output = grad_position + _apply_transform( - grad_transformation, - position, - ) - - return output, grad_output - - @staticmethod - def backward(ctx, grad_output: Tensor) -> (Tensor, Tensor): - _, _, output = ctx.saved_tensors - - return output, grad_output +from beignet._apply_transform import _apply_transform, apply_transform +from beignet._invert_transform import invert_transform - return _Transform.apply(transformation, position) +T = TypeVar("T") def space( @@ -265,7 +126,7 @@ def displacement_fn( raise ValueError if perturbation is not None: - return _apply_transform(input - other, perturbation) + return _apply_transform(perturbation, input - other) return input - other @@ -275,7 +136,7 @@ def shift_fn(input: Tensor, other: Tensor, **_) -> Tensor: return displacement_fn, shift_fn if parallelepiped: - inverse_transformation = _inverse_transform(dimensions) + inverse_transformation = invert_transform(dimensions) if normalized: @@ -303,12 +164,12 @@ def displacement_fn( raise ValueError displacement = apply_transform( - _transformation, torch.remainder(input - other + 1.0 * 0.5, 1.0) - 1.0 * 0.5, + _transformation, ) if perturbation is not None: - return _apply_transform(displacement, perturbation) + return _apply_transform(perturbation, displacement) return displacement @@ -325,12 +186,12 @@ def shift_fn(input: Tensor, other: Tensor, **kwargs) -> Tensor: if "transformation" in kwargs: _transformation = kwargs["transformation"] - _inverse_transformation = _inverse_transform(_transformation) + _inverse_transformation = invert_transform(_transformation) if "updated_transformation" in kwargs: _transformation = kwargs["updated_transformation"] - return u(input, apply_transform(_inverse_transformation, other)) + return u(input, apply_transform(other, _inverse_transformation)) return displacement_fn, shift_fn @@ -342,14 +203,14 @@ def shift_fn(input: Tensor, other: Tensor, **kwargs) -> Tensor: if "transformation" in kwargs: _transformation = kwargs["transformation"] - _inverse_transformation = _inverse_transform( + _inverse_transformation = invert_transform( _transformation, ) if "updated_transformation" in kwargs: _transformation = kwargs["updated_transformation"] - return input + apply_transform(_inverse_transformation, other) + return input + apply_transform(other, _inverse_transformation) return displacement_fn, shift_fn @@ -367,13 +228,13 @@ def displacement_fn( if "transformation" in kwargs: _transformation = kwargs["transformation"] - _inverse_transformation = _inverse_transform(_transformation) + _inverse_transformation = invert_transform(_transformation) if "updated_transformation" in kwargs: _transformation = kwargs["updated_transformation"] - input = apply_transform(_inverse_transformation, input) - other = apply_transform(_inverse_transformation, other) + input = apply_transform(input, _inverse_transformation) + other = apply_transform(other, _inverse_transformation) if len(input.shape) != 1: raise ValueError @@ -382,12 +243,12 @@ def displacement_fn( raise ValueError displacement = apply_transform( - _transformation, torch.remainder(input - other + 1.0 * 0.5, 1.0) - 1.0 * 0.5, + _transformation, ) if perturbation is not None: - return _apply_transform(displacement, perturbation) + return _apply_transform(perturbation, displacement) return displacement @@ -404,7 +265,7 @@ def shift_fn(input: Tensor, other: Tensor, **kwargs) -> Tensor: if "transformation" in kwargs: _transformation = kwargs["transformation"] - _inverse_transformation = _inverse_transform( + _inverse_transformation = invert_transform( _transformation, ) @@ -412,11 +273,11 @@ def shift_fn(input: Tensor, other: Tensor, **kwargs) -> Tensor: _transformation = kwargs["updated_transformation"] return apply_transform( - _transformation, u( apply_transform(_inverse_transformation, input), apply_transform(_inverse_transformation, other), ), + _transformation, ) return displacement_fn, shift_fn @@ -442,7 +303,7 @@ def displacement_fn( displacement = torch.remainder(input - other + dimensions * 0.5, dimensions) if perturbation is not None: - return _apply_transform(displacement - dimensions * 0.5, perturbation) + return _apply_transform(perturbation, displacement - dimensions * 0.5) return displacement - dimensions * 0.5 diff --git a/tests/beignet/func/test__space.py b/tests/beignet/func/test__space.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/beignet/test__apply_transform.py b/tests/beignet/test__apply_transform.py new file mode 100644 index 0000000000..7811aebdba --- /dev/null +++ b/tests/beignet/test__apply_transform.py @@ -0,0 +1,21 @@ +import torch +import torch.func +from beignet import apply_transform +from torch import Tensor + + +def test_apply_transform(): + input = torch.randn([32, 3]) + + transform = torch.randn([3, 3]) + + def f(r: Tensor) -> Tensor: + return torch.sum(r**2) + + def g(r: Tensor, t: Tensor) -> Tensor: + return torch.sum(apply_transform(r, t) ** 2) + + torch.testing.assert_close( + torch.func.grad(f)(apply_transform(input, transform)), + torch.func.grad(g, 0)(input, transform), + ) diff --git a/tests/beignet/test__invert_transform.py b/tests/beignet/test__invert_transform.py new file mode 100644 index 0000000000..adbd8a9e88 --- /dev/null +++ b/tests/beignet/test__invert_transform.py @@ -0,0 +1,16 @@ +import beignet +import torch.testing + + +def test_invert_transform(): + input = torch.randn([32, 3]) + + transform = torch.randn([3, 3]) + + torch.testing.assert_close( + input, + beignet.apply_transform( + beignet.apply_transform(input, transform), + beignet.invert_transform(transform), + ), + ) From cdf85ad64a25359d0a9b1dd363f46affef04fac1 Mon Sep 17 00:00:00 2001 From: Allen Goodman Date: Fri, 10 May 2024 14:59:36 -0400 Subject: [PATCH 05/15] beignet.func.space --- tests/beignet/func/test__space.py | 98 +++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) diff --git a/tests/beignet/func/test__space.py b/tests/beignet/func/test__space.py index e69de29bb2..9ba118aa95 100644 --- a/tests/beignet/func/test__space.py +++ b/tests/beignet/func/test__space.py @@ -0,0 +1,98 @@ +from typing import Callable + +import beignet.func +import hypothesis +import hypothesis.strategies +import torch.testing + + +def map_product(fn: Callable) -> Callable: + return torch.vmap( + torch.vmap( + fn, + in_dims=(0, None), + out_dims=0, + ), + in_dims=(None, 0), + out_dims=0, + ) + + +@hypothesis.strategies.composite +def _strategy(function): + dtype = function( + hypothesis.strategies.sampled_from( + [ + torch.float32, + torch.float64, + ], + ), + ) + + maximum_size = function( + hypothesis.strategies.floats( + min_value=1.0, + max_value=8.0, + ), + ) + + particles = function( + hypothesis.strategies.integers( + min_value=16, + max_value=32, + ), + ) + + spatial_dimension = function( + hypothesis.strategies.integers( + min_value=1, + max_value=3, + ), + ) + + return ( + dtype, + torch.rand([particles, spatial_dimension], dtype=dtype), + particles, + torch.rand([spatial_dimension], dtype=dtype) * maximum_size, + spatial_dimension, + ) + + +@hypothesis.given(_strategy()) +@hypothesis.settings(deadline=None) +def test_space(data): + dtype, input, particles, size, spatial_dimension = data + + displacement_fn, shift_fn = beignet.func.space(size, parallelepiped=False) + + ( + parallelepiped_displacement_fn, + parallelepiped_shift_fn, + ) = beignet.func.space( + torch.diag(size), + ) + + standardized_input = input * size + + torch.testing.assert_close( + map_product( + displacement_fn, + )( + standardized_input, + standardized_input, + ), + map_product( + parallelepiped_displacement_fn, + )( + input, + input, + ), + ) + + displacement = torch.randn([particles, spatial_dimension], dtype=dtype) + + torch.testing.assert_close( + shift_fn(standardized_input, displacement), + parallelepiped_shift_fn(input, displacement) * size, + ) From 1c41099b810bed8199e18461b525182ea2a7c675 Mon Sep 17 00:00:00 2001 From: Allen Goodman Date: Fri, 10 May 2024 15:46:23 -0400 Subject: [PATCH 06/15] beignet.func.space --- src/beignet/func/_space.py | 20 +++++------ tests/beignet/func/test__space.py | 58 +++++++++++++++++++++++++++---- 2 files changed, 62 insertions(+), 16 deletions(-) diff --git a/src/beignet/func/_space.py b/src/beignet/func/_space.py index 2739d685fa..f764d5f413 100644 --- a/src/beignet/func/_space.py +++ b/src/beignet/func/_space.py @@ -151,8 +151,8 @@ def displacement_fn( _inverse_transformation = inverse_transformation - if "transformation" in kwargs: - _transformation = kwargs["transformation"] + if "transform" in kwargs: + _transformation = kwargs["transform"] if "updated_transformation" in kwargs: _transformation = kwargs["updated_transformation"] @@ -183,8 +183,8 @@ def shift_fn(input: Tensor, other: Tensor, **kwargs) -> Tensor: _inverse_transformation = inverse_transformation - if "transformation" in kwargs: - _transformation = kwargs["transformation"] + if "transform" in kwargs: + _transformation = kwargs["transform"] _inverse_transformation = invert_transform(_transformation) @@ -200,8 +200,8 @@ def shift_fn(input: Tensor, other: Tensor, **kwargs) -> Tensor: _inverse_transformation = inverse_transformation - if "transformation" in kwargs: - _transformation = kwargs["transformation"] + if "transform" in kwargs: + _transformation = kwargs["transform"] _inverse_transformation = invert_transform( _transformation, @@ -225,8 +225,8 @@ def displacement_fn( _inverse_transformation = inverse_transformation - if "transformation" in kwargs: - _transformation = kwargs["transformation"] + if "transform" in kwargs: + _transformation = kwargs["transform"] _inverse_transformation = invert_transform(_transformation) @@ -262,8 +262,8 @@ def shift_fn(input: Tensor, other: Tensor, **kwargs) -> Tensor: _inverse_transformation = inverse_transformation - if "transformation" in kwargs: - _transformation = kwargs["transformation"] + if "transform" in kwargs: + _transformation = kwargs["transform"] _inverse_transformation = invert_transform( _transformation, diff --git a/tests/beignet/func/test__space.py b/tests/beignet/func/test__space.py index 9ba118aa95..e91eaf6b44 100644 --- a/tests/beignet/func/test__space.py +++ b/tests/beignet/func/test__space.py @@ -1,9 +1,11 @@ +import functools from typing import Callable import beignet.func import hypothesis import hypothesis.strategies import torch.testing +from torch import Tensor def map_product(fn: Callable) -> Callable: @@ -75,16 +77,18 @@ def test_space(data): standardized_input = input * size + displacement_fn = map_product(displacement_fn) + + parallelepiped_displacement_fn = map_product( + parallelepiped_displacement_fn, + ) + torch.testing.assert_close( - map_product( - displacement_fn, - )( + displacement_fn( standardized_input, standardized_input, ), - map_product( - parallelepiped_displacement_fn, - )( + parallelepiped_displacement_fn( input, input, ), @@ -96,3 +100,45 @@ def test_space(data): shift_fn(standardized_input, displacement), parallelepiped_shift_fn(input, displacement) * size, ) + + def f(input: Tensor) -> Tensor: + return torch.sum(displacement_fn(input, input) ** 2) + + def g(input: Tensor) -> Tensor: + return torch.sum(parallelepiped_displacement_fn(input, input) ** 2) + + torch.testing.assert_close( + torch.func.grad(f)(standardized_input), + torch.func.grad(g)(input), + rtol=0.0001, + atol=0.0001, + ) + + size_a = 10.0 * torch.rand([]) + size_b = 10.0 * torch.rand([], dtype=dtype) + + transform_a = 0.5 * torch.randn([spatial_dimension, spatial_dimension]) + transform_b = 0.5 * torch.randn([spatial_dimension, spatial_dimension], dtype=dtype) + + transform_a = size_a * (torch.eye(spatial_dimension) + transform_a) + transform_b = size_b * (torch.eye(spatial_dimension) + transform_b) + + displacement_fn_a, shift_fn_a = beignet.func.space(transform_a) + displacement_fn_b, shift_fn_b = beignet.func.space(transform_b) + + displacement = torch.randn([particles, spatial_dimension], dtype=dtype) + + torch.testing.assert_close( + map_product( + functools.partial( + displacement_fn_a, + transform=transform_b, + ), + )(input, input), + map_product(displacement_fn_b)(input, input), + ) + + torch.testing.assert_close( + shift_fn_a(input, displacement, transform=transform_b), + shift_fn_b(input, displacement), + ) From 30a4ea8ca761ad404f56f5d4367f56ac66e7a2f8 Mon Sep 17 00:00:00 2001 From: Allen Goodman Date: Fri, 10 May 2024 16:21:16 -0400 Subject: [PATCH 07/15] beignet.func.space --- src/beignet/func/_space.py | 72 ++++++++++++++++++-------------------- 1 file changed, 35 insertions(+), 37 deletions(-) diff --git a/src/beignet/func/_space.py b/src/beignet/func/_space.py index f764d5f413..52e92658e6 100644 --- a/src/beignet/func/_space.py +++ b/src/beignet/func/_space.py @@ -136,7 +136,7 @@ def shift_fn(input: Tensor, other: Tensor, **_) -> Tensor: return displacement_fn, shift_fn if parallelepiped: - inverse_transformation = invert_transform(dimensions) + inverted_transform = invert_transform(dimensions) if normalized: @@ -147,15 +147,15 @@ def displacement_fn( perturbation: Tensor | None = None, **kwargs, ) -> Tensor: - _transformation = dimensions + _transform = dimensions - _inverse_transformation = inverse_transformation + _inverted_transform = inverted_transform if "transform" in kwargs: - _transformation = kwargs["transform"] + _transform = kwargs["transform"] if "updated_transformation" in kwargs: - _transformation = kwargs["updated_transformation"] + _transform = kwargs["updated_transformation"] if len(input.shape) != 1: raise ValueError @@ -165,7 +165,7 @@ def displacement_fn( displacement = apply_transform( torch.remainder(input - other + 1.0 * 0.5, 1.0) - 1.0 * 0.5, - _transformation, + _transform, ) if perturbation is not None: @@ -179,38 +179,36 @@ def u(input: Tensor, other: Tensor) -> Tensor: return torch.remainder(input + other, 1.0) def shift_fn(input: Tensor, other: Tensor, **kwargs) -> Tensor: - _transformation = dimensions + _transform = dimensions - _inverse_transformation = inverse_transformation + _inverted_transform = inverted_transform if "transform" in kwargs: - _transformation = kwargs["transform"] + _transform = kwargs["transform"] - _inverse_transformation = invert_transform(_transformation) + _inverted_transform = invert_transform(_transform) if "updated_transformation" in kwargs: - _transformation = kwargs["updated_transformation"] + _transform = kwargs["updated_transformation"] - return u(input, apply_transform(other, _inverse_transformation)) + return u(input, apply_transform(other, _inverted_transform)) return displacement_fn, shift_fn def shift_fn(input: Tensor, other: Tensor, **kwargs) -> Tensor: - _transformation = dimensions + _transform = dimensions - _inverse_transformation = inverse_transformation + _inverted_transform = inverted_transform if "transform" in kwargs: - _transformation = kwargs["transform"] + _transform = kwargs["transform"] - _inverse_transformation = invert_transform( - _transformation, - ) + _inverted_transform = invert_transform(_transform) if "updated_transformation" in kwargs: - _transformation = kwargs["updated_transformation"] + _transform = kwargs["updated_transformation"] - return input + apply_transform(other, _inverse_transformation) + return input + apply_transform(other, _inverted_transform) return displacement_fn, shift_fn @@ -221,20 +219,20 @@ def displacement_fn( perturbation: Tensor | None = None, **kwargs, ) -> Tensor: - _transformation = dimensions + _transform = dimensions - _inverse_transformation = inverse_transformation + _inverted_transform = inverted_transform if "transform" in kwargs: - _transformation = kwargs["transform"] + _transform = kwargs["transform"] - _inverse_transformation = invert_transform(_transformation) + _inverted_transform = invert_transform(_transform) if "updated_transformation" in kwargs: - _transformation = kwargs["updated_transformation"] + _transform = kwargs["updated_transformation"] - input = apply_transform(input, _inverse_transformation) - other = apply_transform(other, _inverse_transformation) + input = apply_transform(input, _inverted_transform) + other = apply_transform(other, _inverted_transform) if len(input.shape) != 1: raise ValueError @@ -244,7 +242,7 @@ def displacement_fn( displacement = apply_transform( torch.remainder(input - other + 1.0 * 0.5, 1.0) - 1.0 * 0.5, - _transformation, + _transform, ) if perturbation is not None: @@ -258,26 +256,26 @@ def u(input: Tensor, other: Tensor) -> Tensor: return torch.remainder(input + other, 1.0) def shift_fn(input: Tensor, other: Tensor, **kwargs) -> Tensor: - _transformation = dimensions + _transform = dimensions - _inverse_transformation = inverse_transformation + _inverted_transform = inverted_transform if "transform" in kwargs: - _transformation = kwargs["transform"] + _transform = kwargs["transform"] - _inverse_transformation = invert_transform( - _transformation, + _inverted_transform = invert_transform( + _transform, ) if "updated_transformation" in kwargs: - _transformation = kwargs["updated_transformation"] + _transform = kwargs["updated_transformation"] return apply_transform( u( - apply_transform(_inverse_transformation, input), - apply_transform(_inverse_transformation, other), + apply_transform(_inverted_transform, input), + apply_transform(_inverted_transform, other), ), - _transformation, + _transform, ) return displacement_fn, shift_fn From 0cd8676e05b8fd6681731daa47b3a9e03129bbdd Mon Sep 17 00:00:00 2001 From: Henry Isaacson Date: Fri, 24 May 2024 15:13:40 -0400 Subject: [PATCH 08/15] apply_transform --- src/beignet/_apply_transform.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/beignet/_apply_transform.py b/src/beignet/_apply_transform.py index 6b0defbc40..c639219f44 100644 --- a/src/beignet/_apply_transform.py +++ b/src/beignet/_apply_transform.py @@ -31,14 +31,14 @@ def _apply_transform(input: Tensor, transform: Tensor) -> Tensor: if transform.ndim == 1: return torch.einsum( - f"i,{indices}i->{indices}i", + "i,...i->...i", transform, input, ) if transform.ndim == 2: return torch.einsum( - f"ij,{indices}j->{indices}i", + f"ij,...j->...i", transform, input, ) From 131c6783045cdc0147476558d767be61f1e3fa9d Mon Sep 17 00:00:00 2001 From: Henry Isaacson Date: Thu, 30 May 2024 10:55:42 -0400 Subject: [PATCH 09/15] dynamo import --- src/beignet/_apply_transform.py | 11 +++++++++++ tests/beignet/test__apply_transform.py | 3 ++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/beignet/_apply_transform.py b/src/beignet/_apply_transform.py index c639219f44..ddbf40c021 100644 --- a/src/beignet/_apply_transform.py +++ b/src/beignet/_apply_transform.py @@ -3,6 +3,13 @@ from torch.autograd import Function +# Only import torch._dynamo when necessary https://github.com/pytorch/pytorch/issues/110549 +def conditional_import_torch_dynamo(): + import torch._dynamo + + return torch._dynamo + + def _apply_transform(input: Tensor, transform: Tensor) -> Tensor: """ Applies an affine transformation to the position vector. @@ -80,6 +87,8 @@ def setup_context(ctx, inputs, output): def jvp(ctx, grad_transform: Tensor, grad_position: Tensor) -> (Tensor, Tensor): transformation, position, _ = ctx.saved_tensors + _dynamo = conditional_import_torch_dynamo() + output = _apply_transform(position, transformation) grad_output = grad_position + _apply_transform(position, grad_transform) @@ -90,6 +99,8 @@ def jvp(ctx, grad_transform: Tensor, grad_position: Tensor) -> (Tensor, Tensor): def backward(ctx, grad_output: Tensor) -> (Tensor, Tensor): _, _, output = ctx.saved_tensors + _dynamo = conditional_import_torch_dynamo() + return output, grad_output diff --git a/tests/beignet/test__apply_transform.py b/tests/beignet/test__apply_transform.py index 7811aebdba..357b4d91be 100644 --- a/tests/beignet/test__apply_transform.py +++ b/tests/beignet/test__apply_transform.py @@ -1,10 +1,11 @@ import torch import torch.func -from beignet import apply_transform from torch import Tensor def test_apply_transform(): + from beignet._apply_transform import apply_transform + input = torch.randn([32, 3]) transform = torch.randn([3, 3]) From 865c9e3cb2dd6d7919c88c2bae2ba9f2f0711152 Mon Sep 17 00:00:00 2001 From: Henry Isaacson Date: Thu, 30 May 2024 11:09:41 -0400 Subject: [PATCH 10/15] pin torch --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 16bb1154bc..9d6cd642a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ requires = [ authors = [{ email = "allen.goodman@icloud.com", name = "Allen Goodman" }] dependencies = [ "pooch", - "torch", + "torch==2.2.2", "tqdm", ] dynamic = ["version"] From 08199e933e8effcf3085fb75e0545e980ddd4784 Mon Sep 17 00:00:00 2001 From: Henry Isaacson Date: Thu, 30 May 2024 11:21:03 -0400 Subject: [PATCH 11/15] revert dynamo changes --- src/beignet/_apply_transform.py | 11 ----------- tests/beignet/test__apply_transform.py | 3 +-- 2 files changed, 1 insertion(+), 13 deletions(-) diff --git a/src/beignet/_apply_transform.py b/src/beignet/_apply_transform.py index ddbf40c021..c639219f44 100644 --- a/src/beignet/_apply_transform.py +++ b/src/beignet/_apply_transform.py @@ -3,13 +3,6 @@ from torch.autograd import Function -# Only import torch._dynamo when necessary https://github.com/pytorch/pytorch/issues/110549 -def conditional_import_torch_dynamo(): - import torch._dynamo - - return torch._dynamo - - def _apply_transform(input: Tensor, transform: Tensor) -> Tensor: """ Applies an affine transformation to the position vector. @@ -87,8 +80,6 @@ def setup_context(ctx, inputs, output): def jvp(ctx, grad_transform: Tensor, grad_position: Tensor) -> (Tensor, Tensor): transformation, position, _ = ctx.saved_tensors - _dynamo = conditional_import_torch_dynamo() - output = _apply_transform(position, transformation) grad_output = grad_position + _apply_transform(position, grad_transform) @@ -99,8 +90,6 @@ def jvp(ctx, grad_transform: Tensor, grad_position: Tensor) -> (Tensor, Tensor): def backward(ctx, grad_output: Tensor) -> (Tensor, Tensor): _, _, output = ctx.saved_tensors - _dynamo = conditional_import_torch_dynamo() - return output, grad_output diff --git a/tests/beignet/test__apply_transform.py b/tests/beignet/test__apply_transform.py index 357b4d91be..7811aebdba 100644 --- a/tests/beignet/test__apply_transform.py +++ b/tests/beignet/test__apply_transform.py @@ -1,11 +1,10 @@ import torch import torch.func +from beignet import apply_transform from torch import Tensor def test_apply_transform(): - from beignet._apply_transform import apply_transform - input = torch.randn([32, 3]) transform = torch.randn([3, 3]) From fe35ab8f8193208d8fcd7e741a53d3ab277d364d Mon Sep 17 00:00:00 2001 From: Allen Goodman Date: Mon, 10 Jun 2024 14:45:10 -0400 Subject: [PATCH 12/15] cleanup --- src/beignet/_apply_transform.py | 31 +++----- src/beignet/func/_space.py | 132 ++++++++++++++++++++++++-------- 2 files changed, 110 insertions(+), 53 deletions(-) diff --git a/src/beignet/_apply_transform.py b/src/beignet/_apply_transform.py index c639219f44..8b5b0193a1 100644 --- a/src/beignet/_apply_transform.py +++ b/src/beignet/_apply_transform.py @@ -22,28 +22,15 @@ def _apply_transform(input: Tensor, transform: Tensor) -> Tensor: Affine transformed position vector, has the same shape as the position vector. """ - if transform.ndim == 0: - return input * transform - - indices = [chr(ord("a") + index) for index in range(input.ndim - 1)] - - indices = "".join(indices) - - if transform.ndim == 1: - return torch.einsum( - "i,...i->...i", - transform, - input, - ) - - if transform.ndim == 2: - return torch.einsum( - f"ij,...j->...i", - transform, - input, - ) - - raise ValueError + match transform.ndim: + case 0: + return input * transform + case 1: + return torch.einsum("i,...i->...i", transform, input) + case 2: + return torch.einsum("ij,...j->...i", transform, input) + case _: + raise ValueError class _ApplyTransform(Function): diff --git a/src/beignet/func/_space.py b/src/beignet/func/_space.py index 52e92658e6..4b552231d8 100644 --- a/src/beignet/func/_space.py +++ b/src/beignet/func/_space.py @@ -3,8 +3,7 @@ import torch from torch import Tensor -from beignet._apply_transform import _apply_transform, apply_transform -from beignet._invert_transform import invert_transform +import beignet T = TypeVar("T") @@ -126,7 +125,25 @@ def displacement_fn( raise ValueError if perturbation is not None: - return _apply_transform(perturbation, input - other) + transform = input - other + + match transform.ndim: + case 0: + return perturbation * transform + case 1: + return torch.einsum( + "i,...i->...i", + transform, + perturbation, + ) + case 2: + return torch.einsum( + "ij,...j->...i", + transform, + perturbation, + ) + case _: + raise ValueError return input - other @@ -136,7 +153,7 @@ def shift_fn(input: Tensor, other: Tensor, **_) -> Tensor: return displacement_fn, shift_fn if parallelepiped: - inverted_transform = invert_transform(dimensions) + inverted_transform = beignet.invert_transform(dimensions) if normalized: @@ -154,8 +171,8 @@ def displacement_fn( if "transform" in kwargs: _transform = kwargs["transform"] - if "updated_transformation" in kwargs: - _transform = kwargs["updated_transformation"] + if "updated_transform" in kwargs: + _transform = kwargs["updated_transform"] if len(input.shape) != 1: raise ValueError @@ -163,13 +180,29 @@ def displacement_fn( if input.shape != other.shape: raise ValueError - displacement = apply_transform( + displacement = beignet.apply_transform( torch.remainder(input - other + 1.0 * 0.5, 1.0) - 1.0 * 0.5, _transform, ) if perturbation is not None: - return _apply_transform(perturbation, displacement) + match displacement.ndim: + case 0: + return perturbation * displacement + case 1: + return torch.einsum( + "i,...i->...i", + displacement, + perturbation, + ) + case 2: + return torch.einsum( + "ij,...j->...i", + displacement, + perturbation, + ) + case _: + raise ValueError return displacement @@ -186,12 +219,12 @@ def shift_fn(input: Tensor, other: Tensor, **kwargs) -> Tensor: if "transform" in kwargs: _transform = kwargs["transform"] - _inverted_transform = invert_transform(_transform) + _inverted_transform = beignet.invert_transform(_transform) - if "updated_transformation" in kwargs: - _transform = kwargs["updated_transformation"] + if "updated_transform" in kwargs: + _transform = kwargs["updated_transform"] - return u(input, apply_transform(other, _inverted_transform)) + return u(input, beignet.apply_transform(other, _inverted_transform)) return displacement_fn, shift_fn @@ -203,12 +236,12 @@ def shift_fn(input: Tensor, other: Tensor, **kwargs) -> Tensor: if "transform" in kwargs: _transform = kwargs["transform"] - _inverted_transform = invert_transform(_transform) + _inverted_transform = beignet.invert_transform(_transform) - if "updated_transformation" in kwargs: - _transform = kwargs["updated_transformation"] + if "updated_transform" in kwargs: + _transform = kwargs["updated_transform"] - return input + apply_transform(other, _inverted_transform) + return input + beignet.apply_transform(other, _inverted_transform) return displacement_fn, shift_fn @@ -226,13 +259,13 @@ def displacement_fn( if "transform" in kwargs: _transform = kwargs["transform"] - _inverted_transform = invert_transform(_transform) + _inverted_transform = beignet.invert_transform(_transform) - if "updated_transformation" in kwargs: - _transform = kwargs["updated_transformation"] + if "updated_transform" in kwargs: + _transform = kwargs["updated_transform"] - input = apply_transform(input, _inverted_transform) - other = apply_transform(other, _inverted_transform) + input = beignet.apply_transform(input, _inverted_transform) + other = beignet.apply_transform(other, _inverted_transform) if len(input.shape) != 1: raise ValueError @@ -240,13 +273,29 @@ def displacement_fn( if input.shape != other.shape: raise ValueError - displacement = apply_transform( + displacement = beignet.apply_transform( torch.remainder(input - other + 1.0 * 0.5, 1.0) - 1.0 * 0.5, _transform, ) if perturbation is not None: - return _apply_transform(perturbation, displacement) + match displacement.ndim: + case 0: + return perturbation * displacement + case 1: + return torch.einsum( + "i,...i->...i", + displacement, + perturbation, + ) + case 2: + return torch.einsum( + "ij,...j->...i", + displacement, + perturbation, + ) + case _: + raise ValueError return displacement @@ -263,17 +312,17 @@ def shift_fn(input: Tensor, other: Tensor, **kwargs) -> Tensor: if "transform" in kwargs: _transform = kwargs["transform"] - _inverted_transform = invert_transform( + _inverted_transform = beignet.invert_transform( _transform, ) - if "updated_transformation" in kwargs: - _transform = kwargs["updated_transformation"] + if "updated_transform" in kwargs: + _transform = kwargs["updated_transform"] - return apply_transform( + return beignet.apply_transform( u( - apply_transform(_inverted_transform, input), - apply_transform(_inverted_transform, other), + beignet.apply_transform(_inverted_transform, input), + beignet.apply_transform(_inverted_transform, other), ), _transform, ) @@ -298,10 +347,31 @@ def displacement_fn( if input.shape != other.shape: raise ValueError - displacement = torch.remainder(input - other + dimensions * 0.5, dimensions) + displacement = torch.remainder( + input - other + dimensions * 0.5, + dimensions, + ) if perturbation is not None: - return _apply_transform(perturbation, displacement - dimensions * 0.5) + transform = displacement - dimensions * 0.5 + + match transform.ndim: + case 0: + return perturbation * transform + case 1: + return torch.einsum( + "i,...i->...i", + transform, + perturbation, + ) + case 2: + return torch.einsum( + "ij,...j->...i", + transform, + perturbation, + ) + case _: + raise ValueError return displacement - dimensions * 0.5 From 5c9cc924757d95857a094b2952d0590e0c2b6f26 Mon Sep 17 00:00:00 2001 From: Allen Goodman Date: Mon, 10 Jun 2024 14:48:49 -0400 Subject: [PATCH 13/15] cleanup --- docs/beignet.func.md | 2 +- src/beignet/_apply_transform.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/beignet.func.md b/docs/beignet.func.md index acc54935f7..d1ce96093a 100644 --- a/docs/beignet.func.md +++ b/docs/beignet.func.md @@ -1,3 +1,3 @@ # beignet.func -::: beignet.func.space \ No newline at end of file +::: beignet.func.space diff --git a/src/beignet/_apply_transform.py b/src/beignet/_apply_transform.py index 8b5b0193a1..5128262681 100644 --- a/src/beignet/_apply_transform.py +++ b/src/beignet/_apply_transform.py @@ -37,7 +37,7 @@ class _ApplyTransform(Function): generate_vmap_rule = True @staticmethod - def forward(transform: Tensor, position: Tensor) -> Tensor: + def forward(transform: Tensor, input: Tensor) -> Tensor: """ Return affine transformed position. @@ -47,7 +47,7 @@ def forward(transform: Tensor, position: Tensor) -> Tensor: Affine transformation matrix, must have shape `(dimension, dimension)`. - position : Tensor + input : Tensor Position, must have shape `(..., dimension)`. Returns @@ -55,21 +55,21 @@ def forward(transform: Tensor, position: Tensor) -> Tensor: Tensor Affine transformed position of shape `(..., dimension)`. """ - return _apply_transform(position, transform) + return _apply_transform(input, transform) @staticmethod def setup_context(ctx, inputs, output): - transformation, position = inputs + transformation, input = inputs - ctx.save_for_backward(transformation, position, output) + ctx.save_for_backward(transformation, input, output) @staticmethod - def jvp(ctx, grad_transform: Tensor, grad_position: Tensor) -> (Tensor, Tensor): + def jvp(ctx, grad_transform: Tensor, grad_input: Tensor) -> (Tensor, Tensor): transformation, position, _ = ctx.saved_tensors output = _apply_transform(position, transformation) - grad_output = grad_position + _apply_transform(position, grad_transform) + grad_output = grad_input + _apply_transform(position, grad_transform) return output, grad_output From b1a49a8b48f2f4af19ccc42194a770cd9a24302e Mon Sep 17 00:00:00 2001 From: Allen Goodman Date: Mon, 10 Jun 2024 14:52:47 -0400 Subject: [PATCH 14/15] cleanup --- src/beignet/_apply_transform.py | 10 ++++----- src/beignet/func/_space.py | 39 +++++++++++++++------------------ 2 files changed, 23 insertions(+), 26 deletions(-) diff --git a/src/beignet/_apply_transform.py b/src/beignet/_apply_transform.py index 5128262681..6c4f3e06e7 100644 --- a/src/beignet/_apply_transform.py +++ b/src/beignet/_apply_transform.py @@ -59,17 +59,17 @@ def forward(transform: Tensor, input: Tensor) -> Tensor: @staticmethod def setup_context(ctx, inputs, output): - transformation, input = inputs + transform, input = inputs - ctx.save_for_backward(transformation, input, output) + ctx.save_for_backward(transform, input, output) @staticmethod def jvp(ctx, grad_transform: Tensor, grad_input: Tensor) -> (Tensor, Tensor): - transformation, position, _ = ctx.saved_tensors + transform, input, _ = ctx.saved_tensors - output = _apply_transform(position, transformation) + output = _apply_transform(input, transform) - grad_output = grad_input + _apply_transform(position, grad_transform) + grad_output = grad_input + _apply_transform(input, grad_transform) return output, grad_output diff --git a/src/beignet/func/_space.py b/src/beignet/func/_space.py index 4b552231d8..b567f70a7d 100644 --- a/src/beignet/func/_space.py +++ b/src/beignet/func/_space.py @@ -9,7 +9,7 @@ def space( - dimensions: Tensor | None = None, + box: Tensor | None = None, *, normalized: bool = True, parallelepiped: bool = True, @@ -21,7 +21,7 @@ def space( subsets of $\mathbb{R}^{D}$ (where $D = 1$, $2$, or $3$) and is instrumental in setting up simulation environments with specific characteristics (e.g., periodic boundary conditions). The function returns - a a displacement function and a shift function to compute particle + a displacement function and a shift function to compute particle interactions and movements in space. This function supports deformation of the simulation cell, crucial for @@ -30,9 +30,7 @@ def space( Parameters ---------- - dimensions : Tensor | None, default=None - Dimensions of the simulation space. - + box : Tensor | None, default=None Interpretation varies based on the value of `parallelepiped`. If `parallelepiped` is `True`, must be an affine transformation, $T$, specified in one of three ways: @@ -43,8 +41,7 @@ def space( If `parallelepiped` is `False`, must be the edge lengths. - If `dimensions` is `None`, the simulation space has free boundary - conditions. + If `box` is `None`, the simulation space has free boundary conditions. normalized : bool, default=True If `normalized` is `True`, positions are stored in the unit cube. @@ -106,10 +103,10 @@ def space( ), ) """ - if isinstance(dimensions, (int, float)): - dimensions = torch.tensor([dimensions]) + if isinstance(box, (int, float)): + box = torch.tensor([box]) - if dimensions is None: + if box is None: def displacement_fn( input: Tensor, @@ -153,7 +150,7 @@ def shift_fn(input: Tensor, other: Tensor, **_) -> Tensor: return displacement_fn, shift_fn if parallelepiped: - inverted_transform = beignet.invert_transform(dimensions) + inverted_transform = beignet.invert_transform(box) if normalized: @@ -164,7 +161,7 @@ def displacement_fn( perturbation: Tensor | None = None, **kwargs, ) -> Tensor: - _transform = dimensions + _transform = box _inverted_transform = inverted_transform @@ -212,7 +209,7 @@ def u(input: Tensor, other: Tensor) -> Tensor: return torch.remainder(input + other, 1.0) def shift_fn(input: Tensor, other: Tensor, **kwargs) -> Tensor: - _transform = dimensions + _transform = box _inverted_transform = inverted_transform @@ -229,7 +226,7 @@ def shift_fn(input: Tensor, other: Tensor, **kwargs) -> Tensor: return displacement_fn, shift_fn def shift_fn(input: Tensor, other: Tensor, **kwargs) -> Tensor: - _transform = dimensions + _transform = box _inverted_transform = inverted_transform @@ -252,7 +249,7 @@ def displacement_fn( perturbation: Tensor | None = None, **kwargs, ) -> Tensor: - _transform = dimensions + _transform = box _inverted_transform = inverted_transform @@ -305,7 +302,7 @@ def u(input: Tensor, other: Tensor) -> Tensor: return torch.remainder(input + other, 1.0) def shift_fn(input: Tensor, other: Tensor, **kwargs) -> Tensor: - _transform = dimensions + _transform = box _inverted_transform = inverted_transform @@ -348,12 +345,12 @@ def displacement_fn( raise ValueError displacement = torch.remainder( - input - other + dimensions * 0.5, - dimensions, + input - other + box * 0.5, + box, ) if perturbation is not None: - transform = displacement - dimensions * 0.5 + transform = displacement - box * 0.5 match transform.ndim: case 0: @@ -373,12 +370,12 @@ def displacement_fn( case _: raise ValueError - return displacement - dimensions * 0.5 + return displacement - box * 0.5 if remapped: def shift_fn(input: Tensor, other: Tensor, **_) -> Tensor: - return torch.remainder(input + other, dimensions) + return torch.remainder(input + other, box) else: def shift_fn(input: Tensor, other: Tensor, **_) -> Tensor: From ad2964ae1c236a56184d5d12eddb07367ba31e57 Mon Sep 17 00:00:00 2001 From: Allen Goodman Date: Tue, 11 Jun 2024 15:00:07 -0400 Subject: [PATCH 15/15] test_invert_transform --- tests/beignet/test__invert_transform.py | 29 +++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/tests/beignet/test__invert_transform.py b/tests/beignet/test__invert_transform.py index adbd8a9e88..dc0679ab51 100644 --- a/tests/beignet/test__invert_transform.py +++ b/tests/beignet/test__invert_transform.py @@ -1,12 +1,37 @@ import beignet +import hypothesis.strategies import torch.testing -def test_invert_transform(): - input = torch.randn([32, 3]) +@hypothesis.strategies.composite +def _strategy(function): + input = torch.randn( + [ + function( + hypothesis.strategies.integers( + min_value=1, + max_value=8, + ), + ), + 3, + ] + ) transform = torch.randn([3, 3]) + return ( + (input, transform), + beignet.apply_transform( + beignet.apply_transform(input, transform), + beignet.invert_transform(transform), + ), + ) + + +@hypothesis.given(_strategy()) +def test_invert_transform(data): + (input, transform), expected = data + torch.testing.assert_close( input, beignet.apply_transform(