PyTorch: Selecting Elements from a Tensor (3 Ways)

Updated: July 23, 2023 By: Khue Post a comment

This pithy, straightforward article will walk you through three different ways to select elements from a tensor in PyTorch. Without any further ado, let’s get started!

Indexing & Slicing

You can use the square brackets [ ] to index a tensor by specifying the position of the elements you want to select. The example below selects the element at row 1 and column 2 from a tensor of shape (3,4):

import torch 

# set a seed for reproducibility
torch.manual_seed(2023)

# create a random tensor of shape (3, 4)
x = torch.rand(3, 4)
print(x)
# tensor([[0.4290, 0.7201, 0.9481, 0.4797],
#         [0.5414, 0.9906, 0.4086, 0.2183],
#         [0.1834, 0.2852, 0.7813, 0.1048]])

# select the element at row 1 and column 2 (note that indexing starts at 0)
e = x[1, 2]
print(e)
# tensor(0.4086)

You can also use slicing to select a range of elements along a dimension by using the colon : operator. This example selects the elements from rows 0 to 1 and columns 1 to 2:

import torch 

# set a seed for reproducibility
torch.manual_seed(2023)

# create a random tensor of shape (3, 4)
x = torch.rand(3, 4)
print(x)
# tensor([[0.4290, 0.7201, 0.9481, 0.4797],
#         [0.5414, 0.9906, 0.4086, 0.2183],
#         [0.1834, 0.2852, 0.7813, 0.1048]])

# select the elements from rows 0 to 1 and columns 1 to 2
selected_elements = x[0:2, 1:3]
print(selected_elements)
# tensor([[0.7201, 0.9481],
#         [0.9906, 0.4086]])

You can also use negative indices to count from the end of the dimension. For example, x[-1, -2] will select the element at the last row and second last column. You can also use a list or a tensor of indices to select multiple elements along a dimension. For instance, x[[0, 2], :] will select the first and third rows of x.

Using the torch.select() function

You can use the torch.select() function to select a single dimension from a tensor and return a new tensor with one less dimension. The syntax of this function is:

torch.select(input, dim, index) -> Tensor

Where:

  • input: the input tensor that you want to select from.
  • dim: the dimension that you want to select.
  • index: the index of the dimension that you want to select.

For example, if you have a tensor x of shape (2, 3, 4), then you can use torch.select(x, 1, 2) to select the third row of each matrix in x and return a new tensor of shape (2, 4):

import torch 

# set a seed for reproducibility
torch.manual_seed(2023)

# create a random tensor of shape (2, 3, 4)
x = torch.rand(2, 3, 4)
print(x)
# tensor([[[0.4290, 0.7201, 0.9481, 0.4797],
#          [0.5414, 0.9906, 0.4086, 0.2183],
#          [0.1834, 0.2852, 0.7813, 0.1048]],

#         [[0.6550, 0.8375, 0.1823, 0.5239],
#          [0.2432, 0.9644, 0.5034, 0.0320],
#          [0.8316, 0.3807, 0.3539, 0.2114]]])

# Select the third row of each matrix in x
result = torch.select(x, 1, 2)
print(result)
# tensor([[0.1834, 0.2852, 0.7813, 0.1048],
#         [0.8316, 0.3807, 0.3539, 0.2114]])

The Tensor.select() method is equivalent to the torch.select() function, but you can call it directly on a Tensor object.

Using the torch.select_index() function

You can use the torch.index_select() function (or the Tensor.index_select() method) to select multiple dimensions from a tensor and return a new tensor with the same number of dimensions as the input tensor.

Syntax:

torch.index_select(input, dim, index) -> Tensor

Parameters explained:

  • input: the input tensor that you want to select from.
  • dim: the dimension that you want to select.
  • index: a 1-D tensor containing the indices of the dimensions that you want to select.

Suppose you have a tensor x of shape (3, 4), then you can use torch.index_select(x, 0, torch.tensor([0, 2])) to select the first and third rows of x and return a new tensor of shape (2, 4) like this:

# create a random tensor of shape (3, 4)
x = torch.rand(3, 4)
print(x)
# tensor([[0.4290, 0.7201, 0.9481, 0.4797],
#         [0.5414, 0.9906, 0.4086, 0.2183],
#         [0.1834, 0.2852, 0.7813, 0.1048]])

# Select the first and third rows of the tensor
result = torch.index_select(x, 0, torch.tensor([0, 2]))
print(result)
# tensor([[0.4290, 0.7201, 0.9481, 0.4797],
#         [0.1834, 0.2852, 0.7813, 0.1048]])

This tutorial ends here. Happy coding & have a nice day!