keras 实现 GAN

简介: 通过 Keras 实现 GAN ,其主要过程如下:GAN训练过程分析正如上图所示,通过调节 Generator 和 Discriminator 交替训练来达到不断达到真实数据的拟合过程。

通过 Keras 实现 GAN ,其主要过程如下:

img_4bb140d89c9dbedebbf2639ab0aa8bc8.png
GAN训练过程分析

正如上图所示,通过调节 Generator 和 Discriminator 交替训练来达到不断达到真实数据的拟合过程。这里 Generator 输入为随机数。

import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import numpy as np
from tqdm import tqdm
import matplotlib
matplotlib.use("agg")
import matplotlib.pyplot as plt

from keras.layers import Input
from keras.models import Model, Sequential
from keras.layers.core import Reshape, Dense, Dropout, Flatten
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Convolution2D, UpSampling2D
from keras.layers.normalization import BatchNormalization
from keras.datasets import mnist
from keras.optimizers import Adam
from keras import backend as K
from keras import initializers

K.set_image_dim_ordering('th')

# Deterministic output.
# Tired of seeing the same results every time? Remove the line below.
np.random.seed(1000)

# The results are a little better when the dimensionality of the random vector is only 10.
# The dimensionality has been left at 100 for consistency with other GAN implementations.
randomDim = 100

# Load MNIST data
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5)/127.5
X_train = X_train.reshape(60000, 784)

# Optimizer
adam = Adam(lr=0.0002, beta_1=0.5)

