keras 实现 GAN

简介: 通过 Keras 实现 GAN ,其主要过程如下:GAN训练过程分析正如上图所示,通过调节 Generator 和 Discriminator 交替训练来达到不断达到真实数据的拟合过程。

通过 Keras 实现 GAN ,其主要过程如下:

img_4bb140d89c9dbedebbf2639ab0aa8bc8.png
GAN训练过程分析

正如上图所示,通过调节 Generator 和 Discriminator 交替训练来达到不断达到真实数据的拟合过程。这里 Generator 输入为随机数。

import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import numpy as np
from tqdm import tqdm
import matplotlib
matplotlib.use("agg")
import matplotlib.pyplot as plt

from keras.layers import Input
from keras.models import Model, Sequential
from keras.layers.core import Reshape, Dense, Dropout, Flatten
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Convolution2D, UpSampling2D
from keras.layers.normalization import BatchNormalization
from keras.datasets import mnist
from keras.optimizers import Adam
from keras import backend as K
from keras import initializers

K.set_image_dim_ordering('th')

# Deterministic output.
# Tired of seeing the same results every time? Remove the line below.
np.random.seed(1000)

# The results are a little better when the dimensionality of the random vector is only 10.
# The dimensionality has been left at 100 for consistency with other GAN implementations.
randomDim = 100

# Load MNIST data
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5)/127.5
X_train = X_train.reshape(60000, 784)

# Optimizer
adam = Adam(lr=0.0002, beta_1=0.5)

generator = Sequential()
generator.add(Dense(256, input_dim=randomDim, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
generator.add(LeakyReLU(0.2))
generator.add(Dense(512))
generator.add(LeakyReLU(0.2))
generator.add(Dense(1024))
generator.add(LeakyReLU(0.2))
generator.add(Dense(784, activation='tanh'))
generator.compile(loss='binary_crossentropy', optimizer=adam)

discriminator = Sequential()
discriminator.add(Dense(1024, input_dim=784, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(512))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(256))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(1, activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy', optimizer=adam)

# Combined network
discriminator.trainable = False
ganInput = Input(shape=(randomDim,))
x = generator(ganInput)
ganOutput = discriminator(x)
gan = Model(inputs=ganInput, outputs=ganOutput)
gan.compile(loss='binary_crossentropy', optimizer=adam)

dLosses = []
gLosses = []

# Plot the loss from each batch
def plotLoss(epoch):
    plt.figure(figsize=(10, 8))
    plt.plot(dLosses, label='Discriminitive loss')
    plt.plot(gLosses, label='Generative loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig('images/gan_loss_epoch_%d.png' % epoch)

# Create a wall of generated MNIST images
def plotGeneratedImages(epoch, examples=100, dim=(10, 10), figsize=(10, 10)):
    noise = np.random.normal(0, 1, size=[examples, randomDim])
    generatedImages = generator.predict(noise)
    generatedImages = generatedImages.reshape(examples, 28, 28)

    plt.figure(figsize=figsize)
    for i in range(generatedImages.shape[0]):
        plt.subplot(dim[0], dim[1], i+1)
        plt.imshow(generatedImages[i], interpolation='nearest', cmap='gray_r')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig('images/gan_generated_image_epoch_%d.png' % epoch)

# Save the generator and discriminator networks (and weights) for later use
def saveModels(epoch):
    generator.save('models/gan_generator_epoch_%d.h5' % epoch)
    discriminator.save('models/gan_discriminator_epoch_%d.h5' % epoch)

def train(epochs=1, batchSize=128):
    batchCount = int(X_train.shape[0] / batchSize)
    print('Epochs:', epochs)
    print('Batch size:', batchSize)
    print('Batches per epoch:', batchCount)

    for e in range(1, epochs+1):
        print('-'*15, 'Epoch %d' % e, '-'*15)
        for _ in tqdm(range(batchCount)):
            # Get a random set of input noise and images
            noise = np.random.normal(0, 1, size=[batchSize, randomDim])
            imageBatch = X_train[np.random.randint(0, X_train.shape[0], size=batchSize)]

            # Generate fake MNIST images
            # 用 Generator 生成假数据
            generatedImages = generator.predict(noise)
            # print np.shape(imageBatch), np.shape(generatedImages)
            # 将假数据与真实数据进行混合在一起
            X = np.concatenate([imageBatch, generatedImages])

            # Labels for generated and real data
            # 标记所有数据都是假数据
            yDis = np.zeros(2*batchSize)
            # One-sided label smoothing
            # 按真实数据比例,标记前半数据为 0.9 的真实度
            yDis[:batchSize] = 0.9

            # Train discriminator
            # 先训练 Discriminator 让其具有判定能力,同时Generator 也在训练,也能更新参数。
            discriminator.trainable = True
            dloss = discriminator.train_on_batch(X, yDis)

            # Train generator
            # 然后训练 Generator, 注意这里训练 Generator 时候,把 Generator 生成出来的结果置为全真,及按真实数据的方式来进行训练。
            # 先生成相应 batchSize 样本 noise 数据
            noise = np.random.normal(0, 1, size=[batchSize, randomDim])
            # 生成相应的 Discriminator 输出结果
            yGen = np.ones(batchSize)
            # 将 Discriminator 设置为不可训练的状态
            discriminator.trainable = False
           # 训练整个 GAN 网络即可训练出一个能生成真实样本的 Generator
            gloss = gan.train_on_batch(noise, yGen)

        # Store loss of most recent batch from this epoch
        dLosses.append(dloss)
        gLosses.append(gloss)

        if e == 1 or e % 20 == 0:
            plotGeneratedImages(e)
            saveModels(e)

    # Plot losses from every epoch
    plotLoss(e)

if __name__ == '__main__':
    train(200, 128)

上述,通过 Sequential 创建 GeneratorDiscriminator 网络,然后通过 Model 将其进行结合在一起,从而构成了 GAN 网络的全貌。

参考

GAN-in-keras-on-mnist

目录
相关文章
|
6月前
|
机器学习/深度学习 PyTorch 算法框架/工具
Pytorch CIFAR10图像分类 Swin Transformer篇(一)
Pytorch CIFAR10图像分类 Swin Transformer篇(一)
|
6月前
|
机器学习/深度学习 数据可视化 算法
Pytorch CIFAR10图像分类 Swin Transformer篇(二)
Pytorch CIFAR10图像分类 Swin Transformer篇(二)
|
6月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
用TensorFlow实现MNIST神经网络
用TensorFlow实现MNIST神经网络
|
机器学习/深度学习 TensorFlow 算法框架/工具
TensorFlow MNIST手写数字识别(神经网络极简版)
TensorFlow MNIST手写数字识别(神经网络极简版)
TensorFlow MNIST手写数字识别(神经网络极简版)
|
机器学习/深度学习 PyTorch 算法框架/工具
|
TensorFlow 算法框架/工具 计算机视觉
TensorFlow 实现VGG16图像分类
TensorFlow 实现VGG16图像分类
TensorFlow 实现VGG16图像分类
|
机器学习/深度学习 人工智能 PyTorch
Pytorch 基于LeNet的手写数字识别
Pytorch 基于LeNet的手写数字识别
175 0
Pytorch 基于LeNet的手写数字识别
|
机器学习/深度学习 TensorFlow 算法框架/工具
TensorFlow实现GAN生成对抗网络生成MNIST图像
TensorFlow实现GAN生成对抗网络生成MNIST图像
161 0
|
机器学习/深度学习 数据可视化
TensorFlow-CNN卷积神经网络
TensorFlow-CNN卷积神经网络
171 0
TensorFlow-CNN卷积神经网络