GAN出一个女朋友(上)

简介: GAN出一个女朋友

GAN(生成对抗网络)


没有女朋友也没关系,现在我们就试试自己GAN出一个。


1. 生成对抗网络简述


首先要提的肯定是2014年Ian J. Goodfellow大佬的关于GAN的论文,应该也算是GAN的开山之作。Generative Adversarial Network 论文链接


生成对抗网络顾名思义包含了生成对抗的思想,按照原始的GAN网络结构中主要包含了两部分,分别是生成模型 G判别模型 D,通常这两部分都是神经网络,其中G负责通过我们输入的数据生成一些新的数据(生成的数据要尽可能接近我们的数据集中的数据);而D则负责作判别,判断输入的数据是来自真实数据集的 “真数据” 还是由G生成的 ”假数据“ ,所以D是一个用来分辨真伪的二分类模型,网络最后一层往往也使用sigmoid激活函数。


举个简单的例子:


G就像现实中制作假币的机器,而D就是验钞机。一开始G和D的技术都不太好,往往G产生的假币假到人肉眼就看出来是假的,而D判断的技术也很差,往往人肉眼都觉得的是假币而它判断确实真币;经过一段时间的改进优化,D可以比人眼更好的分辨真假币,而G制作出来的假币也能骗过人的眼睛;到了最后,G生成的假币已经可以完美的媲美真币,而D这个时候已经无法区分到底是真币还是G生成的假币(准确率50%),这个时候就是最完美的时候,我们的目的也就达到了。


GAN整体示意图如下:


image.png



从图来看,G网络所需要的输入只有一个随机噪音,通过随机噪音来生成数据,而D网络的输入就是我们真实的样本数据以及G所生成的数据,并对它们作真假判断。


结合上面的简述,我们可以体会到生成以及对抗的含义,这个想法非常好两者在对抗竞争中互相提高,但是实际上往往会出现一些意外情况,导致模型无法达到预期效果。


  • 当我们的验钞机非常厉害,而假币制造技术还很差的时候。不管怎么努力做出的假币还是会被分辨出来,这个时候就没有人会去想着制造假币(无利可图),还不如老老实实工作。
  • 当验钞机判别能力相当弱的时候,不管假币制造的多么假,它还是会觉得是真币,尽管这个时候人眼就能看出这是假币;那这个时候假币制造机已经觉得满足了,也不会去进一步增强自己的造假能力。现实中的情况往往是只会生成假的1元,而5元、10元、20元等都不会生成。


上面两种情况其实就是我们训练GAN时通常会遇到的两种情况,都会导致我们的生成器无法达到我们预期想要的结果。此外,对于这样的一种网络概念,怎么对它进行优化训练也是我们需要考虑的,这些就留到下面继续说。


2.具体内容


由上面的简述,我们有了一种新的网络设计思想,通过两种网络的竞争从而得到性能优异的生成模型,并且同时我们还得到了一个能力不错的判定模型。但是当我们把想法实践的时候也遇到了不少问题,上面的D性能太好和太坏都会导致G的训练失败,而且如何训练这两个模型也是我们需要思考的问题,在这里我就展开详细的说说。


2.1 网络如何训练


有了上面零和博弈的思想,我们就会开始设计网络,通常我们的生成器和判定器都是深层神经网络,简单的全连接到复杂点的卷积……;设计好网络后,我们就需要对网络进行训练,也就是更新网络参数,这个时候就需要设置网络的目标函数。


image.png


从论文上看,所给出的目标函数


min⁡Gmax⁡DV(D,G)=Ex∼pdata(x)[log⁡D(x)]+Ez∼pz(z)[log⁡(1−D(G(z)))] \min\limits_{G}\max\limits_{D}V(D,G)=E_{x \sim p_{data}(x)}[\log{D(x)}]+E_{z \sim p_{z}(z)}[\log{(1-D(G(z)))}]GminDmaxV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]


