【Pytorch神经网络理论篇】 24 神经网络中散度的应用:F散度+f-GAN的实现+互信息神经估计+GAN模型训练技巧

本文涉及的产品
交互式建模 PAI-DSW,5000CU*H 3个月
简介: MINE方法中主要使用了两种技术:互信息转为神经网络模型技术和使用对偶KL散度计算损失技术。最有价值的是这两种技术的思想,利用互信息转为神经网络模型技术,可应用到更多的提示结构中,同时损失函数也可以根据具体的任务而使用不同的分布度量算法。

同学你好!本文章于2021年末编写,获得广泛的好评!


故在2022年末对本系列进行填充与更新,欢迎大家订阅最新的专栏,获取基于Pytorch1.10版本的理论代码(2023版)实现


Pytorch深度学习·理论篇(2023版)目录地址为:


CSDN独家 | 全网首发 | Pytorch深度学习·理论篇(2023版)目录


本专栏将通过系统的深度学习实例,从可解释性的角度对深度学习的原理进行讲解与分析,通过将深度学习知识与Pytorch的高效结合,帮助各位新入门的读者理解深度学习各个模板之间的关系,这些均是在Pytorch上实现的,可以有效的结合当前各位研究生的研究方向,设计人工智能的各个领域,是经过一年时间打磨的精品专栏!

https://v9999.blog.csdn.net/article/details/127587345

欢迎大家订阅(2023版)理论篇

以下为2021版原文~~~~

815902569f6a467a99304f9ac1482386.png



1 散度在无监督学习中的应用


在神经网络的损失计算中,最大化和最小化两个数据分布间散度的方法,已经成为无监督模型中有效的训练方法之一。


在无监督模型训练中,不但可以使用K散度JS散度,而且可以使用其他度量分布的方法。f-GAN将度量分布的做法总结起来并找出了其中的规律,使用统一的f散度实现了基于度量分布的方法实现基于度量分布方法训练GAN模型的通用框架。


1.1 f-GAN简述


f-GAN是是一套训练GAN的框架总结,它不是具体的GAN方法,它可以在GAN的训练中很容易实现各种散度的应用,即f-GAN是一个生产GAN模型的“工厂”。


它所生产的GAN模型都有一个共同特点:不进行任何先验假设,对要生成的样本分布使用最小化差异的度量方法,尝试解决一般性的数据样本生成问题(常用于无监督训练)。


1.2 基于f散度的变分散度最小化方法(Variational Divergence Minimization,VDM)


变分散度最小化方法是指通过最小化两个数据分布间的变分距离来训练模型中参数,这是f-GAN所使用的通用方法。在f-GAN中,数据分布间的距离使用f散度来度量。


1.2.1 变分散度最小化方法的适用范围


WGAN模型的训练方法、分自编码的训练方法也属于VDM方法。所有符合f-GAN框架的GAN模型都可以使用VDM方法进行训练。VDM方法适用于GAN模型的训练。


1.2.1 f散度


给定两个分布P、Q,p(x)和q(x)分别是x对应的概率函数,则f散度可以表示为;


fd01ca20f4bf4473b25a64d5b9f9ce45.png


f散度相当于一个散度“工厂”,在使用它之前必须为式中的生成函数f(x)指定具体内容。f散度会根据生成函数f(x)对应的具体内容,生成指定的度量算法。


58a693fe8846455a8d7da684286df366.png


f485f78198894ab8a2991c13abbc13b3.png


2 用Fenchel共轭函数实现f-GAN


2.1 .Fenchel共轭函数的定义(Fenchel conjugate)


Fenchel共轭/凸共轭函数,是指对于每个凸函数且满足下半连续的f(x),都有一个共轭函数f*的定义为:


667e67614c1d45aca4758bda8cf9a976.png


式中的f*(t)是关于t的函数,其中t是变量;dom(f)为f(x)的定义域;max即求当横坐标取t时,纵坐标在多条表达式为{xt-f(x)}的直线中取最大那条直线上所对应的点,如图所示。


10026d94e4c64da0831fb56f49f3769a.png


2.2 Fenchel共扼函数的特性


图8-23中有1条粗线和若干条细直线,这些细直线是由随机采样的几个x值所生成的f(x),粗线是生成函数的共轭函数f*。图8-23中的生成函数是f(x)=|x-1|÷2,该函数对应的算法是总变分(Total Variation,TV)算法。TV算法常用于对图像的去噪和复原。


61a071953d444751afe7dee757c65724.png


2.3 将Fenchel共轭函数运用到f散度中


76564c3cb45a488a9b44fe8bd22a8be2.png

7de4e0d27e184768b73888101fac46d1.png


2.4 用f-GAN生成各种GAN


将图8-22中的具体算法代入到式(8-40)中,便可以得到对应的GAN。有趣的是,对于通过f-GAN计算出来的GAN,可以找到好多已知的GAN模型。这种通过规律的视角来反向看待个体的模型,会使我们对GAN的理解更加透彻。举例如下:


