5分钟入门GANS:原理解释和keras代码实现

简介: 5分钟入门GANS:原理解释和keras代码实现

本篇文章包含以下内容

  1. 介绍
  2. 历史
  3. 直观解释
  4. 训练过程
  5. GAN在MNIST数据集上的KERAS实现

介绍

生成式对抗网络通常也称为GANs,用于生成图像而不需要很少或没有输入。GANs允许我们生成由神经网络生成的图像。在我们深入讨论这个理论之前,我想向您展示GANs构建您兴奋感的能力。把马变成斑马(反之亦然)。

640.png

640.png

历史

生成式对抗网络(GANs)是由Ian Goodfellow (GANs的GAN Father)等人于2014年在其题为“生成式对抗网络”的论文中提出的。它是一种可替代的自适应变分编码器(VAEs)学习图像的潜在空间,以生成合成图像。它的目的是创造逼真的人工图像,几乎无法与真实的图像区分。

GAN的直观解释

生成器和鉴别器网络:

生成器网络的目的是将随机图像初始化并解码成一个合成图像。

鉴别器网络的目的是获取这个输入,并预测这个图像是来自真实的数据集还是合成的。

正如我们刚才看到的,这实际上就是GANs,两个相互竞争的对抗网络。

GAN的训练过程

GANS的训练是出了名的困难。在CNN中,我们使用梯度下降来改变权重以减少损失。

然而,在GANs中,每一次重量的变化都会改变整个动态系统的平衡。

在GAN的网络中,我们不是在寻求将损失最小化,而是在我们对立的两个网络之间找到一种平衡。

我们将过程总结如下

  1. 输入随机生成的噪声图像到我们的生成器网络中生成样本图像。
  2. 我们从真实数据中提取一些样本图像,并将其与一些生成的图像混合在一起。
  3. 将这些混合图像输入到我们的鉴别器中,鉴别器将对这个混合集进行训练并相应地更新它的权重。
  4. 然后我们制作更多的假图像,并将它们输入到鉴别器中,但是我们将它们标记为真实的。这样做是为了训练生成器。我们在这个阶段冻结了鉴别器的权值(鉴别器学习停止),并且我们使用来自鉴别器的反馈来更新生成器的权值。这就是我们如何教我们的生成器(制作更好的合成图像)和鉴别器更好地识别赝品的方法。

流程图如下

640.png

对于本文,我们将使用MNIST数据集生成手写数字。GAN的架构是:

640.png

使用KERAS实现GANS

首先,我们加载所有必要的库


importosos.environ["KERAS_BACKEND"] ="tensorflow"importnumpyasnpfromtqdmimporttqdmimportmatplotlib.pyplotaspltfromkeras.layersimportInputfromkeras.modelsimportModel, Sequentialfromkeras.layers.coreimportReshape, Dense, Dropout, Flattenfromkeras.layers.advanced_activationsimportLeakyReLUfromkeras.layers.convolutionalimportConvolution2D, UpSampling2Dfromkeras.layers.normalizationimportBatchNormalizationfromkeras.datasetsimportmnistfromkeras.optimizersimportAdamfromkerasimportbackendasKfromkerasimportinitializersK.set_image_dim_ordering('th')
#Deterministicoutput.
#Tiredofseeingthesameresultseverytime?Removethelinebelow.
np.random.seed(1000)
#Theresultsarealittlebetterwhenthedimensionalityoftherandomvectorisonly10.#Thedimensionalityhasbeenleftat100forconsistencywithotherGANimplementations.
randomDim=100

现在我们加载数据集。这里使用MNIST数据集,所以不需要单独下载和处理。

(X_train, y_train), (X_test, y_test) =mnist.load_data()
X_train= (X_train.astype(np.float32) -127.5)/127.5X_train=X_train.reshape(60000, 784)

接下来,我们定义生成器和鉴别器的结构

