keras 实现 GAN

简介: 通过 Keras 实现 GAN ,其主要过程如下:GAN训练过程分析正如上图所示,通过调节 Generator 和 Discriminator 交替训练来达到不断达到真实数据的拟合过程。

通过 Keras 实现 GAN ,其主要过程如下:

img_4bb140d89c9dbedebbf2639ab0aa8bc8.png
GAN训练过程分析

正如上图所示,通过调节 Generator 和 Discriminator 交替训练来达到不断达到真实数据的拟合过程。这里 Generator 输入为随机数。

import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import numpy as np
from tqdm import tqdm
import matplotlib
matplotlib.use("agg")
import matplotlib.pyplot as plt

from keras.layers import Input
from keras.models import Model, Sequential
from keras.layers.core import Reshape, Dense, Dropout, Flatten
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Convolution2D, UpSampling2D
from keras.layers.normalization import BatchNormalization
from keras.datasets import mnist
from keras.optimizers import Adam
from keras import backend as K
from keras import initializers

K.set_image_dim_ordering('th')

# Deterministic output.
# Tired of seeing the same results every time? Remove the line below.
np.random.seed(1000)

# The results are a little better when the dimensionality of the random vector is only 10.
# The dimensionality has been left at 100 for consistency with other GAN implementations.
randomDim = 100

# Load MNIST data
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5)/127.5
X_train = X_train.reshape(60000, 784)

# Optimizer
adam = Adam(lr=0.0002, beta_1=0.5)

