-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix/multilabel confusion matrix #51
Conversation
Function takes binarrized encoded target array and decodes it back. useful for confusion matrix.
The one dimensional stuff wasn't working due to some places expecting lists of lists and only getting lists. Fixed now
Dataloader was added to evaluation to label the confusion matrix. If this is too much mixing of the dataloader we can also of course just pass the decoded y_pred and y_test.
Just noticed there is undefined behavior if the Model Chooses multiple elements for one row for example THF and Et2O on the same molecule. Looks like the Binarizer somehow chooses one, but I guess this shouldn't be an issue as it will work out in the statistics. |
Thanks a lot for the contribution! Checking the code asap |
willco tomorrow 👍🏼 |
Some setup for adding the multi dimensional support for the Confusion matrix the way it produces multiple cms for each target of the --targets.
This function returns a list of list of the columns of each target. This is needed to make multiple confusion matrices.
Added support for one target, one dim target array and multiple targets, multidimensional target array single target, multidimensional target array for confusion matrices.
Current stateThe code has now been refactored quite a bit, especially the training script. It can handle all sorts of dataset mixtures. Roc etc only are generated for binary tasks and depending on the target dimensionality, different things are executed. TodoFix the algorithm to make the subplots and add that cool color bar, but I feel like this depends on how exactly we want to plot it. The technically painful part is done, which is generate all the sub cms and bring along the right labels. Examples of generated confusion matrices |
""" | ||
Plots the confusion matrix. | ||
Parameters: | ||
- cm (array-like): Confusion matrix data. | ||
- classes (list): List of classes for the axis labels. | ||
- title (str): Title of the plot. | ||
- full (bool): If true plots one big, else many smaller. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the description hehe
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively for the final review round in the end, we could rename to “Plots single confusion matrix for all targets combined vs one CM per target” or something like that
plt.savefig(path) | ||
plt.close() | ||
|
||
elif not full: # Plot many small cms of each target |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could rewrite this to else:
as full is either True or False
sub_classes = classes[ | ||
slice(columns_set[i][0], columns_set[i][-1] + 1) | ||
] | ||
axs[i].imshow(sub_cm, interpolation="nearest", cmap=plt.cm.Blues) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super nice! in the final review, we can add our minecraft color scheme here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because we obtain the cmap from the setup_style() function in nmrcraft/analysis/plotting.py
""" | ||
y_decoded = self.binarized_target_decoder(y) | ||
flat_y_decoded = [y for ys in y_decoded for y in ys] | ||
return flat_y_decoded |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That must have been technically difficult to implement! Thanks for taking care of this. In the final review, we could add a bit more detailed documentation here so that the TAs know what’s going on in this function.
from typing import Any, Dict, Tuple | ||
|
||
from sklearn.base import BaseEstimator | ||
from sklearn.metrics import ( | ||
accuracy_score, | ||
auc, | ||
# confusion_matrix, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for cleaning up the code alongside!
|
||
|
||
def model_evaluation_nD( | ||
model: BaseEstimator, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice. Also a note for final review (I’m creating an issue out of this: add type hints for all functions and classes. The purpose of this is to enforce that a function takes inputs only of a certain type, and also only outputs a certain type.
instead of
def confusion_matric_plotter_or_whatever(target, full=True):
return 1+1
Write
def confusion_matric_plotter_or_whatever(target: np.array, full: bool = True) -> int:
return 1+1
This seems a bit useless at first but its a game changer when writing code that takes in many different inputs, and saves some amount of error messages!
X_test: Any, | ||
y_test: Any, | ||
y_labels: Any, | ||
dataloader: dataset.DataLoader, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, for example, you have such type hints
print(f"Accuracy: {metrics['accuracy']}") | ||
mlflow.log_artifact(get_cm_path()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for cleaning up! We might need to adapt this to the script @strsamue is writing.
What it do?
Give dataloader capability to decode target arrays from test and prediction set.
Should fix Issue #50 .
How?
Added internal function to dataloader that usess the same label binarizer that it used to encode the different categories.
It does the inverse encoding then and returns that for an arbitrary target array.
In the current implementaiton the dataloader is passed to the evaluation script to do translation in there but we can of course also translate inside the trainer script and pass the decoded arrays and labels.