【深度学习】实验17 使用GAN生成手写数字样本

简介: 【深度学习】实验17 使用GAN生成手写数字样本

使用GAN生成手写数字样本

生成对抗网络

GAN(Generative Adversarial Networks)生成对抗网络是一种深度学习模型架构,由深度生成网络(Generator)和深度鉴别网络(Discriminator)组成,并且利用对抗学习的方式训练。GAN最初由Ian Goodfellow在2014年提出,自提出以来一直受到学术界和工业界广泛的关注和研究。


GAN的主要思想是让生成网络从噪声中生成样本,并通过鉴别网络来评估生成的样本与真实数据的相似度。生成网络利用噪声输入生成样本,鉴别网络则根据输入的样本给出一个判断,判断这个样本是不是真实数据。生成网络和鉴别网络通过对抗学习的方式来互相学习和提高。在训练过程中,生成网络希望生成的样本能够欺骗鉴别网络,鉴别网络则希望能够区分真实数据和生成数据,从而达到提高样本质量的目的。


GAN的应用非常广泛,主要包括图像生成、视频生成、自然语言处理等领域。在图像生成方面,GAN可以用于生成各种样式的图片,例如人物头像、动物、食品等。在视频生成方面,GAN可以生成逼真的视频序列,包括人物动作、自然风景等。在自然语言处理方面,GAN可以生成逼真的对话、文章等。


GAN的训练过程相对其他深度学习模型更加复杂。生成网络和鉴别网络需要保持平衡,让生成网络生成的样本能够欺骗鉴别网络,同时鉴别网络也需要保持自己的准确率,判断生成的样本是否真实。由于GAN的训练过程极易出现训练不稳定、模式崩溃等问题,因此需要在使用时进行一定的调整和优化。


GAN的发展史上涌现出一系列的变体模型,例如Conditional GAN(CGAN)、CycleGAN、Pix2Pix等。这些变体模型在应用场景上有所不同,但是核心思想都是在GAN的基础上进行调整和改进。


GAN在学术界和工业界都受到了广泛的关注和研究,许多实际应用都对GAN有较高的需求。同时,GAN的研究也面临着一系列的问题和挑战,例如GAN的稳定性、样本多样性等。可以预见,在未来的发展中,GAN会继续得到广泛的关注和应用。


程序设计

# 导入相关库
from __future__ import print_function, division 
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import matplotlib.pyplot as plt
import numpy as np
import sys
import os
class GAN():
    def __init__(self):
        # 行28,列28,也就是mnist的shape
        # 通道为1,灰度图
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        # 28*28*1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 100
        # adam优化器
        optimizer = Adam(0.0002, 0.5)
        # 构造一个判别器
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
        # 构造一个生成器
        self.generator = self.build_generator()
        gan_input = Input(shape=(self.latent_dim,))
        img = self.generator(gan_input)
        # 在训练generator的时候不训练discriminator
        self.discriminator.trainable = False
        # 对生成的假图片进行预测
        validity = self.discriminator(img)
        self.combined = Model(gan_input, validity)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
    # 定义生成器
    def build_generator(self):
        model = Sequential()
        model.add(Dense(256, input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        #全连接层,28*28*1个神经元
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        #变成图片的形状
        model.add(Reshape(self.img_shape))
        noise = Input(shape=(self.latent_dim,))
        #建立了从输入100维随机向量到28,28,1大小的图片生成模型
        img = model(noise)
        return Model(noise, img)
    # 定义判别器
    def build_discriminator(self):
        model = Sequential()
        # 输入一张图片
        model.add(Flatten(input_shape=self.img_shape))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))     
        # 判别真伪
        model.add(Dense(1, activation='sigmoid'))
        img = Input(shape=self.img_shape)
        validity = model(img)
        return Model(img, validity)
    # 定义训练函数
    def train(self, epochs, batch_size=128, sample_interval=50):
        # 获取数据
        (X_train, _), (_,_) = mnist.load_data()
        # 进行标准化
        # 将图片像素值映射到-1到1
        X_train = X_train / 127.5 - 1
        X_train = np.expand_dims(X_train, axis=3)
        # 创建标签
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))
        # 先训练判别器,再训练生成器
        for epoch in range(epochs):
            # 随机选取batch_size个图片
            # 对discriminator进行训练
            # 从train训练集里面随机找出batch—size大小(这么多个)的索引值
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            # 取出一个batch大小的图片
            imgs = X_train[idx] 
            # 正态分布生成batch_size个100维向量作为输入
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            # 用生成model的predict方法(model内部方法)将输入进行生成输出
            gen_imgs = self.generator.predict(noise)
            # 输入真实图片和标签全1》》到判别model,》》计算判别模型的loss
            d_loss_real = self.discriminator.train_on_batch(imgs, valid) 
            # 输入假的图片和标签全0》》到判别model,》计算判别模型的loss 
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake) 
            # 将两者损失结合作为总损失
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) 
            # 训练generator
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            # 如果输入噪音的输出是1,则正确,输入噪音输出是0,则生成网络需要改进,所以loss累加
            g_loss = self.combined.train_on_batch(noise, valid)
            # D准确度越高,代表G生成的图片越离谱,准确率为0.5左右就可以以假乱真了
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))        
            # 每sample_interval轮生成一个图片
            if epoch % sample_interval == 0 :
                self.sample_images(epoch)
    # 定义生成图片函数
    def sample_images(self, epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)
        gen_imgs = 0.5 * gen_imgs + 0.5
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/%d.png" % epoch)
        plt.close()
