Skip to content

Commit

Permalink
refactored des370k and des5m
Browse files Browse the repository at this point in the history
  • Loading branch information
mcneela committed Mar 8, 2024
1 parent 07f70b8 commit 1443450
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 55 deletions.
17 changes: 12 additions & 5 deletions openqdc/datasets/interaction/des370k.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,14 @@ class DES370K(BaseInteractionDataset):
"sapt_delta_HF",
]

def read_raw_entries(self) -> List[Dict]:
self.filepath = os.path.join(self.root, "DES370K.csv")
logger.info(f"Reading DES370K interaction data from {self.filepath}")
df = pd.read_csv(self.filepath)
_filename = "DES370K.csv"
_short_name = "DES370K"

@classmethod
def _read_raw_entries(cls) -> List[Dict]:
filepath = os.path.join(cls.root, cls._filename)
logger.info(f"Reading {cls._short_name} interaction data from {filepath}")
df = pd.read_csv(filepath)
data = []
for idx, row in tqdm(df.iterrows(), total=df.shape[0]):
smiles0, smiles1 = row["smiles0"], row["smiles1"]
Expand All @@ -84,7 +88,7 @@ def read_raw_entries(self) -> List[Dict]:

atomic_inputs = np.concatenate((atomic_nums, charges, pos), axis=-1, dtype=np.float32)

energies = np.array(row[self.energy_target_names].values).astype(np.float32)[None, :]
energies = np.array(row[cls.energy_target_names].values).astype(np.float32)[None, :]

name = np.array([smiles0 + "." + smiles1])

Expand All @@ -108,3 +112,6 @@ def read_raw_entries(self) -> List[Dict]:
)
data.append(item)
return data

def read_raw_entries(self) -> List[Dict]:
return DES370K._read_raw_entries()
54 changes: 4 additions & 50 deletions openqdc/datasets/interaction/des5m.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,6 @@
import os
from typing import Dict, List

import numpy as np
import pandas as pd
from loguru import logger
from tqdm import tqdm

from openqdc.datasets.interaction import DES370K
from openqdc.utils.molecule import atom_table, molecule_groups


class DES5M(DES370K):
Expand Down Expand Up @@ -56,47 +49,8 @@ class DES5M(DES370K):
"sapt_delta_HF",
]

def read_raw_entries(self) -> List[Dict]:
self.filepath = os.path.join(self.root, "DES5M.csv")
logger.info(f"Reading DES5M interaction data from {self.filepath}")
df = pd.read_csv(self.filepath)
data = []
for idx, row in tqdm(df.iterrows(), total=df.shape[0]):
smiles0, smiles1 = row["smiles0"], row["smiles1"]
charge0, charge1 = row["charge0"], row["charge1"]
natoms0, natoms1 = row["natoms0"], row["natoms1"]
pos = np.array(list(map(float, row["xyz"].split()))).reshape(-1, 3)

elements = row["elements"].split()

atomic_nums = np.expand_dims(np.array([atom_table.GetAtomicNumber(x) for x in elements]), axis=1)

charges = np.expand_dims(np.array([charge0] * natoms0 + [charge1] * natoms1), axis=1)
_filename = "DES5M.csv"
_short_name = "DES5M"

atomic_inputs = np.concatenate((atomic_nums, charges, pos), axis=-1, dtype=np.float32)

energies = np.array(row[self.energy_target_names].values).astype(np.float32)[None, :]

name = np.array([smiles0 + "." + smiles1])

subsets = []
# for smiles in [canon_smiles0, canon_smiles1]:
for smiles in [smiles0, smiles1]:
found = False
for functional_group, smiles_set in molecule_groups.items():
if smiles in smiles_set:
subsets.append(functional_group)
found = True
if not found:
logger.info(f"molecule group lookup failed for {smiles}")

item = dict(
energies=energies,
subset=np.array([subsets]),
n_atoms=np.array([natoms0 + natoms1], dtype=np.int32),
n_atoms_first=np.array([natoms0], dtype=np.int32),
atomic_inputs=atomic_inputs,
name=name,
)
data.append(item)
return data
def read_raw_entries(self) -> List[Dict]:
return DES5M._read_raw_entries()

0 comments on commit 1443450

Please sign in to comment.