36f693819dcb48389c3968b7f328b738.png


2.5 f-GAN中判别器的激活函数


565fcf07576f463d8cb27fef0f3beef4.png


21048ebcb9a043b389028fae1e789ea5.png

e9d899bc7bfe49b497044b7686b64a85.png

d23ee5c0d5764cfb98c187c74662529d.png


3 互信息神经估计


互信息神经估计(Mutual Information Neural Estimation,MlNE)是一种基于神经网络估计互信息的方法。它通过BP算法进行训练,对高维度的连续随机变量间的互信息进行估计,可以最大化或者最小化互信息,提升生成模型的对抗训练,突破监督学习分类任务的瓶颈。(参见的论文编号为arX:1801.04062,2018)


3.1 将互信息转化为KL散度


在前面介绍过互信息的公式。它可以表示为两个随机变量XY的边缘分布的乘积相对行太、Y联合概率分布的相对熵,即


c9d2807715604902998471d4d0eae7fd.png


这表明E信息可以通过求KL散度的方法进行计算。


3.2 KL散度的两种对偶表示


KL散度具有不对称性,可以将其转化为具有对偶性的表示方式进行计算,基于散度的对偶表示公式有两种。


91fd6575c31d4304b8780f2e41ab9462.png


其中dual f-divergence表示相对于Donsker-Varadhan表示有更低的下界,会导致估计结果更加宽松和不准确。因此,一般使用Donsker-Varadhan表示。


3.3 神经网络中的KL散度的应用


abc5884170394101b33692c6f585c350.png


4 稳定训练GAN的经验与技巧


4.1 GAN训练失败的分类


GAN模型的训练是神经网络中公认的难题。对于众多训练失败的情况,主要分为两情况:模式丢弃(mode dropping)和模式崩塌(mode collapsing)


  • 模式丢弃是指模型生成的模拟样本中,缺乏多样性的问题,即生成的模拟数据是原始数摆集中的一个子集。刚如,MNST数据分布一共有10个分类,而生成器所生成的模拟数据只有其中某个数字。


  • 模式崩塌:生成器所生成的模拟样本非常模湖,质量很低。


4.2 GAN训练技巧


4.2.1 降低学习率


通常,当使用更大的批次训练横型时,可以设置更高的学习率。但是,当模型发生模式透弃情况时,可以尝试降低模型的学习率,并从头开始训练。


4.2.2 标签平滑


标签平滑可以有效地改善训练中模式崩塌的情况。这种方法也非常容易理解和实现,如奥真实图像的标签设置为1,就将它改成一个低一点的值(如0.9)。这个解决方案阻止判别器过于相信分类标签,即不依赖非常有限的一组特征来判断图像是真还是假。


4.2.3 多尺度梯度


这种技术常用于生成较大(1024像素×1024像素)的模拟图像。该方法的处理方式与传统的用于语义分割的U-Net类似。模型更关注的是多尺度梯度,将真实图片通过下采样方式获得的多尺度图片与生成器的多跳连接部分输出的多尺度向量一起送入判别器,形成MSG-GAN架构。(参见的论文编号为arXv:1903.06048,2019)


4.2.4 更换损失函数


在f-GAN系列的训练方法中,由于散度的度量不同,导致训练不稳定性问题的存在。在这种情况下,可以在模型中使用不同的度量方法作为损失函数,找到更适合的解决方法。


4.2.5 借助互信息估计方法


在训练模型时,还可以使用MNE方法来辅助模型训练。


MINE方法是一个通用的训练方法,可以用于各种模型(自编码神经网络、对抗神经网络)。在GAN的训练过程中,使用MINE方法辅助训练模型会有更好的表现,如图8-27所示。


图8-27左侧是GAN模型生成的结果;右侧是使用MINE辅助训练后的生成结果。可以看到,图中右侧的模拟数据(黄色的点)所覆盖的空间与原始数据(蓝色的点)更一致。


76d8236215d14ef693639721ccdb74e5.png


4.3 MINE方法概述


MINE方法中主要使用了两种技术:互信息转为神经网络模型技术和使用对偶KL散度计算损失技术。最有价值的是这两种技术的思想,利用互信息转为神经网络模型技术,可应用到更多的提示结构中,同时损失函数也可以根据具体的任务而使用不同的分布度量算法。【详见下一节实战】



04d400d2964c4cbb8f9dc6b4e21d4eea.png

