From f19a7ec9f73ba92a099dbbdbcee71f71ec4cee27 Mon Sep 17 00:00:00 2001 From: Jelle Teijema Date: Thu, 7 Sep 2023 10:43:02 +0200 Subject: [PATCH] Fix no legend error message and duplicate labels (#45) --- .../templates/script_get_plot.py.template | 96 ++++++++++--------- 1 file changed, 49 insertions(+), 47 deletions(-) diff --git a/asreviewcontrib/makita/templates/script_get_plot.py.template b/asreviewcontrib/makita/templates/script_get_plot.py.template index 52e884d..1108a22 100644 --- a/asreviewcontrib/makita/templates/script_get_plot.py.template +++ b/asreviewcontrib/makita/templates/script_get_plot.py.template @@ -26,61 +26,63 @@ from asreview import open_state from asreviewcontrib.insights.plot import plot_recall -def get_plot_from_states(states, filename, legend=None): - """Generate an ASReview plot from state files.""" +def _set_legend(ax, state, legend_option, label_to_line, state_file): + metadata = state.settings_metadata + label = None + + if legend_option == "filename": + label = state_file.stem + elif legend_option == "model": + label = " - ".join( + [metadata["settings"]["model"], + metadata["settings"]["feature_extraction"], + metadata["settings"]["balance_strategy"], + metadata["settings"]["query_strategy"]]) + elif legend_option == "classifier": + label = metadata["settings"]["model"] + else: + try: + label = metadata["settings"][legend_option] + except KeyError as err: + raise ValueError(f"Invalid legend setting: '{legend_option}'") from err # noqa: E501 + + if label: + # add label to line + if label not in label_to_line: + ax.lines[-2].set_label(label) + label_to_line[label] = ax.lines[-2] + # set color of line to the color of the first line with the same label + else: + ax.lines[-2].set_color(label_to_line[label].get_color()) + ax.lines[-2].set_label("_no_legend_") - # sort the states alphabetically - states = sorted(states) +def get_plot_from_states(states, filename, legend=None): + """Generate an ASReview plot from state files. + + Arguments + --------- + states: list + List of state files. + filename: str + Filename of the plot. + legend: str + Add a legend to the plot, based on the given parameter. + Possible values: "filename", "model", "feature_extraction", + "balance_strategy", "query_strategy", "classifier". + """ + states = sorted(states) fig, ax = plt.subplots() - - labels = [] + label_to_line = {} for state_file in states: with open_state(state_file) as state: - # draw the plot plot_recall(ax, state) + if legend: + _set_legend(ax, state, legend, label_to_line, state_file) - # settings for legend "filename" - if legend == "filename": - ax.lines[-2].set_label(state_file.stem) - ax.legend(loc=4, prop={'size': 8}) - # settings for legend "settings" - elif legend: - metadata = state.settings_metadata - - # settings for legend "model" - if legend == "model": - label = " - ".join( - [metadata["settings"]["model"], - metadata["settings"]["feature_extraction"], - metadata["settings"]["balance_strategy"], - metadata["settings"]["query_strategy"]]) - - # settings for legend "classifier" - elif legend == "classifier": - label = metadata["settings"]["model"] - - # settings for legend from metadata - else: - try: - label = metadata["settings"][legend] - except KeyError as exc: - raise ValueError( - f"Legend setting '{legend}' " - "not found in state file settings." - ) from exc - - # add label to legend if not already present - # (multiple states can have the same label) - if label not in labels: - ax.lines[-2].set_label(label) - labels.append(label) - - # add legend to plot - ax.legend(loc=4, prop={'size': 8}) - - # save plot + if legend: + ax.legend(loc=4, prop={'size': 8}) fig.savefig(str(filename))