PyTorch Error: mat1 and mat2 shapes cannot be multiplied

Updated: July 7, 2023 By: Wolf Post a comment

Overview

When working with PyTorch, you might encounter the following error:

RuntimeError: mat1 and mat2 shapes cannot be multiplied

This error occurs when you try to perform a matrix multiplication using torch.matmul() or torch.mm() with two tensors that have incompatible shapes. The number of columns of the first tensor must match the number of rows of the second tensor. For example, you can multiply a tensor of shape (m, n) with a tensor of shape (n, p), but NOT with a tensor of shape (p, n).

To fix the mentioned error, you need to check the shapes of your tensors and make sure they are compatible for matrix multiplication. You can use the .shape attribute or the torch.size() function to get the shape of a tensor. You can also use the torch.reshape(), torch.view(), or torch.transpose() functions to change the shape of a tensor if needed.

Common Cases

Below are some typical examples and common mistakes that lead to the error mat1 and mat2 shapes cannot be multiplied.

You are using two 2D tensors with the same shape

When you try to do matrix multiplication with two 2D tensors of the same shape (and are not square matrices), the error will block your way. The solution here is simple: transpose one of the two input tensors.

Example:

import torch

x = torch.tensor([
    [1, 2, 3],
    [4, 5, 6]
])

y = torch.tensor([
    [7, 8, 9],
    [10, 11, 12]
])

result = torch.matmul(x, y.T)

print(result)

You are using a 1D tensor as an input to a linear layer

A linear layer expects a 2D tensor as an input, where the first dimension is the batch size and the second dimension is the number of features. You can use the torch.unsqueeze() function to add a singleton dimension to your 1D tensor. For example, if your input tensor has the shape of (n,), you can do:

input = input.unsqueeze(-1) # change the shape to (n, 1)

You are using a 3D or higher-dimensional tensor as an input to a linear layer

You can use the torch.flatten() function to collapse all dimensions except the first one into one dimension. For example, if your input tensor has the shape of (b, m, n), you can as follows:

input = input.flatten(1) # change the shape to (b, m * n)

You are using two tensors with different batch sizes for matrix multiplication

The batch dimensions must be equal or broadcastable for matrix multiplication. You can use the torch.expand() function to expand a tensor along a dimension with size 1 to match another tensor’s size. For example, if your first tensor has the shape of (b1, m, n) and your second tensor has the shape of (b2, n, p), where b1 != b2 but b1 or b2 is 1, you can like this:

if b1 == 1:
    # expand the first tensor along the batch dimension
    input1 = input1.expand(b2, m, n)
elif b2 == 1:
    # expand the second tensor along the batch dimension
    input2 = input2.expand(b1, n, p)
# now both tensors have the same batch size
output = torch.matmul(input1, input2) # output has a shape of (b1 or b2, m, p)