在TensorFlow中对比两大生成模型:VAE与GAN(附测试代码)

简介:

项目链接:https://github.com/kvmanohar22/ Generative-Models
变分自编码器(VAE)与生成对抗网络(GAN)是复杂分布上无监督学习最具前景的两类方法。
本项目总结了使用变分自编码器(Variational Autoencode,VAE)和生成对抗网络(GAN)对给定数据分布进行建模,并且对比了这些模型的性能。你可能会问:我们已经有了数百万张图像,为什么还要从给定数据分布中生成图像呢?正如 Ian Goodfellow 在 NIPS 2016 教程中指出的那样,实际上有很多应用。我觉得比较有趣的一种是使用 GAN 模拟可能的未来,就像强化学习中使用策略梯度的智能体那样。
本文组织架构:

  • 变分自编码器(VAE)
  • 生成对抗网络(GAN)
  • 训练普通 GAN 的难点
  • 训练细节
  • 在 MNIST 上进行 VAE 和 GAN 对比实验

    • 在无标签的情况下训练 GAN 判别器
    • 在有标签的情况下训练 GAN 判别器
  • 在 CIFAR 上进行 VAE 和 GAN 实验
  • 延伸阅读


VAE

变分自编码器可用于对先验数据分布进行建模。从名字上就可以看出,它包括两部分:编码器和解码器。编码器将数据分布的高级表征映射到数据的低级表征,低级表征叫作本征向量(latent vector)。解码器吸收数据的低级表征,然后输出同样数据的高级表征。
从数学上来讲,让 X 作为编码器的输入,z 作为本征向量,X′作为解码器的输出。
图 1 是 VAE 的可视化图。

1

这与标准自编码器有何不同?关键区别在于我们对本征向量的约束。如果是标准自编码器,那么我们主要关注重建损失(reconstruction loss),即:

2

而在变分自编码器的情况中,我们希望本征向量遵循特定的分布,通常是单位高斯分布(unit Gaussian distribution),使下列损失得到优化:

3

p(z′)∼N(0,I) 中 I 指单位矩阵(identity matrx),q(z∣X) 是本征向量的分布,其中。和由神经网络来计算。KL(A,B) 是分布 B 到 A 的 KL 散度。
由于损失函数中还有其他项,因此存在模型生成图像的精度,同本征向量的分布与单位高斯分布的接近程度之间存在权衡(trade-off)。这两部分由两个超参数λ_1 和λ_2 来控制。

GAN

GAN 是根据给定的先验分布生成数据的另一种方式,包括同时进行的两部分:判别器和生成器。
判别器用于对「真」图像和「伪」图像进行分类,生成器从随机噪声中生成图像(随机噪声通常叫作本征向量或代码,该噪声通常从均匀分布(uniform distribution)或高斯分布中获取)。生成器的任务是生成可以以假乱真的图像,令判别器也无法区分出来。也就是说,生成器和判别器是互相对抗的。判别器非常努力地尝试区分真伪图像,同时生成器尽力生成更加逼真的图像,目的是使判别器将这些图像也分类为「真」图像。
图 2 是 GAN 的典型结构。

4

生成器包括利用代码输出图像的解卷积层。图 3 是生成器的架构图。

5


训练 GAN 的难点

训练 GAN 时我们会遇到一些挑战,我认为其中最大的挑战在于本征向量/代码的采样。代码只是从先验分布中对本征变量的噪声采样。有很多种方法可以克服该挑战,包括:使用 VAE 对本征变量进行编码,学习数据的先验分布。这听起来要好一些,因为编码器能够学习数据分布,现在我们可以从分布中进行采样,而不是生成随机噪声。

训练细节

我们知道两个分布 p(真实分布)和 q(估计分布)之间的交叉熵通过以下公式计算:

6

  • 对于二元分类:


7

  • 对于 GAN,我们假设分布的一半来自真实数据分布,一半来自估计分布,因此:

8

训练 GAN 需要同时优化两个损失函数。
按照极小极大值算法:

9

这里,判别器需要区分图像的真伪,不管图像是否包含真实物体,都没有注意力。当我们在 CIFAR 上检查 GAN 生成的图像时会明显看到这一点。
我们可以重新定义判别器损失目标,使之包含标签。这被证明可以提高主观样本的质量。如:在 MNIST 或 CIFAR-10(两个数据集都有 10 个类别)。
上述 Python 损失函数在 TensorFlow 中的实现:

def VAE_loss(true_images, logits, mean, std):
      """
        Args:
          true_images : batch of input images
          logits      : linear output of the decoder network (the constructed images)
          mean        : mean of the latent code
          std         : standard deviation of the latent code
      """
      imgs_flat    = tf.reshape(true_images, [-1, img_h*img_w*img_d])
      encoder_loss = 0.5 * tf.reduce_sum(tf.square(mean)+tf.square(std)
                     -tf.log(tf.square(std))-1, 1)
      decoder_loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(
                     logits=logits, labels=img_flat), 1)
      return tf.reduce_mean(encoder_loss + decoder_loss)
  def GAN_loss_without_labels(true_logit, fake_logit):
      """
        Args:
          true_logit : Given data from true distribution,
                      `true_logit` is the output of Discriminator (a column vector)
          fake_logit : Given data generated from Generator,
                      `fake_logit` is the output of Discriminator (a column vector)
      """

      true_prob = tf.nn.sigmoid(true_logit)
      fake_prob = tf.nn.sigmoid(fake_logit)
      d_loss = tf.reduce_mean(-tf.log(true_prob)-tf.log(1-fake_prob))
      g_loss = tf.reduce_mean(-tf.log(fake_prob))
      return d_loss, g_loss  
  def GAN_loss_with_labels(true_logit, fake_logit):
      """
        Args:
          true_logit : Given data from true distribution,
                      `true_logit` is the output of Discriminator (a matrix now)
          fake_logit : Given data generated from Generator,
                      `fake_logit` is the output of Discriminator (a matrix now)
      """
      d_true_loss = tf.nn.softmax_cross_entropy_with_logits(
                    labels=self.labels, logits=self.true_logit, dim=1)
      d_fake_loss = tf.nn.softmax_cross_entropy_with_logits(
                    labels=1-self.labels, logits=self.fake_logit, dim=1)
      g_loss = tf.nn.softmax_cross_entropy_with_logits(
                    labels=self.labels, logits=self.fake_logit, dim=1)

      d_loss = d_true_loss + d_fake_loss      return tf.reduce_mean(d_loss), tf.reduce_mean(g_loss)


在 MNIST 上进行 VAE 与 GAN 对比实验

1. 不使用标签训练判别器
我在 MNIST 上训练了一个 VAE。代码地址:https://github.com/kvmanohar22/Generative-Models
实验使用了 MNIST 的 28×28 图像,下图中:

  • 左侧:数据分布的 64 张原始图像
  • 中间:VAE 生成的 64 张图像
  • 右侧:GAN 生成的 64 张图像

第 1 次迭代:

10

第 2 次迭代:

11

第 3 次迭代:

12

第 4 次迭代:

13

第 100 次迭代:

14

VAE(125)和 GAN(368)训练的最终结果:

15

根据GAN迭代次数生成的gif图:

16

显然,VAE 生成的图像与 GAN 生成的图像相比,前者更加模糊。这个结果在预料之中,因为 VAE 模型生成的所有输出都是分布平均。为了减少图像的模糊度,我们可以使用 L1 损失来代替 L2 损失。
在第一个实验后,作者还将在近期研究使用标签训练判别器,并在 CIFAR 数据集上测试 VAE 与 GAN 的性能。
使用
下载 MNIST 和 CIFAR 数据集
使用 MNIST 训练 VAE 请运行:

python main.py --train --model vae --dataset mnist

使用 MNIST 训练 GAN 请运行:

python main.py --train --model gan --dataset mnist

想要获取完整的命令行选项,请运行:

python main.py --help

该模型由 generate_frq 决定生成图片的频率,默认值为 1。

GAN 在 MNIST 上的训练结果

MNIST 数据集中的样本图像:

17

上方是 VAE 生成的图像,下方的图展示了 GAN 生成图像的过程:

18

原文发布时间为:2017-10-29
本文来自云栖社区合作伙伴“数据派THU”,了解相关信息可以关注“数据派THU”微信公众号

