# [Python] how to select some rows/columns from a given tensor

How can I select some rows/columns from a tensor using indices?

+1 vote
by (48.9k points)

The slicing operation works on tensors. So, you can select desired rows/columns from a tensor applying the slicing operation.

Here are examples:

>>> import torch
>>> a=torch.randn(6,4)
>>> a
tensor([[-0.0457, -0.4924, -0.7026,  0.0567],
[-0.5104, -0.1395, -0.3003,  0.8491],
[ 2.2846,  0.5619, -0.1806,  0.9625],
[ 0.7884,  1.1767,  2.0025, -0.0589],
[-0.1579,  0.8199, -0.5279,  0.2966],
[ 0.0946, -0.7405,  0.4907,  1.3673]])

If I have to select rows [0, 2, 3], I will do the following:

>>> idx = [0,2,3]
>>> a[idx]    # select rows
tensor([[-0.0457, -0.4924, -0.7026,  0.0567],
[ 2.2846,  0.5619, -0.1806,  0.9625],
[ 0.7884,  1.1767,  2.0025, -0.0589]])

If I have to select columns [0, 2, 3], I will do the following:

>>> a[:,idx]    # select columns
tensor([[-0.0457, -0.7026,  0.0567],
[-0.5104, -0.3003,  0.8491],
[ 2.2846, -0.1806,  0.9625],
[ 0.7884,  2.0025, -0.0589],
[-0.1579, -0.5279,  0.2966],
[ 0.0946,  0.4907,  1.3673]])

If I have to select rows [0,2,3] and columns [1,3], I will do the following:

>>> r=[0,2,3]
>>> c=[1,3]
>>> a[:,c][r,:]
tensor([[-0.4924,  0.0567],
[ 0.5619,  0.9625],
[ 1.1767, -0.0589]])
>>> a[r,:][:,c]
tensor([[-0.4924,  0.0567],
[ 0.5619,  0.9625],
[ 1.1767, -0.0589]])

You can also use the torch.index_select() function to select indices along a given dimension. The syntax of this function is as follows:

torch.index_select(input, dim, index, *, out=None) → Tensor

Here is an example to select rows [0, 2, 3]  or columns [0, 2, 3]:

>>> ix = torch.tensor([0,2,3])
>>> torch.index_select(a, 0, ix)    # select rows
tensor([[-0.0457, -0.4924, -0.7026,  0.0567],
[ 2.2846,  0.5619, -0.1806,  0.9625],
[ 0.7884,  1.1767,  2.0025, -0.0589]])
>>> torch.index_select(a, 1, ix)    # select columns
tensor([[-0.0457, -0.7026,  0.0567],
[-0.5104, -0.3003,  0.8491],
[ 2.2846, -0.1806,  0.9625],
[ 0.7884,  2.0025, -0.0589],
[-0.1579, -0.5279,  0.2966],
[ 0.0946,  0.4907,  1.3673]])
>>>