找到神经网络的全局最小值到底有多难?

简介: 数天前,机器之心发布了 Simon S. Du 等人的论文《Gradient Descent Finds Global Minima of Deep Neural Networks》引起了大家激烈的讨论。对此论文,读者们褒贬不一。此外,机器之心通过读者留言了解到微软研究院 Zeyuan Allen-Zhu、斯坦福 Yuanzhi Li、德州大学奥斯汀分校 Zhao Song(共同一作)稍微早些时候也发布了一篇类似的论文,但有更好的结果。后经沟通联系,机器之心对微软的这篇论文进行了跟进报道,希望能为读者提供更全面的内容参考,更好的理解两篇论文。

在细致解读微软研究院的这篇论文之前,读者们可以先了解下微软这篇论文与 Simon S. Du 等人论文的对比(详见微软这篇论文的第二页)。


最重要的区别是,Simon Du 等人证明了全连接深度网络(非 ResNet)的收敛时间关于层数 L 的依赖是不超过指数级别 2^O(L) 的,而残差网络(ResNet)是多项式级别 poly(L) 的。Simon Du 等人因此给出理论依据,判断 ResNet 的收敛性更好。本文作者指出,这样的推断是逻辑错误的,因为本文证明了全连接网络也同样在多项式级别 poly(L) 时间内收敛(所以 Simon Du 等人文中的「不超过指数」,其实是和残差网络一样的多项式)。也就是说,ResNet 相对于非 ResNet 的优势,实际上有更深层的原因,而不是像 Simons Du 文章里声称的,是指数和多项式的区别。两篇文章的其它的区别包括,Simon Du 等的人的结果,隐藏了其它的可能指数级别的参数,以及 Simon Du 的结果不能处理最常用的 ReLU 激活函数,等等。


微软的这篇论文是基于 [Li-Liang 2018] 在今年 NIPS 2018 上发表的一个深度延伸。Li 和 Liang 证明了只有一个隐藏层的神经网络,在过参数化的情况下可以找到全局最优。至于多层网络,需要开发更多的理论技术。一个具体的例子是这样的。假设训练数据不退化(相对距离为δ),那么如何证明数据传递到了最后一层,也不会发生退化?这篇论文证明了,只要过参数化(引理 4.5),那么样本传递到最后一层,相对距离依然可以有δ\/2。


下面是对微软研究院这篇论文的技术介绍:


论文:A Convergence Theory for Deep Learning via Over-Parameterization


微信图片_20211130151111.jpg


摘要:深层神经网络(DNN)已在许多领域表现出主导性能;自 AlexNet 以来,实践中使用的神经网络越来越深,越来越宽。然而在理论方面,前人大部分的工作在关注为什么我们可以训练只有一个隐藏层的神经网络。多层网络的理论仍然不明确。


在这项工作中,我们证明了为什么常用的算法,比如随机梯度下降(SGD),可以在多项式时间内找到 DNN 训练的全局最优解。我们只做两个假设:输入数据不退化,和网络过参数化。前者意味着输入不存在两个相同的数据点有矛盾的标签;后者意味着隐藏神经元的数量足够多,也就是关于层数 L,以及样本数量 n 都是多项式级别。

作为一个具体示例,在训练集上从随机初始的权重开始,我们证明了在关于 n 和 L 的多项式时间内,SGD 就可以在分类任务中达到了 100%的准确率,也就是找到全局最优解。我们的理论可适用于最常用的 ReLU 激活函数,适用于任何光滑甚至非凸的损失函数。在网络架构方面,我们的理论至少可以适用于全连接网络,卷积网络(CNN)和残差网络(ResNet)。


神经网络在众多机器学习任务中取得了巨大成功。其中一项实验结果表明,通过随机初始化的一阶方法训练的神经网络,具有非常强的拟合训练数据的能力。从容量的角度来看,拟合训练数据的能力可能并不令人惊讶:现代神经网络总是过参数化,它们具有远多于训练样本总数的参数。因此,理论上,只要数据不退化,总会存在实现零训练误差的参数选择。


然而,从优化的角度来看,一阶方法可以在训练数据上找到全局最优解这事情「非常不简单」。大家常用的神经网络通常配备 ReLU 激活函数,这使得神经网络不仅是非凸的,甚至非光滑。与之相对的是,优化理论中,如何找到非凸、非平滑函数的哪怕是一阶、二阶临界点的收敛性也是不明确的 [Burke, 2005],更不用提全局最优解。那么,实际训练中,随机梯度下降法(SGD)是如何在含有 ReLU 的深度神经网络中,收敛到全局最小值的呢?


1638256314(1).png


细节


这篇文章的细节其实可以由如下两个简单的定理和图片概括(文中 Sec 3.1)。假设损失函数是平方拟合(l_2 regression loss)


文中定理 3(没有马鞍点):在一定条件下(比如 SGD 的移动路径上),神经网络目标函数的梯度模长的平方,大于目标函数值本身,除以一个多项式因子:


1638256367.png


定理 3 说了一件很简单的事情,就是只要没有达到全局最优,那么函数梯度就一定大于零,并且函数越大,梯度的模长就越大。换言之,在 SGD 的移动路径上,只要训练损失 (training loss) 不到 0,就不会出现马鞍点,更不会出现局部最小值。这个结果本身就很特殊,因为大部分的非凸问题不满足这个性质,而过度参数化的神经网络,用 SGD 进行训练,却可以保证得到这个性质!


有了定理 3,就可以证明 SGD 收敛了么?并没有,因为如果 SGD 向梯度的反方向移动,为什么函数值会下降?「函数值会下降」在优化理论中对应了光滑性 (smoothness)。传统的优化理论中有很多关于光滑性的定义,但都需要函数至少二阶可导(可惜 ReLU 激活函数并不存在二阶导数)。这篇文章的另一个精髓,在于证明了过度参数化的神经网络满足以下的一个「半光滑性」。


