-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Henry Isaacson
committed
Apr 19, 2024
1 parent
3e4ffd0
commit 0a5e946
Showing
4 changed files
with
328 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,328 @@ | ||
from typing import ( | ||
Callable, | ||
Optional, | ||
Tuple, | ||
TypeVar, | ||
) | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
from .__inverse_transform import _inverse_transform | ||
from .__transform import _transform | ||
from ._transform import transform | ||
|
||
T = TypeVar("T") | ||
|
||
|
||
def space( | ||
dimensions: Optional[Tensor] = None, | ||
*, | ||
normalized: bool = True, | ||
parallelepiped: bool = True, | ||
remapped: bool = True, | ||
) -> Tuple[Callable, Callable]: | ||
r"""Define a simulation space. | ||
This function is fundamental in constructing simulation spaces derived from | ||
subsets of $\mathbb{R}^{D}$ (where $D = 1$, $2$, or $3$) and is | ||
instrumental in setting up simulation environments with specific | ||
characteristics (e.g., periodic boundary conditions). The function returns | ||
a a displacement function and a shift function to compute particle | ||
interactions and movements in space. | ||
This function supports deformation of the simulation cell, crucial for | ||
certain types of simulations, such as those involving finite deformations | ||
or the computation of elastic constants. | ||
Parameters | ||
---------- | ||
dimensions : Optional[Tensor], default=None | ||
Dimensions of the simulation space. Interpretation varies based on the | ||
value of `parallelepiped`. If `parallelepiped` is `True`, must be an | ||
affine transformation, $T$, specified in one of three ways: a cube, | ||
$L$; an orthorhombic unit cell, $[L_{x}, L_{y}, L_{z}]$; or a triclinic | ||
cell, upper triangular matrix. If `parallelepiped` is `False`, must be | ||
the edge lengths. If `None`, the simulation space has free boudnary | ||
conditions. | ||
normalized : bool, default=True | ||
If `True`, positions are stored in the unit cube. Displacements and | ||
shifts are computed in a normalized simulation space and can be | ||
transformed back to real simulation space using the provided affine | ||
transformation matrix. If `False`, positions are expressed and | ||
computations performed directly in the real simulation space. | ||
parallelepiped : bool, default=True | ||
If `True`, the simulation space is defined as a ${1, 2, 3}$-dimensional | ||
parallelepiped with periodic boundary conditions. If `False`, the space | ||
is defined on a ${1, 2, 3}$-dimensional hypercube. | ||
remapped : bool, default=True | ||
If `True`, positions and displacements are remapped to stay in the | ||
bounds of the defined simulation space. A rempapped simulation space is | ||
topologically equivalent to a torus, ensuring that particles exiting | ||
one boundary re-enter from the opposite side. This is particularly | ||
relevant for simulation spaces with periodic boundary conditions. | ||
Returns | ||
------- | ||
Tuple[Callable[[Tensor, Tensor], Tensor], Callable[[Tensor, Tensor], Tensor]] | ||
A tuple containing two functions: | ||
1. The displacement function, $\overrightarrow{d}$, measures the | ||
difference between two points in the simulation space, factoring in | ||
the geometry and boundary conditions. This function is used to | ||
calculate particle interactions and dynamics. | ||
2. The shift function, $u$, applies a displacement vector to a point | ||
in the space, effectively moving it. This function is used to | ||
update simulated particle positions. | ||
Examples | ||
-------- | ||
transformation = torch.tensor([10.0]) | ||
displacement_fn, shift_fn = space( | ||
transformation, | ||
normalized=False, | ||
) | ||
normalized_displacement_fn, normalized_shift_fn = space( | ||
transformation, | ||
normalized=True, | ||
) | ||
normalized_position = torch.rand([4, 3]) | ||
position = transformation * normalized_position | ||
displacement = torch.randn([4, 3]) | ||
torch.testing.assert_close( | ||
displacement_fn(position[0], position[1]), | ||
normalized_displacement_fn( | ||
normalized_position[0], | ||
normalized_position[1], | ||
), | ||
) | ||
""" | ||
if isinstance(dimensions, (int, float)): | ||
dimensions = torch.tensor([dimensions]) | ||
|
||
if dimensions is None: | ||
|
||
def _displacement_fn( | ||
a: Tensor, | ||
b: Tensor, | ||
*, | ||
perturbation: Optional[Tensor] = None, | ||
**_, | ||
) -> Tensor: | ||
if len(a.shape) != 1: | ||
raise ValueError | ||
|
||
if a.shape != b.shape: | ||
raise ValueError | ||
|
||
if perturbation is not None: | ||
return _transform(a - b, perturbation) | ||
|
||
return a - b | ||
|
||
def _shift_fn(a: Tensor, b: Tensor, **_) -> Tensor: | ||
return a + b | ||
|
||
return _displacement_fn, _shift_fn | ||
|
||
if parallelepiped: | ||
inverse_transformation = _inverse_transform(dimensions) | ||
|
||
if normalized: | ||
|
||
def _displacement_fn( | ||
a: Tensor, | ||
b: Tensor, | ||
*, | ||
perturbation: Optional[Tensor] = None, | ||
**kwargs, | ||
) -> Tensor: | ||
_transformation = dimensions | ||
|
||
_inverse_transformation = inverse_transformation | ||
|
||
if "transformation" in kwargs: | ||
_transformation = kwargs["transformation"] | ||
|
||
if "updated_transformation" in kwargs: | ||
_transformation = kwargs["updated_transformation"] | ||
|
||
if len(a.shape) != 1: | ||
raise ValueError | ||
|
||
if a.shape != b.shape: | ||
raise ValueError | ||
|
||
displacement = a - b | ||
|
||
displacement = torch.remainder(displacement + 1.0 * 0.5, 1.0) | ||
|
||
displacement = displacement - 1.0 * 0.5 | ||
|
||
displacement = transform(_transformation, displacement) | ||
|
||
if perturbation is not None: | ||
return _transform(displacement, perturbation) | ||
|
||
return displacement | ||
|
||
if remapped: | ||
|
||
def _u(a: Tensor, b: Tensor) -> Tensor: | ||
return torch.remainder(a + b, 1.0) | ||
|
||
def _shift_fn(a: Tensor, b: Tensor, **kwargs) -> Tensor: | ||
_transformation = dimensions | ||
|
||
_inverse_transformation = inverse_transformation | ||
|
||
if "transformation" in kwargs: | ||
_transformation = kwargs["transformation"] | ||
|
||
_inverse_transformation = _inverse_transform( | ||
_transformation | ||
) | ||
|
||
if "updated_transformation" in kwargs: | ||
_transformation = kwargs["updated_transformation"] | ||
|
||
return _u(a, transform(_inverse_transformation, b)) | ||
|
||
return _displacement_fn, _shift_fn | ||
|
||
def _shift_fn(a: Tensor, b: Tensor, **kwargs) -> Tensor: | ||
_transformation = dimensions | ||
|
||
_inverse_transformation = inverse_transformation | ||
|
||
if "transformation" in kwargs: | ||
_transformation = kwargs["transformation"] | ||
|
||
_inverse_transformation = _inverse_transform( | ||
_transformation, | ||
) | ||
|
||
if "updated_transformation" in kwargs: | ||
_transformation = kwargs["updated_transformation"] | ||
|
||
return a + transform(_inverse_transformation, b) | ||
|
||
return _displacement_fn, _shift_fn | ||
|
||
def _displacement_fn( | ||
a: Tensor, | ||
b: Tensor, | ||
*, | ||
perturbation: Optional[Tensor] = None, | ||
**kwargs, | ||
) -> Tensor: | ||
_transformation = dimensions | ||
|
||
_inverse_transformation = inverse_transformation | ||
|
||
if "transformation" in kwargs: | ||
_transformation = kwargs["transformation"] | ||
|
||
_inverse_transformation = _inverse_transform(_transformation) | ||
|
||
if "updated_transformation" in kwargs: | ||
_transformation = kwargs["updated_transformation"] | ||
|
||
a = transform(_inverse_transformation, a) | ||
b = transform(_inverse_transformation, b) | ||
|
||
if len(a.shape) != 1: | ||
raise ValueError | ||
|
||
if a.shape != b.shape: | ||
raise ValueError | ||
|
||
displacement = a - b | ||
|
||
displacement = torch.remainder(displacement + 1.0 * 0.5, 1.0) | ||
|
||
displacement = displacement - 1.0 * 0.5 | ||
|
||
displacement = transform(_transformation, displacement) | ||
|
||
if perturbation is not None: | ||
return _transform(displacement, perturbation) | ||
|
||
return displacement | ||
|
||
if remapped: | ||
|
||
def _u(a: Tensor, b: Tensor) -> Tensor: | ||
return torch.remainder(a + b, 1.0) | ||
|
||
def _shift_fn(a: Tensor, b: Tensor, **kwargs) -> Tensor: | ||
_transformation = dimensions | ||
|
||
_inverse_transformation = inverse_transformation | ||
|
||
if "transformation" in kwargs: | ||
_transformation = kwargs["transformation"] | ||
|
||
_inverse_transformation = _inverse_transform( | ||
_transformation, | ||
) | ||
|
||
if "updated_transformation" in kwargs: | ||
_transformation = kwargs["updated_transformation"] | ||
|
||
return transform( | ||
_transformation, | ||
_u( | ||
transform(_inverse_transformation, a), | ||
transform(_inverse_transformation, b), | ||
), | ||
) | ||
|
||
return _displacement_fn, _shift_fn | ||
|
||
def _shift_fn(a: Tensor, b: Tensor, **_) -> Tensor: | ||
return a + b | ||
|
||
return _displacement_fn, _shift_fn | ||
|
||
def _displacement_fn( | ||
a: Tensor, | ||
b: Tensor, | ||
*, | ||
perturbation: Optional[Tensor] = None, | ||
**_, | ||
) -> Tensor: | ||
if len(a.shape) != 1: | ||
raise ValueError | ||
|
||
if a.shape != b.shape: | ||
raise ValueError | ||
|
||
displacement = torch.remainder(a - b + dimensions * 0.5, dimensions) | ||
|
||
displacement = displacement - dimensions * 0.5 | ||
|
||
if perturbation is not None: | ||
return _transform(displacement, perturbation) | ||
|
||
return displacement | ||
|
||
if remapped: | ||
|
||
def _shift_fn(a: Tensor, b: Tensor, **_) -> Tensor: | ||
return torch.remainder(a + b, dimensions) | ||
else: | ||
|
||
def _shift_fn(a: Tensor, b: Tensor, **_) -> Tensor: | ||
return a + b | ||
|
||
return _displacement_fn, _shift_fn |