【Tensorflow+keras】Keras API两种训练GAN网络的方式

简介: 使用Keras API以两种不同方式训练条件生成对抗网络(CGAN)的示例代码:一种是使用train_on_batch方法,另一种是使用tf.GradientTape进行自定义训练循环。

1 第一种 train_on_batch

(1)简介
github:https://github.com/eriklindernoren/Keras-GAN/tree/master/cgan
运行一批样品的单次梯度更新。该方法搭配keras的sequential API使用。
其他网络结构参考Keras API三种搭建神经网络的方式及以mnist举例实现
(2)举例实现

from __future__ import print_function, division
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply
from tensorflow.keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
import numpy as np
class CGAN():
    def __init__(self):
        # Input shape
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.num_classes = 10
        self.latent_dim = 100
        optimizer = Adam(0.0002, 0.5)
        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss=['binary_crossentropy'],
            optimizer=optimizer,
            metrics=['accuracy'])
        # Build the generator
        self.generator = self.build_generator()
        # The generator takes noise and the target label as input
        # and generates the corresponding digit of that label
        noise = Input(shape=(self.latent_dim,))
        label = Input(shape=(1,))
        img = self.generator([noise, label])
        # For the combined model we will only train the generator
        self.discriminator.trainable = False
        # The discriminator takes generated image as input and determines validity
        # and the label of that image
        valid = self.discriminator([img, label])
        # The combined model  (stacked generator and discriminator)
        # Trains generator to fool discriminator
        self.combined = Model([noise, label], valid)
        self.combined.compile(loss=['binary_crossentropy'],
            optimizer=optimizer)
    def build_generator(self):
        model = Sequential()
        model.add(Dense(256, input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))
        model.summary()
        noise = Input(shape=(self.latent_dim,))
        label = Input(shape=(1,), dtype='int32')
        label_embedding = Flatten()(Embedding(self.num_classes, self.latent_dim)(label))
        model_input = multiply([noise, label_embedding])
        img = model(model_input)
        return Model([noise, label], img)
    def build_discriminator(self):
        model = Sequential()
        model.add(Dense(512, input_dim=np.prod(self.img_shape)))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.4))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.4))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()
        img = Input(shape=self.img_shape)
        label = Input(shape=(1,), dtype='int32')
        label_embedding = Flatten()(Embedding(self.num_classes, np.prod(self.img_shape))(label))
        flat_img = Flatten()(img)
        model_input = multiply([flat_img, label_embedding])
        validity = model(model_input)
        return Model([img, label], validity)
    def train(self, epochs, batch_size=128, sample_interval=50):
        # Load the dataset
        (X_train, y_train), (_, _) = mnist.load_data()
        # Configure input
        X_train = (X_train.astype(np.float32) - 127.5) / 127.5
        X_train = np.expand_dims(X_train, axis=3)
        y_train = y_train.reshape(-1, 1)
        # Adversarial ground truths
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))
        for epoch in range(epochs):
            # ---------------------
            #  Train Discriminator
            # ---------------------
            # Select a random half batch of images
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs, labels = X_train[idx], y_train[idx]
            # Sample noise as generator input
            noise = np.random.normal(0, 1, (batch_size, 100))
            # Generate a half batch of new images
            gen_imgs = self.generator.predict([noise, labels])
            # Train the discriminator
            d_loss_real = self.discriminator.train_on_batch([imgs, labels], valid)
            #train_on_batch返回值 为长度为2的列表, d_loss_real[0]为loss, d_loss_real[1]为accuracy
            d_loss_fake = self.discriminator.train_on_batch([gen_imgs, labels], fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
            # ---------------------
            #  Train Generator
            # ---------------------
            # Condition on labels
            sampled_labels = np.random.randint(0, 10, batch_size).reshape(-1, 1)
            # Train the generator
            g_loss = self.combined.train_on_batch([noise, sampled_labels], valid)
            # Plot the progress
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                self.sample_images(epoch)
    def sample_images(self, epoch):
        r, c = 2, 5
        noise = np.random.normal(0, 1, (r * c, 100))
        sampled_labels = np.arange(0, 10).reshape(-1, 1)
        gen_imgs = self.generator.predict([noise, sampled_labels])
        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt,:,:,0], cmap='gray')
                axs[i,j].set_title("Digit: %d" % sampled_labels[cnt])
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/%d.png" % epoch)
        plt.close()
if __name__ == '__main__':
    cgan = CGAN()
    cgan.train(epochs=1000, batch_size=32, sample_interval=200)

2 第二种 tf.GradientTape()

参考:https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit
(1)搭建网络

