wGAN如何解决GAN已有问题(附代码实现)

简介:

随着柯洁与AlphaGo的比赛结束以后,大家是不是对人工智能的底层奥秘越来越有兴趣?


深度学习已经在图像分类、检测等诸多领域取得了突破性的成绩。但是它也存在一些问题。


首先,它与传统的机器学习方法一样,通常假设训练数据与测试数据服从同样的分布,或者是在训练数据上的预测结果与在测试数据上的预测结果服从同样的分布,而实际上这两者存在一定的偏差。另一个问题是深度学习的模型(比如卷积神经网络)有时候并不能很好地学到训练数据中的一些特征。深度对抗学习(deep adversarial learning)就是为了解决上述问题而被提出的一种方法。


学习的过程可以看做是我们要得到一个模型,为了建模真实的数据分布,生成器学习生成实际的数据样本,而鉴别器学习确定这些样本是否是真实的。如果这个鉴别器的水平很高,而它无法分清它们之间的区别,那么就说明我们需要的模型具有很好的表达或者预测能力。


非监督学习是通往真正人工智能的方向,本文回顾了从传统机器学习,到wGAN的逻辑发展过程。GAN能自己生成特征、问题、评估函数,是近年来深度学习的一个突破。而wGAN解决了GAN已有的问题,“一个月内改变行业”,是深度学习的最新进展。本文让读者对wGAN的历史发展有个清晰的认识,并提供了wGAN的代码实现,是一篇很好的学习wGAN的入门材料。



对抗学习是深度学习中最火的一个领域。网站arxiv-sanity的最近最流行的研究领域列表上,许多都是对抗学习,本文同样也是一篇讲对抗学习的文章。

 

在这篇文章中,我们主要学习以下三个方面的内容:


  • 为什么我们应该关注对抗学习

  • 生成对抗网络GANs(General Adversarial Networks) 和它面临的挑战

  • 能解决这些挑战的Wasserstein GAN和改进的稳定训练Wasserstein GAN的方法,还包括了代码实现。

 

从传统机器学习到深度学习


我在UIUC上“模拟信号与系统”课程的时候,教授在一开始就信誓旦旦地说:“这个课程将是你们上的最重要的课程,抽象是工程里面最重要的概念。” 


康奈尔大学的课程里面也有“解决复杂问题的方法就是抽象,也就是隐藏细节信息。抽象屏蔽掉无用的细节。为了设计一个复杂系统,你必须找出哪些是你想暴露给其他人的,哪些是你想隐藏起来的。暴露给其他人的部分,其他人可以进行设计。暴露的部分就是抽象。”

 

深度神经网络中的每层就是数据的抽象表示,层和层之间有依赖关系,最终形成一个层次结构。每一层都是上一层的一个更高级的抽象。给定一组原始数据和要解决的问题,然后定义一个目标函数来评估网络输出的答案,最终神经网络就能通过学习得到一个最优的解。

 

因此,特征是神经网络自己学习得来的。但是在传统的机器学习中,特征和算法都是人工定义的。

 

现在的数据的特征、结构、模式都是网络自我学习的,而不是像传统机器学习那样人工定义。所以以前无法实现的AI的算法现在可行了,并且在某些方面超过了人类。

 

从深度学习到深度对抗学习


很多年前,我学习过拳击。我的拳击教练不让新手问问题,说新手不知道问什么问题,连问的问题都是错误的,会得到没用的答案,会专注于错误的东西,越学越错。

 

Robert Half说过“会问问题和会解题一样,都需要一定的水平”

 

对抗学习的奇妙之处在于所有的东西都是从数据中学习得到的,包括要解决的问题,最终的答案以及评估答案的标准—目标函数。传统的深度学习中,是由人来决定要解决什么问题,人来决定用什么目标函数做评估。

 

Deep Mind公司用AlphaGo证明了深度对抗学习的厉害之处。在围棋比赛中,AlphaGo可以自己创造新的下法和招数。这开创了围棋的新纪元,突破了过去几千年的一个瓶颈,达到了新的高度。AlphaGo能做到这点是因为它能自己给自己打分,可以随时计算当前的局势的分数,而不用预先人工定义和预编程。这样,AlphaGo自己和自己下了几百万局的比赛。听起来很像对抗学习吧?



AlphaGo不仅仅是暴力破解,而是真正掌握了围棋比赛,学到了围棋的招式。之所以这样,是因为它没有被人类束缚,既没有得到人类先验的输入,也不受我们对问题域理解的局限。无法想象,当我们把这些成果应用到实际生活中,AI会如何改造农业、医疗等等。但是这一定会发生。

 

生成对抗网络GAN


Richard Feynman说“如果要真正理解一个东西,我们必须要能够把它创造出来。”

 

正是这句话激励着我开始学习GANs。GANs的训练过程就是两个神经网络自己在作对抗,通过对抗不断的学习。当然学习是在原始数据的基础上学习。



生成器通过对原始数据的分布进行建模,学习如何生成近似数据;而判别器用来判断数据是生成器生成的数据还是原始的真实的数据。这样生成器就能重新创造出原始数据的近似数据。我们相信为了能够理解一个东西,我们要能重新创造这个东西,所以GAN是非常有价值的,我们的努力也是值得的。


如果我们能成功使得GAN达到纳什均衡(完美的判别器也不能识别数据到底是真实数据还是生成数据),我们就能够把这个成果应用到几乎任何事情上,并且还能够有最好的性能。


存在的问题


GANs很难优化,并且训练过程不稳定。网络结构必须设计的非常好,生成器和判别器之间必须有个很好的协调,才能使得训练过程收敛。这些问题中,最显著的就是失去样本多样性(mode dropping, 即生成器只从很小一部分的数据集中学习)。还有由于GANs的学习曲线基本没什么意义,因此很难调试。

 

虽然如此,仍然通过GANs得到了最先进的一些成果。但是就是因为这些问题,GANs的应用被限制住了。


解决方法


Alex J. Champandard说“一个月内,传统的训练GANs的方法会被当做黑暗时代的方法”。

 

GANs的训练目标是生成数据和真实数据的分布的距离差的最小化。


最开始使用的是Jensen-Shannon散度。但是,Wasserstein GAN(wGAN)文章在理论和实际两个方面,都证明了最小化推土距离EMD(Earth Mover’s distance)才是解决上述问题的最优方法。当然在实际计算中,由于EMD的计算量过大,因此使用的是EMD的合理的近似值。


为了使得近似值有效,wGAN在判别器(在wGAN中使用了critic一词,和GAN中的discriminator是同一个意思)中使用了权重剪裁(weight clipping)。但是正是权重剪裁导致了上述的问题。

 

后来对wGAN的训练方法进行了改进,它通过在判别器引入梯度惩罚(gradient penalty)使得训练稳定。梯度惩罚只要简单的加到总损失函数中的Wasserstein距离就可以了。



历史上第一次,终于可以训练GAN而几乎不用超参数调优了。其中包括了101层的残差网络和基于离散数据的语言模型。

 

Wasserstein距离的一个优势就是当判别器改进的时候,生成器能收到改进的梯度。但是在使用Jensen-Shannon散度的时候,当判别器改进的时候,产生的梯度消失,生成器无法学习改进。这个也是产生训练不稳定的主要原因。

 

如果想对这个理论有深入理解,我建议读一下下面两个文章:

  • Wasserstein GAN

  • Wasserstein GANs的改进的训练方法

 

随着新的目标函数的引入,我看待GANs的方式也发生了变化:

 

传统的GAN(Jensen-Shannon散度)下,生成器和判别器是竞争关系,如下图。



在wGAN(Wasserstein距离)下,生成器和判别器是协作关系,如下图。



代码实现


结论


对抗学习的网络不受我们对问题域理解的任何限制,没有任何先验知识,网络就是从数据中学习。



