Sling Academy
Home/Scikit-Learn/Visualizing Calibration Curves with Scikit-Learn's `CalibrationDisplay`

Visualizing Calibration Curves with Scikit-Learn's `CalibrationDisplay`

Last updated: December 17, 2024

Calibration curves are instrumental in understanding the predictions of classification models. They help visualize how well a model's predicted probabilities align with the actual outcomes, indicating the reliability of probability estimates. Scikit-Learn, one of the most popular machine learning libraries in Python, provides convenient tools for plotting these curves, particularly through the CalibrationDisplay class. This article walks you through using the CalibrationDisplay to generate and visualize calibration curves effectively.

Understanding Calibration Curves

Calibration curves plot the predicted probabilities against the observed frequencies. Ideally, a perfectly calibrated model will have points along the diagonal line y=x, suggesting that predicted probabilities approximate the observed probabilities well. For example, if a model predicts a probability of 0.7 for a positive event, it should occur 70% of the time.

Setting Up the Environment

To begin with, ensure you have Scikit-Learn installed in your Python environment. You can install it via pip if you haven't already:

pip install scikit-learn

Next, prepare some example data to illustrate the process. We'll use the breast cancer dataset from Scikit-Learn's datasets module for demonstration purposes.

Code Example: Basic Calibration Display

Let's create a basic calibration display using a logistic regression model.

from sklearn.datasets import load_breast_cancer
from sklearn.linear_model import LogisticRegression
from sklearn.calibration import CalibrationDisplay, calibration_curve
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

# Load dataset
X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Fit a logistic regression model
model = LogisticRegression(max_iter=1000)
model.fit(X_train, y_train)

# Predicted probabilities
prob_pos = model.predict_proba(X_test)[:, 1]

# Calibration curve
disp = CalibrationDisplay.from_predictions(y_test, prob_pos, name="Logistic Regression")
plt.title("Calibration Curve")
plt.show()

In this snippet, we train a logistic regression model on the breast cancer dataset and compute the predicted probabilities on a test set. The CalibrationDisplay.from_predictions method is used to plot the calibration curve.

Advanced Calibration Display Usage

The CalibrationDisplay class also allows for advanced features, like plotting multiple curves or using more sophisticated models and plotting additional information such as probability histograms.

Let's consider an example where we compare multiple models:

from sklearn.ensemble import RandomForestClassifier

# Fit a random forest model
rf = RandomForestClassifier(n_estimators=100)
rf.fit(X_train, y_train)
prob_pos_rf = rf.predict_proba(X_test)[:, 1]

# Plot calibration curves for both models
disp_lr = CalibrationDisplay.from_predictions(y_test, prob_pos, name="Logistic Regression")
disp_rf = CalibrationDisplay.from_predictions(y_test, prob_pos_rf, name="Random Forest")

plt.title("Comparison of Calibration Curves")
plt.legend(loc="best")
plt.show()

In this code block, we have two models: logistic regression and random forest. We calculate their predicted probabilities and create a comparative calibration plot, which is extremely useful for visually assessing which model provides better probability estimates.

Customizing the Calibration Display

The CalibrationDisplay can be customized to suit various needs. For instance, you may want to alter the color scheme, add a histogram of predicted probabilities, or change the number of bins in the probability space.

Here's an example adjusting some settings:

disp = CalibrationDisplay.from_predictions(
    y_test, prob_pos, name="Logistic Regression",
    strategy="quantile", n_bins=15, hist=True
)
plt.title("Customized Calibration Curve")
plt.show()

In this customization, we change the strategy to ‘quantile’, whereby bins are designed to have equal numbers of samples, which can be helpful in certain scenarios. We also increase the number of bins and turn on the histogram.

Conclusion

Visualizing calibration curves helps in comprehending the probabilistic performance of classifiers. The CalibrationDisplay from Scikit-Learn makes it straightforward to plot these curves, allowing data scientists to validate the precision and reliability of models' probability outputs. Remember, calibration is pivotal for decision-making processes, especially in critical applications where predictive certainty is critical!

Next Article: A Step-by-Step Guide to Scikit-Learn's `AffinityPropagation`

Previous Article: How to Perform Calibration with Scikit-Learn's `CalibratedClassifierCV`

Series: Scikit-Learn Tutorials

Scikit-Learn

You May Also Like

  • Generating Gaussian Quantiles with Scikit-Learn
  • Spectral Biclustering with Scikit-Learn
  • Scikit-Learn Complete Cheat Sheet
  • ValueError: Estimator Does Not Support Sparse Input in Scikit-Learn
  • Scikit-Learn TypeError: Cannot Broadcast Due to Shape Mismatch
  • AttributeError: 'dict' Object Has No Attribute 'predict' in Scikit-Learn
  • KeyError: Missing 'param_grid' in Scikit-Learn GridSearchCV
  • Scikit-Learn ValueError: 'max_iter' Must Be Positive Integer
  • Fixing Log Function Error with Negative Values in Scikit-Learn
  • RuntimeError: Distributed Computing Backend Not Found in Scikit-Learn
  • Scikit-Learn TypeError: '<' Not Supported Between 'str' and 'int'
  • AttributeError: GridSearchCV Has No Attribute 'fit_transform' in Scikit-Learn
  • Fixing Scikit-Learn Split Error: Number of Splits > Number of Samples
  • Scikit-Learn TypeError: Cannot Concatenate 'str' and 'int'
  • ValueError: Cannot Use 'predict' Before Fitting Model in Scikit-Learn
  • Fixing AttributeError: NoneType Has No Attribute 'predict' in Scikit-Learn
  • Scikit-Learn ValueError: Cannot Reshape Array of Incorrect Size
  • LinAlgError: Matrix is Singular to Machine Precision in Scikit-Learn
  • Fixing TypeError: ndarray Object is Not Callable in Scikit-Learn