+3 votes
in Programming Languages by (17.9k points)
How can I find the indices of all elements which are not zero in a given tensor?

1 Answer

+1 vote
by (48.9k points)

You can use the torch.nonzero() function. It returns a tensor containing the indices of all non-zero elements of a given input tensor. The syntax of the function is as follows:

torch.nonzero(input, *, out=None, as_tuple=False)

The function returns [row,col] pair for all non-zero elements.

Here is an example:

>>> import numpy as np
>>> import torch
>>> a=torch.tensor([np.random.randint(0,4,5) for _ in range(6)])
>>> a
tensor([[1, 3, 3, 2, 3],
        [1, 1, 2, 1, 2],
        [1, 2, 1, 0, 0],
        [3, 3, 3, 1, 0],
        [0, 2, 0, 0, 2],
        [1, 1, 0, 3, 0]])
>>> torch.nonzero(a)
tensor([[0, 0],
        [0, 1],
        [0, 2],
        [0, 3],
        [0, 4],
        [1, 0],
        [1, 1],
        [1, 2],
        [1, 3],
        [1, 4],
        [2, 0],
        [2, 1],
        [2, 2],
        [3, 0],
        [3, 1],
        [3, 2],
        [3, 3],
        [4, 1],
        [4, 4],
        [5, 0],
        [5, 1],
        [5, 3]])
 


...