TensorFlow 2 和 Keras 高级深度学习:1~5(4)

简介: TensorFlow 2 和 Keras 高级深度学习:1~5(4)

TensorFlow 2 和 Keras 高级深度学习:1~5(3)https://developer.aliyun.com/article/1426946

训练过程中将两个小批数据提供给判别器:

  1. x,来自采样数据的实数据(换言之,x ~ p_data),标签为 1.0
  2. x' = g(z),来自生成器的带有标签 0.0 的伪造数据

为了使的损失函数最小,将通过反向传播通过正确识别真实数据D(x)和合成数据1 - D(g(z))来更新判别器参数θ^(D)。 正确识别真实数据等同于D(x) -> 1.0,而正确分类伪造数据则与D(g(z)) -> 0.01 - D(g(z)) -> 1.0相同。 在此等式中,z是生成器用来合成新信号的任意编码或噪声向量。 两者都有助于最小化损失函数。

为了训练生成器,GAN 将判别器和生成器损失的总和视为零和博弈。 生成器损失函数只是判别器损失函数的负数:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-hwZxxYdy-1681704179662)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_04_012.png)] (Equation 4.1.2)

然后可以将其更恰当地重写为值函数:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Jkbl4vtg-1681704179662)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_04_013.png)] (Equation 4.1.3)

从生成器的角度来看,应将“公式 4.1.3”最小化。 从判别器的角度来看,值函数应最大化。 因此,生成器训练准则可以写成极大极小问题:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-WwT5w5qV-1681704179662)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_04_014.png)] (Equation 4.1.4)

有时,我们会假装合成数据是带有标签 1.0 的真实数据,以此来欺骗判别器。 通过最大化θ^(D),优化器将梯度更新发送到判别器参数,以将该合成数据视为真实数据。 同时,通过将θ^(G)的相关性减至最小,优化器将在上训练生成器的参数,从而欺骗识别器。 但是,实际上,判别器对将合成数据分类为伪造的预测很有信心,并且不会更新 GAN 参数。 此外,梯度更新很小,并且在传播到生成器层时已大大减小。 结果,生成器无法收敛。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-l92TD8vX-1681704179662)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_04_04.png)]

图 4.1.4:训练生成器就像使用二进制交叉熵损失函数训练网络一样。 来自生成器的虚假数据显示为真实数据

解决方案是按以下形式重新构造生成器的损失函数:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-L8euS4xk-1681704179663)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_04_017.png)] (Equation 4.1.5)

损失函数只是通过训练生成器,最大程度地提高了判别器认为合成数据是真实数据的机会。 新公式不再是零和,而是纯粹由启发式驱动的。“图 4.1.4”显示了训练过程中的生成器。 在此图中,仅在训练整个对抗网络时才更新生成器参数。 这是因为梯度从判别器向下传递到生成器。 但是,实际上,判别器权重仅在对抗训练期间临时冻结。

在深度学习中,可以使用合适的神经网络架构来实现生成器和判别器。 如果数据或信号是图像,则生成器和判别器网络都将使用 CNN。 对于诸如音频之类的一维序列,两个网络通常都是循环的(RNN,LSTM 或 GRU)。

在本节中,我们了解到 GAN 的原理很简单。 我们还了解了如何通过熟悉的网络层实现 GAN。 GAN 与其他网络的区别在于众所周知,它们很难训练。 只需稍作更改,就可以使网络变得不稳定。 在以下部分中,我们将研究使用深度 CNN 的 GAN 早期成功实现之一。 它称为 DCGAN [3]。

2. 在 Keras 中实现 DCGAN

“图 4.2.1”显示 DCGAN,其中用于生成伪造的 MNIST 图像:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-hUVTZtOt-1681704179663)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_04_05.png)]

图 4.2.1:DCGAN 模型

DCGAN 实现以下设计原则:

  • 使用stride > 1和卷积代替MaxPooling2DUpSampling2D。 通过stride > 1,CNN 可以学习如何调整特征映射的大小。
  • 避免使用Dense层。 在所有层中使用 CNN。 Dense层仅用作生成器的第一层以接受z向量。 调整Dense层的输出大小,并成为后续 CNN 层的输入。
  • 使用批量归一化BN),通过将每一层的输入归一化以使均值和单位方差为零,来稳定学习。 生成器输出层和判别器输入层中没有 BN。 在此处要介绍的实现示例中,没有在标识符中使用批量归一化。
  • 整流线性单元ReLU)在生成器的所有层中均使用,但在输出层中则使用tanh激活。 在此处要介绍的实现示例中,在生成器的输出中使用sigmoid代替tanh,因为通常会导致对 MNIST 数字进行更稳定的训练。
  • 在判别器的所有层中使用 Leaky ReLU。 与 ReLU 不同,Leaky ReLU 不会在输入小于零时将所有输出清零,而是生成一个等于alpha x input的小梯度。 在以下示例中,alpha = 0.2

生成器学习从 100 维输入向量([-1.0,1.0]范围内具有均匀分布的 100 维随机噪声)生成伪图像。 判别器将真实图像与伪图像分类,但是在训练对抗网络时无意中指导生成器如何生成真实图像。 在我们的 DCGAN 实现中使用的核大小为 5。这是为了允许它增加卷积的接收场大小和表达能力。

生成器接受由 -1.0 到 1.0 范围内的均匀分布生成的 100 维z向量。 生成器的第一层是7 x 7 x 128 = 6,272单元的密集层。 基于输出图像的预期最终尺寸(28 x 28 x 1,28 是 7 的倍数)和第一个Conv2DTranspose的过滤器数量(等于 128)来计算单元数量。

我们可以将转置的 CNN(Conv2DTranspose)想象成 CNN 的逆过程。 在一个简单的示例中,如果 CNN 将图像转换为特征映射,则转置的 CNN 将生成给定特征映射的图像。 因此,转置的 CNN 在上一章的解码器中和本章的生成器中使用。

在对strides = 2进行两个Conv2DTranspose之后,特征映射的大小将为28 x 28 x n_filter。 每个Conv2DTranspose之前都有批量规范化和 ReLU。 最后一层具有 Sigmoid 激活,可生成28 x 28 x 1假 MNIST 图像。 将每个像素标准化为与[0, 255]灰度级相对应的[0.0, 1.0]。 下面的“列表 4.2.1”显示了tf.keras中生成器网络的实现。 定义了一个函数来生成生成器模型。 由于整个代码的长度,我们将列表限制为正在讨论的特定行。

完整的代码可在 GitHub 上获得

“列表 4.2.1”:dcgan-mnist-4.2.1.py

def build_generator(inputs, image_size):
    """Build a Generator Model
Stack of BN-ReLU-Conv2DTranpose to generate fake images
    Output activation is sigmoid instead of tanh in [1].
    Sigmoid converges easily.
Arguments:
        inputs (Layer): Input layer of the generator 
            the z-vector)
        image_size (tensor): Target size of one side
            (assuming square image)
Returns:
        generator (Model): Generator Model
    """
