Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torchscipt compatability for Ensemble #312

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,11 +246,15 @@ def test_gradients(model_name):
model, (z, pos, batch), eps=1e-4, atol=1e-3, rtol=1e-2, nondet_tol=1e-3
)


def test_ensemble():
@mark.parametrize("check_script", [True, False])
def test_ensemble(check_script):
ckpts = [join(dirname(dirname(__file__)), "tests", "example.ckpt")] * 3
model = load_model(ckpts[0])
ensemble_model = load_model(ckpts, return_std=True)

if check_script:
ensemble_model = torch.jit.script(load_model(ckpts))
else:
ensemble_model = load_model(ckpts)
z, pos, batch = create_example_batch(n_atoms=5)

pred, deriv = model(z, pos, batch)
Expand All @@ -271,7 +275,11 @@ def test_ensemble():
with zipfile.ZipFile(ensemble_zip, "w") as zipf:
for i, ckpt in enumerate(ckpts):
zipf.write(ckpt, f"model_{i}.ckpt")
ensemble_model = load_model(ensemble_zip, return_std=True)

if check_script:
ensemble_model = torch.jit.script(load_model(ckpts))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo here, this should be ensemble_zip not ckpts

else:
ensemble_model = load_model(ensemble_zip)
pred_ensemble, deriv_ensemble, y_std, neg_dy_std = ensemble_model(z, pos, batch)

torch.testing.assert_close(pred, pred_ensemble, atol=1e-5, rtol=1e-5)
Expand Down
88 changes: 58 additions & 30 deletions torchmdnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import re
import tempfile
from typing import Optional, List, Tuple, Dict
from typing import Optional, List, Tuple, Dict, Union
import torch
from torch.autograd import grad
from torch import nn, Tensor
Expand Down Expand Up @@ -142,7 +142,7 @@ def create_model(args, prior_model=None, mean=None, std=None):
return model


def load_ensemble(filepath, args=None, device="cpu", return_std=False, **kwargs):
def load_ensemble(filepath, args=None, device="cpu", **kwargs):
"""Load an ensemble of models from a list of checkpoint files or a zip file.

Args:
Expand All @@ -153,7 +153,6 @@ def load_ensemble(filepath, args=None, device="cpu", return_std=False, **kwargs)

args (dict, optional): Arguments for the model. Defaults to None.
device (str, optional): Device on which the model should be loaded. Defaults to "cpu".
return_std (bool, optional): Whether to return the standard deviation of the predictions. Defaults to False.
**kwargs: Extra keyword arguments for the model, will be passed to :py:mod:`load_model`.

Returns:
Expand All @@ -179,11 +178,10 @@ def load_ensemble(filepath, args=None, device="cpu", return_std=False, **kwargs)
)
return Ensemble(
model_list,
return_std=return_std,
)


def load_model(filepath, args=None, device="cpu", return_std=False, **kwargs):
def load_model(filepath, args=None, device="cpu", **kwargs):
"""Load a model from a checkpoint file.

If a list of paths or a path to a zip file is given, an :py:mod:`Ensemble` model is returned.
Expand All @@ -196,7 +194,6 @@ def load_model(filepath, args=None, device="cpu", return_std=False, **kwargs):

args (dict, optional): Arguments for the model. Defaults to None.
device (str, optional): Device on which the model should be loaded. Defaults to "cpu".
return_std (bool, optional): Whether to return the standard deviation of an Ensemble model. Defaults to False.
**kwargs: Extra keyword arguments for the model.

