Stacking Tensors in PyTorch: Tutorials & Examples

Updated: July 22, 2023 By: Frienzied Flame Post a comment

This concise, practical article is about stacking tensors in PyTorch with the torch.stack(), torch.vstack(), and torch.hstack() functions.

torch.stack()

Syntax & Parameters

torch.stack() is a PyTorch function that joins or concatenates a sequence of tensors along a new dimension. It inserts a new dimension and concatenates the tensors along that dimension. The tensors must have the same shape and size to be stacked.

Syntax:

torch.stack(tensors, dim=0, *, out=None) -> Tensor

Where:

  • tensors: a sequence of tensors to concatenate. They must have the same shape and size.
  • dim: an integer that specifies the dimension to insert. It must be between 0 and the number of dimensions of the concatenated tensors (inclusive).
  • out: an optional tensor that stores the output. It must have the same shape and size as the expected output.

The function returns a tensor that is the concatenation of the input tensors along the specified dimension.

Example

Let’s say you have two tensors a and b with shape (3, 4):

a = torch.tensor([[1, 2, 3, 4],
                  [5, 6, 7, 8],
                  [9, 10, 11, 12]])
b = torch.tensor([[13, 14, 15, 16],
                  [17, 18, 19, 20],
                  [21, 22, 23, 24]])

You can stack them along the first dimension (dim=0) to get a tensor c with shape (2, 3, 4):

c = torch.stack([a, b], dim=0)
print(c)
# tensor([[[ 1,  2,  3,  4],
#          [ 5,  6,  7,  8],
#          [ 9, 10, 11, 12]],
#
#         [[13, 14, 15, 16],
#          [17, 18, 19, 20],
#          [21, 22, 23, 24]]])

You can see that the tensors a and b are concatenated along a new dimension at the beginning of the tensor c. The first slice of c along dim=0 is equal to a, and the second slice is equal to b. You can also stack them along the second dimension (dim=1) to get a tensor d with shape (3, 2, 4):

d = torch.stack([a, b], dim=1)
print(d)
# tensor([[[ 1, 2 ,3 ,4 ],
#          [13 ,14 ,15 ,16]],
#
#         [[5 ,6 ,7 ,8 ],
#          [17 ,18 ,19 ,20]],
#
#         [[9 ,10 ,11 ,12],
#          [21 ,22 ,23 ,24]]])

You can see that the tensors a and b are concatenated along a new dimension in the middle of the tensor d. The first column of d along dim=1 is equal to a, and the second column is equal to b.

torch.vstack()

Syntax & Parameters

The torch.vstack() function is used to stack tensors in sequence vertically (row wise). This means that it concatenates tensors along the first axis, or the dimension that represents the rows of a matrix. For example, if you have two tensors A and B, each with shape (3, 4), then torch.vstack((A, B)) will return a tensor with shape (6, 4), where the first three rows are from A and the last three rows are from B.

Syntax:

torch.vstack(tensors, *, out=None) -> Tensor

Where:

  • tensors: a sequence of tensors to be stacked vertically. The tensors must have the same shape along all dimensions except the first dimension. They can be 1-D or higher dimensional tensors, but they will be reshaped to be at least 2-D by using torch.atleast_2d().
  • out: an optional argument to specify a pre-allocated output tensor to store the result. The output tensor must have the correct shape and dtype to hold the stacked tensors. If out is None, a new tensor will be allocated and returned.

Examples

Here is an example of using torch.vstack() with two 1-D tensors:

import torch
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])
z = torch.vstack((x, y))
print(z)

Output:

tensor([[1, 2, 3],
        [4, 5, 6]])

You can see that x and y are reshaped to be (1, 3) tensors and then stacked vertically to form a (2, 3) tensor.

Another example that demonstrates how the function works with tensors that aren’t 1-D:

import torch
a = torch.tensor([
    [[1, 2, 3], [4, 5, 6]],
    [[7, 8, 9], [10, 11, 12]]
])
print(a.shape) 
# (2, 2, 3)

b = torch.tensor([
    [[13, 14, 15], [16, 17, 18]],
    [[19, 20, 21], [22, 23, 24]]
])
print(b.shape)
# (2, 2, 3)

c = torch.vstack([a, b])
print(c.shape)
# (4, 2, 3)

print(c)
# tensor([[[ 1,  2,  3],
#          [ 4,  5,  6]],

#         [[ 7,  8,  9],
#          [10, 11, 12]],

#         [[13, 14, 15],
#          [16, 17, 18]],

#         [[19, 20, 21],
#          [22, 23, 24]]])

torch.hstack()

Syntax & Parameters

torch.hstack() is a function that stacks tensors in sequence horizontally (column wise). It is equivalent to concatenation along the first axis for 1-D tensors, and along the second axis for all other tensors.

Syntax:

torch.hstack(tensors, *, out=None) -> Tensor

Parameters:

  • tensors: a sequence of tensors to concatenate.
  • out: an optional output tensor to store the result.

Examples

Using torch.hstack() with 1-D tensors:

import torch

x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])

z = torch.hstack((x, y))
print(z)
# tensor([1, 2, 3, 4, 5, 6])

Another example:

import torch
a = torch.tensor([
    [[1, 2, 3], [4, 5, 6]],
    [[7, 8, 9], [10, 11, 12]]
])
print(a.shape) 
# (2, 2, 3)

b = torch.tensor([
    [[13, 14, 15], [16, 17, 18]],
    [[19, 20, 21], [22, 23, 24]]
])
print(b.shape) 
# (2, 2, 3)


c = torch.hstack((a, b))
print(c.shape) 
#(2, 4, 3)

print(c)
# tensor([[[ 1,  2,  3],
#          [ 4,  5,  6],
#          [13, 14, 15],
#          [16, 17, 18]],

#         [[ 7,  8,  9],
#          [10, 11, 12],
#          [19, 20, 21],
#          [22, 23, 24]]])