Skip to content

Commit

Permalink
Merge branch 'main' into fix/multilabel-confusion-matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
Tiago Würthner committed May 14, 2024
2 parents 478fe40 + 8176e2e commit 215be3f
Show file tree
Hide file tree
Showing 8 changed files with 304 additions and 17 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ mlruns/
scratch/
dataset/
data/
plots/

docs/source

Expand Down
26 changes: 16 additions & 10 deletions nmrcraft/analysis/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,21 @@
from matplotlib.colors import LinearSegmentedColormap, Normalize
from scipy.stats import gaussian_kde

colors = ["#C28340", "#854F2B", "#61371F", "#8FCA5C", "#70B237", "#477A1E"]
cmap = LinearSegmentedColormap.from_list("custom", colors)

plt.style.use("./style.mplstyle")
plt.rcParams["text.latex.preamble"] = r"\usepackage{sansmathfonts}"
plt.rcParams["axes.prop_cycle"] = cycler(color=colors)
def style_setup():
"""Function to set up matplotlib parameters."""
colors = ["#C28340", "#854F2B", "#61371F", "#8FCA5C", "#70B237", "#477A1E"]
cmap = LinearSegmentedColormap.from_list("custom", colors)

plt.style.use("./style.mplstyle")
plt.rcParams["text.latex.preamble"] = r"\usepackage{sansmathfonts}"
plt.rcParams["axes.prop_cycle"] = cycler(color=colors)

# Use the first color from the custom color cycle
first_color = plt.rcParams["axes.prop_cycle"].by_key()["color"][0]
plt.rcParams["text.usetex"] = False
# Use the first color from the custom color cycle
first_color = plt.rcParams["axes.prop_cycle"].by_key()["color"][0]
plt.rcParams["text.usetex"] = False

return cmap, colors, first_color


