From a48c1749c42f8a261b4973e0a8168c75ce3d9258 Mon Sep 17 00:00:00 2001 From: chris-santiago Date: Tue, 12 Sep 2023 08:58:46 -0400 Subject: [PATCH] expose strategy param for calibration curve --- requirements.txt | 2 +- scikitplot/metrics.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index 3b6d6e6..951b1fa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ matplotlib>=1.4.0 -scikit-learn>=0.18 +scikit-learn>=0.21 scipy>=0.9 joblib>=0.10 diff --git a/scikitplot/metrics.py b/scikitplot/metrics.py index 08ec693..93d2b44 100644 --- a/scikitplot/metrics.py +++ b/scikitplot/metrics.py @@ -908,7 +908,7 @@ def plot_silhouette(X, cluster_labels, title='Silhouette Analysis', return ax -def plot_calibration_curve(y_true, probas_list, clf_names=None, n_bins=10, +def plot_calibration_curve(y_true, probas_list, clf_names=None, n_bins=10, strategy="uniform", title='Calibration plots (Reliability Curves)', ax=None, figsize=None, cmap='nipy_spectral', title_fontsize="large", text_fontsize="medium"): @@ -938,6 +938,12 @@ def plot_calibration_curve(y_true, probas_list, clf_names=None, n_bins=10, n_bins (int, optional): Number of bins. A bigger number requires more data. + strategy (str, optional): Strategy used to define the widths of the bins. + uniform + The bins have identical widths. + quantile + The bins have the same number of samples and depend on `y_prob`. + title (string, optional): Title of the generated plot. Defaults to "Calibration plots (Reliabilirt Curves)" @@ -1024,7 +1030,7 @@ def plot_calibration_curve(y_true, probas_list, clf_names=None, n_bins=10, probas = (probas - probas.min()) / (probas.max() - probas.min()) fraction_of_positives, mean_predicted_value = \ - calibration_curve(y_true, probas, n_bins=n_bins) + calibration_curve(y_true, probas, n_bins=n_bins, strategy=strategy) color = plt.cm.get_cmap(cmap)(float(i) / len(probas_list))