Sling Academy
Home/PyTorch/How to Flatten a Tensor in PyTorch (2 Ways)

How to Flatten a Tensor in PyTorch (2 Ways)

Last updated: July 14, 2023

Flattening a tensor in PyTorch means reshaping it into a one-dimensional tensor (1D tensor). This concise, example-based article will show you a couple of different ways to do so.

Using the torch.flatten() function

This approach uses the torch.flatten() function, which might be the best choice for the task. It is simple and flexible. The function can flatten the whole tensor or a subset of dimensions. It does not depend on the contiguity or copying behavior of the input tensor.

Syntax:

torch.flatten(input, start_dim=0, end_dim=-1) -> Tensor

Where:

  • input: The input tensor that you want to flatten. It can have any shape and data type.
  • start_dim: The first dimension to flatten, counting from zero. The default value is zero, which means to flatten from the first dimension. You can also use negative values to count from the end, such as -1 for the last dimension.
  • end_dim: The last dimension to flatten, counting from zero. The default value is -1, which means to flatten until the last dimension. You can also use negative values to count from the end, such as -2 for the second last dimension.

Example:

import torch

# Create a 3x2x4 tensor
x = torch.tensor([[[1, 2, 3, 4],
                   [5, 6, 7, 8]],
                  [[9, 10, 11, 12],
                   [13, 14, 15, 16]]])

# Flatten the whole tensor
y = torch.flatten(x)
print(y)
# tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16])
print(y.shape)
# torch.Size([16])

# Flatten only the last two dimensions
z = torch.flatten(x, start_dim=1)
print(z)
# tensor([[ 1,  2,  3,  4,  5,  6,  7,  8],
#         [ 9, 10, 11, 12, 13, 14, 15, 16]])
print(z.shape)
# torch.Size([2, 8])

Using the Torch.view() method

This approach is fast and memory efficient as it does not copy the data, but the downside is that it only works on contiguous tensors and requires specifying the desired shape explicitly. There are two main steps:

  1. Check if the input tensor is contiguous by calling Tensor.is_contiguous(). If not, make it contiguous by calling Tensor.contiguous().
  2. Call Tensor.view() on the tensor with -1 as the argument to infer the shape of the output tensor.

The syntax of the Tensor.view() method is as follows:

Tensor.view(*shape) -> Tensor

Where:

  • self: The input tensor that you want to reshape.
  • *shape: Either a torch.Size object or a sequence of integers that determine the desired shape of the output. You can also use -1 to infer the size of a dimension from the other dimensions.

Example:

import torch

# Create a non-contiguous (transposed) tensor
x = torch.tensor([[1,2],[3,4],[5,6]]).transpose(0,1)
print(x)
# tensor([[1, 3, 5],
#         [2, 4, 6]])

print(x.is_contiguous())
# False

# Make it contiguous
x = x.contiguous()
print(x.is_contiguous())
# True

# Reshape it into a one-dimensional tensor
y = x.view(-1)
print(y)
# tensor([1, 3, 5, 2, 4, 6])
print(y.shape)
# torch.Size([6])

You can find more details about the Torch.view() method in this article. Happy coding & have a nice day!

Next Article: Stacking Tensors in PyTorch: Tutorials & Examples

Previous Article: PyTorch Tensor.view() method (with example)

Series: Working with Tensors in PyTorch

PyTorch

You May Also Like

  • Addressing "UserWarning: floor_divide is deprecated, and will be removed in a future version" in PyTorch Tensor Arithmetic
  • In-Depth: Convolutional Neural Networks (CNNs) for PyTorch Image Classification
  • Implementing Ensemble Classification Methods with PyTorch
  • Using Quantization-Aware Training in PyTorch to Achieve Efficient Deployment
  • Accelerating Cloud Deployments by Exporting PyTorch Models to ONNX
  • Automated Model Compression in PyTorch with Distiller Framework
  • Transforming PyTorch Models into Edge-Optimized Formats using TVM
  • Deploying PyTorch Models to AWS Lambda for Serverless Inference
  • Scaling Up Production Systems with PyTorch Distributed Model Serving
  • Applying Structured Pruning Techniques in PyTorch to Shrink Overparameterized Models
  • Integrating PyTorch with TensorRT for High-Performance Model Serving
  • Leveraging Neural Architecture Search and PyTorch for Compact Model Design
  • Building End-to-End Model Deployment Pipelines with PyTorch and Docker
  • Implementing Mixed Precision Training in PyTorch to Reduce Memory Footprint
  • Converting PyTorch Models to TorchScript for Production Environments
  • Deploying PyTorch Models to iOS and Android for Real-Time Applications
  • Combining Pruning and Quantization in PyTorch for Extreme Model Compression
  • Using PyTorch’s Dynamic Quantization to Speed Up Transformer Inference
  • Applying Post-Training Quantization in PyTorch for Edge Device Efficiency