实际训练过程中,我们的损失函数分别为:


  • 对于G
    最小化log⁡(1−D(G(z)))\log{(1-D(G(z)))}log(1D(G(z)))但是在早期可能无法提供足够的梯度来训练G,所以有了替代为最大化log⁡D(G(z))\log{D(G(z))}logD(G(z))
  • 对于D
    −(log⁡D(x)+log⁡(1−D(G(x))))-(\log{D(x)}+\log{(1-D(G(x)))})(logD(x)+log(1D(G(x))))


image.png


有了上面的损失函数,我们就可以对网络权重进行更新,从而训练网络。不过GAN训练起来可并不容易,往往理想很丰满,现实很骨感。


2.2 网络训练时会遇到的问题


再解释我上面提到的两种情况


  1. Vanishing Gradient 梯度消失
    梯度消失是指D判别器训练的判别效果远比G的生成效果好,导致无论G生成什么数据,它都觉得与真实数据分布不一致,这个时候就会导致G的训练一直停滞。


image.png

  • 从论文上看,我们最终可以得到一个最优解D∗(x)=pdata(x)pdata(x)+pg(x)D^*(x)=\frac{p_{data}(x)}{p_{data}(x)+p_{g}(x)}D(x)=pdata(x)+pg(x)pdata(x),带入目标函数,得到简化后的函数min⁡G(2∗JS(Pdata∣∣PG)−log4)\min\limits_{G}(2*JS(P_{data}||P_G)-log4)Gmin(2JS(PdataPG)log4)
    当我们生成的数据与实际数据集数据分布相差过大,几乎没有重叠部分的时候,散度就成了一个常数,这个时候函数的梯度就是0,也就出现了梯度消失的现象。


  • Mode Collapse 模式坍塌


模式坍塌则是我们GAN训练过程中会出现的另一种问题,当D的判别性能跟不上G的生成效果时,G生成的图片都会被判定为真实的,这个时候D(G(x))=1D(G(x))=1D(G(x))=1,G也就失去了训练的方向。往往这种情况下就会导致只能很好的生成某一类数据,或者生成的数据重复率过高。具体如下图


image.png


  1. 生成的图像中出现了多个相似图片,失去了生成模型的多样性。


3.GAN的发展


基于零和博弈策略的GAN作为生成模型,相比之前的生成模型不再需要知道数据的分布,而是先生成数据然后往我们所给出的数据分布上去靠拢。GAN往往用于图像、视频的生成,也有用于文字的生成。从它的原理我们可以看出,它对于不同的输入noise,就会生成不同的数据;另外,它生成的数据往往是整体的,而不能依赖于某一个数据去控制另一个数据。


对于图像的生成,将全连接神经网络改为了卷积神经网络,并且优化激活函数等,比如DCGAN;对于风格迁移的CycleGAN;为了使GAN训练更稳定提出的WGAN和WGAN-GP以及Conditional GAN;最后是复杂的StyleGAN,也是我目前需要研究学习的,下一篇应该就是关于StyleGANStyleGAN2的博客


4.GAN的代码实践


4.1 基于GAN的mnist数据生成


来源:pytorch-exercise/pt01_generative_adversarial_network.py at main · Baileyswu/pytorch-exercise


