Sling Academy
Home/PyTorch/A Step-by-Step Tutorial on Fine-Tuning Classification Models in PyTorch

A Step-by-Step Tutorial on Fine-Tuning Classification Models in PyTorch

Last updated: December 14, 2024

Fine-tuning a pre-trained classification model in PyTorch is an essential skill that allows developers to leverage the power of transfer learning. With the massive amount of publicly available datasets and models, we can significantly cut down the time to develop models by fine-tuning existing ones on new data. In this article, you’ll learn how to fine-tune classification models in PyTorch using a simple step-by-step approach.

Prerequisites

  • Basic understanding of Python and PyTorch.
  • PyTorch and necessary libraries installed (numpy, torchvision, etc.).
  • A dataset with labeled images for classification.

Step 1: Load a Pre-trained Model

To start, we will choose a pre-trained model from PyTorch’s model zoo. PyTorch offers a variety of models such as ResNet, VGG, and AlexNet. For this tutorial, let’s use ResNet-18.

import torch
from torchvision import models

# Load a pre-trained ResNet-18 model
data = models.resnet18(pretrained=True)

It is important to specify pretrained=True to load the model with weights trained on ImageNet.

Step 2: Modify the Model's Classifier

The pre-trained model expects a certain output size (e.g., 1000 classes for ImageNet), so we need to adjust the final layer according to our dataset's number of classes.

import torch.nn as nn

# Modify the fully connected layer to output the number of classes in your dataset
data.fc = nn.Linear(data.fc.in_features, num_classes)

Here, num_classes should be set to the number of categories in your dataset.

Step 3: Prepare the Dataset

Next, we need to load our dataset and perform the necessary transformations such as resizing and normalization. We’ll use PyTorch’s DataLoader for this.

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Transforms for the training data
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load the dataset
dataset = datasets.ImageFolder('/path/to/dataset', transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

It’s crucial to match the expected input size and normalization parameters of the pre-trained models.

Step 4: Define the Loss Function and Optimizer

Now that we have our model and dataset ready, we need to define a loss function and optimizer. The choice of optimizer can vary, but torch.optim.SGD is a good start.

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

Feel free to experiment with other optimizers like Adam to see which works best in your case.

Step 5: Train the Model

The next step is to train the model. For fine-tuning, a lower learning rate is typically used to retain the features learned from the previous tasks.

for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    running_loss = 0.0

    for inputs, labels in dataloader:
        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(dataloader)}')

This loop iterates over the dataset for a number of epochs where in each epoch, it computes the loss and updates the model weights through backpropagation.

Step 6: Evaluate the Model

After training, it's important to evaluate the model on a separate validation dataset to avoid overfitting.

correct = 0
 total = 0
 with torch.no_grad():
    model.eval()  # Set the model to evaluation mode
    for inputs, labels in val_loader:
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
print(f'Accuracy: {100 * correct / total}%')

Use a validation loader populated with validation images and labels to check the model's accuracy.

Conclusion

Fine-tuning a classification model in PyTorch is a straightforward process that leverages existing top models. By following the steps outlined here, you can tailor any pre-trained model to work with your own dataset efficiently. With practice, this technique becomes invaluable in reducing development time and achieving better results in image classification tasks.

Next Article: Semi-Supervised Classification with PyTorch: Leveraging Unlabeled Data

Previous Article: PyTorch Classification Under the Hood: Understanding Model Internals

Series: PyTorch Neural Network Classification

PyTorch

You May Also Like

  • Addressing "UserWarning: floor_divide is deprecated, and will be removed in a future version" in PyTorch Tensor Arithmetic
  • In-Depth: Convolutional Neural Networks (CNNs) for PyTorch Image Classification
  • Implementing Ensemble Classification Methods with PyTorch
  • Using Quantization-Aware Training in PyTorch to Achieve Efficient Deployment
  • Accelerating Cloud Deployments by Exporting PyTorch Models to ONNX
  • Automated Model Compression in PyTorch with Distiller Framework
  • Transforming PyTorch Models into Edge-Optimized Formats using TVM
  • Deploying PyTorch Models to AWS Lambda for Serverless Inference
  • Scaling Up Production Systems with PyTorch Distributed Model Serving
  • Applying Structured Pruning Techniques in PyTorch to Shrink Overparameterized Models
  • Integrating PyTorch with TensorRT for High-Performance Model Serving
  • Leveraging Neural Architecture Search and PyTorch for Compact Model Design
  • Building End-to-End Model Deployment Pipelines with PyTorch and Docker
  • Implementing Mixed Precision Training in PyTorch to Reduce Memory Footprint
  • Converting PyTorch Models to TorchScript for Production Environments
  • Deploying PyTorch Models to iOS and Android for Real-Time Applications
  • Combining Pruning and Quantization in PyTorch for Extreme Model Compression
  • Using PyTorch’s Dynamic Quantization to Speed Up Transformer Inference
  • Applying Post-Training Quantization in PyTorch for Edge Device Efficiency