image_resize = image_size // 4
    # network parameters 
    kernel_size = 5
    layer_filters = [128, 64, 32, 1]
x = Dense(image_resize * image_resize * layer_filters[0])(inputs)
    x = Reshape((image_resize, image_resize, layer_filters[0]))(x)
for filters in layer_filters:
        # first two convolution layers use strides = 2
        # the last two use strides = 1
        if filters > layer_filters[-2]:
            strides = 2
        else:
            strides = 1
        x = BatchNormalization()(x)
        x = Activation('relu')(x)
        x = Conv2DTranspose(filters=filters,
                            kernel_size=kernel_size,
                            strides=strides,
                            padding='same')(x)
x = Activation('sigmoid')(x)
    generator = Model(inputs, x, name='generator')
    return generator

判别器与相似,是许多基于 CNN 的分类器。 输入是28 x 28 x 1MNIST 图像,分类为真实(1.0)或伪(0.0)。 有四个 CNN 层。 除了最后的卷积,每个Conv2D都使用strides = 2将特征映射下采样两个。 然后每个Conv2D之前都有一个泄漏的 ReLU 层。 最终的过滤器大小为 256,而初始的过滤器大小为 32,并使每个卷积层加倍。 最终的过滤器大小 128 也适用。 但是,我们会发现生成的图像在 256 的情况下看起来更好。最终输出层被展平,并且在通过 Sigmoid 激活层缩放后,单个单元Dense层在 0.0 到 1.0 之间生成预测。 输出被建模为伯努利分布。 因此,使用了二进制交叉熵损失函数。

建立生成器和判别器模型后,通过将生成器和判别器网络连接起来,建立对抗模型。 鉴别网络和对抗网络都使用 RMSprop 优化器。 判别器的学习率是2e-4,而对抗网络的学习率是1e-4。 判别器的 RMSprop 衰减率为6e-8,对抗网络的 RMSprop 衰减率为3e-8

将对手的学习率设置为判别器的一半将使训练更加稳定。 您会从“图 4.1.3”和“图 4.1.4”中回忆起,GAN 训练包含两个部分:判别器训练和生成器训练,这是冻结判别器权重的对抗训练。

“列表 4.2.2”显示了tf.keras中判别器的实现。 定义一个函数来建立鉴别模型。

“列表 4.2.2”:dcgan-mnist-4.2.1.py

def build_discriminator(inputs):
    """Build a Discriminator Model
Stack of LeakyReLU-Conv2D to discriminate real from fake.
    The network does not converge with BN so it is not used here
    unlike in [1] or original paper.
Arguments:
        inputs (Layer): Input layer of the discriminator (the image)
Returns:
        discriminator (Model): Discriminator Model
    """
    kernel_size = 5
    layer_filters = [32, 64, 128, 256]
x = inputs
    for filters in layer_filters:
        # first 3 convolution layers use strides = 2
        # last one uses strides = 1
        if filters == layer_filters[-1]:
            strides = 1
        else:
            strides = 2
        x = LeakyReLU(alpha=0.2)(x)
        x = Conv2D(filters=filters,
                   kernel_size=kernel_size,
                   strides=strides,
                   padding='same')(x)
x = Flatten()(x)
    x = Dense(1)(x)
    x = Activation('sigmoid')(x)
    discriminator = Model(inputs, x, name='discriminator')
    return discriminator

在“列表 4.2.3”中,我们将说明如何构建 GAN 模型。 首先,建立鉴别模型,然后实例化生成器模型。 对抗性模型只是生成器和判别器组合在一起。 在许多 GAN 中,批大小为 64 似乎是最常见的。 网络参数显示在“列表 4.2.3”中。

“列表 4.2.3”:dcgan-mnist-4.2.1.py

建立 DCGAN 模型并调用训练例程的函数:

def build_and_train_models():
    # load MNIST dataset
    (x_train, _), (_, _) = mnist.load_data()
# reshape data for CNN as (28, 28, 1) and normalize
    image_size = x_train.shape[1]
    x_train = np.reshape(x_train, [-1, image_size, image_size, 1])
    x_train = x_train.astype('float32') / 255
model_name = "dcgan_mnist"
    # network parameters
    # the latent or z vector is 100-dim
    latent_size = 100
    batch_size = 64
    train_steps = 40000
    lr = 2e-4
    decay = 6e-8
    input_shape = (image_size, image_size, 1)
# build discriminator model
    inputs = Input(shape=input_shape, name='discriminator_input')
    discriminator = build_discriminator(inputs)
    # [1] or original paper uses Adam, 
    # but discriminator converges easily with RMSprop
    optimizer = RMSprop(lr=lr, decay=decay)
    discriminator.compile(loss='binary_crossentropy',
                          optimizer=optimizer,
                          metrics=['accuracy'])
    discriminator.summary()
# build generator model
    input_shape = (latent_size, )
    inputs = Input(shape=input_shape, name='z_input')
    generator = build_generator(inputs, image_size)
    generator.summary()
# build adversarial model
    optimizer = RMSprop(lr=lr * 0.5, decay=decay * 0.5)
    # freeze the weights of discriminator during adversarial training
    discriminator.trainable = False
    # adversarial = generator + discriminator
    adversarial = Model(inputs,
                        discriminator(generator(inputs)),
                        name=model_name)
    adversarial.compile(loss='binary_crossentropy',
                        optimizer=optimizer,
                        metrics=['accuracy'])
    adversarial.summary()
# train discriminator and adversarial networks
    models = (generator, discriminator, adversarial)
    params = (batch_size, latent_size, train_steps, model_name)
    train(models, x_train, params)

从“列表 4.2.1”和“列表 4.2.2”中可以看出,DCGAN 模型很简单。 使它们难以构建的原因是,网络中的较小更改设计很容易破坏训练收敛。 例如,如果在判别器中使用批量归一化,或者如果生成器中的strides = 2传输到后面的 CNN 层,则 DCGAN 将无法收敛。

“列表 4.2.4”显示了专用于训练判别器和对抗网络的函数。 由于自定义训练,将不使用常规的fit()函数。 取而代之的是,调用train_on_batch()对给定的数据批量运行单个梯度更新。 然后通过对抗网络训练生成器。 训练首先从数据集中随机选择一批真实图像。 这被标记为实数(1.0)。 然后,生成器将生成一批伪图像。 这被标记为假(0.0)。 这两个批量是连接在一起的,用于训练判别器。

完成此操作后,生成器将生成一批新的伪图像,并将其标记为真实(1.0)。 这批将用于训练对抗网络。 交替训练这两个网络约 40,000 步。 定期将基于特定噪声向量生成的 MNIST 数字保存在文件系统中。 在最后的训练步骤中,网络已收敛。 生成器模型也保存在文件中,因此我们可以轻松地将训练后的模型重新用于未来的 MNIST 数字生成。 但是,仅保存生成器模型,因为这是该 DCGAN 在生成新 MNIST 数字时的有用部分。 例如,我们可以通过执行以下操作来生成新的和随机的 MNIST 数字:

