Skip to content

Commit

Permalink
remove the mol_idx property as it was unused and also added extra req…
Browse files Browse the repository at this point in the history
…uirement in pandas (#361)
  • Loading branch information
stefdoerr authored Feb 27, 2025
1 parent f960354 commit fdd86d6
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 29 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ dependencies = [
"torch_geometric",
"lightning",
"tqdm",
"pandas",
]

[project.urls]
Expand Down
12 changes: 2 additions & 10 deletions torchmdnet/datasets/ace.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from torchmdnet.datasets.memdataset import MemmappedDataset
from torch_geometric.data import Data
from tqdm import tqdm
import pandas as pd


class Ace(MemmappedDataset):
Expand Down Expand Up @@ -133,7 +132,6 @@ def __init__(
paths=None,
max_gradient=None,
subsample_molecules=1,
index_csv=None,
):
assert isinstance(paths, (str, list))

Expand All @@ -143,13 +141,8 @@ def __init__(
self.paths = paths
self.max_gradient = max_gradient
self.subsample_molecules = int(subsample_molecules)
if index_csv is not None:
df = pd.read_csv(index_csv, dtype=int, converters={"name": str})
self.mol_indexes = {mol_id: i for i, mol_id in enumerate(df.name)}

props = ["y", "neg_dy", "q", "pq", "dp"]
if index_csv is not None:
props += ["mol_idx"]
super().__init__(
root,
transform,
Expand Down Expand Up @@ -239,7 +232,7 @@ def _load_confs_2_0(mol, n_atoms):
def sample_iter(self, mol_ids=False):
assert self.subsample_molecules > 0

for path in tqdm(self.raw_paths, desc="Files"):
for i_path, path in tqdm(enumerate(self.raw_paths), desc="Files"):
h5 = h5py.File(path)
assert h5.attrs["layout"] == "Ace"
version = h5.attrs["layout_version"]
Expand Down Expand Up @@ -285,10 +278,9 @@ def sample_iter(self, mol_ids=False):
z=z, pos=pos, y=y.view(1, 1), neg_dy=neg_dy, q=q, pq=pq, dp=dp
)
if mol_ids:
args["i_path"] = i_path
args["mol_id"] = mol_id
args["i_conf"] = i_conf
if "mol_idx" in self.properties:
args["mol_idx"] = self.mol_indexes[mol_id]

data = Data(**args)

Expand Down
18 changes: 0 additions & 18 deletions torchmdnet/datasets/memdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,6 @@ def __init__(
self.mmaps["dp"] = np.memmap(
fnames["dp"], mode="r", dtype=np.float32, shape=(num_all_confs, 3)
)
if "mol_idx" in self.properties:
self.mmaps["mol_idx"] = np.memmap(
fnames["mol_idx"], mode="r", dtype=np.uint64
)

assert self.mmaps["idx"][0] == 0
assert self.mmaps["idx"][-1] == len(self.mmaps["z"])
Expand Down Expand Up @@ -178,13 +174,6 @@ def process(self):
dtype=np.float32,
shape=(num_all_confs, 3),
)
if "mol_idx" in self.properties:
mmaps["mol_idx"] = np.memmap(
fnames["mol_idx"] + ".tmp",
mode="w+",
dtype=np.uint64,
shape=(num_all_confs,),
)

print("Storing data...")
i_atom = 0
Expand All @@ -204,8 +193,6 @@ def process(self):
mmaps["pq"][i_atom:i_next_atom] = data.pq
if "dp" in self.properties:
mmaps["dp"][i_conf] = data.dp
if "mol_idx" in self.properties:
mmaps["mol_idx"][i_conf] = data.mol_idx
i_atom = i_next_atom

mmaps["idx"][-1] = num_all_atoms
Expand All @@ -231,8 +218,6 @@ def process(self):
os.rename(fnames["pq"] + ".tmp", fnames["pq"])
if "dp" in self.properties:
os.rename(fnames["dp"] + ".tmp", fnames["dp"])
if "mol_idx" in self.properties:
os.rename(fnames["mol_idx"] + ".tmp", fnames["mol_idx"])

def len(self):
return len(self.mmaps["idx"]) - 1
Expand All @@ -249,7 +234,6 @@ def get(self, idx):
- :obj:`q`: Total charge of the molecule.
- :obj:`pq`: Partial charges of the atoms.
- :obj:`dp`: Dipole moment of the molecule.
- :obj:`mol_idx`: The index of the molecule of the conformer.
Args:
idx (int): Index of the data object.
Expand All @@ -272,8 +256,6 @@ def get(self, idx):
props["pq"] = pt.tensor(self.mmaps["pq"][atoms])
if "dp" in self.properties:
props["dp"] = pt.tensor(self.mmaps["dp"][idx])
# if "mol_idx" in self.properties:
# props["mol_idx"] = pt.tensor(self.mmaps["mol_idx"][idx], dtype=pt.int64).view(1, 1)
return Data(z=z, pos=pos, **props)

def __del__(self):
Expand Down

0 comments on commit fdd86d6

Please sign in to comment.