一文详解残差网络

本文涉及的产品
实时数仓Hologres,5000CU*H 100GB 3个月
智能开放搜索 OpenSearch行业算法版,1GB 20LCU 1个月
实时计算 Flink 版,1000CU*H 3个月
简介: 残差网络(ResNet)源于2016年的论文《Deep Residual Learning for Image Recognition》,旨在解决深层网络中的梯度消失和爆炸问题。通过引入残差块,即在网络中添加跳跃连接,使得信息可以直接跨过多层传递,从而有效解决了网络加深导致的训练困难。ResNet不仅显著提高了模型性能,还促进了深度学习领域的发展。

[TOC]

残差网络

残差网路的由来

​ 残差网络这一思想的起源是在2016年论文《Deep Residual Learning for Image Recognition》中第一次被提出,目前的引用已经高达3w,现代的深度卷积神经网络大多都是使用此网络模型结构。

1.png

​ 在神经网络的发展期间,深度卷积网络是存在一定瓶颈的----即梯度爆炸梯度弥散。由此,为了解决层次过高带来的卷积网络梯度不稳定的问题,残差网络因此诞生

深度卷积网络的瓶颈

​ 理论上,增加网络层数后,网络可以进行更加复杂的特征提取,效果也会更好,这一点从神经网络的求导操作中即可得出,因此也意味着当模型更深时可以取得更好的效果。

​ 但是VGGGoogLeNet等网络(不了解的同学可以先去看看其他相关文章,本文不再赘述)单纯增加层数的时候遇到了一些瓶颈:简单的只增加卷积层,训练的误差不但没有降低,反而越来越高。

2.png

在CIFAR-10、imageNet等数据集上,单纯增加3x3卷积,训练和测试的误差都变大了。但这并不是过拟合的问题,因为56的网络的训练误差同样很高。这主要是深层网络中存在着梯度消失或者爆炸问题,模型的层数越多则会越难训练。

​ 从另一个角度考虑,加入我们把浅层模型直接拿过来,在此基础上增加新的层(就是把网络变得更深),我们此时可以考虑一种极端情况,新增加的层”什么都不学习“(就是F(x) = x),“什么都不学习”我们称为"恒等映射",用数学表达式成为:F(x) = x。其中F(x)是我们希望拟合的一个网络结构(就是我们的网络),或者是一种映射关系。

残差网络

​ 但是神经网络的各种激活函数并没有能够使得网络做到恒等映射残差的网络的初衷便是尽量使得模型做到恒等映射(即什么都不做),这样就不会因为网络的层数问题产生梯度爆炸和梯度弥散

残差网络原理

​ 上文中说了,残差网络就是一个使得卷积网络恒等映射(什么都不做)的激活函数。即F(x) = x。

​ 我们假设有一个网络的输出是H(x),输入为x,layers所做的一系列操作是F(x)。要让此网络做到F(x) = x,也就是做到H(x) = F(x) + x(F(x) = 0),换句话说:残差网络就是输出等于输入的网络(恒等映射)(如图)。

3.png

​ 从另一个角度分析:我们的卷积网络遇到了梯度爆炸和梯度弥散并且导致了网络的错误率随着层数的增加而增加。想降低错误率,我们要想个办法使得我们的卷积网络的错误率至少不会越来越高,然后再去谈降低错误率,而这正是残差网络的作用。

4.png

残差网络的实现

​ 具体实现上,残差网络有集中构建方式,如上图所示。上图中左侧虚线部分的输出和x直接相加,右侧的x经过1x1卷积层(此处的1x1卷积层就是上文中的恒等映射=>第二张图片)再与虚线部分相加。虚线部分包含了两个3x3卷积层。左侧虚线部分能和输入x直接相加的前提是,虚线里的3x3卷积层没有改变x的维度,可以直接相加。因为如果一旦改变了x的维度,一种办法是使用1x1的卷积层直接相加(肯定不会改变x的维度),另一种是用0填充x

import torch
from torch import nn
from torch.nn import functional as F