from tensorflow.keras import layers
# Create the discriminator
discriminator = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.GlobalMaxPooling2D(),
        layers.Dense(1),
    ],
    name="discriminator",
)
# Create the generator
latent_dim = 128
generator = keras.Sequential(
    [
        keras.Input(shape=(latent_dim,)),
        # We want to generate 128 coefficients to reshape into a 7x7x128 map
        layers.Dense(7 * 7 * 128),
        layers.LeakyReLU(alpha=0.2),
        layers.Reshape((7, 7, 128)),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
    ],
    name="generator",
)
#训练网络
class GAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super(GAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super(GAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn
    def train_step(self, real_images):
        if isinstance(real_images, tuple):
            real_images = real_images[0]
        # Sample random points in the latent space
        batch_size = tf.shape(real_images)[0]
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        # Decode them to fake images
        generated_images = self.generator(random_latent_vectors)
        # Combine them with real images
        combined_images = tf.concat([generated_images, real_images], axis=0)
        # Assemble labels discriminating real from fake images
        labels = tf.concat(
            [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
        )
        # Add random noise to the labels - important trick!
        labels += 0.05 * tf.random.uniform(tf.shape(labels))
        # Train the discriminator
        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions)
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(
            zip(grads, self.discriminator.trainable_weights)
        )
        # Sample random points in the latent space
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        # Assemble labels that say "all real images"
        misleading_labels = tf.zeros((batch_size, 1))
        # Train the generator (note that we should *not* update the weights
        # of the discriminator)!
        with tf.GradientTape() as tape:
            predictions = self.discriminator(self.generator(random_latent_vectors))
            g_loss = self.loss_fn(misleading_labels, predictions)
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
        return {"d_loss": d_loss, "g_loss": g_loss}
#测试网络
batch_size = 64
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
all_digits = np.concatenate([x_train, x_test])
all_digits = all_digits.astype("float32") / 255.0
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))
dataset = tf.data.Dataset.from_tensor_slices(all_digits)
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)
gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)
gan.compile(
    d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
    g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
    loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
)
目录
相关文章
|
8天前
|
机器学习/深度学习 TensorFlow 算法框架/工具
深度学习之格式转换笔记(三):keras(.hdf5)模型转TensorFlow(.pb) 转TensorRT(.uff)格式
将Keras训练好的.hdf5模型转换为TensorFlow的.pb模型,然后再转换为TensorRT支持的.uff格式,并提供了转换代码和测试步骤。
37 3
深度学习之格式转换笔记(三):keras(.hdf5)模型转TensorFlow(.pb) 转TensorRT(.uff)格式
|
12天前
|
机器学习/深度学习 TensorFlow API
使用 TensorFlow 和 Keras 构建图像分类器
【10月更文挑战第2天】使用 TensorFlow 和 Keras 构建图像分类器
|
8天前
|
机器学习/深度学习 移动开发 TensorFlow
深度学习之格式转换笔记(四):Keras(.h5)模型转化为TensorFlow(.pb)模型
本文介绍了如何使用Python脚本将Keras模型转换为TensorFlow的.pb格式模型,包括加载模型、重命名输出节点和量化等步骤,以便在TensorFlow中进行部署和推理。
37 0
|
2月前
|
JavaScript 网络协议 API
【Azure API 管理】Azure APIM服务集成在内部虚拟网络后,在内部环境中打开APIM门户使用APIs中的TEST功能失败
【Azure API 管理】Azure APIM服务集成在内部虚拟网络后,在内部环境中打开APIM门户使用APIs中的TEST功能失败
|
2月前
|
Apache 开发者 Java
Apache Wicket揭秘:如何巧妙利用模型与表单机制,实现Web应用高效开发?
【8月更文挑战第31天】本文深入探讨了Apache Wicket的模型与表单处理机制。Wicket作为一个组件化的Java Web框架,提供了多种模型实现,如CompoundPropertyModel等,充当组件与数据间的桥梁。文章通过示例介绍了模型创建及使用方法,并详细讲解了表单组件、提交处理及验证机制,帮助开发者更好地理解如何利用Wicket构建高效、易维护的Web应用程序。
40 0
|
2月前
|
机器学习/深度学习 API TensorFlow
深入解析TensorFlow 2.x中的Keras API:快速搭建深度学习模型的实战指南
【8月更文挑战第31天】本文通过搭建手写数字识别模型的实例,详细介绍了如何利用TensorFlow 2.x中的Keras API简化深度学习模型构建流程。从环境搭建到数据准备,再到模型训练与评估,展示了Keras API的强大功能与易用性,适合初学者快速上手。通过简单的代码,即可完成卷积神经网络的构建与训练,显著降低了深度学习的技术门槛。无论是新手还是专业人士,都能从中受益,高效实现模型开发。
22 0
|
2月前
|
Java API 网络安全
探索Java中的Stream API:从基础到高级应用云计算与网络安全:技术融合与挑战
【8月更文挑战第27天】在Java的海洋中,Stream API犹如一艘强大的船,让开发者能以声明式的方式处理集合数据。本文将启航,先带你了解Stream的基本概念和用法,再深入探讨其高级特性,如并行流、管道操作以及性能考量。我们将通过具体代码示例,展示如何高效利用Stream API简化数据处理流程,提升代码的可读性和性能。无论你是初学者还是有经验的开发者,这篇文章都将为你打开一扇通往更优雅编程风格的大门。
|
2月前
|
SQL 网络协议 安全
【Azure API 管理】APIM集成内网虚拟网络后,启用自定义路由管理外出流量经过防火墙(Firewall),遇见APIs加载不出来问题
【Azure API 管理】APIM集成内网虚拟网络后,启用自定义路由管理外出流量经过防火墙(Firewall),遇见APIs加载不出来问题
|
2月前
|
TensorFlow 算法框架/工具
【Tensorflow+Keras】学习率指数、分段、逆时间、多项式衰减及自定义学习率衰减的完整实例
使用Tensorflow和Keras实现学习率衰减的完整实例,包括指数衰减、分段常数衰减、多项式衰减、逆时间衰减以及如何通过callbacks自定义学习率衰减策略。
53 0
|
5天前
|
安全 网络安全 数据安全/隐私保护
网络安全与信息安全:守护数字世界的坚盾
在数字化浪潮中,网络安全已成为维系现代社会正常运转的关键。本文旨在探讨网络安全漏洞的成因、加密技术的应用及安全意识的提升,以期为广大用户和技术人员提供实用的知识分享。通过对这些方面的深入剖析,我们期望能够共同构建一个更加安全可靠的数字环境。