keras 实现 GAN

简介: 通过 Keras 实现 GAN ,其主要过程如下:GAN训练过程分析正如上图所示,通过调节 Generator 和 Discriminator 交替训练来达到不断达到真实数据的拟合过程。

通过 Keras 实现 GAN ,其主要过程如下:

img_4bb140d89c9dbedebbf2639ab0aa8bc8.png
GAN训练过程分析

正如上图所示,通过调节 Generator 和 Discriminator 交替训练来达到不断达到真实数据的拟合过程。这里 Generator 输入为随机数。

import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import numpy as np
from tqdm import tqdm
import matplotlib
matplotlib.use("agg")
import matplotlib.pyplot as plt

from keras.layers import Input
from keras.models import Model, Sequential
from keras.layers.core import Reshape, Dense, Dropout, Flatten
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Convolution2D, UpSampling2D
from keras.layers.normalization import BatchNormalization
from keras.datasets import mnist
from keras.optimizers import Adam
from keras import backend as K
from keras import initializers

K.set_image_dim_ordering('th')

# Deterministic output.
# Tired of seeing the same results every time? Remove the line below.
np.random.seed(1000)

# The results are a little better when the dimensionality of the random vector is only 10.
# The dimensionality has been left at 100 for consistency with other GAN implementations.
randomDim = 100

# Load MNIST data
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5)/127.5
X_train = X_train.reshape(60000, 784)

# Optimizer
adam = Adam(lr=0.0002, beta_1=0.5)

