-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Henry Isaacson
committed
Apr 22, 2024
1 parent
4b15a0e
commit 06197f5
Showing
19 changed files
with
987 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import dataclasses | ||
from typing import List, Tuple, Type, TypeVar | ||
|
||
import optree | ||
|
||
T = TypeVar("T") | ||
|
||
|
||
def _dataclass(cls: Type[T]): | ||
def _set(self: dataclasses.dataclass, **kwargs): | ||
return dataclasses.replace(self, **kwargs) | ||
|
||
cls.set = _set | ||
|
||
dataclass_cls = dataclasses.dataclass(frozen=True)(cls) | ||
|
||
data_fields, metadata_fields = [], [] | ||
|
||
for name, kind in dataclass_cls.__dataclass_fields__.items(): | ||
if not kind.metadata.get("static", False): | ||
data_fields = [*data_fields, name] | ||
else: | ||
metadata_fields = [*metadata_fields, name] | ||
|
||
def _iterate_cls(_x) -> List[Tuple]: | ||
data_iterable = [] | ||
|
||
for k in data_fields: | ||
data_iterable.append(getattr(_x, k)) | ||
|
||
metadata_iterable = [] | ||
|
||
for k in metadata_fields: | ||
metadata_iterable.append(getattr(_x, k)) | ||
|
||
return [data_iterable, metadata_iterable] | ||
|
||
def _iterable_to_cls(meta, data): | ||
meta_args = tuple(zip(metadata_fields, meta)) | ||
data_args = tuple(zip(data_fields, data)) | ||
kwargs = dict(meta_args + data_args) | ||
|
||
return dataclass_cls(**kwargs) | ||
|
||
optree.register_pytree_node( | ||
dataclass_cls, | ||
_iterate_cls, | ||
_iterable_to_cls, | ||
"prescient.func", | ||
) | ||
|
||
return dataclass_cls |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from typing import Iterable, Optional, Union | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
|
||
def _safe_sum( | ||
x: Tensor, | ||
dim: Optional[Union[Iterable[int], int]] = None, | ||
keepdim: bool = False, | ||
): | ||
match x: | ||
case _ if x.is_complex(): | ||
promoted_dtype = torch.complex128 | ||
case _ if x.is_floating_point(): | ||
promoted_dtype = torch.float64 | ||
case _: | ||
promoted_dtype = torch.int64 | ||
|
||
summation = torch.sum(x, dim=dim, dtype=promoted_dtype, keepdim=keepdim) | ||
|
||
return summation.to(dtype=x.dtype) |
26 changes: 26 additions & 0 deletions
26
src/beignet/func/_molecular_dynamics/__zero_diagonal_mask.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import torch | ||
from torch import Tensor | ||
|
||
|
||
def _zero_diagonal_mask(x: Tensor) -> Tensor: | ||
"""Sets the diagonal of a matrix to zero.""" | ||
if x.shape[0] != x.shape[1]: | ||
raise ValueError( | ||
f"Diagonal mask can only mask square matrices. Found {x.shape[0]}x{x.shape[1]}." | ||
) | ||
|
||
if len(x.shape) > 3: | ||
raise ValueError( | ||
f"Diagonal mask can only mask rank-2 or rank-3 tensors. Found {len(x.shape)}." | ||
) | ||
|
||
n = x.shape[0] | ||
|
||
x = torch.nan_to_num(x) | ||
|
||
mask = 1.0 - torch.eye(n, device=x.device, dtype=x.dtype) | ||
|
||
if len(x.shape) == 3: | ||
mask = torch.reshape(mask, [n, n, 1]) | ||
|
||
return x * mask |
10 changes: 10 additions & 0 deletions
10
src/beignet/func/_molecular_dynamics/_interact/__angle_interaction.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from typing import Callable | ||
|
||
from torch import Tensor | ||
|
||
|
||
def _angle_interaction( | ||
fn: Callable[..., Tensor], | ||
displacement_fn: Callable[[Tensor, Tensor], Tensor], | ||
): | ||
raise NotImplementedError |
81 changes: 81 additions & 0 deletions
81
src/beignet/func/_molecular_dynamics/_interact/__bond_interaction.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import functools | ||
from typing import Callable, Optional | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
from ..__safe_sum import _safe_sum | ||
from .__merge_dictionaries import _merge_dictionaries | ||
from .__to_bond_kind_parameters import _to_bond_kind_parameters | ||
|
||
|
||
def _bond_interaction( | ||
fn: Callable[..., Tensor], | ||
displacement_fn: Callable[[Tensor, Tensor], Tensor], | ||
static_bonds: Optional[Tensor] = None, | ||
static_kinds: Optional[Tensor] = None, | ||
ignore_unused_parameters: bool = False, | ||
**static_kwargs, | ||
) -> Callable[..., Tensor]: | ||
merge_dictionaries_fn = functools.partial( | ||
_merge_dictionaries, | ||
ignore_unused_parameters=ignore_unused_parameters, | ||
) | ||
|
||
def mapped_fn( | ||
positions: Tensor, | ||
bonds: Optional[Tensor] = None, | ||
kinds: Optional[Tensor] = None, | ||
**kwargs, | ||
) -> Tensor: | ||
accumulator = torch.tensor( | ||
0.0, | ||
device=positions.device, | ||
dtype=positions.dtype, | ||
) | ||
|
||
distance_fn = functools.partial(displacement_fn, **kwargs) | ||
|
||
distance_fn = torch.func.vmap(distance_fn, 0, 0) | ||
|
||
if bonds is not None: | ||
parameters = merge_dictionaries_fn(static_kwargs, kwargs) | ||
|
||
for name, parameter in parameters.items(): | ||
if kinds is not None: | ||
parameters[name] = _to_bond_kind_parameters( | ||
parameter, | ||
kinds, | ||
) | ||
|
||
interactions = distance_fn( | ||
positions[bonds[:, 0]], | ||
positions[bonds[:, 1]], | ||
) | ||
|
||
interactions = _safe_sum(fn(interactions, **parameters)) | ||
|
||
accumulator = accumulator + interactions | ||
|
||
if static_bonds is not None: | ||
parameters = merge_dictionaries_fn(static_kwargs, kwargs) | ||
|
||
for name, parameter in parameters.items(): | ||
if static_kinds is not None: | ||
parameters[name] = _to_bond_kind_parameters( | ||
parameter, | ||
static_kinds, | ||
) | ||
|
||
interactions = distance_fn( | ||
positions[static_bonds[:, 0]], | ||
positions[static_bonds[:, 1]], | ||
) | ||
|
||
interactions = _safe_sum(fn(interactions, **parameters)) | ||
|
||
accumulator = accumulator + interactions | ||
|
||
return accumulator | ||
|
||
return mapped_fn |
10 changes: 10 additions & 0 deletions
10
src/beignet/func/_molecular_dynamics/_interact/__dihedral_interaction.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from typing import Callable | ||
|
||
from torch import Tensor | ||
|
||
|
||
def _dihedral_interaction( | ||
fn: Callable[..., Tensor], | ||
displacement_fn: Callable[[Tensor, Tensor], Tensor], | ||
): | ||
raise NotImplementedError |
44 changes: 44 additions & 0 deletions
44
src/beignet/func/_molecular_dynamics/_interact/__kwargs_to_neighbor_list_parameters.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from typing import Callable, Dict | ||
|
||
from torch import Tensor | ||
|
||
from .._partition.__neighbor_list_format import _NeighborListFormat | ||
from .__to_neighbor_list_kind_parameters import ( | ||
_to_neighbor_list_kind_parameters, | ||
) | ||
from .__to_neighbor_list_matrix_parameters import ( | ||
_to_neighbor_list_matrix_parameters, | ||
) | ||
|
||
|
||
def _kwargs_to_neighbor_list_parameters( | ||
format: _NeighborListFormat, | ||
indexes: Tensor, | ||
species: Tensor, | ||
kwargs: Dict[str, Tensor], | ||
combinators: Dict[str, Callable], | ||
) -> Dict[str, Tensor]: | ||
parameters = {} | ||
|
||
for name, parameter in kwargs.items(): | ||
if species is None or (isinstance(parameter, Tensor) and parameter.ndim == 1): | ||
combinator = combinators.get(name, lambda x, y: 0.5 * (x + y)) | ||
|
||
parameters[name] = _to_neighbor_list_matrix_parameters( | ||
format, | ||
indexes, | ||
parameter, | ||
combinator, | ||
) | ||
else: | ||
if name in combinators: | ||
raise ValueError | ||
|
||
parameters[name] = _to_neighbor_list_kind_parameters( | ||
format, | ||
indexes, | ||
species, | ||
parameter, | ||
) | ||
|
||
return parameters |
118 changes: 118 additions & 0 deletions
118
src/beignet/func/_molecular_dynamics/_interact/__kwargs_to_pair_parameters.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
from typing import Callable, Dict, Union | ||
|
||
import optree | ||
from optree import PyTree | ||
from torch import Tensor | ||
|
||
from .__parameter_tree import _ParameterTree | ||
from .__parameter_tree_kind import _ParameterTreeKind | ||
|
||
|
||
def _kwargs_to_pair_parameters( | ||
kwargs: Dict[str, Union[_ParameterTree, Tensor, float, PyTree]], | ||
combinators: Dict[str, Callable], | ||
kinds: Tensor | None = None, | ||
) -> Dict[str, Tensor]: | ||
parameters = {} | ||
|
||
for name, parameter in kwargs.items(): | ||
if kinds is None: | ||
|
||
def _combinator_fn(x: Tensor, y: Tensor) -> Tensor: | ||
return (x + y) * 0.5 | ||
|
||
combinator = combinators.get(name, _combinator_fn) | ||
|
||
match parameter: | ||
case _ParameterTree(): | ||
match parameter.kind: | ||
case _ParameterTreeKind.BOND | _ParameterTreeKind.SPACE: | ||
parameters[name] = parameter.tree | ||
case _ParameterTreeKind.PARTICLE: | ||
|
||
def _particle_fn(_parameter: Tensor) -> Tensor: | ||
return combinator( | ||
_parameter[:, None, ...], | ||
_parameter[None, :, ...], | ||
) | ||
|
||
parameters[name] = optree.tree_map( | ||
_particle_fn, | ||
parameter.tree, | ||
) | ||
case _: | ||
message = f""" | ||
parameter `kind` is `{parameter.kind}`. If `kinds` is `None` and a parameter is | ||
an instance of `ParameterTree`, `kind` must be `ParameterTreeKind.BOND`, | ||
`ParameterTreeKind.PARTICLE`, or `ParameterTreeKind.SPACE`. | ||
""".replace("\n", " ") | ||
|
||
raise ValueError(message) | ||
case Tensor(): | ||
match parameter.ndim: | ||
case 0 | 2: | ||
parameters[name] = parameter | ||
case 1: | ||
parameters[name] = combinator( | ||
parameter[:, None], | ||
parameter[None, :], | ||
) | ||
case _: | ||
message = f""" | ||
parameter `ndim` is `{parameter.ndim}`. If `kinds` is `None` and a parameter is | ||
an instance of `Tensor`, `ndim` must be in `0`, `1`, or `2`. | ||
""".replace("\n", " ") | ||
|
||
raise ValueError(message) | ||
case float() | int(): | ||
parameters[name] = parameter | ||
case _: | ||
message = f""" | ||
parameter `type` is {type(parameter)}. If `kinds` is `None`, a parameter must | ||
be an instance of `ParameterTree`, `Tensor`, `float`, or `int`. | ||
""".replace("\n", " ") | ||
|
||
raise ValueError(message) | ||
else: | ||
if name in combinators: | ||
raise ValueError | ||
|
||
match parameter: | ||
case _ParameterTree(): | ||
match parameter.kind: | ||
case _ParameterTreeKind.SPACE: | ||
parameters[name] = parameter.tree | ||
case _ParameterTreeKind.KINDS: | ||
|
||
def _kinds_fn(_parameter: Tensor) -> Tensor: | ||
return _parameter[kinds] | ||
|
||
parameters[name] = optree.tree_map( | ||
_kinds_fn, | ||
parameter.tree, | ||
) | ||
case _: | ||
message = f""" | ||
parameter `kind` is {parameter.kind}. If `kinds` is `None` and a parameter is | ||
an instance of `ParameterTree`, `kind` must be `ParameterTreeKind.SPACE` or | ||
`ParameterTreeKind.KINDS`. | ||
""".replace("\n", " ") | ||
|
||
raise ValueError(message) | ||
case Tensor(): | ||
match parameter.ndim: | ||
case 0: | ||
parameters[name] = parameter | ||
case 2: | ||
parameters[name] = parameter[kinds] | ||
case _: | ||
message = f""" | ||
parameter `ndim` is `{parameter.ndim}`. If `kinds` is not `None` and a | ||
parameter is an instance of `Tensor`, `ndim` must be in `0`, `1`, or `2`. | ||
""".replace("\n", " ") | ||
|
||
raise ValueError(message) | ||
case _: | ||
parameters[name] = parameter | ||
|
||
return parameters |
10 changes: 10 additions & 0 deletions
10
src/beignet/func/_molecular_dynamics/_interact/__long_range_interaction.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from typing import Callable | ||
|
||
from torch import Tensor | ||
|
||
|
||
def _long_range_interaction( | ||
fn: Callable[..., Tensor], | ||
displacement_fn: Callable[[Tensor, Tensor], Tensor], | ||
): | ||
raise NotImplementedError |
Oops, something went wrong.