diff --git a/src/utils_plots.py b/src/utils_plots.py index 7aedbdd4..0b22e974 100644 --- a/src/utils_plots.py +++ b/src/utils_plots.py @@ -143,7 +143,7 @@ def plot_classification_report(classification_report, title='Classification repo else: lines = classification_report.split('\n') for line in lines[2 : (len(lines) - 1)]: - t = line.strip().replace('avg / total', 'micro-avg').split() + t = line.strip().replace(' avg', '-avg').split() if len(t) < 2: continue classes.append(t[0]) v = [float(x)*100 for x in t[1: len(t) - 1]]