Skip to content

Commit

Permalink
partition
Browse files Browse the repository at this point in the history
  • Loading branch information
Henry Isaacson committed Apr 22, 2024
1 parent d23e2ef commit 442e8c1
Show file tree
Hide file tree
Showing 27 changed files with 1,178 additions and 0 deletions.
52 changes: 52 additions & 0 deletions src/beignet/func/_molecular_dynamics/__dataclass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import dataclasses
from typing import List, Tuple, Type, TypeVar

import optree

T = TypeVar("T")


def _dataclass(cls: Type[T]):
def _set(self: dataclasses.dataclass, **kwargs):
return dataclasses.replace(self, **kwargs)

cls.set = _set

dataclass_cls = dataclasses.dataclass(frozen=True)(cls)

data_fields, metadata_fields = [], []

for name, kind in dataclass_cls.__dataclass_fields__.items():
if not kind.metadata.get("static", False):
data_fields = [*data_fields, name]
else:
metadata_fields = [*metadata_fields, name]

def _iterate_cls(_x) -> List[Tuple]:
data_iterable = []

for k in data_fields:
data_iterable.append(getattr(_x, k))

metadata_iterable = []

for k in metadata_fields:
metadata_iterable.append(getattr(_x, k))

return [data_iterable, metadata_iterable]

def _iterable_to_cls(meta, data):
meta_args = tuple(zip(metadata_fields, meta))
data_args = tuple(zip(data_fields, data))
kwargs = dict(meta_args + data_args)

return dataclass_cls(**kwargs)

optree.register_pytree_node(
dataclass_cls,
_iterate_cls,
_iterable_to_cls,
"prescient.func",
)

return dataclass_cls
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import functools
import math
import operator

import torch
from torch import Tensor


def _cell_dimensions(
spatial_dimension: int,
box_size: Tensor,
minimum_cell_size: float,
) -> (Tensor, Tensor, Tensor, int):
if isinstance(box_size, int):
box_size = float(box_size)

if isinstance(box_size, Tensor):
if box_size.dtype in {torch.int32, torch.int64}:
box_size = float(box_size)

cells_per_side = math.floor(box_size / minimum_cell_size)

cell_size = box_size / cells_per_side

cells_per_side = torch.tensor(cells_per_side, dtype=torch.int32)

if isinstance(box_size, Tensor):
if box_size.ndim == 1 or box_size.ndim == 2:
assert box_size.size == spatial_dimension

flattened_cells_per_side = torch.reshape(
cells_per_side,
[
-1,
],
)

for cells in flattened_cells_per_side:
if cells < 3:
raise ValueError

cell_count = functools.reduce(
operator.mul,
flattened_cells_per_side,
1,
)
elif box_size.ndim == 0:
cell_count = math.pow(cells_per_side, spatial_dimension)
else:
raise ValueError
else:
cell_count = math.pow(cells_per_side, spatial_dimension)

return box_size, cell_size, cells_per_side, int(cell_count)
34 changes: 34 additions & 0 deletions src/beignet/func/_molecular_dynamics/_partition/__cell_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Callable, Dict

from torch import Tensor

from ..__dataclass import _dataclass
from .._static_field import static_field


@_dataclass
class _CellList:
exceeded_maximum_size: Tensor

indexes: Tensor

item_size: float = static_field()

parameters: Dict[str, Tensor]

positions_buffer: Tensor

size: int = static_field()

update_fn: Callable[..., "_CellList"] = static_field()

