PyTorch torch.permute() function

Updated: July 23, 2023 By: Frienzied Flame Post a comment

This concise article is about the torch.permute() function in PyTorch.

The fundamentals

The torch.permute() function is used to rearrange the dimensions of a tensor according to a given order. For example, if you have a tensor of shape (2, 3, 4), you can use torch.permute() to change its shape to (4, 2, 3) by swapping the first and the last dimension. This function returns a view of the original tensor, which means it does not create a copy of the data in memory. It only changes how the data is accessed by modifying the strides attribute of the tensor.

Syntax:

torch.permute(input, dims) -> Tensor

Where:

  • input: The input tensor that you want to permute. It can be any shape and dtype.
  • dims: A tuple of integers that specifies the desired ordering of dimensions. The length of the tuple must match the number of dimensions of the input tensor. Each element of the tuple must be a valid index of a dimension of the input tensor, starting from zero.

For more clarity, see the examples below.

Examples

Basic example

This code changes the shape of a tensor from (2, 3, 4) to (4, 2, 3):

import torch

# Set a random seed for reproducibility
torch.manual_seed(0)

# Create a random tensor of shape (2, 3, 4)
input = torch.randn(2, 3, 4) 

# Print the shape of the input tensor
print(input.shape)
# torch.Size([2, 3, 4])

# print the input tensor
print(input)
# tensor([[[-1.1258, -1.1524, -0.2506, -0.4339],
#          [ 0.8487,  0.6920, -0.3160, -2.1152],
#          [ 0.4681, -0.1577,  1.4437,  0.2660]],

#         [[ 0.1665,  0.8744, -0.1435, -0.1116],
#          [ 0.9318,  1.2590,  2.0050,  0.0537],
#          [ 0.6181, -0.4128, -0.8411, -2.3160]]])

# Permute the tensor to have shape (4, 2, 3)
output = input.permute(2, 0, 1) 

# Print the shape of the output tensor
print(output.shape) 
# torch.Size([4, 2, 3])

# Print the output tensor
print(output)
# tensor([[[-1.1258,  0.8487,  0.4681],
#          [ 0.1665,  0.9318,  0.6181]],

#         [[-1.1524,  0.6920, -0.1577],
#          [ 0.8744,  1.2590, -0.4128]],

#         [[-0.2506, -0.3160,  1.4437],
#          [-0.1435,  2.0050, -0.8411]],

#         [[-0.4339, -2.1152,  0.2660],
#          [-0.1116,  0.0537, -2.3160]]])

Using torch.permute() in a neural network

One real-world scenario is when you want to change the order of the input channels for a convolutional layer. For instance, suppose you have an input tensor of shape (batch_size, height, width, channels), where channels are RGB values. However, your convolutional layer expects the input to have the shape (batch_size, channels, height, width). In this case, you can use torch.permute() to swap the last and the second dimension of the input tensor. Here is a code snippet that illustrates this:

# Import PyTorch library
import torch
import torch.nn as nn

# Define a convolutional layer with 3 input channels and 6 output channels
conv = nn.Conv2d(3, 6, kernel_size=3)

# Create a random input tensor of shape (batch_size, height, width, channels)
input = torch.randn(4, 32, 32, 3)

# Permute the input tensor to have the shape (batch_size, channels, height, width)
input = input.permute(0, 3, 1, 2)
print(input.shape)
# torch.Size([4, 3, 32, 32])

# Apply the convolutional layer to the permuted input
output = conv(input)

# Print the shape of the output tensor
print(output.shape)
# torch.Size([4, 6, 30, 30])

That’s it. This tutorial ends here. Happy coding & enjoy your day!