Sling Academy
Home/PyTorch/Creating Custom Training Loops in PyTorch

Creating Custom Training Loops in PyTorch

Last updated: December 14, 2024

PyTorch is a popular open-source machine-learning library used for a variety of deep-learning applications. While PyTorch provides high-level APIs such as torch.nn.Module and torch.optim, which simplify training models, there are instances where you might want more control over the training process. This is when creating custom training loops becomes essential.

In this article, we'll walk through setting up a custom training loop in PyTorch. This approach offers flexibility in detail that high-level frameworks may not, such as implementing custom training schedules, managing resources, handling model-specific requirements, and more.

Prerequisites

Before diving into custom training loops, ensure you have the following:

  • Python and PyTorch installed. You can get PyTorch from its official website.
  • A basic understanding of neural networks, PyTorch tensors, and datasets.

Setting up the Dataset

To demonstrate custom training loops, we'll use Pyrorch's dataset handling utilities. Here is a sample code to set up a simple dataset using torchvision.datasets:

import torch
from torchvision import datasets, transforms

# Transforming dataset - converting to tensor and normalizing
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='data', train=True, 
                               transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, 
                                           batch_size=64, shuffle=True)

Creating a Model

Let’s create a simple fully connected neural network using torch.nn.Module:

import torch.nn as nn
import torch.nn.functional as F

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)  # Flatten the tensor
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

With the dataset and model in place, let's move on to create the custom training loop.

Writing a Custom Training Loop

The custom training loop consists of iterating through the dataset, computing the loss, and performing backpropagation manually:

# Initialize model, loss, and optimizer
model = SimpleNet()
criterion = nn.NLLLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# Training loop
epochs = 5
for epoch in range(epochs):
    for images, labels in train_loader:
        # Zero gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")

Here's the breakdown of the above code:

  • We initialize the model, loss function, and optimizer. SGD is used for simplicity, though more sophisticated optimizers are available in PyTorch.
  • The main loop iterates over each epoch, and within each epoch, we iterate over all batches of the data.
  • For each batch, the gradients are reset using optimizer.zero_grad().
  • We perform the forward pass to obtain predictions from our model and compute the loss.
  • The backward() method calculates the gradients, which are then used to update the weights with optimizer.step().

Extending the Loop

This custom framework is flexible, allowing you to integrate various callbacks, custom logging, or advanced features such as gradient clipping. Consider this snippet for early stopping:

 best_loss = float('inf')
patience = 2
trials = 0

for epoch in range(epochs):
    epoch_loss = 0
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    # Early stopping logic
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        trials = 0
    else:
        trials += 1
        if trials >= patience:
            print("Early stopping due to no improvement!")
            break

In this example, simple early stopping is implemented to halt training if there is no improvement in loss over a specified 'patience' number of epochs.

Conclusion

Creating custom training loops in PyTorch offers researchers and developers the flexibility needed in more complex scenarios where predefined utilities may not fulfill specific requirements. With PyTorch's dynamic computation graph, experimenting with, testing, and deploying deep learning models becomes a seamless experience.

Next Article: PyTorch Workflow for Complex Projects

Previous Article: Advanced PyTorch Techniques for Model Training

Series: The First Steps with PyTorch

PyTorch

You May Also Like

  • Addressing "UserWarning: floor_divide is deprecated, and will be removed in a future version" in PyTorch Tensor Arithmetic
  • In-Depth: Convolutional Neural Networks (CNNs) for PyTorch Image Classification
  • Implementing Ensemble Classification Methods with PyTorch
  • Using Quantization-Aware Training in PyTorch to Achieve Efficient Deployment
  • Accelerating Cloud Deployments by Exporting PyTorch Models to ONNX
  • Automated Model Compression in PyTorch with Distiller Framework
  • Transforming PyTorch Models into Edge-Optimized Formats using TVM
  • Deploying PyTorch Models to AWS Lambda for Serverless Inference
  • Scaling Up Production Systems with PyTorch Distributed Model Serving
  • Applying Structured Pruning Techniques in PyTorch to Shrink Overparameterized Models
  • Integrating PyTorch with TensorRT for High-Performance Model Serving
  • Leveraging Neural Architecture Search and PyTorch for Compact Model Design
  • Building End-to-End Model Deployment Pipelines with PyTorch and Docker
  • Implementing Mixed Precision Training in PyTorch to Reduce Memory Footprint
  • Converting PyTorch Models to TorchScript for Production Environments
  • Deploying PyTorch Models to iOS and Android for Real-Time Applications
  • Combining Pruning and Quantization in PyTorch for Extreme Model Compression
  • Using PyTorch’s Dynamic Quantization to Speed Up Transformer Inference
  • Applying Post-Training Quantization in PyTorch for Edge Device Efficiency