目录
相关文章
|
20天前
|
机器学习/深度学习 运维 算法
基于机器学习的网络安全威胁检测系统优化策略
【4月更文挑战第21天】 随着网络环境的日趋复杂,传统的安全防御机制在应对日益狡猾的网络攻击时显得力不从心。本文提出了一种结合深度学习与行为分析的网络安全威胁检测系统的优化策略,旨在提高对先进持续威胁(APT)和零日攻击的识别能力。通过构建一个多层次特征提取框架,并引入自适应学习算法,该系统能够实时学习网络行为模式,有效区分正常行为与潜在威胁。同时,文中探讨了模型训练过程中的数据增强、对抗性样本生成以及模型蒸馏等技术的应用,以提升模型的泛化能力和鲁棒性。
|
7天前
|
机器学习/深度学习 人工智能 编解码
【AI 生成式】生成对抗网络 (GAN) 的概念
【5月更文挑战第4天】【AI 生成式】生成对抗网络 (GAN) 的概念
【AI 生成式】生成对抗网络 (GAN) 的概念
|
11天前
|
机器学习/深度学习 PyTorch 算法框架/工具
【Python机器学习专栏】PyTorch在深度学习中的应用
【4月更文挑战第30天】PyTorch是流行的开源深度学习框架,基于动态计算图,易于使用且灵活。它支持张量操作、自动求导、优化器和神经网络模块,适合快速实验和模型训练。PyTorch的优势在于易用性、灵活性、社区支持和高性能(利用GPU加速)。通过Python示例展示了如何构建和训练神经网络。作为一个强大且不断发展的工具,PyTorch适用于各种深度学习任务。
|
11天前
|
机器学习/深度学习 PyTorch TensorFlow
【Python机器学习专栏】循环神经网络(RNN)与LSTM详解
【4月更文挑战第30天】本文探讨了处理序列数据的关键模型——循环神经网络(RNN)及其优化版长短期记忆网络(LSTM)。RNN利用循环结构处理序列依赖,但遭遇梯度消失/爆炸问题。LSTM通过门控机制解决了这一问题,有效捕捉长距离依赖。在Python中,可使用深度学习框架如PyTorch实现LSTM。示例代码展示了如何定义和初始化一个简单的LSTM网络结构,强调了RNN和LSTM在序列任务中的应用价值。
|
11天前
|
机器学习/深度学习 PyTorch TensorFlow
【Python机器学习专栏】卷积神经网络(CNN)的原理与应用
【4月更文挑战第30天】本文介绍了卷积神经网络(CNN)的基本原理和结构组成,包括卷积层、激活函数、池化层和全连接层。CNN在图像识别等领域表现出色,其层次结构能逐步提取特征。在Python中,可利用TensorFlow或PyTorch构建CNN模型,示例代码展示了使用TensorFlow Keras API创建简单CNN的过程。CNN作为强大深度学习模型,未来仍有广阔发展空间。
|
11天前
|
机器学习/深度学习 自然语言处理 语音技术
【Python 机器学习专栏】Python 深度学习入门:神经网络基础
【4月更文挑战第30天】本文介绍了Python在深度学习中应用于神经网络的基础知识,包括神经网络概念、基本结构、训练过程,以及Python中的深度学习库TensorFlow和PyTorch。通过示例展示了如何使用Python实现神经网络,并提及优化技巧如正则化和Dropout。最后,概述了神经网络在图像识别、语音识别和自然语言处理等领域的应用,并强调掌握这些知识对深度学习的重要性。随着技术进步,神经网络的应用将持续扩展,期待更多创新。
|
11天前
|
机器学习/深度学习 PyTorch 算法框架/工具
Python用GAN生成对抗性神经网络判别模型拟合多维数组、分类识别手写数字图像可视化
Python用GAN生成对抗性神经网络判别模型拟合多维数组、分类识别手写数字图像可视化
|
11天前
|
机器学习/深度学习 数据采集 安全
基于机器学习的网络安全威胁检测系统
【4月更文挑战第30天】 随着网络技术的迅猛发展,网络安全问题日益凸显。传统的安全防御机制在应对复杂多变的网络攻击时显得力不从心。为了提高威胁检测的准确性和效率,本文提出了一种基于机器学习的网络安全威胁检测系统。该系统通过集成多种数据预处理技术和特征选择方法,结合先进的机器学习算法,能够实时识别并响应各类网络威胁。实验结果表明,与传统方法相比,本系统在检测率、误报率以及处理速度上均有显著提升,为网络安全管理提供了一种新的技术手段。
|
12天前
|
机器学习/深度学习 数据采集 监控
构建高效机器学习模型的策略与实践云端防御:融合云计算与网络安全的未来策略
【4月更文挑战第29天】 在数据驱动的时代,构建一个高效的机器学习模型对于解决复杂问题至关重要。本文将探讨一系列策略和最佳实践,旨在提高机器学习模型的性能和泛化能力。我们将从数据处理的重要性入手,进而讨论模型选择、训练技巧、超参数调优以及模型评估方法。通过这些策略的实施,读者将能够构建出更加健壮、准确的模型,并有效地避免过拟合和欠拟合问题。
|
13天前
|
机器学习/深度学习
GAN网络的代码实现(学习ing)
GAN网络的代码实现(学习ing)

热门文章

最新文章