Skip to content

Commit

Permalink
Fix no legend error message and duplicate labels (#45)
Browse files Browse the repository at this point in the history
  • Loading branch information
jteijema authored Sep 7, 2023
1 parent a3d4e74 commit f19a7ec
Showing 1 changed file with 49 additions and 47 deletions.
96 changes: 49 additions & 47 deletions asreviewcontrib/makita/templates/script_get_plot.py.template
Original file line number Diff line number Diff line change
Expand Up @@ -26,61 +26,63 @@ from asreview import open_state
from asreviewcontrib.insights.plot import plot_recall


def get_plot_from_states(states, filename, legend=None):
"""Generate an ASReview plot from state files."""
def _set_legend(ax, state, legend_option, label_to_line, state_file):
metadata = state.settings_metadata
label = None

if legend_option == "filename":
label = state_file.stem
elif legend_option == "model":
label = " - ".join(
[metadata["settings"]["model"],
metadata["settings"]["feature_extraction"],
metadata["settings"]["balance_strategy"],
metadata["settings"]["query_strategy"]])
elif legend_option == "classifier":
label = metadata["settings"]["model"]
else:
try:
label = metadata["settings"][legend_option]
except KeyError as err:
raise ValueError(f"Invalid legend setting: '{legend_option}'") from err # noqa: E501

if label:
# add label to line
if label not in label_to_line:
ax.lines[-2].set_label(label)
label_to_line[label] = ax.lines[-2]
# set color of line to the color of the first line with the same label
else:
ax.lines[-2].set_color(label_to_line[label].get_color())
ax.lines[-2].set_label("_no_legend_")

# sort the states alphabetically
states = sorted(states)

def get_plot_from_states(states, filename, legend=None):
"""Generate an ASReview plot from state files.

Arguments
---------
states: list
List of state files.
filename: str
Filename of the plot.
legend: str
Add a legend to the plot, based on the given parameter.
Possible values: "filename", "model", "feature_extraction",
"balance_strategy", "query_strategy", "classifier".
"""
states = sorted(states)
fig, ax = plt.subplots()

labels = []
label_to_line = {}

for state_file in states:
with open_state(state_file) as state:
# draw the plot
plot_recall(ax, state)
if legend:
_set_legend(ax, state, legend, label_to_line, state_file)

# settings for legend "filename"
if legend == "filename":
ax.lines[-2].set_label(state_file.stem)
ax.legend(loc=4, prop={'size': 8})
# settings for legend "settings"
elif legend:
metadata = state.settings_metadata

# settings for legend "model"
if legend == "model":
label = " - ".join(
[metadata["settings"]["model"],
metadata["settings"]["feature_extraction"],
metadata["settings"]["balance_strategy"],
metadata["settings"]["query_strategy"]])

# settings for legend "classifier"
elif legend == "classifier":
label = metadata["settings"]["model"]

# settings for legend from metadata
else:
try:
label = metadata["settings"][legend]
except KeyError as exc:
raise ValueError(
f"Legend setting '{legend}' "
"not found in state file settings."
) from exc

# add label to legend if not already present
# (multiple states can have the same label)
if label not in labels:
ax.lines[-2].set_label(label)
labels.append(label)

# add legend to plot
ax.legend(loc=4, prop={'size': 8})

# save plot
if legend:
ax.legend(loc=4, prop={'size': 8})
fig.savefig(str(filename))


Expand Down

0 comments on commit f19a7ec

Please sign in to comment.