Tensor
Tensor,它可以是0维、一维以及多维的数组,你可以将它看作为神经网络界的Numpy,它与Numpy相似,二者可以共享内存,且之间的转换非常方便。
但它们也不相同,最大的区别就是Numpy会把ndarray放在CPU中进行加速运算,而由Torch产生的Tensor会放在GPU中进行加速运算。
对于Tensor,从接口划分,我们大致可分为2类:
1.torch.function:如torch.sum、torch.add等。
2.tensor.function:如tensor.view、tensor.add等。
而从是否修改自身来划分,会分为如下2类:
1.不修改自身数据,如x.add(y),x的数据不变,返回一个新的Tensor。
2.修改自身数据,如x.add_(y),运算结果存在x中,x被修改。
简单的理解就是方法名带不带下划线的问题。
现在,我们来实现2个数组对应位置相加,看看其效果就近如何:
import torch x = torch.tensor([1, 2]) y = torch.tensor([3, 4]) print(x + y) print(x.add(y)) print(x) print(x.add_(y)) print(x)
运行之后,效果如下:
下面,我们来正式讲解Tensor的使用方式。
创建Tensor
与Numpy一样,创建Tensor也有很多的方法,可以自身的函数进行生成,也可以通过列表或者ndarray进行转换,同样也可以指定维度等。具体方法如下表(数组即张量):
函数 | 意义 |
Tensor(*size) | 直接从参数构造,支持list,Numpy数组 |
eye(row,column) | 创建指定行列的二维Tensor |
linspace(start,end,steps) | 从start到end,均匀切分成steps份 |
logspace(start,end,steps) | 从10^start到10^and,均分成steps份 |
rand/randn(*size) | 生成[0,1)均匀分布/标准正态分布的数据 |
ones(*size) | 生成指定shape全为1的张量 |
zeros(*size) | 生成指定shape全为0的张量 |
ones_like(t) | 返回与t的shape相同的张量,且元素全为1 |
zeros_like(t) | 返回与t的shape相同的张量,且元素全为0 |
arange(start,end,step) | 在区间[start,end)上,以间隔step生成一个序列张量 |
from_Numpy(ndarray) | 从ndarray创建一个Tensor |
这里需要注意Tensor有大写的方法也有小写的方法,具体效果我们先来看看代码:
import torch t1 = torch.tensor(1) t2 = torch.Tensor(1) print("值{0},类型{1}".format(t1, t1.type())) print("值{0},类型{1}".format(t2, t2.type()))
运行之后,效果如下:
可以看到,tensor与Tensor生成的值的类型就不同,而且t2(Tensor)返回一个大小为1的张量,而t1(tensor)返回的就是1这个值。
其他示例如下:
import torch import numpy as np t1 = torch.zeros(1, 2) print(t1) t2 = torch.arange(4) print(t2) t3 = torch.linspace(10, 5, 6) print(t3) nd = np.array([1, 2, 3, 4]) t4 = torch.from_numpy(nd) print(t4)
其他例子基本与上面基本差不多,这里不在赘述。
修改Tensor维度
同样的与Numpy一样,Tensor一样有维度的修改函数,具体的方法如下表所示:
函数 | 意义 |
size() | 返回张量的shape,即维度 |
numel(input) | 计算张量的元素个数 |
view(*shape) | 修改张量的shape,但View返回的对象与源张量共享内存,修改一个,另一个也被修改。Reshape将生成新的张量,而不要求源张量是连续的,View(-1)展平数组 |
resize | 类似与view,但在size超出时,会重新分配内存空间 |
item | 若张量为单元素,则返回Python的标量 |
unsqueeze | 在指定的维度增加一个“1” |
squeeze | 在指定的维度压缩一个“1” |
示例代码如下所示:
import torch t1 = torch.Tensor([[1, 2]]) print(t1) print(t1.size()) print(t1.dim()) print(t1.view(2, 1)) print(t1.view(-1)) print(torch.unsqueeze(t1, 0)) print(t1.numel())
运行之后,效果如下: