Working with the torch.matmul() function in PyTorch

Updated: July 7, 2023 By: Wolf Post a comment

Overview

The behavior of the torch.matmul() function

The torch.matmul() function performs a matrix product of two tensors. The behavior depends on the dimensionality of the tensors as follows:

  • If both tensors are 1-dimensional, the dot product (scalar) is returned.
  • If both arguments are 2-dimensional, the matrix-matrix product is returned.
  • If the first argument is 1-dimensional and the second argument is 2-dimensional, a 1 is prepended to its dimension for the purpose of the matrix multiply. After the matrix multiply, the prepended dimension is removed.
  • If the first argument is 2-dimensional and the second argument is 1-dimensional, the matrix-vector product is returned.
  • If both arguments are at least 1-dimensional and at least one argument is N-dimensional (where N > 2), then a batched matrix multiply is returned. If the first argument is 1-dimensional, a 1 is prepended to its dimension for the purpose of the batched matrix multiplied and removed after. If the second argument is 1-dimensional, a 1 is appended to its dimension for the purpose of the batched matrix multiple and removed after. The non-matrix (i.e. batch) dimensions are broadcasted (and thus must be broadcastable).

Syntax and parameters

The syntax of the torch.matmul() function is shown below:

torch.matmul(input, other, *, out=None) -> Tensor

Where:

  • input (Tensor) – the first tensor to be multiplied
  • other (Tensor) – the second tensor to be multiplied
  • out (Tensor, optional) – the output tensor

Alternatives

The alternatives of the torch.matmul() function are listed below:

  • torch.Tensor.matmul() – a method that is called on the input tensor object instead of passing it as an argument
  • torch.Tensor.mm() – a method that only works for 2D tensors and performs a matrix-matrix product
  • torch.mm() – a function that only works for 2D tensors and performs a matrix-matrix product

Examples

Some examples that demonstrate how to use the torch.matmul() function in practice.

Vector x vector

This example shows how to compute the dot product of two 1D tensors using torch.matmul():

import torch

torch.manual_seed(2023)

# create two 1D tensors of size 3
tensor1 = torch.randn(3)
tensor2 = torch.randn(3)

# compute the dot product (scalar) using torch.matmul()
result = torch.matmul(tensor1, tensor2)

# print the result
print(result)

Output:

tensor(-1.3027)

Matrix x matrix

Perhaps this is the most common case when it comes to torch.matmul().

import torch

torch.manual_seed(2024)

# create two 2D tensors of size (3, 4) and (4, 5)
tensor1 = torch.randn(3, 4)
tensor2 = torch.randn(4, 5)

# compute the matrix-matrix product (2D tensor) using torch.matmul()
result = torch.matmul(tensor1, tensor2)

# print the result
print(result)

Output:

tensor([[-1.4659, -0.7207, -6.5537, -0.1978, -4.0800],
        [ 1.2412, -0.4786, -1.0352, -1.7091, -0.7333],
        [-0.7588,  0.2746,  1.7423, -0.3800,  0.1791]])

Matrix x vector

This example shows how to compute the matrix-vector product of a 2D tensor and a 1D tensor with the help of torch.matmul():

import torch

torch.manual_seed(2024)

# create a 2D tensor of size (3, 4)
tensor1 = torch.randn(3, 4)

# create a 1D tensor of size 4
tensor2 = torch.randn(4)

# compute the matrix-vector product (1D tensor) using torch.matmul()
result = torch.matmul(tensor1, tensor2)

# print the result
print(result)

Output:

tensor([-2.9603,  1.0947, -1.4140])

Batched matrix x broadcasted vector

This example shows how to compute the batched matrix-vector product of a 3D tensor and a 1D tensor with torch.matmul(). The non-matrix dimensions are broadcasted to match the batch size.

import torch

torch.manual_seed(2024)

# create a 3D tensor of size (10, 3, 4)
tensor1 = torch.randn(10, 3, 4)

# create a 1D tensor of size 4
tensor2 = torch.randn(4)

# compute the batched matrix-vector product (2D tensor) using torch.matmul()
result = torch.matmul(tensor1, tensor2)

# print the result
print(result) 

Output:

