【Tensorflow+keras】Keras API两种训练GAN网络的方式

简介: 使用Keras API以两种不同方式训练条件生成对抗网络(CGAN)的示例代码:一种是使用train_on_batch方法,另一种是使用tf.GradientTape进行自定义训练循环。

1 第一种 train_on_batch

(1)简介
github:https://github.com/eriklindernoren/Keras-GAN/tree/master/cgan
运行一批样品的单次梯度更新。该方法搭配keras的sequential API使用。
其他网络结构参考Keras API三种搭建神经网络的方式及以mnist举例实现
(2)举例实现

from __future__ import print_function, division
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply
from tensorflow.keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
import numpy as np
class CGAN():
    def __init__(self):
        # Input shape
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.num_classes = 10
        self.latent_dim = 100
        optimizer = Adam(0.0002, 0.5)
        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss=['binary_crossentropy'],
            optimizer=optimizer,
            metrics=['accuracy'])
        # Build the generator
        self.generator = self.build_generator()
        # The generator takes noise and the target label as input
        # and generates the corresponding digit of that label
        noise = Input(shape=(self.latent_dim,))
        label = Input(shape=(1,))
        img = self.generator([noise, label])
        # For the combined model we will only train the generator
        self.discriminator.trainable = False
        # The discriminator takes generated image as input and determines validity
        # and the label of that image
        valid = self.discriminator([img, label])
        # The combined model  (stacked generator and discriminator)
        # Trains generator to fool discriminator
        self.combined = Model([noise, label], valid)
        self.combined.compile(loss=['binary_crossentropy'],
            optimizer=optimizer)
    def build_generator(self):
        model = Sequential()
        model.add(Dense(256, input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))
        model.summary()
        noise = Input(shape=(self.latent_dim,))
        label = Input(shape=(1,), dtype='int32')
        label_embedding = Flatten()(Embedding(self.num_classes, self.latent_dim)(label))
        model_input = multiply([noise, label_embedding])
        img = model(model_input)
        return Model([noise, label], img)
    def build_discriminator(self):
        model = Sequential()
        model.add(Dense(512, input_dim=np.prod(self.img_shape)))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.4))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.4))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()
        img = Input(shape=self.img_shape)
        label = Input(shape=(1,), dtype='int32')
        label_embedding = Flatten()(Embedding(self.num_classes, np.prod(self.img_shape))(label))
        flat_img = Flatten()(img)
        model_input = multiply([flat_img, label_embedding])
        validity = model(model_input)
        return Model([img, label], validity)
    def train(self, epochs, batch_size=128, sample_interval=50):
        # Load the dataset
        (X_train, y_train), (_, _) = mnist.load_data()
        # Configure input
        X_train = (X_train.astype(np.float32) - 127.5) / 127.5
        X_train = np.expand_dims(X_train, axis=3)
        y_train = y_train.reshape(-1, 1)
        # Adversarial ground truths
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))
        for epoch in range(epochs):
            # ---------------------
            #  Train Discriminator
            # ---------------------
            # Select a random half batch of images
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs, labels = X_train[idx], y_train[idx]
            # Sample noise as generator input
            noise = np.random.normal(0, 1, (batch_size, 100))
            # Generate a half batch of new images
            gen_imgs = self.generator.predict([noise, labels])
            # Train the discriminator
            d_loss_real = self.discriminator.train_on_batch([imgs, labels], valid)
            #train_on_batch返回值 为长度为2的列表, d_loss_real[0]为loss, d_loss_real[1]为accuracy
            d_loss_fake = self.discriminator.train_on_batch([gen_imgs, labels], fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
            # ---------------------
            #  Train Generator
            # ---------------------
            # Condition on labels
            sampled_labels = np.random.randint(0, 10, batch_size).reshape(-1, 1)
            # Train the generator
            g_loss = self.combined.train_on_batch([noise, sampled_labels], valid)
            # Plot the progress
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                self.sample_images(epoch)
    def sample_images(self, epoch):
        r, c = 2, 5
        noise = np.random.normal(0, 1, (r * c, 100))
        sampled_labels = np.arange(0, 10).reshape(-1, 1)
        gen_imgs = self.generator.predict([noise, sampled_labels])
        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt,:,:,0], cmap='gray')
                axs[i,j].set_title("Digit: %d" % sampled_labels[cnt])
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/%d.png" % epoch)
        plt.close()
