diff --git a/requirements.txt b/requirements.txt index 3b6d6e6..951b1fa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ matplotlib>=1.4.0 -scikit-learn>=0.18 +scikit-learn>=0.21 scipy>=0.9 joblib>=0.10 diff --git a/scikitplot/metrics.py b/scikitplot/metrics.py index 08ec693..93d2b44 100644 --- a/scikitplot/metrics.py +++ b/scikitplot/metrics.py @@ -908,7 +908,7 @@ def plot_silhouette(X, cluster_labels, title='Silhouette Analysis', return ax -def plot_calibration_curve(y_true, probas_list, clf_names=None, n_bins=10, +def plot_calibration_curve(y_true, probas_list, clf_names=None, n_bins=10, strategy="uniform", title='Calibration plots (Reliability Curves)', ax=None, figsize=None, cmap='nipy_spectral', title_fontsize="large", text_fontsize="medium"): @@ -938,6 +938,12 @@ def plot_calibration_curve(y_true, probas_list, clf_names=None, n_bins=10, n_bins (int, optional): Number of bins. A bigger number requires more data. + strategy (str, optional): Strategy used to define the widths of the bins. + uniform + The bins have identical widths. + quantile + The bins have the same number of samples and depend on `y_prob`. + title (string, optional): Title of the generated plot. Defaults to "Calibration plots (Reliabilirt Curves)" @@ -1024,7 +1030,7 @@ def plot_calibration_curve(y_true, probas_list, clf_names=None, n_bins=10, probas = (probas - probas.min()) / (probas.max() - probas.min()) fraction_of_positives, mean_predicted_value = \ - calibration_curve(y_true, probas, n_bins=n_bins) + calibration_curve(y_true, probas, n_bins=n_bins, strategy=strategy) color = plt.cm.get_cmap(cmap)(float(i) / len(probas_list))