# [Pytorch] how to split a tensor into a given number of chunks

I want to split a tensor into a specified number of chunks. Which Pytorch function should I use for it?

by (28.5k points)

torch.chunk() function can divide a tensor into "k" chunks along a given axis. This function returns a list of tensors. Each chunk may not be of the same size; it depends on the size of the tensor along the given axis and the number of chunks.

E.g., If there are five columns in a tensor and you want to divide it into two chunks, one tensor will have three columns, and the other tensor will have two columns.

Here is an example:

I am splitting the tensor into two chunks, horizontally and vertically.

>> import torch

>>> x=torch.tensor([[1,2,3,4,5],[11,12,13,14,15],[21,22,23,24,25],[31,32,33,34,35]])

>>> x

tensor([[ 1,  2,  3,  4,  5],

[11, 12, 13, 14, 15],

[21, 22, 23, 24, 25],

[31, 32, 33, 34, 35]])

>>> torch.chunk(x,2,dim=0)

(tensor([[ 1,  2,  3,  4,  5],

[11, 12, 13, 14, 15]]), tensor([[21, 22, 23, 24, 25],

[31, 32, 33, 34, 35]]))

>>> torch.chunk(x,2,dim=1)

(tensor([[ 1,  2,  3],

[11, 12, 13],

[21, 22, 23],

[31, 32, 33]]), tensor([[ 4,  5],

[14, 15],

[24, 25],

[34, 35]]))