if __name__ == '__main__':
    if not os.path.exists("./images"):
        os.makedirs("./images")
    gan = GAN()
    gan.train(epochs=10000, batch_size=256, sample_interval=200)
   Using TensorFlow backend.
   WARNING:tensorflow:From /home/nlp/anaconda3/lib/python3.7/site-packages/tensorflow/python/ops/nn_impl.py:180: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
   Instructions for updating:
   Use tf.where in 2.0, which has the same broadcast rule as np.where
   WARNING:tensorflow:From /home/nlp/anaconda3/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:422: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.
   /home/nlp/anaconda3/lib/python3.7/site-packages/keras/engine/training.py:297: UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ?
     'Discrepancy between trainable weights and collected trainable'
   0 [D loss: 0.986130, acc.: 26.17%] [G loss: 0.834596]
   /home/nlp/anaconda3/lib/python3.7/site-packages/keras/engine/training.py:297: UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ?
     'Discrepancy between trainable weights and collected trainable'
   1 [D loss: 0.403944, acc.: 83.98%] [G loss: 0.796327]
   2 [D loss: 0.347891, acc.: 83.01%] [G loss: 0.777482]
   3 [D loss: 0.344294, acc.: 81.45%] [G loss: 0.784850]
   4 [D loss: 0.340509, acc.: 82.42%] [G loss: 0.815181]
   5 [D loss: 0.323516, acc.: 86.52%] [G loss: 0.901500]
   6 [D loss: 0.292972, acc.: 93.75%] [G loss: 0.991136]
   7 [D loss: 0.257421, acc.: 97.27%] [G loss: 1.111775]
   8 [D loss: 0.231006, acc.: 98.05%] [G loss: 1.239357]
   9 [D loss: 0.194000, acc.: 99.80%] [G loss: 1.371341]
   10 [D loss: 0.173448, acc.: 100.00%] [G loss: 1.501673]
   11 [D loss: 0.154554, acc.: 100.00%] [G loss: 1.620853]
   12 [D loss: 0.142011, acc.: 99.61%] [G loss: 1.732671]
   13 [D loss: 0.124580, acc.: 99.80%] [G loss: 1.827322]
   14 [D loss: 0.116470, acc.: 99.80%] [G loss: 1.972561]
   15 [D loss: 0.105582, acc.: 100.00%] [G loss: 2.067226]
   16 [D loss: 0.093254, acc.: 100.00%] [G loss: 2.198446]
   17 [D loss: 0.087950, acc.: 100.00%] [G loss: 2.304677]
   18 [D loss: 0.073583, acc.: 100.00%] [G loss: 2.355863]
   19 [D loss: 0.072164, acc.: 100.00%] [G loss: 2.464585]
   20 [D loss: 0.065558, acc.: 99.80%] [G loss: 2.534361]
   21 [D loss: 0.059140, acc.: 100.00%] [G loss: 2.626909]
   22 [D loss: 0.057848, acc.: 100.00%] [G loss: 2.673893]
   23 [D loss: 0.052325, acc.: 100.00%] [G loss: 2.714813]
   24 [D loss: 0.052922, acc.: 100.00%] [G loss: 2.763450]
   25 [D loss: 0.046035, acc.: 100.00%] [G loss: 2.853940]
   26 [D loss: 0.049457, acc.: 100.00%] [G loss: 2.869173]
   27 [D loss: 0.042687, acc.: 100.00%] [G loss: 2.941574]
   28 [D loss: 0.039089, acc.: 100.00%] [G loss: 2.948203]
   29 [D loss: 0.036347, acc.: 100.00%] [G loss: 2.968413]
   30 [D loss: 0.038200, acc.: 100.00%] [G loss: 3.048651]
   31 [D loss: 0.039299, acc.: 100.00%] [G loss: 3.102673]
   32 [D loss: 0.033043, acc.: 100.00%] [G loss: 3.050264]
   33 [D loss: 0.035250, acc.: 100.00%] [G loss: 3.078978]
   34 [D loss: 0.037255, acc.: 100.00%] [G loss: 3.131599]
   35 [D loss: 0.033308, acc.: 100.00%] [G loss: 3.127816]
   36 [D loss: 0.035622, acc.: 100.00%] [G loss: 3.157865]
   37 [D loss: 0.038046, acc.: 100.00%] [G loss: 3.272691]
   38 [D loss: 0.037665, acc.: 100.00%] [G loss: 3.304567]
   39 [D loss: 0.029662, acc.: 100.00%] [G loss: 3.323656]
   40 [D loss: 0.031073, acc.: 100.00%] [G loss: 3.342812]
   41 [D loss: 0.031860, acc.: 100.00%] [G loss: 3.330144]
   42 [D loss: 0.033744, acc.: 100.00%] [G loss: 3.365006]
   43 [D loss: 0.030133, acc.: 100.00%] [G loss: 3.361420]
   44 [D loss: 0.032508, acc.: 100.00%] [G loss: 3.456270]
   45 [D loss: 0.030021, acc.: 100.00%] [G loss: 3.498577]
   46 [D loss: 0.029159, acc.: 100.00%] [G loss: 3.499414]
   47 [D loss: 0.031974, acc.: 100.00%] [G loss: 3.484164]
   48 [D loss: 0.033442, acc.: 99.80%] [G loss: 3.459633]
   49 [D loss: 0.030912, acc.: 100.00%] [G loss: 3.481130]
   50 [D loss: 0.033645, acc.: 100.00%] [G loss: 3.492231]
   51 [D loss: 0.034441, acc.: 100.00%] [G loss: 3.489124]
   52 [D loss: 0.034330, acc.: 100.00%] [G loss: 3.506902]
   53 [D loss: 0.034518, acc.: 100.00%] [G loss: 3.520910]
   54 [D loss: 0.030822, acc.: 100.00%] [G loss: 3.618950]
   55 [D loss: 0.034566, acc.: 99.80%] [G loss: 3.538144]
   56 [D loss: 0.032794, acc.: 100.00%] [G loss: 3.566177]
   57 [D loss: 0.037374, acc.: 99.61%] [G loss: 3.600816]
   58 [D loss: 0.037127, acc.: 100.00%] [G loss: 3.521185]
   59 [D loss: 0.039322, acc.: 100.00%] [G loss: 3.531039]
   60 [D loss: 0.030453, acc.: 100.00%] [G loss: 3.616879]
   61 [D loss: 0.044332, acc.: 99.02%] [G loss: 3.628755]
   62 [D loss: 0.037772, acc.: 99.80%] [G loss: 3.723062]
   63 [D loss: 0.041130, acc.: 99.61%] [G loss: 3.533709]
   64 [D loss: 0.044611, acc.: 99.41%] [G loss: 3.657721]
   65 [D loss: 0.037362, acc.: 99.61%] [G loss: 3.582735]
   66 [D loss: 0.050663, acc.: 99.02%] [G loss: 3.555587]
   67 [D loss: 0.039863, acc.: 99.41%] [G loss: 3.611456]
   68 [D loss: 0.051172, acc.: 99.02%] [G loss: 3.540278]
   69 [D loss: 0.052263, acc.: 98.63%] [G loss: 3.612799]
   70 [D loss: 0.056154, acc.: 99.41%] [G loss: 3.557292]
   71 [D loss: 0.055386, acc.: 99.22%] [G loss: 3.744767]
   72 [D loss: 0.096904, acc.: 97.66%] [G loss: 3.443518]
   73 [D loss: 0.070626, acc.: 98.05%] [G loss: 3.833835]
   74 [D loss: 0.180408, acc.: 93.55%] [G loss: 3.301687]
   75 [D loss: 0.074523, acc.: 98.44%] [G loss: 3.776305]
   76 [D loss: 0.057483, acc.: 99.02%] [G loss: 3.714150]
   77 [D loss: 0.141995, acc.: 95.12%] [G loss: 3.380850]
   78 [D loss: 0.067733, acc.: 98.63%] [G loss: 3.779586]
   79 [D loss: 0.303615, acc.: 87.89%] [G loss: 2.848376]
   80 [D loss: 0.145237, acc.: 94.14%] [G loss: 3.108039]
   81 [D loss: 0.046822, acc.: 99.22%] [G loss: 3.635069]
   82 [D loss: 0.108516, acc.: 96.48%] [G loss: 3.235212]
   83 [D loss: 0.105234, acc.: 96.48%] [G loss: 3.336948]
   84 [D loss: 0.233112, acc.: 90.82%] [G loss: 2.740180]
   85 [D loss: 0.118313, acc.: 94.92%] [G loss: 3.181991]
   86 [D loss: 0.300344, acc.: 87.30%] [G loss: 2.879515]
   87 [D loss: 0.106900, acc.: 96.48%] [G loss: 3.189476]
   88 [D loss: 0.381278, acc.: 84.38%] [G loss: 2.337953]
   89 [D loss: 0.252046, acc.: 88.28%] [G loss: 2.707138]
   90 [D loss: 0.087314, acc.: 97.07%] [G loss: 3.401120]
   91 [D loss: 0.260525, acc.: 90.62%] [G loss: 2.520348]
   92 [D loss: 0.148098, acc.: 93.36%] [G loss: 2.991073]
   93 [D loss: 0.141315, acc.: 96.09%] [G loss: 2.805464]
   94 [D loss: 0.288812, acc.: 89.45%] [G loss: 2.549888]
   95 [D loss: 0.143633, acc.: 94.14%] [G loss: 2.978777]
   96 [D loss: 0.584615, acc.: 78.32%] [G loss: 2.050247]
   97 [D loss: 0.328917, acc.: 83.01%] [G loss: 2.579935]
   98 [D loss: 0.111224, acc.: 97.66%] [G loss: 3.526271]
   99 [D loss: 0.702403, acc.: 68.95%] [G loss: 1.994847]
   100 [D loss: 0.335197, acc.: 84.96%] [G loss: 2.110721]
   101 [D loss: 0.147330, acc.: 93.55%] [G loss: 2.962312]
   102 [D loss: 0.091300, acc.: 98.44%] [G loss: 3.025173]
   103 [D loss: 0.304929, acc.: 87.70%] [G loss: 2.458197]
   104 [D loss: 0.199925, acc.: 90.43%] [G loss: 2.897576]
   105 [D loss: 0.335472, acc.: 87.30%] [G loss: 2.198746]
   106 [D loss: 0.235486, acc.: 88.09%] [G loss: 2.742341]
   107 [D loss: 0.346595, acc.: 84.77%] [G loss: 2.340909]
   108 [D loss: 0.211129, acc.: 91.60%] [G loss: 2.801579]
   109 [D loss: 0.361250, acc.: 84.96%] [G loss: 2.304583]
   110 [D loss: 0.183040, acc.: 93.16%] [G loss: 2.763792]
   111 [D loss: 0.365892, acc.: 82.62%] [G loss: 2.418060]
   112 [D loss: 0.197837, acc.: 92.19%] [G loss: 2.826400]
   113 [D loss: 0.413041, acc.: 81.05%] [G loss: 2.408184]
   114 [D loss: 0.198854, acc.: 91.80%] [G loss: 2.784730]
   115 [D loss: 0.395174, acc.: 81.45%] [G loss: 2.115457]
   116 [D loss: 0.189158, acc.: 90.04%] [G loss: 2.603389]
   117 [D loss: 0.237316, acc.: 92.97%] [G loss: 2.648600]
   118 [D loss: 0.285941, acc.: 87.89%] [G loss: 2.370326]
   119 [D loss: 0.208490, acc.: 90.43%] [G loss: 2.849175]
   120 [D loss: 0.454702, acc.: 80.08%] [G loss: 1.897220]
   121 [D loss: 0.217595, acc.: 89.06%] [G loss: 2.498424]
   122 [D loss: 0.173055, acc.: 94.92%] [G loss: 2.664538]
   123 [D loss: 0.262918, acc.: 90.82%] [G loss: 2.133595]
   124 [D loss: 0.190525, acc.: 91.02%] [G loss: 2.840866]
   125 [D loss: 0.292295, acc.: 87.11%] [G loss: 2.199357]
   126 [D loss: 0.215348, acc.: 88.87%] [G loss: 2.739654]
   127 [D loss: 0.365445, acc.: 84.96%] [G loss: 2.162226]
   128 [D loss: 0.200284, acc.: 89.65%] [G loss: 2.871504]
   129 [D loss: 0.450811, acc.: 79.10%] [G loss: 1.971582]
   130 [D loss: 0.200712, acc.: 90.82%] [G loss: 2.715580]
   131 [D loss: 0.310609, acc.: 85.94%] [G loss: 2.443402]
   132 [D loss: 0.234690, acc.: 89.65%] [G loss: 2.654381]
   133 [D loss: 0.449007, acc.: 79.30%] [G loss: 1.873044]
   134 [D loss: 0.233484, acc.: 87.89%] [G loss: 2.710910]
   135 [D loss: 0.274398, acc.: 87.70%] [G loss: 2.632379]
   136 [D loss: 0.295981, acc.: 87.11%] [G loss: 2.511465]
   137 [D loss: 0.247948, acc.: 89.65%] [G loss: 2.698283]
   138 [D loss: 0.490601, acc.: 75.20%] [G loss: 2.161157]
   139 [D loss: 0.215320, acc.: 90.43%] [G loss: 2.841792]
   140 [D loss: 0.564996, acc.: 73.63%] [G loss: 1.618642]
   141 [D loss: 0.270847, acc.: 86.52%] [G loss: 2.598266]
   142 [D loss: 0.210049, acc.: 93.75%] [G loss: 3.058943]
   143 [D loss: 0.462835, acc.: 76.17%] [G loss: 2.219012]
   144 [D loss: 0.213740, acc.: 89.84%] [G loss: 2.845963]
   145 [D loss: 0.518464, acc.: 73.63%] [G loss: 1.735387]
   146 [D loss: 0.273846, acc.: 87.11%] [G loss: 2.634973]
   ……


