Understanding numpy.rollaxis() function (4 examples)

Updated: February 29, 2024 By: Guest Contributor Post a comment

Introduction

NumPy, the cornerstone of scientific computing with Python, boasts a myriad of functions that make it significantly easier to manipulate multi-dimensional arrays. One such function, numpy.rollaxis(), might not be as well-known as numpy.reshape() or numpy.transpose(), but its utility in reordering the axes of an array to facilitate broadcasting or simply for better data manipulation cannot be overstated. In this tutorial, we will explore the numpy.rollaxis() function through four progressively advanced examples.

Understanding numpy.rollaxis()

The numpy.rollaxis() function moves the specified axis backwards, until it lies in a given position. This could sound a bit abstract at first, so let’s dive into practical examples to clarify how it works.

Syntax:

numpy.rollaxis(a, axis, start=0)

Parameters:

  • a: array_like. Input array.
  • axis: int. The axis to be rolled. The positions of the other axes do not change relative to one another.
  • start: int, optional. When start is positive, the axis is rolled until it lies before this position. When start is negative, the axis is rolled until it lies after this position. The default is 0.

Returns:

  • res: ndarray. For NumPy versions before 1.13.0, the returned array may be a view on the input array. For NumPy versions 1.13.0 and later, a view of a is always returned if possible.

Example 1: Basic Axis Reordering

import numpy as np

# Creating a 3-dimensional array
array = np.array([[[1, 2, 3], [4, 5, 6]],
                 [[7, 8, 9], [10, 11, 12]]])

# Use rollaxis to move the last axis to the front
rolled = np.rollaxis(array, axis=2, start=0)
print("Original shape:", array.shape)
print("After rollaxis:", rolled.shape)

Output:

Original shape: (2, 2, 3)
After rollaxis: (3, 2, 2)

This simple example demonstrates the basic usage of numpy.rollaxis(). By moving the last axis (axis=2) of a 3D array to the front, the shape of the array changes from (2, 2, 3) to (3, 2, 2), essentially rotating the axes.

Example 2: Advanced Data Manipulation

Moving towards more complex scenarios, consider a situation where you’re working with temporal data assembled in a three-axis format, and you need to shift the time dimension to facilitate a specific analysis.

import numpy as np

# Suppose you have a data array with shape (days, hours, features)
data = np.random.rand(10, 24, 5)  # Random data for illustration

# Reordering the time dimension (hours) to the first position
reordered = np.rollaxis(data, 1, 0)
print("New shape:", reordered.shape)

Output:

New shape: (24, 10, 5)

This maneuver shifts the ‘hours’ axis to the leading position, making it easier to apply functions or perform analyses that need to iterate over each hour across all days and features systematically.

Example 3: Multi-Dimensional Data Broadcasting for Custom Operations

Consider a scenario where you have a multi-dimensional dataset representing some measurements across different conditions, dimensions being (conditions, experiments, samples), and you want to apply a normalization that requires broadcasting a specific operation along the samples axis across all conditions and experiments.

import numpy as np

def normalize_across_samples(data):
    # data shape: (conditions, experiments, samples)
    # Goal: Normalize across the 'samples' axis, treating each condition and experiment individually
    
    # Step 1: Roll 'samples' axis to the front
    rolled_data = np.rollaxis(data, 2, 0)
    
    # Step 2: Apply normalization
    # Example normalization: subtract mean and divide by std deviation
    mean = np.mean(rolled_data, axis=0)
    std = np.std(rolled_data, axis=0)
    normalized_data = (rolled_data - mean) / std
    
    # Step 3: Roll 'samples' axis back
    unrolled_data = np.rollaxis(normalized_data, 0, 3)
    
    return unrolled_data

# Simulate some data: 4 conditions, 10 experiments, 100 samples each
data = np.random.rand(4, 10, 100)

# Normalize across samples
normalized_data = normalize_across_samples(data)

print(f"Original shape: {data.shape}")
print(f"Normalized shape: {normalized_data.shape}")

Output:

Original shape: (4, 10, 100)
Normalized shape: (4, 10, 100)

Example 4: 3D Image Data Manipulation

Imagine working with a 3D array representing a stack of RGB images, where the dimensions are (images, rows, columns, channels) and you want to process this data in a manner that requires operating across all images on a per-channel basis. You might want to roll the channels axis to the front to apply a filter or transformation individually to each color channel across all images.

import numpy as np

def apply_color_transformation(image_stack):
    # image_stack shape: (images, rows, columns, channels)
    # Roll the 'channels' axis to the front
    rolled_images = np.rollaxis(image_stack, 3, 0)
    
    # Placeholder transformation: invert color channels
    # This is just an example; in practice, this could be any complex operation
    transformed_images = 255 - rolled_images
    
    # Roll the 'channels' axis back to its original position
    unrolled_images = np.rollaxis(transformed_images, 0, 4)
    
    return unrolled_images

# Simulating a stack of 5 RGB images of size 100x100
image_stack = np.random.randint(0, 256, (5, 100, 100, 3), dtype=np.uint8)

# Apply the color transformation
transformed_stack = apply_color_transformation(image_stack)

print(f"Original shape: {image_stack.shape}")
print(f"Transformed shape: {transformed_stack.shape}")

Output:

Original shape: (5, 100, 100, 3)
Transformed shape: (5, 100, 100, 3)

Conclusion

Throughout this tutorial, we’ve gone through from basic to advanced examples showcasing the utility of numpy.rollaxis() in various data manipulation tasks. The function is incredibly versatile, enabling cleaner code, simplifying broadcasting operations, and providing a pathway to more complex array manipulations. Embracing numpy.rollaxis(), alongside other numpy tools, can significantly enhance your data processing routines, leading to more efficient and readable code.