Sling Academy
Home/PyTorch/How to Reshape a Tensor in PyTorch (with Examples)

How to Reshape a Tensor in PyTorch (with Examples)

Last updated: July 14, 2023

Overview

In PyTorch, reshaping a tensor means changing its shape (the number of dimensions and the size of each dimension) while keeping the same data and the number of elements. It is useful for manipulating the data to fit different operations or models. For example, you may want to reshape a 1D tensor (a vector) into a 2D tensor (a matrix) or vice versa.

How to reshape a tensor?

PyTorch brings to the table the torch.reshape() function that can help us easily and efficiently get the job of reshaping tensors done. Below is its syntax:

reshaped_tensor = torch.reshape(input, shape)

Where input is the tensor you want to reshape, and shape is a tuple of integers specifying the new shape.

The condition that must be satisfied when reshaping a tensor in PyTorch is that the number of elements in the input tensor must be equal to the number of elements in the output tensor. This means that the product of the sizes of all the dimensions in the input tensor must be equal to the product of the sizes of all the dimensions in the output tensor. For example, if you have a tensor with the shape (2, 3, 4) and you want to reshape it into a tensor with the shape (6, 4), you can do so because both tensors have 24 elements. However, you cannot reshape it into a tensor with the shape (5, 5) because that would require 25 elements.

Let’s examine a practical example to get a better understanding. Suppose you have a tensor named a with shape (8,) and you want to reshape it into a matrix with 4 rows and 2 columns, you can do like so:

import torch

a = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
print(a.shape) # torch.Size([8])

b = torch.reshape(a, (4, 2))

print(b.shape) # torch.Size([4, 2])
print(b)
# tensor([[1, 2],
#         [3, 4],
#         [5, 6],
#         [7, 8]])

The function will try to return a view of the input tensor if possible, which means that the reshaped tensor will share the same data as the input tensor. However, this is not always possible depending on the contiguity and stride of the input tensor. In that case, the function will return a copy of the input tensor with the new shape. You should not rely on whether the function returns a view or a copy.

You can also use a single dimension of -1 in the shape argument to let PyTorch infer that dimension from the remaining dimensions and the number of elements in the input tensor. For example, if you want to reshape a tensor named c with shape (2, 4) into a vector, you can do as follows:

import torch

c = torch.tensor([[1, 2, 3, 4],
                  [5, 6, 7, 8]])
print(c.shape) # torch.Size([2, 4])

d = torch.reshape(c, (-1,))
print(d.shape) # torch.Size([8])
print(d)
# tensor([1, 2, 3, 4, 5, 6, 7, 8])

PyTorch will infer that the -1 dimension should be equal to 8 in this case.

Next Article: PyTorch Tensor.view() method (with example)

Previous Article: Using the torch.prod() and torch.cumprod() functions in PyTorch

Series: Working with Tensors in PyTorch

PyTorch