Returns:
Expand All @@ -205,7 +202,7 @@ def load_model(filepath, args=None, device="cpu", return_std=False, **kwargs):
isEnsemble = isinstance(filepath, (list, tuple)) or filepath.endswith(".zip")
if isEnsemble:
return load_ensemble(
filepath, args=args, device=device, return_std=return_std, **kwargs
filepath, args=args, device=device, **kwargs
)
assert isinstance(filepath, str)
ckpt = torch.load(filepath, map_location="cpu")
Expand Down Expand Up @@ -490,47 +487,78 @@ def forward(
class Ensemble(torch.nn.ModuleList):
"""Average predictions over an ensemble of TorchMD-Net models.

This module behaves like a single TorchMD-Net model, but its forward method returns the average and standard deviation of the predictions over all models it was initialized with.
This module behaves similarly to a single TorchMD-Net model, but its forward method returns the average and standard deviation of the predictions over all models it was initialized with.

Args:
modules (List[nn.Module]): List of :py:mod:`TorchMD_Net` models to average predictions over.
return_std (bool, optional): Whether to return the standard deviation of the predictions. Defaults to False. If set to True, the model returns 4 arguments (mean_y, mean_neg_dy, std_y, std_neg_dy) instead of 2 (mean_y, mean_neg_dy).
"""

def __init__(self, modules: List[nn.Module], return_std: bool = False):
def __init__(self, modules: List[nn.Module]):
for module in modules:
assert isinstance(module, TorchMD_Net)
super().__init__(modules)
self.return_std = return_std

def forward(
self,
*args,
**kwargs,
):
"""Average predictions over all models in the ensemble.
The arguments to this function are simply relayed to the forward method of each :py:mod:`TorchMD_Net` model in the ensemble.
z: Tensor,
pos: Tensor,
batch: Optional[Tensor] = None,
box: Optional[Tensor] = None,
q: Optional[Tensor] = None,
s: Optional[Tensor] = None,
extra_args: Optional[Dict[str, Tensor]] = None,
) -> Tuple[Tensor, Optional[Tensor], Tensor, Optional[Tensor]]:
"""
Compute the output of the ensemble of models.

The predictions are the average over all models in the ensemble.

This function optionally supports periodic boundary conditions with
arbitrary triclinic boxes. The box vectors `a`, `b`, and `c` must satisfy
certain requirements:

.. code:: python

a[1] = a[2] = b[2] = 0
a[0] >= 2*cutoff, b[1] >= 2*cutoff, c[2] >= 2*cutoff
a[0] >= 2*b[0]
a[0] >= 2*c[0]
b[1] >= 2*c[1]


These requirements correspond to a particular rotation of the system and
reduced form of the vectors, as well as the requirement that the cutoff be
no larger than half the box width.

Args:
*args: Positional arguments to forward to the models.
**kwargs: Keyword arguments to forward to the models.
Returns:
Tuple[Tensor, Optional[Tensor]] or Tuple[Tensor, Optional[Tensor], Tensor, Optional[Tensor]]: The average and standard deviation of the predictions over all models in the ensemble. If return_std is False, the output is a tuple (mean_y, mean_neg_dy). If return_std is True, the output is a tuple (mean_y, mean_neg_dy, std_y, std_neg_dy).
z (Tensor): Atomic numbers of the atoms in the molecule. Shape: (N,).
pos (Tensor): Atomic positions in the molecule. Shape: (N, 3).
batch (Tensor, optional): Batch indices for the atoms in the molecule. Shape: (N,).
box (Tensor, optional): Box vectors. Shape (3, 3).
The vectors defining the periodic box. This must have shape `(3, 3)`,
where `box_vectors[0] = a`, `box_vectors[1] = b`, and `box_vectors[2] = c`.
If this is omitted, periodic boundary conditions are not applied.
q (Tensor, optional): Atomic charges in the molecule. Shape: (N,).
s (Tensor, optional): Atomic spins in the molecule. Shape: (N,).
extra_args (Dict[str, Tensor], optional): Extra arguments to pass to the prior model.

Returns:
Tuple[Tensor, Optional[Tensor], Tensor, Optional[Tensor]]: The mean output of the models, the mean negative derivatives, the std of the outputs, and the std of the negative derivatives.

"""

y = []
neg_dy = []
for model in self:
res = model(*args, **kwargs)
res = model(z=z, pos=pos, batch=batch, box=box, q=q, s=s, extra_args=extra_args)
y.append(res[0])
neg_dy.append(res[1])
y = torch.stack(y)
neg_dy = torch.stack(neg_dy)
y_mean = torch.mean(y, axis=0)
neg_dy_mean = torch.mean(neg_dy, axis=0)
y_std = torch.std(y, axis=0)
neg_dy_std = torch.std(neg_dy, axis=0)

if self.return_std:
return y_mean, neg_dy_mean, y_std, neg_dy_std
else:
return y_mean, neg_dy_mean
y_mean = torch.mean(y, dim=0)
neg_dy_mean = torch.mean(neg_dy, dim=0)
y_std = torch.std(y, dim=0)
neg_dy_std = torch.std(neg_dy, dim=0)

return y_mean, neg_dy_mean, y_std, neg_dy_std

Loading