Sling Academy
Home/PyTorch/Using the torch.prod() and torch.cumprod() functions in PyTorch

Using the torch.prod() and torch.cumprod() functions in PyTorch

Last updated: July 08, 2023

The torch.prod() and torch.cumprod() functions in PyTorch are used to calculate the product and the cumulative product of elements in a tensor, respectively.

torch.prod()

Syntax:

torch.prod(input, dim=None, keepdim=False, *, dtype=None) -> Tensor

Where:

  • input: the input tensor
  • dim: the dimension to reduce (optional)
  • keepdim: whether to retain the reduced dimension in the output (optional, default False)
  • dtype: the desired data type of the output tensor (optional)

Example:

import torch

# Create a 2D tensor of size 3x4
x = torch.tensor([[1, 2, 3, 4],
                  [5, 6, 7, 8],
                  [9, 10, 11, 12]])

# Calculate the product of all elements in x
p = torch.prod(x)
print(p)  # tensor(479001600)

# Calculate the product of each row in x
p_row = torch.prod(x, dim=1)
print(p_row)  # tensor([   24,  1680, 11880])

# Calculate the product of each column in x and keep the dimension
p_col = torch.prod(x, dim=0, keepdim=True)
print(p_col)  # tensor([[ 45, 120, 231, 384]])

torch.cumprod()

Syntax:

torch.cumprod(input, dim, *, dtype=None, out=None) -> Tensor

Where:

  • input: the input tensor
  • dim: the dimension to do the operation over
  • dtype: the desired data type of the output tensor (optional)
  • out: the output tensor (optional)

Example:

import torch

# Create a 1D tensor of size 5
y = torch.tensor([1, 2, 3, 4, 5])

# Calculate the cumulative product of y along dimension 0
c = torch.cumprod(y, dim=0)
print(c)  
# tensor([  1,   2,   6,  24, 120])

# Create a 2D tensor of size 2x3
z = torch.tensor([[1, -2, -3],
                  [4, -5, -6]])

# Calculate the cumulative product of z along dimension 1
c_row = torch.cumprod(z, dim=1)
print(c_row)  
# tensor([[  1,  -2,   6],
#         [  4, -20, 120]])

That’s it. Happy coding & enjoy your day!

Next Article: How to Reshape a Tensor in PyTorch (with Examples)

Previous Article: PyTorch: Find the Sum and Mean of a Tensor

Series: Working with Tensors in PyTorch

PyTorch