接下来我们进入到pytorch的形状操作
介绍: 在搭建网络模型时,掌握对张量形状的操作是非常重要的,因为这直接影响到数据如何在网络各层之间传递和处理。网络层与层之间很多都是以不同的 shape 的方式进行表现和运算,我们需要掌握对张量形状的操作,以便能够更好处理网络各层之间的数据连接,确保数据能够顺利地在网络中流动,接下来我们看看几个常用的函数方法🌹
reshape 函数
💎reshape 函数可以在保证张量数据不变的前提下改变数据的维度,将其转换成指定的形状,在后面的神经网络学习时,会经常使用该函数来调节数据的形状,以适配不同网络层之间的数据传递。
import torch tensor = torch.tensor([[1, 2], [3, 4]]) print("原始张量:") print(tensor) reshaped_tensor = tensor.reshape(1, 4) print("修改后的张量:") print(reshaped_tensor)
当第二个参数为-1时,表示自动计算该维度的大小,以使得张量的元素总数不变,这样我们可以免去思考的时间。
import torch tensor = torch.tensor([[1, 2], [3, 4]]) print("原始张量:") print(tensor) reshaped_tensor = tensor.reshape(1, -1) print("修改后的张量:") print(reshaped_tensor) 原始张量: tensor([[1, 2], [3, 4]]) 修改后的张量: tensor([[1, 2, 3, 4]])
transpose 和 permute 函数
💎transpose 函数可以实现交换张量形状的指定维度,permute 函数可以一次交换更多的维度。
- transpose:transpose用于交换张量的两个维度。它并不改变张量中元素的数量,也不改变每个元素的值,只是改变了元素在张量中的排列顺序。在二维情况下,transpose相当于矩阵的转置,将行变为列,列变为行。在多维情况下,它会按照提供的轴(dimension)参数来重新排列维度。
- reshape:reshape则是改变张量的形状,而不改变任何特定的维度位置。你可以使用reshape将张量从一种形状变换到另一种形状,只要两个形状的元素总数相同。这个过程不涉及元素之间的交换,只是调整了元素在内存中的分布,以适应新的形状。在内部实现上,reshape通常通过修改张量的元数据(如shape和strides属性)来实现,而不需要重新排列数据本身。
- 如果你需要保持张量中元素的相对位置不变,仅调整张量的维度顺序,那么应该使用transpose;如果你需要改变张量的整体形状而不关心维度的顺序,reshape会是正确的选择。
data = torch.tensor(np.random.randint(0, 10, [3, 4, 5])) print('data shape:', data.size()) 交换1和2维度 new_data = torch.transpose(data, 1, 2) print('data shape:', new_data.size()) new_data = torch.transpose(data, 0, 1) new_data = torch.transpose(new_data, 1, 2) print('new_data shape:', new_data.size()) new_data = torch.permute(data, [1, 2, 0]) print('new_data shape:', new_data.size()) data shape: torch.Size([3, 4, 5]) data shape: torch.Size([3, 5, 4]) new_data shape: torch.Size([4, 5, 3]) new_data shape: torch.Size([4, 5, 3])
view 和 contigous 函数
💎view 函数也可以用于修改张量的形状,只能用于存储在整块内存中的张量。在 PyTorch 中,有些张量是由不同的数据块组成的,它们并没有存储在整块的内存中,view 函数无法对这样的张量进行变形处理,如果张量存储在不连续的内存中,使用view函数会导致错误。在这种情况下,可以使用contiguous函数将张量复制到连续的内存中,然后再使用view函数进行形状修改。
import torch tensor = torch.randn(2, 3, 4) reshaped_tensor = tensor.view(6, 4) contiguous_tensor = tensor.contiguous()
使用 transpose 函数修改形状或者 permute 函数的处理之后,就无法使用 view 函数进行形状操作,这时data.contiguous().view(2, 3)即可。
squeeze 和 unsqueeze 函数
💎squeeze函数用于移除张量中维度为1的轴,而unsqueeze函数则用于在指定位置插入一个新的维度。
torch.squeeze(input, dim=None)
input
: 输入张量。dim
: 可选参数,指定要移除的维度。如果不指定,则移除所有大小为1的维度。
import torch A = torch.tensor([[[1, 2, 3], [4, 5, 6]]]) print(A.shape) # 输出:torch.Size([1, 2, 3]) B = torch.squeeze(A) print(B.shape) # 输出:torch.Size([2, 3]) C = torch.squeeze(A, 0) print(C.shape) # 输出:torch.Size([2, 3]) D = torch.squeeze(A, 1) print(D.shape) # 输出:torch.Size([1, 3])
torch.unsqueeze(input, dim)
input
: 输入张量。dim
: 指定要插入新维度的位置。
import torch A = torch.tensor([1, 2, 3]) print(A.shape) # 输出:torch.Size([3]) B = torch.unsqueeze(A, 0) print(B.shape) # 输出:torch.Size([1, 3]) C = torch.unsqueeze(A, 1) print(C.shape) # 输出:torch.Size([3, 1])
🎰小结
- reshape函数可以在保证张量不变的前提下改变数据维度。
- transpose(转置)函数可以实现交换张量形状的指定维度,permute可以一次交换更多维度。
- view函数也可以用于修改张量的形状,但是他要求被转换的张量内存必须连续,所以一般配合contiguous(连续的)函数使用。
- squeeze(挤压)函数和unsqueeze函数可以用来增加或者减少维度。