目录
相关文章
|
15天前
|
机器学习/深度学习 数据采集 TensorFlow
R语言KERAS深度学习CNN卷积神经网络分类识别手写数字图像数据(MNIST)
R语言KERAS深度学习CNN卷积神经网络分类识别手写数字图像数据(MNIST)
30 0
|
4月前
|
机器学习/深度学习 Python
深度学习第5天:GAN生成对抗网络
深度学习第5天:GAN生成对抗网络
50 0
|
5月前
|
机器学习/深度学习 人工智能 算法
深度学习及CNN、RNN、GAN等神经网络简介(图文解释 超详细)
深度学习及CNN、RNN、GAN等神经网络简介(图文解释 超详细)
183 1
|
25天前
|
机器学习/深度学习 人工智能 算法
基于AidLux的工业视觉少样本缺陷检测实战应用---深度学习分割模型UNET的实践部署
  工业视觉在生产和制造中扮演着关键角色,而缺陷检测则是确保产品质量和生产效率的重要环节。工业视觉的前景与发展在于其在生产制造领域的关键作用,尤其是在少样本缺陷检测方面,借助AidLux技术和深度学习分割模型UNET的实践应用,深度学习分割模型UNET的实践部署变得至关重要。
67 1
|
7月前
|
机器学习/深度学习 人工智能 自然语言处理
【深度学习】实验18 自然语言处理
【深度学习】实验18 自然语言处理
38 0
|
7月前
|
机器学习/深度学习 算法 PyTorch
【深度学习】实验16 使用CNN完成MNIST手写体识别(PyTorch)
【深度学习】实验16 使用CNN完成MNIST手写体识别(PyTorch)
79 0
|
7月前
|
机器学习/深度学习 自然语言处理 算法
【深度学习】实验15 使用CNN完成MNIST手写体识别(Keras)
【深度学习】实验15 使用CNN完成MNIST手写体识别(Keras)
58 0
|
7月前
|
机器学习/深度学习 算法 TensorFlow
【深度学习】实验14 使用CNN完成MNIST手写体识别(TensorFlow)
【深度学习】实验14 使用CNN完成MNIST手写体识别(TensorFlow)
68 0
|
7月前
|
机器学习/深度学习 自然语言处理 PyTorch
【深度学习】实验12 使用PyTorch训练模型
【深度学习】实验12 使用PyTorch训练模型
81 0
|
7月前
|
机器学习/深度学习
【深度学习】实验13 使用Dropout抑制过拟合 2
【深度学习】实验13 使用Dropout抑制过拟合
24 0