新颖训练方法——用迭代投影算法训练神经网络

简介: 本文介绍了一种利用迭代投影算法对神经网络进行训练的方法,首先介绍了交替投影的基础知识,说明投影方法是寻找非凸优化问题解决方案的一种有效方法;之后介绍了差异图的基础知识,将差异图与一些其他算法相结合使得差分映射算法能够收敛于一个好的解决方案;当投影的情况变多时,介绍了分治算法,最后将迭代投影算法应用到神经网络训练中,给出的例子实验结果表明效果不错。

首发地址:https://yq.aliyun.com/articles/72738


作者介绍Jesse Clark

8293cf57e34cba700a3cd9a473017b739d1d3627

研究相位恢复的物理学家、数据科学家,有着丰富的建设网站与设计手机应用的经验,在创业公司有着丰富的经验,对创业有着极大的热情。

 Github: https://github.com/jn2clark

Linkedin: http://www.linkedin.com/in/j3ss3cl4rk

相位恢复(PR)关心的是在给定幅度信息以及受到实空间限制,找到复值函数(通常在傅立叶空间中)的相位[1]

PR是一个非凸优化问题,已经成为大量工作[1,2,3,4,5,6,9]的主题,并且成为结晶学的支柱,是结构生物学的中坚力量

下面显示的是PR重建过程的一个例子,展示了3D弥散数据(傅里叶幅度)重构实空间3D密度纳米晶体[15]

1582d35b16b7b80825386276994866e970e3f9d5

大多PR问题的成功算法是基于投影的方法,这是受到凸优化投影到凸集上的启发[10]。由于基于投影的方法在PR上取得成功探索能否使用类似的方法训练神经网络。

交替投影

738ab9dc7400d4866d941609053aa318dc5958a0

凸集投影(POCS)是找到凸集之间交点的有用方法。上面显示了一个简单的例子,其中两个凸约束集C1(红色)和C2蓝色)。通过简单迭代映射连续地投影每个集合来找到交集:

3cfaf5f2282fa69d39c94bf4b71bc45e4091ebd4

其中P各自的集合上的投影。投影是幂等PP=P并且是距离最小化;

P(x)=y以至于33e2f72835a10956c0b97d8189890d7e54e18db9最小;

当满足下式的时候,能够发现解决方案:

d54977aa9958b20ea6a9af990053242dc8396a0a

当约束集非凸时,很少能得出一般结论。因此,使用简单的交替投影可能导致局部最小值的停滞。下面展示一个例子,其中集合被设置为非凸,找到交集(全局小值)的能力高度依赖于初始猜测

57ebb6657d27c4cebd050bb9571c2f2a74d94e31

尽管集合凸的情况下失去了保障,但投影方法证明是寻找非凸优化问题解决方案的一种有效方法。例子包括数独、n皇后问题图形着色和相位检索[4,10]

差异图

最成功的非凸投影算法之一是差分图(DM)[4,8],可以写成

0f9d77805a8190f99f5a748ed48def79b29f2c69

其中

067d1fc68d787aeb2f3daa4014e24136eebe1bcb

其中y1和y2被称为估计。一旦达到定点

3197703b6a36567a735776b90268658b410227fd

这意味着两个估计等价于解决方案

0f810da3975723b48ed178ac76bdd27f15cde6b5

差异图通过作为泛化或等价特定超参数,关联了PR文献中许多的不同算法[1,3,6]不于上述形式,简单版本差异图经常被使用

b604bfebb15b24b8b92ad4e571fc3b7017e09b51

这种更简单的版本通常表现良好,并减少每次迭代所需的投影数量(投影的顺序也可以切换)。公式中的2P2-I也被称为反射操作,出现在许多投影算法中[9]

同样的非凸问题如下所示,使用差分映射算法被困在局部最小值中,而是能够逃脱搜索更多的解空间,最后收敛一个解决方案。

cf220dd0037aed4ef5be11f9cde3603f1efcdec2

分治算法

差异图先前被定义为两个投影,那么当有两个以上时会发生什么呢?在这种情况下,定义一个新的迭代X,它是n重复连接[10]

038a1bc7b2887968ff1137a27a40ab0b18ca6d0d

然后定义平均和直积投影;

52d66299f71d58223d288321e042b522957f9e61

其中Pll投影,x是加权和;

a6147e595ed94442a7c9173f154a4a1c8eeb4782

那么许多预测的差异图

9e6026186163a00142f9243b725b779ecfda30d4

更新X:

84ad0d1eb3d398afc1c9df8f8218591372302551

这种方法被称为“分治算法”。下面是一个数独拼图的迭代例子,收敛使用差异图与分治算法

2fe9c5ca429d38df178803b77b92f64c6f08bfca

数独有4个约束每行的数字为1到9,每列的数字为1到9,3x3子方格的数字为1到9,最后数字与部分填充的模板一致。代码实现这个例子

用于训练神经网络的投影