原文发布时间为:2017-06-27

本文作者:Michael Dietz

本文来自云栖社区合作伙伴“数据派THU”,了解相关信息可以关注“数据派THU”微信公众号

相关文章
|
2月前
|
机器学习/深度学习 算法 TensorFlow
深度学习笔记(五):学习率过大过小对于网络训练有何影响以及如何解决
学习率是深度学习中的关键超参数,它影响模型的训练进度和收敛性,过大或过小的学习率都会对网络训练产生负面影响,需要通过适当的设置和调整策略来优化。
556 0
深度学习笔记(五):学习率过大过小对于网络训练有何影响以及如何解决
|
2月前
|
人工智能 人机交互 智能硬件
从大模型的原理到提示词优化
本文介绍了大语言模型(LLM)的基本概念及其工作原理,重点探讨了AI提示词(Prompt)的重要性和几种有效技巧,包括角色设定、One-shot/Few-shot、任务拆解和思维链。通过实例解析,展示了如何利用这些技巧提升LLM的输出质量和准确性,强调了提供高质量上下文信息对优化LLM表现的关键作用。
88 0
|
7月前
|
计算机视觉 网络架构
【YOLOv8改进】MSBlock : 分层特征融合策略 (论文笔记+引入代码)
YOLO-MS是一个创新的实时目标检测器,通过多尺度构建块(MS-Block)和异构Kernel选择(HKS)协议提升多尺度特征表示能力。它在不依赖预训练权重和大型数据集的情况下,在MS COCO上超越了YOLO-v7和RTMDet,例如YOLO-MS XS版本(4.5M参数,8.7G FLOPs)达到了43%+的AP,比RTMDet高2%+。MS-Block利用分层特征融合和不同大小的卷积,而HKS协议根据网络深度调整Kernel大小,优化多尺度语义信息捕获。此外,YOLO-MS的模块化设计允许其作为即插即用的组件集成到其他YOLO模型中,提升它们的检测性能。
|
7月前
|
前端开发 PyTorch 算法框架/工具
【基础实操】借用torch自带网络进行训练自己的图像数据
【基础实操】借用torch自带网络进行训练自己的图像数据
101 0
【基础实操】借用torch自带网络进行训练自己的图像数据
|
机器学习/深度学习 人工智能 算法
目标检测模型设计准则 | YOLOv7参考的ELAN模型解读,YOLO系列模型思想的设计源头
目标检测模型设计准则 | YOLOv7参考的ELAN模型解读,YOLO系列模型思想的设计源头
937 0
目标检测模型设计准则 | YOLOv7参考的ELAN模型解读,YOLO系列模型思想的设计源头
|
机器学习/深度学习 数据采集 人工智能
头疼!卷积神经网络是什么?CNN结构、训练与优化一文全解
头疼!卷积神经网络是什么?CNN结构、训练与优化一文全解
108 0
|
机器学习/深度学习 人工智能 算法
目标检测模型设计准则 | YOLOv7参考的ELAN模型解读,YOLO系列模型思想的设计源头(一)
目标检测模型设计准则 | YOLOv7参考的ELAN模型解读,YOLO系列模型思想的设计源头(一)
624 0
|
Go 网络架构 计算机视觉
目标检测模型设计准则 | YOLOv7参考的ELAN模型解读,YOLO系列模型思想的设计源头(二)
目标检测模型设计准则 | YOLOv7参考的ELAN模型解读,YOLO系列模型思想的设计源头(二)
1046 0
|
机器学习/深度学习 并行计算 固态存储
YOLO系列 | 一份YOLOX改进的实验报告,并提出更优秀的模型架构组合!
YOLO系列 | 一份YOLOX改进的实验报告,并提出更优秀的模型架构组合!
209 0
|
机器学习/深度学习 编解码 计算机视觉
小目标绝技 | 用最简单的方式完成Yolov5的小目标检测升级!
小目标绝技 | 用最简单的方式完成Yolov5的小目标检测升级!
1185 0