def update(self, positions: Tensor, **kwargs) -> "_CellList":
return self.update_fn(
positions,
[
self.size,
self.exceeded_maximum_size,
self.update_fn,
],
**kwargs,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import Callable

from torch import Tensor

from ..__dataclass import _dataclass
from .._static_field import static_field
from .__cell_list import _CellList


@_dataclass
class _CellListFunctionList:
setup_fn: Callable[..., _CellList] = static_field()

update_fn: Callable[[Tensor, _CellList | int], _CellList] = static_field()

def __iter__(self):
return iter([self.setup_fn, self.update_fn])
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import torch
from torch import Tensor


def _cell_size(box: Tensor, minimum_unit_size: Tensor) -> Tensor:
return box / torch.floor(box / minimum_unit_size)
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import math

import torch
from torch import Tensor


def _hash_constants(spatial_dimensions: int, cells_per_side: Tensor) -> Tensor:
if cells_per_side.ndim == 0:
constants = []

for spatial_dimension in range(spatial_dimensions):
constants = [
*constants,
math.pow(cells_per_side, spatial_dimension),
]

return torch.tensor([constants], dtype=torch.int32)

if cells_per_side.size == spatial_dimensions:
cells_per_side = torch.concatenate(
[
torch.tensor([[1]], dtype=torch.int32),
cells_per_side[:, :-1],
],
dim=1,
)

return torch.cumprod(torch.flatten(cells_per_side), dim=0)

raise ValueError
16 changes: 16 additions & 0 deletions src/beignet/func/_molecular_dynamics/_partition/__iota.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch
from torch import Tensor


def _iota(shape: tuple[int, ...], dim: int = 0, **kwargs) -> Tensor:
dimensions = []

for index, _ in enumerate(shape):
if index != dim:
dimension = 1
else:
dimension = shape[index]

dimensions = [*dimensions, dimension]

return torch.arange(shape[dim], **kwargs).view(*dimensions).expand(*shape)
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .__neighbor_list_format import (
_NeighborListFormat,
)


def _is_neighbor_list_format_valid(neighbor_list_format: _NeighborListFormat):
if neighbor_list_format not in list(_NeighborListFormat):
raise ValueError
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from .__neighbor_list_format import (
_NeighborListFormat,
)


def _is_neighbor_list_sparse(
neighbor_list_format: _NeighborListFormat,
) -> bool:
return neighbor_list_format in {
_NeighborListFormat.ORDERED_SPARSE,
_NeighborListFormat.SPARSE,
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import torch
from torch import Tensor


def _is_space_valid(space: Tensor) -> bool:
if space.ndim == 0 or space.ndim == 1:
return torch.tensor([True])

if space.ndim == 2:
return torch.tensor([torch.all(torch.triu(space) == space)])

return torch.tensor([False])
14 changes: 14 additions & 0 deletions src/beignet/func/_molecular_dynamics/_partition/__map_bond.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch


def _map_bond(distance_fn):
def wrapper(start_positions, end_positions):
batch_size = start_positions.shape[0]
return torch.stack(
[
distance_fn(start_positions[i], end_positions[i])
for i in range(batch_size)
]
)

return wrapper
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def _map_neighbor(**kwargs):
return
35 changes: 35 additions & 0 deletions src/beignet/func/_molecular_dynamics/_partition/__neighbor_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Any, Callable

from torch import Tensor

from ..__dataclass import _dataclass
from .._static_field import static_field
from .__cell_list import _CellList
from .__neighbor_list_format import _NeighborListFormat
from .__partition_error import _PartitionError


@_dataclass
class _NeighborList:
buffer_fn: Callable[[Tensor, _CellList], _CellList] = static_field()

indexes: Tensor

item_size: float | None = static_field()

maximum_size: int = static_field()

format: _NeighborListFormat = static_field()

partition_error: _PartitionError

reference_positions: Tensor

units_buffer_size: int | None = static_field()

update_fn: Callable[[Tensor, "_NeighborList", Any], "_NeighborList"] = (
static_field()
)

def update(self, positions: Tensor, **kwargs) -> "_NeighborList":
return self.update_fn(positions, self, **kwargs)
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from enum import Enum


class _NeighborListFormat(Enum):
DENSE = 0
ORDERED_SPARSE = 1
SPARSE = 2
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import Callable

from torch import Tensor

from ..__dataclass import _dataclass
from .._static_field import static_field
from .__neighbor_list import _NeighborList


@_dataclass
class _NeighborListFunctionList:
setup_fn: Callable[..., _NeighborList] = static_field()

update_fn: Callable[[Tensor, _NeighborList], _NeighborList] = static_field()

def __iter__(self):
return iter((self.setup_fn, self.update_fn))
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import Generator

import numpy


def _neighboring_cell_lists(
dimension: int,
) -> Generator[numpy.ndarray, None, None]:
for index in numpy.ndindex(*([3] * dimension)):
yield numpy.array(index, dtype=numpy.int32) - 1
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch


def _normalize_cell_size(box, cutoff):
if box.ndim == 0:
return cutoff / box

if box.ndim == 1:
return cutoff / torch.min(box)

if box.ndim == 2:
if box.shape[0] == 1:
return 1 / torch.floor(box[0, 0] / cutoff)

if box.shape[0] == 2:
xx = box[0, 0]
yy = box[1, 1]
xy = box[0, 1] / yy

nx = xx / torch.sqrt(1 + xy**2)
ny = yy

nmin = torch.floor(torch.min(torch.tensor([nx, ny])) / cutoff)

return 1 / torch.where(nmin == 0, 1, nmin)

if box.shape[0] == 3:
xx = box[0, 0]
yy = box[1, 1]
zz = box[2, 2]
xy = box[0, 1] / yy
xz = box[0, 2] / zz
yz = box[1, 2] / zz

nx = xx / torch.sqrt(1 + xy**2 + (xy * yz - xz) ** 2)
ny = yy / torch.sqrt(1 + yz**2)
nz = zz

nmin = torch.floor(torch.min(torch.tensor([nx, ny, nz])) / cutoff)
return 1 / torch.where(nmin == 0, 1, nmin)
else:
raise ValueError
else:
raise ValueError
Loading

0 comments on commit 442e8c1

Please sign in to comment.