对差异图投影及其在非凸优化中的应用有了解,下一步是对神经网络的训练进行预测。下例仅考虑一个分类任务基本思想是寻找一个正确分类数据的权重向量将数据分解成K个子集

719a9e052cabc837fe4a64508976ed5b7ae3c335

定义一个“投影”权重的投影,使得子集中的所有训练数据被正确分类(或者损失为0)。实际上,使用的是子集的梯度下降来实现投影(基本上是过度拟合的点)。目标是获得能正确分类每个数据子集的权重,并且查找这些集合的交集。

结果

为了测试训练方案(代码,使用标准方法[13]训练一个小型网络,并将其与基于投影的方法进行比较。小型网络使用非常简单的层,大约包含22000个参数; 1个卷积层,8个3x3滤波器2个子采样1个全连接层(激活函数为ReLU),16个节点最后softmax10个输出(MNIST的10类)。使用Glorot uniform[11]初始化权重

下图显示其平均训练和测试损失曲线:

e03afedc47bfb2f7abbc76e75a0e3f06ec99cb0c

训练损失曲线

9682470245de3125f87ee3b2d642e85bacedc693

测试损失函数

从图中可以出效果不错。训练数据被分为大小相同的3组,都被用于投影约束。对于投影而言,需要找到一组最新的权重,使其与先前一组权重的距离最小另外使用梯度下降法进行训练,一旦训练数据的准确度达到99%就终止投影。更新后的权重投影到3组上产生3个新的权重集合,这些集合连接在一起以形成

41b46f89b66cf3bff2e165d3f5ea9ae0f9cf55d7

平均投影可以通过将权重平均得到之后进行复制并连接形成新的向量

1dc98aa873d97c3b500e0ab37fc7f954f4187ad7

根据差异图将这两个投影步骤组合以获得权重的更新方案。除了常规度量外,还可以监视差异图误差来寻找收敛。差异映射误差由下式定义:

a8b67c2bc4dc65fdc7480ba806a2783547f649e2

上式值越低,表明解决方案越。差异图错误达到稳定表明已经找到了一个近似的解决方案。差异图错误通常在稳定突然下降[4],表明找到合适的解决方案。

67c3015a68fe1bacf4378314fa989e94bec8e49e

在上例中,投影是通过训练数据子集上的反复梯度变化定义,本质上是过度拟合的点。在下例中,遍历完一次训练数据终止投影

下面显示的是平均cv测试和训练误差(与上述相同的常规训练相比)

9eb15578194f31947ad2cdaea989b9ab57a6080d

66ba6ca81398eeb0776abdda16df2cdbeeaeb147

478674c51c616a43265b9186c4e12014ab9605d4

从图中可以看到这种方法仍然可行,为什么会这样呢?如果投影操作提前终止,那么想到的一点就是简单地将投影视为一个松弛投影或非最佳投影。凸优化和PR的结果[4,5,7,14]仍然表明,松弛投影或非最佳投影趋于的解决方案。另外,在单遍历投影限制中,可以通过交替投影来恢复传统的基于梯度下降的训练方案(以3组为例)

a5bc188b28b440f651733ed296c1c0f4027fd8f1

最后,常规训练中的参数设置会对网络的结果产生很大的影响,具体参数设置可以查看原文。训练这样的网络并执行提前终止,传统训练方法最终损失和准确度分别为0.0724和97.5%,使用差异图方法的结果分别为0.0628和97.9%。

投影方法扩展

关于投影方法的好处之一是可以轻松实现额外的约束。对于L1正则化而言,可以定义收缩或软阈值操作,如

944752dbcecf0a5c5056b6545c4f54d9e3210974

其他投影可以是卷积核的对称性或权重的直方图约束。

其他注意事项

本文还有很多未回答的问题,并没有深入研究比如最佳集合数是多少投影操作如何工作、近解决方案平均有助于泛化等问题。虽然还有很多问题需要回答,但是使用相位检索和非凸投影方法来重新构建训练得到了一些有趣的结果。

 

本文由北邮@爱可可-爱生活老师推荐,阿里云云栖社区组织翻译。

文章原标题《Training neural networks with iterative projection algorthms》,作者:Jesse Clark,译者:海棠,审阅:tiamo_zn

文章为简译,更为详细的内容,请查看原文

Wechat:269970760 

Email:duanzhch@tju.edu.cn

微信公众号:AI科技时讯

157f33dddfc596ede3681e0a2a0e7068dc288cc1

目录
相关文章
|
18天前
|
机器学习/深度学习 人工智能 算法
猫狗宠物识别系统Python+TensorFlow+人工智能+深度学习+卷积网络算法
宠物识别系统使用Python和TensorFlow搭建卷积神经网络,基于37种常见猫狗数据集训练高精度模型,并保存为h5格式。通过Django框架搭建Web平台,用户上传宠物图片即可识别其名称,提供便捷的宠物识别服务。
200 55
|
9天前
|
机器学习/深度学习 算法
基于改进遗传优化的BP神经网络金融序列预测算法matlab仿真
本项目基于改进遗传优化的BP神经网络进行金融序列预测,使用MATLAB2022A实现。通过对比BP神经网络、遗传优化BP神经网络及改进遗传优化BP神经网络,展示了三者的误差和预测曲线差异。核心程序结合遗传算法(GA)与BP神经网络,利用GA优化BP网络的初始权重和阈值,提高预测精度。GA通过选择、交叉、变异操作迭代优化,防止局部收敛,增强模型对金融市场复杂性和不确定性的适应能力。
138 80
|
28天前
|
机器学习/深度学习 人工智能 算法
【宠物识别系统】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+图像识别
宠物识别系统,本系统使用Python作为主要开发语言,基于TensorFlow搭建卷积神经网络算法,并收集了37种常见的猫狗宠物种类数据集【'阿比西尼亚猫(Abyssinian)', '孟加拉猫(Bengal)', '暹罗猫(Birman)', '孟买猫(Bombay)', '英国短毛猫(British Shorthair)', '埃及猫(Egyptian Mau)', '缅因猫(Maine Coon)', '波斯猫(Persian)', '布偶猫(Ragdoll)', '俄罗斯蓝猫(Russian Blue)', '暹罗猫(Siamese)', '斯芬克斯猫(Sphynx)', '美国斗牛犬
152 29
【宠物识别系统】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+图像识别
|
23天前
|
机器学习/深度学习 数据采集 人工智能
基于Huffman树的层次化Softmax:面向大规模神经网络的高效概率计算方法
层次化Softmax算法通过引入Huffman树结构,将传统Softmax的计算复杂度从线性降至对数级别,显著提升了大规模词汇表的训练效率。该算法不仅优化了计算效率,还在处理大规模离散分布问题上提供了新的思路。文章详细介绍了Huffman树的构建、节点编码、概率计算及基于Gensim的实现方法,并讨论了工程实现中的优化策略与应用实践。
65 15
基于Huffman树的层次化Softmax:面向大规模神经网络的高效概率计算方法
|
2天前
|
机器学习/深度学习 算法
基于遗传优化的双BP神经网络金融序列预测算法matlab仿真
本项目基于遗传优化的双BP神经网络实现金融序列预测,使用MATLAB2022A进行仿真。算法通过两个初始学习率不同的BP神经网络(e1, e2)协同工作,结合遗传算法优化,提高预测精度。实验展示了三个算法的误差对比结果,验证了该方法的有效性。
|
5天前
|
机器学习/深度学习 数据采集 算法
基于PSO粒子群优化的CNN-GRU-SAM网络时间序列回归预测算法matlab仿真
本项目展示了基于PSO优化的CNN-GRU-SAM网络在时间序列预测中的应用。算法通过卷积层、GRU层、自注意力机制层提取特征,结合粒子群优化提升预测准确性。完整程序运行效果无水印,提供Matlab2022a版本代码,含详细中文注释和操作视频。适用于金融市场、气象预报等领域,有效处理非线性数据,提高预测稳定性和效率。
|
3天前
|
算法 网络协议 Python
探秘Win11共享文件夹之Python网络通信算法实现
本文探讨了Win11共享文件夹背后的网络通信算法,重点介绍基于TCP的文件传输机制,并提供Python代码示例。Win11共享文件夹利用SMB协议实现局域网内的文件共享,通过TCP协议确保文件传输的完整性和可靠性。服务器端监听客户端连接请求,接收文件请求并分块发送文件内容;客户端则连接服务器、接收数据并保存为本地文件。文中通过Python代码详细展示了这一过程,帮助读者理解并优化文件共享系统。
|
27天前
|
机器学习/深度学习 人工智能 算法
深入解析图神经网络:Graph Transformer的算法基础与工程实践
Graph Transformer是一种结合了Transformer自注意力机制与图神经网络(GNNs)特点的神经网络模型,专为处理图结构数据而设计。它通过改进的数据表示方法、自注意力机制、拉普拉斯位置编码、消息传递与聚合机制等核心技术,实现了对图中节点间关系信息的高效处理及长程依赖关系的捕捉,显著提升了图相关任务的性能。本文详细解析了Graph Transformer的技术原理、实现细节及应用场景,并通过图书推荐系统的实例,展示了其在实际问题解决中的强大能力。
140 30
|
14天前
|
JSON 算法 Java
Nettyの网络聊天室&扩展序列化算法
通过本文的介绍,我们详细讲解了如何使用Netty构建一个简单的网络聊天室,并扩展序列化算法以提高数据传输效率。Netty的高性能和灵活性使其成为实现各种网络应用的理想选择。希望本文能帮助您更好地理解和使用Netty进行网络编程。
34 12
|
15天前
|
域名解析 缓存 网络协议
优化Lua-cURL:减少网络请求延迟的实用方法
优化Lua-cURL:减少网络请求延迟的实用方法

热门文章

最新文章