Skip to content

Commit

Permalink
update base interaction dataset to add n_atoms_first property
Browse files Browse the repository at this point in the history
  • Loading branch information
mcneela committed Mar 8, 2024
1 parent 1443450 commit 802b70b
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 4 deletions.
43 changes: 43 additions & 0 deletions openqdc/datasets/interaction/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Dict, List, Optional

import numpy as np
from sklearn.utils import Bunch

from openqdc.datasets.base import BaseDataset
from openqdc.utils.atomization_energies import IsolatedAtomEnergyFactory
from openqdc.utils.constants import NB_ATOMIC_FEATURES


Expand Down Expand Up @@ -45,4 +47,45 @@ def data_shapes(self):
"position_idx_range": (-1, 2),
"energies": (-1, len(self.__energy_methods__)),
"forces": (-1, 3, len(self.force_target_names)),
"n_atoms_first": (-1,),
}

@property
def data_types(self):
return {
"atomic_inputs": np.float32,
"position_idx_range": np.int32,
"energies": np.float32,
"forces": np.float32,
"n_atoms_first": np.int32,
}

def __getitem__(self, idx: int):
shift = IsolatedAtomEnergyFactory.max_charge
p_start, p_end = self.data["position_idx_range"][idx]
input = self.data["atomic_inputs"][p_start:p_end]
z, c, positions, energies = (
np.array(input[:, 0], dtype=np.int32),
np.array(input[:, 1], dtype=np.int32),
np.array(input[:, -3:], dtype=np.float32),
np.array(self.data["energies"][idx], dtype=np.float32),
)
name = self.__smiles_converter__(self.data["name"][idx])
subset = self.data["subset"][idx]
n_atoms_first = self.data["n_atoms_first"][idx]

if "forces" in self.data:
forces = np.array(self.data["forces"][p_start:p_end], dtype=np.float32)
else:
forces = None
return Bunch(
positions=positions,
atomic_numbers=z,
charges=c,
e0=self.__isolated_atom_energies__[..., z, c + shift].T,
energies=energies,
name=name,
subset=subset,
forces=forces,
n_atoms_first=n_atoms_first,
)
11 changes: 8 additions & 3 deletions openqdc/datasets/interaction/des370k.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from tqdm import tqdm

from openqdc.datasets.interaction import BaseInteractionDataset
from openqdc.utils.io import get_local_cache
from openqdc.utils.molecule import atom_table, molecule_groups


Expand Down Expand Up @@ -66,12 +67,16 @@ class DES370K(BaseInteractionDataset):
]

_filename = "DES370K.csv"
_short_name = "DES370K"
_name = "des370k_interaction"

@classmethod
def _root(cls):
return os.path.join(get_local_cache(), cls._name)

@classmethod
def _read_raw_entries(cls) -> List[Dict]:
filepath = os.path.join(cls.root, cls._filename)
logger.info(f"Reading {cls._short_name} interaction data from {filepath}")
filepath = os.path.join(cls._root(), cls._filename)
logger.info(f"Reading {cls._name} interaction data from {filepath}")
df = pd.read_csv(filepath)
data = []
for idx, row in tqdm(df.iterrows(), total=df.shape[0]):
Expand Down
2 changes: 1 addition & 1 deletion openqdc/datasets/interaction/des5m.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class DES5M(DES370K):
]

_filename = "DES5M.csv"
_short_name = "DES5M"
_name = "des5m_interaction"

def read_raw_entries(self) -> List[Dict]:
return DES5M._read_raw_entries()
1 change: 1 addition & 0 deletions openqdc/datasets/interaction/metcalf.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def read_raw_entries(self) -> List[Dict]:
positions=xyz,
atomic_inputs=atomic_inputs,
name=np.array([""]),
n_atoms_first=np.array([-1]),
)
data.append(item)
return data

0 comments on commit 802b70b

Please sign in to comment.