这里总结如何利用Pytorch定义/修改一个模型,以及如何保存载入模型及参数
1 模型的定义
1.1 模型类的重载
我们通过继承torch.nn.Module来实现自己的类,其中__init__
和forward
函数是必须实现的:
__init__
:初始化模型forward
:向前传播,输入转输出
下面举一个简单的例子:
import torch from torch import nn class Net(nn.Module): def __init__(self): self.conv1 = nn.Conv2d(1, 6, 5) self.conv2 = nn.Conv2d(6, 16, 5) def forward(self, x): x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) x = F.max_pool2d(F.relu(self.conv2(x)), 2) return x net = Net() print(net)
结果
Net( (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1)) (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1)) )
上面代码,当执行net=Net()
时,__init__
就初始化了两个卷积。之后我们会输出数据x
,当我们执行net(x)
时进行forward正向传播,计算得到最终的结果。
1.2 封装
上一次介绍了基本模型的原理和实现,我们有时候定义的网络就是基本模块的堆叠,如堆叠n次(卷积,激活函数,BN层)。为了使网络结构看起来更加清晰,需要把不同层的结构封装起来。
有三种办法:Sequential
,ModuleList
和ModuleDict
。下面一一介绍并加以区别:
1.2.1 Sequential
一个有序的容器,神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行,同时以神经网络模块为元素的有序字典也可以作为传入参数。
1.下面是顺序添加神经网络模块的代码:
import torch.nn as nn model = nn.Sequential( nn.Conv2d(1, 20, 5), nn.ReLU(), nn.Conv2d(20, 64, 5), nn.ReLU() ) print(model)
上述将神经网络模块顺序添加到Sequential容器中,结果如下
Sequential( (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1)) (1): ReLU() (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1)) (3): ReLU() )
我们传入的神经网络模块名称自动生成
2.下面是传入以神经网络模块为元素的有序字典的代码
import torch.nn as nn from collections import OrderedDict model = nn.Sequential( OrderedDict([ ('conv1', nn.Conv2d(1,20,5)), ('relu1', nn.ReLU()), ('conv2', nn.Conv2d(20,64,5)), ('relu2', nn.ReLU()) ]) ) print(model)
结果如下:
Sequential( (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1)) (relu1): ReLU() (conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1)) (relu2): ReLU() )
神经网络模块的名称是我们自定义的
注意:上面两种方式是不能混用的
1.2.2 ModuleList
Module非常像列表,可以进行append等操作,只是对象变成了神经网络模块。但不同于Sequential可以直接进行正向传播,ModuleList必须通过遍历取得列表中的神经网络模块进行正向传播。
下面举一个例子:
from torch import nn import torch from torch import nn import torch class MyModule(nn.Module): def __init__(self): super(MyModule, self).__init__() self.model = nn.ModuleList([nn.Conv2d(10, 10, 2) for i in range(3)]) self.model.append(nn.Conv2d(10, 1, 2)) def forward(self, x): # ModuleList can act as an iterable, or be indexed using ints for i, l in enumerate(self.model): x = self.model[i](x) return x x = torch.full([2, 10, 5, 5], 1.0) model = MyModule() res = model(x) print(res.shape) print(model.model)
结果如下:
torch.Size([2, 1, 1, 1]) ModuleList( (0-2): 3 x Conv2d(10, 10, kernel_size=(2, 2), stride=(1, 1)) (3): Conv2d(10, 1, kernel_size=(2, 2), stride=(1, 1)) )
模块名称是自动生成的。
1.2.3 ModuleDict
ModuleDict和字典比较像,包含字典的一些方法:
- 1. clear(): 清空ModuleDict
- 2. items(): 返回可迭代的键值对(key-value pairs)
- 3. keys(): 返回字典的键(key)
- 4. values(): 返回字典的值(value)
- 5. pop(): 返回一对键值,并从字典中删除
与ModuleList相似,也不能直接进行正向传播,该方法生成的模块名称是自定义的。
下面举一个例子:
from torch import nn import torch class MyModule(nn.Module): def __init__(self): super(MyModule, self).__init__() self.model = nn.ModuleDict({ "con1": nn.Conv2d(10, 10, 2), "con2": nn.Conv2d(10, 10, 2), "con3": nn.Conv2d(10, 10, 2), "con4": nn.Conv2d(10, 10, 2), }) def forward(self, x): # ModuleList can act as an iterable, or be indexed using ints for model in list(self.model.values()): x = model(x) return x x = torch.full([2, 10, 5, 5], 1.0) model = MyModule() res = model(x) print(res.shape) print(model.model)
结果
torch.Size([2, 10, 1, 1]) ModuleDict( (con1): Conv2d(10, 10, kernel_size=(2, 2), stride=(1, 1)) (con2): Conv2d(10, 10, kernel_size=(2, 2), stride=(1, 1)) (con3): Conv2d(10, 10, kernel_size=(2, 2), stride=(1, 1)) (con4): Conv2d(10, 10, kernel_size=(2, 2), stride=(1, 1)) )
2 模型的修改
假设我们定义了如下模型:
from torch import nn import torch from collections import OrderedDict class MyModule(nn.Module): def __init__(self): super(MyModule, self).__init__() self.model1 = nn.Sequential( nn.Conv2d(10, 10 ,2), ) self.model2 = nn.Sequential(OrderedDict([ ("Sequential_OrderedDict", nn.Conv2d(10, 10, 2)) ])) self.model3 = nn.ModuleList([ nn.Conv2d(10, 10, 2) ]) self.model4 = nn.ModuleDict({ "ModuleDict": nn.Conv2d(10, 1, 2), }) def forward(self, x): # ModuleList can act as an iterable, or be indexed using ints x = self.model1(x) x = self.model2(x) x = self.model3[0](x) x = self.model4["ModuleDict"](x) return x x = torch.full([2, 10, 5, 5], 1.0) model = MyModule() res = model(x) for i in range(4): eval(f"print(model.model{i+1})".format(i))
结果如下:
Sequential( # model1 (0): Conv2d(10, 10, kernel_size=(2, 2), stride=(1, 1)) ) Sequential( # model2 (Sequential_OrderedDict): Conv2d(10, 10, kernel_size=(2, 2), stride=(1, 1)) ) ModuleList( # model3 (0): Conv2d(10, 10, kernel_size=(2, 2), stride=(1, 1)) ) ModuleDict( # model4 (ModuleDict): Conv2d(10, 1, kernel_size=(2, 2), stride=(1, 1)) )
其中model1到model4用了上面介绍过的不同方法定义,但是一致的是每个神经网络模块前都有一个索引,如model1的序号0,model2的字符串"Sequential_OrderedDict",model3的序号0,model4的字符串"ModuleDict"。我们就可以利用这些索引来改变其中的结构:
比如我们执行:
model.model4["ModuleDict"] = nn.Conv2d(10, 100, 2)
那么再输出模型结构时,model4就会变成:
ModuleDict( (ModuleDict): Conv2d(10, 100, kernel_size=(2, 2), stride=(1, 1)) )
3 模型保存和加载
保存分为两类,一是保存整个模型,二是保存模型权重。
3.1 模型的保存和加载
保存用到torch.save(model, save_dir)函数,其中model是自定义的模型对象,save_dir是保存路径。
以第2小节的代码为例,保存该模型代码如下:
save_dir = "" # 保存路径,自定义 torch.save(model, save_dir)
而加载该模型应该是:
torch.load(save_path) # save_path是保存的模型的路径
3.2 权重的保存和加载
仍以第2节的代码为例,这里就需要先用model.state_dict()
获取参数:
paramters = model.state_dict()
然后用torch.save保存
:
torch.save(paramters, save_path) # save_path是保存路径,自定义
而加载模型,仍然用torch.load
:
paramters = torch.load(save_path) # save_path是权重文件的保存路径
只是后面,我们需要用load_state_dict
将参数赋予模型:
model.load_state_dict(paramters)