相关文章
|
2月前
|
数据采集 机器学习/深度学习 大数据
行为检测代码(一):超详细介绍C3D架构训练+测试步骤
这篇文章详细介绍了C3D架构在行为检测领域的应用,包括训练和测试步骤,使用UCF101数据集进行演示。
80 1
行为检测代码(一):超详细介绍C3D架构训练+测试步骤
|
2月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
深度学习之格式转换笔记(三):keras(.hdf5)模型转TensorFlow(.pb) 转TensorRT(.uff)格式
将Keras训练好的.hdf5模型转换为TensorFlow的.pb模型,然后再转换为TensorRT支持的.uff格式,并提供了转换代码和测试步骤。
106 3
深度学习之格式转换笔记(三):keras(.hdf5)模型转TensorFlow(.pb) 转TensorRT(.uff)格式
|
2月前
|
机器学习/深度学习 人工智能 监控
提升软件质量的关键路径:高效测试策略与实践在软件开发的宇宙中,每一行代码都如同星辰般璀璨,而将这些星辰编织成星系的过程,则依赖于严谨而高效的测试策略。本文将引领读者探索软件测试的奥秘,揭示如何通过精心设计的测试方案,不仅提升软件的性能与稳定性,还能加速产品上市的步伐,最终实现质量与效率的双重飞跃。
在软件工程的浩瀚星海中,测试不仅是发现缺陷的放大镜,更是保障软件质量的坚固防线。本文旨在探讨一种高效且创新的软件测试策略框架,它融合了传统方法的精髓与现代技术的突破,旨在为软件开发团队提供一套系统化、可执行性强的测试指引。我们将从测试规划的起点出发,沿着测试设计、执行、反馈再到持续优化的轨迹,逐步展开论述。每一步都强调实用性与前瞻性相结合,确保测试活动能够紧跟软件开发的步伐,及时适应变化,有效应对各种挑战。
|
1月前
|
机器学习/深度学习 数据采集 数据可视化
TensorFlow,一款由谷歌开发的开源深度学习框架,详细讲解了使用 TensorFlow 构建深度学习模型的步骤
本文介绍了 TensorFlow,一款由谷歌开发的开源深度学习框架,详细讲解了使用 TensorFlow 构建深度学习模型的步骤,包括数据准备、模型定义、损失函数与优化器选择、模型训练与评估、模型保存与部署,并展示了构建全连接神经网络的具体示例。此外,还探讨了 TensorFlow 的高级特性,如自动微分、模型可视化和分布式训练,以及其在未来的发展前景。
76 5
|
27天前
|
并行计算 算法 测试技术
C语言因高效灵活被广泛应用于软件开发。本文探讨了优化C语言程序性能的策略,涵盖算法优化、代码结构优化、内存管理优化、编译器优化、数据结构优化、并行计算优化及性能测试与分析七个方面
C语言因高效灵活被广泛应用于软件开发。本文探讨了优化C语言程序性能的策略,涵盖算法优化、代码结构优化、内存管理优化、编译器优化、数据结构优化、并行计算优化及性能测试与分析七个方面,旨在通过综合策略提升程序性能,满足实际需求。
61 1
|
1月前
|
机器学习/深度学习 人工智能 算法
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
手写数字识别系统,使用Python作为主要开发语言,基于深度学习TensorFlow框架,搭建卷积神经网络算法。并通过对数据集进行训练,最后得到一个识别精度较高的模型。并基于Flask框架,开发网页端操作平台,实现用户上传一张图片识别其名称。
97 0
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
|
1月前
|
机器学习/深度学习 人工智能 算法
基于深度学习的【蔬菜识别】系统实现~Python+人工智能+TensorFlow+算法模型
蔬菜识别系统,本系统使用Python作为主要编程语言,通过收集了8种常见的蔬菜图像数据集('土豆', '大白菜', '大葱', '莲藕', '菠菜', '西红柿', '韭菜', '黄瓜'),然后基于TensorFlow搭建卷积神经网络算法模型,通过多轮迭代训练最后得到一个识别精度较高的模型文件。在使用Django开发web网页端操作界面,实现用户上传一张蔬菜图片识别其名称。
95 0
基于深度学习的【蔬菜识别】系统实现~Python+人工智能+TensorFlow+算法模型
|
1月前
|
编解码 人工智能 自然语言处理
迈向多语言医疗大模型:大规模预训练语料、开源模型与全面基准测试
【10月更文挑战第23天】Oryx 是一种新型多模态架构,能够灵活处理各种分辨率的图像和视频数据,无需标准化。其核心创新包括任意分辨率编码和动态压缩器模块,适用于从微小图标到长时间视频的多种应用场景。Oryx 在长上下文检索和空间感知数据方面表现出色,并且已开源,为多模态研究提供了强大工具。然而,选择合适的分辨率和压缩率仍需谨慎,以平衡处理效率和识别精度。论文地址:https://www.nature.com/articles/s41467-024-52417-z
52 2
|
1月前
|
机器学习/深度学习 人工智能 算法
【车辆车型识别】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+算法模型
车辆车型识别,使用Python作为主要编程语言,通过收集多种车辆车型图像数据集,然后基于TensorFlow搭建卷积网络算法模型,并对数据集进行训练,最后得到一个识别精度较高的模型文件。再基于Django搭建web网页端操作界面,实现用户上传一张车辆图片识别其类型。
87 0
【车辆车型识别】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+算法模型
|
2月前
|
测试技术
谈谈【软件测试的基础知识,基础模型】
谈谈【软件测试的基础知识,基础模型】
33 5

热门文章

最新文章