Skip to content

Commit

Permalink
Fix bug in get_plots (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
jteijema authored Jun 14, 2023
1 parent 5108928 commit 413fd34
Showing 1 changed file with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions asreviewcontrib/makita/templates/script_get_plot.py.template
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ Authors
import argparse
from pathlib import Path

import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
from asreview import open_state

Expand All @@ -30,31 +29,39 @@ from asreviewcontrib.insights.plot import plot_recall
def get_plot_from_states(states, filename, legend=None):
"""Generate an ASReview plot from state files."""

# sort the states alphabetically
states = sorted(states)

fig, ax = plt.subplots()

labels = []
colors = list(mcolors.TABLEAU_COLORS.values())

for state_file in states:
with open_state(state_file) as state:
# draw the plot
plot_recall(ax, state)

# set the label
# settings for legend "filename"
if legend == "filename":
ax.lines[-2].set_label(state_file.stem)
ax.legend(loc=4, prop={'size': 8})
# settings for legend "settings"
elif legend:
metadata = state.settings_metadata

# settings for legend "model"
if legend == "model":
label = " - ".join(
[metadata["settings"]["model"],
metadata["settings"]["feature_extraction"],
metadata["settings"]["balance_strategy"],
metadata["settings"]["query_strategy"]])

# settings for legend "classifier"
elif legend == "classifier":
label = metadata["settings"]["model"]

# settings for legend from metadata
else:
try:
label = metadata["settings"][legend]
Expand All @@ -63,12 +70,17 @@ def get_plot_from_states(states, filename, legend=None):
f"Legend setting '{legend}' "
"not found in state file settings."
) from exc

# add label to legend if not already present
# (multiple states can have the same label)
if label not in labels:
ax.lines[-2].set_label(label)
labels.append(label)
ax.lines[-2].set_color(colors[labels.index(label) % len(colors)])
ax.legend(loc=4, prop={'size': 8})

# add legend to plot
ax.legend(loc=4, prop={'size': 8})

# save plot
fig.savefig(str(filename))


Expand All @@ -92,10 +104,10 @@ if __name__ == "__main__":
args = parser.parse_args()

# load states
states = Path(args.s).glob("*.asreview")
states = list(Path(args.s).glob("*.asreview"))

# check if states are found
if len(list(states)) == 0:
if len(states) == 0:
raise FileNotFoundError(f"No state files found in {args.s}")

# generate plot and save results
Expand Down

0 comments on commit 413fd34

Please sign in to comment.