PyTorch Tensor.view() method (with example)

Updated: July 14, 2023 By: Khue Post a comment

This concise and straight-to-the-point article is about the Tensor.view() method in PyTorch.

The fundamentals

A view of a tensor is a new tensor that shares the same underlying data with the original tensor but has a different shape or size.

The Tensor.view() method is used to reshape a tensor into a new shape without changing its data. It returns a new view of the original tensor. This means that modifying the new tensor will affect the original tensor and vice versa.

Syntax:

Tensor.view(*shape) -> Tensor

Where:

  • self: The input tensor that you want to reshape.
  • *shape: Either a torch.Size object or a sequence of integers that specify the desired shape of the output tensor. You can also use -1 to infer the size of a dimension from the other dimensions.

However, Tensor.view() only works on contiguous tensors, which are tensors that are stored in contiguous memory. If the input tensor is not contiguous, you need to call Tensor.contiguous() before calling Tensor.view(). You can check if a tensor is contiguous by calling Tensor.is_contiguous().

Code example

A code example that demonstrates how to use the Tensor.view() method in practice:

import torch

torch.manual_seed(2023)

# Create a tensor with the shape of 4x4
x = torch.randn(4, 4)
print(x)
# tensor([[ 0.4305, -0.3499,  0.4749,  0.9041],
#         [-0.7021,  1.5963,  0.4228, -0.6940],
#         [ 0.9672,  1.5569, -2.3860,  0.6994],
#         [-1.0325, -2.6043,  0.9337, -0.1050]])

# Create a tensor with the shape of 16
y = x.view(16)
print(y)
# tensor([ 0.4305, -0.3499,  0.4749,  0.9041, -0.7021,  1.5963,  0.4228, -0.6940,
#          0.9672,  1.5569, -2.3860,  0.6994, -1.0325, -2.6043,  0.9337, -0.1050])

# Create a tensor with the shape of 2x8
z = x.view(2, 8)
print(z)
# tensor([[ 0.4305, -0.3499,  0.4749,  0.9041, -0.7021,  1.5963,  0.4228, -0.6940],
#         [ 0.9672,  1.5569, -2.3860,  0.6994, -1.0325, -2.6043,  0.9337, -0.1050]])

# Use -1 to infer the shape
w = x.view(-1, 2)
print(w.shape) # torch.Size([8, 2])
print(w)
# tensor([[ 0.4305, -0.3499],
#         [ 0.4749,  0.9041],
#         [-0.7021,  1.5963],
#         [ 0.4228, -0.6940],
#         [ 0.9672,  1.5569],
#         [-2.3860,  0.6994],
#         [-1.0325, -2.6043],
#         [ 0.9337, -0.1050]])

You can see that y has the same data as x but in a different shape of 16x1. z has the same data as x but in a different shape of 2x8. w has the same data as x but in a different shape of 8x2. The last dimension of w is inferred from the other dimensions and the number of elements in x.

The difference between Torch.view() and torch.reshape()

The difference between the torch.reshape() function and the Tensor.view() method is that torch.reshape() can return either a view or a copy of the original tensor, depending on whether the new shape is compatible with the original shape and stride, while Tensor.view() always returns a view of the original tensor, but only works on contiguous tensors.

You should use torch.reshape() when you want to reshape a tensor without worrying about its contiguity or copying behavior, and you should use Tensor.view() when you want to reshape a contiguous tensor and ensure that it shares the same data with the original tensor.