一、介绍
torch.Tensor()大家都很熟悉,torch中操作的数据类型都是Tensor。Storage在实际使用中却很少接触,但它却非常重要,因为Tensor真正的数据存储在Storage中,接下来我将结合代码简单的介绍一下Storage。
官方文档:PACKAGE参考 - torch.Storage - 《PyTorch中文文档》
Storage的位置:torch.Storage()
官方解释:
一个torch.Storage
是一个单一数据类型的连续一维数组。
每个
torch.Tensor
都有一个对应的、相同数据类型的存储。
二、真正的数据存在Storage中
我们可以用 torch.Tensor()新建一个Tensor,并且规定形状。一个Tensor分为头信息区和存储区(Storage)。信息区主要保存张量的形状(size)、步长(stride)、数据类型(dtype)等信息。真正的数据保存在存储区。
代码如下:
import torch data1 = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] tensor_data1 = torch.Tensor(data1) print("tensor_data1.size():", tensor_data1.size()) print("tensor_data1.dtype:", tensor_data1.dtype) print("tensor_data1.storage():", tensor_data1.storage()) # 输出如下: # tensor_data1.size(): torch.Size([3, 3]) # tensor_data1.dtype: torch.float32 # tensor_data1.storage(): 1.0 # 2.0 # 3.0 # 4.0 # 5.0 # 6.0 # 7.0 # 8.0 # 9.0
可以看到Tensor有size、type等属性,真正的数据存在Storage中。
三、Storage是连续一维数组
Tensor无论形状如何,torch.Storage都是一个单一数据类型的连续一维数组。我们可以直接创建一个Storage对象,但是想要进行计算梯度、反向传播、正向传播等操作,还是需要将Storage转换成Tensor。
代码如下:
import torch data1 = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] storage_date1 = torch.Storage(data1) # Storage转Tensor tensor_data1 = torch.Tensor(storage_date1) print("storage_date1:", storage_date1) # storage_date1: 1.0 # 2.0 # 3.0 # 4.0 # 5.0 # 6.0 # 7.0 # 8.0 # 9.0 # [torch.storage._TypedStorage(dtype=torch.float32, device=cpu) of size 9]
四、每个Tensor都有一个对应的Storage
Tensor有如下几种数据类型:
class DoubleTensor(Tensor): ... class FloatTensor(Tensor): ... class LongTensor(Tensor): ... class IntTensor(Tensor): ... class ShortTensor(Tensor): ... class HalfTensor(Tensor): ... class CharTensor(Tensor): ... class ByteTensor(Tensor): ... class BoolTensor(Tensor): ...
每种Tensor都有对应类型的Storage,使用torch.Tensor()新建默认是FloatTensor。
import torch data1 = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] tensor_data1 = torch.Tensor(data1) tensor_data331 = tensor_data1.view(3, 3, 1) print("id(tensor_data1) == id(tensor_data331):", id(tensor_data1) == id(tensor_data331)) print("id(tensor_data1.storage()) == id(tensor_data331.storage()):", id(tensor_data1.storage()) == id(tensor_data331.storage())) # id(tensor_data1) == id(tensor_data331): False # id(tensor_data1.storage()) == id(tensor_data331.storage()): True
可以看到,虽然我将tensor_data1变换了形状付给另外一个变量,但是数据的内容并没有变,两个变量引用的是相同的数据存储。
而且我发现一个有意思的事情,一旦创建了Tensor,它的Storage数值是不变的。即使改变数据类型,它的真实值也不会变。比如下面这个例子,我将float的数据类型变成int,再变回float,它的小数也不会丢,内存地址也不会变:
import torch data1 = [[1.1, 2.1, 3.1], [4.1, 5.1, 6.1], [7.1, 8.1, 9.1]] tensor_data1 = torch.Tensor(data1) tensor_data_int = tensor_data1.int() tensor_data_f = tensor_data1.float() print("tensor_data1:", tensor_data1.storage()) print("tensor_data_int:", tensor_data_int.storage()) print("tensor_data_f:", tensor_data_f.storage()) # 两个Tensor内存地址肯定不一样,因为是两个对象 print("id(tensor_data_int) == id(tensor_data_f):", id(tensor_data_int) == id(tensor_data_f)) # 虽然数据类型不一样,但是两个Storage内存地址一样,因为类型其实是跟着Tensor走的 print("id(tensor_data_int.storage()) == id(tensor_data_f.storage()):", id(tensor_data_int.storage()) == id(tensor_data_f.storage())) # tensor_data1: 1.100000023841858 # 2.0999999046325684 # 3.0999999046325684 # 4.099999904632568 # 5.099999904632568 # 6.099999904632568 # 7.099999904632568 # 8.100000381469727 # 9.100000381469727 # [torch.storage._TypedStorage(dtype=torch.float32, device=cpu) of size 9] # tensor_data_int: 1 # 2 # 3 # 4 # 5 # 6 # 7 # 8 # 9 # [torch.storage._TypedStorage(dtype=torch.int32, device=cpu) of size 9] # tensor_data_f: 1.100000023841858 # 2.0999999046325684 # 3.0999999046325684 # 4.099999904632568 # 5.099999904632568 # 6.099999904632568 # 7.099999904632568 # 8.100000381469727 # 9.100000381469727 # [torch.storage._TypedStorage(dtype=torch.float32, device=cpu) of size 9] # id(tensor_data_int) == id(tensor_data_f): False # tensor_data_int.storage().type(): torch.int32 torch.float32 # id(tensor_data_int.storage()) == id(tensor_data_f.storage()): True