torch.split()
是 PyTorch 中用于将张量拆分为多个张量的函数。它的语法如下:
torch.split(tensor, split_size_or_sections, dim=0)
其中,tensor
是需要拆分的张量,split_size_or_sections
可以是一个整数,表示在指定维度上平均拆分成几份;也可以是一个整数列表,表示在指定维度上按照列表中给定的大小拆分成多份。dim
是指定拆分维度的参数,默认为 0。
举个例子,假设有一个形状为 (6, 8)
的张量 x
:
import torch x = torch.randn(6, 8)
现在,我们想将它在第 0 维上平均拆分成 3 份,可以使用以下代码:
out = torch.split(x, 2, dim=0)
这将返回一个元组,包含 3 个大小为 (2, 8)
的张量。
如果我们想在第 1 维上按照列表 [2, 2, 4]
的大小拆分成 3 份,可以使用以下代码:
out = torch.split(x, [2, 2, 4], dim=1)
这将返回一个元组,包含 3 个大小分别为 (6, 2)
、(6, 2)
和 (6, 4)
的张量。
需要注意的是,当给定的拆分大小不等于张量在指定维度上的大小时,torch.split()
会引发一个异常。