TensorFlow 2 和 Keras 高级深度学习:6~10(3)

简介: TensorFlow 2 和 Keras 高级深度学习:6~10(3)

TensorFlow 2 和 Keras 高级深度学习:6~10(2)https://developer.aliyun.com/article/1426951

我们遵循训练过程,我们可以从上一节中的“算法 7.1.1”中调用。“列表 7.1.5”显示了 CycleGAN 训练。 此训练与原始 GAN 之间的次要区别是有两个要优化的判别器。 但是,只有一种对抗模型需要优化。 对于每 2,000 步,生成器将保存预测的源图像和目标图像。 我们将的批量大小设为 32。我们也尝试了 1 的批量大小,但是输出质量几乎相同,并且需要花费更长的时间进行训练(批量为每个图像 43 ms,在 NVIDIA GTX 1060 上批量大小为 32 时,最大大小为每个图像 1 vs 3.6 ms)

“列表 7.1.5”:cyclegan-7.1.1.py

tf.keras中的 CycleGAN 训练例程:

def train_cyclegan(models,
                   data,
                   params,
                   test_params,
                   test_generator):
    """ Trains the CycleGAN. 
    1) Train the target discriminator
    2) Train the source discriminator
    3) Train the forward and backward cyles of 
        adversarial networks
Arguments:
    models (Models): Source/Target Discriminator/Generator,
        Adversarial Model
    data (tuple): source and target training data
    params (tuple): network parameters
    test_params (tuple): test parameters
    test_generator (function): used for generating 
        predicted target and source images
    """
# the models
    g_source, g_target, d_source, d_target, adv = models
    # network parameters
    batch_size, train_steps, patch, model_name = params
    # train dataset
    source_data, target_data, test_source_data, test_target_data\
            = data
titles, dirs = test_params
# the generator image is saved every 2000 steps
    save_interval = 2000
    target_size = target_data.shape[0]
    source_size = source_data.shape[0]
# whether to use patchgan or not
    if patch > 1:
        d_patch = (patch, patch, 1)
        valid = np.ones((batch_size,) + d_patch)
        fake = np.zeros((batch_size,) + d_patch)
    else:
        valid = np.ones([batch_size, 1])
        fake = np.zeros([batch_size, 1])
valid_fake = np.concatenate((valid, fake))
    start_time = datetime.datetime.now()
for step in range(train_steps):
        # sample a batch of real target data
        rand_indexes = np.random.randint(0,
                                         target_size,
                                         size=batch_size)
        real_target = target_data[rand_indexes]
# sample a batch of real source data
        rand_indexes = np.random.randint(0,
                                         source_size,
                                         size=batch_size)
        real_source = source_data[rand_indexes]
        # generate a batch of fake target data fr real source data
        fake_target = g_target.predict(real_source)
# combine real and fake into one batch
        x = np.concatenate((real_target, fake_target))
        # train the target discriminator using fake/real data
        metrics = d_target.train_on_batch(x, valid_fake)
        log = "%d: [d_target loss: %f]" % (step, metrics[0])
# generate a batch of fake source data fr real target data
        fake_source = g_source.predict(real_target)
        x = np.concatenate((real_source, fake_source))
        # train the source discriminator using fake/real data
        metrics = d_source.train_on_batch(x, valid_fake)
        log = "%s [d_source loss: %f]" % (log, metrics[0])
# train the adversarial network using forward and backward
        # cycles. the generated fake source and target 
        # data attempts to trick the discriminators
        x = [real_source, real_target]
        y = [valid, valid, real_source, real_target]
        metrics = adv.train_on_batch(x, y)
        elapsed_time = datetime.datetime.now() - start_time
        fmt = "%s [adv loss: %f] [time: %s]"
        log = fmt % (log, metrics[0], elapsed_time)
        print(log)
        if (step + 1) % save_interval == 0:
            test_generator((g_source, g_target),
                           (test_source_data, test_target_data),
                           step=step+1,
                           titles=titles,
                           dirs=dirs,
                           show=False)
# save the models after training the generators
    g_source.save(model_name + "-g_source.h5")
    g_target.save(model_name + "-g_target.h5")

最后,在使用 CycleGAN 构建和训练函数之前,我们必须执行一些数据准备。 模块cifar10_utils.pyother_ utils.py加载CIFAR10训练和测试数据。 有关这两个文件的详细信息,请参考源代码。 加载后,将训练图像和测试图像转换为灰度,以生成源数据和测试源数据。

“列表 7.1.6”显示了 CycleGAN 如何用于构建和训练用于灰度图像着色的生成器网络(g_target)。 由于 CycleGAN 是对称的,因此我们还构建并训练了第二个生成器网络(g_source),该网络可以将颜色转换为灰度。 训练了两个 CycleGAN 着色网络。 第一种使用标量输出类似于原始 GAN 的判别器,第二种使用2 x 2 PatchGAN。

“列表 7.1.6”:cyclegan-7.1.1.py

CycleGAN 用于着色:

def graycifar10_cross_colorcifar10(g_models=None):
    """Build and train a CycleGAN that can do
        grayscale <--> color cifar10 images
    """
model_name = 'cyclegan_cifar10'
    batch_size = 32
    train_steps = 100000
    patchgan = True
    kernel_size = 3
    postfix = ('%dp' % kernel_size) \
            if patchgan else ('%d' % kernel_size)
data, shapes = cifar10_utils.load_data()
    source_data, _, test_source_data, test_target_data = data
    titles = ('CIFAR10 predicted source images.',
              'CIFAR10 predicted target images.',
              'CIFAR10 reconstructed source images.',
              'CIFAR10 reconstructed target images.')
    dirs = ('cifar10_source-%s' % postfix, \
            'cifar10_target-%s' % postfix)
# generate predicted target(color) and source(gray) images
    if g_models is not None:
        g_source, g_target = g_models
        other_utils.test_generator((g_source, g_target),
                                   (test_source_data, \
                                           test_target_data),
                                   step=0,
                                   titles=titles,
                                   dirs=dirs,
                                   show=True)
        return
# build the cyclegan for cifar10 colorization
    models = build_cyclegan(shapes,
                            "gray-%s" % postfix,
                            "color-%s" % postfix,
                            kernel_size=kernel_size,
                            patchgan=patchgan)
    # patch size is divided by 2^n since we downscaled the input
    # in the discriminator by 2^n (ie. we use strides=2 n times)
    patch = int(source_data.shape[1] / 2**4) if patchgan else 1
    params = (batch_size, train_steps, patch, model_name)
    test_params = (titles, dirs)
    # train the cyclegan
    train_cyclegan(models,
                   data,
                   params,
                   test_params,
                   other_utils.test_generator)

在的下一部分中,我们将检查 CycleGAN 的生成器输出以进行着色。

CycleGAN 的生成器输出

“图 7.1.13”显示 CycleGAN 的着色结果。 源图像来自测试数据集:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-UdAhE17z-1681704311659)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_13.png)]

图 7.1.13:使用不同技术进行着色。 显示的是基本事实,使用自编码器的着色(第 3 章,自编码器),使用带有原始 GAN 判别器的 CycleGAN 进行着色,以及使用带有 PatchGAN 判别器的 CycleGAN 进行着色。 彩色效果最佳。 原始彩色照片可以在该书的 GitHub 存储库中找到。

为了进行比较,我们使用第 3 章,“自编码器”中描述的普通自编码器显示了地面真实情况和着色结果。 通常,所有彩色图像在感觉上都是可接受的。 总体而言,似乎每种着色技术都有自己的优点和缺点。 所有着色方法与天空和车辆的正确颜色不一致。

例如,平面背景(第三行,第二列)中的天空为白色。 自编码器没错,但是 CycleGAN 认为它是浅棕色或蓝色。

对于第六行第六列,暗海上的船天空阴沉,但自编码器将其涂成蓝色和蓝色,而 CycleGAN 将其涂成蓝色和白色,而没有 PatchGAN。 两种预测在现实世界中都是有意义的。 同时,使用 PatchGAN 对 CycleGAN 的预测与基本事实相似。 在倒数第二行和第二列上,没有方法能够预测汽车的红色。 在动物身上,CycleGAN 的两种口味都具有接近真实情况的颜色。

由于 CycleGAN 是对称的,因此它还能在给定彩色图像的情况下预测灰度图像。“图 7.1.14”显示了两个 CycleGAN 变体执行的颜色到灰度转换。 目标图像来自测试数据集。 除了某些图像的灰度阴影存在细微差异外,这些预测通常是准确的。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-9v2eWewg-1681704311659)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_14.png)]