#Optimizeradam=Adam(lr=0.0002, beta_1=0.5)#generatorgenerator=Sequential()
generator.add(Dense(256, input_dim=randomDim, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
generator.add(LeakyReLU(0.2))
generator.add(Dense(512))
generator.add(LeakyReLU(0.2))
generator.add(Dense(1024))
generator.add(LeakyReLU(0.2))
generator.add(Dense(784, activation='tanh'))
generator.compile(loss='binary_crossentropy', optimizer=adam)#discriminatordiscriminator=Sequential()
discriminator.add(Dense(1024, input_dim=784, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(512))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(256))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(1, activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy', optimizer=adam)

现在我们把发生器和鉴别器结合起来同时训练。

#Combinednetworkdiscriminator.trainable=FalseganInput=Input(shape=(randomDim,))
x=generator(ganInput)
ganOutput=discriminator(x)
gan=Model(inputs=ganInput, outputs=ganOutput)
gan.compile(loss='binary_crossentropy', optimizer=adam)
dLosses= []
gLosses= []

三个函数,每20个epoch绘制并保存结果,并保存模型。

#PlotthelossfromeachbatchdefplotLoss(epoch):
plt.figure(figsize=(10, 8))
plt.plot(dLosses, label='Discriminitive loss')
plt.plot(gLosses, label='Generative loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.savefig('images/gan_loss_epoch_%d.png'%epoch)
#CreateawallofgeneratedMNISTimagesdefplotGeneratedImages(epoch, examples=100, dim=(10, 10), figsize=(10, 10)):
noise=np.random.normal(0, 1, size=[examples, randomDim])
generatedImages=generator.predict(noise)
generatedImages=generatedImages.reshape(examples, 28, 28)
plt.figure(figsize=figsize)
foriinrange(generatedImages.shape[0]):
plt.subplot(dim[0], dim[1], i+1)
plt.imshow(generatedImages[i], interpolation='nearest', cmap='gray_r')
plt.axis('off')
plt.tight_layout()
plt.savefig('images/gan_generated_image_epoch_%d.png'%epoch)
#Savethegeneratoranddiscriminatornetworks (andweights) forlaterusedefsaveModels(epoch):
generator.save('models/gan_generator_epoch_%d.h5'%epoch)
discriminator.save('models/gan_discriminator_epoch_%d.h5'%epoch)

训练函数

deftrain(epochs=1, batchSize=128):
batchCount=X_train.shape[0] /batchSizeprint'Epochs:', epochsprint'Batch size:', batchSizeprint'Batches per epoch:', batchCountforeinxrange(1, epochs+1):
print'-'*15, 'Epoch %d'%e, '-'*15for_intqdm(xrange(batchCount)):
#Getarandomsetofinputnoiseandimagesnoise=np.random.normal(0, 1, size=[batchSize, randomDim])
imageBatch=X_train[np.random.randint(0, X_train.shape[0], size=batchSize)]
#GeneratefakeMNISTimagesgeneratedImages=generator.predict(noise)
#printnp.shape(imageBatch), np.shape(generatedImages)
X=np.concatenate([imageBatch, generatedImages])
#LabelsforgeneratedandrealdatayDis=np.zeros(2*batchSize)
#One-sidedlabelsmoothingyDis[:batchSize] =0.9#Traindiscriminatordiscriminator.trainable=Truedloss=discriminator.train_on_batch(X, yDis)
#Traingeneratornoise=np.random.normal(0, 1, size=[batchSize, randomDim])
yGen=np.ones(batchSize)
discriminator.trainable=Falsegloss=gan.train_on_batch(noise, yGen)
#StorelossofmostrecentbatchfromthisepochdLosses.append(dloss)
gLosses.append(gloss)
ife==1ore%20==0:
plotGeneratedImages(e)
saveModels(e)
#PlotlossesfromeveryepochplotLoss(e)

至此一个简单的GAN已经完成了,完整的代码在这里找到

https://github.com/bhaveshgoyal27/mediumblogs/blob/master/Keras_MNIST_GAN.py

目录
相关文章
|
8月前
|
机器学习/深度学习 算法 数据挖掘
【Python机器学习】层次聚类AGNES、二分K-Means算法的讲解及实战演示(图文解释 附源码)
【Python机器学习】层次聚类AGNES、二分K-Means算法的讲解及实战演示(图文解释 附源码)
246 0
|
存储 机器学习/深度学习 算法
线性回归 梯度下降原理与基于Python的底层代码实现
梯度下降是一种常用的优化算法,可以用来求解许包括线性回归在内的许多机器学习中的问题。前面讲解了直接使用公式求解θ \thetaθ (最小二乘法的求解推导与基于Python的底层代码实现),但是对于复杂的函数来说,可能较难求出对应的公式,因此需要使用梯度下降。
|
机器学习/深度学习 自然语言处理 算法
【机器学习实战】10分钟学会Python怎么用NN神经网络进行分类(十一)
【机器学习实战】10分钟学会Python怎么用NN神经网络进行分类(十一)
152 0
|
机器学习/深度学习 传感器 自然语言处理
【机器学习实战】10分钟学会Python怎么用GBM梯度提升机进行预测(十四)
【机器学习实战】10分钟学会Python怎么用GBM梯度提升机进行预测(十四)
689 0
|
机器学习/深度学习 自然语言处理 算法
机器学习|TF-IDF算法(原理及代码实现)
TFIDF算法的原理及其代码实现。
|
索引
【Pytorch--代码技巧】各种论文代码常见技巧
博主在阅读论文原代码的时候常常看见一些没有见过的代码技巧,特此将这些内容进行汇总
180 0
|
人工智能 TensorFlow 算法框架/工具
Tensorflow反卷积(DeConv)实现原理+手写python代码实现反卷积(DeConv)
Tensorflow反卷积(DeConv)实现原理+手写python代码实现反卷积(DeConv)
|
人工智能 移动开发 TensorFlow
Tensorflow卷积实现原理+手写python代码实现卷积
Tensorflow卷积实现原理+手写python代码实现卷积
|
人工智能 TensorFlow 算法框架/工具
MobileNet原理+手写python代码实现MobileNet
MobileNet原理+手写python代码实现MobileNet
|
机器学习/深度学习 算法 Python
DL:神经网络算法简介之Affine 层的简介、使用方法、代码实现之详细攻略
DL:神经网络算法简介之Affine 层的简介、使用方法、代码实现之详细攻略
DL:神经网络算法简介之Affine 层的简介、使用方法、代码实现之详细攻略