Skip to content

Commit

Permalink
Feature: dataloader choose complex geometry (#55)
Browse files Browse the repository at this point in the history
* Feature: dataloader choose complex geometry

* Fix: removed tupelization

---------

Co-authored-by: Tiago Würthner <[email protected]>
  • Loading branch information
tiaguinho-code and Tiago Würthner authored May 18, 2024
1 parent 8eba2a7 commit 479ceeb
Showing 1 changed file with 17 additions and 2 deletions.
19 changes: 17 additions & 2 deletions nmrcraft/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def target_label_readabilitizer(readable_labels):
"""
# Trun that class_ into list
human_readable_label_list = list(itertools.chain(*readable_labels))
# Handle Binarized metal stuff and make the two columns become a single one
# Handle Binarized metal stuff and make the two columns become a single one because the metals get turned into a single column by the binarizer
for i in enumerate(human_readable_label_list):
if (
human_readable_label_list[i[0]] == "Mo"
Expand Down Expand Up @@ -200,7 +200,8 @@ def __init__(
data_files="all_no_nan.csv",
feature_columns=None,
target_columns="metal",
target_type="one-hot", # can be "categorical" or "one-hot"
target_type="one-hot", # can be "categorical" or "one-hot",
complex_geometry="all",
test_size=0.3,
random_state=42,
dataset_size=0.01,
Expand All @@ -212,6 +213,8 @@ def __init__(
self.random_state = random_state
self.dataset_size = dataset_size
self.target_type = target_type
self.complex_geometry = complex_geometry

if not testing:
self.dataset = load_dataset_from_hf()
elif testing:
Expand All @@ -220,6 +223,18 @@ def __init__(
def load_data(self):
self.dataset = filename_to_ligands(self.dataset)
self.dataset = self.dataset.sample(frac=self.dataset_size)
if self.complex_geometry == "oct":
self.dataset = self.dataset[
self.dataset["geometry"] == "oct"
] # only load octahedral complexes
elif self.complex_geometry == "spy":
self.dataset = self.dataset[
self.dataset["geometry"] == "spy"
] # only load square pyramidal complexes
elif self.complex_geometry == "tbp":
self.dataset = self.dataset[
self.dataset["geometry"] == "tbp"
] # only load trigonal bipyramidal complexes
if self.target_type == "categorical":
return self.split_and_preprocess_categorical()
elif (
Expand Down

0 comments on commit 479ceeb

Please sign in to comment.