图 7.1.14:颜色(来自图 7.1.9)到 CycleGAN 的灰度转换

要训练 CycleGAN 进行着色,命令是:

python3 cyclegan-7.1.1.py -c

读者可以使用带有 PatchGAN 的 CycleGAN 预训练模型来运行图像转换:

python3 cyclegan-7.1.1.py --cifar10_g_source=cyclegan_cifar10-g_source.h5
--cifar10_g_target=cyclegan_cifar10-g_target.h5

在本节中,我们演示了 CycleGAN 在着色上的一种实际应用。 在下一部分中,我们将在更具挑战性的数据集上训练 CycleGAN。 源域 MNIST 与目标域 SVHN 数据集有很大的不同[1]。

MNIST 和 SVHN 数据集上的 CycleGAN

我们现在要解决一个更具挑战性的问题。 假设我们使用 MNIST 灰度数字作为源数据,并且我们想从 SVHN [1]中借鉴样式,这是我们的目标数据。 每个域中的样本数据显示在“图 7.1.15”中:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-fkmEFbTV-1681704311660)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_15.png)]

图 7.1.15:两个未对齐数据的不同域。 原始彩色照片可以在该书的 GitHub 存储库中找到。

我们可以重用上一节中讨论的 CycleGAN 的所有构建和训练函数,以执行样式迁移。 唯一的区别是,我们必须添加用于加载 MNIST 和 SVHN 数据的例程。 SVHN 数据集可在这个页面中找到。

我们介绍mnist_svhn_utils.py模块来帮助我们完成此任务。“列表 7.1.7”显示了针对跨域迁移的 CycleGAN 的初始化和训练。

CycleGAN 结构与上一部分相同,不同之处在于我们使用的核大小为 5,因为两个域完全不同。

“列表 7.1.7”:cyclegan-7.1.1.py

CycleGAN 用于 MNIST 和 SVHN 之间的跨域样式迁移:

def mnist_cross_svhn(g_models=None):
    """Build and train a CycleGAN that can do mnist <--> svhn
    """
model_name = 'cyclegan_mnist_svhn'
    batch_size = 32
    train_steps = 100000
    patchgan = True
    kernel_size = 5
    postfix = ('%dp' % kernel_size) \
            if patchgan else ('%d' % kernel_size)
data, shapes = mnist_svhn_utils.load_data()
    source_data, _, test_source_data, test_target_data = data
    titles = ('MNIST predicted source images.',
              'SVHN predicted target images.',
              'MNIST reconstructed source images.',
              'SVHN reconstructed target images.')
    dirs = ('mnist_source-%s' \
            % postfix, 'svhn_target-%s' % postfix)
# generate predicted target(svhn) and source(mnist) images
    if g_models is not None:
        g_source, g_target = g_models
        other_utils.test_generator((g_source, g_target),
                                   (test_source_data, \
                                           test_target_data),
                                   step=0,
                                   titles=titles,
                                   dirs=dirs,
                                   show=True)
        return
# build the cyclegan for mnist cross svhn
    models = build_cyclegan(shapes,
                            "mnist-%s" % postfix,
                            "svhn-%s" % postfix,
                            kernel_size=kernel_size,
                            patchgan=patchgan)
    # patch size is divided by 2^n since we downscaled the input
    # in the discriminator by 2^n (ie. we use strides=2 n times)
    patch = int(source_data.shape[1] / 2**4) if patchgan else 1
    params = (batch_size, train_steps, patch, model_name)
    test_params = (titles, dirs)
    # train the cyclegan
    train_cyclegan(models,
                   data,
                   params,
                   test_params,
                   other_utils.test_generator)

将 MNIST 从测试数据集迁移到 SVHN 的结果显示在“图 7.1.16”中。 生成的图像具有样式的 SVHN,但是数字未完全传送。 例如,在第四行上,数字 3、1 和 3 由 CycleGAN 进行样式化。

但是,在第三行中,不带有和带有 PatchGAN 的 CycleGAN 的数字 9、6 和 6 分别设置为 0、6、01、0、65 和 68:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-MMIvwZQw-1681704311660)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_16.png)]

图 7.1.16:测试数据从 MNIST 域到 SVHN 的样式迁移。 原始彩色照片可以在该书的 GitHub 存储库中找到。

向后循环的结果为“图 7.1.17”中所示的。 在这种情况下,目标图像来自 SVHN 测试数据集。 生成的图像具有 MNIST 的样式,但是数字没有正确翻译。 例如,在第一行中,对于不带和带有 PatchGAN 的 CycleGAN,数字 5、2 和 210 分别被样式化为 7、7、8、3、3 和 1:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ZTSndytm-1681704311660)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_17.png)]

图 7.1.17:测试数据从 SVHN 域到 MNIST 的样式迁移。 原始彩色照片可以在该书的 GitHub 存储库中找到。

在 PatchGAN 的情况下,假设预测的 MNIST 数字被限制为一位,则输出 1 是可以理解的。 有以某种方式正确的预测,例如在第二行中,不使用 PatchGAN 的 CycleGAN 将 SVHN 数字的最后三列 6、3 和 4 转换为 6、3 和 6。 但是,CycleGAN 两种版本的输出始终是个位数且可识别。

从 MNIST 到 SVHN 的转换中出现的问题称为“标签翻转”[8],其中源域中的数字转换为目标域中的另一个数字。 尽管 CycleGAN 的预测是周期一致的,但它们不一定是语义一致的。 在翻译过程中数字的含义会丢失。

为了解决这个问题, Hoffman [8]引入了一种改进的 CycleGAN,称为循环一致性对抗域自适应CyCADA)。 不同之处在于,附加的语义损失项可确保预测不仅周期一致,而且语义一致。

“图 7.1.18”显示 CycleGAN 在正向循环中重建 MNIST 数字。 重建的 MNIST 数字几乎与源 MNIST 数字相同:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-gGwbiPvZ-1681704311660)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_18.png)]

图 7.1.18:带有 MNIST 上的 PatchGAN 的 CycleGAN(源)到 SVHN(目标)的前向周期。 重建的源类似于原始源。 原始彩色照片可以在该书的 GitHub 存储库中找到。

“图 7.1.19”显示了 CycleGAN 在向后周期中重构 SVHN 数字的过程:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bvUxt1if-1681704311661)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_07_19.png)]

图 7.1.19:带有 MNIST 上的 PatchGAN 的 CycleGAN 与 SVHN(目标)的反向循环。 重建的目标与原始目标并不完全相似。 原始彩色照片可以在该书的 GitHub 存储库中找到。

在“图 7.1.3”中,CycleGAN 被描述为具有周期一致性。 换句话说,给定源x,CycleGAN 将正向循环中的源重构为x'。 另外,在给定目标y的情况下,CycleGAN 在反向循环中将目标重构为y'

重建了许多目标图像。 有些数字显然是相同的,例如最后两列(3 和 4)中的第二行,而有些数字却是相同的但是模糊的,例如前两列列中的第一行(5 和 2)。 尽管样式仍像第二行一样,但在前两列(从 33 和 6 到 1 以及无法识别的数字)中,有些数字会转换为另一数字。

要将 MNIST 的 CycleGAN 训练为 SVHN,命令为:

python3 cyclegan-7.1.1.py -m

鼓励读者使用带有 PatchGAN 的 CycleGAN 预训练模型来运行图像翻译:

python3 cyclegan-7.1.1.py --mnist_svhn_g_source=cyclegan_mnist_svhn-g_ source.h5 --mnist_svhn_g_target=cyclegan_mnist_svhn-g_target.h5

到目前为止,我们只看到了 CycleGAN 的两个实际应用。 两者都在小型数据集上进行了演示,以强调可重复性的概念。 如本章前面所述,CycleGAN 还有许多其他实际应用。 我们在这里介绍的 CycleGAN 可以作为分辨率更高的图像转换的基础。

2. 总结

在本章中,我们讨论了 CycleGAN 作为可用于图像翻译的算法。 在 CycleGAN 中,源数据和目标数据不一定要对齐。 我们展示了两个示例,灰度 ↔ 颜色MNIST ↔ SVHN ,尽管 CycleGAN 可以执行许多其他可能的图像转换 。

