Using numpy.moveaxis() function (5 examples)

Updated: February 29, 2024 By: Guest Contributor Post a comment

Introduction

The numpy.moveaxis() function is a powerful tool in Python’s NumPy library that allows you to rearrange the axes of an array. Understanding how to use this function can profoundly impact your data manipulation and transformation tasks, especially in the context of scientific computing or deep learning frameworks. In this tutorial, we will explore the numpy.moveaxis() function through five progressively complex examples.

Understanding numpy.moveaxis()

numpy.moveaxis() takes in an array and moves its axes to new positions. The general syntax is:

numpy.moveaxis(a, source, destination)

Where a is the input array, source is the position of the axes to move, and destination is their target position.

Example 1: Basic Axis Movement

import numpy as np

# Create a 3x2 array
arr = np.array([[1, 2], [3, 4], [5, 6]])
print("Original array:
", arr)

# Move the first axis (0) to the last position
result = np.moveaxis(arr, 0, 1)
print("After moveaxis:
", result)

Output:

Original array:
 [[1, 2],
 [3, 4],
 [5, 6]]
After moveaxis:
 [[1, 3, 5],
 [2, 4, 6]]

Example 2: Moving Multiple Axes

import numpy as np

# 3D array
arr = np.array([[[1, 2, 3], [4, 5, 6]],
                [[7, 8, 9], [10, 11, 12]]])
print("3D array before moveaxis:
", arr)

# Move axes 0 and 2
result = np.moveaxis(arr, [0, 2], [2, 1])
print("3D array after moveaxis:
", result)

Output:

3D array before moveaxis:
 [[[1, 2, 3],
   [4, 5, 6]],
  [[7, 8, 9],
   [10, 11, 12]]]
3D array after moveaxis:
 [[[ 1,  7],
   [ 4, 10]],

  [[ 2,  8],
   [ 5, 11]],

  [[ 3,  9],
   [ 6, 12]]]

Example 3: Using moveaxis with Images

Working with images in machine learning often requires rearranging their axes because different libraries expect different axes orders. For example, changing an image with shape (height, width, channels) to (channels, height, width) for PyTorch:

import numpy as np

# Assuming an image with shape (256, 256, 3)
image = np.random.rand(256, 256, 3)
print("Original image shape:", image.shape)

result = np.moveaxis(image, -1, 0)
print("After moveaxis, image shape:", result.shape)

Output:

Original image shape: (256, 256, 3)
After moveaxis, image shape: (3, 256, 256)

Example 4: Complex Rearrangements in Multidimensional Arrays

import numpy as np

arr = np.zeros((1, 2, 3, 4, 5))
print("Original array shape:", arr.shape)

# Example of complex rearrangement
result = np.moveaxis(arr, [0, 3], [3, 0])
print("After moveaxis, new shape:", result.shape)

Output:

Original array shape: (1, 2, 3, 4, 5)
After moveaxis, new shape: (4, 2, 3, 1, 5)

Example 5: Concatenating and Then Rearranging Axes

Sometimes, you might need to concatenate arrays and then rearrange the axes of the combined array. Here’s how:

import numpy as np

# Two 2D arrays
arr1 = np.array([1, 2, 3])
arr2 = np.array([4, 5, 6])

# Concatenate along a new axis to make a 2x3 array
combined = np.stack((arr1, arr2), axis=0)
print("Combined array before moveaxis:", combined.shape)

# Rearrange axes
result = np.moveaxis(combined, 0, 1)
print("Combined array after moveaxis:", result.shape)

Output:

Combined array before moveaxis: (2, 3)
Combined array after moveaxis: (3, 2)

Conclusion

The numpy.moveaxis() function is a versatile tool for manipulating array dimensions, allowing data to be reshaped to match required formats for various applications efficiently. By mastering this function, you can greatly enhance your data preprocessing and manipulation tasks in Python.