Saving and loading models are crucial parts of any machine learning workflow. PyTorch, a popular deep learning library, offers a simple method to save and load models. This allows for resuming training later, sharing models with others, or deploying on different systems.
Saving a Model in PyTorch
In PyTorch, models are saved using the torch.save()
function. Typically, PyTorch models are saved as a .pt
or .pth
file. This can be done in two main ways: saving the entire model or just the model parameters (state_dict).
Saving Entire Model
To save the entire model, you need to pass the model instance and the file path to the torch.save()
function. Here’s how you can do it:
import torch
import torch.nn as nn
# Define a sample model
class SampleModel(nn.Module):
def __init__(self):
super(SampleModel, self).__init__()
self.linear = nn.Linear(10, 2)
def forward(self, x):
return self.linear(x)
model = SampleModel()
# Save the entire model
torch.save(model, "model.pth")
In this code snippet, we first define a simple neural network model and then save it as "model.pth".
Saving Model State Dictionary
The more common practice is saving the model's state dictionary. This approach is preferred as it allows you to save only the parameters and buffers of your model, omitting the unnecessary parts like the model architecture. Here's how you save your model's state:
# Save the model's state dictionary
torch.save(model.state_dict(), "model_state.pth")
This method is generally used because the model architecture is often defined in the code, and saving the state dictionary is more portable.
Loading a Model in PyTorch
Loading a model in PyTorch requires you to know how it was saved. Depending on whether you saved the entire model or just the state_dict, the loading code will vary.
Loading Entire Model
If you have saved the entire model using torch.save()
, you can load it back using torch.load()
.
# Load the entire model
loaded_model = torch.load("model.pth")
loaded_model.eval() # For evaluation mode
Using torch.load()
, you can quickly reload your model and set it to evaluation mode if you want to make predictions.
Loading Model State Dictionary
To load a model using the state dictionary, you'll need to re-instantiate the model class and then load the parameters.
# Instantiate the model
loaded_model = SampleModel()
# Load the state dictionary
loaded_model.load_state_dict(torch.load("model_state.pth"))
After loading the state dictionary, it is crucial to put the model into evaluation mode using .eval()
if you are not training the model further.
Using GPU vs CPU
When loading a model, especially after training on GPU, you might need to ensure compatibility depending on your device. Load the model based on your current environment.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SampleModel()
model.load_state_dict(torch.load("model_state.pth", map_location=device))
model.to(device)
The above snippet demonstrates how to load a model onto the available device, making your code adaptable to various environments.
Conclusion
Saving and loading models in PyTorch is a straightforward process. Knowing when to save the full model or just the state dictionaries can streamline your project and reduce overhead. Utilizing these methods ensures that your models can be effortlessly reloaded and used across different systems or stages of development.