# [PyTorch] How to select some elements from a given tensor

+2 votes

edited
I have a 2D tensor. I want to select some elements from it. Which function should I use for it?

## 1 Answer

+2 votes
by (48.9k points)

You can use the take() function with indices of elements as an argument. It will return a tensor whose shape will be the same as the indices. The take() function treats the input tensor as a 1D tensor.

Here are an examples:

When indices are 1D tensor

>>> import torch
>>> a=torch.randn(6,4)
>>> a
tensor([[-0.3410, -2.3171,  0.2685, -1.4083],
[-0.1782,  0.4501,  0.4013, -0.4777],
[-0.8800, -0.8078, -1.0272,  0.0961],
[-1.2799, -0.5404, -1.3871, -1.5463],
[-0.3515, -0.0466, -1.5026,  0.6122],
[ 0.7668, -1.1009, -0.5753, -0.0123]])
>>> i=torch.tensor([1, 5, 6, 8])
>>> torch.take(a,i)
tensor([-2.3171,  0.4501,  0.4013, -0.8800])

When indices are 2D tensor

>>> import torch
>>> a=torch.randn(6,4)
>>> a
tensor([[-0.3410, -2.3171,  0.2685, -1.4083],
[-0.1782,  0.4501,  0.4013, -0.4777],
[-0.8800, -0.8078, -1.0272,  0.0961],
[-1.2799, -0.5404, -1.3871, -1.5463],
[-0.3515, -0.0466, -1.5026,  0.6122],
[ 0.7668, -1.1009, -0.5753, -0.0123]])
>>> i=torch.tensor([[1,2],[3,4]])
>>> torch.take(a,i)
tensor([[-2.3171,  0.2685],
[-1.4083, -0.1782]])

+3 votes
1 answer
+2 votes
1 answer
+3 votes
1 answer
+3 votes
1 answer
+3 votes
1 answer