Sling Academy
Home/PyTorch/How to Train Your First Model in PyTorch: Step-by-Step

How to Train Your First Model in PyTorch: Step-by-Step

Last updated: December 14, 2024

PyTorch is a powerful open-source deep learning library that provides a robust platform to train machine learning models. Whether you're a seasoned data scientist or a beginner in machine learning, PyTorch offers the flexibility and versatility needed to build everything from simple linear models to complex neural networks. In this article, we’ll go through a step-by-step guide on how to train your first model using PyTorch.

1. Install PyTorch

Before you can get started with model training, you need to ensure that PyTorch is installed on your system.

To install PyTorch, use the following command:

# For Linux
!pip install torch torchvision
# For Windows
pip install torch torchvision

Make sure to verify your installation by importing PyTorch and checking the version:

import torch
print(torch.__version__)

2. Import Necessary Libraries

To start using PyTorch, you need to import several essential libraries:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

If you're planning to use CUDA (GPU) for acceleration, verify your CUDA availability:

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

3. Define the Dataset

Here, we'll use the popular MNIST dataset, a standard dataset in the computer vision community. First, define a transformation to convert the image to tensor and normalize the values.

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

4. Build a Neural Network Model

With PyTorch, you can define a simple neural network model by subclassing nn.Module. This example demonstrates a small fully connected neural network:

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

5. Set the Loss Function and Optimizer

Select a loss function and an optimizer to train your network. In this case, the choice of CrossEntropyLoss is standard for multi-class classification:

model = SimpleNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

6. Train the Model

Now that you have all the components in place, you can start training your model. Loop through the dataset in batches, compute the loss, backpropagate the error, and update the model's parameters.

n_epochs = 5
for epoch in range(n_epochs):
    model.train()
    total_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader)}")

7. Test the Model (Optional)

Once the model has been trained, you should test it with data that was not part of the training set to evaluate how well it has learned to generalize. You can similarly prepare a test dataset and iterate through it, recording accuracy metrics.

Training your first model using PyTorch might seem overwhelming at first, but by following clearly defined steps and experimenting, you'll soon be able to leverage the powerful tools PyTorch offers to solve complex problems. Happy coding!

Next Article: Setting Up Optimizers and Loss Functions in PyTorch

Previous Article: The Essentials of Training a PyTorch Model

Series: The First Steps with PyTorch

PyTorch

You May Also Like

  • Addressing "UserWarning: floor_divide is deprecated, and will be removed in a future version" in PyTorch Tensor Arithmetic
  • In-Depth: Convolutional Neural Networks (CNNs) for PyTorch Image Classification
  • Implementing Ensemble Classification Methods with PyTorch
  • Using Quantization-Aware Training in PyTorch to Achieve Efficient Deployment
  • Accelerating Cloud Deployments by Exporting PyTorch Models to ONNX
  • Automated Model Compression in PyTorch with Distiller Framework
  • Transforming PyTorch Models into Edge-Optimized Formats using TVM
  • Deploying PyTorch Models to AWS Lambda for Serverless Inference
  • Scaling Up Production Systems with PyTorch Distributed Model Serving
  • Applying Structured Pruning Techniques in PyTorch to Shrink Overparameterized Models
  • Integrating PyTorch with TensorRT for High-Performance Model Serving
  • Leveraging Neural Architecture Search and PyTorch for Compact Model Design
  • Building End-to-End Model Deployment Pipelines with PyTorch and Docker
  • Implementing Mixed Precision Training in PyTorch to Reduce Memory Footprint
  • Converting PyTorch Models to TorchScript for Production Environments
  • Deploying PyTorch Models to iOS and Android for Real-Time Applications
  • Combining Pruning and Quantization in PyTorch for Extreme Model Compression
  • Using PyTorch’s Dynamic Quantization to Speed Up Transformer Inference
  • Applying Post-Training Quantization in PyTorch for Edge Device Efficiency