在下一章中,我们将着手另一种生成模型,即变分自编码器VAE)。 VAE 具有类似的学习目标–如何生成新图像(数据)。 他们专注于学习建模为高斯分布的潜在向量。 我们将以有条件的 VAE 和解开 VAE 中的潜在表示形式来证明 GAN 解决的问题中的其他相似之处。

3. 参考

  1. Yuval Netzer et al.: Reading Digits in Natural Images with Unsupervised Feature Learning. NIPS workshop on deep learning and unsupervised feature learning. Vol. 2011. No. 2. 2011 (https://www-cs.stanford.edu/~twangcat/papers/nips2011_housenumbers.pdf).
  2. Zhu-Jun-Yan et al.: Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks. 2017 IEEE International Conference on Computer Vision (ICCV). IEEE, 2017 (http://openaccess.thecvf.com/content_ICCV_2017/papers/Zhu_Unpaired_Image-To-Image_Translation_ICCV_2017_paper.pdf).
  3. Phillip Isola et al.: Image-to-Image Translation with Conditional Adversarial Networks. 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR). IEEE, 2017 (http://openaccess.thecvf.com/content_cvpr_2017/papers/Isola_Image-To-Image_Translation_With_CVPR_2017_paper.pdf).
  4. Mehdi Mirza and Simon Osindero. Conditional Generative Adversarial Nets. arXiv preprint arXiv:1411.1784, 2014 (https://arxiv.org/pdf/1411.1784.pdf).
  5. Xudong Mao et al.: Least Squares Generative Adversarial Networks. 2017 IEEE International Conference on Computer Vision (ICCV). IEEE, 2017 (http://openaccess.thecvf.com/content_ICCV_2017/papers/Mao_Least_Squares_Generative_ICCV_2017_paper.pdf).
  6. Chuan Li and Michael Wand. Precomputed Real-Time Texture Synthesis with Markovian Generative Adversarial Networks. European Conference on Computer Vision. Springer, Cham, 2016 (https://arxiv.org/pdf/1604.04382.pdf).
  7. Olaf Ronneberger, Philipp Fischer, and Thomas Brox. U-Net: Convolutional Networks for Biomedical Image Segmentation. International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015 (https://arxiv.org/pdf/1505.04597.pdf).
  8. Judy Hoffman et al.: CyCADA: Cycle-Consistent Adversarial Domain Adaptation. arXiv preprint arXiv:1711.03213, 2017 (https://arxiv.org/pdf/1711.03213.pdf).

八、变分自编码器(VAE)

与我们在之前的章节中讨论过的生成对抗网络GAN)类似,变分自编码器VAE)[1] 属于生成模型家族。 VAE 的生成器能够在导航其连续潜在空间的同时产生有意义的输出。 通过潜向量探索解码器输出的可能属性。

在 GAN 中,重点在于如何得出近似输入分布的模型。 VAE 尝试对可解码的连续潜在空间中的输入分布进行建模。 这是 GAN 与 VAE 相比能够生成更真实信号的可能的潜在原因之一。 例如,在图像生成中,GAN 可以生成看起来更逼真的图像,而相比之下,VAE 生成的图像清晰度较差。

在 VAE 中,重点在于潜在代码的变分推理。 因此,VAE 为潜在变量的学习和有效贝叶斯推理提供了合适的框架。 例如,带有解缠结表示的 VAE 可以将潜在代码重用于迁移学习。

在结构上,VAE 与自编码器相似。 它也由编码器(也称为识别或推理模型)和解码器(也称为生成模型)组成。 VAE 和自编码器都试图在学习潜向量的同时重建输入数据。

但是,与自编码器不同,VAE 的潜在空间是连续的,并且解码器本身被用作生成模型。

在前面各章中讨论的 GAN 讨论中,也可以对 VAE 的解码器进行调整。 例如,在 MNIST 数据集中,我们能够指定一个给定的单热向量产生的数字。 这种有条件的 VAE 类别称为 CVAE [2]。 也可以通过在损失函数中包含正则化超参数来解开 VAE 潜向量。 这称为 β-VAE [5]。 例如,在 MNIST 中,我们能够隔离确定每个数字的粗细或倾斜角度的潜向量。 本章的目的是介绍:

  • VAE 的原理
  • 了解重新参数化技巧,有助于在 VAE 优化中使用随机梯度下降
  • 有条件的 VAE(CVAE)和 β-VAE 的原理
  • 了解如何使用tf.keras实现 VAE

我们将从谈论 VAE 的基本原理开始。

1. VAE 原理

在生成模型中,我们经常对使用神经网络来逼近输入的真实分布感兴趣:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-rVHL8wEO-1681704311661)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_003.png)] (Equation 8.1.1)

在前面的等式中,θ表示训练期间确定的参数。 例如,在名人面孔数据集的上下文中,这等效于找到可以绘制面孔的分布。 同样,在 MNIST 数据集中,此分布可以生成可识别的手写数字。

在机器学习中,为了执行特定级别的推理,我们有兴趣寻找P[θ](x, z),这是输入x和潜在变量z之间的联合分布。 潜在变量不是数据集的一部分,而是对可从输入中观察到的某些属性进行编码。 在名人面孔的背景下,这些可能是面部表情,发型,头发颜色,性别等。 在 MNIST 数据集中,潜在变量可以表示数字和书写样式。

P[θ](x, z)实际上是输入数据点及其属性的分布。 P[θ](x)可以从边际分布计算得出:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-8ycC5rg0-1681704311661)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_010.png)] (Equation 8.1.2)

换句话说,考虑所有可能的属性,我们最终得到描述输入的分布。 在名人面孔中,如果考虑所有面部表情,发型,头发颜色和性别,将恢复描述名人面孔的分布。 在 MNIST 数据集中,如果考虑所有可能的数字,书写风格等,我们以手写数字的分布来结束。

问题在于“公式 8.1.2”很难处理。 该方程式没有解析形式或有效的估计量。 它的参数无法微分。 因此,通过神经网络进行优化是不可行的。

使用贝叶斯定理,我们可以找到“公式 8.1.2”的替代表达式:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Z7xXxJa2-1681704311661)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_011.png)] (Equation 8.1.3)

P(z)z的先验分布。 它不以任何观察为条件。 如果z是离散的,而P[θ](x | z)是高斯分布,则P[θ](x)是高斯的混合。 如果z是连续的,则P[θ](x)是高斯的无限混合。

实际上,如果我们尝试在没有合适的损失函数的情况下建立一个近似P[θ](x | z)的神经网络,它将忽略z得出一个简单的解P[θ](x | z) = P[θ](x)。 因此,“公式 8.1.3”无法为我们提供P[θ](x)的良好估计。 或者,“公式 8.1.2”也可以表示为:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-IGZWILHU-1681704311661)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_024.png)] (Equation 8.1.4)

但是,P[θ](z | x)也很棘手。 VAE 的目标是在给定输入的情况下,找到一种可预测的分布,该分布易于估计P[θ](z | x),即潜在属性z的条件分布的估计。

变分推理

为了使易于处理,VAE 引入了变化推理模型(编码器):

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-iVjUbONa-1681704311662)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_030.png)] (Equation 8.1.5)

Q[φ](z | x)提供了P[θ](z | x)的良好估计。 它既参数化又易于处理。 Q[φ](z | x)可以通过优化参数φ由深度神经网络近似。 通常,Q[φ](z | x)被选择为多元高斯:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-jroqc1Kl-1681704311662)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_036.png)] (Equation 8.1.6)

均值μ(x)和标准差σ(x)均由编码器神经网络使用输入数据点计算得出。 对角线矩阵表示z的元素是独立的。

在下一节中,我们将求解 VAE 的核心方程。 核心方程式将引导我们找到一种优化算法,该算法将帮助我们确定推理模型的参数。

核心方程

推理模型Q[φ](z | x)从输入x生成潜向量zQ[φ](z | x)似于自编码器模型中的编码器。 另一方面,从潜在代码z重构输入。 P[θ](x | z)的作用类似于自编码器模型中的解码器。 要估计P[θ](x),我们必须确定其与Q[φ](z | x)P[θ](x | z)的关系。

如果Q[φ](z | x)P[θ](z | x)的估计值,则 Kullback-LeiblerKL)的差异决定了这两个条件密度之间的距离:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-uHJuQakp-1681704311662)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_052.png)] (Equation 8.1.7)

使用贝叶斯定理:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-XwjhoFzr-1681704311662)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_053.png)] (Equation 8.1.8)

