Sling Academy
Home/PyTorch/Building a Semantic Segmentation Model with PyTorch and U-Net

Building a Semantic Segmentation Model with PyTorch and U-Net

Last updated: December 14, 2024

Semantic segmentation is a crucial area in computer vision, involving the process of classifying each pixel in an image into a class. In this article, we will walk through building a semantic segmentation model using PyTorch and the U-Net architecture, a popular choice for this task due to its robustness in segmenting medical images.

Understanding U-Net Architecture

U-Net is a convolutional neural network architecture that uses a symmetric architecture with an encoder-decoder structure. It consists of three main parts: the encoder, bottleneck, and decoder. The encoder captures context through a series of convolutional and pooling layers, while the decoder reconstructs the segmentation map using up-convolutions and concatenations with high-resolution features from the encoder path.

Prerequisites

Before we start building our model, ensure you have Python, PyTorch, and the necessary libraries installed. You can do this by running:

pip install torch torchvision numpy matplotlib

Data Preparation

First, we need to load and preprocess our dataset. For the purpose of this article, we'll use a public dataset, which you can download from Kaggle or another open data source. Ensure your dataset is split into images and labels.

Let's define a basic PyTorch dataset class:


import os
from torch.utils.data import Dataset
from PIL import Image

class SegmentationDataset(Dataset):
    def __init__(self, image_dir, label_dir, transform=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        label_path = os.path.join(self.label_dir, self.images[idx])
        image = Image.open(img_path).convert("RGB")
        label = Image.open(label_path).convert("L")

        if self.transform:
            image, label = self.transform(image, label)

        return image, label

Building the U-Net Model

Next, let's implement the U-Net model. We'll define the architecture with customizable depth and width parameters for flexibility.


import torch
import torch.nn as nn

class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()
        self.encoder = nn.Sequential(
            self.conv_layer(in_channels, 64),
            nn.MaxPool2d(kernel_size=2)
        )
        self.bottleneck = self.conv_layer(64, 128)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            self.conv_layer(64, out_channels)
        )

    def conv_layer(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        enc_out = self.encoder(x)
        bottleneck = self.bottleneck(enc_out)
        dec_out = self.decoder(bottleneck)
        return dec_out

Training the Model

With the dataset and model ready, the next step is training. We'll set up our training loop, defining the loss function and optimizer:


from torch import optim

# Hyperparameters
num_epochs = 25
learning_rate = 0.001

# Initialize model, optimizer, and loss function
model = UNet(in_channels=3, out_channels=1)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.BCEWithLogitsLoss()

for epoch in range(num_epochs):
    for images, labels in dataloader:
        outputs = model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

Conclusion

Building a semantic segmentation model requires careful consideration of the dataset, model architecture, and training procedures. PyTorch, combined with architectures like U-Net, provides the tools necessary to develop powerful semantic segmentation models that can be fine-tuned for various applications. By enhancing the model with advanced techniques like data augmentation and transfer learning, performance can be significantly improved.

Next Article: PyTorch for Instance Segmentation: Training Mask R-CNN from Scratch

Previous Article: Implementing Object Detection Pipelines in PyTorch Using Faster R-CNN

Series: PyTorch Computer Vision

PyTorch

You May Also Like

  • Addressing "UserWarning: floor_divide is deprecated, and will be removed in a future version" in PyTorch Tensor Arithmetic
  • In-Depth: Convolutional Neural Networks (CNNs) for PyTorch Image Classification
  • Implementing Ensemble Classification Methods with PyTorch
  • Using Quantization-Aware Training in PyTorch to Achieve Efficient Deployment
  • Accelerating Cloud Deployments by Exporting PyTorch Models to ONNX
  • Automated Model Compression in PyTorch with Distiller Framework
  • Transforming PyTorch Models into Edge-Optimized Formats using TVM
  • Deploying PyTorch Models to AWS Lambda for Serverless Inference
  • Scaling Up Production Systems with PyTorch Distributed Model Serving
  • Applying Structured Pruning Techniques in PyTorch to Shrink Overparameterized Models
  • Integrating PyTorch with TensorRT for High-Performance Model Serving
  • Leveraging Neural Architecture Search and PyTorch for Compact Model Design
  • Building End-to-End Model Deployment Pipelines with PyTorch and Docker
  • Implementing Mixed Precision Training in PyTorch to Reduce Memory Footprint
  • Converting PyTorch Models to TorchScript for Production Environments
  • Deploying PyTorch Models to iOS and Android for Real-Time Applications
  • Combining Pruning and Quantization in PyTorch for Extreme Model Compression
  • Using PyTorch’s Dynamic Quantization to Speed Up Transformer Inference
  • Applying Post-Training Quantization in PyTorch for Edge Device Efficiency