if __name__ == '__main__':
    cgan = CGAN()
    cgan.train(epochs=1000, batch_size=32, sample_interval=200)

2 第二种 tf.GradientTape()

参考:https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit
(1)搭建网络

from tensorflow.keras import layers
# Create the discriminator
discriminator = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.GlobalMaxPooling2D(),
        layers.Dense(1),
    ],
    name="discriminator",
)
# Create the generator
latent_dim = 128
generator = keras.Sequential(
    [
        keras.Input(shape=(latent_dim,)),
        # We want to generate 128 coefficients to reshape into a 7x7x128 map
        layers.Dense(7 * 7 * 128),
        layers.LeakyReLU(alpha=0.2),
        layers.Reshape((7, 7, 128)),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
    ],
    name="generator",
)
#训练网络
class GAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super(GAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super(GAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn
    def train_step(self, real_images):
        if isinstance(real_images, tuple):
            real_images = real_images[0]
        # Sample random points in the latent space
        batch_size = tf.shape(real_images)[0]
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        # Decode them to fake images
        generated_images = self.generator(random_latent_vectors)
        # Combine them with real images
        combined_images = tf.concat([generated_images, real_images], axis=0)
        # Assemble labels discriminating real from fake images
        labels = tf.concat(
            [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
        )
        # Add random noise to the labels - important trick!
        labels += 0.05 * tf.random.uniform(tf.shape(labels))
        # Train the discriminator
        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions)
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(
            zip(grads, self.discriminator.trainable_weights)
        )
        # Sample random points in the latent space
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        # Assemble labels that say "all real images"
        misleading_labels = tf.zeros((batch_size, 1))
        # Train the generator (note that we should *not* update the weights
        # of the discriminator)!
        with tf.GradientTape() as tape:
            predictions = self.discriminator(self.generator(random_latent_vectors))
            g_loss = self.loss_fn(misleading_labels, predictions)
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
        return {"d_loss": d_loss, "g_loss": g_loss}
#测试网络
batch_size = 64
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
all_digits = np.concatenate([x_train, x_test])
all_digits = all_digits.astype("float32") / 255.0
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))
dataset = tf.data.Dataset.from_tensor_slices(all_digits)
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)
gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)
gan.compile(
    d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
    g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
    loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
)
目录
相关文章
|
5月前
|
机器学习/深度学习 PyTorch 算法框架/工具
目标检测实战(一):CIFAR10结合神经网络加载、训练、测试完整步骤
这篇文章介绍了如何使用PyTorch框架,结合CIFAR-10数据集,通过定义神经网络、损失函数和优化器,进行模型的训练和测试。
284 2
目标检测实战(一):CIFAR10结合神经网络加载、训练、测试完整步骤
|
3月前
|
机器学习/深度学习 人工智能 算法
猫狗宠物识别系统Python+TensorFlow+人工智能+深度学习+卷积网络算法
宠物识别系统使用Python和TensorFlow搭建卷积神经网络,基于37种常见猫狗数据集训练高精度模型,并保存为h5格式。通过Django框架搭建Web平台,用户上传宠物图片即可识别其名称,提供便捷的宠物识别服务。
411 55
|
5月前
|
机器学习/深度学习 数据可视化 计算机视觉
目标检测笔记(五):详细介绍并实现可视化深度学习中每层特征层的网络训练情况
这篇文章详细介绍了如何通过可视化深度学习中每层特征层来理解网络的内部运作,并使用ResNet系列网络作为例子,展示了如何在训练过程中加入代码来绘制和保存特征图。
97 1
目标检测笔记(五):详细介绍并实现可视化深度学习中每层特征层的网络训练情况
|
23天前
|
机器学习/深度学习 人工智能 算法
基于Python深度学习的【蘑菇识别】系统~卷积神经网络+TensorFlow+图像识别+人工智能
蘑菇识别系统,本系统使用Python作为主要开发语言,基于TensorFlow搭建卷积神经网络算法,并收集了9种常见的蘑菇种类数据集【"香菇(Agaricus)", "毒鹅膏菌(Amanita)", "牛肝菌(Boletus)", "网状菌(Cortinarius)", "毒镰孢(Entoloma)", "湿孢菌(Hygrocybe)", "乳菇(Lactarius)", "红菇(Russula)", "松茸(Suillus)"】 再使用通过搭建的算法模型对数据集进行训练得到一个识别精度较高的模型,然后保存为为本地h5格式文件。最后使用Django框架搭建了一个Web网页平台可视化操作界面,
84 11
基于Python深度学习的【蘑菇识别】系统~卷积神经网络+TensorFlow+图像识别+人工智能
|
1月前
|
机器学习/深度学习 文件存储 异构计算
YOLOv11改进策略【模型轻量化】| 替换骨干网络为EfficientNet v2,加速训练,快速收敛
YOLOv11改进策略【模型轻量化】| 替换骨干网络为EfficientNet v2,加速训练,快速收敛
89 18
YOLOv11改进策略【模型轻量化】| 替换骨干网络为EfficientNet v2,加速训练,快速收敛
|
8天前
|
机器学习/深度学习 数据可视化 API
DeepSeek生成对抗网络(GAN)的训练与应用
生成对抗网络(GANs)是深度学习的重要技术,能生成逼真的图像、音频和文本数据。通过生成器和判别器的对抗训练,GANs实现高质量数据生成。DeepSeek提供强大工具和API,简化GAN的训练与应用。本文介绍如何使用DeepSeek构建、训练GAN,并通过代码示例帮助掌握相关技巧,涵盖模型定义、训练过程及图像生成等环节。
|
25天前
|
机器学习/深度学习 文件存储 异构计算
RT-DETR改进策略【模型轻量化】| 替换骨干网络为EfficientNet v2,加速训练,快速收敛
RT-DETR改进策略【模型轻量化】| 替换骨干网络为EfficientNet v2,加速训练,快速收敛
34 1
|
17天前
|
机器学习/深度学习 PyTorch TensorFlow
深度学习工具和框架详细指南:PyTorch、TensorFlow、Keras
在深度学习的世界中,PyTorch、TensorFlow和Keras是最受欢迎的工具和框架,它们为研究者和开发者提供了强大且易于使用的接口。在本文中,我们将深入探索这三个框架,涵盖如何用它们实现经典深度学习模型,并通过代码实例详细讲解这些工具的使用方法。
|
3月前
|
机器学习/深度学习 人工智能 算法
【宠物识别系统】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+图像识别
宠物识别系统,本系统使用Python作为主要开发语言,基于TensorFlow搭建卷积神经网络算法,并收集了37种常见的猫狗宠物种类数据集【'阿比西尼亚猫(Abyssinian)', '孟加拉猫(Bengal)', '暹罗猫(Birman)', '孟买猫(Bombay)', '英国短毛猫(British Shorthair)', '埃及猫(Egyptian Mau)', '缅因猫(Maine Coon)', '波斯猫(Persian)', '布偶猫(Ragdoll)', '俄罗斯蓝猫(Russian Blue)', '暹罗猫(Siamese)', '斯芬克斯猫(Sphynx)', '美国斗牛犬
233 29
【宠物识别系统】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+图像识别
|
4月前
|
机器学习/深度学习 自然语言处理 语音技术
Python在深度学习领域的应用,重点讲解了神经网络的基础概念、基本结构、训练过程及优化技巧
本文介绍了Python在深度学习领域的应用,重点讲解了神经网络的基础概念、基本结构、训练过程及优化技巧,并通过TensorFlow和PyTorch等库展示了实现神经网络的具体示例,涵盖图像识别、语音识别等多个应用场景。
130 8

热门文章

最新文章