Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Work in progress: Multiclass possible now #62

Merged
merged 25 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
cb69197
Work in progress: Multiclass possible now
May 23, 2024
99353f2
Merge branch 'main' of https://github.com/mlederbauer/NMRcraft into f…
May 23, 2024
d86a2b8
fix: standard scaling
May 23, 2024
3a17d15
feat: simplified data loader
May 23, 2024
0eb6bad
chore: refactor name to dataloader
May 23, 2024
c486683
feat: add more columns to results df
May 24, 2024
ae94e2f
move default args of dataloader to classifier
May 24, 2024
b6508db
test multiple targets
May 24, 2024
7246f14
feat: removed data folder from gitignore
May 25, 2024
1c83a3b
fix: change absolute path of data file
May 26, 2024
4aec0dd
feat: barebone baslines script
May 26, 2024
3bdb20a
fix: testing for now lol
May 26, 2024
4af718e
feat: functional multiclass models
kbiniek May 26, 2024
c5ee67a
resolve conflicts
kbiniek May 26, 2024
49c820f
Merge pull request #68 from mlederbauer/chore/minimal-dataloader
kbiniek May 26, 2024
b32dcb5
feat: working baseline
May 26, 2024
168e266
Merge pull request #69 from mlederbauer/feat/baselines
kbiniek May 26, 2024
0be3994
feat: new evaluation
May 26, 2024
ebe51ff
fix: targets instead of y_labels
May 26, 2024
1d0a7a4
feat: add default parameters to DataLoader
kbiniek May 26, 2024
01de0b0
Merge pull request #70 from mlederbauer/feat/unify-evaluation
kbiniek May 26, 2024
8598603
feat: fix confusion matrix plot and bootstrap
kbiniek May 26, 2024
9fddfe5
feat: add multioutput models
kbiniek May 27, 2024
2b77d23
Feat: Added statistics for the bootstrapped Metrics
May 27, 2024
8ac7432
Chore/47 plotting functions (#71)
mlederbauer May 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
mlruns/
scratch/
dataset/
data/
plots/
data/

docs/source

Expand Down
133 changes: 97 additions & 36 deletions nmrcraft/analysis/plotting.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""Functions to plot."""

import matplotlib.patches as mpatches
import os
import matplotlib.pyplot as plt
import numpy as np
from cycler import cycler
Expand All @@ -15,11 +19,13 @@ def style_setup():
plt.rcParams["text.latex.preamble"] = r"\usepackage{sansmathfonts}"
plt.rcParams["axes.prop_cycle"] = cycler(color=colors)

# Use the first color from the custom color cycle
first_color = plt.rcParams["axes.prop_cycle"].by_key()["color"][0]
all_colors = [
plt.rcParams["axes.prop_cycle"].by_key()["color"][i]
for i in range(len(colors))
]
plt.rcParams["text.usetex"] = False

return cmap, colors, first_color
return cmap, colors, all_colors


def plot_predicted_vs_ground_truth(
Expand All @@ -33,7 +39,8 @@ def plot_predicted_vs_ground_truth(
Returns:
None
"""
_, _, first_color = style_setup()
_, _, colors = style_setup()
first_color = colors[0]
# Creating the plot
plt.figure(figsize=(10, 8))
plt.scatter(y_test, y_pred, color=first_color, edgecolor="k", alpha=0.6)
Expand Down Expand Up @@ -85,7 +92,7 @@ def plot_predicted_vs_ground_truth_density(


def plot_confusion_matrix(
cm, classes, title, path, full=True, columns_set=False
cm_list, y_labels, model_name, dataset_size, folder_path: str = "plots/"
):
"""
Plots the confusion matrix.
Expand All @@ -98,45 +105,27 @@ def plot_confusion_matrix(
Returns:
None
"""
_, _, _ = style_setup()
if full: # Plot one big cm
if not os.path.exists(folder_path):
os.makedirs(folder_path)
# _, _, _ = style_setup()
for target in y_labels:
file_path = os.path.join(
folder_path,
f"ConfusionMatrix_{model_name}_{dataset_size}_{target}.png",
)
cm = cm_list[target]
classes = y_labels[target]
plt.figure(figsize=(10, 8))
plt.imshow(cm, interpolation="nearest", cmap=plt.cm.Blues)
plt.title(title)
plt.title(f"{target} Confusion Matrix")
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.xticks(tick_marks, classes, rotation=90)
plt.yticks(tick_marks, classes)
plt.tight_layout()
plt.ylabel("True label")
plt.xlabel("Predicted label")
plt.savefig(path)
plt.close()

elif not full: # Plot many small cms of each target
cms = []
for columns in columns_set: # Make list of confusion matrices
cms.append(
cm[
slice(columns[0], columns[-1] + 1),
slice(columns[0], columns[-1] + 1),
]
)
fig, axs = plt.subplots(nrows=len(cms), figsize=(10, 8 * len(cms)))
for i, sub_cm in enumerate(cms):
sub_classes = classes[
slice(columns_set[i][0], columns_set[i][-1] + 1)
]
axs[i].imshow(sub_cm, interpolation="nearest", cmap=plt.cm.Blues)
axs[i].set_title(f"Confusion Matrix {i+1}")
tick_marks = np.arange(len(sub_classes))
axs[i].set_xticks(tick_marks)
axs[i].set_xticklabels(sub_classes, rotation=45)
axs[i].set_yticks(tick_marks)
axs[i].set_yticklabels(sub_classes)
plt.tight_layout()
print(cm)
plt.savefig(path)
plt.savefig(file_path)
plt.close()


Expand Down Expand Up @@ -167,3 +156,75 @@ def plot_roc_curve(fpr, tpr, roc_auc, title, path):
plt.legend(loc="lower right")
plt.savefig(path)
plt.close()


def plot_with_without_ligands_bar(df):
categories = df["target"].unique()
_, _, colors = style_setup()
first_color = colors[0]
second_color = colors[1]

# Extract data

x_pos = np.arange(len(categories))
bar_width = 0.35

# Initialize plot
fig, ax = plt.subplots()

# Loop through each category and plot bars
for i, category in enumerate(categories):
subset = df[df["target"] == category]

# Means and error bars
means = subset["accuracy_mean"].values
errors = [
subset["accuracy_mean"].values
- subset["accuracy_lower_bd"].values,
subset["accuracy_upper_bd"].values
- subset["accuracy_mean"].values,
]

# Bar locations for the group
bar_positions = x_pos[i] + np.array([-bar_width / 2, bar_width / 2])

# Determine bar colors based on 'nmr_tensor_input_only' field
bar_colors = [
first_color if x else second_color
for x in subset["nmr_tensor_input_only"]
]

# Plotting the bars
ax.bar(
bar_positions,
means,
yerr=np.array(errors),
color=bar_colors,
align="center",
ecolor="black",
capsize=5,
width=bar_width,
)

# Labeling and aesthetics
ax.set_ylabel("Accuracy / %")
ax.set_xlabel("Target(s)")
ax.set_xticks(x_pos)
ax.set_xticklabels(categories)
ax.set_title("Accuracy Measurements with Error Bars")

handles = [
mpatches.Patch(color=first_color, label="With Ligand Info"),
mpatches.Patch(color=second_color, label="Without Ligand Info"),
]
ax.legend(handles=handles, loc="best", fontsize=20)
plt.tight_layout()
plt.savefig("plots/exp3_incorporate_ligand_info.png")
print("Saved to plots/exp3_incorporate_ligand_info.png")


if __name__ == "main":
import pandas as pd

df = pd.read_csv("dataset/path_to_results.csv")
plot_with_without_ligands_bar(df)
76 changes: 76 additions & 0 deletions nmrcraft/data/data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Load and preprocess data."""

import os

import pandas as pd
from datasets import load_dataset


class DatasetLoadError(FileNotFoundError):
"""Exeption raised when the Dataloader could not find data/dataset.csv,
even after trying to generate it from huggingface"""

def __init__(self, t):
super().__init__(f"Could not load raw Dataset '{t}'")


class InvalidTargetError(ValueError):
"""Exception raised when the specified model name is not found."""

def __init__(self, t):
super().__init__(f"Invalid target '{t}'")


def filename_to_ligands(dataset: pd.DataFrame):
"""
Extract ligands from the filename and add as columns to the dataset.
Assumes that filenames are structured in a specific way that can be parsed into ligands.
"""
filename_parts = dataset["file_name"].str.split("_", expand=True)
dataset["metal"] = filename_parts.get(0)
dataset["geometry"] = filename_parts.get(1)
dataset["E_ligand"] = filename_parts.get(2)
dataset["X1_ligand"] = filename_parts.get(3)
dataset["X2_ligand"] = filename_parts.get(4)
dataset["X3_ligand"] = filename_parts.get(5)
dataset["X4_ligand"] = filename_parts.get(6)
dataset["L_ligand"] = filename_parts.get(7).fillna(
"none"
) # Fill missing L_ligand with 'none'
return dataset


def load_dummy_dataset_locally(datset_path: str = "tests/data.csv"):
dataset = pd.read_csv(datset_path)
return dataset


def load_dataset_from_hf(
dataset_name: str = "NMRcraft/nmrcraft", data_files: str = "all_no_nan.csv"
):
"""Load the dataset.

This function loads the dataset using the specified dataset name and data files.
It assumes that you have logged into the Hugging Face CLI prior to calling this function.

Args:
dataset_name (str, optional): The name of the dataset. Defaults to "NMRcraft/nmrcraft".
data_files (str, optional): The name of the data file. Defaults to 'all_no_nan.csv'.

Returns:
pandas.DataFrame: The loaded dataset as a pandas DataFrame.
"""
# Create data dir if needed
if not os.path.isdir("dataset"):
os.mkdir("dataset")
# Check if hf dataset is already downloaded, else download it and then load it
if not os.path.isfile("dataset/dataset.csv"):
dataset = load_dataset(dataset_name, data_files=data_files)[
"train"
].to_pandas()
dataset.to_csv("dataset/dataset.csv")
if os.path.isfile("dataset/dataset.csv"):
dataset = pd.read_csv("dataset/dataset.csv")
elif not os.path.isfile("dataset/dataset.csv"):
raise DatasetLoadError(FileNotFoundError)
return dataset
Loading
Loading