在pytorch中已经帮助我们写好了保存模型的方法,一般有两种方式,方法一是保存整个模型,方法二是只保存模型参数状态。
其实对于原生python也有保存模型的方式,可以使用dump、load来保存。
class Linear(nn.Module): def __init__(self): super().__init__() self.w1 = nn.Parameter(torch.randn(3, 4)) self.b1 = nn.Parameter(torch.randn(1, 3)) self.w2 = nn.Parameter(torch.randn(3, 2)) self.b2 = nn.Parameter(torch.randn(1, 2)) def forward(self, x): x = F.linear(x, self.w1, self.b1) return F.linear(x, self.w2, self.b2)
方法一
对于方法一是将整个模型全部保存,包括模型参数及模型的结构都会保存,所以模型较重,读写速度较慢,而且这种方式容易出错,虽然方法使用简单,但是不推荐使用。
model = Linear() torch.save(model, 'model.pth') new_model = torch.load('model.pth')
方法二
方法二是只保存模型对应的参数,创建好一个新的模型,我们只需要将参数读取到新的模型中即可,这种方法尤为推荐,只不过相对方法一写起来会复杂一点。
torch.save(model.state_dict(), 'model.pth') new_model = Linear() new_model.load_state_dict(torch.load('model.pth'))
注意一点,二者读取参数时,方法一不需要实例化新模型,直接读取model.pth就会返回一个模型,因为方法一会保存模型的结构信息。
对于方法二需要提前定义模型,因为方法二只保存参数,我们需要先实例化一个模型,然后把读取的模型参数加载到模型中。