Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chore/47 plotting functions #59

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
mlruns/
scratch/
dataset/
data/
plots/
data/

docs/source

Expand Down
84 changes: 80 additions & 4 deletions nmrcraft/analysis/plotting.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
from cycler import cycler
Expand All @@ -15,11 +16,13 @@ def style_setup():
plt.rcParams["text.latex.preamble"] = r"\usepackage{sansmathfonts}"
plt.rcParams["axes.prop_cycle"] = cycler(color=colors)

# Use the first color from the custom color cycle
first_color = plt.rcParams["axes.prop_cycle"].by_key()["color"][0]
all_colors = [
plt.rcParams["axes.prop_cycle"].by_key()["color"][i]
for i in range(len(colors))
]
plt.rcParams["text.usetex"] = False

return cmap, colors, first_color
return cmap, colors, all_colors


def plot_predicted_vs_ground_truth(
Expand All @@ -33,7 +36,8 @@ def plot_predicted_vs_ground_truth(
Returns:
None
"""
_, _, first_color = style_setup()
_, _, colors = style_setup()
first_color = colors[0]
# Creating the plot
plt.figure(figsize=(10, 8))
plt.scatter(y_test, y_pred, color=first_color, edgecolor="k", alpha=0.6)
Expand Down Expand Up @@ -167,3 +171,75 @@ def plot_roc_curve(fpr, tpr, roc_auc, title, path):
plt.legend(loc="lower right")
plt.savefig(path)
plt.close()


def plot_with_without_ligands_bar(df):
categories = df["target"].unique()
_, _, colors = style_setup()
first_color = colors[0]
second_color = colors[1]

# Extract data

x_pos = np.arange(len(categories))
bar_width = 0.35

# Initialize plot
fig, ax = plt.subplots()

# Loop through each category and plot bars
for i, category in enumerate(categories):
subset = df[df["target"] == category]

# Means and error bars
means = subset["accuracy_mean"].values
errors = [
subset["accuracy_mean"].values
- subset["accuracy_lower_bd"].values,
subset["accuracy_upper_bd"].values
- subset["accuracy_mean"].values,
]

# Bar locations for the group
bar_positions = x_pos[i] + np.array([-bar_width / 2, bar_width / 2])

# Determine bar colors based on 'nmr_tensor_input_only' field
bar_colors = [
first_color if x else second_color
for x in subset["nmr_tensor_input_only"]
]

# Plotting the bars
ax.bar(
bar_positions,
means,
yerr=np.array(errors),
color=bar_colors,
align="center",
ecolor="black",
capsize=5,
width=bar_width,
)

# Labeling and aesthetics
ax.set_ylabel("Accuracy / %")
ax.set_xlabel("Target(s)")
ax.set_xticks(x_pos)
ax.set_xticklabels(categories)
ax.set_title("Accuracy Measurements with Error Bars")

handles = [
mpatches.Patch(color=first_color, label="With Ligand Info"),
mpatches.Patch(color=second_color, label="Without Ligand Info"),
]
ax.legend(handles=handles, loc="best", fontsize=20)
plt.tight_layout()
plt.savefig("plots/exp3_incorporate_ligand_info.png")
print("Saved to plots/exp3_incorporate_ligand_info.png")


if __name__ == "main":
import pandas as pd

df = pd.read_csv("dataset/path_to_results.csv")
plot_with_without_ligands_bar(df)
76 changes: 76 additions & 0 deletions nmrcraft/data/data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Load and preprocess data."""

import os

import pandas as pd
from datasets import load_dataset


class DatasetLoadError(FileNotFoundError):
"""Exeption raised when the Dataloader could not find data/dataset.csv,
even after trying to generate it from huggingface"""

def __init__(self, t):
super().__init__(f"Could not load raw Dataset '{t}'")


class InvalidTargetError(ValueError):
"""Exception raised when the specified model name is not found."""

def __init__(self, t):
super().__init__(f"Invalid target '{t}'")


def filename_to_ligands(dataset: pd.DataFrame):
"""
Extract ligands from the filename and add as columns to the dataset.
Assumes that filenames are structured in a specific way that can be parsed into ligands.
"""
filename_parts = dataset["file_name"].str.split("_", expand=True)
dataset["metal"] = filename_parts.get(0)
dataset["geometry"] = filename_parts.get(1)
dataset["E_ligand"] = filename_parts.get(2)
dataset["X1_ligand"] = filename_parts.get(3)
dataset["X2_ligand"] = filename_parts.get(4)
dataset["X3_ligand"] = filename_parts.get(5)
dataset["X4_ligand"] = filename_parts.get(6)
dataset["L_ligand"] = filename_parts.get(7).fillna(
"none"
) # Fill missing L_ligand with 'none'
return dataset


def load_dummy_dataset_locally(datset_path: str = "tests/data.csv"):
dataset = pd.read_csv(datset_path)
return dataset


def load_dataset_from_hf(
dataset_name: str = "NMRcraft/nmrcraft", data_files: str = "all_no_nan.csv"
):
"""Load the dataset.

This function loads the dataset using the specified dataset name and data files.
It assumes that you have logged into the Hugging Face CLI prior to calling this function.

Args:
dataset_name (str, optional): The name of the dataset. Defaults to "NMRcraft/nmrcraft".
data_files (str, optional): The name of the data file. Defaults to 'all_no_nan.csv'.

Returns:
pandas.DataFrame: The loaded dataset as a pandas DataFrame.
"""
# Create data dir if needed
if not os.path.isdir("dataset"):
os.mkdir("dataset")
# Check if hf dataset is already downloaded, else download it and then load it
if not os.path.isfile("dataset/dataset.csv"):
dataset = load_dataset(dataset_name, data_files=data_files)[
"train"
].to_pandas()
dataset.to_csv("dataset/dataset.csv")
if os.path.isfile("dataset/dataset.csv"):
dataset = pd.read_csv("dataset/dataset.csv")
elif not os.path.isfile("dataset/dataset.csv"):
raise DatasetLoadError(FileNotFoundError)
return dataset
Loading
Loading