Skip to content

Commit

Permalink
interact
Browse files Browse the repository at this point in the history
  • Loading branch information
Henry Isaacson committed Apr 22, 2024
1 parent 4b15a0e commit 06197f5
Show file tree
Hide file tree
Showing 19 changed files with 987 additions and 0 deletions.
52 changes: 52 additions & 0 deletions src/beignet/func/_molecular_dynamics/__dataclass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import dataclasses
from typing import List, Tuple, Type, TypeVar

import optree

T = TypeVar("T")


def _dataclass(cls: Type[T]):
def _set(self: dataclasses.dataclass, **kwargs):
return dataclasses.replace(self, **kwargs)

cls.set = _set

dataclass_cls = dataclasses.dataclass(frozen=True)(cls)

data_fields, metadata_fields = [], []

for name, kind in dataclass_cls.__dataclass_fields__.items():
if not kind.metadata.get("static", False):
data_fields = [*data_fields, name]
else:
metadata_fields = [*metadata_fields, name]

def _iterate_cls(_x) -> List[Tuple]:
data_iterable = []

for k in data_fields:
data_iterable.append(getattr(_x, k))

metadata_iterable = []

for k in metadata_fields:
metadata_iterable.append(getattr(_x, k))

return [data_iterable, metadata_iterable]

def _iterable_to_cls(meta, data):
meta_args = tuple(zip(metadata_fields, meta))
data_args = tuple(zip(data_fields, data))
kwargs = dict(meta_args + data_args)

return dataclass_cls(**kwargs)

optree.register_pytree_node(
dataclass_cls,
_iterate_cls,
_iterable_to_cls,
"prescient.func",
)

return dataclass_cls
22 changes: 22 additions & 0 deletions src/beignet/func/_molecular_dynamics/__safe_sum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Iterable, Optional, Union

import torch
from torch import Tensor


def _safe_sum(
x: Tensor,
dim: Optional[Union[Iterable[int], int]] = None,
keepdim: bool = False,
):
match x:
case _ if x.is_complex():
promoted_dtype = torch.complex128
case _ if x.is_floating_point():
promoted_dtype = torch.float64
case _:
promoted_dtype = torch.int64

summation = torch.sum(x, dim=dim, dtype=promoted_dtype, keepdim=keepdim)

return summation.to(dtype=x.dtype)
26 changes: 26 additions & 0 deletions src/beignet/func/_molecular_dynamics/__zero_diagonal_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch
from torch import Tensor


def _zero_diagonal_mask(x: Tensor) -> Tensor:
"""Sets the diagonal of a matrix to zero."""
if x.shape[0] != x.shape[1]:
raise ValueError(
f"Diagonal mask can only mask square matrices. Found {x.shape[0]}x{x.shape[1]}."
)

if len(x.shape) > 3:
raise ValueError(
f"Diagonal mask can only mask rank-2 or rank-3 tensors. Found {len(x.shape)}."
)

n = x.shape[0]

x = torch.nan_to_num(x)

mask = 1.0 - torch.eye(n, device=x.device, dtype=x.dtype)

if len(x.shape) == 3:
mask = torch.reshape(mask, [n, n, 1])

return x * mask
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import Callable

from torch import Tensor


def _angle_interaction(
fn: Callable[..., Tensor],
displacement_fn: Callable[[Tensor, Tensor], Tensor],
):
raise NotImplementedError
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import functools
from typing import Callable, Optional

import torch
from torch import Tensor

from ..__safe_sum import _safe_sum
from .__merge_dictionaries import _merge_dictionaries
from .__to_bond_kind_parameters import _to_bond_kind_parameters


def _bond_interaction(
fn: Callable[..., Tensor],
displacement_fn: Callable[[Tensor, Tensor], Tensor],
static_bonds: Optional[Tensor] = None,
static_kinds: Optional[Tensor] = None,
ignore_unused_parameters: bool = False,
**static_kwargs,
) -> Callable[..., Tensor]:
merge_dictionaries_fn = functools.partial(
_merge_dictionaries,
ignore_unused_parameters=ignore_unused_parameters,
)

def mapped_fn(
positions: Tensor,
bonds: Optional[Tensor] = None,
kinds: Optional[Tensor] = None,
**kwargs,
) -> Tensor:
accumulator = torch.tensor(
0.0,
device=positions.device,
dtype=positions.dtype,
)

distance_fn = functools.partial(displacement_fn, **kwargs)

distance_fn = torch.func.vmap(distance_fn, 0, 0)

if bonds is not None:
parameters = merge_dictionaries_fn(static_kwargs, kwargs)

for name, parameter in parameters.items():
if kinds is not None:
parameters[name] = _to_bond_kind_parameters(
parameter,
kinds,
)

interactions = distance_fn(
positions[bonds[:, 0]],
positions[bonds[:, 1]],
)

interactions = _safe_sum(fn(interactions, **parameters))

accumulator = accumulator + interactions

if static_bonds is not None:
parameters = merge_dictionaries_fn(static_kwargs, kwargs)

for name, parameter in parameters.items():
if static_kinds is not None:
parameters[name] = _to_bond_kind_parameters(
parameter,
static_kinds,
)

