-
Notifications
You must be signed in to change notification settings - Fork 1
/
plot_ablation_results.py
35 lines (27 loc) · 1.21 KB
/
plot_ablation_results.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import numpy as np
from matplotlib import pyplot as plt
ablation_percentages = np.array([0, 0.25, 0.5, 0.75, 1.0])
# the 2nd and the second to last are swapped
vision_aug_results = np.array([27.15, 34.72, 35.24, 36.64, 36.81])
cdf_aug_results = np.array([34.09, 34.89, 36.61, 36.81, 36.81])
human_results = np.array([19.17, 19.17, 19.17, 19.17, 19.17])
y_ticks = [17, 20, 23, 26, 29, 32, 35, 38]
fig = plt.figure()
ax = fig.add_subplot(111) # noqa: WPS432
ax.plot(ablation_percentages, human_results, "-.g", label="DTC")
ax.plot(ablation_percentages, vision_aug_results, "-bs", label="Visual Aug")
ax2 = ax.twinx()
ax2.plot(ablation_percentages, cdf_aug_results, "--ro", label="CDF Aug") # type: ignore[attr-defined]
# fig.legend(loc="upper right")
# ax.set_xlabel("Pr")
ax.set_yticks(y_ticks)
ax2.set_yticks(y_ticks)
ax.set_ylabel("Vision Augmentations")
ax2.set_ylabel("CDF Augmentations")
ax.grid()
fig.legend(loc="upper center", bbox_to_anchor=(0.5, 0.425), fancybox=True, ncol=3)
plt.xticks(ablation_percentages, ablation_percentages) # type: ignore[arg-type]
# plt.show()
plt.title("Performance curves when ablating augmentations")
ax.set_xlabel("Proportion of train instances")
plt.savefig("human.pdf", transparent=True)