class Residual(nn.Module):
    """The Residual block of ResNet."""
    def __init__(self, input_channels, num_channels,
                 use_1x1conv=False, strides=1):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, num_channels,
                               kernel_size=3, padding=1, stride=strides)
        self.conv2 = nn.Conv2d(num_channels, num_channels,
                               kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(input_channels, num_channels,
                                   kernel_size=1, stride=strides)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.bn2 = nn.BatchNorm2d(num_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        Y += X
        return F.relu(Y)

ResNet-34

​ 基于基础的残差网络,我们可以构建出ResNet网络,论文中给出了基于ImageNet的数据集的ResNet-34和VGG-19的对比图。如下图中最左侧是VGG-19,中间只使用3x3卷积层的网络,右侧使用了残差网络的思想的ResNet-34ResNet-34除了第一层是7x7的卷积层、最后一层是全连接层外,中间全部都是3x3的卷积层

5.png

​ 最右侧是ResNet-34,命名为ResNet-34,是因为网络中7×7卷积层、3×3卷积层和全连接层共34层。在计算这个34层时,论文作者并没有将BatchNorm、ReLU、AvgPool以及Shortcut中的层考虑进去。右侧ResNet-34中的3×3卷积层的颜色不同,共4种颜色。每种颜色表示一个模块,由一组残差基础块组成,只不过残差基础块的数量不同,从上到下依次是[3, 4, 6, 3]个残差基础块。另外,图4右侧,残差基础块中用实线Shortcut表示维度没有变化(可以直接相加);虚线Shortcut表示维度变化了,比如通道数从64变为128,无法直接相加,或者在xx上填充0,或者使用1×1卷积层改变维度。

相关文章
|
机器学习/深度学习 算法 数据挖掘
YOLOv6 | 模型结构与训练策略详细解析
YOLOv6 | 模型结构与训练策略详细解析
2491 0
YOLOv6 | 模型结构与训练策略详细解析
|
关系型数据库 MySQL 数据库
深入探讨MySQL中的幻读现象:原因、影响及解决方案
**导言:** 在数据库领域中,幻读(Phantom Read)是一个常见但容易被忽视的问题。它可能会导致事务的隔离级别无法满足预期,从而引发数据一致性问题。MySQL作为广泛使用的关系型数据库,也不免遇到幻读问题。本文将深入解析MySQL中的幻读现象,探讨其原因、影响以及可能的解决方案。
2429 0
|
9月前
|
机器学习/深度学习 计算机视觉
《深度剖析:残差连接如何攻克深度卷积神经网络的梯度与退化难题》
残差连接通过引入“短路”连接,解决了深度卷积神经网络(CNN)中随层数增加而出现的梯度消失和退化问题。它使网络学习输入与输出之间的残差,而非直接映射,从而加速训练、提高性能,并允许网络学习更复杂的特征。这一设计显著提升了深度学习在图像识别等领域的应用效果。
438 13
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】31. 卷积神经网络之残差网络(ResNet)介绍及其Pytorch实现
【从零开始学习深度学习】31. 卷积神经网络之残差网络(ResNet)介绍及其Pytorch实现
|
机器学习/深度学习 编解码 自然语言处理
ResNet(残差网络)
【10月更文挑战第1天】
|
机器学习/深度学习 数据采集 数据挖掘
深度学习之地形分类与变化检测
基于深度学习的地形分类与变化检测是遥感领域的一个关键应用,利用深度学习技术从卫星、无人机等地球观测平台获取的遥感数据中自动分析地表特征,并识别地形的变化。这一技术被广泛应用于城市规划、环境监测、灾害预警、土地利用变化分析等领域。
779 2
|
机器学习/深度学习 vr&ar
深度学习笔记(十):深度学习评估指标
关于深度学习评估指标的全面介绍,涵盖了专业术语解释、一级和二级指标,以及各种深度学习模型的性能评估方法。
551 0
深度学习笔记(十):深度学习评估指标
|
域名解析 网络协议
DNS服务工作原理
文章详细介绍了DNS服务的工作原理,包括FQDN的概念、名称解析过程、DNS域名分级策略、根服务器的作用、DNS解析流程中的递归查询和迭代查询,以及为何有时基于IP能访问而基于域名不能访问的原因。
1440 2
DNS服务工作原理
|
机器学习/深度学习 API 算法框架/工具
残差网络(ResNet) -深度学习(Residual Networks (ResNet) – Deep Learning)
残差网络(ResNet) -深度学习(Residual Networks (ResNet) – Deep Learning)
655 0
|
SQL 关系型数据库 MySQL
MySQL模糊查询二三事
在实际应用中,根据需求和实际数据情况,选择合适的模糊查询方法并优化查询模式,是确保查询效率和准确性的关键。复杂的查询模式往往需要详细的测试和调优,以达到最佳的性能与响应时效。
704 4