+3 votes
in Programming Languages by (17.9k points)
I want to split a tensor into a specified number of chunks. Which Pytorch function should I use for it?

1 Answer

0 votes
by (28.5k points)

torch.chunk() function can divide a tensor into "k" chunks along a given axis. This function returns a list of tensors. Each chunk may not be of the same size; it depends on the size of the tensor along the given axis and the number of chunks. 

E.g., If there are five columns in a tensor and you want to divide it into two chunks, one tensor will have three columns, and the other tensor will have two columns.

Here is an example:

I am splitting the tensor into two chunks, horizontally and vertically.

>> import torch

>>> x=torch.tensor([[1,2,3,4,5],[11,12,13,14,15],[21,22,23,24,25],[31,32,33,34,35]])

>>> x

tensor([[ 1,  2,  3,  4,  5],

        [11, 12, 13, 14, 15],

        [21, 22, 23, 24, 25],

        [31, 32, 33, 34, 35]])

>>> torch.chunk(x,2,dim=0)

(tensor([[ 1,  2,  3,  4,  5],

        [11, 12, 13, 14, 15]]), tensor([[21, 22, 23, 24, 25],

        [31, 32, 33, 34, 35]]))

>>> torch.chunk(x,2,dim=1)

(tensor([[ 1,  2,  3],

        [11, 12, 13],

        [21, 22, 23],

        [31, 32, 33]]), tensor([[ 4,  5],

        [14, 15],

        [24, 25],

        [34, 35]]))


...