在“公式 8.1.7”中:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-zVLcRvhb-1681704311663)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_054.png)] (Equation 8.1.9)

由于log P[θ](x)不依赖于z ~ Q,因此可能会超出预期。 重新排列“公式 8.1.9”并认识到:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-lDcfbONN-1681704311663)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_057.png)],其结果是:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-5VjPOOmf-1681704311663)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_058.png)] (Equation 8.1.10)

“公式 8.1.10”是 VAE 的核心。 左侧是项P[θ](x),由于Q[φ](z | x)与真实P[θ](z | x)的距离,我们使误差最小化。 我们可以记得,的对数不会更改最大值(或最小值)的位置。 给定提供P[θ](z | x)良好估计的推断模型,D[KL](Q[φ](z | x) || P[θ](z | x)大约为零。

右边的第一项P[θ](x | z)类似于解码器,该解码器从推理模型中抽取样本以重建输入。

第二个项是另一个距离。 这次是在Q[φ](z | x)和先前的P[θ](z)之间。 “公式 8.1.10”的左侧也称为变异下界证据下界ELBO)。 由于 KL 始终为正,因此 ELBO 是log P[θ](x)的下限。 通过优化神经网络的参数φθ来最大化 ELBO 意味着:

  • 在将z中的x属性编码时,D[KL](Q[φ](z | x) || P[θ](z | x) -> 0或推理模型变得更好。
  • 右边的log P[θ](x | z)最大化了“公式 8.1.10”或解码器模型在从潜在向量z重构x方面变得更好。
  • 在下一节中,我们将利用“公式 8.1.10”的结构来确定推理模型(编码器)和解码器的损失函数。

优化

“公式 8.1.10”的右侧具有有关 VAE 的loss函数的两个重要信息。 解码器项E[z~Q] [log P[θ](x | z)]表示生成器从推理模型的输出中提取z个样本,以重建输入。 使最大化是指我们将重构损失L_R降到最低。 如果假设图像(数据)分布为高斯分布,则可以使用 MSE。

如果每个像素(数据)都被认为是伯努利分布,那么损失函数就是二进制互熵。

第二项-D[KL](Q[φ](z | x) || P[θ](z))易于评估。 根据“公式 8.1.6”,Q[φ]是高斯分布。 通常,P[θ](z) = P(z) = N(0, 1)也是平均值为零且标准差等于 1.0 的高斯。 在“公式 8.1.11”中,我们看到 KL 项简化为:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-pjRfealv-1681704311663)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_082.png)] (Equation 8.1.11)

其中Jz的维。 μ[j]σ[j]都是通过推理模型计算的x的函数。 要最大化:-D[KL]σ[j] -> 1μ[j] -> 9P(z) = N(0, 1)的选择源于各向同性单位高斯的性质,在具有适当函数的情况下,它可以变形为任意分布[6]。

根据“公式 8.1.11”,KL 损失L_KL简称为D[KL]

总之,在“公式 8.1.12”中将 VAE loss函数定义为:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-dQ91SFSk-1681704311664)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_094.png)] (Equation 8.1.12)

在给定编码器和解码器模型的情况下,在我们可以构建和训练 VAE(随机采样块,生成潜在属性)之前,还需要解决一个问题。 在下一节中,我们将讨论此问题以及如何使用重新参数化技巧解决它。

重新参数化技巧

“图 8.1.1”的左侧显示了 VAE 网络。 编码器获取输入x,并估计潜向量z的多元高斯分布的平均值μ和标准差σ。 解码器从潜向量z中提取样本,以将输入重构为x_tilde。 这似乎很简单,直到在反向传播期间发生梯度更新为止:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-RHX1ocUI-1681704311664)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_01.png)]

图 8.1.1:带有和不带有重新参数化技巧的 VAE 网络

反向传播梯度将不会通过随机采样块。 尽管具有用于神经网络的随机输入是可以的,但梯度不可能穿过随机层。

解决此问题的方法是将采样处理作为输入,如“图 8.1.1”右侧所示。 然后,将样本计算为:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-h61Cc07u-1681704311664)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_101.png)] (Equation 8.1.13)

如果εσ以向量格式表示,则εσ是逐元素乘法。 使用“公式 8.1.13”,看起来好像采样直接来自潜在空间一样。 这项技术被称为重新参数化技巧

现在,在输入端发生采样时,可以使用熟悉的优化算法(例如 SGD,Adam 或 RMSProp)来训练 VAE 网络。

在讨论如何在tf.keras中实现 VAE 之前,让我们首先展示如何测试经过训练的解码器。

解码器测试

在训练了 VAE 网络之后,可以丢弃推理模型,包括加法和乘法运算符。 为了生成新的有意义的输出,请从用于生成ε的高斯分布中抽取样本。“图 8.1.2”向我们展示了解码器的测试设置:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-chhoz0du-1681704311664)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_02.png)]

图 8.1.2:解码器测试设置

通过重新参数化技巧解决了 VAE 上的最后一个问题,我们现在可以在tf.keras中实现和训练变分自编码器。

ALAS 与 Keras

VAE 的结构类似于典型的自编码器。 区别主要在于重新参数化技巧中的高斯随机变量的采样。“列表 8.1.1”显示了使用 MLP 实现的编码器,解码器和 VAE。

此代码也已添加到官方 Keras GitHub 存储库中

为便于显示潜在代码,将z的维设置为 2。编码器仅是两层 MLP,第二层生成均值和对数方差。 对数方差的使用是为了简化 KL 损失和重新参数化技巧的计算。 编码器的第三个输出是使用重新参数化技巧进行的z采样。 我们应该注意,在采样函数exp(0.5 log σ²) = sqrt(σ²) = σ中,因为σ > 0假定它是高斯分布的标准差。

解码器也是两层 MLP,它采用z的样本来近似输入。 编码器和解码器均使用大小为 512 的中间尺寸。

VAE 网络只是将编码器和解码器连接在一起。 loss函数是重建损失KL 损失的总和。 在默认的 Adam 优化器上,VAE 网络具有良好的效果。 VAE 网络中的参数总数为 807,700。

VAE MLP 的 Keras 代码具有预训练的权重。 要测试,我们需要运行:

python3 vae-mlp-mnist-8.1.1.py --weights=vae_mlp_mnist.tf

完整的代码可以在以下链接中找到

“列表 8.1.1”:vae-mlp-mnist-8.1.1.py

# reparameterization trick
# instead of sampling from Q(z|X), sample eps = N(0,I)
# z = z_mean + sqrt(var)*eps
def sampling(args):
    """Reparameterization trick by sampling 
        fr an isotropic unit Gaussian.
# Arguments:
        args (tensor): mean and log of variance of Q(z|X)
# Returns:
        z (tensor): sampled latent vector
    """
z_mean, z_log_var = args
    # K is the keras backend
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    # by default, random_normal has mean=0 and std=1.0
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_log_var) * epsilon
# MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()
image_size = x_train.shape[1]
original_dim = image_size * image_size
x_train = np.reshape(x_train, [-1, original_dim])
x_test = np.reshape(x_test, [-1, original_dim])
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255
# network parameters
input_shape = (original_dim, )
intermediate_dim = 512
batch_size = 128
latent_dim = 2
epochs = 50
# VAE model = encoder + decoder
# build encoder model
inputs = Input(shape=input_shape, name='encoder_input')
x = Dense(intermediate_dim, activation='relu')(inputs)
z_mean = Dense(latent_dim, name='z_mean')(x)
z_log_var = Dense(latent_dim, name='z_log_var')(x)
# use reparameterization trick to push the sampling out as input
# note that "output_shape" isn't necessary 
# with the TensorFlow backend
z = Lambda(sampling,
           output_shape=(latent_dim,),
           name='z')([z_mean, z_log_var])
# instantiate encoder model
encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder')
# build decoder model
latent_inputs = Input(shape=(latent_dim,), name='z_sampling')
x = Dense(intermediate_dim, activation='relu')(latent_inputs)
outputs = Dense(original_dim, activation='sigmoid')(x)
# instantiate decoder model
decoder = Model(latent_inputs, outputs, name='decoder')
# instantiate VAE model
outputs = decoder(encoder(inputs)[2])
vae = Model(inputs, outputs, name='vae_mlp')
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    help_ = "Load tf model trained weights"
    parser.add_argument("-w", "--weights", help=help_)
    help_ = "Use binary cross entropy instead of mse (default)"
    parser.add_argument("--bce", help=help_, action='store_true')
    args = parser.parse_args()
    models = (encoder, decoder)
    data = (x_test, y_test)