文中定理 4(半光滑性):在一定条件下(比如 SGD 的移动路径上),神经网络目标函数和其一阶近似之间的距离「很小」:


1638256402(1).png


与传统光滑性不同的是,这里不等式的右边有一个关于‖ΔW‖的一阶项,文中说明,这个一阶项会随着神经元数量越来越多,变得越来越小。也就是网络参数越多,会越「光滑」,也就越容易做训练。


为了理解定理 3 和定理 4,我们可以参见文中的图片。当目标函数值在 1.3(并没有到全局最优)的时候,函数的梯度很大(定理 3),并且向梯度方向走,的确可以有效地降低目标函数值(定理 4)。


微信图片_20211130151352.jpg

延伸



本文中关于 SGD 收敛性的证明,停留在了前馈网络(包括 CNN,ResNet 等),那么是否可以扩展到其它的更复杂的深度学习网络呢?例如在自然语言处理应用中广泛使用的递归神经网络(RNN)?作者强调了 RNN 其实是一个比 DNN 更难的问题(第三页)。在两周前,本文的作者将这个问题单独成稿,发表在了 arXiv 上(链接:https://arxiv.org/abs/1810.12065)。


目前提到的多层网络收敛性都是针对训练数据找全局最优。那么如何推广到测试数据集呢?这篇文章并没有涉及,但是在第三页援引了一个同样是这周传到 arXiv 的重要工作 [Allen-Zhu, Li, Liang 2018]:证明了过度参数化的三层神经网络的训练集最优解可以推广到测试集!具体而言,如果数据是由一个未知的三层神经网络产生的,那么使用过参数化的三层神经网络,和 SGD 进行训练,只要多项式级别那么多样本,就可以学出能在测试集上完成比如分类、拟合问题的未知网络。这个结果是对本文的一个很好的补充(链接:https://arxiv.org/abs/1811.04918)。

相关文章
【网站部署】解析二级域名并部署网站(一)
【网站部署】解析二级域名并部署网站(一)
786 0
【网站部署】解析二级域名并部署网站(一)
|
2月前
|
安全 Shell Linux
深入剖析Sudo提权:白帽子的防御视角与审计指南
本文深入解析Linux系统中`sudo`提权的常见手法,从白帽子视角出发,剖析攻击原理并提供实用防御与审计策略,助力加固系统权限安全。
393 1
|
NoSQL 关系型数据库 MySQL
排行榜系统设计:高并发场景下的最佳实践
本文由技术分享者小米带来,详细介绍了如何设计一个高效、稳定且易扩展的排行榜系统。内容涵盖项目背景、技术选型、数据结构设计、基本操作实现、分页显示、持久化与数据恢复,以及高并发下的性能优化策略。通过Redis与MySQL的结合,确保了排行榜的实时性和可靠性。适合对排行榜设计感兴趣的技术人员参考学习。
1739 7
排行榜系统设计:高并发场景下的最佳实践
|
传感器 人工智能 物联网
数字孪生与灾害预测:提升应急响应能力
【10月更文挑战第31天】数字孪生技术通过实时监测、灾害模拟和应急响应优化,显著提升了灾害预测和应急响应能力。本文探讨了其在洪水、地震等自然灾害中的应用,展示了其在提高预警准确性、优化资源配置和提升应急响应效率方面的巨大潜力。
|
数据库 Android开发 开发者
构建高效Android应用:采用Kotlin与Jetpack的实践指南
【4月更文挑战第30天】 随着移动开发技术的不断演进,Android平台提供了多种工具和框架以提升应用性能和开发效率。在本文中,我们将深入探讨如何结合Kotlin语言的简洁性和Android Jetpack组件的强大功能来构建一个既高效又可维护的Android应用。通过分析现代Android应用架构的关键要素,我们将展示如何利用Kotlin的特性以及如何整合Jetpack中的LiveData、ViewModel和Room等组件,以实现响应式编程、数据持久化和生命周期管理。
|
Linux 数据安全/隐私保护
Linux系统忘记密码的三种解决办法
这篇博客介绍了三种在Linux忘记密码时重置登录密码的方法:1) 使用恢复模式,通过控制台界面以管理员权限更改密码;2) 利用Linux Live CD/USB启动,挂载硬盘分区并使用终端更改密码;3) 进入单用户模式,自动以管理员身份登录后重置密码。每个方法都提供了详细步骤,提醒用户在操作前备份重要数据。
|
XML JSON 应用服务中间件
使用Python的requests库发送SOAP请求,错误码415
使用Python的requests库发送SOAP请求,错误码415
464 0
|
存储 网络协议 关系型数据库
微服务架构 | 3.2 Alibaba Nacos 注册中心
Nacos 致力于解决微服务中的统一配置、服务注册与发现等问题。它提供了一组简单易用的特性集,帮助开发者快速实现动态服务发现、服务配置、服务元数据及流量管理;
1220 0
微服务架构 | 3.2 Alibaba Nacos 注册中心
|
机器学习/深度学习 弹性计算 自然语言处理
|
机器学习/深度学习 人工智能 自然语言处理
7.5亿美元做代码转换?一个Facebook TransCoder AI就够了!
代码的迁移和语言转换是一件很困难且昂贵的事情,澳大利亚联邦银行就曾花费5年时间,耗费7.5亿美元将其平台从COBOL转换为Java。而Facebook最近宣称,他们开发的一种神经转换编译器(neural transcompiler),可以将一种高级编程语言(如C ++,Java和Python)转换为另一种,效率飞起!
872 0
7.5亿美元做代码转换?一个Facebook TransCoder AI就够了!