Sling Academy
Home/PyTorch/Saving and Loading Models with `torch.save()` and `torch.load()` in PyTorch

Saving and Loading Models with `torch.save()` and `torch.load()` in PyTorch

Last updated: December 14, 2024

PyTorch is a powerful library for deep learning that is widely used for building and training neural networks. One of the benefits of PyTorch is that it offers simple and efficient functions for saving and loading models. This is particularly useful for tasks such as resuming training, sharing models, or deploying them for inference.

Saving a PyTorch Model

The function torch.save() is used to serialize and save a model to disk. This process is straightforward but having a good understanding of torch.save()'s features will help you manage your saved models effectively.

The general syntax for saving a PyTorch model involves two parts: the model's state dictionary and the recommended file format, typically with a .pt or .pth extension.


import torch
import torch.nn as nn
import torch.optim as optim

# Example model
eclass SampleModel(nn.Module):
    def __init__(self):
        super(SampleModel, self).__init__()
        self.layer = nn.Linear(10, 2)

    def forward(self, x):
        return self.layer(x)

model = SampleModel()

# Define a random optimizer
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Specify the path for saving the model
PATH = "./model.pth"

# Save the model's state_dict (recommended approach)
torch.save(model.state_dict(), PATH)

Here, state_dict is a Python dictionary object that stores all the parameters of the model and the persistent buffers (e.g., running averages of any batch-norm layers).

Loading a PyTorch Model

The torch.load() function is used to load a saved PyTorch model. A key point to remember is that loading a state_dict into a model requires that the model class is instantiated with the same architecture as when it was saved.


# Initialize the model again
model = SampleModel()

# Load the weights from the saved file into the model's state_dict
model.load_state_dict(torch.load(PATH))

# Remember to set the model to evaluation mode after loading
model.eval()

In this snippet, after the weights are loaded into the model, it is crucial to call model.eval() to set the model into evaluation mode. This is particularly important if the model contains layers such as dropout layers or batch normalization layers, which behave differently during evaluation versus training.

Complete Model Saving

If you would like to save not only the state_dict but the entire model, including the architecture, you may do so using torch.save(). However, it's important to be cautious since saving the entire model creates a substantial file size and introduces dependencies between the model and PyTorch.


# This saves the entire model including the architecture
torch.save(model, "./full_model.pth")

# To load the full model:
model = torch.load("./full_model.pth")
model.eval()

This method is generally not recommended if you're planning to refactor code and still access model weights, since loading will use the code as-is from when it was saved. It offers less flexibility compared to the state_dict approach.

Conclusion

Saving and loading models in PyTorch can be straightforward yet requires an understanding of when to use state dictionaries versus entire models. Use state_dicts for saving and loading when you want to store parameter values cleanly and maintain code efficiency. Opt for saving entire models only when preserving architecture as-is is necessary, recognizing the trade-offs involved. Whether you're transitioning between training sessions or deploying for inference, PyTorch’s serialization and deserialization support equips you with the essential tools needed.

Next Article: Counting Tensor Elements with `torch.numel()` in PyTorch

Previous Article: Move Your Tensors to GPU with `torch.to()` in PyTorch

Series: Working with Tensors in PyTorch

PyTorch

You May Also Like

  • Addressing "UserWarning: floor_divide is deprecated, and will be removed in a future version" in PyTorch Tensor Arithmetic
  • In-Depth: Convolutional Neural Networks (CNNs) for PyTorch Image Classification
  • Implementing Ensemble Classification Methods with PyTorch
  • Using Quantization-Aware Training in PyTorch to Achieve Efficient Deployment
  • Accelerating Cloud Deployments by Exporting PyTorch Models to ONNX
  • Automated Model Compression in PyTorch with Distiller Framework
  • Transforming PyTorch Models into Edge-Optimized Formats using TVM
  • Deploying PyTorch Models to AWS Lambda for Serverless Inference
  • Scaling Up Production Systems with PyTorch Distributed Model Serving
  • Applying Structured Pruning Techniques in PyTorch to Shrink Overparameterized Models
  • Integrating PyTorch with TensorRT for High-Performance Model Serving
  • Leveraging Neural Architecture Search and PyTorch for Compact Model Design
  • Building End-to-End Model Deployment Pipelines with PyTorch and Docker
  • Implementing Mixed Precision Training in PyTorch to Reduce Memory Footprint
  • Converting PyTorch Models to TorchScript for Production Environments
  • Deploying PyTorch Models to iOS and Android for Real-Time Applications
  • Combining Pruning and Quantization in PyTorch for Extreme Model Compression
  • Using PyTorch’s Dynamic Quantization to Speed Up Transformer Inference
  • Applying Post-Training Quantization in PyTorch for Edge Device Efficiency