From 06197f511a3f75cb88f0cd514a77db8a9e46d84f Mon Sep 17 00:00:00 2001 From: Henry Isaacson Date: Mon, 22 Apr 2024 11:23:12 -0400 Subject: [PATCH] interact --- .../func/_molecular_dynamics/__dataclass.py | 52 ++++++ .../func/_molecular_dynamics/__safe_sum.py | 22 +++ .../__zero_diagonal_mask.py | 26 +++ .../_interact/__angle_interaction.py | 10 ++ .../_interact/__bond_interaction.py | 81 +++++++++ .../_interact/__dihedral_interaction.py | 10 ++ .../__kwargs_to_neighbor_list_parameters.py | 44 +++++ .../_interact/__kwargs_to_pair_parameters.py | 118 +++++++++++++ .../_interact/__long_range_interaction.py | 10 ++ .../_interact/__map_product.py | 28 +++ .../_interact/__merge_dictionaries.py | 20 +++ .../_interact/__mesh_interaction.py | 10 ++ .../_interact/__neighbor_list_interaction.py | 120 +++++++++++++ .../_interact/__pair_interaction.py | 166 ++++++++++++++++++ .../_interact/__parameter_tree.py | 12 ++ .../_interact/__parameter_tree_kind.py | 8 + .../_interact/__to_bond_kind_parameters.py | 42 +++++ .../__to_neighbor_list_kind_parameters.py | 85 +++++++++ .../__to_neighbor_list_matrix_parameters.py | 123 +++++++++++++ 19 files changed, 987 insertions(+) create mode 100644 src/beignet/func/_molecular_dynamics/__dataclass.py create mode 100644 src/beignet/func/_molecular_dynamics/__safe_sum.py create mode 100644 src/beignet/func/_molecular_dynamics/__zero_diagonal_mask.py create mode 100644 src/beignet/func/_molecular_dynamics/_interact/__angle_interaction.py create mode 100644 src/beignet/func/_molecular_dynamics/_interact/__bond_interaction.py create mode 100644 src/beignet/func/_molecular_dynamics/_interact/__dihedral_interaction.py create mode 100644 src/beignet/func/_molecular_dynamics/_interact/__kwargs_to_neighbor_list_parameters.py create mode 100644 src/beignet/func/_molecular_dynamics/_interact/__kwargs_to_pair_parameters.py create mode 100644 src/beignet/func/_molecular_dynamics/_interact/__long_range_interaction.py create mode 100644 src/beignet/func/_molecular_dynamics/_interact/__map_product.py create mode 100644 src/beignet/func/_molecular_dynamics/_interact/__merge_dictionaries.py create mode 100644 src/beignet/func/_molecular_dynamics/_interact/__mesh_interaction.py create mode 100644 src/beignet/func/_molecular_dynamics/_interact/__neighbor_list_interaction.py create mode 100644 src/beignet/func/_molecular_dynamics/_interact/__pair_interaction.py create mode 100644 src/beignet/func/_molecular_dynamics/_interact/__parameter_tree.py create mode 100644 src/beignet/func/_molecular_dynamics/_interact/__parameter_tree_kind.py create mode 100644 src/beignet/func/_molecular_dynamics/_interact/__to_bond_kind_parameters.py create mode 100644 src/beignet/func/_molecular_dynamics/_interact/__to_neighbor_list_kind_parameters.py create mode 100644 src/beignet/func/_molecular_dynamics/_interact/__to_neighbor_list_matrix_parameters.py diff --git a/src/beignet/func/_molecular_dynamics/__dataclass.py b/src/beignet/func/_molecular_dynamics/__dataclass.py new file mode 100644 index 0000000000..5a6af65cf2 --- /dev/null +++ b/src/beignet/func/_molecular_dynamics/__dataclass.py @@ -0,0 +1,52 @@ +import dataclasses +from typing import List, Tuple, Type, TypeVar + +import optree + +T = TypeVar("T") + + +def _dataclass(cls: Type[T]): + def _set(self: dataclasses.dataclass, **kwargs): + return dataclasses.replace(self, **kwargs) + + cls.set = _set + + dataclass_cls = dataclasses.dataclass(frozen=True)(cls) + + data_fields, metadata_fields = [], [] + + for name, kind in dataclass_cls.__dataclass_fields__.items(): + if not kind.metadata.get("static", False): + data_fields = [*data_fields, name] + else: + metadata_fields = [*metadata_fields, name] + + def _iterate_cls(_x) -> List[Tuple]: + data_iterable = [] + + for k in data_fields: + data_iterable.append(getattr(_x, k)) + + metadata_iterable = [] + + for k in metadata_fields: + metadata_iterable.append(getattr(_x, k)) + + return [data_iterable, metadata_iterable] + + def _iterable_to_cls(meta, data): + meta_args = tuple(zip(metadata_fields, meta)) + data_args = tuple(zip(data_fields, data)) + kwargs = dict(meta_args + data_args) + + return dataclass_cls(**kwargs) + + optree.register_pytree_node( + dataclass_cls, + _iterate_cls, + _iterable_to_cls, + "prescient.func", + ) + + return dataclass_cls diff --git a/src/beignet/func/_molecular_dynamics/__safe_sum.py b/src/beignet/func/_molecular_dynamics/__safe_sum.py new file mode 100644 index 0000000000..5fe4f8c77f --- /dev/null +++ b/src/beignet/func/_molecular_dynamics/__safe_sum.py @@ -0,0 +1,22 @@ +from typing import Iterable, Optional, Union + +import torch +from torch import Tensor + + +def _safe_sum( + x: Tensor, + dim: Optional[Union[Iterable[int], int]] = None, + keepdim: bool = False, +): + match x: + case _ if x.is_complex(): + promoted_dtype = torch.complex128 + case _ if x.is_floating_point(): + promoted_dtype = torch.float64 + case _: + promoted_dtype = torch.int64 + + summation = torch.sum(x, dim=dim, dtype=promoted_dtype, keepdim=keepdim) + + return summation.to(dtype=x.dtype) diff --git a/src/beignet/func/_molecular_dynamics/__zero_diagonal_mask.py b/src/beignet/func/_molecular_dynamics/__zero_diagonal_mask.py new file mode 100644 index 0000000000..0742a72aa0 --- /dev/null +++ b/src/beignet/func/_molecular_dynamics/__zero_diagonal_mask.py @@ -0,0 +1,26 @@ +import torch +from torch import Tensor + + +def _zero_diagonal_mask(x: Tensor) -> Tensor: + """Sets the diagonal of a matrix to zero.""" + if x.shape[0] != x.shape[1]: + raise ValueError( + f"Diagonal mask can only mask square matrices. Found {x.shape[0]}x{x.shape[1]}." + ) + + if len(x.shape) > 3: + raise ValueError( + f"Diagonal mask can only mask rank-2 or rank-3 tensors. Found {len(x.shape)}." + ) + + n = x.shape[0] + + x = torch.nan_to_num(x) + + mask = 1.0 - torch.eye(n, device=x.device, dtype=x.dtype) + + if len(x.shape) == 3: + mask = torch.reshape(mask, [n, n, 1]) + + return x * mask diff --git a/src/beignet/func/_molecular_dynamics/_interact/__angle_interaction.py b/src/beignet/func/_molecular_dynamics/_interact/__angle_interaction.py new file mode 100644 index 0000000000..89d4b157c9 --- /dev/null +++ b/src/beignet/func/_molecular_dynamics/_interact/__angle_interaction.py @@ -0,0 +1,10 @@ +from typing import Callable + +from torch import Tensor + + +def _angle_interaction( + fn: Callable[..., Tensor], + displacement_fn: Callable[[Tensor, Tensor], Tensor], +): + raise NotImplementedError diff --git a/src/beignet/func/_molecular_dynamics/_interact/__bond_interaction.py b/src/beignet/func/_molecular_dynamics/_interact/__bond_interaction.py new file mode 100644 index 0000000000..4c438b7f53 --- /dev/null +++ b/src/beignet/func/_molecular_dynamics/_interact/__bond_interaction.py @@ -0,0 +1,81 @@ +import functools +from typing import Callable, Optional + +import torch +from torch import Tensor + +from ..__safe_sum import _safe_sum +from .__merge_dictionaries import _merge_dictionaries +from .__to_bond_kind_parameters import _to_bond_kind_parameters + + +def _bond_interaction( + fn: Callable[..., Tensor], + displacement_fn: Callable[[Tensor, Tensor], Tensor], + static_bonds: Optional[Tensor] = None, + static_kinds: Optional[Tensor] = None, + ignore_unused_parameters: bool = False, + **static_kwargs, +) -> Callable[..., Tensor]: + merge_dictionaries_fn = functools.partial( + _merge_dictionaries, + ignore_unused_parameters=ignore_unused_parameters, + ) + + def mapped_fn( + positions: Tensor, + bonds: Optional[Tensor] = None, + kinds: Optional[Tensor] = None, + **kwargs, + ) -> Tensor: + accumulator = torch.tensor( + 0.0, + device=positions.device, + dtype=positions.dtype, + ) + + distance_fn = functools.partial(displacement_fn, **kwargs) + + distance_fn = torch.func.vmap(distance_fn, 0, 0) + + if bonds is not None: + parameters = merge_dictionaries_fn(static_kwargs, kwargs) + + for name, parameter in parameters.items(): + if kinds is not None: + parameters[name] = _to_bond_kind_parameters( + parameter, + kinds, + ) + + interactions = distance_fn( + positions[bonds[:, 0]], + positions[bonds[:, 1]], + ) + + interactions = _safe_sum(fn(interactions, **parameters)) + + accumulator = accumulator + interactions + + if static_bonds is not None: + parameters = merge_dictionaries_fn(static_kwargs, kwargs) + + for name, parameter in parameters.items(): + if static_kinds is not None: + parameters[name] = _to_bond_kind_parameters( + parameter, + static_kinds, + ) + + interactions = distance_fn( + positions[static_bonds[:, 0]], + positions[static_bonds[:, 1]], + ) + + interactions = _safe_sum(fn(interactions, **parameters)) + + accumulator = accumulator + interactions + + return accumulator + + return mapped_fn diff --git a/src/beignet/func/_molecular_dynamics/_interact/__dihedral_interaction.py b/src/beignet/func/_molecular_dynamics/_interact/__dihedral_interaction.py new file mode 100644 index 0000000000..4148399cf7 --- /dev/null +++ b/src/beignet/func/_molecular_dynamics/_interact/__dihedral_interaction.py @@ -0,0 +1,10 @@ +from typing import Callable + +from torch import Tensor + + +def _dihedral_interaction( + fn: Callable[..., Tensor], + displacement_fn: Callable[[Tensor, Tensor], Tensor], +): + raise NotImplementedError diff --git a/src/beignet/func/_molecular_dynamics/_interact/__kwargs_to_neighbor_list_parameters.py b/src/beignet/func/_molecular_dynamics/_interact/__kwargs_to_neighbor_list_parameters.py new file mode 100644 index 0000000000..affb6e2f80 --- /dev/null +++ b/src/beignet/func/_molecular_dynamics/_interact/__kwargs_to_neighbor_list_parameters.py @@ -0,0 +1,44 @@ +from typing import Callable, Dict + +from torch import Tensor + +from .._partition.__neighbor_list_format import _NeighborListFormat +from .__to_neighbor_list_kind_parameters import ( + _to_neighbor_list_kind_parameters, +) +from .__to_neighbor_list_matrix_parameters import ( + _to_neighbor_list_matrix_parameters, +) + + +def _kwargs_to_neighbor_list_parameters( + format: _NeighborListFormat, + indexes: Tensor, + species: Tensor, + kwargs: Dict[str, Tensor], + combinators: Dict[str, Callable], +) -> Dict[str, Tensor]: + parameters = {} + + for name, parameter in kwargs.items(): + if species is None or (isinstance(parameter, Tensor) and parameter.ndim == 1): + combinator = combinators.get(name, lambda x, y: 0.5 * (x + y)) + + parameters[name] = _to_neighbor_list_matrix_parameters( + format, + indexes, + parameter, + combinator, + ) + else: + if name in combinators: + raise ValueError + + parameters[name] = _to_neighbor_list_kind_parameters( + format, + indexes, + species, + parameter, + ) + + return parameters diff --git a/src/beignet/func/_molecular_dynamics/_interact/__kwargs_to_pair_parameters.py b/src/beignet/func/_molecular_dynamics/_interact/__kwargs_to_pair_parameters.py new file mode 100644 index 0000000000..826b35dc59 --- /dev/null +++ b/src/beignet/func/_molecular_dynamics/_interact/__kwargs_to_pair_parameters.py @@ -0,0 +1,118 @@ +from typing import Callable, Dict, Union + +import optree +from optree import PyTree +from torch import Tensor + +from .__parameter_tree import _ParameterTree +from .__parameter_tree_kind import _ParameterTreeKind + + +def _kwargs_to_pair_parameters( + kwargs: Dict[str, Union[_ParameterTree, Tensor, float, PyTree]], + combinators: Dict[str, Callable], + kinds: Tensor | None = None, +) -> Dict[str, Tensor]: + parameters = {} + + for name, parameter in kwargs.items(): + if kinds is None: + + def _combinator_fn(x: Tensor, y: Tensor) -> Tensor: + return (x + y) * 0.5 + + combinator = combinators.get(name, _combinator_fn) + + match parameter: + case _ParameterTree(): + match parameter.kind: + case _ParameterTreeKind.BOND | _ParameterTreeKind.SPACE: + parameters[name] = parameter.tree + case _ParameterTreeKind.PARTICLE: + + def _particle_fn(_parameter: Tensor) -> Tensor: + return combinator( + _parameter[:, None, ...], + _parameter[None, :, ...], + ) + + parameters[name] = optree.tree_map( + _particle_fn, + parameter.tree, + ) + case _: + message = f""" +parameter `kind` is `{parameter.kind}`. If `kinds` is `None` and a parameter is +an instance of `ParameterTree`, `kind` must be `ParameterTreeKind.BOND`, +`ParameterTreeKind.PARTICLE`, or `ParameterTreeKind.SPACE`. + """.replace("\n", " ") + + raise ValueError(message) + case Tensor(): + match parameter.ndim: + case 0 | 2: + parameters[name] = parameter + case 1: + parameters[name] = combinator( + parameter[:, None], + parameter[None, :], + ) + case _: + message = f""" +parameter `ndim` is `{parameter.ndim}`. If `kinds` is `None` and a parameter is +an instance of `Tensor`, `ndim` must be in `0`, `1`, or `2`. + """.replace("\n", " ") + + raise ValueError(message) + case float() | int(): + parameters[name] = parameter + case _: + message = f""" +parameter `type` is {type(parameter)}. If `kinds` is `None`, a parameter must +be an instance of `ParameterTree`, `Tensor`, `float`, or `int`. + """.replace("\n", " ") + + raise ValueError(message) + else: + if name in combinators: + raise ValueError + + match parameter: + case _ParameterTree(): + match parameter.kind: + case _ParameterTreeKind.SPACE: + parameters[name] = parameter.tree + case _ParameterTreeKind.KINDS: + + def _kinds_fn(_parameter: Tensor) -> Tensor: + return _parameter[kinds] + + parameters[name] = optree.tree_map( + _kinds_fn, + parameter.tree, + ) + case _: + message = f""" +parameter `kind` is {parameter.kind}. If `kinds` is `None` and a parameter is +an instance of `ParameterTree`, `kind` must be `ParameterTreeKind.SPACE` or +`ParameterTreeKind.KINDS`. + """.replace("\n", " ") + + raise ValueError(message) + case Tensor(): + match parameter.ndim: + case 0: + parameters[name] = parameter + case 2: + parameters[name] = parameter[kinds] + case _: + message = f""" +parameter `ndim` is `{parameter.ndim}`. If `kinds` is not `None` and a +parameter is an instance of `Tensor`, `ndim` must be in `0`, `1`, or `2`. + """.replace("\n", " ") + + raise ValueError(message) + case _: + parameters[name] = parameter + + return parameters diff --git a/src/beignet/func/_molecular_dynamics/_interact/__long_range_interaction.py b/src/beignet/func/_molecular_dynamics/_interact/__long_range_interaction.py new file mode 100644 index 0000000000..f714d0050a --- /dev/null +++ b/src/beignet/func/_molecular_dynamics/_interact/__long_range_interaction.py @@ -0,0 +1,10 @@ +from typing import Callable + +from torch import Tensor + + +def _long_range_interaction( + fn: Callable[..., Tensor], + displacement_fn: Callable[[Tensor, Tensor], Tensor], +): + raise NotImplementedError diff --git a/src/beignet/func/_molecular_dynamics/_interact/__map_product.py b/src/beignet/func/_molecular_dynamics/_interact/__map_product.py new file mode 100644 index 0000000000..027a8cdc44 --- /dev/null +++ b/src/beignet/func/_molecular_dynamics/_interact/__map_product.py @@ -0,0 +1,28 @@ +from typing import Callable + +import torch +from torch import Tensor + + +def _map_product( + fn: Callable[[Tensor, Tensor], Tensor], +) -> Callable[[Tensor, Tensor], Tensor]: + """ + + Parameters + ---------- + fn + + Returns + ------- + + """ + return torch.func.vmap( + torch.func.vmap( + fn, + (0, None), + 0, + ), + (None, 0), + 0, + ) diff --git a/src/beignet/func/_molecular_dynamics/_interact/__merge_dictionaries.py b/src/beignet/func/_molecular_dynamics/_interact/__merge_dictionaries.py new file mode 100644 index 0000000000..dced9de4c0 --- /dev/null +++ b/src/beignet/func/_molecular_dynamics/_interact/__merge_dictionaries.py @@ -0,0 +1,20 @@ +from typing import Dict + + +def _merge_dictionaries( + this: Dict, + that: Dict, + ignore_unused_parameters: bool = False, +): + if not ignore_unused_parameters: + return {**this, **that} + + merged_dictionaries = dict(this) + + for this_key in merged_dictionaries.keys(): + that_value = that.get(this_key) + + if that_value is not None: + merged_dictionaries[this_key] = that_value + + return merged_dictionaries diff --git a/src/beignet/func/_molecular_dynamics/_interact/__mesh_interaction.py b/src/beignet/func/_molecular_dynamics/_interact/__mesh_interaction.py new file mode 100644 index 0000000000..a2470cab52 --- /dev/null +++ b/src/beignet/func/_molecular_dynamics/_interact/__mesh_interaction.py @@ -0,0 +1,10 @@ +from typing import Callable + +from torch import Tensor + + +def _mesh_interaction( + fn: Callable[..., Tensor], + displacement_fn: Callable[[Tensor, Tensor], Tensor], +): + raise NotImplementedError diff --git a/src/beignet/func/_molecular_dynamics/_interact/__neighbor_list_interaction.py b/src/beignet/func/_molecular_dynamics/_interact/__neighbor_list_interaction.py new file mode 100644 index 0000000000..33c03ad45d --- /dev/null +++ b/src/beignet/func/_molecular_dynamics/_interact/__neighbor_list_interaction.py @@ -0,0 +1,120 @@ +import functools +from typing import Callable, Optional, Tuple + +import torch +from torch import Tensor + +from ..__safe_sum import _safe_sum +from .._partition.__is_neighbor_list_sparse import _is_neighbor_list_sparse +from .._partition.__map_bond import _map_bond +from .._partition.__map_neighbor import _map_neighbor +from .._partition.__neighbor_list import _NeighborList +from .._partition.__neighbor_list_format import _NeighborListFormat +from .._partition.__segment_sum import _segment_sum +from .__kwargs_to_neighbor_list_parameters import ( + _kwargs_to_neighbor_list_parameters, +) +from .__merge_dictionaries import _merge_dictionaries + + +def _neighbor_list_interaction( + fn: Callable[..., Tensor], + displacement_fn: Callable[[Tensor, Tensor], Tensor], + kinds: Tensor | None = None, + dim: Optional[Tuple[int, ...]] = None, + ignore_unused_parameters: bool = False, + **kwargs, +) -> Callable[..., Tensor]: + parameters, combinators = {}, {} + + for name, parameter in kwargs.items(): + if isinstance(parameter, Callable): + combinators[name] = parameter + elif isinstance(parameter, tuple) and isinstance(parameter[0], Callable): + assert len(parameter) == 2 + + combinators[name], parameters[name] = parameter[0], parameter[1] + else: + parameters[name] = parameter + + merged_dictionaries = functools.partial( + _merge_dictionaries, + ignore_unused_parameters=ignore_unused_parameters, + ) + + def mapped_fn( + positions: Tensor, + neighbor_list: _NeighborList, + **dynamic_kwargs, + ) -> Tensor: + distance_fn = functools.partial(displacement_fn, **dynamic_kwargs) + + _kinds = dynamic_kwargs.get("kinds", kinds) + + normalization = 2.0 + + if _is_neighbor_list_sparse(neighbor_list.format): + distances = _map_bond(distance_fn)( + positions[neighbor_list.indexes[0]], + positions[neighbor_list.indexes[1]], + ) + + mask = torch.less(neighbor_list.indexes[0], positions.shape[0]) + + if neighbor_list.format is _NeighborListFormat.ORDERED_SPARSE: + normalization = 1.0 + else: + distances = _map_neighbor(distance_fn)( + positions, + positions[neighbor_list.indexes], + ) + + mask = torch.less(neighbor_list.indexes, positions.shape[0]) + + out = fn( + distances, + **_kwargs_to_neighbor_list_parameters( + neighbor_list.format, + neighbor_list.indexes, + _kinds, + merged_dictionaries( + parameters, + dynamic_kwargs, + ), + combinators, + ), + ) + + if out.ndim > mask.ndim: + mask = torch.reshape( + mask, + [*mask.shape, *([1] * (out.ndim - mask.ndim))], + ) + + out = torch.multiply(out, mask) + + if dim is None: + return torch.divide(_safe_sum(out), normalization) + + if 0 in dim and 1 not in dim: + raise ValueError + + if not _is_neighbor_list_sparse(neighbor_list.format): + return torch.divide(_safe_sum(out, dim=dim), normalization) + + if 0 in dim: + return _safe_sum(out, dim=[0, *[a - 1 for a in dim if a > 1]]) + + if neighbor_list.format is _NeighborListFormat.ORDERED_SPARSE: + raise ValueError + + return torch.divide( + _segment_sum( + _safe_sum(out, dim=[a - 1 for a in dim if a > 1]), + neighbor_list.indexes[0], + positions.shape[0], + ), + normalization, + ) + + return mapped_fn diff --git a/src/beignet/func/_molecular_dynamics/_interact/__pair_interaction.py b/src/beignet/func/_molecular_dynamics/_interact/__pair_interaction.py new file mode 100644 index 0000000000..9d7a55e5fa --- /dev/null +++ b/src/beignet/func/_molecular_dynamics/_interact/__pair_interaction.py @@ -0,0 +1,166 @@ +import functools +from typing import Callable, Optional, Tuple, Union + +import torch +from torch import Tensor + +from ..__safe_sum import _safe_sum +from ..__zero_diagonal_mask import _zero_diagonal_mask +from .__kwargs_to_pair_parameters import ( + _kwargs_to_pair_parameters, +) +from .__map_product import _map_product +from .__merge_dictionaries import _merge_dictionaries + + +def _pair_interaction( + fn: Callable[..., Tensor], + displacement_fn: Callable[[Tensor, Tensor], Tensor], + kinds: Optional[Union[int, Tensor]] = None, + dim: Optional[Tuple[int, ...]] = None, + keepdim: bool = False, + ignore_unused_parameters: bool = False, + **kwargs, +) -> Callable[..., Tensor]: + parameters, combinators = {}, {} + + for name, parameter in kwargs.items(): + if isinstance(parameter, Callable): + combinators[name] = parameter + elif isinstance(parameter, tuple) and isinstance(parameter[0], Callable): + assert len(parameter) == 2 + + combinators[name], parameters[name] = parameter[0], parameter[1] + else: + parameters[name] = parameter + + merge_dicts = functools.partial( + _merge_dictionaries, + ignore_unused_parameters=ignore_unused_parameters, + ) + + if kinds is None: + + def mapped_fn(_position: Tensor, **_dynamic_kwargs) -> Tensor: + distance_fn = functools.partial(displacement_fn, **_dynamic_kwargs) + + distances = _map_product(distance_fn)(_position, _position) + + dictionaries = merge_dicts(parameters, _dynamic_kwargs) + + to_parameters = _kwargs_to_pair_parameters( + dictionaries, + combinators, + ) + + u = fn(distances, **to_parameters) + + u = _zero_diagonal_mask(u) + + u = _safe_sum(u, dim=dim, keepdim=keepdim) + + return u * 0.5 + + return mapped_fn + + if isinstance(kinds, Tensor): + if not isinstance(kinds, Tensor) or kinds.is_floating_point(): + raise ValueError + + kinds_count = int(torch.max(kinds)) + + if dim is not None or keepdim: + raise ValueError + + def mapped_fn(_position: Tensor, **_dynamic_kwargs): + u = torch.tensor(0.0, dtype=torch.float32) + + distance_fn = functools.partial(displacement_fn, **_dynamic_kwargs) + + distance_fn = _map_product(distance_fn) + + for m in range(kinds_count + 1): + for n in range(m, kinds_count + 1): + distance = distance_fn( + _position[kinds == m], + _position[kinds == n], + ) + + _kwargs = merge_dicts(parameters, _dynamic_kwargs) + + s_kwargs = _kwargs_to_pair_parameters(_kwargs, combinators, (m, n)) + + u = fn(distance, **s_kwargs) + + if m == n: + u = _zero_diagonal_mask(u) + + u = _safe_sum(u) + + u = u + u * 0.5 + else: + y = _safe_sum(u) + + u = u + y + + return u + + return mapped_fn + + if isinstance(kinds, int): + kinds_count = kinds + + def mapped_fn(_position: Tensor, _kinds: Tensor, **_dynamic_kwargs): + if not isinstance(_kinds, Tensor) or _kinds.is_floating_point(): + raise ValueError + + u = torch.tensor(0.0, dtype=torch.float32) + + n = _position.shape[0] + + distance_fn = functools.partial(displacement_fn, **_dynamic_kwargs) + + distance_fn = _map_product(distance_fn) + + _kwargs = merge_dicts(parameters, _dynamic_kwargs) + + distance = distance_fn(_position, _position) + + for m in range(kinds_count): + for n in range(kinds_count): + a = torch.reshape( + _kinds == m, + [ + n, + ], + ) + b = torch.reshape( + _kinds == n, + [ + n, + ], + ) + + a = a.to(dtype=_position.dtype)[:, None] + b = b.to(dtype=_position.dtype)[None, :] + + mask = a * b + + if m == n: + mask = _zero_diagonal_mask(mask) * mask + + to_parameters = _kwargs_to_pair_parameters( + _kwargs, combinators, (m, n) + ) + + y = fn(distance, **to_parameters) * mask + + y = _safe_sum(y, dim=dim, keepdim=keepdim) + + u = u + y + + return u / 2.0 + + return mapped_fn + + raise ValueError diff --git a/src/beignet/func/_molecular_dynamics/_interact/__parameter_tree.py b/src/beignet/func/_molecular_dynamics/_interact/__parameter_tree.py new file mode 100644 index 0000000000..811ad80b10 --- /dev/null +++ b/src/beignet/func/_molecular_dynamics/_interact/__parameter_tree.py @@ -0,0 +1,12 @@ +import dataclasses + +from optree import PyTree + +from ..__dataclass import _dataclass +from .__parameter_tree_kind import _ParameterTreeKind + + +@_dataclass +class _ParameterTree: + tree: PyTree + kind: _ParameterTreeKind = dataclasses.field(metadata={"static": True}) diff --git a/src/beignet/func/_molecular_dynamics/_interact/__parameter_tree_kind.py b/src/beignet/func/_molecular_dynamics/_interact/__parameter_tree_kind.py new file mode 100644 index 0000000000..8c08b29f70 --- /dev/null +++ b/src/beignet/func/_molecular_dynamics/_interact/__parameter_tree_kind.py @@ -0,0 +1,8 @@ +from enum import Enum + + +class _ParameterTreeKind(Enum): + BOND = 0 + KINDS = 1 + PARTICLE = 2 + SPACE = 3 diff --git a/src/beignet/func/_molecular_dynamics/_interact/__to_bond_kind_parameters.py b/src/beignet/func/_molecular_dynamics/_interact/__to_bond_kind_parameters.py new file mode 100644 index 0000000000..77788ddd1b --- /dev/null +++ b/src/beignet/func/_molecular_dynamics/_interact/__to_bond_kind_parameters.py @@ -0,0 +1,42 @@ +from typing import Dict + +import optree +from torch import Tensor + +from .__parameter_tree import _ParameterTree +from .__parameter_tree_kind import _ParameterTreeKind + + +def _to_bond_kind_parameters( + parameter: Tensor | _ParameterTree, + kinds: Tensor, +) -> Tensor | _ParameterTree: + assert isinstance(kinds, Tensor) + + assert len(kinds.shape) == 1 + + match parameter: + case Tensor(): + match parameter.shape: + case 0: + return parameter + case 1: + return parameter[kinds] + case _: + raise ValueError + case _ParameterTree(): + if parameter.kind is _ParameterTreeKind.BOND: + + def _fn(_parameter: Dict) -> Tensor: + return _parameter[kinds] + + return optree.tree_map(_fn, parameter.tree) + + if parameter.kind is _ParameterTreeKind.SPACE: + return parameter.tree + + raise ValueError + case float() | int(): + return parameter + case _: + raise NotImplementedError diff --git a/src/beignet/func/_molecular_dynamics/_interact/__to_neighbor_list_kind_parameters.py b/src/beignet/func/_molecular_dynamics/_interact/__to_neighbor_list_kind_parameters.py new file mode 100644 index 0000000000..c70e07c54e --- /dev/null +++ b/src/beignet/func/_molecular_dynamics/_interact/__to_neighbor_list_kind_parameters.py @@ -0,0 +1,85 @@ +import functools + +import optree +import torch +from optree import PyTree +from torch import Tensor + +from .._partition.__is_neighbor_list_sparse import _is_neighbor_list_sparse +from .._partition.__map_bond import _map_bond +from .._partition.__neighbor_list_format import _NeighborListFormat +from .__parameter_tree import _ParameterTree +from .__parameter_tree_kind import _ParameterTreeKind + + +def _to_neighbor_list_kind_parameters( + format: _NeighborListFormat, + indexes: Tensor, + kinds: Tensor, + parameters: _ParameterTree | Tensor | float, +) -> PyTree | _ParameterTree | Tensor | float: + fn = functools.partial( + lambda p, a, b: p[a, b], + parameters, + ) + + match parameters: + case parameters if isinstance(parameters, Tensor): + match parameters.shape: + case 0: + return parameters + case 2: + if _is_neighbor_list_sparse(format): + return _map_bond( + fn, + )( + kinds[indexes[0]], + kinds[indexes[1]], + ) + + return torch.func.vmap( + torch.func.vmap( + fn, + in_dims=(None, 0), + ), + )(kinds, kinds[indexes]) + case _: + raise ValueError + case parameters if isinstance(parameters, _ParameterTree): + match parameters.kind: + case _ParameterTreeKind.KINDS: + if _is_neighbor_list_sparse(format): + return optree.tree_map( + lambda parameter: _map_bond( + functools.partial( + fn, + parameter, + ), + )( + kinds[indexes[0]], + kinds[indexes[1]], + ), + parameters.tree, + ) + + return optree.tree_map( + lambda parameter: torch.func.vmap( + torch.func.vmap( + functools.partial( + fn, + parameter, + ), + (None, 0), + ) + )( + kinds, + kinds[indexes], + ), + parameters.tree, + ) + case _ParameterTreeKind.SPACE: + return parameters.tree + case _: + raise ValueError + + return parameters diff --git a/src/beignet/func/_molecular_dynamics/_interact/__to_neighbor_list_matrix_parameters.py b/src/beignet/func/_molecular_dynamics/_interact/__to_neighbor_list_matrix_parameters.py new file mode 100644 index 0000000000..476892ea43 --- /dev/null +++ b/src/beignet/func/_molecular_dynamics/_interact/__to_neighbor_list_matrix_parameters.py @@ -0,0 +1,123 @@ +import functools +from typing import Callable + +import optree +import torch +from optree import PyTree +from torch import Tensor + +from .._partition.__is_neighbor_list_sparse import _is_neighbor_list_sparse +from .._partition.__map_bond import _map_bond +from .._partition.__map_neighbor import _map_neighbor +from .._partition.__neighbor_list_format import _NeighborListFormat +from .__parameter_tree import _ParameterTree +from .__parameter_tree_kind import _ParameterTreeKind + + +def _to_neighbor_list_matrix_parameters( + format: _NeighborListFormat, + indexes: Tensor, + parameters: _ParameterTree | Tensor | float, + combinator: Callable[[Tensor, Tensor], Tensor], +) -> PyTree | _ParameterTree | Tensor | float: + match parameters: + case parameters if isinstance(parameters, Tensor): + match parameters.ndim: + case 0: + return parameters + case 1: + if _is_neighbor_list_sparse(format): + return _map_bond( + combinator, + )( + parameters[indexes[0]], + parameters[indexes[1]], + ) + + return combinator( + parameters[:, None], + parameters[indexes], + ) + case 2: + if _is_neighbor_list_sparse(format): + return _map_bond( + lambda a, b: parameters[a, b], + )( + indexes[0], + indexes[1], + ) + + return torch.func.vmap( + torch.func.vmap( + lambda a, b: parameters[a, b], + (None, 0), + ), + )( + torch.arange(indexes.shape[0], dtype=torch.int32), + indexes, + ) + case _: + raise ValueError + case parameters if isinstance(parameters, _ParameterTree): + match parameters.kind: + case _ParameterTreeKind.BOND: + if _is_neighbor_list_sparse(format): + return optree.tree_map( + lambda parameter: _map_bond( + functools.partial( + lambda p, a, b: p[a, b], + parameter, + ), + )( + indexes[0], + indexes[1], + ), + parameters.tree, + ) + + return optree.tree_map( + lambda parameter: torch.func.vmap( + torch.func.vmap( + functools.partial( + lambda p, a, b: p[a, b], + parameter, + ), + (None, 0), + ), + )( + torch.arange(indexes.shape[0], dtype=torch.int32), + indexes, + ), + parameters.tree, + ) + case _ParameterTreeKind.PARTICLE: + if _is_neighbor_list_sparse(format): + return optree.tree_map( + lambda parameter: _map_bond( + combinator, + )( + parameter[indexes[0]], + parameter[indexes[1]], + ), + parameters.tree, + ) + + return optree.tree_map( + lambda parameter: _map_neighbor( + combinator, + )( + parameter, + parameter[indexes], + ), + parameters.tree, + ) + case _ParameterTreeKind.SPACE: + return parameters.tree + case _: + raise ValueError + case parameters if isinstance(parameters, float): + return parameters + case parameters if isinstance(parameters, int): + return parameters + case _: + raise ValueError