generator = Sequential()
generator.add(Dense(256, input_dim=randomDim, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
generator.add(LeakyReLU(0.2))
generator.add(Dense(512))
generator.add(LeakyReLU(0.2))
generator.add(Dense(1024))
generator.add(LeakyReLU(0.2))
generator.add(Dense(784, activation='tanh'))
generator.compile(loss='binary_crossentropy', optimizer=adam)

discriminator = Sequential()
discriminator.add(Dense(1024, input_dim=784, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(512))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(256))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(1, activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy', optimizer=adam)

# Combined network
discriminator.trainable = False
ganInput = Input(shape=(randomDim,))
x = generator(ganInput)
ganOutput = discriminator(x)
gan = Model(inputs=ganInput, outputs=ganOutput)
gan.compile(loss='binary_crossentropy', optimizer=adam)

dLosses = []
gLosses = []

# Plot the loss from each batch
def plotLoss(epoch):
    plt.figure(figsize=(10, 8))
    plt.plot(dLosses, label='Discriminitive loss')
    plt.plot(gLosses, label='Generative loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig('images/gan_loss_epoch_%d.png' % epoch)

# Create a wall of generated MNIST images
def plotGeneratedImages(epoch, examples=100, dim=(10, 10), figsize=(10, 10)):
    noise = np.random.normal(0, 1, size=[examples, randomDim])
    generatedImages = generator.predict(noise)
    generatedImages = generatedImages.reshape(examples, 28, 28)

    plt.figure(figsize=figsize)
    for i in range(generatedImages.shape[0]):
        plt.subplot(dim[0], dim[1], i+1)
        plt.imshow(generatedImages[i], interpolation='nearest', cmap='gray_r')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig('images/gan_generated_image_epoch_%d.png' % epoch)

# Save the generator and discriminator networks (and weights) for later use
def saveModels(epoch):
    generator.save('models/gan_generator_epoch_%d.h5' % epoch)
    discriminator.save('models/gan_discriminator_epoch_%d.h5' % epoch)

def train(epochs=1, batchSize=128):
    batchCount = int(X_train.shape[0] / batchSize)
    print('Epochs:', epochs)
    print('Batch size:', batchSize)
    print('Batches per epoch:', batchCount)

    for e in range(1, epochs+1):
        print('-'*15, 'Epoch %d' % e, '-'*15)
        for _ in tqdm(range(batchCount)):
            # Get a random set of input noise and images
            noise = np.random.normal(0, 1, size=[batchSize, randomDim])
            imageBatch = X_train[np.random.randint(0, X_train.shape[0], size=batchSize)]

            # Generate fake MNIST images
            # 用 Generator 生成假数据
            generatedImages = generator.predict(noise)
            # print np.shape(imageBatch), np.shape(generatedImages)
            # 将假数据与真实数据进行混合在一起
            X = np.concatenate([imageBatch, generatedImages])

            # Labels for generated and real data
            # 标记所有数据都是假数据
            yDis = np.zeros(2*batchSize)
            # One-sided label smoothing
            # 按真实数据比例,标记前半数据为 0.9 的真实度
            yDis[:batchSize] = 0.9

            # Train discriminator
            # 先训练 Discriminator 让其具有判定能力,同时Generator 也在训练,也能更新参数。
            discriminator.trainable = True
            dloss = discriminator.train_on_batch(X, yDis)

            # Train generator
            # 然后训练 Generator, 注意这里训练 Generator 时候,把 Generator 生成出来的结果置为全真,及按真实数据的方式来进行训练。
            # 先生成相应 batchSize 样本 noise 数据
            noise = np.random.normal(0, 1, size=[batchSize, randomDim])
            # 生成相应的 Discriminator 输出结果
            yGen = np.ones(batchSize)
            # 将 Discriminator 设置为不可训练的状态
            discriminator.trainable = False
           # 训练整个 GAN 网络即可训练出一个能生成真实样本的 Generator
            gloss = gan.train_on_batch(noise, yGen)

        # Store loss of most recent batch from this epoch
        dLosses.append(dloss)
        gLosses.append(gloss)

        if e == 1 or e % 20 == 0:
            plotGeneratedImages(e)
            saveModels(e)

    # Plot losses from every epoch
    plotLoss(e)

if __name__ == '__main__':
    train(200, 128)

上述,通过 Sequential 创建 GeneratorDiscriminator 网络,然后通过 Model 将其进行结合在一起,从而构成了 GAN 网络的全貌。

参考

GAN-in-keras-on-mnist

目录
相关文章
|
9月前
|
机器学习/深度学习 Python
【Python 机器学习专栏】堆叠(Stacking)集成策略详解
【4月更文挑战第30天】堆叠(Stacking)是机器学习中的集成学习策略,通过多层模型组合提升预测性能。该方法包含基础学习器和元学习器两个阶段:基础学习器使用多种模型(如决策树、SVM、神经网络)学习并产生预测;元学习器则利用这些预测结果作为新特征进行学习,生成最终预测。在Python中实现堆叠集成,需划分数据集、训练基础模型、构建新训练集、训练元学习器。堆叠集成的优势在于提高性能和灵活性,但可能增加计算复杂度和过拟合风险。
1008 0
|
9月前
|
分布式计算 DataWorks Java
DataWorks常见问题之数据集成导出分区表的全量数据如何解决
DataWorks是阿里云提供的一站式大数据开发与管理平台,支持数据集成、数据开发、数据治理等功能;在本汇总中,我们梳理了DataWorks产品在使用过程中经常遇到的问题及解答,以助用户在数据处理和分析工作中提高效率,降低难度。
201 0
|
传感器
电脑鼠标的工作原理是怎样的?底层原理是什么?
电脑鼠标的工作原理是怎样的?底层原理是什么?
822 0
|
关系型数据库 Java 数据库
Failed to configure a DataSource: 'url' attribute is not specified and no embedded datasource could
版权声明:本文为 testcs_dn(微wx笑) 原创文章,非商用自由转载-保持署名-注明出处,谢谢。 https://blog.csdn.net/testcs_dn/article/details/80897402 ...
6785 0
|
SQL 分布式计算 DataWorks
DataWorks开通并导入本地数据
大数据开发治理平台 DataWorks基于MaxCompute/EMR/MC-Hologres等大数据计算引擎,为客户提供专业高效、安全可靠的一站式大数据开发与治理平台,自带阿里巴巴数据中台与数据治理最佳实践,赋能各行业数字化转型。每天阿里巴巴集团内部有数万名数据/算法工程师正在使用DataWorks,承担集团99%数据业务构建。本篇简单介绍下Dataworks的开通以及数据开发使用
1289 0
DataWorks开通并导入本地数据
|
机器学习/深度学习 编解码 算法
重磅好文透彻理解,异构图上 Node 分类理论与DGL源码实战
重磅好文透彻理解,异构图上 Node 分类理论与DGL源码实战
重磅好文透彻理解,异构图上 Node 分类理论与DGL源码实战
|
机器学习/深度学习 算法 Python
机器学习中的数学原理——对数似然函数
机器学习中的数学原理——对数似然函数
1053 0
机器学习中的数学原理——对数似然函数
|
安全 Linux 网络安全
VS Code通过跳板机连接服务器进行远程代码开发
VS Code通过跳板机连接服务器进行远程代码开发
1956 0
VS Code通过跳板机连接服务器进行远程代码开发
|
资源调度 运维 监控
SLS机器学习介绍(03):时序异常检测建模
虽然计算机软硬件的快速发展已经极大提高了应用程序的可靠性,但是在大型集群中仍然存在大量的软件错误和硬件故障。系统要求7x24小时不间断运行,因此,对这些系统进行持续监控至关重要。这就要求我们就被从系统中持续采集系统运行日志,业务运行日志的能力,并能快速的分析和监控当前状态曲线的异常,一旦发现异常,能第一时间将信息送到相关人员手中。
22962 0
SLS机器学习介绍(03):时序异常检测建模
|
算法框架/工具 数据挖掘
keras 自定义 metrics
自定义 Metrics 在 keras 中操作的均为 Tensor 对象,因此,需要定义操作 Tensor 的函数来操作所有输出结果,定义好函数之后,直接将其放在 model.
4051 0

热门文章

最新文章