Skip to content

Commit

Permalink
space
Browse files Browse the repository at this point in the history
  • Loading branch information
Henry Isaacson committed Apr 19, 2024
1 parent 3e4ffd0 commit 0a5e946
Show file tree
Hide file tree
Showing 4 changed files with 328 additions and 0 deletions.
Empty file added src/beignet/func/__init__.py
Empty file.
Empty file.
Empty file.
328 changes: 328 additions & 0 deletions src/beignet/func/_molecular_dynamics/_space/_space.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,328 @@
from typing import (
Callable,
Optional,
Tuple,
TypeVar,
)

import torch
from torch import Tensor

from .__inverse_transform import _inverse_transform
from .__transform import _transform
from ._transform import transform

T = TypeVar("T")


def space(
dimensions: Optional[Tensor] = None,
*,
normalized: bool = True,
parallelepiped: bool = True,
remapped: bool = True,
) -> Tuple[Callable, Callable]:
r"""Define a simulation space.
This function is fundamental in constructing simulation spaces derived from
subsets of $\mathbb{R}^{D}$ (where $D = 1$, $2$, or $3$) and is
instrumental in setting up simulation environments with specific
characteristics (e.g., periodic boundary conditions). The function returns
a a displacement function and a shift function to compute particle
interactions and movements in space.
This function supports deformation of the simulation cell, crucial for
certain types of simulations, such as those involving finite deformations
or the computation of elastic constants.
Parameters
----------
dimensions : Optional[Tensor], default=None
Dimensions of the simulation space. Interpretation varies based on the
value of `parallelepiped`. If `parallelepiped` is `True`, must be an
affine transformation, $T$, specified in one of three ways: a cube,
$L$; an orthorhombic unit cell, $[L_{x}, L_{y}, L_{z}]$; or a triclinic
cell, upper triangular matrix. If `parallelepiped` is `False`, must be
the edge lengths. If `None`, the simulation space has free boudnary
conditions.
normalized : bool, default=True
If `True`, positions are stored in the unit cube. Displacements and
shifts are computed in a normalized simulation space and can be
transformed back to real simulation space using the provided affine
transformation matrix. If `False`, positions are expressed and
computations performed directly in the real simulation space.
parallelepiped : bool, default=True
If `True`, the simulation space is defined as a ${1, 2, 3}$-dimensional
parallelepiped with periodic boundary conditions. If `False`, the space
is defined on a ${1, 2, 3}$-dimensional hypercube.
remapped : bool, default=True
If `True`, positions and displacements are remapped to stay in the
bounds of the defined simulation space. A rempapped simulation space is
topologically equivalent to a torus, ensuring that particles exiting
one boundary re-enter from the opposite side. This is particularly
relevant for simulation spaces with periodic boundary conditions.
Returns
-------
Tuple[Callable[[Tensor, Tensor], Tensor], Callable[[Tensor, Tensor], Tensor]]
A tuple containing two functions:
1. The displacement function, $\overrightarrow{d}$, measures the
difference between two points in the simulation space, factoring in
the geometry and boundary conditions. This function is used to
calculate particle interactions and dynamics.
2. The shift function, $u$, applies a displacement vector to a point
in the space, effectively moving it. This function is used to
update simulated particle positions.
Examples
--------
transformation = torch.tensor([10.0])
displacement_fn, shift_fn = space(
transformation,
normalized=False,
)
normalized_displacement_fn, normalized_shift_fn = space(
transformation,
normalized=True,
)
normalized_position = torch.rand([4, 3])
position = transformation * normalized_position
displacement = torch.randn([4, 3])
torch.testing.assert_close(
displacement_fn(position[0], position[1]),
normalized_displacement_fn(
normalized_position[0],
normalized_position[1],
),
)
"""
if isinstance(dimensions, (int, float)):
dimensions = torch.tensor([dimensions])

if dimensions is None:

def _displacement_fn(
a: Tensor,
b: Tensor,
*,
perturbation: Optional[Tensor] = None,
**_,
) -> Tensor:
if len(a.shape) != 1:
raise ValueError

if a.shape != b.shape:
raise ValueError

if perturbation is not None:
return _transform(a - b, perturbation)

return a - b

def _shift_fn(a: Tensor, b: Tensor, **_) -> Tensor:
return a + b

return _displacement_fn, _shift_fn

if parallelepiped:
inverse_transformation = _inverse_transform(dimensions)

if normalized:

def _displacement_fn(
a: Tensor,
b: Tensor,
*,
perturbation: Optional[Tensor] = None,
**kwargs,
) -> Tensor:
_transformation = dimensions

_inverse_transformation = inverse_transformation

if "transformation" in kwargs:
_transformation = kwargs["transformation"]

if "updated_transformation" in kwargs:
_transformation = kwargs["updated_transformation"]

if len(a.shape) != 1:
raise ValueError

if a.shape != b.shape:
raise ValueError

displacement = a - b

displacement = torch.remainder(displacement + 1.0 * 0.5, 1.0)

displacement = displacement - 1.0 * 0.5

displacement = transform(_transformation, displacement)

if perturbation is not None:
return _transform(displacement, perturbation)

return displacement

if remapped:

def _u(a: Tensor, b: Tensor) -> Tensor:
return torch.remainder(a + b, 1.0)

def _shift_fn(a: Tensor, b: Tensor, **kwargs) -> Tensor:
_transformation = dimensions

_inverse_transformation = inverse_transformation

if "transformation" in kwargs:
_transformation = kwargs["transformation"]

_inverse_transformation = _inverse_transform(
_transformation
)

if "updated_transformation" in kwargs:
_transformation = kwargs["updated_transformation"]

return _u(a, transform(_inverse_transformation, b))

return _displacement_fn, _shift_fn

def _shift_fn(a: Tensor, b: Tensor, **kwargs) -> Tensor:
_transformation = dimensions

_inverse_transformation = inverse_transformation

if "transformation" in kwargs:
_transformation = kwargs["transformation"]

_inverse_transformation = _inverse_transform(
_transformation,
)

if "updated_transformation" in kwargs:
_transformation = kwargs["updated_transformation"]

return a + transform(_inverse_transformation, b)

return _displacement_fn, _shift_fn

def _displacement_fn(
a: Tensor,
b: Tensor,
*,
perturbation: Optional[Tensor] = None,
**kwargs,
) -> Tensor:
_transformation = dimensions

_inverse_transformation = inverse_transformation

if "transformation" in kwargs:
_transformation = kwargs["transformation"]

_inverse_transformation = _inverse_transform(_transformation)

if "updated_transformation" in kwargs:
_transformation = kwargs["updated_transformation"]

a = transform(_inverse_transformation, a)
b = transform(_inverse_transformation, b)

if len(a.shape) != 1:
raise ValueError

if a.shape != b.shape:
raise ValueError

displacement = a - b

displacement = torch.remainder(displacement + 1.0 * 0.5, 1.0)

displacement = displacement - 1.0 * 0.5

displacement = transform(_transformation, displacement)

if perturbation is not None:
return _transform(displacement, perturbation)

return displacement

if remapped:

def _u(a: Tensor, b: Tensor) -> Tensor:
return torch.remainder(a + b, 1.0)

def _shift_fn(a: Tensor, b: Tensor, **kwargs) -> Tensor:
_transformation = dimensions

_inverse_transformation = inverse_transformation

if "transformation" in kwargs:
_transformation = kwargs["transformation"]

_inverse_transformation = _inverse_transform(
_transformation,
)

if "updated_transformation" in kwargs:
_transformation = kwargs["updated_transformation"]

return transform(
_transformation,
_u(
transform(_inverse_transformation, a),
transform(_inverse_transformation, b),
),
)

return _displacement_fn, _shift_fn

def _shift_fn(a: Tensor, b: Tensor, **_) -> Tensor:
return a + b

return _displacement_fn, _shift_fn

def _displacement_fn(
a: Tensor,
b: Tensor,
*,
perturbation: Optional[Tensor] = None,
**_,
) -> Tensor:
if len(a.shape) != 1:
raise ValueError

if a.shape != b.shape:
raise ValueError

displacement = torch.remainder(a - b + dimensions * 0.5, dimensions)

displacement = displacement - dimensions * 0.5

if perturbation is not None:
return _transform(displacement, perturbation)

return displacement

if remapped:

def _shift_fn(a: Tensor, b: Tensor, **_) -> Tensor:
return torch.remainder(a + b, dimensions)
else:

def _shift_fn(a: Tensor, b: Tensor, **_) -> Tensor:
return a + b

return _displacement_fn, _shift_fn

0 comments on commit 0a5e946

Please sign in to comment.