-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Henry Isaacson
committed
Apr 22, 2024
1 parent
d23e2ef
commit 442e8c1
Showing
27 changed files
with
1,178 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import dataclasses | ||
from typing import List, Tuple, Type, TypeVar | ||
|
||
import optree | ||
|
||
T = TypeVar("T") | ||
|
||
|
||
def _dataclass(cls: Type[T]): | ||
def _set(self: dataclasses.dataclass, **kwargs): | ||
return dataclasses.replace(self, **kwargs) | ||
|
||
cls.set = _set | ||
|
||
dataclass_cls = dataclasses.dataclass(frozen=True)(cls) | ||
|
||
data_fields, metadata_fields = [], [] | ||
|
||
for name, kind in dataclass_cls.__dataclass_fields__.items(): | ||
if not kind.metadata.get("static", False): | ||
data_fields = [*data_fields, name] | ||
else: | ||
metadata_fields = [*metadata_fields, name] | ||
|
||
def _iterate_cls(_x) -> List[Tuple]: | ||
data_iterable = [] | ||
|
||
for k in data_fields: | ||
data_iterable.append(getattr(_x, k)) | ||
|
||
metadata_iterable = [] | ||
|
||
for k in metadata_fields: | ||
metadata_iterable.append(getattr(_x, k)) | ||
|
||
return [data_iterable, metadata_iterable] | ||
|
||
def _iterable_to_cls(meta, data): | ||
meta_args = tuple(zip(metadata_fields, meta)) | ||
data_args = tuple(zip(data_fields, data)) | ||
kwargs = dict(meta_args + data_args) | ||
|
||
return dataclass_cls(**kwargs) | ||
|
||
optree.register_pytree_node( | ||
dataclass_cls, | ||
_iterate_cls, | ||
_iterable_to_cls, | ||
"prescient.func", | ||
) | ||
|
||
return dataclass_cls |
54 changes: 54 additions & 0 deletions
54
src/beignet/func/_molecular_dynamics/_partition/__cell_dimensions.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import functools | ||
import math | ||
import operator | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
|
||
def _cell_dimensions( | ||
spatial_dimension: int, | ||
box_size: Tensor, | ||
minimum_cell_size: float, | ||
) -> (Tensor, Tensor, Tensor, int): | ||
if isinstance(box_size, int): | ||
box_size = float(box_size) | ||
|
||
if isinstance(box_size, Tensor): | ||
if box_size.dtype in {torch.int32, torch.int64}: | ||
box_size = float(box_size) | ||
|
||
cells_per_side = math.floor(box_size / minimum_cell_size) | ||
|
||
cell_size = box_size / cells_per_side | ||
|
||
cells_per_side = torch.tensor(cells_per_side, dtype=torch.int32) | ||
|
||
if isinstance(box_size, Tensor): | ||
if box_size.ndim == 1 or box_size.ndim == 2: | ||
assert box_size.size == spatial_dimension | ||
|
||
flattened_cells_per_side = torch.reshape( | ||
cells_per_side, | ||
[ | ||
-1, | ||
], | ||
) | ||
|
||
for cells in flattened_cells_per_side: | ||
if cells < 3: | ||
raise ValueError | ||
|
||
cell_count = functools.reduce( | ||
operator.mul, | ||
flattened_cells_per_side, | ||
1, | ||
) | ||
elif box_size.ndim == 0: | ||
cell_count = math.pow(cells_per_side, spatial_dimension) | ||
else: | ||
raise ValueError | ||
else: | ||
cell_count = math.pow(cells_per_side, spatial_dimension) | ||
|
||
return box_size, cell_size, cells_per_side, int(cell_count) |
34 changes: 34 additions & 0 deletions
34
src/beignet/func/_molecular_dynamics/_partition/__cell_list.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from typing import Callable, Dict | ||
|
||
from torch import Tensor | ||
|
||
from ..__dataclass import _dataclass | ||
from .._static_field import static_field | ||
|
||
|
||
@_dataclass | ||
class _CellList: | ||
exceeded_maximum_size: Tensor | ||
|
||
indexes: Tensor | ||
|
||
item_size: float = static_field() | ||
|
||
parameters: Dict[str, Tensor] | ||
|
||
positions_buffer: Tensor | ||
|
||
size: int = static_field() | ||
|
||
update_fn: Callable[..., "_CellList"] = static_field() | ||
|
||
def update(self, positions: Tensor, **kwargs) -> "_CellList": | ||
return self.update_fn( | ||
positions, | ||
[ | ||
self.size, | ||
self.exceeded_maximum_size, | ||
self.update_fn, | ||
], | ||
**kwargs, | ||
) |
17 changes: 17 additions & 0 deletions
17
src/beignet/func/_molecular_dynamics/_partition/__cell_list_function_list.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from typing import Callable | ||
|
||
from torch import Tensor | ||
|
||
from ..__dataclass import _dataclass | ||
from .._static_field import static_field | ||
from .__cell_list import _CellList | ||
|
||
|
||
@_dataclass | ||
class _CellListFunctionList: | ||
setup_fn: Callable[..., _CellList] = static_field() | ||
|
||
update_fn: Callable[[Tensor, _CellList | int], _CellList] = static_field() | ||
|
||
def __iter__(self): | ||
return iter([self.setup_fn, self.update_fn]) |
6 changes: 6 additions & 0 deletions
6
src/beignet/func/_molecular_dynamics/_partition/__cell_size.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
import torch | ||
from torch import Tensor | ||
|
||
|
||
def _cell_size(box: Tensor, minimum_unit_size: Tensor) -> Tensor: | ||
return box / torch.floor(box / minimum_unit_size) |
30 changes: 30 additions & 0 deletions
30
src/beignet/func/_molecular_dynamics/_partition/__hash_constants.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import math | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
|
||
def _hash_constants(spatial_dimensions: int, cells_per_side: Tensor) -> Tensor: | ||
if cells_per_side.ndim == 0: | ||
constants = [] | ||
|
||
for spatial_dimension in range(spatial_dimensions): | ||
constants = [ | ||
*constants, | ||
math.pow(cells_per_side, spatial_dimension), | ||
] | ||
|
||
return torch.tensor([constants], dtype=torch.int32) | ||
|
||
if cells_per_side.size == spatial_dimensions: | ||
cells_per_side = torch.concatenate( | ||
[ | ||
torch.tensor([[1]], dtype=torch.int32), | ||
cells_per_side[:, :-1], | ||
], | ||
dim=1, | ||
) | ||
|
||
return torch.cumprod(torch.flatten(cells_per_side), dim=0) | ||
|
||
raise ValueError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import torch | ||
from torch import Tensor | ||
|
||
|
||
def _iota(shape: tuple[int, ...], dim: int = 0, **kwargs) -> Tensor: | ||
dimensions = [] | ||
|
||
for index, _ in enumerate(shape): | ||
if index != dim: | ||
dimension = 1 | ||
else: | ||
dimension = shape[index] | ||
|
||
dimensions = [*dimensions, dimension] | ||
|
||
return torch.arange(shape[dim], **kwargs).view(*dimensions).expand(*shape) |
8 changes: 8 additions & 0 deletions
8
src/beignet/func/_molecular_dynamics/_partition/__is_neighbor_list_format_valid.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from .__neighbor_list_format import ( | ||
_NeighborListFormat, | ||
) | ||
|
||
|
||
def _is_neighbor_list_format_valid(neighbor_list_format: _NeighborListFormat): | ||
if neighbor_list_format not in list(_NeighborListFormat): | ||
raise ValueError |
12 changes: 12 additions & 0 deletions
12
src/beignet/func/_molecular_dynamics/_partition/__is_neighbor_list_sparse.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from .__neighbor_list_format import ( | ||
_NeighborListFormat, | ||
) | ||
|
||
|
||
def _is_neighbor_list_sparse( | ||
neighbor_list_format: _NeighborListFormat, | ||
) -> bool: | ||
return neighbor_list_format in { | ||
_NeighborListFormat.ORDERED_SPARSE, | ||
_NeighborListFormat.SPARSE, | ||
} |
12 changes: 12 additions & 0 deletions
12
src/beignet/func/_molecular_dynamics/_partition/__is_space_valid.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import torch | ||
from torch import Tensor | ||
|
||
|
||
def _is_space_valid(space: Tensor) -> bool: | ||
if space.ndim == 0 or space.ndim == 1: | ||
return torch.tensor([True]) | ||
|
||
if space.ndim == 2: | ||
return torch.tensor([torch.all(torch.triu(space) == space)]) | ||
|
||
return torch.tensor([False]) |
14 changes: 14 additions & 0 deletions
14
src/beignet/func/_molecular_dynamics/_partition/__map_bond.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import torch | ||
|
||
|
||
def _map_bond(distance_fn): | ||
def wrapper(start_positions, end_positions): | ||
batch_size = start_positions.shape[0] | ||
return torch.stack( | ||
[ | ||
distance_fn(start_positions[i], end_positions[i]) | ||
for i in range(batch_size) | ||
] | ||
) | ||
|
||
return wrapper |
2 changes: 2 additions & 0 deletions
2
src/beignet/func/_molecular_dynamics/_partition/__map_neighbor.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
def _map_neighbor(**kwargs): | ||
return |
35 changes: 35 additions & 0 deletions
35
src/beignet/func/_molecular_dynamics/_partition/__neighbor_list.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from typing import Any, Callable | ||
|
||
from torch import Tensor | ||
|
||
from ..__dataclass import _dataclass | ||
from .._static_field import static_field | ||
from .__cell_list import _CellList | ||
from .__neighbor_list_format import _NeighborListFormat | ||
from .__partition_error import _PartitionError | ||
|
||
|
||
@_dataclass | ||
class _NeighborList: | ||
buffer_fn: Callable[[Tensor, _CellList], _CellList] = static_field() | ||
|
||
indexes: Tensor | ||
|
||
item_size: float | None = static_field() | ||
|
||
maximum_size: int = static_field() | ||
|
||
format: _NeighborListFormat = static_field() | ||
|
||
partition_error: _PartitionError | ||
|
||
reference_positions: Tensor | ||
|
||
units_buffer_size: int | None = static_field() | ||
|
||
update_fn: Callable[[Tensor, "_NeighborList", Any], "_NeighborList"] = ( | ||
static_field() | ||
) | ||
|
||
def update(self, positions: Tensor, **kwargs) -> "_NeighborList": | ||
return self.update_fn(positions, self, **kwargs) |
7 changes: 7 additions & 0 deletions
7
src/beignet/func/_molecular_dynamics/_partition/__neighbor_list_format.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from enum import Enum | ||
|
||
|
||
class _NeighborListFormat(Enum): | ||
DENSE = 0 | ||
ORDERED_SPARSE = 1 | ||
SPARSE = 2 |
17 changes: 17 additions & 0 deletions
17
src/beignet/func/_molecular_dynamics/_partition/__neighbor_list_function_list.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from typing import Callable | ||
|
||
from torch import Tensor | ||
|
||
from ..__dataclass import _dataclass | ||
from .._static_field import static_field | ||
from .__neighbor_list import _NeighborList | ||
|
||
|
||
@_dataclass | ||
class _NeighborListFunctionList: | ||
setup_fn: Callable[..., _NeighborList] = static_field() | ||
|
||
update_fn: Callable[[Tensor, _NeighborList], _NeighborList] = static_field() | ||
|
||
def __iter__(self): | ||
return iter((self.setup_fn, self.update_fn)) |
10 changes: 10 additions & 0 deletions
10
src/beignet/func/_molecular_dynamics/_partition/__neighboring_cell_lists.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from typing import Generator | ||
|
||
import numpy | ||
|
||
|
||
def _neighboring_cell_lists( | ||
dimension: int, | ||
) -> Generator[numpy.ndarray, None, None]: | ||
for index in numpy.ndindex(*([3] * dimension)): | ||
yield numpy.array(index, dtype=numpy.int32) - 1 |
44 changes: 44 additions & 0 deletions
44
src/beignet/func/_molecular_dynamics/_partition/__normalize_cell_size.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import torch | ||
|
||
|
||
def _normalize_cell_size(box, cutoff): | ||
if box.ndim == 0: | ||
return cutoff / box | ||
|
||
if box.ndim == 1: | ||
return cutoff / torch.min(box) | ||
|
||
if box.ndim == 2: | ||
if box.shape[0] == 1: | ||
return 1 / torch.floor(box[0, 0] / cutoff) | ||
|
||
if box.shape[0] == 2: | ||
xx = box[0, 0] | ||
yy = box[1, 1] | ||
xy = box[0, 1] / yy | ||
|
||
nx = xx / torch.sqrt(1 + xy**2) | ||
ny = yy | ||
|
||
nmin = torch.floor(torch.min(torch.tensor([nx, ny])) / cutoff) | ||
|
||
return 1 / torch.where(nmin == 0, 1, nmin) | ||
|
||
if box.shape[0] == 3: | ||
xx = box[0, 0] | ||
yy = box[1, 1] | ||
zz = box[2, 2] | ||
xy = box[0, 1] / yy | ||
xz = box[0, 2] / zz | ||
yz = box[1, 2] / zz | ||
|
||
nx = xx / torch.sqrt(1 + xy**2 + (xy * yz - xz) ** 2) | ||
ny = yy / torch.sqrt(1 + yz**2) | ||
nz = zz | ||
|
||
nmin = torch.floor(torch.min(torch.tensor([nx, ny, nz])) / cutoff) | ||
return 1 / torch.where(nmin == 0, 1, nmin) | ||
else: | ||
raise ValueError | ||
else: | ||
raise ValueError |
Oops, something went wrong.