在pytorch中,保存神经网络用方法:
torch.save(net, 'net.pkl')
提取神经网络用方法:
torch.load('net.pkl')
保存神经网络有两种方式:
1、保存整个网络
torch.save(net, 'net.pkl')
这种方法能最大程度的保留网络的所有信息,缺点是读取网络时速度稍慢
2、保存网络的状态信息
torch.save(net.state_dict(), 'net_params.pkl')
这种方法只保留网络当前的状态信息,保存和读取速度快,保存的pkl文件体积小,缺点是在读取网络时需要自行先构建网络,否则无法还原信息
示例:
import torch import torch.nn.functional as F from torch.autograd import Variable import matplotlib.pyplot as plt x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) y = x.pow(2) + 0.2 * torch.rand(x.size()) x, y = Variable(x).cuda(), Variable(y).cuda() # 保存网络 def save(): net = torch.nn.Sequential( torch.nn.Linear(1, 10), torch.nn.ReLU(), torch.nn.Linear(10, 1), ).cuda() optimizer = torch.optim.SGD(net.parameters(), lr=0.5) loss_func = torch.nn.MSELoss() for t in range(300): prediction = net(x) loss = loss_func(prediction, y) optimizer.zero_grad() loss.backward() optimizer.step() plt.figure(1, figsize=(10,3)) plt.subplot(131) plt.scatter(x.data.cpu().numpy(), y.data.cpu().numpy()) plt.plot(x.data.cpu().numpy(), prediction.data.cpu().numpy(), 'r-', lw=5) # 保存整个网络 torch.save(net, 'net.pkl') # 保存网络当前的状态 torch.save(net.state_dict(), 'net_params.pkl') # 提取整个网络 def restore_net(): net = torch.load('net.pkl').cuda() prediction = net(x) plt.figure(1, figsize=(10, 3)) plt.subplot(132) plt.scatter(x.data.cpu().numpy(), y.data.cpu().numpy()) plt.plot(x.data.cpu().numpy(), prediction.data.cpu().numpy(), 'r-', lw=5) # 提取网络状态 def restore_params(): net = torch.nn.Sequential( torch.nn.Linear(1, 10), torch.nn.ReLU(), torch.nn.Linear(10, 1), ).cuda() net.load_state_dict(torch.load('net_params.pkl')) prediction = net(x) plt.figure(1, figsize=(10, 3)) plt.subplot(133) plt.scatter(x.data.cpu().numpy(), y.data.cpu().numpy()) plt.plot(x.data.cpu().numpy(), prediction.data.cpu().numpy(), 'r-', lw=5) save() restore_net() restore_params() plt.show()
图一为保存的神经网络,图二、三分别为用不同方法提取的神经网络,可以看到,两种提取方式的结果是一致的