0. 前言
在深度学习实际应用中,往往涉及到的神经元网络模型都很大,权重参数众多,因此会导致训练epoch次数很多,训练时间长。
如果每次调整非模型相关的参数(训练数据集、优化函数类型、学习率、迭代次数)都要重新训练一次模型,这显然会浪费大量的训练时间。
而且,对于一些成熟的网络模型,已经有前人做过大量的“预训练”,这时如果能基于前人预训练的结果,训练自己的数据集,明显会事半功倍。
因此,加载与保存权重在深度学习实际使用中有很大的必要。
1. Pytorch框架加载与保存权重的方法
①加载权重的方法: .load_state_dict()方法说明:
.load_state_dict()定义:
def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]',strict: bool = True):
- state_dict :即要加载的权重,通常是一个文件地址;
- strick: 可以理解为等于"True"时是“精确匹配”,要求要加载的权重与要被加载权重的模型完全匹配。
Pytorch源文件注释:
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
strict (bool, optional): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
*小注释:meth笔误了,应该是mesh,网格
②保存权重的方法:.save()方法说明:
.save()定义:
def save(obj, f: Union[str, os.PathLike, BinaryIO, IO[bytes]], pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True) -> None:
- obj:要保存的权重参数;
- f:保存的文件路径
这里仅说明.save()在保存网络模型权重数据上的作用。实际上.save()还有很多应用,例如:保存整个网络,这里不再赘述。
Pytorch源文件注释:
"""Saves an object to a disk file.
See also: `saving-loading-tensors`
Args:
obj: saved object
f: a file-like object (has to implement write and flush) or a string or
os.PathLike object containing a file name
pickle_module: module used for pickling metadata and objects
pickle_protocol: can be specified to override the default protocol
2. 实例问题说明
首先说明本次的实例问题:本次要构建的神经元网络为一个“平方网络”,即网络输出数据为输入数据的平方。
网络模型结构:
输入(1)→全连接层(1×5)→Sigmoid激活函数(5)→全连接层(5×5)→Sigmoid激活函数(5)→全连接层(5×1)→输出(1)
训练数据:
输入数据[1, 2, 3, 4, 5];
输出数据[1, 4, 9, 16, 25]
3. 加载权重数据
直接上代码
import torch class LinearNet(torch.nn.Module): def __init__(self, input_size, output_size): super().__init__() self.net = torch.nn.Sequential( torch.nn.Linear(in_features=input_size, out_features= 5, bias=True), torch.nn.Sigmoid(), torch.nn.Linear(in_features= 5, out_features=5, bias=True), torch.nn.Sigmoid(), torch.nn.Linear(in_features=5, out_features=output_size, bias=True) ) def forward(self,x): return self.net(x) square_net = LinearNet(1,1) square_net.load_state_dict(torch.load('weight.pth')) #直接加载已经训练好的权重 if __name__ == '__main__': print(square_net(torch.tensor([3.16],dtype=torch.float32)))
其中weight.pth是我已经训练好的权重数据路径,这里定义好网络模型后,直接加载权重数据,不必关心这个权重是如何训练来的,更不必关系具体权重的值是多少。测试输入为3.16输出为:
tensor([9.9180], grad_fn=<AddBackward0>)
这里要注意的是:因为上面strict默认为True,即为“精确匹配”,这里新构建的网络模型结构必须和权重来源的网络模型结构相同。
4. 保存权重数据
import torch input = torch.tensor([[1],[2],[3],[4],[5]], dtype=torch.float32) output = torch.tensor([[1],[4],[9],[16],[25]], dtype=torch.float32) class LinearNet(torch.nn.Module): def __init__(self, input_size, output_size): super().__init__() self.net = torch.nn.Sequential( torch.nn.Linear(in_features=input_size, out_features= 5, bias=True), torch.nn.Sigmoid(), torch.nn.Linear(in_features= 5, out_features=5, bias=True), torch.nn.Sigmoid(), torch.nn.Linear(in_features=5, out_features=output_size, bias=True) ) def forward(self,x): return self.net(x) Loss = torch.nn.MSELoss() linear_net = LinearNet(1,1) opt = torch.optim.SGD(linear_net.parameters(), lr= 0.003) for k in range(1000): opt.zero_grad() for i in range(len(input)): train_out = linear_net(input[i]) loss = Loss(train_out, output[i]) loss.backward() opt.step() torch.save(linear_net.state_dict(),'weight.pth') #保存.pth权重文件 for keys,values in linear_net.state_dict().items(): #查看权重名称及值 print(keys) print(values) print('************************************************************************') if __name__ == '__main__': print(linear_net(torch.tensor([3.16],dtype=torch.float32)))
这里可以看到具体的训练过程及相关的训练参数,权重保存在'weight.pth'文件中。
可以通过print查看具体的权重数值:
net.0.weight tensor([[-0.8204], [-1.7341], [-0.6987], [ 0.9370], [-1.5558]]) ************************************************************************ net.0.bias tensor([ 0.9285, 2.1061, 1.0247, -2.9221, 7.1159]) ************************************************************************ net.2.weight tensor([[-1.6075, -1.3072, -1.5342, 2.4527, -3.9922], [-0.7101, -1.5125, -0.6791, 2.0325, -2.3406], [-1.1707, -1.6899, -0.9883, 2.9682, -1.5409], [-1.1992, -2.0559, -0.7610, 2.3890, -1.3782], [-1.1274, -1.7907, -1.0860, 2.3549, -3.6847]]) ************************************************************************ net.2.bias tensor([0.4826, 0.7057, 0.9702, 1.0532, 0.4214]) ************************************************************************ net.4.weight tensor([[7.3601, 4.7667, 6.2473, 5.0187, 7.2028]]) ************************************************************************ net.4.bias tensor([-0.2476]) ************************************************************************