PyTorch RuntimeError: mean(): could not infer output dtype

Updated: July 8, 2023 By: Khue Post a comment

When working with PyTorch and using the torch.mean() function (or the torch.Torch.mean() method), you might encounter the following error:

RuntimeError: mean(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: Long

This error means that you are trying to use the torch.mean() function on a tensor that has a data type of Long, which is an integer type. However, the torch.mean() function requires the input tensor to have a data type of either floating point or complex, which can represent fractional or decimal values. This is because the mean of a set of numbers may NOT be an integer, and the function needs to infer the output data type based on the input data type.

To fix the error, you need to either change the data type of your input tensor to a floating point or complex type or pass a dtype argument to the torch.mean() function to specify the desired data type of the output tensor like this:

import torch

t = torch.tensor([
    [1, 2, 3],
    [4, 5, 6]
])

mean = torch.mean(t, dtype=torch.float32)
print(mean)

In case you prefer to call the torch.Torch.mean() method on your tensor object, just do like so:

import torch

t = torch.tensor([
    [1, 2, 3],
    [4, 5, 6]
])

mean = t.mean(dtype=torch.float32)
print(mean)

That’s it. Happy coding & have a nice day!