# VAE loss = mse_loss or xent_loss + kl_loss
    if args.bce:
        reconstruction_loss = binary_crossentropy(inputs,
                                                  outputs)
    else:
        reconstruction_loss = mse(inputs, outputs)
    reconstruction_loss *= original_dim
    kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
    kl_loss = K.sum(kl_loss, axis=-1)
    kl_loss *= -0.5
    vae_loss = K.mean(reconstruction_loss + kl_loss)
    vae.add_loss(vae_loss)
    vae.compile(optimizer='adam')

“图 8.1.3”显示了编码器模型,它是一个 MLP,具有两个输出,即潜向量的均值和方差。 lambda 函数实现了重新参数化技巧,将随机潜在代码的采样推送到 VAE 网络之外:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-rDkvi3aw-1681704311664)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_03.png)]

图 8.1.3:VAE MLP 的编码器模型

“图 8.1.4”显示了解码器模型。 2 维输入来自 lambda 函数。 输出是重构的 MNIST 数字:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-MiTE41SX-1681704311665)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_04.png)]

图 8.1.4:VAE MLP 的解码器模型

“图 8.1.5”显示了完整的 VAE 模型。 通过将编码器和解码器模型结合在一起制成:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-M5hizOkT-1681704311665)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_05.png)]

图 8.1.5:使用 MLP 的 VAE 模型

“图 8.1.6”显示了使用plot_results()在 50 个周期后潜向量的连续空间。 为简单起见,此函数未在此处显示,但可以在vae-mlp-mnist-8.1.1.py的其余代码中找到。 该函数绘制两个图像,即测试数据集标签(“图 8.1.6”)和样本生成的数字(“图 8.1.7”),这两个图像都是z的函数。 这两个图都说明了潜在向量如何确定所生成数字的属性:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Qe9OO9Ve-1681704311665)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_06.png)]

图 8.1.6:MNIST 数字标签作为测试数据集(VAE MLP)的潜在向量平均值的函数。 原始图像可以在该书的 GitHub 存储库中找到。

浏览时,连续空格始终会产生与 MNIST 数字相似的输出。 例如,数字 9 的区域接近数字 7 的区域。从中心附近的 9 移动到左下角会将数字变形为 7。从中心向上移动会将生成的数字从 3 更改为 5,最后变为 0.数字的变形在“图 8.1.7”中更明显,这是解释“图 8.1.6”的另一种方式。

在“图 8.1.7”中,显示生成器输出。 显示了潜在空间中数字的分布。 可以观察到所有数字都被表示。 由于中心附近分布密集,因此变化在中间迅速,在平均值较高的区域则缓慢。 我们需要记住,“图 8.1.7”是“图 8.1.6”的反映。 例如,数字 0 在两个图的左上象限中,而数字 1 在右下象限中。

“图 8.1.7”中存在一些无法识别的数字,尤其是在右上象限中。 从“图 8.1.6”可以看出,该区域大部分是空的,并且远离中心:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ykRPrEwz-1681704311665)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_07.png)]

图 8.1.7:根据潜在向量平均值(VAE MLP)生成的数字。 为了便于解释,均值的范围类似于图 8.1.6

在本节中,我们演示了如何在 MLP 中实现 VAE。 我们还解释了导航潜在空间的结果。 在的下一部分中,我们将使用 CNN 实现相同的 VAE。

带有 CNN 的 AE

在原始论文《自编码变分贝叶斯》[1]中,使用 MLP 来实现 VAE 网络,这与我们在上一节中介绍的类似。 在本节中,我们将证明使用 CNN 将显着提高所产生数字的质量,并将参数数量显着减少至 134,165。

“列表 8.1.3”显示了编码器,解码器和 VAE 网络。 该代码也被添加到了官方的 Keras GitHub 存储库中

为简洁起见,不再显示与 MLP VAE 类似的某些代码行。 编码器由两层 CNN 和两层 MLP 组成,以生成潜在代码。 编码器的输出结构与上一节中看到的 MLP 实现类似。 解码器由一层Dense和三层转置的 CNN 组成。

VAE CNN 的 Keras 代码具有预训练的权重。 要测试,我们需要运行:

python3 vae-cnn-mnist-8.1.2.py --weights=vae_cnn_mnist.tf

“列表 8.1.3”:vae-cnn-mnist-8.1.2.py

使用 CNN 层的tf.keras中的 VAE:

# network parameters
input_shape = (image_size, image_size, 1)
batch_size = 128
kernel_size = 3
filters = 16
latent_dim = 2
epochs = 30
# VAE model = encoder + decoder
# build encoder model
inputs = Input(shape=input_shape, name='encoder_input')
x = inputs
for i in range(2):
    filters *= 2
    x = Conv2D(filters=filters,
               kernel_size=kernel_size,
               activation='relu',
               strides=2,
               padding='same')(x)
# shape info needed to build decoder model
shape = K.int_shape(x)
# generate latent vector Q(z|X)
x = Flatten()(x)
x = Dense(16, activation='relu')(x)
z_mean = Dense(latent_dim, name='z_mean')(x)
z_log_var = Dense(latent_dim, name='z_log_var')(x)
# use reparameterization trick to push the sampling out as input
# note that "output_shape" isn't necessary 
# with the TensorFlow backend
z = Lambda(sampling,
           output_shape=(latent_dim,),
           name='z')([z_mean, z_log_var])
# instantiate encoder model
encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder')
# build decoder model
latent_inputs = Input(shape=(latent_dim,), name='z_sampling')
x = Dense(shape[1] * shape[2] * shape[3],
          activation='relu')(latent_inputs)
x = Reshape((shape[1], shape[2], shape[3]))(x)
for i in range(2):
    x = Conv2DTranspose(filters=filters,
                        kernel_size=kernel_size,
                        activation='relu',
                        strides=2,
                        padding='same')(x)
    filters //= 2
outputs = Conv2DTranspose(filters=1,
                          kernel_size=kernel_size,
                          activation='sigmoid',
                          padding='same',
                          name='decoder_output')(x)
# instantiate decoder model
decoder = Model(latent_inputs, outputs, name='decoder')
# instantiate VAE model
outputs = decoder(encoder(inputs)[2])
vae = Model(inputs, outputs, name='vae')

“图 8.1.8”显示了 CNN 编码器模型的两个输出,即潜向量的均值和方差。 lambda 函数实现了重新参数化技巧,将随机潜码的采样推送到 VAE 网络之外:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-rTFwOwuA-1681704311666)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_08.png)]

图 8.1.8:VAE CNN 的编码器

“图 8.1.9”显示了 CNN 解码器模型。 2 维输入来自 lambda 函数。 输出是重构的 MNIST 数字:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Wd7Nf6eD-1681704311666)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_09.png)]

图 8.1.9:VAE CNN 的解码器

“图 8.1.10”显示完整的 CNN VAE 模型。 通过将编码器和解码器模型结合在一起制成:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-WMfD8AAr-1681704311666)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_10.png)]

图 8.1.10:使用 CNN 的 VAE 模型

对 VAE 进行了 30 个周期的训练。“图 8.1.11”显示了在导航 VAE 的连续潜在空间时数字的分布。 例如,从中间到右边从 2 变为 0:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-BIkuhCUR-1681704311666)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_11.png)]

图 8.1.11:MNIST 数字标签作为测试数据集(VAE CNN)的潜在向量平均值的函数。 原始图像可以在该书的 GitHub 存储库中找到。

“图 8.1.12”向我们展示了生成模型的输出。 从质量上讲,与“图 8.1.7”(具有 MLP 实现)相比,模棱两可的位数更少:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-5l35MO5C-1681704311666)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_12.png)]

图 8.1.12:根据潜在向量平均值(VAE CNN)生成的数字。 为了便于解释,均值的范围类似于图 8.1.11

前的两节讨论了使用 MLP 或 CNN 的 VAE 的实现。 我们分析了两种实现方式的结果,结果表明 CNN 可以减少参数数量并提高感知质量。 在下一节中,我们将演示如何在 VAE 中实现条件,以便我们可以控制要生成的数字。

2. 条件 VAE(CVAE)

有条件的 VAE [2]与 CGAN 相似。 在 MNIST 数据集的上下文中,如果随机采样潜在空间,则 VAE 无法控制将生成哪个数字。 CVAE 可以通过包含要产生的数字的条件(单标签)来解决此问题。 该条件同时施加在编码器和解码器输入上。