generator = Sequential()
generator.add(Dense(256, input_dim=randomDim, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
generator.add(LeakyReLU(0.2))
generator.add(Dense(512))
generator.add(LeakyReLU(0.2))
generator.add(Dense(1024))
generator.add(LeakyReLU(0.2))
generator.add(Dense(784, activation='tanh'))
generator.compile(loss='binary_crossentropy', optimizer=adam)

discriminator = Sequential()
discriminator.add(Dense(1024, input_dim=784, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(512))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(256))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(1, activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy', optimizer=adam)

# Combined network
discriminator.trainable = False
ganInput = Input(shape=(randomDim,))
x = generator(ganInput)
ganOutput = discriminator(x)
gan = Model(inputs=ganInput, outputs=ganOutput)
gan.compile(loss='binary_crossentropy', optimizer=adam)

dLosses = []
gLosses = []

# Plot the loss from each batch
def plotLoss(epoch):
    plt.figure(figsize=(10, 8))
    plt.plot(dLosses, label='Discriminitive loss')
    plt.plot(gLosses, label='Generative loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig('images/gan_loss_epoch_%d.png' % epoch)

# Create a wall of generated MNIST images
def plotGeneratedImages(epoch, examples=100, dim=(10, 10), figsize=(10, 10)):
    noise = np.random.normal(0, 1, size=[examples, randomDim])
    generatedImages = generator.predict(noise)
    generatedImages = generatedImages.reshape(examples, 28, 28)

    plt.figure(figsize=figsize)
    for i in range(generatedImages.shape[0]):
        plt.subplot(dim[0], dim[1], i+1)
        plt.imshow(generatedImages[i], interpolation='nearest', cmap='gray_r')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig('images/gan_generated_image_epoch_%d.png' % epoch)

# Save the generator and discriminator networks (and weights) for later use
def saveModels(epoch):
    generator.save('models/gan_generator_epoch_%d.h5' % epoch)
    discriminator.save('models/gan_discriminator_epoch_%d.h5' % epoch)

def train(epochs=1, batchSize=128):
    batchCount = int(X_train.shape[0] / batchSize)
    print('Epochs:', epochs)
    print('Batch size:', batchSize)
    print('Batches per epoch:', batchCount)

    for e in range(1, epochs+1):
        print('-'*15, 'Epoch %d' % e, '-'*15)
        for _ in tqdm(range(batchCount)):
            # Get a random set of input noise and images
            noise = np.random.normal(0, 1, size=[batchSize, randomDim])
            imageBatch = X_train[np.random.randint(0, X_train.shape[0], size=batchSize)]

            # Generate fake MNIST images
            # 用 Generator 生成假数据
            generatedImages = generator.predict(noise)
            # print np.shape(imageBatch), np.shape(generatedImages)
            # 将假数据与真实数据进行混合在一起
            X = np.concatenate([imageBatch, generatedImages])

            # Labels for generated and real data
            # 标记所有数据都是假数据
            yDis = np.zeros(2*batchSize)
            # One-sided label smoothing
            # 按真实数据比例,标记前半数据为 0.9 的真实度
            yDis[:batchSize] = 0.9

            # Train discriminator
            # 先训练 Discriminator 让其具有判定能力,同时Generator 也在训练,也能更新参数。
            discriminator.trainable = True
            dloss = discriminator.train_on_batch(X, yDis)

            # Train generator
            # 然后训练 Generator, 注意这里训练 Generator 时候,把 Generator 生成出来的结果置为全真,及按真实数据的方式来进行训练。
            # 先生成相应 batchSize 样本 noise 数据
            noise = np.random.normal(0, 1, size=[batchSize, randomDim])
            # 生成相应的 Discriminator 输出结果
            yGen = np.ones(batchSize)
            # 将 Discriminator 设置为不可训练的状态
            discriminator.trainable = False
           # 训练整个 GAN 网络即可训练出一个能生成真实样本的 Generator
            gloss = gan.train_on_batch(noise, yGen)

        # Store loss of most recent batch from this epoch
        dLosses.append(dloss)
        gLosses.append(gloss)

        if e == 1 or e % 20 == 0:
            plotGeneratedImages(e)
            saveModels(e)

    # Plot losses from every epoch
    plotLoss(e)

if __name__ == '__main__':
    train(200, 128)

上述,通过 Sequential 创建 GeneratorDiscriminator 网络,然后通过 Model 将其进行结合在一起,从而构成了 GAN 网络的全貌。

参考

GAN-in-keras-on-mnist

目录
相关文章
|
存储 搜索推荐 PyTorch
通义千问7B-基于本地知识库问答
上期,我们介绍了通义千问7B模型的微调+部署方式,但在实际使用时,很多开发者还是希望能够结合特定的行业知识来增强模型效果,这时就需要通过外接知识库,让大模型能够返回更精确的结果。
|
安全 Linux 网络安全
VS Code通过跳板机连接服务器进行远程代码开发
VS Code通过跳板机连接服务器进行远程代码开发
1956 0
VS Code通过跳板机连接服务器进行远程代码开发
|
8月前
|
机器学习/深度学习 TensorFlow API
TensorFlow与Keras实战:构建深度学习模型
本文探讨了TensorFlow和其高级API Keras在深度学习中的应用。TensorFlow是Google开发的高性能开源框架,支持分布式计算,而Keras以其用户友好和模块化设计简化了神经网络构建。通过一个手写数字识别的实战案例,展示了如何使用Keras加载MNIST数据集、构建CNN模型、训练及评估模型,并进行预测。案例详述了数据预处理、模型构建、训练过程和预测新图像的步骤,为读者提供TensorFlow和Keras的基础实践指导。
555 59
|
9月前
|
机器学习/深度学习 Python
【Python 机器学习专栏】堆叠(Stacking)集成策略详解
【4月更文挑战第30天】堆叠(Stacking)是机器学习中的集成学习策略,通过多层模型组合提升预测性能。该方法包含基础学习器和元学习器两个阶段:基础学习器使用多种模型(如决策树、SVM、神经网络)学习并产生预测;元学习器则利用这些预测结果作为新特征进行学习,生成最终预测。在Python中实现堆叠集成,需划分数据集、训练基础模型、构建新训练集、训练元学习器。堆叠集成的优势在于提高性能和灵活性,但可能增加计算复杂度和过拟合风险。
1008 0
|
9月前
|
数据采集 索引 Python
Pandas之DataFrame,快速入门,迅速掌握(二)
Pandas之DataFrame,快速入门,迅速掌握(二)
162 0
|
5月前
|
机器学习/深度学习 数据采集 TensorFlow
使用Python实现智能物流路径优化
使用Python实现智能物流路径优化
186 1
|
8月前
|
机器学习/深度学习 Python
sigmoid函数
本文探讨了高等数学中的sigmoid函数,它在神经网络中的应用,特别是在二分类问题的输出层。sigmoid函数公式为 $\frac{1}{1 + e^{-x}}$,其导数为 $sigmoid(x)\cdot(1-sigmoid(x))$。文章还展示了sigmoid函数的图像,并提供了一个使用Python绘制函数及其导数的代码示例。
336 2
|
9月前
|
分布式计算 DataWorks Java
DataWorks常见问题之数据集成导出分区表的全量数据如何解决
DataWorks是阿里云提供的一站式大数据开发与管理平台,支持数据集成、数据开发、数据治理等功能;在本汇总中,我们梳理了DataWorks产品在使用过程中经常遇到的问题及解答,以助用户在数据处理和分析工作中提高效率,降低难度。
201 0
|
9月前
|
机器学习/深度学习 SQL 人工智能
机器学习PAI常见问题之训练模型报错如何解决
PAI(平台为智能,Platform for Artificial Intelligence)是阿里云提供的一个全面的人工智能开发平台,旨在为开发者提供机器学习、深度学习等人工智能技术的模型训练、优化和部署服务。以下是PAI平台使用中的一些常见问题及其答案汇总,帮助用户解决在使用过程中遇到的问题。
|
机器学习/深度学习 编解码 算法
重磅好文透彻理解,异构图上 Node 分类理论与DGL源码实战
重磅好文透彻理解,异构图上 Node 分类理论与DGL源码实战
重磅好文透彻理解,异构图上 Node 分类理论与DGL源码实战

热门文章

最新文章