def plot_predicted_vs_ground_truth(
Expand All @@ -29,7 +33,7 @@ def plot_predicted_vs_ground_truth(
Returns:
None
"""

_, _, first_color = style_setup()
# Creating the plot
plt.figure(figsize=(10, 8))
plt.scatter(y_test, y_pred, color=first_color, edgecolor="k", alpha=0.6)
Expand All @@ -53,7 +57,7 @@ def plot_predicted_vs_ground_truth_density(
Returns:
None
"""

cmap, _, _ = style_setup()
# Calculate the point densities
values = np.vstack([y_test, y_pred])
kernel = gaussian_kde(values)(values)
Expand Down Expand Up @@ -94,6 +98,7 @@ def plot_confusion_matrix(
Returns:
None
"""
_, _, _ = style_setup()
if full: # Plot one big cm
plt.figure(figsize=(10, 8))
plt.imshow(cm, interpolation="nearest", cmap=plt.cm.Blues)
Expand Down Expand Up @@ -146,6 +151,7 @@ def plot_roc_curve(fpr, tpr, roc_auc, title, path):
Returns:
None
"""
_, _, _ = style_setup()
plt.figure(figsize=(10, 8))
plt.plot(
fpr,
Expand Down
4 changes: 1 addition & 3 deletions nmrcraft/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,7 @@ def __init__(
self.dataset = load_dummy_dataset_locally()

def load_data(self):
self.dataset = filename_to_ligands(
self.dataset
) # Assuming filename_to_ligands is defined elsewhere
self.dataset = filename_to_ligands(self.dataset)
self.dataset = self.dataset.sample(frac=self.dataset_size)
if self.target_type == "categorical":
return self.split_and_preprocess_categorical()
Expand Down
23 changes: 22 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ huggingface-hub = "^0.22.2"
mlflow = "^2.12.1"
argparse = "^1.4.0"
hyperopt = "^0.2.7"
seaborn = "^0.13.2"

[tool.poetry.group.dev.dependencies]
black = "^24.4.2"
Expand Down
121 changes: 121 additions & 0 deletions scripts/analysis/dataset_statistics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""Script to show important statist8ics of the dataset, such as ligand classes, ligand class distriutions."""

import os

import matplotlib.pyplot as plt
import seaborn as sns

from nmrcraft.analysis.plotting import style_setup
from nmrcraft.data.dataset import filename_to_ligands, load_dataset_from_hf


def plot_stacked_bars(
df, group_col, stack_col, output_file, title, rotation_deg
):
"""
Generic function to plot stacked bars, with annotations for counts just below the top of each bar.
"""
_, colors, _ = style_setup()
plt.figure(figsize=(12, 10))
categories = df[group_col].unique()
custom_palette = sns.color_palette(colors)[
:2
] # Assuming the first two colors are for Mo and W

# Calculate bottom heights and keep track of top heights for annotation
bottom_heights = {category: 0 for category in categories}
top_heights = {category: 0 for category in categories}

for metal, color in zip(df[stack_col].unique(), custom_palette):
metal_data = df[df[stack_col] == metal]
counts = (
metal_data[group_col]
.value_counts()
.reindex(categories, fill_value=0)
)
bars = plt.bar(
categories,
counts,
bottom=[bottom_heights[cat] for cat in categories],
color=color,
label=metal,
)

# Annotate each bar
for bar, cat in zip(bars, categories):
height = bar.get_height()
if height > 0: # Only annotate if the bar's height is not zero
plt.text(
bar.get_x() + bar.get_width() / 2,
top_heights[cat] + height - 0.05 * height,
f"{height}",
ha="center",
va="top",
color="white",
fontsize=18,
rotation=rotation_deg,
)
top_heights[cat] += height

# Update the bottom heights for the next metal
for i, cat in enumerate(categories):
bottom_heights[cat] += counts.iloc[i]

plt.xticks(rotation=60, ha="right", fontsize=18)
plt.xlabel(group_col)
plt.ylabel("Count")
plt.title(title)
plt.legend(title="Metal")

plt.savefig(output_file, bbox_inches="tight")
plt.close()


def main():
# Load data
df = filename_to_ligands(load_dataset_from_hf())
df = df[
[
"metal",
"geometry",
"E_ligand",
"X1_ligand",
"X2_ligand",
"X3_ligand",
"X4_ligand",
"L_ligand",
]
]

output_path = "plots"
os.makedirs(output_path, exist_ok=True)

# Modify 'E_ligand' values to group imido-containing ligands
df["E_ligand_grouped"] = df["E_ligand"].apply(
lambda x: "Imido Group" if "imido" in x else x
)

# Plotting all ligands with imido grouped
plot_stacked_bars(
df,
"E_ligand_grouped",
"metal",
os.path.join(output_path, "ligands_distribution_grouped.png"),
"Distribution of data points per E-ligand, with imido grouped",
rotation_deg=0,
)

# Plotting only imido ligands
imido_df = df[df["E_ligand"].str.contains("imido")]
plot_stacked_bars(
imido_df,
"E_ligand",
"metal",
os.path.join(output_path, "imido_ligands_distribution.png"),
"Distribution of imido-containing E-ligands",
rotation_deg=90,
)


if __name__ == "__main__":
main()
142 changes: 142 additions & 0 deletions scripts/analysis/pca_ligand_space.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
"""Script to plot a PCA of the complexes according to their principal components."""

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

from nmrcraft.analysis.plotting import style_setup
from nmrcraft.data.dataset import filename_to_ligands, load_dataset_from_hf


def perform_pca(df, features):
"""Perform PCA on specified features and return principal components."""
scaler = StandardScaler()
df_scaled = scaler.fit_transform(df[features])
pca = PCA(n_components=2) # Reduce to 2 dimensions for plotting
principal_components = pca.fit_transform(df_scaled)
principal_df = pd.DataFrame(
data=principal_components, columns=["PC1", "PC2"]
)
return principal_df


def plot_pca(df, pca_df, category, title, filter_condition=None, suffix=""):
"""Generate and save PCA plots colored by categories, with optional filtering."""
cmap, colors, _ = style_setup()
fig, ax = plt.subplots()

if filter_condition is not None:
filtered_indices = filter_condition(df)
df_filtered = df[filtered_indices]
pca_df_filtered = pca_df[filtered_indices]
else:
df_filtered = df
pca_df_filtered = pca_df

categories = df_filtered[category].unique()
colors = cmap(np.linspace(0, 1, len(categories)))

for c, color in zip(categories, colors):
indices_to_keep = df_filtered[category] == c
ax.scatter(
pca_df_filtered.loc[indices_to_keep, "PC1"],
pca_df_filtered.loc[indices_to_keep, "PC2"],
s=50,
label=c,
color=color,
)

ax.set_xlabel("Principal Component 1")
ax.set_ylabel("Principal Component 2")
ax.set_ylim(-5, 10)
ax.set_xlim(-4, 10)
ax.set_title(title)
if category == "E_ligand":
if suffix == "_without_imido":
ax.legend(
title=category,
loc="upper center",
bbox_to_anchor=(0.5, -0.15),
fancybox=True,
shadow=True,
ncol=3,
)
elif category == "metal" or category == "geometry":
ax.legend(
title=category,
loc="upper center",
bbox_to_anchor=(0.5, -0.15),
fancybox=True,
shadow=True,
ncol=3,
)
plt.savefig(f"plots/pca_{category}{suffix}.png", bbox_inches="tight")
plt.close()


def main():
df = filename_to_ligands(load_dataset_from_hf())
df = df[
[
"metal",
"geometry",
"E_ligand",
"M_sigma11_ppm",
"M_sigma22_ppm",
"M_sigma33_ppm",
"M_sigmaiso_ppm",
"E_sigma11_ppm",
"E_sigma22_ppm",
"E_sigma33_ppm",
"E_sigmaiso_ppm",
]
]

features = [
"M_sigma11_ppm",
"M_sigma22_ppm",
"M_sigma33_ppm",
"M_sigmaiso_ppm",
"E_sigma11_ppm",
"E_sigma22_ppm",
"E_sigma33_ppm",
"E_sigmaiso_ppm",
]

pca_df = perform_pca(df, features)

plot_pca(df, pca_df, "metal", "PCA Plot Colored by Metal")
plot_pca(df, pca_df, "geometry", "PCA Plot Colored by Geometry")

# Standard plot for E_ligand
plot_pca(df, pca_df, "E_ligand", "PCA Plot Colored by E-Ligand")

# Plot without 'imido' entries
plot_pca(
df,
pca_df,
"E_ligand",
"PCA Plot Colored by E-Ligand (Without Imido)",
filter_condition=lambda x: ~x["E_ligand"].str.contains(
"imido", na=False
),
suffix="_without_imido",
)

# Plot only 'imido' entries
plot_pca(
df,
pca_df,
"E_ligand",
"PCA Plot Colored by E-Ligand (Imido Only)",
filter_condition=lambda x: x["E_ligand"].str.contains(
"imido", na=False
),
suffix="_imido_only",
)


if __name__ == "__main__":
main()
3 changes: 0 additions & 3 deletions style.mplstyle
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
text.usetex : True
font.family : sans-serif
font.sans-serif : Arial
font.size : 30
axes.titlesize : 24
axes.labelsize : 20
Expand Down

0 comments on commit 215be3f

Please sign in to comment.