正式地,将“公式 8.1.10”中 VAE 的核心公式修改为包括条件c

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-KQcqohNE-1681704311666)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_113.png)] (Equation 8.2.1)

与 VAE 相似,“公式 8.2.1”表示如果要最大化输出条件cP[θ](x | c),则必须最小化两个损失项:

  • 给定潜在向量和条件,解码器的重建损失。
  • 给定潜在向量和条件的编码器之间的 KL 损失以及给定条件的先验分布。 与 VAE 相似,我们通常选择P[θ](x | c) = P(x | c) = N(0, 1)

实现 CVAE 需要对 VAE 的代码进行一些修改。 对于 CVAE,使用 VAE CNN 实现是因为它可以形成一个较小的网络,并产生感知上更好的数字。

“列表 8.2.1”突出显示了针对 MNIST 数字的 VAE 原始代码所做的更改。 编码器输入现在是原始输入图像及其单标签的连接。 解码器输入现在是潜在空间采样与其应生成的图像的一键热标签的组合。 参数总数为 174,437。 与 β-VAE 相关的代码将在本章下一节中讨论。

损失函数没有改变。 但是,在训练,测试和结果绘制过程中会提供单热标签。

“列表 8.2.1”:cvae-cnn-mnist-8.2.1.py

tf.keras中使用 CNN 层的 CVAE。 重点介绍了为支持 CVAE 而进行的更改:

# compute the number of labels
num_labels = len(np.unique(y_train))
# network parameters
input_shape = (image_size, image_size, 1)
label_shape = (num_labels, )
batch_size = 128
kernel_size = 3
filters = 16
latent_dim = 2
epochs = 30
# VAE model = encoder + decoder
# build encoder model
inputs = Input(shape=input_shape, name='encoder_input')
y_labels = Input(shape=label_shape, name='class_labels')
x = Dense(image_size * image_size)(y_labels)
x = Reshape((image_size, image_size, 1))(x)
x = keras.layers.concatenate([inputs, x])
for i in range(2):
    filters *= 2
    x = Conv2D(filters=filters,
               kernel_size=kernel_size,
               activation='relu',
               strides=2,
               padding='same')(x)
# shape info needed to build decoder model
shape = K.int_shape(x)
# generate latent vector Q(z|X)
x = Flatten()(x)
x = Dense(16, activation='relu')(x)
z_mean = Dense(latent_dim, name='z_mean')(x)
z_log_var = Dense(latent_dim, name='z_log_var')(x)
# use reparameterization trick to push the sampling out as input
# note that "output_shape" isn't necessary 
# with the TensorFlow backend
z = Lambda(sampling,
           output_shape=(latent_dim,),
           name='z')([z_mean, z_log_var])
# instantiate encoder model
encoder = Model([inputs, y_labels],
                [z_mean, z_log_var, z],
                name='encoder')
# build decoder model
latent_inputs = Input(shape=(latent_dim,), name='z_sampling')
x = concatenate([latent_inputs, y_labels])
x = Dense(shape[1]*shape[2]*shape[3], activation='relu')(x)
x = Reshape((shape[1], shape[2], shape[3]))(x)
for i in range(2):
    x = Conv2DTranspose(filters=filters,
                        kernel_size=kernel_size,
                        activation='relu',
                        strides=2,
                        padding='same')(x)
    filters //= 2
outputs = Conv2DTranspose(filters=1,
                          kernel_size=kernel_size,
                          activation='sigmoid',
                          padding='same',
                          name='decoder_output')(x)
# instantiate decoder model
decoder = Model([latent_inputs, y_labels],
                outputs,
                name='decoder')
# instantiate vae model
outputs = decoder([encoder([inputs, y_labels])[2], y_labels])
cvae = Model([inputs, y_labels], outputs, name='cvae')

“图 8.2.1”显示了 CVAE 模型的编码器。 附加输入,即单热向量class_labels形式的条件标签表示:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bLvR6Phc-1681704311667)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_13.png)]

图 8.2.1:CVAE CNN 中的编码器。 输入现在包括 VAE 输入和条件标签的连接

“图 8.2.2”显示了 CVAE 模型的解码器。 附加输入,即单热向量class_labels形式的条件标签表示:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-1AJ1jTMJ-1681704311667)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_14.png)]

图 8.2.2:CVAE CNN 中的解码器。 输入现在包括 z 采样和条件标签的连接

“图 8.2.3”显示了完整的 CVAE 模型,该模型是编码器和解码器结合在一起的。 附加输入,即单热向量class_labels形式的条件标签:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-yFcG8oOh-1681704311667)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_15.png)]

图 8.2.3:使用 CNN 的 CVAE 模型。输入现在包含一个 VAE 输入和一个条件标签

在“图 8.2.4”中,每个标记的平均值分布在 30 个周期后显示。 与前面章节中的“图 8.1.6”和“图 8.1.11”不同,每个标签不是集中在一个区域上,而是分布在整个图上。 这是预期的,因为潜在空间中的每个采样都应生成一个特定的数字。 浏览潜在空间会更改该特定数字的属性。 例如,如果指定的数字为 0,则在潜伏空间中导航仍将产生 0,但是诸如倾斜角度,厚度和其他书写样式方面的属性将有所不同。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-tB1jiTpc-1681704311667)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_16.png)]

图 8.2.4:作为测试数据集(CVAE CNN)的潜在向量平均值的函数的 MNIST 数字标签。 原始图像可以在该书的 GitHub 存储库中找到。

“图 8.2.4”在“图 8.2.5”中更清楚地显示,数字 0 到 5。每个帧都有相同的数字,并且属性在我们浏览时顺畅地变化。 潜在代码:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-lKyT6wAx-1681704311668)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_17.png)]

图 8.2.5:根据潜在向量平均值和单热点标签(CVAE CNN)生成的数字 0 至 5。 为了便于解释,均值的范围类似于图 8.2.4。

“图 8.2.6”显示“图 8.2.4”,用于数字 6 至 9:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-fr1yGOiU-1681704311668)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_18.png)]

图 8.2.6:根据潜在向量平均值和单热点标签(CVAE CNN)生成的数字 6 至 9。 为了便于解释,均值的范围类似于图 8.2.4。

为了便于比较,潜向量的值范围与“图 8.2.4”中的相同。 使用预训练的权重,可以通过执行以下命令来生成数字(例如 0):

python3 cvae-cnn-mnist-8.2.1.py –bce --weights=cvae_cnn_mnist.tf --digit=0

在“图 8.2.5”和“图 8.2.6”中,可以注意到,每个数字的宽度和圆度(如果适用)随z[0]的变化而变化。 从左到右追踪。 同时,当z[1]从上到下导航时,每个数字的倾斜角度和圆度(如果适用)也会发生变化。 随着我们离开分布中心,数字的图像开始退化。 这是可以预期的,因为潜在空间是一个圆形。

属性中其他明显的变化可能是数字特定的。 例如,数字 1 的水平笔划(手臂)在左上象限中可见。 数字 7 的水平笔划(纵横线)只能在右象限中看到。

在下一节中,我们将发现 CVAE 实际上只是另一种称为 β-VAE 的 VAE 的特例。

3. β-VAE – 具有纠缠的潜在表示形式的 VAE

在“第 6 章”,“非纠缠表示 GAN”中,讨论了潜码非纠缠表示的概念和重要性。 我们可以回想起,一个纠缠的表示是单个潜伏单元对单个生成因子的变化敏感,而相对于其他因子的变化相对不变[3]。 更改潜在代码会导致生成的输出的一个属性发生更改,而其余属性保持不变。

在同一章中,InfoGAN [4]向我们展示了对于 MNIST 数据集,可以控制生成哪个数字以及书写样式的倾斜度和粗细。 观察上一节中的结果,可以注意到,VAE 在本质上使潜向量维解开了一定程度。 例如,查看“图 8.2.6”中的数字 8,从上到下导航z[1]会减小宽度和圆度,同时顺时针旋转数字。 从左至右增加z[0]也会在逆时针旋转数字时减小宽度和圆度。 换句话说,z[1]控制顺时针旋转,而z[0]影响逆时针旋转,并且两者都改变宽度和圆度。

在本节中,我们将演示对 VAE 损失函数的简单修改会迫使潜在代码进一步解开纠缠。 修改为正恒重β > 1,用作 KL 损失的调节器:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ZFv9spjc-1681704311668)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_121.png)] (Equation 8.3.1)

