介绍
生成对抗网络(Generative Adversarial Networks,简称GANs)是一种深度学习模型,由Ian Goodfellow于2014年提出。GAN由两个神经网络组成,一个生成器(Generator)和一个判别器(Discriminator),它们通过对抗的方式共同训练。
判别器(Discriminator)
判别器的任务是区分输入数据是来自真实数据集还是生成器产生的假数据。它的目标是在面对真实数据时输出高概率,面对生成器产生的数据时输出低概率。
生成器(Generator)
生成器的目的是产生与真实数据尽可能相似的假数据。在训练的开始,生成器通常会输出随机噪声,但是随着训练的进行,它会逐渐学习到如何生成越来越真实的数据。生成器不直接访问真实数据,而是通过判别器给出的反馈来优化它的输出。
对抗训练
GAN的训练过程可以被看作是一个博弈游戏,其中生成器尝试“欺骗”判别器,而判别器则尝试不被欺骗。这个过程可以用以下步骤概括:
- 训练判别器:用真实数据和生成器产生的假数据训练判别器,目标是准确地区分这两者。
- 训练生成器:固定判别器,更新生成器的参数使得判别器更可能将假数据误判为真实数据。
损失函数
GANs的训练通常涉及到最小化一些损失函数,这些函数会衡量判别器和生成器的性能。最常用的一个损失函数是交叉熵损失函数,但也有其他的变体和改进方法。
## 伪代码
for epoch in range(num_epochs):
for batch_data in data_loader:
# 更新判别器
real_images = batch_data.to(device)
z = torch.randn(batch_size, latent_dim).to(device)
fake_images = generator(z).detach()
d_loss_real = discriminator(real_images)
d_loss_fake = discriminator(fake_images)
# 判别器损失
d_loss = -(torch.mean(d_loss_real) - torch.mean(d_loss_fake))
discriminator.zero_grad()
d_loss.backward()
discriminator_optimizer.step()
# 更新生成器
z = torch.randn(batch_size, latent_dim).to(device)
fake_images = generator(z)
g_loss = -torch.mean(discriminator(fake_images))
generator.zero_grad()
g_loss.backward()
generator_optimizer.step()
历史
- 最早概念提出
2014年,Ian Goodfellow等在论文“Generative Adversarial Nets”中首次提出了GAN的概念。提出了一种判别器和生成器对抗并互相提升的框架。
https://arxiv.org/abs/1406.2661
2.CNN条件GAN
2015年,Mirza等人提出使用卷积神经网络(CNN)作为GAN的判别器和生成器,使其可以处理图像数据。
https://arxiv.org/abs/1411.1784
3.DCGAN
2015年,Radford等人提出了DCGAN,使用CNN并提供了许多改进训练GAN的指导,成为第一个有效训练GAN的框架。
https://arxiv.org/abs/1511.06434
4.理论分析
2016年,Arjovsky等人提出Wasserstein GAN,从理论上分析了GAN训练的不稳定性,提出了Wasserstein距离改善模型稳定性。
https://arxiv.org/abs/1701.07875
5.条件GAN
2014年,Mirza等人提出条件GAN,可以控制GAN的生成结果。
https://arxiv.org/abs/1606.03657
6.进一步扩展
此后陆续出现了InfoGAN、CycleGAN等框架,GAN的应用范围不断扩展。2017年CycleGAN实现无配对图像转换。
https://arxiv.org/abs/1703.10593
7.高分辨率GAN
2018年,Brock提出BigGAN,首次实现高分辨率、高质量的图像生成。
https://arxiv.org/abs/1809.11096
8.自监督GAN
2020年,Chen提出基于对比学习的自监督GAN SimGAN,无需数据标注即可训练。
https://arxiv.org/abs/2010.08895
9.大模型GAN
近年来,大模型GAN如Nvidia的StyleGAN等可以生成更逼真的图片。GAN技术仍在不断发展中。
https://arxiv.org/abs/1912.04958
应用场景
- 生成逼真的人脸或物体图片
- 风格迁移,如将日常照片转换为名画风格
- 图像超分辨率,即从低分辨率图像生成高分辨率版本
- 数据增强,为数据集生成新的样本
缺点
GANs的训练是复杂的,并且常常面临模式崩溃(mode collapse)等问题,模式崩溃指的是生成器开始生成非常相似或重复的样本,而没有多样性。