Sling Academy
Home/PyTorch/PyTorch: Find the Sum and Mean of a Tensor

PyTorch: Find the Sum and Mean of a Tensor

Last updated: July 08, 2023

In PyTorch, to find the sum and mean of a tensor, you can use the torch.sum() and torch.mean() functions, respectively. These functions can operate on the whole tensor or on a specific dimension, and return either a single value or a tensor of values, depending on the input arguments.

Example:

import torch

# set the seed for generating random numbers
torch.manual_seed(2023)

# create a random tensor of shape (3, 4)
a = torch.randn(3, 4)
print(a)
# tensor([[-1.2075,  0.5493, -0.3856,  0.6910],
#         [-0.7424,  0.1570,  0.0721,  1.1055],
#         [ 0.2218, -0.0794, -1.0846, -1.5421]])

# find the sum of all elements in the tensor
sum_a = torch.sum(a)  
print(sum_a)
# tensor(-2.2448)

# find the mean of all elements in the tensor
mean_a = torch.mean(a)  
print(mean_a)
# tensor(-0.1871)

# find the sum of each row in the tensor
sum_a_row = torch.sum(a, dim=1)  
print(sum_a_row)
# tensor([-0.3528,  0.5922, -2.4843])


# find the mean of each row in the tensor
mean_a_row = torch.mean(a, dim=1)  
print(mean_a_row)
# tensor([-0.0882,  0.1480, -0.6211])

You can also specify a data type for the output tensor by passing a dtype argument to these functions.

Example:

import torch

t = torch.tensor([
    [1, 2, 3],
    [4, 5, 6]
])

# calculate sum of all elements in tensor
sum_float = torch.sum(t, dtype=torch.float32)
print(sum_float)
# tensor(21.)

# calculate mean of all elements in tensor
mean_float = torch.mean(t, dtype=torch.float32)
print(mean_float)
# tensor(3.5000)

You can also use the torch.Torch.sum() and torch.Torch.mean() methods to get the job done. These methods are equivalent to the functions that are called from the torch module (torch.sum() and torch.mean()). They have the same arguments and return the same results. The only difference is the syntax and the way they are invoked.

Next Article: Using the torch.prod() and torch.cumprod() functions in PyTorch

Previous Article: PyTorch: How to Find the Min and Max in a Tensor

Series: Working with Tensors in PyTorch

PyTorch