# 深度学习界明星：生成对抗网络与Improving GAN

+关注继续查看

## 1何为生成对抗网络

### 2. 对抗模型

class discriminator(nn.Module):
def __init__(self):
super(discriminator, self).__init__()
self.dis = nn.Sequential(
nn.Linear(784, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid()
)

def forward(self, x):
x = self.dis(x)
return x


class generator(nn.Module):
def __init__(self， input_size):
super(generator, self).__init__()
self.gen = nn.Sequential(
nn.Linear(input_size, 256),
nn.ReLU(True),
nn.Linear(256, 256),
nn.ReLU(True),
nn.Linear(256, 784),
nn.Tanh()
)

def forward(self, x):
x = self.gen(x)
return x


criterion = nn.BCELoss() # Binary Cross Entropy


img = img.view(num_img, -1)
real_img = Variable(img).cuda()
real_label = Variable(torch.ones(num_img)).cuda()
fake_label = Variable(torch.zeros(num_img)).cuda()

# compute loss of real_img
real_out = D(real_img)
d_loss_real = criterion(real_out, real_label)
real_scores = real_out

# compute loss of fake_img
z = Variable(torch.randn(num_img, z_dimension)).cuda()
fake_img = G(z)
fake_out = D(fake_img)
d_loss_fake = criterion(fake_out, fake_label)
fake_scores = fake_out

# bp and optimize
d_loss = d_loss_real + d_loss_fake
d_loss.backward()
d_optimizer.step()


# compute loss of fake_img
z = Variable(torch.randn(num_img, z_dimension)).cuda() # 得到随机噪声
fake_img = G(z) # 生成假的图片
output = D(fake_img) # 经过判别器得到结果
g_loss = criterion(output, real_label) # 得到假的图片与真实图片label的loss

# bp and optimize
g_loss.backward() # 反向传播
g_optimizer.step() # 更新生成网络的参数


class discriminator(nn.Module):
def __init__(self):
super(discriminator, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(1, 32, 5, padding=2), # batch, 32, 28, 28
nn.LeakyReLU(0.2, True),
nn.AvgPool2d(2, stride=2), # batch, 32, 14, 14
)
self.conv2 = nn.Sequential(
nn.Conv2d(32, 64, 5, padding=2), # batch, 64, 14, 14
nn.LeakyReLU(0.2, True),
nn.AvgPool2d(2, stride=2) # batch, 64, 7, 7
)
self.fc = nn.Sequential(
nn.Linear(64*7*7, 1024),
nn.LeakyReLU(0.2, True),
nn.Linear(1024, 1),
nn.Sigmoid()
)

def forward(self, x):
'''
x: batch, width, height, channel=1
'''
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x

class generator(nn.Module):
def __init__(self, input_size, num_feature):
super(generator, self).__init__()
self.fc = nn.Linear(input_size, num_feature) # batch, 3136=1x56x56
self.br = nn.Sequential(
nn.BatchNorm2d(1),
nn.ReLU(True)
)
self.downsample1 = nn.Sequential(
nn.Conv2d(1, 50, 3, stride=1, padding=1), # batch, 50, 56, 56
nn.BatchNorm2d(50),
nn.ReLU(True)
)
self.downsample2 = nn.Sequential(
nn.Conv2d(50, 25, 3, stride=1, padding=1), # batch, 25, 56, 56
nn.BatchNorm2d(25),
nn.ReLU(True)
)
self.downsample3 = nn.Sequential(
nn.Conv2d(25, 1, 2, stride=2), # batch, 1, 28, 28
nn.Tanh()
)

def forward(self, x):
x = self.fc(x)
x = x.view(x.size(0), 1, 56, 56)
x = self.br(x)
x = self.downsample1(x)
x = self.downsample2(x)
x = self.downsample3(x)
return x


## 2 Improving GAN

### 1 Wasserstein GAN

Wasserstein GAN 是GAN 的一种变式，我们知道GAN 的训练是非常麻烦的，需要很多训练技巧，而且在不同的数据集上，由于数据的分布会发生变化，也需要重新调整参数，不仅需要小心地平衡生成器和判别器的训练进程，同时生成的样本还缺乏多样性。除此之外最大的问题是没办法衡量这个生成器到底好不好，因为没办法通过判别器的loss 去判断这个事情。虽然DC GAN 依靠对生成器和判别器的结构进行枚举，最终找到了一个比较好的网络设置，但还是没有从根本上解决训练的问题。

WGAN 的出现，彻底解决了下面这些难点：

（1）彻底解决了训练不稳定的问题，不再需要设计参数去平衡判别器和生成器；

（2）基本解决了collapse mode 的问题，确保了生成样本的多样性；

（3）训练中有一个向交叉熵、准确率的数值指标来衡量训练的进程，数值越小代表GAN 训练得越好，同时也就代表着生成的图片质量越高；

（4）不需要精心设计网络结构，用简单的多层感知器就能够取得比较好的效果。

#### ②Wasserstein 距离

W 距离与JS Divergence 相比有什么好处呢？最大的好处就是不管两种分布是否有重叠，它都是连续变换的而不是突变的，可以用下面这个例子来说明一下，如图4所示。

#### ③WGAN

W 距离有很好的优越性，把它拿来作为两种分布的度量优化生成器，但是W 距离里面有一个是没办法求解的。作者Martin 在论文附录里面通过定理将这个问题转变成了一个新的问题，有着如下形式：

（1）判别器最后一层去掉sigmoid；

（2）生成器和判别器的loss 不取log；

（3）每次更新判别器的参数之后把它们的绝对值裁剪到不超过一个固定常数的数；

### 2 Improving WGAN

WGAN 的提出成功地解决了GAN 的很多问题，最后需要满足一阶Lipschitz 连续性条件，所以在训练的时候加了一个限制——权重裁剪。

想及时获得更多精彩文章，可在微信中搜索“博文视点”或者扫描下方二维码并关注。

Spark-SparkSQL深入学习系列十（转自OopsOutOfMemory）
/** Spark SQL源码分析系列文章*/     前面讲到了Spark SQL In-Memory Columnar Storage的存储结构是基于列存储的。
890 0
Spark-SparkSQL深入学习系列七（转自OopsOutOfMemory）
/** Spark SQL源码分析系列文章*/   接上一篇文章Spark SQL Catalyst源码分析之Physical Plan，本文将介绍Physical Plan的toRDD的具体实现细节：   我们都知道一段sql，真正的执行是当你调用它的collect()方法才会执行Spark Job，最后计算得到RDD。
918 0
Spark-SparkSQL深入学习系列六（转自OopsOutOfMemory）
/** Spark SQL源码分析系列文章*/   前面几篇文章主要介绍的是Spark sql包里的的spark sql执行流程，以及Catalyst包内的SqlParser，Analyzer和Optimizer，最后要介绍一下Catalyst里最后的一个Plan了，即Physical Plan。
1101 0
Spark-SparkSQL深入学习系列五（转自OopsOutOfMemory）
/** Spark SQL源码分析系列文章*/   前几篇文章介绍了Spark SQL的Catalyst的核心运行流程、SqlParser，和Analyzer 以及核心类库TreeNode，本文将详细讲解Spark SQL的Optimizer的优化思想以及Optimizer在Catalyst里的表现方式，并加上自己的实践，对Optimizer有一个直观的认识。
991 0

0 0
【推荐系统】浪潮之巅——深度学习推荐系列模型

0 0
04、SpringCloud之Feign组件学习笔记（二）
04、SpringCloud之Feign组件学习笔记（二）
0 0
goroutine 的引出 | 学习笔记

0 0
Spark 集群搭建_Spark 集群结构|学习笔记

0 0

0 0
+关注