python3 dcgan-mnist-4.2.1.py --generator=dcgan_mnist.h5

“列表 4.2.4”:dcgan-mnist-4.2.1.py

训练判别器和对抗网络的函数:

def train(models, x_train, params):
    """Train the Discriminator and Adversarial Networks
Alternately train Discriminator and Adversarial networks by batch.
    Discriminator is trained first with properly real and fake images.
    Adversarial is trained next with fake images pretending to be real
    Generate sample images per save_interval.
Arguments:
        models (list): Generator, Discriminator, Adversarial models
        x_train (tensor): Train images
        params (list) : Networks parameters
"""
    # the GAN component models
    generator, discriminator, adversarial = models
    # network parameters
    batch_size, latent_size, train_steps, model_name = params
    # the generator image is saved every 500 steps
    save_interval = 500
    # noise vector to see how the generator output evolves during training
    noise_input = np.random.uniform(-1.0, 1.0, size=[16, latent_size])
    # number of elements in train dataset
    train_size = x_train.shape[0]
    for i in range(train_steps):
        # train the discriminator for 1 batch
        # 1 batch of real (label=1.0) and fake images (label=0.0)
        # randomly pick real images from dataset
        rand_indexes = np.random.randint(0, train_size, size=batch_size)
        real_images = x_train[rand_indexes]
        # generate fake images from noise using generator 
        # generate noise using uniform distribution
        noise = np.random.uniform(-1.0,
                                  1.0,
                                  size=[batch_size, latent_size])
        # generate fake images
        fake_images = generator.predict(noise)
        # real + fake images = 1 batch of train data
        x = np.concatenate((real_images, fake_images))
        # label real and fake images
        # real images label is 1.0
        y = np.ones([2 * batch_size, 1])
        # fake images label is 0.0
        y[batch_size:, :] = 0.0
        # train discriminator network, log the loss and accuracy
        loss, acc = discriminator.train_on_batch(x, y)
        log = "%d: [discriminator loss: %f, acc: %f]" % (i, loss, acc)
# train the adversarial network for 1 batch
        # 1 batch of fake images with label=1.0
        # since the discriminator weights 
        # are frozen in adversarial network
        # only the generator is trained
        # generate noise using uniform distribution
        noise = np.random.uniform(-1.0,
                                  1.0,
                                  size=[batch_size, latent_size])
        # label fake images as real or 1.0
        y = np.ones([batch_size, 1])
        # train the adversarial network 
        # note that unlike in discriminator training, 
        # we do not save the fake images in a variable
        # the fake images go to the discriminator input of the adversarial
        # for classification
        # log the loss and accuracy
        loss, acc = adversarial.train_on_batch(noise, y)
        log = "%s [adversarial loss: %f, acc: %f]" % (log, loss, acc)
        print(log)
        if (i + 1) % save_interval == 0:
            # plot generator images on a periodic basis
            plot_images(generator,
                        noise_input=noise_input,
                        show=False,
                        step=(i + 1),
                        model_name=model_name)
# save the model after training the generator
    # the trained generator can be reloaded for 
    # future MNIST digit generation
    generator.save(model_name + ".h5")

“图 4.2.2”显示了生成器伪造图像根据训练步骤的演变。 生成器已经以 5,000 步的速度生成了可识别的图像。 非常像拥有一个知道如何绘制数字的智能体。 值得注意的是,某些数字从一种可识别的形式(例如,最后一行的第二列中的 8)变为另一种形式(例如,0)。 当训练收敛时,判别器损失接近 0.5,而对抗性损失接近 1.0,如下所示:

39997: [discriminator loss: 0.423329, acc: 0.796875] [adversarial loss:
0.819355, acc: 0.484375]
39998: [discriminator loss: 0.471747, acc: 0.773438] [adversarial loss:
1.570030, acc: 0.203125]
39999: [discriminator loss: 0.532917, acc: 0.742188] [adversarial loss:
0.824350, acc: 0.453125]

我们可以看到以下结果:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-rbWarYGP-1681704179663)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_04_06.png)]

图 4.2.2:DCGAN 生成器在不同训练步骤生成的伪造图像

在本节中,由 DCGAN 生成的伪造图像是随机的。

生成器无法控制哪个特定数字。 没有机制可以请求生成器提供特定的数字。 这个问题可以通过称为 CGAN [4]的 GAN 变体来解决,我们将在下一部分中进行讨论。

3. Conditional GAN

使用与上一节相同的 GAN ,会对生成器和判别器输入都施加一个条件。 条件是数字的一键向量形式。 这与要生成的图像(生成器)或分类为真实或伪造的图像(判别器)相关。 CGAN 模型显示在“图 4.3.1”中。

CGAN 与 DCGAN 相似,除了附加的单热向量输入。 对于生成器,单热标签在Dense层之前与潜向量连接在一起。 对于判别器,添加了新的Dense层。 新层用于处理单热向量并对其进行整形,以使其适合于与后续 CNN 层的另一个输入连接。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-P3WiH4JK-1681704179663)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_04_07.png)]

图 4.3.1:CGAN 模型与 DCGAN 相似,只不过是单热向量,用于调节生成器和判别器的输出

生成器学习从 100 维输入向量和指定位数生成伪图像。 判别器基于真实和伪图像及其对应的标签将真实图像与伪图像分类。

CGAN 的基础仍然与原始 GAN 原理相同,区别在于判别器和生成器的输入均以“一热”标签y为条件。

通过在“公式 4.1.1”和“公式 4.1.5”中合并此条件,判别器和生成器的损失函数在“公式 4.3.1”和“公式 4.3.2”中显示,分别为:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-yPzEVRWZ-1681704179664)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_04_018.png)] (Equation 4.3.1)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-i3yrMyfb-1681704179664)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_04_019.png)] (Equation 4.3.2)

给定“图 4.3.2”,将损失函数写为:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-YWe7l9jF-1681704179664)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_04_020.png)] (Equation 4.3.3)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-XS06Aa6K-1681704179664)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_04_021.png)] (Equation 4.3.4)

判别器的新损失函数旨在最大程度地减少预测来自数据集的真实图像和来自生成器的假图像(给定单热点标签)的误差。“图 4.3.2”显示了如何训练判别器。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ti48EU0G-1681704179664)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_04_08.png)]

图 4.3.2:训练 CGAN 判别器类似于训练 GAN 判别器。 唯一的区别是,所生成的伪造品和数据集的真实图像均以其相应的“一键通”标签作为条件。

生成器的新损失函数可最大程度地减少对以指定的一幅热标签为条件的伪造图像进行鉴别的正确预测。 生成器学习如何在给定单热向量的情况下生成特定的 MNIST 数字,该数字可能使判别器蒙蔽。“图 4.3.3”显示了如何训练生成器。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-8dhWGXEc-1681704179665)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_04_09.png)]

图 4.3.3:通过对抗网络训练 CGAN 生成器类似于训练 GAN 生成器。 唯一的区别是,生成的伪造图像以“一热”标签为条件

