Skip to content

Commit

Permalink
Fix/multilabel confusion matrix (#51)
Browse files Browse the repository at this point in the history
* Init: init fixing environment + scratch script

* Added binarized_target_decoder to dataloader

Function takes binarrized encoded target array and decodes it back.
useful for confusion matrix.

* Confusion Matrix working for all except metals

* Fixed problem with metals flag

The one dimensional stuff wasn't working due to some places expecting
lists of lists and only getting lists. Fixed now

* Chore: get rid of ported testing scripts

* Chore: restore model_evaluation as much as possible

* Chore: get rid of TODO comments fixed by dataloader

* Fix: added the dataloader y_labels to the evaluation + cleanup

Dataloader was added to evaluation to label the confusion matrix. If
this is too much mixing of the dataloader we can also of course just
pass the decoded y_pred and y_test.

* Refactor train_script and multi dim support for evaluation and plotting

Some setup for adding the multi dimensional support for the Confusion
matrix the way it produces multiple cms for each target of the
--targets.

* Feature: get_target_columns_separated() added to dataloader

This function returns a list of list of the columns of each target. This
is needed to make multiple confusion matrices.

* Add support for multiple cm types

Added support for one target, one dim target array and
multiple targets, multidimensional target array
single target, multidimensional target array for confusion matrices.

* Hotfix: remove to_csv("yeet")

---------

Co-authored-by: TiagoW <[email protected]>
  • Loading branch information
tiaguinho-code and TiagoW authored May 14, 2024
1 parent 8176e2e commit 8eba2a7
Show file tree
Hide file tree
Showing 4 changed files with 267 additions and 55 deletions.
57 changes: 44 additions & 13 deletions nmrcraft/analysis/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,29 +84,60 @@ def plot_predicted_vs_ground_truth_density(
plt.show()


def plot_confusion_matrix(cm, classes, title, path):
def plot_confusion_matrix(
cm, classes, title, path, full=True, columns_set=False
):
"""
Plots the confusion matrix.
Parameters:
- cm (array-like): Confusion matrix data.
- classes (list): List of classes for the axis labels.
- title (str): Title of the plot.
- full (bool): If true plots one big, else many smaller.
- columns_set (list of lists): contains all relevant indices.
Returns:
None
"""
_, _, _ = style_setup()
plt.figure(figsize=(10, 8))
plt.imshow(cm, interpolation="nearest", cmap=plt.cm.Blues)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
plt.tight_layout()
plt.ylabel("True label")
plt.xlabel("Predicted label")
plt.savefig(path)
plt.close()
if full: # Plot one big cm
plt.figure(figsize=(10, 8))
plt.imshow(cm, interpolation="nearest", cmap=plt.cm.Blues)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
plt.tight_layout()
plt.ylabel("True label")
plt.xlabel("Predicted label")
plt.savefig(path)
plt.close()

elif not full: # Plot many small cms of each target
cms = []
for columns in columns_set: # Make list of confusion matrices
cms.append(
cm[
slice(columns[0], columns[-1] + 1),
slice(columns[0], columns[-1] + 1),
]
)
fig, axs = plt.subplots(nrows=len(cms), figsize=(10, 8 * len(cms)))
for i, sub_cm in enumerate(cms):
sub_classes = classes[
slice(columns_set[i][0], columns_set[i][-1] + 1)
]
axs[i].imshow(sub_cm, interpolation="nearest", cmap=plt.cm.Blues)
axs[i].set_title(f"Confusion Matrix {i+1}")
tick_marks = np.arange(len(sub_classes))
axs[i].set_xticks(tick_marks)
axs[i].set_xticklabels(sub_classes, rotation=45)
axs[i].set_yticks(tick_marks)
axs[i].set_yticklabels(sub_classes)
plt.tight_layout()
print(cm)
plt.savefig(path)
plt.close()


def plot_roc_curve(fpr, tpr, roc_auc, title, path):
Expand Down
94 changes: 92 additions & 2 deletions nmrcraft/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,18 @@ def target_label_readabilitizer_categorical(target_labels):
return good_labels


def column_length_to_indices(column_lengths):
indices = []
start_index = 0
for length in column_lengths:
if length == 1:
indices.append([start_index])
else:
indices.append(list(range(start_index, start_index + length)))
start_index += length
return indices


class DataLoader:
def __init__(
self,
Expand Down Expand Up @@ -225,6 +237,80 @@ def preprocess_features(self, X):
X_scaled = scaler.fit_transform(X)
return X_scaled, scaler

def get_target_columns_separated(self):
"""Returns the column indicies of the target array nicely sorted.
For example: metal_X1: [[0, 1], [1, 2, 3, 4]]"""
if (
"metal" in self.target_columns
): # If targets have metal, do weird stuff
metal_index = self.target_columns.index("metal")
y_column_indices = column_length_to_indices(
self.target_column_numbers
)
for i in range(len(y_column_indices)):
if i == metal_index:
y_column_indices[i].append(y_column_indices[i][0] + 1)
if i > metal_index:
y_column_indices[i] = [x + 1 for x in y_column_indices[i]]

elif "metal" not in self.target_columns:
y_column_indices = column_length_to_indices(
self.target_column_numbers
)
return y_column_indices

def more_than_one_target(self):
"""Function returns true if more than one target is specified"""
return len(self.target_columns) > 1

def binarized_target_decoder(self, y):
"""
function takes in the target (y) array and transforms it back to decoded form.
For this function to be run the one-hot-preprocesser already has to have been run beforehand.
"""
y_column_indices = column_length_to_indices(self.target_column_numbers)
ys = []
ys_decoded = []
# Split up compressed array into the categories
# If one-dimensional y (f.e only metals)
if isinstance(y[0], np.int64):
ys = list(y[:]) # copy list
ys = np.array(list(map(list, [[x] for x in ys])))
ys_decoded = self.encoders[0].inverse_transform(ys)
ys_decoded_properly_rotated = np.array(
list(map(list, [[x] for x in ys_decoded]))
)
# Decode the binarized metal using the original binarizer
# If multidimensional y
if not isinstance(y[0], np.int64):
for i in range(len(y_column_indices)):
ys.append(y[:, y_column_indices[i]])
# Decode the binarized categries using the original binarizers
for i in range(len(ys)):
ys_decoded.append(self.encoders[i].inverse_transform(ys[i]))
ys_decoded_properly_rotated = [
list(x) if i == 0 else x
for i, x in enumerate(map(list, zip(*ys_decoded)))
]
return ys_decoded_properly_rotated

def confusion_matrix_data_adapter(self, y):
"""
Takes in binary encoded target array and returns decoded flat list.
Especially designed to work with confusion matrix.
"""
y_decoded = self.binarized_target_decoder(y)
flat_y_decoded = [y for ys in y_decoded for y in ys]
return flat_y_decoded

def confusion_matrix_label_adapter(self, y_labels):
y_labels_copy = y_labels[:]
for i in range(len(y_labels)):
if y_labels_copy[i] == "Mo W":
y_labels_copy[i] = "Mo"
y_labels_copy.insert(i, "W")
return y_labels_copy

def split_and_preprocess_categorical(self):
"""
Split data into training and test sets, then apply normalization.
Expand Down Expand Up @@ -324,10 +410,16 @@ def split_and_preprocess_one_hot(self):
self.target_unique_labels = target_unique_labels
ys = []
readable_labels = []
self.encoders = []
self.target_column_numbers = []
for i in range(len(target_unique_labels)):
LBiner = LabelBinarizer()
ys.append(LBiner.fit_transform(y_labels[i]))
readable_labels.append(LBiner.classes_)
self.encoders.append(LBiner) # save encoder for later decoding
self.target_column_numbers.append(
len(ys[i][0])
) # save column numbers for later decoding
y = np.concatenate(list(ys), axis=1)

# Get NMR and structural Features, one-hot-encode and combine
Expand All @@ -342,8 +434,6 @@ def split_and_preprocess_one_hot(self):
X_Structural_Features_enc = one_hot.transform(
X_Structural_Features
).toarray()
# X = [X_NMR, X_Structural_Features_enc]
# print(X)

# Split the datasets
(
Expand Down
78 changes: 73 additions & 5 deletions nmrcraft/evaluation/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,34 @@
import os
from typing import Any, Dict, Tuple

from sklearn.base import BaseEstimator
from sklearn.metrics import (
accuracy_score,
auc,
# confusion_matrix,
multilabel_confusion_matrix,
confusion_matrix,
f1_score,
roc_curve,
)

from nmrcraft.data import dataset


def model_evaluation(
model: BaseEstimator, X_test: Any, y_test: Any
model: BaseEstimator,
X_test: Any,
y_test: Any,
y_labels: Any,
dataloader: dataset.DataLoader,
) -> Tuple[Dict[str, float], Any, Any, Any]:
"""
Evaluate the performance of the trained machine learning model.
Evaluate the performance of the trained machine learning model for 1D targets.
Args:
model (BaseEstimator): The trained machine learning model.
X_test (Any): The input features for testing.
y_test (Any): The true labels for testing.
y_labels (Any): Label for the columns of the target.
dataloader (DataLoader): Dataloader to decode the target arrays.
Returns:
Tuple[Dict[str, float], Any, Any, Any]: A tuple containing:
Expand All @@ -30,12 +38,18 @@ def model_evaluation(
- The true positive rate.
"""
y_pred = model.predict(X_test)

score = accuracy_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred, average="weighted")
cm = multilabel_confusion_matrix(y_test, y_pred)
fpr, tpr, thresholds = roc_curve(y_test, model.predict_proba(X_test)[:, 1])
roc_auc = auc(fpr, tpr)

y_test_cm = dataloader.confusion_matrix_data_adapter(y_test)
y_pred_cm = dataloader.confusion_matrix_data_adapter(y_pred)
y_labels_cm = dataloader.confusion_matrix_label_adapter(y_labels)
cm = confusion_matrix(
y_pred=y_pred_cm, y_true=y_test_cm, labels=y_labels_cm
)
return (
{
"accuracy": score,
Expand All @@ -46,3 +60,57 @@ def model_evaluation(
fpr,
tpr,
)


def model_evaluation_nD(
model: BaseEstimator,
X_test: Any,
y_test: Any,
y_labels: Any,
dataloader: dataset.DataLoader,
) -> Tuple[Dict[str, float], Any, Any, Any]:
"""
Evaluate the performance of the trained machine learning model for 2D+ Targets.
Args:
model (BaseEstimator): The trained machine learning model.
X_test (Any): The input features for testing.
y_test (Any): The true labels for testing.
y_labels (Any): Label for the columns of the target.
dataloader (DataLoader): Dataloader to decode the target arrays.
Returns:
Tuple[Dict[str, float], Any]: A tuple containing:
- A dictionary with evaluation metrics (accuracy, f1_score).
- The confusion matrix.
"""
y_pred = model.predict(X_test)
y_test_cm = dataloader.confusion_matrix_data_adapter(y_test)
y_pred_cm = dataloader.confusion_matrix_data_adapter(y_pred)
y_labels_cm = dataloader.confusion_matrix_label_adapter(y_labels)
score = accuracy_score(y_test_cm, y_pred_cm)
f1 = f1_score(y_test_cm, y_pred_cm, average="weighted")
cm = confusion_matrix(
y_pred=y_pred_cm, y_true=y_test_cm, labels=y_labels_cm
)
return (
{
"accuracy": score,
"f1_score": f1,
},
cm,
)


def get_cm_path():
fig_path = "scratch/"
if not os.path.exists(fig_path):
os.makedirs(fig_path)
return os.path.join(fig_path, "cm.png")


def get_roc_path():
fig_path = "scratch/"
if not os.path.exists(fig_path):
os.makedirs(fig_path)
return os.path.join(fig_path, "roc.png")
Loading

0 comments on commit 8eba2a7

Please sign in to comment.