PyTorch: How to Find the Min and Max in a Tensor

Updated: July 8, 2023 By: Wolf Post a comment

torch.min() and torch.max()

In PyTorch, you can make use of the built-in functions torch.min() and torch.max() to find the minimum and maximum values of the whole tensor or along a given dimension. These functions return either a single value or a tuple of values and indices, depending on the input arguments.

Example:

import torch

# set the seed for generating random numbers
torch.manual_seed(2023) 

# create a random tensor of shape (3, 4)
a = torch.randn(3, 4) 
print(a)
# tensor([[-1.2075,  0.5493, -0.3856,  0.6910],
#         [-0.7424,  0.1570,  0.0721,  1.1055],
#         [ 0.2218, -0.0794, -1.0846, -1.5421]])

# find the minimum value of the whole tensor
min_a = torch.min(a) 
print(min_a)
# tensor(-1.5421)

# find the maximum value of the whole tensor
max_a = torch.max(a) 
print(max_a)
# tensor(1.1055)

# find the minimum value of each row
min_a_row = torch.min(a, dim=1) 
print(min_a_row)
# torch.return_types.min(
#   values=tensor([-1.2075, -0.7424, -1.5421]),
#   indices=tensor([0, 0, 3])
# )

# find the maximum value of each row
max_a_row = torch.max(a, dim=1) 
print(max_a_row)
# torch.return_types.max(
#   values=tensor([0.6910, 1.1055, 0.2218]),
#   indices=tensor([3, 3, 0])
# )

torch.argmin() and torch.argmax()

PyTorch also brings to the table the functions torch.argmin() and torch.argmax() to get the indices of the minimum and maximum values of a tensor, respectively. They are similar to the torch.min() and torch.max() functions, but they only return the indices, not the values. They can also operate on the whole tensor or on a specific dimension, and return either a single index or a tensor of indices, depending on the input arguments.

A code example is worth more than a thousand words:

import torch

# set the seed for generating random numbers
torch.manual_seed(2023)

# create a random tensor of shape (3, 4)
a = torch.randn(3, 4)
print(a)
# tensor([[-1.2075,  0.5493, -0.3856,  0.6910],
#         [-0.7424,  0.1570,  0.0721,  1.1055],
#         [ 0.2218, -0.0794, -1.0846, -1.5421]])

# find the index of the minimum value of the whole tensor
argmin_a = torch.argmin(a)
print(argmin_a)
# tensor(11)

# find the index of the maximum value of the whole tensor
argmax_a = torch.argmax(a)
print(argmax_a)
# tensor(7)

# find the index of the minimum value of each row
argmin_a_row = torch.argmin(a, dim=1)
print(argmin_a_row)
# tensor([0, 0, 3])

# find the index of the maximum value of each row
argmax_a_row = torch.argmax(a, dim=1)
print(argmax_a_row)
# tensor([3, 3, 0])