VAE 的这种变化称为 β-VAE [5]。 β的隐含效果是更严格的标准差。 换句话说,β强制后验分布中的潜码Q[φ](z | x)独立。

实现 β-VAE 很简单。 例如,对于上一个示例中的 CVAE,所需的修改是kl_loss中的额外beta因子:

kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
kl_loss = K.sum(kl_loss, axis=-1)
kl_loss *= -0.5 * beta

CVAE 是 β-VAE 的特例,其中β = 1。 其他一切都一样。 但是,确定的值需要一些反复试验。 为了潜在的代码独立性,在重构误差和正则化之间必须有一个仔细的平衡。 解缠最大在β = 9附近。 当中β = 9的值时,β-VAE 仅被迫学习一个解纠缠的表示,而忽略另一个潜在维度。

“图 8.3.1”和“图 8.3.2”显示 β-VAE 的潜向量平均值,其中β = 9β = 10

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-lrcWwni7-1681704311668)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_19.png)]

图 8.3.1:MNIST 数字标签与测试数据集的潜在向量平均值的函数(β-VAE,β = 9)。 原始图像可以在该书的 GitHub 存储库中找到。

β = 9时,与 CVAE 相比,分布具有较小的标准差。 在β = 10的情况下,仅学习了潜在代码。 分布实际上缩小为一个维度,编码器和解码器忽略了第一潜码z[0]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-M0IexeCU-1681704311668)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_20.png)]

图 8.3.2:MNIST 数字标签与测试数据集的潜向量平均值的函数(β-VAE 和β = 10

原始图像可以在该书的 GitHub 存储库中找到

这些观察结果反映在“图 8.3.3”中。 具有β = 9的 β-VAE 具有两个实际上独立的潜在代码。 z[0]确定书写样式的倾斜度,而z[1]指定数字的宽度和圆度(如果适用)。 对于中β = 10的 β-VAE,z[0]被静音。 z[0]的增加不会显着改变数字。z[1]确定书写样式的倾斜角度和宽度:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-esbGFOdP-1681704311669)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_08_21.png)]

图 8.3.3:根据潜在向量平均值和单热点标签(β-VAE,β = 1, 9, 10)生成的数字 0 至 3。 为了便于解释,均值的范围类似于图 8.3.1。

β-VAE 的tf.keras代码具有预训练的权重。 要使用β = 9生成数字 0 来测试 β-VAE,我们需要运行以下命令:

python3 cvae-cnn-mnist-8.2.1.py --beta=9 --bce --weights=beta-cvae_cnn_mnist.tf --digit=0

总而言之,我们已经证明与 GAN 相比,在 β-VAE 上更容易实现解缠表示学习。 我们所需要做的就是调整单个超参数。

4. 总结

在本章中,我们介绍了 VAE 的原理。 正如我们从 VAE 原理中学到的那样,从两次尝试从潜在空间创建合成输出的角度来看,它们都与 GAN 相似。 但是,可以注意到,与 GAN 相比,VAE 网络更简单,更容易训练。 越来越清楚的是 CVAE 和 β-VAE 在概念上分别类似于条件 GAN 和解缠表示 GAN。

VAE 具有消除潜在向量纠缠的内在机制。 因此,构建 β-VAE 很简单。 但是,我们应该注意,可解释和解开的代码对于构建智能体很重要。

在下一章中,我们将专注于强化学习。 在没有任何先验数据的情况下,智能体通过与周围的世界进行交互来学习。 我们将讨论如何为智能体的正确行为提供奖励,并为错误的行为提供惩罚。

5. 参考

  1. Diederik P. Kingma and Max Welling. Auto-encoding Variational Bayes. arXiv preprint arXiv:1312.6114, 2013 (https://arxiv.org/pdf/1312.6114.pdf).
  2. Kihyuk Sohn, Honglak Lee, and Xinchen Yan. Learning Structured Output Representation Using Deep Conditional Generative Models. Advances in Neural Information Processing Systems, 2015 (http://papers.nips.cc/paper/5775-learning-structured-output-representation-using-deep-conditional-generative-models.pdf).
  3. Yoshua Bengio, Aaron Courville, and Pascal Vincent. Representation Learning.
  4. A Review and New Perspectives. IEEE transactions on Pattern Analysis and Machine Intelligence 35.8, 2013: 1798-1828 (https://arxiv.org/pdf/1206.5538.pdf).
  5. Xi Chen et al.: Infogan: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets. Advances in Neural Information Processing Systems, 2016 (http://papers.nips.cc/paper/6399-infogan-interpretable-representation-learning-by-information-maximizing-generative-adversarial-nets.pdf).
  6. I. Higgins, L. Matthey, A. Pal, C. Burgess, X. Glorot, M. Botvinick, S. Mohamed, and A. Lerchner. -VAE: Learning Basic Visual Concepts with a Constrained Variational Framework. ICLR, 2017 (https://openreview.net/pdf?id=Sy2fzU9gl).
  7. Carl Doersch. Tutorial on variational autoencoders. arXiv preprint arXiv:1606.05908, 2016 (https://arxiv.org/pdf/1606.05908.pdf).

九、深度强化学习

强化学习RL)是智能体程序用于决策的框架。 智能体不一定是软件实体,例如您在视频游戏中可能看到的那样。 相反,它可以体现在诸如机器人或自动驾驶汽车之类的硬件中。 内在的智能体可能是充分理解和利用 RL 的最佳方法,因为物理实体与现实世界进行交互并接收响应。

该智能体位于环境中。 环境具有状态,可以部分或完全观察到。 该智能体具有一组操作,可用于与环境交互。 动作的结果将环境转换为新状态。 执行动作后,会收到相应的标量奖励

智能体的目标是通过学习策略来最大化累积的未来奖励,该策略将决定在特定状态下应采取的行动。

RL 与人类心理学有很强的相似性。 人类通过体验世界来学习。 错误的行为会导致某种形式的惩罚,将来应避免使用,而正确的行为应得到奖励并应予以鼓励。 这种与人类心理学的强相似之处使许多研究人员相信 RL 可以将引向真正的人工智能AI)。

RL 已经存在了几十年。 但是,除了简单的世界模型之外,RL 还在努力扩展规模。 这是,其中深度学习DL)开始发挥作用。 它解决了这个可扩展性问题,从而开启了深度强化学习DRL)的时代。 在本章中,我们的重点是 DRL。 DRL 中值得注意的例子之一是 DeepMind 在智能体上的工作,这些智能体能够在不同的视频游戏上超越最佳的人类表现。

在本章中,我们将讨论 RL 和 DRL。

总之,本章的目的是介绍:

  • RL 的原理
  • RL 技术,Q 学习
  • 高级主题,包括深度 Q 网络DQN)和双重 Q 学习DDQN
  • 关于如何使用tf.keras在 Python 和 DRL 上实现 RL 的说明

让我们从 RL 的基本原理开始。

1. 强化学习原理(RL)

“图 9.1.1”显示了用于描述 RL 的感知动作学习循环。 环境是苏打水可以坐在地板上。 智能体是一个移动机器人,其目标是拾取苏打水。 它观察周围的环境,并通过车载摄像头跟踪汽水罐的位置。 观察结果以一种状态的形式进行了汇总,机器人将使用该状态来决定要采取的动作。 所采取的动作可能与低级控制有关,例如每个车轮的旋转角度/速度,手臂的每个关节的旋转角度/速度以及抓手是打开还是关闭。

可替代地,动作可以是高级控制动作,诸如向前/向后移动机器人,以特定角度转向以及抓取/释放。 将夹持器从汽水中移开的任何动作都会得到负回报。 缩小抓取器位置和苏打之间的缝隙的任何动作都会获得积极的回报。 当机械臂成功捡起汽水罐时,它会收到丰厚的回报。 RL 的目标是学习最佳策略,该策略可帮助机器人决定在给定状态下采取哪种行动以最大化累积的折扣奖励:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bgGcFrjL-1681704311669)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_09_01.png)]

图 9.1.1:RL 中的感知-动作-学习循环

形式上,RL 问题可以描述为 Markov 决策过程MDP)。

