diff --git a/pyproject.toml b/pyproject.toml index dd7d3467..e641d8bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,6 @@ dependencies = [ "torch_geometric", "lightning", "tqdm", - "pandas", ] [project.urls] diff --git a/torchmdnet/datasets/ace.py b/torchmdnet/datasets/ace.py index 0b92d673..ebb7da0b 100644 --- a/torchmdnet/datasets/ace.py +++ b/torchmdnet/datasets/ace.py @@ -9,7 +9,6 @@ from torchmdnet.datasets.memdataset import MemmappedDataset from torch_geometric.data import Data from tqdm import tqdm -import pandas as pd class Ace(MemmappedDataset): @@ -133,7 +132,6 @@ def __init__( paths=None, max_gradient=None, subsample_molecules=1, - index_csv=None, ): assert isinstance(paths, (str, list)) @@ -143,13 +141,8 @@ def __init__( self.paths = paths self.max_gradient = max_gradient self.subsample_molecules = int(subsample_molecules) - if index_csv is not None: - df = pd.read_csv(index_csv, dtype=int, converters={"name": str}) - self.mol_indexes = {mol_id: i for i, mol_id in enumerate(df.name)} props = ["y", "neg_dy", "q", "pq", "dp"] - if index_csv is not None: - props += ["mol_idx"] super().__init__( root, transform, @@ -239,7 +232,7 @@ def _load_confs_2_0(mol, n_atoms): def sample_iter(self, mol_ids=False): assert self.subsample_molecules > 0 - for path in tqdm(self.raw_paths, desc="Files"): + for i_path, path in tqdm(enumerate(self.raw_paths), desc="Files"): h5 = h5py.File(path) assert h5.attrs["layout"] == "Ace" version = h5.attrs["layout_version"] @@ -285,10 +278,9 @@ def sample_iter(self, mol_ids=False): z=z, pos=pos, y=y.view(1, 1), neg_dy=neg_dy, q=q, pq=pq, dp=dp ) if mol_ids: + args["i_path"] = i_path args["mol_id"] = mol_id args["i_conf"] = i_conf - if "mol_idx" in self.properties: - args["mol_idx"] = self.mol_indexes[mol_id] data = Data(**args) diff --git a/torchmdnet/datasets/memdataset.py b/torchmdnet/datasets/memdataset.py index c84d6559..6cca8344 100644 --- a/torchmdnet/datasets/memdataset.py +++ b/torchmdnet/datasets/memdataset.py @@ -86,10 +86,6 @@ def __init__( self.mmaps["dp"] = np.memmap( fnames["dp"], mode="r", dtype=np.float32, shape=(num_all_confs, 3) ) - if "mol_idx" in self.properties: - self.mmaps["mol_idx"] = np.memmap( - fnames["mol_idx"], mode="r", dtype=np.uint64 - ) assert self.mmaps["idx"][0] == 0 assert self.mmaps["idx"][-1] == len(self.mmaps["z"]) @@ -178,13 +174,6 @@ def process(self): dtype=np.float32, shape=(num_all_confs, 3), ) - if "mol_idx" in self.properties: - mmaps["mol_idx"] = np.memmap( - fnames["mol_idx"] + ".tmp", - mode="w+", - dtype=np.uint64, - shape=(num_all_confs,), - ) print("Storing data...") i_atom = 0 @@ -204,8 +193,6 @@ def process(self): mmaps["pq"][i_atom:i_next_atom] = data.pq if "dp" in self.properties: mmaps["dp"][i_conf] = data.dp - if "mol_idx" in self.properties: - mmaps["mol_idx"][i_conf] = data.mol_idx i_atom = i_next_atom mmaps["idx"][-1] = num_all_atoms @@ -231,8 +218,6 @@ def process(self): os.rename(fnames["pq"] + ".tmp", fnames["pq"]) if "dp" in self.properties: os.rename(fnames["dp"] + ".tmp", fnames["dp"]) - if "mol_idx" in self.properties: - os.rename(fnames["mol_idx"] + ".tmp", fnames["mol_idx"]) def len(self): return len(self.mmaps["idx"]) - 1 @@ -249,7 +234,6 @@ def get(self, idx): - :obj:`q`: Total charge of the molecule. - :obj:`pq`: Partial charges of the atoms. - :obj:`dp`: Dipole moment of the molecule. - - :obj:`mol_idx`: The index of the molecule of the conformer. Args: idx (int): Index of the data object. @@ -272,8 +256,6 @@ def get(self, idx): props["pq"] = pt.tensor(self.mmaps["pq"][atoms]) if "dp" in self.properties: props["dp"] = pt.tensor(self.mmaps["dp"][idx]) - # if "mol_idx" in self.properties: - # props["mol_idx"] = pt.tensor(self.mmaps["mol_idx"][idx], dtype=pt.int64).view(1, 1) return Data(z=z, pos=pos, **props) def __del__(self):