generator = Sequential()
generator.add(Dense(256, input_dim=randomDim, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
generator.add(LeakyReLU(0.2))
generator.add(Dense(512))
generator.add(LeakyReLU(0.2))
generator.add(Dense(1024))
generator.add(LeakyReLU(0.2))
generator.add(Dense(784, activation='tanh'))
generator.compile(loss='binary_crossentropy', optimizer=adam)

discriminator = Sequential()
discriminator.add(Dense(1024, input_dim=784, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(512))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(256))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(1, activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy', optimizer=adam)

# Combined network
discriminator.trainable = False
ganInput = Input(shape=(randomDim,))
x = generator(ganInput)
ganOutput = discriminator(x)
gan = Model(inputs=ganInput, outputs=ganOutput)
gan.compile(loss='binary_crossentropy', optimizer=adam)

dLosses = []
gLosses = []

# Plot the loss from each batch
def plotLoss(epoch):
    plt.figure(figsize=(10, 8))
    plt.plot(dLosses, label='Discriminitive loss')
    plt.plot(gLosses, label='Generative loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig('images/gan_loss_epoch_%d.png' % epoch)

# Create a wall of generated MNIST images
def plotGeneratedImages(epoch, examples=100, dim=(10, 10), figsize=(10, 10)):
    noise = np.random.normal(0, 1, size=[examples, randomDim])
    generatedImages = generator.predict(noise)
    generatedImages = generatedImages.reshape(examples, 28, 28)

    plt.figure(figsize=figsize)
    for i in range(generatedImages.shape[0]):
        plt.subplot(dim[0], dim[1], i+1)
        plt.imshow(generatedImages[i], interpolation='nearest', cmap='gray_r')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig('images/gan_generated_image_epoch_%d.png' % epoch)

# Save the generator and discriminator networks (and weights) for later use
def saveModels(epoch):
    generator.save('models/gan_generator_epoch_%d.h5' % epoch)
    discriminator.save('models/gan_discriminator_epoch_%d.h5' % epoch)

def train(epochs=1, batchSize=128):
    batchCount = int(X_train.shape[0] / batchSize)
    print('Epochs:', epochs)
    print('Batch size:', batchSize)
    print('Batches per epoch:', batchCount)

    for e in range(1, epochs+1):
        print('-'*15, 'Epoch %d' % e, '-'*15)
        for _ in tqdm(range(batchCount)):
            # Get a random set of input noise and images
            noise = np.random.normal(0, 1, size=[batchSize, randomDim])
            imageBatch = X_train[np.random.randint(0, X_train.shape[0], size=batchSize)]

            # Generate fake MNIST images
            # 用 Generator 生成假数据
            generatedImages = generator.predict(noise)
            # print np.shape(imageBatch), np.shape(generatedImages)
            # 将假数据与真实数据进行混合在一起
            X = np.concatenate([imageBatch, generatedImages])

            # Labels for generated and real data
            # 标记所有数据都是假数据
            yDis = np.zeros(2*batchSize)
            # One-sided label smoothing
            # 按真实数据比例,标记前半数据为 0.9 的真实度
            yDis[:batchSize] = 0.9

            # Train discriminator
            # 先训练 Discriminator 让其具有判定能力,同时Generator 也在训练,也能更新参数。
            discriminator.trainable = True
            dloss = discriminator.train_on_batch(X, yDis)

            # Train generator
            # 然后训练 Generator, 注意这里训练 Generator 时候,把 Generator 生成出来的结果置为全真,及按真实数据的方式来进行训练。
            # 先生成相应 batchSize 样本 noise 数据
            noise = np.random.normal(0, 1, size=[batchSize, randomDim])
            # 生成相应的 Discriminator 输出结果
            yGen = np.ones(batchSize)
            # 将 Discriminator 设置为不可训练的状态
            discriminator.trainable = False
           # 训练整个 GAN 网络即可训练出一个能生成真实样本的 Generator
            gloss = gan.train_on_batch(noise, yGen)

        # Store loss of most recent batch from this epoch
        dLosses.append(dloss)
        gLosses.append(gloss)

        if e == 1 or e % 20 == 0:
            plotGeneratedImages(e)
            saveModels(e)

    # Plot losses from every epoch
    plotLoss(e)

if __name__ == '__main__':
    train(200, 128)

上述,通过 Sequential 创建 GeneratorDiscriminator 网络,然后通过 Model 将其进行结合在一起,从而构成了 GAN 网络的全貌。

参考

GAN-in-keras-on-mnist

目录
相关文章
|
负载均衡 安全 网络协议
|
开发框架 测试技术 Android开发
移动应用开发之旅:从新手到专家
本文将带领读者踏上移动应用开发的旅程,从基础概念的铺垫到高级技术的深入,逐步揭示如何构建一个成功的移动应用。文章不仅涵盖移动操作系统的核心知识,还提供实用的开发技巧和最佳实践,旨在帮助初学者快速入门,并引导有一定基础的开发者进一步提升技能。通过阅读本文,你将了解移动应用开发的各个阶段,包括设计、编码、测试和发布,以及如何应对市场变化和技术更新的挑战。无论你是刚开始探索移动应用开发的世界,还是希望扩展你的技术栈,这篇文章都将为你提供宝贵的指导和灵感。让我们开始这段激动人心的旅程吧!
372 6
|
存储 缓存 算法
C语言在实现高效算法方面的特点与优势,包括高效性、灵活性、可移植性和底层访问能力
本文探讨了C语言在实现高效算法方面的特点与优势,包括高效性、灵活性、可移植性和底层访问能力。文章还分析了数据结构的选择与优化、算法设计的优化策略、内存管理和代码优化技巧,并通过实际案例展示了C语言在排序和图遍历算法中的高效实现。
390 2
|
缓存 前端开发 JavaScript
ES6 全部特性详解
ES6 是 JavaScript 语言的一个重要升级,它引入了大量新的功能,极大地增强了 JavaScript 的表达力和可读性。通过了解和掌握这些特性,开发者可以编写出更加简洁、高效、优雅的代码,并轻松应对大型项目的复杂性。
343 7
|
Web App开发 JavaScript 前端开发
使用Node.js和Express框架构建Web服务器
使用Node.js和Express框架构建Web服务器
|
网络协议 算法 网络性能优化
计算机网络 第五章 网络层(习题)
计算机网络 第五章 网络层(习题)
625 1
|
机器学习/深度学习 存储 SQL
AllData数据中台核心菜单十:指标体系
杭州奥零数据科技有限公司成立于2023年,专注于数据中台业务,维护开源项目AllData并提供商业版解决方案。AllData提供数据集成、存储、开发、治理及BI展示等一站式服务,支持AI大模型应用,助力企业高效利用数据价值。
|
机器学习/深度学习 API TensorFlow
TensorFlow的高级API:tf.keras深度解析
【4月更文挑战第17天】本文深入解析了TensorFlow的高级API `tf.keras`,包括顺序模型和函数式API的模型构建,以及模型编译、训练、评估和预测的步骤。`tf.keras`结合了Keras的易用性和TensorFlow的性能,支持回调函数、模型保存与加载等高级特性,助力提升深度学习开发效率。
|
Web App开发 JavaScript 前端开发
从浏览器原理出发聊聊Chrome插件
本文从浏览器架构演进、插件运行机制、插件基本介绍和一些常见的插件实现思路几个方向聊聊Chrome插件。