Sling Academy
Home/Scikit-Learn/Visualizing Learning Curves with Scikit-Learn

Visualizing Learning Curves with Scikit-Learn

Last updated: December 17, 2024

When building machine learning models, understanding and observing the model's performance over time is crucial. Learning curves are an effective way to visualize how a model improves as more training data is used and how it generalizes over unseen data. Scikit-Learn, a robust library for machine learning in Python, provides efficient tools to plot these curves. In this article, we will explore how to create learning curves using Scikit-Learn.

Understanding Learning Curves

Learning curves illustrate a model's learning process. Typically, these plots showcase a model's performance as a function of the number of training samples. They help in diagnosing whether a model is underfitting or overfitting. Usually, two curves are plotted: one for training data and another for validation data.

  • Underfitting: Both training and validation scores plateau at a low level indicating the model is too simple.
  • Overfitting: Training score continues to increase while validation score plateaus or decreases, suggesting the model fits the training data too well but fails to generalize.

Setting Up Scikit-Learn

To plot learning curves, ensure you have Scikit-Learn installed. If you haven't already, you can install it via pip:

pip install scikit-learn

You will also need Matplotlib for plotting and Numpy for data manipulation:

pip install matplotlib numpy

Generating a Learning Curve

Let’s dive into the steps to plot a learning curve for a simple linear regression model using the learning_curve function in Scikit-Learn.

Step 1: Import the Libraries


import matplotlib.pyplot as plt
from sklearn.model_selection import learning_curve
from sklearn.linear_model import LinearRegression
from sklearn.datasets import make_regression

Step 2: Create a Dataset

We will use Scikit-Learn's make_regression function to create a dataset:


X, y = make_regression(n_samples=200, n_features=1, noise=0.1)

Step 3: Calculate Learning Curves

Now, prepare the learning curve data:


train_sizes, train_scores, validation_scores = learning_curve(
    estimator=LinearRegression(),
    X=X, y=y,
    train_sizes=[50, 100, 150, 200],
    cv=5,
    scoring='neg_mean_squared_error')

Step 4: Average and Plot the Results

Finally, average the scores and plot them:


train_scores_mean = -train_scores.mean(axis=1)
validation_scores_mean = -validation_scores.mean(axis=1)

plt.plot(train_sizes, train_scores_mean, label='Training error')
plt.plot(train_sizes, validation_scores_mean, label='Validation error')
plt.ylabel('Error')
plt.xlabel('Training set size')
plt.title('Learning curve')
plt.legend()
plt.show()

Interpreting the Learning Curve

Upon executing the above code, you should see a plot detailing both training and validation errors as the size of the training set increases. Here’s how you can analyze the results:

  • If there's a significant gap between the training and validation curves, it may indicate overfitting.
  • If both curves converge but have high error rates, the model is likely underfitting.
  • If they converge with low error rates, it signifies a good model fit.

Benefits of Using Learning Curves

Learning curves help in:

  • Determining the adequacy of the model (underfitting vs. overfitting).
  • Deciding if additional data could improve model performance.
  • Providing insights into whether a model complexity or feature engineering might be beneficial.

By following these steps, you can effectively plot and interpret learning curves, thereby evaluating your models robustly. Scikit-Learn automates much of this task, making it easier to weave this critical evaluation into your machine learning workflow.

Next Article: One-vs-Rest Classification Strategy in Scikit-Learn

Previous Article: Understanding `RandomizedSearchCV` in Scikit-Learn

Series: Scikit-Learn Tutorials

Scikit-Learn

You May Also Like

  • Generating Gaussian Quantiles with Scikit-Learn
  • Spectral Biclustering with Scikit-Learn
  • Scikit-Learn Complete Cheat Sheet
  • ValueError: Estimator Does Not Support Sparse Input in Scikit-Learn
  • Scikit-Learn TypeError: Cannot Broadcast Due to Shape Mismatch
  • AttributeError: 'dict' Object Has No Attribute 'predict' in Scikit-Learn
  • KeyError: Missing 'param_grid' in Scikit-Learn GridSearchCV
  • Scikit-Learn ValueError: 'max_iter' Must Be Positive Integer
  • Fixing Log Function Error with Negative Values in Scikit-Learn
  • RuntimeError: Distributed Computing Backend Not Found in Scikit-Learn
  • Scikit-Learn TypeError: '<' Not Supported Between 'str' and 'int'
  • AttributeError: GridSearchCV Has No Attribute 'fit_transform' in Scikit-Learn
  • Fixing Scikit-Learn Split Error: Number of Splits > Number of Samples
  • Scikit-Learn TypeError: Cannot Concatenate 'str' and 'int'
  • ValueError: Cannot Use 'predict' Before Fitting Model in Scikit-Learn
  • Fixing AttributeError: NoneType Has No Attribute 'predict' in Scikit-Learn
  • Scikit-Learn ValueError: Cannot Reshape Array of Incorrect Size
  • LinAlgError: Matrix is Singular to Machine Precision in Scikit-Learn
  • Fixing TypeError: ndarray Object is Not Callable in Scikit-Learn