中国人工智能学会通讯——最优传输理论在机器学习中的应用 1.1 最优传输理论与 WGAN 模型-阿里云开发者社区

开发者社区> 人工智能> 正文

中国人工智能学会通讯——最优传输理论在机器学习中的应用 1.1 最优传输理论与 WGAN 模型

简介:

image

最优传输理论是连接几何和概率的桥梁, 它用几何的方法为概率分布的建模和衡量概 率分布之间的距离提供了强有力的工具。最 近,最优传输理论的概念和方法日益渗透进 机器学习领域,为机器学习原理的解释提供 了新的视角,为机器学习算法的改进提供了 新的指导方向。

本文介绍最优传输理论的基本概念和原 理,解释如何用最优传输理论的框架来表示 概率分布,度量概率分布间的距离,如何降 维逼近,并进一步解释这些手法在机器学习 中的应用,给出机器学习原理和特点的最优 传输理论阐释。

1.1 最优传输理论与 WGAN 模型

1. 生成对抗网络简介

训练模型生成对抗网络 (GAN, Generative Adversarial Networks)[1] 是一个“自 相矛盾”的系统,就是“以己之矛,攻己之盾”, 在矛盾中发展,使得矛更加锋利,盾更加强 韧。这里的矛被称为判别器(Descriminator), 这里的盾被称为生成器(Generator)。如图 1~3 所示。

生成器 G 一般是将一个随机变量(例如 高斯分布,或者均匀分布),通过参数化的 概率生成模型(通常是用一个深度神经网来 进行参数化),进行概率分布的逆变换采样, 从而得到一个生成的概率分布。如图 2 所示。 判别器 D 也通常采用深度卷积神经网。

image
image

我们的目的是要找出给定的真实数据内 部的统计规律,将其概率分布表示为 Pr。为 此制作了一个随机变量生成器 G,G 能够产生 随机变量,其概率分布是 Pg,我们用 Pg 来尽 量接近 Pr。为了区分真实概率分布 Pr 和生成 概率分布 Pg,又制作了一个判别器 D,D 用 来判别一个样本是来自真实数据,还是来自 G 生成的伪造数据。为了使 GAN 中的判别器尽 可能将真实样本判为正例,将生成样本判为负 例,Goodfellow 设计了如下的损失函数(loss function):

image

这里第一项不依赖于生成器 G。 此式也可用 于定义 GAN 中生成器的损失函数。

矛盾的交锋过程如下:在训练过程中, 判别器 D 和生成器 G 交替学习,最终达到纳 什均衡(零和游戏)。在均衡状态,判别器 无法区分真实样本和生成样本,此时的生成 概率分布 Pg,可以被视作是真实概率分布 Pr 的一个良好逼近。如图 1~3 所示。

GAN 具有非常重要的优越性:当真实 数据的概率分布 Pr 不可计算时,依赖于数 据内在解释的传统生成模型无法被直接应 用。但是 GAN 依然可以被使用,这是因 为 GAN 引入了内部对抗的训练机制,能 够逼近难以计算的概率分布。Yann LeCun 一直积极倡导 GAN,因为 GAN 为无监督 学习提供了一个强有力的算法框架,而无 监督学习被广泛认为是通往人工智能的重 要一环。

原始 GAN 形式具有致命缺陷:判别器 越好,生成器的梯度消失越严重。我们固定 生成器 G 来优化判别器 D。考察任意一个样 本 x,其对判别器损失函数的贡献是

image

在这种情况下(判别器最优),如果 Pr 和 Pg 的支撑集合 (support) 交集为零测度,则生成 器的损失函数恒为 0,梯度消失。

本质上,JS 散度给出了概率分布 Pr 、 Pg 之间的差异程度,亦即概率分布间的度 量。我们可以用其他的度量来替换 JS 散度。Wasserstein 距离就是一个好的选择,因为 即便 Pr 、Pg 的交集为零测度,它们之间的 Wasserstein 距离依然非零。这样我们就得到 了 Wasserstein GAN 的模式 [2-3]。Wasserstein 距离的好处在于,即便 Pr、 Pg 两个分布之间 没有重叠,Wasserstein 距离依然能够度量它们的远近。

为此,我们引入最优传输的几何理论 (Optimal Mass Transportation),这个理论可视 化了 W-GAN 的关键概念,例如概率分布、 概率生成模型(生成器)、Wasserstein 距离。 更为重要的,这套理论中所有的概念、原理 都是透明的。例如,对于概率生成模型,理 论上我们可以用最优传输的框架取代深度神 经网络来构造生成器,从而使得机器学习的 黑箱变得透明。

2. 最优传输理论梗概

image
image

蒙 日-安 培 方 程 解 的 存 在 性、 唯 一 性 等价于经典的凸几何中的亚历山大定理 (Alexandrov Theorem)。

image
image

3. W-GAN 中关键概念可视化

W-GAN 模型中,关键的概念包括概率分 布(概率测度)、概率测度间的最优传输映 射(生成器)、概率测度间的Wasserstein距离。 下面我们详细解释每个概念的含义、所对应 的构造方法和相应的几何意义。

概率分布 GAN 模型中有两个至关重要 的概率分布(probability measure),一个 是真实数据的概率分布 Pr;一个是生成数 据的概率分布 Pg。另外,生成器的输入随 机变量可以是任意标准概率分布,例如高 斯分布、均匀分布等。

概率测度可以看成是一种推广的面积(或 者体积)。我们可以用几何变换随意构造一 个概率测度。如图 5 所示,我们用三维扫描 仪获取一张人脸曲面,那么人脸曲面上的面 积就是一个概率测度。我们缩放变换人脸曲 面,使得总面积等于 π;然后,用保角变换 将人脸曲面映射到平面圆盘。如图 5 所示, 保角变换将人脸曲面上的无穷小圆映到平面 上的无穷小圆,但是,小圆的面积发生了变化。 每对小圆的面积比率定义了平面圆盘上的概 率密度函数。

image
image
image
image
image
image
image
image
image
image
image

4. 小结

image

在 W-GAN 模型中,通常生成器和判别 器是用深度神经网络来实现的。根据最优传 输理论,可以用 Briener 势函数来代替深度 神经网络这个黑箱,从而使得整个系统变得透明。在另一层面上,深度神经网络本质上 是在训练概率分布间的传输映射,因此有可 能隐含地在学习最优传输映射,或者等价地 Brenier 势能函数。对这些问题的深入了解, 将有助于我们看穿黑箱。和图6中的例子类似, 图 12 显示了用最优传输映射计算的曲面保面 积参数化。最优传输理论在任意维空间都成立, 图 13 显示了一个三维体的最优传输例子。

image

版权声明:本文内容由阿里云实名注册用户自发贡献,版权归原作者所有,阿里云开发者社区不拥有其著作权,亦不承担相应法律责任。具体规则请查看《阿里云开发者社区用户服务协议》和《阿里云开发者社区知识产权保护指引》。如果您发现本社区中有涉嫌抄袭的内容,填写侵权投诉表单进行举报,一经查实,本社区将立刻删除涉嫌侵权内容。

分享:
人工智能
使用钉钉扫一扫加入圈子
+ 订阅

了解行业+人工智能最先进的技术和实践,参与行业+人工智能实践项目

其他文章