“列表 4.3.1”突出显示了判别器模型中所需的微小更改。 该代码使用Dense层处理单热点向量,并将其与输入图像连接在一起。 修改了Model实例以用于图像和一键输入向量。

“列表 4.3.1”:cgan-mnist-4.3.1.py

突出显示了 DCGAN 中所做的更改:

def build_discriminator(inputs, labels, image_size):
    """Build a Discriminator Model
Inputs are concatenated after Dense layer.
    Stack of LeakyReLU-Conv2D to discriminate real from fake.
    The network does not converge with BN so it is not used here
    unlike in DCGAN paper.
Arguments:
        inputs (Layer): Input layer of the discriminator (the image)
        labels (Layer): Input layer for one-hot vector to condition
            the inputs
        image_size: Target size of one side (assuming square image)
    Returns:
        discriminator (Model): Discriminator Model
    """
    kernel_size = 5
    layer_filters = [32, 64, 128, 256]
x = inputs
y = Dense(image_size * image_size)(labels)
    y = Reshape((image_size, image_size, 1))(y)
    x = concatenate([x, y])
for filters in layer_filters:
        # first 3 convolution layers use strides = 2
        # last one uses strides = 1
        if filters == layer_filters[-1]:
            strides = 1
        else:
            strides = 2
        x = LeakyReLU(alpha=0.2)(x)
        x = Conv2D(filters=filters,
                   kernel_size=kernel_size,
                   strides=strides,
                   padding='same')(x)
x = Flatten()(x)
    x = Dense(1)(x)
    x = Activation('sigmoid')(x)
    # input is conditioned by labels
    discriminator = Model([inputs, labels], x, name='discriminator')
    return discriminator

以下“列表 4.3.2”突出显示了代码更改,以在生成器生成器函数中合并条件化单热标签。 对于z向量和单热向量输入,修改了Model实例。

“列表 4.3.2”:cgan-mnist-4.3.1.py

突出显示了 DCGAN 中所做的更改:

def build_generator(inputs, labels, image_size):
    """Build a Generator Model
    Inputs are concatenated before Dense layer.
    Stack of BN-ReLU-Conv2DTranpose to generate fake images.
    Output activation is sigmoid instead of tanh in orig DCGAN.
    Sigmoid converges easily.
Arguments:
        inputs (Layer): Input layer of the generator (the z-vector)
        labels (Layer): Input layer for one-hot vector to condition the inputs
        image_size: Target size of one side (assuming square image)
    Returns:
        generator (Model): Generator Model
    """
    image_resize = image_size // 4
    # network parameters
    kernel_size = 5
    layer_filters = [128, 64, 32, 1]
x = concatenate([inputs, labels], axis=1)
    x = Dense(image_resize * image_resize * layer_filters[0])(x)
    x = Reshape((image_resize, image_resize, layer_filters[0]))(x)
for filters in layer_filters:
        # first two convolution layers use strides = 2
        # the last two use strides = 1
        if filters > layer_filters[-2]:
            strides = 2
        else:
            strides = 1
        x = BatchNormalization()(x)
        x = Activation('relu')(x)
        x = Conv2DTranspose(filters=filters,
                            kernel_size=kernel_size,
                            strides=strides,
                            padding='same')(x)
x = Activation('sigmoid')(x)
    # input is conditioned by labels
    generator = Model([inputs, labels], x, name='generator')
    return generator

“列表 4.3.3”突出显示了在train()函数中所做的更改,以适应判别器和生成器的条件一热向量。 首先对 CGAN 判别器进行训练,以一批真实和伪造的数据为条件,这些数据以其各自的热门标签为条件。 然后,在给定单热标签条件假冒数据为假的情况下,通过训练对抗网络来更新生成器参数。 与 DCGAN 相似,在对抗训练中,判别器权重被冻结。

“列表 4.3.3”:cgan-mnist-4.3.1.py

着重介绍了 DCGAN 中所做的更改:

def train(models, data, params):
    """Train the Discriminator and Adversarial Networks
Alternately train Discriminator and Adversarial networks by batch.
    Discriminator is trained first with properly labelled real and fake images.
    Adversarial is trained next with fake images pretending to be real.
    Discriminator inputs are conditioned by train labels for real images,
    and random labels for fake images.
    Adversarial inputs are conditioned by random labels.
    Generate sample images per save_interval.
Arguments:
        models (list): Generator, Discriminator, Adversarial models
        data (list): x_train, y_train data
        params (list): Network parameters
"""
    # the GAN models
    generator, discriminator, adversarial = models
    # images and labels
    x_train, y_train = data
    # network parameters
    batch_size, latent_size, train_steps, num_labels, model_name = params
    # the generator image is saved every 500 steps
    save_interval = 500
    # noise vector to see how the generator output evolves during training
    noise_input = np.random.uniform(-1.0, 1.0, size=[16, latent_size])
    # one-hot label the noise will be conditioned to
    noise_class = np.eye(num_labels)[np.arange(0, 16) % num_labels]
    # number of elements in train dataset
    train_size = x_train.shape[0]
print(model_name,
          "Labels for generated images: ",
          np.argmax(noise_class, axis=1))
for i in range(train_steps):
        # train the discriminator for 1 batch
        # 1 batch of real (label=1.0) and fake images (label=0.0)
        # randomly pick real images from dataset
        rand_indexes = np.random.randint(0, train_size, size=batch_size)
        real_images = x_train[rand_indexes]
        # corresponding one-hot labels of real images
        real_labels = y_train[rand_indexes]
        # generate fake images from noise using generator
        noise = np.random.uniform(-1.0,
                                  1.0,
                                 size=[batch_size, latent_size])
# assign random one-hot labels
        fake_labels = np.eye(num_labels)[np.random.choice(num_labels,batch_size)]
        # generate fake images conditioned on fake labels
        fake_images = generator.predict([noise, fake_labels])
        # real + fake images = 1 batch of train data
        x = np.concatenate((real_images, fake_images))
        # real + fake one-hot labels = 1 batch of train one-hot labels
        labels = np.concatenate((real_labels, fake_labels))
        # label real and fake images
        # real images label is 1.0
        y = np.ones([2 * batch_size, 1])
        # fake images label is 0.0
        y[batch_size:, :] = 0.0
        # train discriminator network, log the loss and accuracy
        loss, acc = discriminator.train_on_batch([x, labels], y)
        log = "%d: [discriminator loss: %f, acc: %f]" % (i, loss, acc)
        # train the adversarial network for 1 batch
        # 1 batch of fake images conditioned on fake 1-hot labels 
        # w/ label=1.0
        # since the discriminator weights are frozen in 
        # adversarial network only the generator is trained
        # generate noise using uniform distribution        
        noise = np.random.uniform(-1.0,
                                  1.0,
                                  size=[batch_size, latent_size])
        # assign random one-hot labels
        fake_labels = np.eye(num_labels)[np.random.choice(num_labels,batch_size)]
