PyTorch is a powerful library for deep learning that is widely used for building and training neural networks. One of the benefits of PyTorch is that it offers simple and efficient functions for saving and loading models. This is particularly useful for tasks such as resuming training, sharing models, or deploying them for inference.
Saving a PyTorch Model
The function torch.save()
is used to serialize and save a model to disk. This process is straightforward but having a good understanding of torch.save()
's features will help you manage your saved models effectively.
The general syntax for saving a PyTorch model involves two parts: the model's state dictionary and the recommended file format, typically with a .pt
or .pth
extension.
import torch
import torch.nn as nn
import torch.optim as optim
# Example model
eclass SampleModel(nn.Module):
def __init__(self):
super(SampleModel, self).__init__()
self.layer = nn.Linear(10, 2)
def forward(self, x):
return self.layer(x)
model = SampleModel()
# Define a random optimizer
optimizer = optim.SGD(model.parameters(), lr=0.01)
# Specify the path for saving the model
PATH = "./model.pth"
# Save the model's state_dict (recommended approach)
torch.save(model.state_dict(), PATH)
Here, state_dict
is a Python dictionary object that stores all the parameters of the model and the persistent buffers (e.g., running averages of any batch-norm layers).
Loading a PyTorch Model
The torch.load()
function is used to load a saved PyTorch model. A key point to remember is that loading a state_dict into a model requires that the model class is instantiated with the same architecture as when it was saved.
# Initialize the model again
model = SampleModel()
# Load the weights from the saved file into the model's state_dict
model.load_state_dict(torch.load(PATH))
# Remember to set the model to evaluation mode after loading
model.eval()
In this snippet, after the weights are loaded into the model, it is crucial to call model.eval()
to set the model into evaluation mode. This is particularly important if the model contains layers such as dropout layers or batch normalization layers, which behave differently during evaluation versus training.
Complete Model Saving
If you would like to save not only the state_dict but the entire model, including the architecture, you may do so using torch.save()
. However, it's important to be cautious since saving the entire model creates a substantial file size and introduces dependencies between the model and PyTorch.
# This saves the entire model including the architecture
torch.save(model, "./full_model.pth")
# To load the full model:
model = torch.load("./full_model.pth")
model.eval()
This method is generally not recommended if you're planning to refactor code and still access model weights, since loading will use the code as-is from when it was saved. It offers less flexibility compared to the state_dict approach.
Conclusion
Saving and loading models in PyTorch can be straightforward yet requires an understanding of when to use state dictionaries versus entire models. Use state_dicts for saving and loading when you want to store parameter values cleanly and maintain code efficiency. Opt for saving entire models only when preserving architecture as-is is necessary, recognizing the trade-offs involved. Whether you're transitioning between training sessions or deploying for inference, PyTorch’s serialization and deserialization support equips you with the essential tools needed.