生成对抗网络 (GAN)。通过学习图像训练数据集的隐分布(图像的“隐空间”),GAN 可以生成看起来极为真实的新图像。
一个 GAN 由两部分组成:一个“生成器”模型(可将隐空间中的点映射到图像空间中的点)和一个“判别器”模型,后者是一个可以区分真实图像(来自训练数据集)与虚假图像(生成器网络的输出)之间差异的分类器。
GAN 训练循环如下所示:
1.训练判别器
- 在隐空间中对一批随机点采样
- 通过“生成器”模型将这些点转换为虚假图像
- 获取一批真实图像,并将它们与生成的图像组合
- 训练“判别器”模型以对生成的图像与真实图像进行分类
2.训练生成器
- 在隐空间中对随机点采样
- 通过“生成器”网络将这些点转换为虚假图像
- 获取一批真实图像,并将它们与生成的图像组合
- 训练“生成器”模型以“欺骗”判别器,并将虚假图像分类为真实图像
定义判别器
discriminator = keras.Sequential( [ keras.Input(shape=(28, 28, 1)), layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"), layers.LeakyReLU(alpha=0.2), layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"), layers.LeakyReLU(alpha=0.2), layers.GlobalMaxPooling2D(), layers.Dense(1), ], name="discriminator", ) discriminator.summary()
定义生成器
latent_dim = 128 generator = keras.Sequential( [ keras.Input(shape=(latent_dim,)), layers.Dense(7 * 7 * 128), layers.LeakyReLU(alpha=0.2), layers.Reshape((7, 7, 128)), layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"), layers.LeakyReLU(alpha=0.2), layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"), layers.LeakyReLU(alpha=0.2), layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"), ], name="generator", )
单步训练
d_optimizer = keras.optimizers.Adam(learning_rate=0.0003) g_optimizer = keras.optimizers.Adam(learning_rate=0.0004) loss_fn = keras.losses.BinaryCrossentropy(from_logits=True) @tf.function def train_step(real_images): random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim)) generated_images = generator(random_latent_vectors) combined_images = tf.concat([generated_images, real_images], axis=0) labels = tf.concat( [tf.ones((batch_size, 1)), tf.zeros((real_images.shape[0], 1))], axis=0 ) labels += 0.05 * tf.random.uniform(labels.shape) with tf.GradientTape() as tape: predictions = discriminator(combined_images) d_loss = loss_fn(labels, predictions) grads = tape.gradient(d_loss, discriminator.trainable_weights) d_optimizer.apply_gradients(zip(grads, discriminator.trainable_weights)) random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim)) misleading_labels = tf.zeros((batch_size, 1)) with tf.GradientTape() as tape: predictions = discriminator(generator(random_latent_vectors)) g_loss = loss_fn(misleading_labels, predictions) grads = tape.gradient(g_loss, generator.trainable_weights) g_optimizer.apply_gradients(zip(grads, generator.trainable_weights)) return d_loss, g_loss, generated_images
完整代码
""" * Created with PyCharm * 作者: 阿光 * 日期: 2022/1/3 * 时间: 20:49 * 描述: """ import numpy as np import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers discriminator = keras.Sequential( [ keras.Input(shape=(28, 28, 1)), layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"), layers.LeakyReLU(alpha=0.2), layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"), layers.LeakyReLU(alpha=0.2), layers.GlobalMaxPooling2D(), layers.Dense(1), ], name="discriminator", ) discriminator.summary() latent_dim = 128 generator = keras.Sequential( [ keras.Input(shape=(latent_dim,)), layers.Dense(7 * 7 * 128), layers.LeakyReLU(alpha=0.2), layers.Reshape((7, 7, 128)), layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"), layers.LeakyReLU(alpha=0.2), layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"), layers.LeakyReLU(alpha=0.2), layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"), ], name="generator", ) d_optimizer = keras.optimizers.Adam(learning_rate=0.0003) g_optimizer = keras.optimizers.Adam(learning_rate=0.0004) loss_fn = keras.losses.BinaryCrossentropy(from_logits=True) @tf.function def train_step(real_images): random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim)) generated_images = generator(random_latent_vectors) combined_images = tf.concat([generated_images, real_images], axis=0) labels = tf.concat( [tf.ones((batch_size, 1)), tf.zeros((real_images.shape[0], 1))], axis=0 ) labels += 0.05 * tf.random.uniform(labels.shape) with tf.GradientTape() as tape: predictions = discriminator(combined_images) d_loss = loss_fn(labels, predictions) grads = tape.gradient(d_loss, discriminator.trainable_weights) d_optimizer.apply_gradients(zip(grads, discriminator.trainable_weights)) random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim)) misleading_labels = tf.zeros((batch_size, 1)) with tf.GradientTape() as tape: predictions = discriminator(generator(random_latent_vectors)) g_loss = loss_fn(misleading_labels, predictions) grads = tape.gradient(g_loss, generator.trainable_weights) g_optimizer.apply_gradients(zip(grads, generator.trainable_weights)) return d_loss, g_loss, generated_images import os batch_size = 64 (x_train, _), (x_test, _) = keras.datasets.mnist.load_data() all_digits = np.concatenate([x_train, x_test]) all_digits = all_digits.astype("float32") / 255.0 all_digits = np.reshape(all_digits, (-1, 28, 28, 1)) dataset = tf.data.Dataset.from_tensor_slices(all_digits) dataset = dataset.shuffle(buffer_size=1024).batch(batch_size) epochs = 1 save_dir = "./" for epoch in range(epochs): print("\nStart epoch", epoch) for step, real_images in enumerate(dataset): d_loss, g_loss, generated_images = train_step(real_images) if step % 200 == 0: print("discriminator loss at step %d: %.2f" % (step, d_loss)) print("adversarial loss at step %d: %.2f" % (step, g_loss)) img = tf.keras.preprocessing.image.array_to_img( generated_images[0] * 255.0, scale=False ) img.save(os.path.join(save_dir, "generated_img" + str(step) + ".png")) if step > 10: break