# 【从零开始学习深度学习】18. Pytorch中自定义层的几种方法：nn.Module、ParameterList和ParameterDict

### 1 不含模型参数的自定义层

import torch
from torch import nn
class CenteredLayer(nn.Module):
def __init__(self, **kwargs):
super(CenteredLayer, self).__init__(**kwargs)
def forward(self, x):
return x - x.mean()

layer = CenteredLayer()
layer(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float))

tensor([-2., -1.,  0.,  1.,  2.])

net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())

y = net(torch.rand(4, 8))
y.mean().item()

0.0

### 2 含模型参数的自定义层

ParameterList接收一个Parameter实例的列表作为输入然后得到一个参数列表，使用的时候可以用索引来访问某个参数，另外也可以使用appendextend在列表后面新增参数。

class MyDense(nn.Module):
def __init__(self):
super(MyDense, self).__init__()
self.params = nn.ParameterList([nn.Parameter(torch.randn(4, 4)) for i in range(3)])
self.params.append(nn.Parameter(torch.randn(4, 1)))
def forward(self, x):
for i in range(len(self.params)):
x = torch.mm(x, self.params[i])
return x
net = MyDense()
print(net)

MyDense(
(params): ParameterList(
(0): Parameter containing: [torch.FloatTensor of size 4x4]
(1): Parameter containing: [torch.FloatTensor of size 4x4]
(2): Parameter containing: [torch.FloatTensor of size 4x4]
(3): Parameter containing: [torch.FloatTensor of size 4x1]
)
)

ParameterDict接收一个Parameter实例的字典作为输入然后得到一个参数字典，然后可以按照字典的规则使用了。例如使用update()新增参数，使用keys()返回所有键值，使用items()返回所有键值对等等

class MyDictDense(nn.Module):
def __init__(self):
super(MyDictDense, self).__init__()
self.params = nn.ParameterDict({
'linear1': nn.Parameter(torch.randn(4, 4)),
'linear2': nn.Parameter(torch.randn(4, 1))
})
self.params.update({'linear3': nn.Parameter(torch.randn(4, 2))}) # 新增
def forward(self, x, choice='linear1'):
net = MyDictDense()
print(net)

MyDictDense(
(params): ParameterDict(
(linear1): Parameter containing: [torch.FloatTensor of size 4x4]
(linear2): Parameter containing: [torch.FloatTensor of size 4x1]
(linear3): Parameter containing: [torch.FloatTensor of size 4x2]
)
)

x = torch.ones(1, 4)
print(net(x, 'linear1'))
print(net(x, 'linear2'))
print(net(x, 'linear3'))

tensor([[1.5082, 1.5574, 2.1651, 1.2409]], grad_fn=<MmBackward>)
tensor([[ 2.2193, -1.6539]], grad_fn=<MmBackward>)

net = nn.Sequential(
MyDictDense(),
MyListDense(),
)
print(net)
print(net(x))

Sequential(
(0): MyDictDense(
(params): ParameterDict(
(linear1): Parameter containing: [torch.FloatTensor of size 4x4]
(linear2): Parameter containing: [torch.FloatTensor of size 4x1]
(linear3): Parameter containing: [torch.FloatTensor of size 4x2]
)
)
(1): MyListDense(
(params): ParameterList(
(0): Parameter containing: [torch.FloatTensor of size 4x4]
(1): Parameter containing: [torch.FloatTensor of size 4x4]
(2): Parameter containing: [torch.FloatTensor of size 4x4]
(3): Parameter containing: [torch.FloatTensor of size 4x1]
)
)
)
tensor([[-101.2394]], grad_fn=<MmBackward>)

### 总结

• 可以通过Module类自定义神经网络中的层，从而可以被重复调用。

|
20天前
|

33 0
|
1月前
|

40 0
|
5天前
|

19 4
|
6天前
|

14 3
|
14天前
|

32 5
|
17天前
|

PAI DLC与其他深度学习框架如TensorFlow或PyTorch的异同
PAI DLC与其他深度学习框架如TensorFlow或PyTorch的异同
20 1
|
27天前
|

OpenCV与AI深度学习之常用AI名词解释学习
AGI：Artificial General Intelligence (通用人工智能)：是指具备与人类同等或超越人类的智能，能够表现出正常人类所具有的所有智能行为。又被称为强人工智能。
30 2
|
6天前
|

13 0
|
17天前
|

【7月更文挑战第3天】 使用Python实现深度学习模型：迁移学习与领域自适应教程
14 0
|
1月前
|

《PyTorch深度学习实践》--3梯度下降算法
《PyTorch深度学习实践》--3梯度下降算法
13 0