Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Statistics abstraction pattern #74

Merged
merged 33 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fixes
  • Loading branch information
FNTwin committed Mar 25, 2024
commit 3b4823ec5dda5f4b8d638c22deac1fd6b3d2bfc4
2 changes: 1 addition & 1 deletion openqdc/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def datasets():
table = PrettyTable(["Name", "Type of Energy", "Forces", "Level of theory"])
for dataset in AVAILABLE_DATASETS:
empty_dataset = AVAILABLE_DATASETS[dataset].no_init()
has_forces = False if not empty_dataset.__force_methods__ else True
has_forces = False if not empty_dataset.force_mask else True
en_type = "Potential" if dataset in AVAILABLE_POTENTIAL_DATASETS else "Interaction"
table.add_row(
[
Expand Down
41 changes: 32 additions & 9 deletions openqdc/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import pickle as pkl
from copy import deepcopy
from itertools import compress
from os.path import join as p_join
from typing import Dict, List, Optional, Union

Expand Down Expand Up @@ -32,15 +33,16 @@
from openqdc.utils.regressor import Regressor # noqa
from openqdc.utils.units import get_conversion


class BaseDataset(DatasetPropertyMixIn):
"""
Base class for datasets in the openQDC package.
"""

__energy_methods__ = []
__force_methods__ = []
energy_target_names = []
force_target_names = []
__energy_methods__ = []
__force_mask__ = []
__isolated_atom_energies__ = []

__energy_unit__ = "hartree"
Expand Down Expand Up @@ -129,6 +131,27 @@ def _convert_data(self):
for key in self.data_keys:
self.data[key] = self._convert_on_loading(self.data[key], key)

@property
def __force_methods__(self):
"""
For backward compatibility. To be removed in the future.
"""
return self.force_methods

@property
def energy_methods(self):
return self.__energy_methods__

@property
def force_methods(self):
return list(compress(self.energy_methods, self.force_mask))

@property
def force_mask(self):
if len(self.__class__.__force_mask__) == 0:
self.__class__.__force_mask__ = [False] * len(self.energy_methods)
return self.__class__.__force_mask__

@property
def energy_unit(self):
return self.__energy_unit__
Expand Down Expand Up @@ -196,12 +219,11 @@ def _set_units(self, en, ds):
self.__class__.__fn_forces__ = get_conversion(old_en + "/" + old_ds, self.__forces_unit__)

def _set_isolated_atom_energies(self):
if self.__energy_methods__ is None:
if self.energy_methods is None:
logger.error("No energy methods defined for this dataset.")
f = get_conversion("hartree", self.__energy_unit__)

self.__isolated_atom_energies__ = f(
np.array([IsolatedAtomEnergyFactory.get_matrix(en_method) for en_method in self.__energy_methods__])
np.array([IsolatedAtomEnergyFactory.get_matrix(en_method) for en_method in self.energy_methods])
)

def convert_energy(self, x):
Expand Down Expand Up @@ -280,13 +302,13 @@ def read_preprocess(self, overwrite_local_cache=False):
f"Dataset {self.__name__} with the following units:\n\
Energy: {self.energy_unit},\n\
Distance: {self.distance_unit},\n\
Forces: {self.force_unit if self.__force_methods__ else 'None'}"
Forces: {self.force_unit if self.force_methods else 'None'}"
)
self.data = {}
for key in self.data_keys:
filename = p_join(self.preprocess_path, f"{key}.mmap")
pull_locally(filename, overwrite=overwrite_local_cache)
self.data[key] = np.memmap(filename, mode="r", dtype=self.data_types[key]).reshape(self.data_shapes[key])
self.data[key] = np.memmap(filename, mode="r", dtype=self.data_types[key]).reshape(*self.data_shapes[key])

filename = p_join(self.preprocess_path, "props.pkl")
pull_locally(filename, overwrite=overwrite_local_cache)
Expand Down Expand Up @@ -423,7 +445,8 @@ def get_statistics(self, normalization: str = "formation", return_none: bool = T
"""
Get the statistics of the dataset.
normalization : str, optional
Type of energy, by default "formation", must be one of ["formation", "total", "inter"]
Type of energy, by default "formation", must be one of ["formation", "total",
"residual_regression", "per_atom_formation", "per_atom_residual_regression"]
return_none : bool, optional
Whether to return None if the statistics for the forces are not available, by default True
Otherwise, the statistics for the forces are set to 0.0
Expand All @@ -434,7 +457,7 @@ def get_statistics(self, normalization: str = "formation", return_none: bool = T
if normalization not in POSSIBLE_NORMALIZATION:
raise NormalizationNotAvailableError(normalization)
selected_stats = stats[normalization]
if len(self.__force_methods__) == 0 and not return_none:
if len(self.force_methods) == 0 and not return_none:
selected_stats.update(
{
"forces": {
Expand Down
1 change: 1 addition & 0 deletions openqdc/datasets/potential/ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class ANI1(BaseDataset):
energy_target_names = [
"ωB97x:6-31G(d) Energy",
]

__energy_unit__ = "hartree"
__distance_unit__ = "bohr"
__forces_unit__ = "hartree/bohr"
Expand Down
3 changes: 2 additions & 1 deletion openqdc/datasets/potential/comp6.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from os.path import join as p_join

from openqdc.datasets.base import BaseDataset, read_qc_archive_h5
from openqdc.datasets.base import BaseDataset
from openqdc.utils import read_qc_archive_h5


class COMP6(BaseDataset):
Expand Down
Loading