# label fake images as real or 1.0
        y = np.ones([batch_size, 1])
        # train the adversarial network 
        # note that unlike in discriminator training, 
        # we do not save the fake images in a variable
        # the fake images go to the discriminator input of the adversarial
        # for classification
        # log the loss and accuracy
        loss, acc = adversarial.train_on_batch([noise, fake_labels], y)
        log = "%s [adversarial loss: %f, acc: %f]" % (log, loss, acc)
        print(log)
        if (i + 1) % save_interval == 0:
            # plot generator images on a periodic basis
            plot_images(generator,
                        noise_input=noise_input,
                        noise_class=noise_class,
                        show=False,
                        step=(i + 1),
                        model_name=model_name)
# save the model after training the generator
    # the trained generator can be reloaded for 
    # future MNIST digit generation
    generator.save(model_name + ".h5")

“图 4.3.4”显示了当生成器被调整为产生带有以下标签的数字时生成的 MNIST 数字的演变:

[0 1 2 3
4 5 6 7
8 9 0 1
2 3 4 5]

我们可以看到以下结果:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-NfdWj3kQ-1681704179665)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_04_10.png)]

图 4.3.4:使用标签[0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5]对 CGAN 在不同训练步骤中生成的伪造图像

鼓励您运行经过训练的生成器模型,以查看新的合成 MNIST 数字图像:

python3 cgan-mnist-4.3.1.py --generator=cgan_mnist.h5

或者,也可以请求要生成的特定数字(例如 8):

python3 cgan-mnist-4.3.1.py --generator=cgan_mnist.h5 --digit=8

使用 CGAN,就像有一个智能体,我们可以要求绘制数字,类似于人类如何写数字。 与 DCGAN 相比,CGAN 的主要优势在于我们可以指定希望智能体绘制的数字。

4。结论

本章讨论了 GAN 的一般原理,以便为我们现在要讨论的更高级的主题奠定基础,包括改进的 GAN,解缠的表示 GAN 和跨域 GAN。 我们从了解 GAN 如何由两个网络(称为生成器和判别器)组成的这一章开始。 判别器的作用是区分真实信号和虚假信号。 生成器的目的是欺骗判别器。 生成器通常与判别器结合以形成对抗网络。 生成器是通过训练对抗网络来学习如何生成可欺骗判别器的虚假数据的。

我们还了解了 GAN 的构建方法,但众所周知,其操作起来非常困难。 提出了tf.keras中的两个示例实现。 DCGAN 证明了可以训练 GAN 使用深层 CNN 生成伪造图像。 伪造的图像是 MNIST 数字。 但是,DCGAN 生成器无法控制应绘制的特定数字。 CGAN 通过调节生成器以绘制特定数字来解决此问题。 该病是单热标签的形式。 如果我们要构建可以生成特定类数据的智能体,则 CGAN 很有用。

在下一章中,将介绍 DCGAN 和 CGAN 的改进。 特别是,重点将放在如何稳定 DCGAN 的训练以及如何提高 CGAN 的感知质量上。 这将通过引入新的损失函数和稍有不同的模型架构来完成。

5. 参考

  1. Ian Goodfellow. NIPS 2016 Tutorial: Generative Adversarial Networks. arXiv preprint arXiv:1701.00160, 2016 (https://arxiv.org/pdf/1701.00160.pdf).
  2. Alec Radford, Luke Metz, and Soumith Chintala. Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks. arXiv preprint arXiv:1511.06434, 2015 (https://arxiv.org/pdf/1511.06434.pdf).
  3. Mehdi Mirza and Simon Osindero. Conditional Generative Adversarial Nets. arXiv preprint arXiv:1411.1784, 2014 (https://arxiv.org/pdf/1411.1784.pdf).
  4. Tero Karras et al. Progressive Growing of GANs for Improved Quality, Stability, and Variation. ICLR, 2018 (https://arxiv.org/pdf/1710.10196.pdf).
  5. Tero Karras, , Samuli Laine, and Timo Aila. A Style-Based Generator Architecture for Generative Adversarial Networks. Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2019.
  6. Tero Karras et al. Analyzing and Improving the Image Quality of StyleGAN. 2019 (https://arxiv.org/abs/1912.04958).

五、改进的 GAN

自 2014 年引入生成对抗网络GAN)以来,其流行度迅速提高。 GAN 已被证明是有用的生成模型,可以合成看起来真实的新数据。 深度学习中的许多研究论文都遵循提出的措施来解决原始 GAN 的困难和局限性。

正如我们在前几章中讨论的那样,众所周知,GAN 很难训练,并且易于崩溃。 模式损失是一种情况,即使损失函数已经被优化,但生成器仍会产生看起来相同的输出。 在 MNIST 数字的情况下,模式折叠时,生成器可能只产生数字 4 和 9,因为看起来很相似。 Wasserstein GANWGAN)[2]解决了这些问题,认为只需替换基于 Wasserstein 的 GAN 损失函数就可以稳定的训练和避免模式崩溃,也称为陆地移动距离EMD)。

但是,稳定性问题并不是 GAN 的唯一问题。 也越来越需要来提高所生成图像的感知质量。 最小二乘 GANLSGAN)[3]建议同时解决这两个问题。 基本前提是,在训练过程中,Sigmoid 交叉熵损失会导致梯度消失。 这导致较差的图像质量。 最小二乘损失不会导致梯度消失。 与原始 GAN 生成的图像相比,生成的生成图像具有更高的感知质量。

在上一章中,CGAN 介绍了一种调节生成器输出的方法。 例如,如果要获取数字 8,则可以在生成器的输入中包含条件标签。 受 CGAN 的启发,辅助分类器 GANACGAN)[4]提出了一种改进的条件算法,可产生更好的感知质量和输出多样性。

总之,本章的目的是介绍:

  • WGAN 的理论描述
  • 对 LSGAN 原理的理解
  • 对 ACGAN 原理的理解
  • 改进的 GAN 的tf.keras实现 – WGAN,LSGAN 和 ACGAN

让我们从讨论 WGAN 开始。

1. Wasserstein GAN

如前所述,众所周知,GAN 很难训练。 判别器和生成器这两个网络的相反目标很容易导致训练不稳定。 判别器尝试从真实数据中正确分类伪造数据。 同时,生成器将尽最大努力欺骗判别器。 如果判别器的学习速度比生成器快,则生成器参数将无法优化。 另一方面,如果判别器学习较慢,则梯度可能会在到达生成器之前消失。 在最坏的情况下,如果判别器无法收敛,则生成器将无法获得任何有用的反馈。

WGAN 认为 GAN 固有的不稳定性是由于它的损失函数引起的,该函数基于 Jensen-ShannonJS)距离。 在 GAN 中,生成器的目的是学习如何将一种源分布(例如噪声)从转换为估计的目标分布(例如 MNIST 数字)。 使用 GAN 的原始公式,损失函数实际上是使目标分布与其估计值之间的距离最小。 问题是,对于某些分布对,没有平滑的路径可以最小化此 JS 距离。 因此,训练将无法收敛。