tensor([[-1.9808,  2.2655,  0.5833],
        [ 0.2252,  0.4640,  0.1878],
        [ 2.9779, -0.3317,  5.9854],
        [ 0.8479,  2.8258,  0.2135],
        [ 1.5425, -0.1028,  1.0848],
        [ 0.2890,  0.2828,  1.4975],
        [-4.0526, -2.9466, -1.2563],
        [-4.8923, -0.9569, -5.5869],
        [ 0.2150,  3.3648,  1.5289],
        [ 3.4914,  3.9322,  1.6943]])

Batched matrix x batched matrix

This example shows how to compute the batched matrix-matrix product of two 3D tensors by making use of torch.matmul(). The non-matrix dimensions are broadcasted if they are not equal.

import torch

torch.manual_seed(2023)

# create two 3D tensors of size (10, 3, 4) and (10, 4, 5)
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(10, 4, 5)

# compute the batched matrix-matrix product (3D tensor) using torch.matmul()
result = torch.matmul(tensor1, tensor2)

# print the result
print(result)

Output:

tensor([[[-1.7069e+00, -1.8534e+00, -2.1287e+00, -4.7003e-01,  1.0697e+00],
         [ 3.8941e+00,  9.4840e-01,  3.7984e+00, -9.8924e-02,  4.7634e-02],
         [-6.7208e-01,  2.0294e+00,  1.6539e+00,  4.7087e+00, -1.0837e+00]],

        [[ 4.5204e-02,  3.8919e+00, -5.8897e-03,  3.9597e+00,  1.6740e+00],
         [-2.6450e+00,  1.1118e+00, -7.4549e-01,  1.3357e+00, -1.5274e+00],
         [ 3.2368e+00, -1.7224e+00, -2.5382e+00,  1.9547e-01,  1.4201e+00]],

        [[-4.4614e-01,  2.1000e+00,  2.1122e-01,  4.0859e+00,  1.4184e+00],
         [ 4.5354e-01,  8.9127e-02,  3.6503e-01,  8.5806e-01, -1.3477e+00],
         [-2.4486e-01, -1.0420e+00, -6.5352e-01, -2.0812e+00,  4.1399e-01]],

        [[-6.1078e-01, -2.3098e+00,  1.0434e+00, -1.9856e+00, -2.2091e+00],
         [-3.2604e-02,  4.9101e-01, -2.5692e-01, -3.6945e-01, -3.3364e+00],
         [ 5.4829e-02, -3.6443e-01, -6.1170e-01, -8.1428e-01, -1.8848e+00]],

        [[-2.6326e+00, -3.3785e+00, -2.4368e+00,  1.2138e+00,  1.1739e+00],
         [ 3.8012e+00, -2.7200e-01,  2.4414e+00, -3.1762e+00, -2.8488e+00],
         [ 2.1198e+00, -1.0791e+00,  4.1354e-01,  7.2069e-01, -5.7256e-01]],

        [[-2.1189e+00, -1.2500e+00, -2.3749e+00,  2.1247e-02, -8.0348e-01],
         [-1.3381e+00, -3.2541e+00, -5.0417e+00, -4.2254e+00, -2.5890e+00],
         [ 1.3309e+00,  5.0646e-01, -1.8184e+00,  1.6543e+00, -1.1076e+00]],

        [[-1.3862e-01,  1.7765e+00,  3.9243e+00,  5.4390e+00,  3.1808e+00],
         [ 1.4334e+00,  3.4047e-01, -2.9210e+00, -3.4534e+00, -2.3460e+00],
         [-3.9575e+00,  4.9891e-02,  2.4687e+00,  1.8572e+00,  3.2704e+00]],

        [[ 7.9100e-01, -6.5746e-01, -1.8795e+00,  4.9998e-01, -6.6892e-01],
         [-3.0543e+00,  1.8478e+00,  4.5600e+00, -2.3450e+00, -1.3333e-01],
         [-5.4343e-01,  7.4177e-02,  3.8904e+00, -2.0984e+00, -2.4714e+00]],

        [[-1.0236e+00, -4.9922e+00,  7.3867e+00, -5.2357e+00, -1.7133e+00],
         [ 5.8022e-01, -3.0840e-01, -7.3925e-02, -5.5694e-02, -2.1630e+00],
         [ 4.7603e-01,  2.7360e+00, -6.6990e+00,  3.8872e+00, -2.0109e+00]],

        [[-1.6855e+00,  6.5613e-02,  5.2002e+00,  2.9535e+00, -1.8489e-01],
         [-2.5715e+00, -3.3931e+00,  3.4194e+00,  8.0750e-01, -2.9943e-01],
         [ 2.2488e-01, -2.7440e+00, -2.9225e+00, -2.6996e+00, -8.2288e-01]]])

