觉得有帮助请点赞关注收藏~~~
PyTorch提供了大量的对Tensor进行操作的函数或者方法,这些函数内部使用指针实现对矩阵的形状变换拼接拆分等操作,使得人们无须关心Tensor在内存中的物理结构或者管理指针就可以方便且快速的执行这些操作,下面nelement,ndimension,size等方法 可以查看矩阵元素的个数,轴的个数以及维度等
测试代码如下
import torch import numpy as np a=torch.rand(1,2,3,4,5) print("元素个数",a.nelement()) print("轴的个数",a.ndimension()) print("矩阵维度",a.shape)
在PyTorch中,reshape和view都能用来被更改Tensor的维度,它们区别在于view要求Tensor的物理内存必须是连续的,否则将报错,reshape则没有这种要求,但是view返回的一定是一个索引,reshape返回的是引用还是复制是不确定的
代码如下
import torch import numpy as np a=torch.rand(1,2,3,4,5) print("元素个数",a.nelement()) print("轴的个数",a.ndimension()) print("矩阵维度",a.shape) ########### b=a.view(2*3,4*5) print(b.shape) c=a.reshape(-1) print(c.shape) d=a.reshape(2*3,-1) print(d.shape)
squeeze和unsqueeze用来给Tensor去掉和添加轴,分别去掉维度为1 的轴和添加维度为1的轴
b=torch.squeeze(a) print(b.shape)
t和transpose用于转置二维矩阵,这两个函数只接受二维Tensor,t是transpose的简化版
对于高纬度Tensor,可以使用permute方法来变换维度
a=torch.tensor([[2]]) b=torch.tensor([[2,3]]) print(torch.transpose(a,1,0,)) print(torch.t(a)) print(torch.transpose(b,1,0,)) print(torch.t(b)) ############ a=torch.rand((1,224,224,3)) print(a.shape) b=a.permute(0,3,1,2) print(b.shape)
PyTorch提供了cat和stack方法用于拼接矩阵,cat在已有的轴dim上拼接矩阵,给定轴的维度可以不同,而其他轴的维度必须相同,stack在新的轴上面拼接,它要求被拼接的矩阵所有维度都相同
a=torch.randn(2,3) b=torch.randn(3,3) c=torch.cat((a,b)) d=torch.cat((b,b,b),dim=1) print(c.shape) print(d.shape) c=torch.stack((b,b),dim=1) d=torch.stack((b,b),dim=0) print(c.shape) print(d.shape)
除了拼接之外,还有split和chunk用于拆分矩阵,它们不同之处在于split传入的是拆分后每个矩阵的大小,可以传入list也可以传入整数,而chunk传入的是拆分的矩阵个数
a=torch.randn(10,3) for x in torch.split(a,[1,2,3,4],dim=0): print(x.shape) for x in torch.split(a,4,dim=0): print(x.shape) for x in torch.chunk(a,4,dim=0): print(x.shape)
最后 全部测试代码如下
import torch import numpy as np a=torch.rand(1,2,3,4,5) print("元素个数",a.nelement()) print("轴的个数",a.ndimension()) print("矩阵维度",a.shape) ########### b=a.view(2*3,4*5) print(b.shape) c=a.reshape(-1) print(c.shape) d=a.reshape(2*3,-1) print(d.shape) ############# b=torch.squeeze(a) print(b.shape) ############# a=torch.tensor([[2]]) b=torch.tensor([[2,3]]) print(torch.transpose(a,1,0,)) print(torch.t(a)) print(torch.transpose(b,1,0,)) print(torch.t(b)) ############ a=torch.rand((1,224,224,3)) print(a.shape) b=a.permute(0,3,1,2) print(b.shape) ################ a=torch.randn(2,3) b=torch.randn(3,3) c=torch.cat((a,b)) d=torch.cat((b,b,b),dim=1) print(c.shape) print(d.shape) c=torch.stack((b,b),dim=1) d=torch.stack((b,b),dim=0) print(c.shape) print(d.shape) ############ a=torch.randn(10,3) for x in torch.split(a,[1,2,3,4],dim=0): print(x.shape) for x in torch.split(a,4,dim=0): print(x.shape) for x in torch.chunk(a,4,dim=0): print(x.shape)