Sling Academy
Home/PyTorch/Stacking Tensors in PyTorch: Tutorials & Examples

Stacking Tensors in PyTorch: Tutorials & Examples

Last updated: July 22, 2023

This concise, practical article is about stacking tensors in PyTorch with the torch.stack(), torch.vstack(), and torch.hstack() functions.

torch.stack()

Syntax & Parameters

torch.stack() is a PyTorch function that joins or concatenates a sequence of tensors along a new dimension. It inserts a new dimension and concatenates the tensors along that dimension. The tensors must have the same shape and size to be stacked.

Syntax:

torch.stack(tensors, dim=0, *, out=None) -> Tensor

Where:

  • tensors: a sequence of tensors to concatenate. They must have the same shape and size.
  • dim: an integer that specifies the dimension to insert. It must be between 0 and the number of dimensions of the concatenated tensors (inclusive).
  • out: an optional tensor that stores the output. It must have the same shape and size as the expected output.

The function returns a tensor that is the concatenation of the input tensors along the specified dimension.

Example

Let’s say you have two tensors a and b with shape (3, 4):

a = torch.tensor([[1, 2, 3, 4],
                  [5, 6, 7, 8],
                  [9, 10, 11, 12]])
b = torch.tensor([[13, 14, 15, 16],
                  [17, 18, 19, 20],
                  [21, 22, 23, 24]])

You can stack them along the first dimension (dim=0) to get a tensor c with shape (2, 3, 4):

c = torch.stack([a, b], dim=0)
print(c)
# tensor([[[ 1,  2,  3,  4],
#          [ 5,  6,  7,  8],
#          [ 9, 10, 11, 12]],
#
#         [[13, 14, 15, 16],
#          [17, 18, 19, 20],
#          [21, 22, 23, 24]]])

You can see that the tensors a and b are concatenated along a new dimension at the beginning of the tensor c. The first slice of c along dim=0 is equal to a, and the second slice is equal to b. You can also stack them along the second dimension (dim=1) to get a tensor d with shape (3, 2, 4):

d = torch.stack([a, b], dim=1)
print(d)
# tensor([[[ 1, 2 ,3 ,4 ],
#          [13 ,14 ,15 ,16]],
#
#         [[5 ,6 ,7 ,8 ],
#          [17 ,18 ,19 ,20]],
#
#         [[9 ,10 ,11 ,12],
#          [21 ,22 ,23 ,24]]])

You can see that the tensors a and b are concatenated along a new dimension in the middle of the tensor d. The first column of d along dim=1 is equal to a, and the second column is equal to b.

torch.vstack()

Syntax & Parameters

The torch.vstack() function is used to stack tensors in sequence vertically (row wise). This means that it concatenates tensors along the first axis, or the dimension that represents the rows of a matrix. For example, if you have two tensors A and B, each with shape (3, 4), then torch.vstack((A, B)) will return a tensor with shape (6, 4), where the first three rows are from A and the last three rows are from B.

Syntax:

torch.vstack(tensors, *, out=None) -> Tensor

Where:

  • tensors: a sequence of tensors to be stacked vertically. The tensors must have the same shape along all dimensions except the first dimension. They can be 1-D or higher dimensional tensors, but they will be reshaped to be at least 2-D by using torch.atleast_2d().
  • out: an optional argument to specify a pre-allocated output tensor to store the result. The output tensor must have the correct shape and dtype to hold the stacked tensors. If out is None, a new tensor will be allocated and returned.

Examples

Here is an example of using torch.vstack() with two 1-D tensors:

import torch
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])
z = torch.vstack((x, y))
print(z)

Output:

tensor([[1, 2, 3],
        [4, 5, 6]])

You can see that x and y are reshaped to be (1, 3) tensors and then stacked vertically to form a (2, 3) tensor.

Another example that demonstrates how the function works with tensors that aren’t 1-D:

