1.1 原理
残差结构如下:
作用如下:深度学习网络一个最大的问题是梯度消失和爆炸,以往解决方案则是数据的初始化和正则化,但是这样虽然解决了梯度的问题,深度加深了,却带来了另外的问题,就是网络性能的退化问题,深度加深了,错误率却上升了,而残差用来设计解决退化问题,其同时也解决了梯度问题,更使得网络的性能也提升了。
网络退化:当网络达到一定深度后,模型性能会暂时陷入一个瓶颈很难增加,当网络继续加深后,模型在测试集上的性能反而会下降!这其实就是深度学习退化
1.2 代码
class ResidualBlock(torch.nn.Module): def __init__(self,channels): super(ResidualBlock,self).__init__() self.channels = channels self.conv1 = torch.nn.Conv2d(channels,channels,kernel_size=3,padding=1) self.conv2 = torch.nn.Conv2d(channels,channels,kernel_size=3,padding=1) def forward(self, x): y = F.relu(self.conv1(x)) y = self.conv2(y) return F.relu(x+y)