-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into fix/multilabel-confusion-matrix
- Loading branch information
Showing
8 changed files
with
304 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ mlruns/ | |
scratch/ | ||
dataset/ | ||
data/ | ||
plots/ | ||
|
||
docs/source | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
"""Script to show important statist8ics of the dataset, such as ligand classes, ligand class distriutions.""" | ||
|
||
import os | ||
|
||
import matplotlib.pyplot as plt | ||
import seaborn as sns | ||
|
||
from nmrcraft.analysis.plotting import style_setup | ||
from nmrcraft.data.dataset import filename_to_ligands, load_dataset_from_hf | ||
|
||
|
||
def plot_stacked_bars( | ||
df, group_col, stack_col, output_file, title, rotation_deg | ||
): | ||
""" | ||
Generic function to plot stacked bars, with annotations for counts just below the top of each bar. | ||
""" | ||
_, colors, _ = style_setup() | ||
plt.figure(figsize=(12, 10)) | ||
categories = df[group_col].unique() | ||
custom_palette = sns.color_palette(colors)[ | ||
:2 | ||
] # Assuming the first two colors are for Mo and W | ||
|
||
# Calculate bottom heights and keep track of top heights for annotation | ||
bottom_heights = {category: 0 for category in categories} | ||
top_heights = {category: 0 for category in categories} | ||
|
||
for metal, color in zip(df[stack_col].unique(), custom_palette): | ||
metal_data = df[df[stack_col] == metal] | ||
counts = ( | ||
metal_data[group_col] | ||
.value_counts() | ||
.reindex(categories, fill_value=0) | ||
) | ||
bars = plt.bar( | ||
categories, | ||
counts, | ||
bottom=[bottom_heights[cat] for cat in categories], | ||
color=color, | ||
label=metal, | ||
) | ||
|
||
# Annotate each bar | ||
for bar, cat in zip(bars, categories): | ||
height = bar.get_height() | ||
if height > 0: # Only annotate if the bar's height is not zero | ||
plt.text( | ||
bar.get_x() + bar.get_width() / 2, | ||
top_heights[cat] + height - 0.05 * height, | ||
f"{height}", | ||
ha="center", | ||
va="top", | ||
color="white", | ||
fontsize=18, | ||
rotation=rotation_deg, | ||
) | ||
top_heights[cat] += height | ||
|
||
# Update the bottom heights for the next metal | ||
for i, cat in enumerate(categories): | ||
bottom_heights[cat] += counts.iloc[i] | ||
|
||
plt.xticks(rotation=60, ha="right", fontsize=18) | ||
plt.xlabel(group_col) | ||
plt.ylabel("Count") | ||
plt.title(title) | ||
plt.legend(title="Metal") | ||
|
||
plt.savefig(output_file, bbox_inches="tight") | ||
plt.close() | ||
|
||
|
||
def main(): | ||
# Load data | ||
df = filename_to_ligands(load_dataset_from_hf()) | ||
df = df[ | ||
[ | ||
"metal", | ||
"geometry", | ||
"E_ligand", | ||
"X1_ligand", | ||
"X2_ligand", | ||
"X3_ligand", | ||
"X4_ligand", | ||
"L_ligand", | ||
] | ||
] | ||
|
||
output_path = "plots" | ||
os.makedirs(output_path, exist_ok=True) | ||
|
||
# Modify 'E_ligand' values to group imido-containing ligands | ||
df["E_ligand_grouped"] = df["E_ligand"].apply( | ||
lambda x: "Imido Group" if "imido" in x else x | ||
) | ||
|
||
# Plotting all ligands with imido grouped | ||
plot_stacked_bars( | ||
df, | ||
"E_ligand_grouped", | ||
"metal", | ||
os.path.join(output_path, "ligands_distribution_grouped.png"), | ||
"Distribution of data points per E-ligand, with imido grouped", | ||
rotation_deg=0, | ||
) | ||
|
||
# Plotting only imido ligands | ||
imido_df = df[df["E_ligand"].str.contains("imido")] | ||
plot_stacked_bars( | ||
imido_df, | ||
"E_ligand", | ||
"metal", | ||
os.path.join(output_path, "imido_ligands_distribution.png"), | ||
"Distribution of imido-containing E-ligands", | ||
rotation_deg=90, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
"""Script to plot a PCA of the complexes according to their principal components.""" | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import pandas as pd | ||
from sklearn.decomposition import PCA | ||
from sklearn.preprocessing import StandardScaler | ||
|
||
from nmrcraft.analysis.plotting import style_setup | ||
from nmrcraft.data.dataset import filename_to_ligands, load_dataset_from_hf | ||
|
||
|
||
def perform_pca(df, features): | ||
"""Perform PCA on specified features and return principal components.""" | ||
scaler = StandardScaler() | ||
df_scaled = scaler.fit_transform(df[features]) | ||
pca = PCA(n_components=2) # Reduce to 2 dimensions for plotting | ||
principal_components = pca.fit_transform(df_scaled) | ||
principal_df = pd.DataFrame( | ||
data=principal_components, columns=["PC1", "PC2"] | ||
) | ||
return principal_df | ||
|
||
|
||
def plot_pca(df, pca_df, category, title, filter_condition=None, suffix=""): | ||
"""Generate and save PCA plots colored by categories, with optional filtering.""" | ||
cmap, colors, _ = style_setup() | ||
fig, ax = plt.subplots() | ||
|
||
if filter_condition is not None: | ||
filtered_indices = filter_condition(df) | ||
df_filtered = df[filtered_indices] | ||
pca_df_filtered = pca_df[filtered_indices] | ||
else: | ||
df_filtered = df | ||
pca_df_filtered = pca_df | ||
|
||
categories = df_filtered[category].unique() | ||
colors = cmap(np.linspace(0, 1, len(categories))) | ||
|
||
for c, color in zip(categories, colors): | ||
indices_to_keep = df_filtered[category] == c | ||
ax.scatter( | ||
pca_df_filtered.loc[indices_to_keep, "PC1"], | ||
pca_df_filtered.loc[indices_to_keep, "PC2"], | ||
s=50, | ||
label=c, | ||
color=color, | ||
) | ||
|
||
ax.set_xlabel("Principal Component 1") | ||
ax.set_ylabel("Principal Component 2") | ||
ax.set_ylim(-5, 10) | ||
ax.set_xlim(-4, 10) | ||
ax.set_title(title) | ||
if category == "E_ligand": | ||
if suffix == "_without_imido": | ||
ax.legend( | ||
title=category, | ||
loc="upper center", | ||
bbox_to_anchor=(0.5, -0.15), | ||
fancybox=True, | ||
shadow=True, | ||
ncol=3, | ||
) | ||
elif category == "metal" or category == "geometry": | ||
ax.legend( | ||
title=category, | ||
loc="upper center", | ||
bbox_to_anchor=(0.5, -0.15), | ||
fancybox=True, | ||
shadow=True, | ||
ncol=3, | ||
) | ||
plt.savefig(f"plots/pca_{category}{suffix}.png", bbox_inches="tight") | ||
plt.close() | ||
|
||
|
||
def main(): | ||
df = filename_to_ligands(load_dataset_from_hf()) | ||
df = df[ | ||
[ | ||
"metal", | ||
"geometry", | ||
"E_ligand", | ||
"M_sigma11_ppm", | ||
"M_sigma22_ppm", | ||
"M_sigma33_ppm", | ||
"M_sigmaiso_ppm", | ||
"E_sigma11_ppm", | ||
"E_sigma22_ppm", | ||
"E_sigma33_ppm", | ||
"E_sigmaiso_ppm", | ||
] | ||
] | ||
|
||
features = [ | ||
"M_sigma11_ppm", | ||
"M_sigma22_ppm", | ||
"M_sigma33_ppm", | ||
"M_sigmaiso_ppm", | ||
"E_sigma11_ppm", | ||
"E_sigma22_ppm", | ||
"E_sigma33_ppm", | ||
"E_sigmaiso_ppm", | ||
] | ||
|
||
pca_df = perform_pca(df, features) | ||
|
||
plot_pca(df, pca_df, "metal", "PCA Plot Colored by Metal") | ||
plot_pca(df, pca_df, "geometry", "PCA Plot Colored by Geometry") | ||
|
||
# Standard plot for E_ligand | ||
plot_pca(df, pca_df, "E_ligand", "PCA Plot Colored by E-Ligand") | ||
|
||
# Plot without 'imido' entries | ||
plot_pca( | ||
df, | ||
pca_df, | ||
"E_ligand", | ||
"PCA Plot Colored by E-Ligand (Without Imido)", | ||
filter_condition=lambda x: ~x["E_ligand"].str.contains( | ||
"imido", na=False | ||
), | ||
suffix="_without_imido", | ||
) | ||
|
||
# Plot only 'imido' entries | ||
plot_pca( | ||
df, | ||
pca_df, | ||
"E_ligand", | ||
"PCA Plot Colored by E-Ligand (Imido Only)", | ||
filter_condition=lambda x: x["E_ligand"].str.contains( | ||
"imido", na=False | ||
), | ||
suffix="_imido_only", | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters