diff --git a/src/conformist/prediction_dataset.py b/src/conformist/prediction_dataset.py index dd8968e..3782a5e 100644 --- a/src/conformist/prediction_dataset.py +++ b/src/conformist/prediction_dataset.py @@ -452,7 +452,7 @@ def visualize_prediction_heatmap(self, min_softmax_threshold=0.5): hm2.set_yticks([]) # Add label - hm2.set_xlabel("mean softmax\nfalse positive", + hm2.set_xlabel("mean softmax FP", weight='bold', rotation=90) @@ -470,9 +470,8 @@ def visualize_prediction_heatmap(self, min_softmax_threshold=0.5): # Rotate x labels plt.setp(hm3.get_xticklabels(), rotation=90) - # Set y label - hm3.set_ylabel(f"MEAN PREDICTION SET SIZE, SOFTMAX > {min_softmax_threshold}", - weight='bold', labelpad=labelpad) + # Remove y label + hm3.set_ylabel('') # Position y label to the right of heatmap hm3.yaxis.set_label_position("right") @@ -481,7 +480,9 @@ def visualize_prediction_heatmap(self, min_softmax_threshold=0.5): hm3.set_yticks([]) # Remove x label - hm3.set_xlabel('') + hm3.set_xlabel("mean set size\nsoftmax > 0.5", + weight='bold', + rotation=90) # Remove x ticks hm3.set_xticks([])