Batched matrix x broadcasted matrix

The example below illustrates how to compute the batched matrix-matrix product of a 3D tensor and a 2D tensor by utilizing torch.matmul(). The non-matrix dimensions are broadcasted to match the batch size.

import torch

torch.manual_seed(2023)

# create a 3D tensor of size (10, 3, 4)
tensor1 = torch.randn(10, 3, 4)

# create a 2D tensor of size (4, 5)
tensor2 = torch.randn(4, 5)

# compute the batched matrix-matrix product (3D tensor) using torch.matmul()
result = torch.matmul(tensor1, tensor2)

# print the result
print(result)

Output:

tensor([[[ 3.5080e-02, -2.4184e+00, -1.2028e+00,  1.4250e+00, -1.6410e+00],
         [-1.9470e-01,  2.1141e+00,  5.6419e-01, -7.0405e-01,  3.1504e+00],
         [ 2.7579e+00, -7.6057e-01,  1.4620e+00,  2.5192e+00, -7.7006e-01]],

        [[-4.5815e-01,  8.5896e-01,  8.3757e-01, -2.8229e+00,  2.2722e-01],
         [-4.4352e-01, -1.5412e+00, -8.6720e-01,  4.4711e-01, -2.9628e+00],
         [ 2.0432e+00,  3.9481e+00,  3.8507e+00, -2.5680e+00,  4.0710e+00]],

        [[ 9.5828e-01,  5.6027e-01,  1.7352e+00, -1.5257e+00, -1.4508e+00],
         [-1.4288e+00, -2.6703e+00, -2.4472e+00,  1.4732e+00, -3.9988e+00],
         [ 5.9590e-01,  1.7817e+00,  1.0970e+00, -6.0984e-01,  3.5765e+00]],

        [[-1.7303e+00,  1.8793e+00, -6.9760e-01, -5.9631e-01,  9.1063e-02],
         [-1.1918e+00,  2.8247e+00,  6.6636e-01, -2.5371e+00,  1.6771e+00],
         [-1.7941e+00,  8.0749e-01, -1.0960e+00, -6.4452e-01, -3.6994e-01]],

        [[ 2.1279e+00, -4.7671e+00, -3.7989e-01,  2.4091e+00, -2.8382e+00],
         [-1.3139e+00,  3.8104e+00,  1.6157e+00, -4.3324e+00,  7.1359e-01],
         [-1.7975e+00,  5.6656e-04, -1.3343e+00, -9.2224e-01, -2.5261e-02]],

        [[ 2.6325e+00, -2.0290e+00,  1.2190e+00,  1.0536e+00,  9.9029e-01],
         [ 2.7471e+00,  2.9210e-01,  2.1788e+00,  6.4283e-02,  4.8397e+00],
         [ 8.6415e-01, -2.7562e+00, -8.5525e-01,  2.2541e+00, -1.5677e+00]],

        [[-3.5942e+00,  2.8794e+00, -1.8659e+00, -1.4797e+00,  1.5315e+00],
         [ 4.3266e+00, -1.4798e+00,  2.7324e+00,  1.4158e+00,  2.3671e+00],
         [-3.5717e+00, -3.1569e+00, -4.7173e+00,  1.7454e+00, -3.6640e+00]],

        [[-4.6379e-01,  1.5428e+00,  6.8418e-01, -1.7160e+00,  5.0970e-01],
         [-4.2143e-01, -3.1792e+00, -2.4101e+00,  3.5827e+00, -4.2032e+00],
         [-2.1132e+00, -5.7218e-01, -2.4433e+00,  1.4951e+00, -2.4490e+00]],

        [[ 1.4307e+00, -6.5002e+00, -2.2543e+00,  4.7914e+00, -5.2939e+00],
         [-1.1652e+00, -3.6335e-01, -1.0990e+00,  3.9657e-01, -2.7056e+00],
         [-1.1471e+00,  1.1457e+00, -1.0032e+00,  7.5988e-01,  1.3065e+00]],

        [[-2.2915e-01, -3.8629e+00, -2.3769e+00,  2.5113e+00, -1.3395e+00],
         [ 2.5257e+00, -3.6131e+00, -4.7740e-02,  2.9718e+00,  1.9309e-01],
         [ 2.2912e+00,  3.0198e+00,  3.6501e+00, -2.0257e+00,  2.9431e+00]]])