在以下部分中,我们将研究三个距离函数,并分析什么可以替代更适合 GAN 优化的 JS 距离函数。

距离函数

可以通过检查其损失函数来了解训练 GAN 的稳定性。 为了更好地理解 GAN 损失函数,我们将回顾两个概率分布之间的公共距离或散度函数。

我们关注的是用于真实数据分配的p_data与用于生成器数据分配的p_g之间的距离。 GAN 的目标是制造p_g -> p_data。“表 5.1.1”显示了散度函数。

在大多数个最大似然任务中,我们将使用 Kullback-LeiblerKL)散度,或D[KL]损失函数可以衡量我们的神经网络模型预测与真实分布函数之间的距离。 如“公式 5.1.1”所示,由于D[KL](p_data || p_g) ≠ D[KL](p_g || p_data),所以D[KL]不对称。

JSD[JS]是基于D[KL]的差异。 但是,与D[KL]不同,D[JS]是对称的并且是有限的。 在本节中,我们将演示优化 GAN 损失函数等同于优化D[JS]

散度 表达式
Kullback-Leibler(KL)“公式 5.1.1” [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-VeBGU1eZ-1681704179665)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_003.png)]
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Xq3IxItR-1681704179665)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_004.png)]
*詹森·香农(JS)“公式 5.1.2” [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-uF13aNLu-1681704179666)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_005.png)]
陆地移动距离(EMD)或 Wasserstein 1 “公式 5.1.3” [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-WtdS7NId-1681704179666)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_006.png)]
其中Π(p_data, p_g)是所有联合分布γ(x, y)的集合,其边际为p_datap_g

表 5.1.1:两个概率分布函数p_datap_g之间的散度函数

EMD 背后的想法是,它是d = ||x - y||传输多少质量γ(x, y),为了让概率分布p_data匹配p_g的度量。 γ(x, y)是所有可能的联合分布Π(p_data, p_g)的空间中的联合分布。 γ(x, y)也被称为运输计划,以反映运输质量以匹配两个概率分布的策略。 给定两个概率分布,有许多可能的运输计划。 大致而言, inf表示成本最低的运输计划。

例如,“图 5.1.1”向我们展示了两个简单的离散分布xy

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-mahjD3GA-1681704179666)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_01.png)]

图 5.1.1:EMD 是从x传输以匹配目标分布y的质量的加权数量。

在位置i = 1, 2, 3, 4上,x在具有质量m[i], i = 1, 2, 3, 4。同时,位置y[i], i = 1, 2上,y的质量为m[i], i = 1, 2。为了匹配分布y,图中的箭头显示了将每个质量x[i]移动d[i]的最小运输计划。 EMD 计算如下:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-qm42ZQdY-1681704179666)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_014.png)] (Equation 5.1.4)

在“图 5.1.1”中,EMD 可解释为移动一堆污物x填充孔y所需的最少工作量。 尽管在此示例中,也可以从图中推导出inf,但在大多数情况下,尤其是在连续分布中,用尽所有可能的运输计划是很棘手的。 我们将在本章中稍后回到这个问题。 同时,我们将向您展示 GAN 损失函数的作用,实际上是如何使 JS 的差异最小化。

GAN 中的距离函数

现在,在上一章的损失函数给定任何生成器的情况下,我们将计算最佳判别器。 我们将回顾上一章中的以下等式:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-EDa3H3ex-1681704179666)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_015.png)] (Equation 4.1.1)

除了从噪声分布中采样外,前面的等式也可以表示为从生成器分布中采样:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-90ytS6IO-1681704179667)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_016.png)] (Equation 5.1.5)

找出最小的L^(D)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-NKoF9FUQ-1681704179667)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_018.png)] (Equation 5.1.6)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-j7O36Saf-1681704179667)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_019.png)] (Equation 5.1.7)

积分内部的项为y -> a log(y) + b log(1 - y)的形式,对于不包括{0, 0}的任何a, b ∈ R^2,在y ∈ [0. 1]a / (a + b)处都有一个已知的最大值。 由于该积分不会更改此表达式的最大值(或L^(D)的最小值)的位置,因此最佳判别器为:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-jv8SyUwd-1681704179667)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_026.png)] (Equation 5.1.8)

因此,给定最佳判别器的损失函数为:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-p11dU40J-1681704179667)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_027.png)] (Equation 5.1.9)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-zAzdV2RM-1681704179668)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_028.png)] (Equation 5.1.10)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-2oJsFy2l-1681704179668)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_029.png)] (Equation 5.1.11)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-UHyX6j9w-1681704179668)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_030.png)] (Equation 5.1.12)

我们可以从“公式 5.1.12”观察到,最佳判别器的损失函数为常数减去真实分布p_data和任何生成器分布p_g之间的 JS 散度的两倍。 最小化L^(D*)意味着最大化D[JS](p_data || p_g),否则判别器必须正确地将真实数据中的伪造物分类。

同时,我们可以放心地说,最佳生成器是当生成器分布等于真实数据分布时:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-BTrFDo6V-1681704179668)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_034.png)] (Equation 5.1.13)

这是有道理的,因为生成器的目的是通过学习真实的数据分布来欺骗判别器。 有效地,我们可以通过最小化D[JS]或通过制作p_g -> p_data来获得最佳生成器。 给定最佳生成器,最佳判别器为D*(x) = 1 / 2L^(D*) = 2log2 = 0.60

问题在于,当两个分布没有重叠时,就没有平滑函数可以帮助缩小它们之间的差距。 训练 GAN 不会因梯度下降而收敛。 例如,假设:

p_data = (x, y) where x = 0, y ~ U(0, 1) (Equation 5.1.14)

p_g = (x, y) where x = θ, y ~ U(0, 1) (Equation 5.1.15)

这两个分布显示在“图 5.1.2”中:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-gczEkEOy-1681704179668)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_02.png)]

图 5.1.2:没有重叠的两个分布的示例。 对于p_gθ = 0.5

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-tfFeAEpL-1681704179669)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_044.png)]是均匀分布。 每个距离函数的差异如下:

  • [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-cu5m8s6u-1681704179669)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_045.png)]
  • [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-kfAJSRe5-1681704179669)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_046.png)]
  • [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ew0vt6il-1681704179669)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_047.png)]
  • [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ieYXLEnd-1681704179669)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_048.png)]

由于D[JS]是一个常数,因此 GAN 将没有足够的梯度来驱动p_g -> p_data。 我们还会发现D[KL]或反向D[KL]也不起作用。 但是,通过W(p_data, p_g),我们可以拥有平滑函数,以便通过梯度下降获得p_g -> p_data。 为了优化 GAN,EMD 或 Wasserstein 1 似乎是一个更具逻辑性的损失函数,因为在两个分布具有极小或没有重叠的情况下,D[JS]会失败。

为了帮助进一步理解,可以在以下位置找到有关距离函数的精彩讨论

在下一节中,我们将重点介绍使用 EMD 或 Wasserstein 1 距离函数来开发替代损失函数,以鼓励稳定训练 GAN。