import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from bokeh.io import show, output_notebook
from bokeh.plotting import figure, gridplot
from bokeh.models import LinearAxis, Range1d
output_notebook()
train_dataset = dsets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = dsets.MNIST(root='./data', train=False, transform=transforms.ToTensor())
batch_size = 50
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=True)
# 定义模型
class GAN(nn.Module):
    def __init__(self):
        super().__init__()
        # 生成器
        self.G = nn.Sequential(
            nn.Linear(64, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 784),
            nn.Tanh()
        )
        # 判定器
        self.D = nn.Sequential(
            nn.Linear(784, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    def forward(self, z):
        return self.G(z)
    # 将输入噪音z传入G,然后G生成的数据给D进行判断
    def score(self, z):
        fake_imgs = self.G(z)
        fake_score = self.D(fake_imgs)
        return fake_score
lrate = 0.0001
epochs = 300
# 模型实例化
model = GAN().cuda()
# 损失函数
criterion = nn.BCELoss()
# 优化器
optim_G = torch.optim.Adam(model.G.parameters(), lr = lrate)
optim_D = torch.optim.Adam(model.D.parameters(), lr = lrate)
# 模型可视化
model
# 绘制图片
def list_img(i, img, title):
    img = img.reshape(28, 28)
    plt.subplot(2, 5, i+1)
    plt.imshow(img)
    plt.title('%s' % (title))
# 可视化生成的数据
def generate_test(inputs, title=''):
    plt.figure(figsize=(15, 6))
    imgs = model(inputs)
    imgs = (imgs + 1) / 2
    imgs.clamp(0, 1)
    for i in range(len(inputs)):
        list_img(i, imgs[i].cpu().detach().numpy(), title)
    plt.show()
result_d = []
result_g = []
test_inputs = torch.randn(5, 64).cuda()
# 开始训练
for e in range(epochs):
    for i, (inputs, _) in enumerate(train_loader):
        inputs = inputs.view(-1, 28*28).cuda()
        real_labels = torch.ones(batch_size, 1).cuda()
        fake_labels = torch.zeros(batch_size, 1).cuda()
        # 对D进行训练
        real_score = model.D(inputs)
        loss_d_real = criterion(real_score, real_labels)
        fake_score = model.score(torch.randn(batch_size, 64).cuda())
        loss_d_fake = criterion(fake_score, fake_labels)
        optim_D.zero_grad() 
        loss_d = loss_d_real + loss_d_fake
        loss_d.backward()
        optim_D.step()
        fake_score = model.score(torch.randn(batch_size, 64).cuda())
        loss_g = criterion(fake_score, real_labels)
        optim_G.zero_grad()
        loss_g.backward()
        optim_G.step()
        if i % 100 == 0:
            result_d.append(float(loss_d))
            result_g.append(float(loss_g))
    if e % 30 == 0:
        generate_test(test_inputs, str(e))
fig = figure()
fig.line(range(len(result_d)), result_d, legend_label='D loss', line_width=1.5)
fig.line(range(len(result_g)), result_g, legend_label='G loss', line_color="green")
show(fig)
# 测试生成器效果
new_data=model.G(torch.randn(1,64).cuda())
img=new_data.cpu().detach().numpy().reshape(28,28)
plt.imshow(img)
plt.show()
复制代码


训练图片过程


image.png

image.png

损失曲线

image.png

测试生成图片

image.png

目录
相关文章
|
3月前
|
机器学习/深度学习 数据采集 人工智能
GAN的主要介绍
【10月更文挑战第6天】
|
3月前
|
编解码 自然语言处理 算法
生成对抗网络的应用有哪些
【10月更文挑战第14天】生成对抗网络的应用有哪些
|
6月前
|
机器学习/深度学习 算法
生成对抗网络
生成对抗网络
|
机器学习/深度学习 决策智能 计算机视觉
理解GAN生成对抗网络
理解GAN生成对抗网络
|
8月前
|
机器学习/深度学习 编解码 数据处理
GAN介绍
GAN介绍
94 0
|
机器学习/深度学习 人工智能 文字识别
生成对抗网络(一)
生成对抗网络(一)
105 0
|
机器学习/深度学习 人工智能 开发者
生成对抗网络(二)
生成对抗网络(二)
129 0
|
机器学习/深度学习
从零使用GAN(生成对抗网络)进行图像生成
本项目使用 DCGAN 模型,在自建数据集上进行实验。
322 0
从零使用GAN(生成对抗网络)进行图像生成
|
机器学习/深度学习 编解码 运维
GAN的详细介绍及其应用(全面且完整)
GAN的详细介绍及其应用(全面且完整)
GAN的详细介绍及其应用(全面且完整)