+2 votes
in Programming Languages by (17.9k points)
edited by
I have a 2D tensor. I want to select some elements from it. Which function should I use for it?

1 Answer

+2 votes
by (48.9k points)

You can use the take() function with indices of elements as an argument. It will return a tensor whose shape will be the same as the indices. The take() function treats the input tensor as a 1D tensor.

Here are an examples:

When indices are 1D tensor

>>> import torch
>>> a=torch.randn(6,4)
>>> a
tensor([[-0.3410, -2.3171,  0.2685, -1.4083],
        [-0.1782,  0.4501,  0.4013, -0.4777],
        [-0.8800, -0.8078, -1.0272,  0.0961],
        [-1.2799, -0.5404, -1.3871, -1.5463],
        [-0.3515, -0.0466, -1.5026,  0.6122],
        [ 0.7668, -1.1009, -0.5753, -0.0123]])
>>> i=torch.tensor([1, 5, 6, 8])
>>> torch.take(a,i)
tensor([-2.3171,  0.4501,  0.4013, -0.8800])

When indices are 2D tensor

>>> import torch
>>> a=torch.randn(6,4)
>>> a
tensor([[-0.3410, -2.3171,  0.2685, -1.4083],
        [-0.1782,  0.4501,  0.4013, -0.4777],
        [-0.8800, -0.8078, -1.0272,  0.0961],
        [-1.2799, -0.5404, -1.3871, -1.5463],
        [-0.3515, -0.0466, -1.5026,  0.6122],
        [ 0.7668, -1.1009, -0.5753, -0.0123]])
>>> i=torch.tensor([[1,2],[3,4]])
>>> torch.take(a,i)
tensor([[-2.3171,  0.2685],
        [-1.4083, -0.1782]])


...