使用 Wasserstein 损失

在使用 EMD 或 Wasserstein 1 之前,还有一个要解决的问题。 耗尽Π(p_data, p_g)的空间来找到γ ~ Π(p_data, p_g)是很棘手的。 提出的解决方案是使用其 Kantorovich-Rubinstein 对偶:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-S39W67fv-1681704179670)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_053.png)] (Equation 5.1.16)

等效地,EMD sup ||f||_L <= 1是所有 K-Lipschitz 函数上的最高值(大约是最大值):f: x -> R。 K-Lipschitz 函数满足以下约束:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Qi76pQsr-1681704179670)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_056.png)] (Equation 5.1.17)

对于所有x[1], x[2] ∈ R。 K-Lipschitz 函数具有有界导数,并且几乎总是连续可微的(例如,f(x) = |x|具有有界导数并且是连续的,但在x = 0时不可微分)。

“公式 5.1.16”可以通过找到 K-Lipschitz 函数{f[w]}, w ∈ W的族来求解:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Gd6BwfgW-1681704179670)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_060.png)] (Equation 5.1.18)

在 GAN 中,可以通过从z-噪声分布采样并用f[w]替换“公式 5.1.18”来重写。 鉴别函数,D[w]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-kCuipoDN-1681704179670)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_061.png)] (Equation 5.1.19)

我们使用粗体字母突出显示多维样本的一般性。 最后一个问题是如何找到函数族w ∈ W。 所提出的解决方案是在每次梯度更新时进行的。 判别器w的权重被限制在上下限之间(例如,-0.01 和 0.01):

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-yX1iKjls-1681704179670)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_063.png)] (Equation 5.1.20)

w的较小值将判别器约束到紧凑的参数空间,从而确保 Lipschitz 连续性。

我们可以使用“公式 5.1.19”作为我们新的 GAN 损失函数的基础。 EMD 或 Wasserstein 1 是生成器旨在最小化的损失函数,以及判别器试图最大化的损失函数(或最小化-W(p_data, p_g)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ppgquOvf-1681704179671)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_064.png)] (Equation 5.1.21)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-kVWzEj23-1681704179671)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_065.png)] (Equation 5.1.22)

在生成器损失函数中,第一项消失了,因为它没有针对实际数据进行直接优化。

“表 5.1.2”显示了 GAN 和 WGAN 的损失函数之间的差异。 为简洁起见,我们简化了L^(D)L^(G)的表示法:

网络 损失函数 公式
GAN [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-87Qz0bjT-1681704179671)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_068.png)] 4.1.1
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-GSNSM1lv-1681704179671)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_069.png)] 4.1.5
WGAN [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-g8AWPywh-1681704179671)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_070.png)] 5.1.21
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-DXYoVwvN-1681704179672)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_071.png)] 5.1.22
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-lOgNWSXI-1681704179672)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_072.png)] 5.1.20

表 5.1.2:GAN 和 WGAN 的损失函数之间的比较

这些损失函数用于训练 WGAN,如“算法 5.1.1”中所示。

算法 5.1.1 WGAN。 参数的值为α = 0.00005c = 0.01m = 64n_critic = 5

要求:α,学习率。c是削波参数。m,批量大小。 n_critic,即每个生成器迭代的评论(鉴别)迭代次数。

要求:w[D],初始判别器(discriminator)参数。 θ[D],初始生成器参数:

  1. θ[D]尚未收敛,执行:
  2. 对于t = 1, ..., n_critic,执行:
  3. 从真实数据中抽样一批{x^(i)} ~ p_data, i = 1, ..., m
  4. 从均匀的噪声分布中采样一批{z^(i)} ~ p_x, i = 1, ..., m
  5. [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-tM7IJKlL-1681704179672)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_085.png)]
    计算判别器梯度
  6. [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-GM9Z97tW-1681704179672)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_086.png)]
    更新判别器参数
  7. [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-t2JpiGqA-1681704179672)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_087.png)]
    剪辑判别器权重
  8. end for
  9. 从均匀的噪声分布中采样一批{z^(i)} ~ p_x, i = 1, ..., m
  10. [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Ii1Ri2hW-1681704179673)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_089.png)]
计算生成器梯度
  1. [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-uVfhsnc3-1681704179673)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_090.png)]
更新生成器参数
  1. end while

“图 5.1.3”展示了 WGAN 模型实际上与 DCGAN 相同,除了伪造的/真实的数据标签和损失函数:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bETyHrN9-1681704179673)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_03.png)]

图 5.1.3:顶部:训练 WGAN 判别器需要来自生成器的虚假数据和来自真实分发的真实数据。 下:训练 WGAN 生成器要求生成器中假冒的真实数据是真实的

与 GAN 相似,WGAN 交替训练判别器和生成器(通过对抗)。 但是,在 WGAN 中,判别器(也称为评论者)在训练生成器进行一次迭代(第 9 至 11 行)之前,先训练n_critic迭代(第 2 至 8 行)。 这与对于判别器和生成器具有相同数量的训练迭代的 GAN 相反。 换句话说,在 GAN 中,n_critic = 1

训练判别器意味着学习判别器的参数(权重和偏差)。 这需要从真实数据中采样一批(第 3 行),并从伪数据中采样一批(第 4 行),然后将采样数据馈送到判别器网络,然后计算判别器参数的梯度(第 5 行)。 判别器参数使用 RMSProp(第 6 行)进行了优化。 第 5 行和第 6 行都是“公式 5.1.21”的优化。

最后,EM 距离优化中的 Lipschitz 约束是通过裁剪判别器参数(第 7 行)来施加的。 第 7 行是“公式 5.1.20”的实现。 在n_critic迭代判别器训练之后,判别器参数被冻结。 生成器训练通过对一批伪造数据进行采样开始(第 9 行)。 采样的数据被标记为实数(1.0),以致愚弄判别器网络。 在第 10 行中计算生成器梯度,并在第 11 行中使用 RMSProp 对其进行优化。第 10 行和第 11 行执行梯度更新以优化“公式 5.1.22”。

训练生成器后,将解冻判别器参数,并开始另一个n_critic判别器训练迭代。 我们应该注意,在判别器训练期间不需要冻结生成器参数,因为生成器仅涉及数据的制造。 类似于 GAN,可以将判别器训练为一个单独的网络。 但是,训练生成器始终需要判别器通过对抗网络参与,因为损失是根据生成器网络的输出计算得出的。