interactions = distance_fn(
positions[static_bonds[:, 0]],
positions[static_bonds[:, 1]],
)

interactions = _safe_sum(fn(interactions, **parameters))

accumulator = accumulator + interactions

return accumulator

return mapped_fn
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import Callable

from torch import Tensor


def _dihedral_interaction(
fn: Callable[..., Tensor],
displacement_fn: Callable[[Tensor, Tensor], Tensor],
):
raise NotImplementedError
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import Callable, Dict

from torch import Tensor

from .._partition.__neighbor_list_format import _NeighborListFormat
from .__to_neighbor_list_kind_parameters import (
_to_neighbor_list_kind_parameters,
)
from .__to_neighbor_list_matrix_parameters import (
_to_neighbor_list_matrix_parameters,
)


def _kwargs_to_neighbor_list_parameters(
format: _NeighborListFormat,
indexes: Tensor,
species: Tensor,
kwargs: Dict[str, Tensor],
combinators: Dict[str, Callable],
) -> Dict[str, Tensor]:
parameters = {}

for name, parameter in kwargs.items():
if species is None or (isinstance(parameter, Tensor) and parameter.ndim == 1):
combinator = combinators.get(name, lambda x, y: 0.5 * (x + y))

parameters[name] = _to_neighbor_list_matrix_parameters(
format,
indexes,
parameter,
combinator,
)
else:
if name in combinators:
raise ValueError

parameters[name] = _to_neighbor_list_kind_parameters(
format,
indexes,
species,
parameter,
)

return parameters
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from typing import Callable, Dict, Union

import optree
from optree import PyTree
from torch import Tensor

from .__parameter_tree import _ParameterTree
from .__parameter_tree_kind import _ParameterTreeKind


def _kwargs_to_pair_parameters(
kwargs: Dict[str, Union[_ParameterTree, Tensor, float, PyTree]],
combinators: Dict[str, Callable],
kinds: Tensor | None = None,
) -> Dict[str, Tensor]:
parameters = {}

for name, parameter in kwargs.items():
if kinds is None:

def _combinator_fn(x: Tensor, y: Tensor) -> Tensor:
return (x + y) * 0.5

combinator = combinators.get(name, _combinator_fn)

match parameter:
case _ParameterTree():
match parameter.kind:
case _ParameterTreeKind.BOND | _ParameterTreeKind.SPACE:
parameters[name] = parameter.tree
case _ParameterTreeKind.PARTICLE:

def _particle_fn(_parameter: Tensor) -> Tensor:
return combinator(
_parameter[:, None, ...],
_parameter[None, :, ...],
)

parameters[name] = optree.tree_map(
_particle_fn,
parameter.tree,
)
case _:
message = f"""
parameter `kind` is `{parameter.kind}`. If `kinds` is `None` and a parameter is
an instance of `ParameterTree`, `kind` must be `ParameterTreeKind.BOND`,
`ParameterTreeKind.PARTICLE`, or `ParameterTreeKind.SPACE`.
""".replace("\n", " ")

raise ValueError(message)
case Tensor():
match parameter.ndim:
case 0 | 2:
parameters[name] = parameter
case 1:
parameters[name] = combinator(
parameter[:, None],
parameter[None, :],
)
case _:
message = f"""
parameter `ndim` is `{parameter.ndim}`. If `kinds` is `None` and a parameter is
an instance of `Tensor`, `ndim` must be in `0`, `1`, or `2`.
""".replace("\n", " ")

raise ValueError(message)
case float() | int():
parameters[name] = parameter
case _:
message = f"""
parameter `type` is {type(parameter)}. If `kinds` is `None`, a parameter must
be an instance of `ParameterTree`, `Tensor`, `float`, or `int`.
""".replace("\n", " ")

raise ValueError(message)
else:
if name in combinators:
raise ValueError

match parameter:
case _ParameterTree():
match parameter.kind:
case _ParameterTreeKind.SPACE:
parameters[name] = parameter.tree
case _ParameterTreeKind.KINDS:

def _kinds_fn(_parameter: Tensor) -> Tensor:
return _parameter[kinds]

parameters[name] = optree.tree_map(
_kinds_fn,
parameter.tree,
)
case _:
message = f"""
parameter `kind` is {parameter.kind}. If `kinds` is `None` and a parameter is
an instance of `ParameterTree`, `kind` must be `ParameterTreeKind.SPACE` or
`ParameterTreeKind.KINDS`.
""".replace("\n", " ")

raise ValueError(message)
case Tensor():
match parameter.ndim:
case 0:
parameters[name] = parameter
case 2:
parameters[name] = parameter[kinds]
case _:
message = f"""
parameter `ndim` is `{parameter.ndim}`. If `kinds` is not `None` and a
parameter is an instance of `Tensor`, `ndim` must be in `0`, `1`, or `2`.
""".replace("\n", " ")

raise ValueError(message)
case _:
parameters[name] = parameter

return parameters
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import Callable

from torch import Tensor


def _long_range_interaction(
fn: Callable[..., Tensor],
displacement_fn: Callable[[Tensor, Tensor], Tensor],
):
raise NotImplementedError
Loading

0 comments on commit 06197f5

Please sign in to comment.