pytorch中有两种方式可以进行维度扩展,第一种是torch.unsqueeze()
,第二种是使用None索引
进行扩展。
例1:在首部添加一个新的维度
这里定义了一个张量维度为【3,5,7】,我们想要在首部增加一个维度,变成【1,3,5,7】,这里我们可以使用None进行扩增维度。
a = torch.randn(3, 5, 7) print(a[None, ...].shape)
torch.Size([1, 3, 5, 7])
None可以理解为占位,该位置代表一个新的维度为1,...
符号代表切片所有维度。
这里还有另外一种写法:
print(a.unsqueeze(dim=0).shape)
例2:跳步扩展维度
现有一个张量为【3,5,7】,我们想把这个张量变成维度为【1,3,1,5,1,7】,如果使用unsqueeze()会比较麻烦,需要调用三次函数,而且每次调用后还需要计算新的新的添加维度位置。
但如果使用None索引机制就会方便许多。
a = a.unsqueeze(1) a = a.unsqueeze(3) a = a.unsqueeze(5)
print(a[None, :, None, :, None, :].shape)
torch.Size([1, 3, 1, 5, 1, 7])