与 GAN 不同,在 WGAN 中,将实际数据标记为 1.0,而将伪数据标记为 -1.0,作为计算第 5 行中的梯度的一种解决方法。第 5-6 和 10-11 行执行梯度更新以优化“公式 5.1.21”和“5.1.22”。 第 5 行和第 10 行中的每一项均建模为:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-EyG3VrWl-1681704179673)(https://gitcode.net/apachecn/apachecn-dl-zh/-/raw/master/docs/adv-dl-tf2-keras/img/B14853_05_091.png)] (Equation 5.1.23)

对于真实数据,其中y_label = 1.0,对于假数据,y_label= -1.0。 为了简化符号,我们删除了上标(i)。 对于判别器,当使用实际数据进行训练时,WGAN 增加y_pred = D[w](x)以最小化损失函数。

使用伪造数据进行训练时,WGAN 会降低y_pred = D[w](g(z))以最大程度地减少损失函数。 对于生成器,当在训练过程中将伪数据标记为真实数据时,WGAN 增加y_pred = D[w](g(z))以最小化损失函数。 请注意,y_label除了其符号外,对损失函数没有直接贡献。 在tf.keras中,“公式 5.1.23”实现为:

def wasserstein_loss(y_label, y_pred):
    return -K.mean(y_label * y_pred)

本节最重要的部分是用于稳定训练 GAN 的新损失函数。 它基于 EMD 或 Wasserstein1。“算法 5.1.1”形式化了 WGAN 的完整训练算法,包括损失函数。 在下一节中,将介绍tf.keras中训练算法的实现。

使用 Keras 的 WGAN 实现

为了在tf.keras中实现 WGAN,我们可以重用 GAN 的 DCGAN 实现,这是我们在上一一章中介绍的。 DCGAN 构建器和工具函数在lib文件夹的gan.py中作为模块实现。

函数包括:

  • generator():生成器模型构建器
  • discriminator():判别器模型构建器
  • train():DCGAN 训练师
  • plot_images():通用生成器输出绘图仪
  • test_generator():通用的生成器测试工具

如“列表 5.1.1”所示,我们可以通过简单地调用以下命令来构建一个判别器:

discriminator = gan.discriminator(inputs, activation='linear')

WGAN 使用线性输出激活。 对于生成器,我们执行:

generator = gan.generator(inputs, image_size)

tf.keras中的整体网络模型类似于 DCGAN 的“图 4.2.1”中看到的模型。

“列表 5.1.1”突出显示了 RMSprop 优化器和 Wasserstein 损失函数的使用。 在训练期间使用“算法 5.1.1”中的超参数。

完整的代码可在 GitHub 上获得

“列表 5.1.1”:wgan-mnist-5.1.2.py

def build_and_train_models():
    """Load the dataset, build WGAN discriminator,
    generator, and adversarial models.
    Call the WGAN train routine.
    """
    # load MNIST dataset
    (x_train, _), (_, _) = mnist.load_data()
# reshape data for CNN as (28, 28, 1) and normalize
    image_size = x_train.shape[1]
    x_train = np.reshape(x_train, [-1, image_size, image_size, 1])
    x_train = x_train.astype('float32') / 255
model_name = "wgan_mnist"
    # network parameters
    # the latent or z vector is 100-dim
    latent_size = 100
    # hyper parameters from WGAN paper [2]
    n_critic = 5
    clip_value = 0.01
    batch_size = 64
    lr = 5e-5
    train_steps = 40000
    input_shape = (image_size, image_size, 1)
# build discriminator model
    inputs = Input(shape=input_shape, name='discriminator_input')
    # WGAN uses linear activation in paper [2]
    discriminator = gan.discriminator(inputs, activation='linear')
    optimizer = RMSprop(lr=lr)
    # WGAN discriminator uses wassertein loss
    discriminator.compile(loss=wasserstein_loss,
                          optimizer=optimizer,
                          metrics=['accuracy'])
    discriminator.summary()
# build generator model
    input_shape = (latent_size, )
    inputs = Input(shape=input_shape, name='z_input')
    generator = gan.generator(inputs, image_size)
    generator.summary()
# build adversarial model = generator + discriminator
    # freeze the weights of discriminator during adversarial training
    discriminator.trainable = False
    adversarial = Model(inputs,
                        discriminator(generator(inputs)),
                        name=model_name)
    adversarial.compile(loss=wasserstein_loss,
                        optimizer=optimizer,
                        metrics=['accuracy'])
    adversarial.summary()
# train discriminator and adversarial networks
    models = (generator, discriminator, adversarial)
    params = (batch_size,
              latent_size,
              n_critic,
              clip_value,
              train_steps,
              model_name)
    train(models, x_train, params)

“列表 5.1.2”是紧跟“算法 5.1.1”的训练函数。 但是,在判别器的训练中有一个小的调整。 与其在单个合并的真实数据和虚假数据中组合训练权重,不如先训练一批真实数据,然后再训练一批虚假数据。 这种调整将防止梯度消失,因为真实和伪造数据标签中的符号相反,并且由于裁剪而导致的权重较小。

TensorFlow 2 和 Keras 高级深度学习:1~5(5)https://developer.aliyun.com/article/1426948

相关文章
|
算法框架/工具 机器学习/深度学习 算法
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(三)(2)
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(三)
30 0
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(三)(2)
|
10天前
|
机器学习/深度学习 数据可视化 测试技术
深度学习:Keras使用神经网络进行简单文本分类分析新闻组数据
深度学习:Keras使用神经网络进行简单文本分类分析新闻组数据
21 0
|
10天前
|
机器学习/深度学习 运维 监控
TensorFlow分布式训练:加速深度学习模型训练
【4月更文挑战第17天】TensorFlow分布式训练加速深度学习模型训练,通过数据并行和模型并行利用多机器资源,减少训练时间。优化策略包括配置计算资源、优化数据划分和减少通信开销。实际应用需关注调试监控、系统稳定性和容错性,以应对分布式训练挑战。
|
10天前
|
机器学习/深度学习 API TensorFlow
TensorFlow的高级API:tf.keras深度解析
【4月更文挑战第17天】本文深入解析了TensorFlow的高级API `tf.keras`,包括顺序模型和函数式API的模型构建,以及模型编译、训练、评估和预测的步骤。`tf.keras`结合了Keras的易用性和TensorFlow的性能,支持回调函数、模型保存与加载等高级特性,助力提升深度学习开发效率。
|
10天前
|
机器学习/深度学习 API 算法框架/工具
R语言深度学习:用keras神经网络回归模型预测时间序列数据
R语言深度学习:用keras神经网络回归模型预测时间序列数据
18 0
|
10天前
|
机器学习/深度学习 数据采集 TensorFlow
R语言KERAS深度学习CNN卷积神经网络分类识别手写数字图像数据(MNIST)
R语言KERAS深度学习CNN卷积神经网络分类识别手写数字图像数据(MNIST)
29 0
|
10天前
|
机器学习/深度学习 自然语言处理 算法框架/工具
用于NLP的Python:使用Keras进行深度学习文本生成
用于NLP的Python:使用Keras进行深度学习文本生成
20 2
|
12天前
|
机器学习/深度学习 人工智能 算法框架/工具
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(八)(4)
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(八)
27 0
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(八)(4)
|
12天前
|
机器学习/深度学习 算法框架/工具 TensorFlow
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(七)(4)
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(七)
45 0
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(七)(4)
|
机器学习/深度学习 算法 算法框架/工具
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(七)(3)
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(七)
13 0
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(七)(3)