From 802b70b5d7028e4d2e7bfa92867ec1429f535b92 Mon Sep 17 00:00:00 2001 From: mcneela Date: Fri, 8 Mar 2024 12:23:29 -0500 Subject: [PATCH] update base interaction dataset to add n_atoms_first property --- openqdc/datasets/interaction/base.py | 43 +++++++++++++++++++++++++ openqdc/datasets/interaction/des370k.py | 11 +++++-- openqdc/datasets/interaction/des5m.py | 2 +- openqdc/datasets/interaction/metcalf.py | 1 + 4 files changed, 53 insertions(+), 4 deletions(-) diff --git a/openqdc/datasets/interaction/base.py b/openqdc/datasets/interaction/base.py index 71c8e84..27c2f88 100644 --- a/openqdc/datasets/interaction/base.py +++ b/openqdc/datasets/interaction/base.py @@ -1,8 +1,10 @@ from typing import Dict, List, Optional import numpy as np +from sklearn.utils import Bunch from openqdc.datasets.base import BaseDataset +from openqdc.utils.atomization_energies import IsolatedAtomEnergyFactory from openqdc.utils.constants import NB_ATOMIC_FEATURES @@ -45,4 +47,45 @@ def data_shapes(self): "position_idx_range": (-1, 2), "energies": (-1, len(self.__energy_methods__)), "forces": (-1, 3, len(self.force_target_names)), + "n_atoms_first": (-1,), } + + @property + def data_types(self): + return { + "atomic_inputs": np.float32, + "position_idx_range": np.int32, + "energies": np.float32, + "forces": np.float32, + "n_atoms_first": np.int32, + } + + def __getitem__(self, idx: int): + shift = IsolatedAtomEnergyFactory.max_charge + p_start, p_end = self.data["position_idx_range"][idx] + input = self.data["atomic_inputs"][p_start:p_end] + z, c, positions, energies = ( + np.array(input[:, 0], dtype=np.int32), + np.array(input[:, 1], dtype=np.int32), + np.array(input[:, -3:], dtype=np.float32), + np.array(self.data["energies"][idx], dtype=np.float32), + ) + name = self.__smiles_converter__(self.data["name"][idx]) + subset = self.data["subset"][idx] + n_atoms_first = self.data["n_atoms_first"][idx] + + if "forces" in self.data: + forces = np.array(self.data["forces"][p_start:p_end], dtype=np.float32) + else: + forces = None + return Bunch( + positions=positions, + atomic_numbers=z, + charges=c, + e0=self.__isolated_atom_energies__[..., z, c + shift].T, + energies=energies, + name=name, + subset=subset, + forces=forces, + n_atoms_first=n_atoms_first, + ) diff --git a/openqdc/datasets/interaction/des370k.py b/openqdc/datasets/interaction/des370k.py index e97710c..382b84c 100644 --- a/openqdc/datasets/interaction/des370k.py +++ b/openqdc/datasets/interaction/des370k.py @@ -7,6 +7,7 @@ from tqdm import tqdm from openqdc.datasets.interaction import BaseInteractionDataset +from openqdc.utils.io import get_local_cache from openqdc.utils.molecule import atom_table, molecule_groups @@ -66,12 +67,16 @@ class DES370K(BaseInteractionDataset): ] _filename = "DES370K.csv" - _short_name = "DES370K" + _name = "des370k_interaction" + + @classmethod + def _root(cls): + return os.path.join(get_local_cache(), cls._name) @classmethod def _read_raw_entries(cls) -> List[Dict]: - filepath = os.path.join(cls.root, cls._filename) - logger.info(f"Reading {cls._short_name} interaction data from {filepath}") + filepath = os.path.join(cls._root(), cls._filename) + logger.info(f"Reading {cls._name} interaction data from {filepath}") df = pd.read_csv(filepath) data = [] for idx, row in tqdm(df.iterrows(), total=df.shape[0]): diff --git a/openqdc/datasets/interaction/des5m.py b/openqdc/datasets/interaction/des5m.py index ea0d929..5b027f4 100644 --- a/openqdc/datasets/interaction/des5m.py +++ b/openqdc/datasets/interaction/des5m.py @@ -50,7 +50,7 @@ class DES5M(DES370K): ] _filename = "DES5M.csv" - _short_name = "DES5M" + _name = "des5m_interaction" def read_raw_entries(self) -> List[Dict]: return DES5M._read_raw_entries() diff --git a/openqdc/datasets/interaction/metcalf.py b/openqdc/datasets/interaction/metcalf.py index 5e1cd73..c9921da 100644 --- a/openqdc/datasets/interaction/metcalf.py +++ b/openqdc/datasets/interaction/metcalf.py @@ -74,6 +74,7 @@ def read_raw_entries(self) -> List[Dict]: positions=xyz, atomic_inputs=atomic_inputs, name=np.array([""]), + n_atoms_first=np.array([-1]), ) data.append(item) return data