Sling Academy
Home/PyTorch/A Step-by-Step Guide to Data Splitting in PyTorch

A Step-by-Step Guide to Data Splitting in PyTorch

Last updated: December 14, 2024

In machine learning, data splitting is a crucial step that ensures your model can generalize well to unseen data. PyTorch, being a dynamic and versatile framework, provides various ways to split your dataset into training, validation, and testing subsets. This article will lead you through a step-by-step tutorial on how to efficiently split data in PyTorch, and apply the splits for training a model.

Understanding the Need for Data Splitting

Data splitting is pivotal because it helps in evaluating the performance of your machine learning models. Typically, datasets are divided into:

  • Training Set: Used to train the model.
  • Validation Set: Used to tune the parameters of the model.
  • Test Set: Used to evaluate the model's predictions.

Basic Dataset Setup

First, let's set up a sample dataset. We'll use PyTorch's torchvision to load a sample dataset. Assume you want to use the MNIST dataset for the task.

import torch
from torchvision import datasets, transforms

# Define a transform to normalize the data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Load the dataset
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

With this setup, you now have the MNIST dataset ready for splitting.

Splitting the Dataset

Next, we'll split the dataset into training, validation, and test sets. PyTorch offers an easy way to do this using torch.utils.data.random_split.

from torch.utils.data import random_split

# Define the size of each dataset
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

# Split the dataset
dataset_train, dataset_val = random_split(dataset, [train_size, val_size])

Here, the dataset is split into 80% training data and 20% validation data. For illustrative purposes, the test set splitting is often done separately using another dataset download, or alternately, with further random splitting from dataset_val.

Creating DataLoaders for the Splits

Once you have your splits, the next step is to create data loaders using torch.utils.data.DataLoader. This will help in iterating over the dataset in batches.

from torch.utils.data import DataLoader

# Define the data loaders
dataloader_train = DataLoader(dataset_train, batch_size=64, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size=64, shuffle=False)

By setting shuffle=True for the training data, we're ensuring that our model doesn't learn based on the order of items.

Integrating the Splits into Model Training

With data loaders set up, you can now integrate them into your model's training and validation loops. Below is a brief example of a training loop utilizing these splits:

# A simple neural network
import torch.nn as nn
import torch.optim as optim

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(28 * 28, 10)

    def forward(self, x):
        x = self.flatten(x)
        x = self.fc(x)
        return x

model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)

# Training loop
for epoch in range(10):  # loop over the dataset multiple times
    running_loss = 0.0
    for images, labels in dataloader_train:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    print(f'Epoch {epoch+1}, Loss: {running_loss/len(dataloader_train)}')

This training loop iterates over the training data, calculates the loss, and updates the model weights. Similarly, you would run a validation loop using dataloader_val to assess performance at each epoch, ensuring the model isn't overfitting.

Conclusion

Data splitting is a fundamental part of building any robust machine learning model. By splitting your data into training, validation, and test sets in PyTorch, you ensure that your model is evaluated correctly. This guide showed you how to load a dataset, perform the splits, create data loaders, and integrate them into a model training process.

Next Article: Why Data Splitting Matters in Machine Learning and How to Do It in PyTorch

Previous Article: How to Split Your Dataset into Training and Test Sets in PyTorch

Series: The First Steps with PyTorch

PyTorch

You May Also Like

  • Addressing "UserWarning: floor_divide is deprecated, and will be removed in a future version" in PyTorch Tensor Arithmetic
  • In-Depth: Convolutional Neural Networks (CNNs) for PyTorch Image Classification
  • Implementing Ensemble Classification Methods with PyTorch
  • Using Quantization-Aware Training in PyTorch to Achieve Efficient Deployment
  • Accelerating Cloud Deployments by Exporting PyTorch Models to ONNX
  • Automated Model Compression in PyTorch with Distiller Framework
  • Transforming PyTorch Models into Edge-Optimized Formats using TVM
  • Deploying PyTorch Models to AWS Lambda for Serverless Inference
  • Scaling Up Production Systems with PyTorch Distributed Model Serving
  • Applying Structured Pruning Techniques in PyTorch to Shrink Overparameterized Models
  • Integrating PyTorch with TensorRT for High-Performance Model Serving
  • Leveraging Neural Architecture Search and PyTorch for Compact Model Design
  • Building End-to-End Model Deployment Pipelines with PyTorch and Docker
  • Implementing Mixed Precision Training in PyTorch to Reduce Memory Footprint
  • Converting PyTorch Models to TorchScript for Production Environments
  • Deploying PyTorch Models to iOS and Android for Real-Time Applications
  • Combining Pruning and Quantization in PyTorch for Extreme Model Compression
  • Using PyTorch’s Dynamic Quantization to Speed Up Transformer Inference
  • Applying Post-Training Quantization in PyTorch for Edge Device Efficiency