import torch
a = torch.tensor([
    [[1, 2, 3], [4, 5, 6]],
    [[7, 8, 9], [10, 11, 12]]
])
print(a.shape) 
# (2, 2, 3)

b = torch.tensor([
    [[13, 14, 15], [16, 17, 18]],
    [[19, 20, 21], [22, 23, 24]]
])
print(b.shape)
# (2, 2, 3)

c = torch.vstack([a, b])
print(c.shape)
# (4, 2, 3)

print(c)
# tensor([[[ 1,  2,  3],
#          [ 4,  5,  6]],

#         [[ 7,  8,  9],
#          [10, 11, 12]],

#         [[13, 14, 15],
#          [16, 17, 18]],

#         [[19, 20, 21],
#          [22, 23, 24]]])

torch.hstack()

Syntax & Parameters

torch.hstack() is a function that stacks tensors in sequence horizontally (column wise). It is equivalent to concatenation along the first axis for 1-D tensors, and along the second axis for all other tensors.

Syntax:

torch.hstack(tensors, *, out=None) -> Tensor

Parameters:

  • tensors: a sequence of tensors to concatenate.
  • out: an optional output tensor to store the result.

Examples

Using torch.hstack() with 1-D tensors:

import torch

x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])

z = torch.hstack((x, y))
print(z)
# tensor([1, 2, 3, 4, 5, 6])

Another example:

import torch
a = torch.tensor([
    [[1, 2, 3], [4, 5, 6]],
    [[7, 8, 9], [10, 11, 12]]
])
print(a.shape) 
# (2, 2, 3)

b = torch.tensor([
    [[13, 14, 15], [16, 17, 18]],
    [[19, 20, 21], [22, 23, 24]]
])
print(b.shape) 
# (2, 2, 3)


c = torch.hstack((a, b))
print(c.shape) 
#(2, 4, 3)

print(c)
# tensor([[[ 1,  2,  3],
#          [ 4,  5,  6],
#          [13, 14, 15],
#          [16, 17, 18]],

#         [[ 7,  8,  9],
#          [10, 11, 12],
#          [19, 20, 21],
#          [22, 23, 24]]])

Next Article: PyTorch: Squeezing and Unsqueezing Tensors

Previous Article: How to Flatten a Tensor in PyTorch (2 Ways)

Series: Working with Tensors in PyTorch

PyTorch

You May Also Like

  • Addressing "UserWarning: floor_divide is deprecated, and will be removed in a future version" in PyTorch Tensor Arithmetic
  • In-Depth: Convolutional Neural Networks (CNNs) for PyTorch Image Classification
  • Implementing Ensemble Classification Methods with PyTorch
  • Using Quantization-Aware Training in PyTorch to Achieve Efficient Deployment
  • Accelerating Cloud Deployments by Exporting PyTorch Models to ONNX
  • Automated Model Compression in PyTorch with Distiller Framework
  • Transforming PyTorch Models into Edge-Optimized Formats using TVM
  • Deploying PyTorch Models to AWS Lambda for Serverless Inference
  • Scaling Up Production Systems with PyTorch Distributed Model Serving
  • Applying Structured Pruning Techniques in PyTorch to Shrink Overparameterized Models
  • Integrating PyTorch with TensorRT for High-Performance Model Serving
  • Leveraging Neural Architecture Search and PyTorch for Compact Model Design
  • Building End-to-End Model Deployment Pipelines with PyTorch and Docker
  • Implementing Mixed Precision Training in PyTorch to Reduce Memory Footprint
  • Converting PyTorch Models to TorchScript for Production Environments
  • Deploying PyTorch Models to iOS and Android for Real-Time Applications
  • Combining Pruning and Quantization in PyTorch for Extreme Model Compression
  • Using PyTorch’s Dynamic Quantization to Speed Up Transformer Inference
  • Applying Post-Training Quantization in PyTorch for Edge Device Efficiency