How to Transpose a Tensor in PyTorch

Updated: July 7, 2023 By: Wolf Post a comment

In the world of PyTorch and deep learning, transposing a tensor means rearranging its dimensions. For example, a tensor with a shape (3, 4) becomes (4, 3) after transposing. Transposition is useful when the original tensor’s shape doesn’t match the desired input shape for an operation. It allows aligning dimensions appropriately for matrix operations like matrix multiplication or when reshaping tensors. Transposing can also be used to change the axis order, which may be necessary when working with different frameworks or data formats that expect a specific axis order.

This concise, example-based article will walk you through some ways to transpose a given tensor by using the built-in features of PyTorch.

Using the torch.transpose() function

The torch.transpose() function returns a tensor that is a transposed version of the input tensor, with the given dimensions swapped. It works for any shape and layout of the input tensor as long as the dimensions are valid. The function does not create a copy of the input tensor (if it is a strided tensor) but shares its underlying storage. This means that changing one will affect the other.

Example:

import torch

# create an input tensor of shape (2, 3)
x = torch.tensor([[1, 2, 3], [4, 5, 6]])

# transpose the tensor by swapping dimensions 0 and 1
y = torch.transpose(x, 0, 1)

# print the output tensor of shape (3, 2)
print(y)

Output:

tensor([[1, 4],
        [2, 5],
        [3, 6]])

If you want a copy of the transposed tensor, you can use the clone() method as shown below:

z = torch.transpose(x, 0, 1).clone()

Using the Tensor.transpose() method

This method is similar to the torch.transpose() function, but it is called on the input tensor object instead of passing it as an argument.

Example:

import torch

# create an input tensor of shape (2, 3)
x = torch.tensor([
    [[1, 2], [3, 4], [5, 6]],
    [[7, 8], [9, 10], [11, 12]]
])

# transpose the tensor by swapping dimensions 0 and 2
y = x.transpose(0, 1)

# print the output tensor of shape (3, 2)
print(y)

Output:

tensor([[[ 1,  2],
         [ 7,  8]],

        [[ 3,  4],
         [ 9, 10]],

        [[ 5,  6],
         [11, 12]]])

Using the T attribute or the t() method (for 2D tensors only)

These are shorthands for transposing a 2D tensor (matrix) by swapping dimensions 0 and 1. They do not work for higher-dimensional tensors or other dimensions. They also do not create a copy of the input tensor but share its underlying storage. This means that changing one will affect the other.

Example:

import torch

# create an input tensor of shape (2, 3)
x = torch.tensor([[1, 2, 3], [4, 5, 6]])

# transpose the tensor by accessing the T attribute
y = x.T

# or transpose the tensor by calling the t() method
z = x.t()

# print the output tensors of shape (3, 2)
print(y)
print(z)

Output:

tensor([[1, 4],
        [2, 5],
        [3, 6]])
tensor([[1, 4],
        [2, 5],
        [3, 6]])