PyTorch: How to compare 2 tensors

Updated: April 14, 2023 By: Khue Post a comment

Exact equality comparison

Exactly equality comparison means checking if two PyTorch tensors have the same shape, dtype, and values. It returns True if they are exactly the same and False otherwise. To deal with this kind of comparison, we can use the built-in function torch.equal().

Example:

import torch

# Create some tensors
x = torch.tensor([
    [1, 2],
    [3, 4],
])

y = torch.tensor([
    [1, 2],
    [3, 4],
])

z = torch.tensor([
    [1, 2, 3],
    [4, 5, 6],
])

# Compare x and y
print(torch.equal(x, y))

# Compare x and z
print(torch.equal(x, z))

Output:

True
False

Shape and dtype comparison

Shape and type comparison means checking if two given PyTorch tensors have the same shape and dtype but not necessarily the same values. You can use tensor_one.shape == tensor_two.shape and tensor_one.dtype == tensor_two.dtype which return boolean values.

Example:

import torch

a = torch.tensor( [ [1., 2.], [3., 4.]])
b = torch.tensor( [ [1., 2.], [3., 4.]])
c = torch.tensor( [ [1., 2.]])
d = torch.tensor( [ [1, 2], [3, 4]])

print(a.shape == b.shape) # True
print(a.shape == c.shape) # False
print(a.dtype == b.dtype) # True
print(a.dtype == d.dtype) # False

Approximate equality comparison

You may want to use this kind of comparison when you want to check if two tensors are close enough at each position within some tolerance for floating point differences. You can use the torch.allclose(input, other) function which returns a boolean value to do the job. You can also specify the tolerance (epsilon) as an argument.

Example:

import torch

a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([1.0001, 2.0002, 3.0003])
c = torch.tensor([1.01, 2.02, 3.03])

print(torch.allclose(a, b, rtol=0.001))  # True
print(torch.allclose(a, c))  # False
print(torch.allclose(a, c, atol=0.03))  # True

Element-wise comparison

Use this type of comparison when you want to compare two tensors element-wise and get a tensor of booleans as a result. The torch.eq(tensor_one, tensor_two) function can help you in this situation.

Example:

import torch

a = torch.tensor([1, 2, 3])
b = torch.tensor([1, 4, 3])
c = torch.tensor([4, 5, 6])

print(torch.eq(a, b))  
# Output: tensor([ True, False,  True])

print(torch.eq(a, c)) 
# Output: tensor([False, False, False])