Sling Academy
Home/PyTorch/How to Split Your Dataset into Training and Test Sets in PyTorch

How to Split Your Dataset into Training and Test Sets in PyTorch

Last updated: December 14, 2024

When working with machine learning models, it is crucial to split your dataset into training and test sets. By splitting the data, you can train your model on one dataset and then test its performance on a separate dataset, providing an unbiased evaluation. In this guide, we'll explore how to execute such a split using PyTorch, a popular open-source machine learning library in Python.

Why Split a Dataset?

Splitting a dataset helps determine how well a trained model performs on unseen data. This process helps to prevent both overfitting and underfitting. Overfitting occurs when the model learns the training data too well and performs poorly on unseen data, while underfitting occurs when a model is too simple to capture the underlying trends in the data.

Prerequisites

Before diving into dataset splitting, ensure you are equipped with Python and PyTorch installed on your computer. You can install PyTorch using pip:

pip install torch torchvision

Understanding Your Data

Before splitting, you'll first load your dataset. Commonly, data can be in various formats, such as CSV, image files, or stored in custom objects. For this example, we'll consider a basic tensor dataset. PyTorch enables easy manipulation of data using torch.utils.data.TensorDataset and other utilities.

Code Example: Splitting the Dataset in PyTorch

Let's start with a working code example demonstrating how to split a dataset:

import torch
from torch.utils.data import DataLoader, random_split, TensorDataset

# Example dataset
data = torch.arange(1000).view(-1, 10)  # 100 samples, each with 10 features
targets = torch.randint(0, 2, (100,))  # 100 binary labels

dataset = TensorDataset(data, targets)

# Define the split ratio
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size

# Split the dataset
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# Create DataLoader for each dataset
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=True)

Step-by-Step Explanation

1. Load the Dataset

In our example, we created synthetic data using torch.arange() and torch.randint(). Generally, you'll load your dataset from a source and convert it into tensors suitable for PyTorch.

2. Specify the Split Ratio

You define how much of the dataset will be used for training and how much for testing. A typical approach is to use 80% of the data for training and the remaining 20% for testing, as implemented here: train_size = int(0.8 * len(dataset)).

3. Perform the Split

Use torch.utils.data.random_split() to randomly split the dataset. This ensures that your training and test datasets remain representative of your full dataset.

4. Create DataLoaders

DataLoaders are a PyTorch utility that loads data in mini-batches, an essential aspect for efficient model training. Here, they help facilitate training in manageable chunks, ensuring optimal load times and GPU utilization.

Handling Imbalanced Datasets

In real-world cases, datasets may be imbalanced, meaning classes have disparate representation frequencies. When splitting, it’s vital to ensure both sets maintain similar distributions. For this, you may need stratified sampling, which isn't directly supported by default in PyTorch but can be implemented via custom scripts or using sklearn's train_test_split() with stratification features.

Conclusion

Splitting your dataset into training and test sets is a fundamental step in developing robust machine learning models. PyTorch provides efficient utilities to facilitate this process, allowing seamless transitions from data preparation to training and evaluation. Adopting the practices mentioned here will aid in creating models that generalize well to new, unseen data, ultimately delivering higher accuracy.

Next Article: A Step-by-Step Guide to Data Splitting in PyTorch

Previous Article: Creating Your First Dataset with Linear Regression in PyTorch

Series: The First Steps with PyTorch

PyTorch

You May Also Like

  • Addressing "UserWarning: floor_divide is deprecated, and will be removed in a future version" in PyTorch Tensor Arithmetic
  • In-Depth: Convolutional Neural Networks (CNNs) for PyTorch Image Classification
  • Implementing Ensemble Classification Methods with PyTorch
  • Using Quantization-Aware Training in PyTorch to Achieve Efficient Deployment
  • Accelerating Cloud Deployments by Exporting PyTorch Models to ONNX
  • Automated Model Compression in PyTorch with Distiller Framework
  • Transforming PyTorch Models into Edge-Optimized Formats using TVM
  • Deploying PyTorch Models to AWS Lambda for Serverless Inference
  • Scaling Up Production Systems with PyTorch Distributed Model Serving
  • Applying Structured Pruning Techniques in PyTorch to Shrink Overparameterized Models
  • Integrating PyTorch with TensorRT for High-Performance Model Serving
  • Leveraging Neural Architecture Search and PyTorch for Compact Model Design
  • Building End-to-End Model Deployment Pipelines with PyTorch and Docker
  • Implementing Mixed Precision Training in PyTorch to Reduce Memory Footprint
  • Converting PyTorch Models to TorchScript for Production Environments
  • Deploying PyTorch Models to iOS and Android for Real-Time Applications
  • Combining Pruning and Quantization in PyTorch for Extreme Model Compression
  • Using PyTorch’s Dynamic Quantization to Speed Up Transformer Inference
  • Applying Post-Training Quantization in PyTorch for Edge Device Efficiency