为简单起见,我们将假定为确定性环境,在该环境中,给定状态下的某个动作将始终导致已知的下一个状态和奖励。 在本章的后面部分,我们将研究如何考虑随机性。 在时间步t时:

  • 环境处于状态空间S的状态下,状态s[0],该状态可以是离散的也可以是连续的。 起始状态为s[0],而终止状态为s[T]
  • 智能体通过遵循策略π(a[t] | s[t])从操作空间A采取操作,即s[a]A可以是离散的或连续的。
  • 环境使用状态转换动态T(s[t + 1] | s[t], a[t])转换为新状态,s[t + 1]。 下一个状态仅取决于当前状态和操作。 智能体不知道T
  • 智能体使用奖励函数接收标量奖励,r[t + 1] = R(s[t], a[t]),以及r: A x S -> R。 奖励仅取决于当前状态和操作。 智能体不知道R
  • 将来的奖励折扣为γ^k,其中γ ∈ [0, 1]k是未来的时间步长。
  • 地平线H是完成从s[0]s[T]的一集所需的时间步长T

该环境可以是完全或部分可观察的。 后者也称为部分可观察的 MDPPOMDP。 在大多数情况下,完全观察环境是不现实的。 为了提高的可观察性,当前的观测值也考虑了过去的观测值。 状态包括对环境的足够观察,以使策略决定采取哪种措施。 回忆“图 9.1.1”,这可能是汽水罐相对于机器人抓手的三维位置,如机器人摄像头所估计的那样。

每当环境转换到新状态时,智能体都会收到标量奖励r[t + 1]。 在“图 9.1.1”中,每当机器人靠近汽水罐时,奖励可能为 +1;当机器人离汽水罐更远时,奖励为 -1;当机器人关闭夹具并成功捡起苏打时,奖励为 +100。 能够。 智能体的目标是学习一种最佳策略π*,该策略可使所有状态的收益最大化:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-f85hflJU-1681704311669)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/14853_09_012.png)] (Equation 9.1.1)

回报定义为折扣累积奖励R[t] = Σ γ^t r[t+k], k = 0, ..., T。 从“公式 9.1.1”可以看出,与通常的γ^k < 1.0相比,与立即获得的奖励相比,未来的奖励权重较低。 在极端情况下,当γ = 0时,仅立即获得奖励很重要。 当γ = 1时,将来的奖励与立即奖励的权重相同。

遵循任意策略π,可以将回报解释为对给定状态值的度量:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-byZLyWAj-1681704311676)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/14853_09_019.png)] (Equation 9.1.2)

换句话说,RL 问题是智能体的目标,是学习使所有状态s最大化的最优策略V^π

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-CTBNheSJ-1681704311676)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/14853_09_021.png)] (Equation 9.1.3)

最优策略的值函数就是V*。 在“图 9.1.1”中,最佳策略是生成最短动作序列的一种,该动作序列使机器人越来越靠近苏打罐,直到被取走为止。 状态越接近目标状态,其值越高。 可以将导致目标(或最终状态)的事件序列建模为策略的轨迹部署

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-klKvBKtv-1681704311676)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/14853_09_023.png)] (Equation 9.1.4)

如果 MDP 是偶发的,则当智能体到达终端状态s[T]时,状态将重置为s[0]。 如果T是有限的,则我们的水平范围是有限的。 否则,视野是无限的。 在“图 9.1.1”中,如果 MDP 是情景剧集,则在收集苏打罐后,机器人可能会寻找另一个苏打罐来拾取,并且 RL 问题重发。

因此,RL 的主要目标是找到一种使每个状态的值最大化的策略。 在下一部分中,我们将介绍可用于最大化值函数的策略学习算法。

2. Q 值

如果 RL 问题是找到π*,则智能体如何通过与环境交互来学习?“公式 9.1.3”并未明确指出尝试进行的操作以及计算收益的后续状态。 在 RL 中,使用 Q 值更容易学习π*

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-tAIkVX5O-1681704311677)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/14853_09_026.png)] (Equation 9.2.1)

哪里:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-kltRYhYh-1681704311677)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/14853_09_027.png)] (Equation 9.2.2)

换句话说,不是找到使所有状态的值最大化的策略,而是“公式 9.2.1”寻找使所有状态的质量(Q)值最大化的操作。 在找到 Q 值函数之后,分别由“公式 9.2.2”和“公式 9.1.3”确定V*,因此确定了π*

如果对于每个动作,都可以观察到奖励和下一状态,则可以制定以下迭代或反复试验算法来学习 Q 值:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-eH928w5D-1681704311677)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/14853_09_030.png)] (Equation 9.2.3)

为了简化符号,s'a'分别是下一个状态和动作。 “公式 9.2.3”被称为贝尔曼方程,它是 Q 学习算法的核心。 Q 学习尝试根据当前状态和作用来近似返回值或值的一阶展开(“公式 9.1.2”)。 从对环境动态的零知识中,智能体尝试执行操作a,观察以奖励r和下一个状态s'的形式发生的情况。 max[a'] Q(s', a')选择下一个逻辑动作,该动作将为下一个状态提供最大 Q 值。 有了“公式 9.2.3”中的所有项,该当前状态-动作对的 Q 值就会更新。 迭代地执行更新将最终使智能体能够学习 Q 值函数。

Q 学习是一种脱离策略 RL 算法。 它学习了如何通过不直接从策略中抽取经验来改进策略。 换句话说,Q 值的获取与智能体所使用的基础策略无关。 当 Q 值函数收敛时,才使用“公式 9.2.1”确定最佳策略。

在为提供有关如何使用 Q 学习的示例之前,请注意,智能体必须在不断利用其到目前为止所学知识的同时不断探索其环境。 这是 RL 中的问题之一-在探索开发之间找到适当的平衡。 通常,在学习开始时,动作是随机的(探索)。 随着学习的进行,智能体会利用 Q 值(利用)。 例如,一开始,90% 的动作是随机的,而 10% 的动作则来自 Q 值函数。 在每个剧集的结尾,这逐渐减少。 最终,该动作是 10% 随机的,并且是 Q 值函数的 90%。

在下一节中,我们将给出有关在简单的确定性环境中如何使用 Q 学习的具体示例。

TensorFlow 2 和 Keras 高级深度学习:6~10(4)https://developer.aliyun.com/article/1426953

相关文章
|
6天前
|
机器学习/深度学习 数据可视化 测试技术
深度学习:Keras使用神经网络进行简单文本分类分析新闻组数据
深度学习:Keras使用神经网络进行简单文本分类分析新闻组数据
19 0
|
6天前
|
机器学习/深度学习 运维 监控
TensorFlow分布式训练:加速深度学习模型训练
【4月更文挑战第17天】TensorFlow分布式训练加速深度学习模型训练,通过数据并行和模型并行利用多机器资源,减少训练时间。优化策略包括配置计算资源、优化数据划分和减少通信开销。实际应用需关注调试监控、系统稳定性和容错性,以应对分布式训练挑战。
|
6天前
|
机器学习/深度学习 API TensorFlow
TensorFlow的高级API:tf.keras深度解析
【4月更文挑战第17天】本文深入解析了TensorFlow的高级API `tf.keras`,包括顺序模型和函数式API的模型构建,以及模型编译、训练、评估和预测的步骤。`tf.keras`结合了Keras的易用性和TensorFlow的性能,支持回调函数、模型保存与加载等高级特性,助力提升深度学习开发效率。
|
6天前
|
机器学习/深度学习 API 算法框架/工具
R语言深度学习:用keras神经网络回归模型预测时间序列数据
R语言深度学习:用keras神经网络回归模型预测时间序列数据
16 0
|
6天前
|
机器学习/深度学习 数据采集 TensorFlow
R语言KERAS深度学习CNN卷积神经网络分类识别手写数字图像数据(MNIST)
R语言KERAS深度学习CNN卷积神经网络分类识别手写数字图像数据(MNIST)
23 0
|
6天前
|
机器学习/深度学习 自然语言处理 算法框架/工具
用于NLP的Python:使用Keras进行深度学习文本生成
用于NLP的Python:使用Keras进行深度学习文本生成
17 2
|
8天前
|
机器学习/深度学习 人工智能 算法框架/工具
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(八)(4)
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(八)
22 0
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(八)(4)
|
8天前
|
机器学习/深度学习 算法框架/工具 TensorFlow
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(七)(4)
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(七)
40 0
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(七)(4)
|
机器学习/深度学习 算法 算法框架/工具
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(七)(3)
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(七)
9 0
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(七)(3)
|
8天前
|
机器学习/深度学习 算法框架/工具 自然语言处理
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(七)(1)
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(七)
28 0

热门文章

最新文章