Sling Academy
Home/Pandas/Pandas: Extract test/train/validation sets from a DataFrame

Pandas: Extract test/train/validation sets from a DataFrame

Last updated: February 20, 2024

Introduction

Pandas is a powerful library in the Python ecosystem that makes it easy to manipulate and analyze data. When building machine learning models, a common task is to split your dataset into training, validation, and test sets. This tutorial will guide you through multiple ways of achieving this in Pandas, from basic methods to more advanced approaches.

Setting Up Your Environment

Before we dive into the examples, ensure you have Pandas installed:

pip install pandas
pip install sklearn  # For some advanced examples

Also, let’s import Pandas and create a simple DataFrame:

import pandas as pd

# Create a simple DataFrame
df = pd.DataFrame(
    {
        "features_1": range(100),
        "features_2": range(100, 200),
        "label": [1 if x % 2 == 0 else 0 for x in range(100)],
    }
)

Basic Method: Manual Split

The most basic way to split a DataFrame is manually. This method gives you full control over how the splits are made but requires more coding. Here’s how you can do it:

train, validate, test = np.split(df.sample(frac=1), [int(.6*len(df)), int(.8*len(df))])

This code first shuffles the DataFrame using sample(frac=1), then splits it into training (60%), validation (20%), and test (20%) sets based on the total number of rows.

Using sklearn’s train_test_split

A more conventional method is to use train_test_split from sklearn’s model_selection module. This method provides an easy way to split data into two sets. To create three splits (train, validation, and test), we’ll need to call it twice:

from sklearn.model_selection import train_test_split

# Split into train and temp (temp will be split into validation and test)
X_train, X_temp, y_train, y_temp = train_test_split(
    df.drop("label", axis=1), df["label"], test_size=0.4, random_state=42
)

# Split temp into validation and test
X_val, X_test, y_val, y_test = train_test_split(
    X_temp, y_temp, test_size=0.5, random_state=42
)

This approach is straightforward and allows for easy randomization and stratification of the splits.

Advanced Approach: Stratified Splits with sklearn

For datasets with imbalanced classes or when we want each set to contain approximately the same percentage of samples of each class as the complete set, stratified splits come in handy. We use StratifiedShuffleSplit to achieve this:

from sklearn.model_selection import train_test_split

from sklearn.model_selection import StratifiedShuffleSplit

sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)

for train_index, test_index in sss.split(df.drop("label", axis=1), df["label"]):
    X_train, X_test = (
        df.drop("label", axis=1).iloc[train_index],
        df.drop("label", axis=1).iloc[test_index],
    )
    y_train, y_test = df["label"].iloc[train_index], df["label"].iloc[test_index]

# For Validation Set
validation_size = 0.25  # 25% of the training set
train_index, val_index = train_test_split(
    range(len(X_train)), test_size=validation_size, random_state=42
)
X_train, X_val = X_train.iloc[train_index], X_train.iloc[val_index]
y_train, y_val = y_train.iloc[train_index], y_train.iloc[val_index]

This strategy ensures that your training, validation, and test sets are more representative of the overall distribution of your data.

Conclusion

In this tutorial, we explored various methods to split a DataFrame into training, validation, and test sets using Pandas and sklearn. Whether you require simple random splits or stratified splits for imbalanced data, the approaches outlined above will help you prepare your data for modeling. Understanding and properly preparing your data is crucial in building effective machine learning models.

Next Article: Pandas DataFrame: Get indexes of rows where column meets certain condition

Previous Article: Pandas: Selecting all columns except some from a DataFrame (4 ways)

Series: DateFrames in Pandas

Pandas

You May Also Like

  • How to Use Pandas Profiling for Data Analysis (4 examples)
  • How to Handle Large Datasets with Pandas and Dask (4 examples)
  • Pandas – Using DataFrame.pivot() method (3 examples)
  • Pandas: How to ‘FULL JOIN’ 2 DataFrames (3 examples)
  • Pandas: Select columns whose names start/end with a specific string (4 examples)
  • 3 ways to turn off future warnings in Pandas
  • How to Integrate Pandas with Apache Spark
  • How to Use Pandas for Web Scraping and Saving Data (2 examples)
  • How to Clean and Preprocess Text Data with Pandas (3 examples)
  • Pandas – Using Series.replace() method (3 examples)
  • Pandas json_normalize() function: Explained with examples
  • Pandas: Reading CSV and Excel files from AWS S3 (4 examples)
  • Using pandas.Series.rank() method (4 examples)
  • Pandas: Dropping columns whose names contain a specific string (4 examples)
  • Pandas: How to print a DataFrame without index (3 ways)
  • Fixing Pandas NameError: name ‘df’ is not defined
  • Pandas – Using DataFrame idxmax() and idxmin() methods (4 examples)
  • Pandas FutureWarning: ‘M’ is deprecated and will be removed in a future version, please use ‘ME’ instead
  • Pandas: Checking equality of 2 DataFrames (element-wise)