前言
论文:https://arxiv.org/pdf/1809.00219v2.pdf
一、ESRGAN的主要介绍
研究目的:通过改进SRGAN(Super-Resolution Generative Adversarial Network)来提高视觉质量。
ESRGAN是基于SRGAN改进而来到,相比于SRGAN它在三个方面进行了改进:
1、改进了网络结构、对抗损失、感知损失
2、引入Residual-in-Residual Dense Block(RRDB)
3、使用激活前的VGG特征来改善感知损失
在开始讲这个ESRGAN的具体实现之前,先来看一下他和他的前辈SRGAN的对比效果:
我们可以从上图看出:
1、ESRGAN在锐度和边缘信息上优于SRGAN,且去除了“伪影”。
2、从PI(图像感知质量指标—perceptual index)和PMSE(根均方误差)两个指标来看,ESRGAN也可以当之无愧地称得上是超分辨率复原任务中的the State-of-the-Art。
二、ESRGAN的主要内容
1.RRDB,对residual blocks的改进
我们可以看出这个残差块是很传统的Conv-BN-relu-Conv-BN的结构,而作者在文章中是这么说到的:
什么意思呢?就是说作者认为SRGAN之所以会产生伪影,就是因为使用了Batch normalization,所以作者做出了去除BN的改进
而且我们再来看,SRGAN的残差块是顺序连接的,而作者可能哎,受denseNet的启发,他就把这些残差块用密集连接的方式连在一起.那么他的生成器里的特征提取部分最终变成了这样子:
既然我们知道他的网络结是这样子设计的,那么他的实现其实就很简单:
class RRDB_Net(nn.Module): def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', \ mode='CNA', res_scale=1, upsample_mode='upconv'): super(RRDB_Net, self).__init__() n_upscale = int(math.log(upscale, 2)) if upscale == 3: n_upscale = 1 fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None) rb_blocks = [B.RRDB(nf, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \ norm_type=norm_type, act_type=act_type, mode='CNA') for _ in range(nb)] LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode) if upsample_mode == 'upconv': upsample_block = B.upconv_blcok elif upsample_mode == 'pixelshuffle': upsample_block = B.pixelshuffle_block else: raise NotImplementedError('upsample mode [%s] is not found' % upsample_mode) if upscale == 3: upsampler = upsample_block(nf, nf, 3, act_type=act_type) else: upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)] HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type) HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None) self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)),\ *upsampler, HR_conv0, HR_conv1) def forward(self, x): x = self.model(x)
2、对损失函数的改进
说到损失函数,在SRGAN的文章里,这个判别器它判断的是你输入的图片是"真的"高清图像,还是"假的"高清图像,而且作者他就提出一种新的思考模式,就是说我的判别器是来估计真实图像相对来说比fake图像更逼真的概率。
怎么来理解这句话呢?
具体而言,作者把标准的判别器换成Relativistic average Discriminator(RaD),所以判别器的损失函数定义为:
对应的生成器的对抗损失函数为:
求MSE的操作是通过对$mini-batch$中的所有数据求平均得到的,
xf
是原始低分辨图像经过生成器以后的图像,由于对抗的损失包含了
xr
和
xf
,所以生成器受益于对抗训练中的生成数据和实际数据的梯度,这种调整会使得网络学习到更尖锐的边缘和更细节的纹理。
3、对感知损失的改进
之前看SRGAN的时候,它是用来一个训练好的VGG16来给出超分辨率复原所需要的特征,作者通过对损失域的研究发现,激活前的特征,这样会克服两个缺点。
1、激活后的特征是非常稀疏的,特别是在很深的网络中。这种稀疏的激活提供的监督效果是很弱的,会造成性能低下;
2、使用激活后的特征会导致重建图像与GT的亮度不一致。
与此同时,作者还在loss函数中加入了
L1=Exi||G(xi)−y||1
,也就是
L1
损失,最终损失函数由三部分组成:
4、网络插值(Network Interpolation)
为了平衡感知质量和PSNR等评价值,作者提出了一个灵活且有效的方法---网络插值。具体而言,作者首先基于PSNR方法训练的得到的网络G_PSNR,然后再用基于GAN的网络G_GAN进行整合。
网络插值与图像插值的比较:
5、具体实现部分
ESRGAN可以实现放大4倍的效果
首先要训练一个基于PSRN指标的模型,如何根据这个模型的权重进行生成器的初始化.
作者使用了Adam作为优化器(β1 = 0.9, β2 = 0.999.)进行交替训练.
生成器有16个residual block和23个RRDB
三、结论
提出了一个ESRGAN模型,实现一贯更好的感知质量比以前的SR方法。该方法在PIRM-SR挑战赛中获得了感知指数的第一名。
已经制定了一个新的架构,包含几个RDDB块没有BN层。此外,采用包括残余缩放和较小初始化的有用技术来促进所提出的深度模型的训练。还介绍了使用相对论GAN作为鉴别器,它学习判断一个图像是否比另一个更真实,指导生成器恢复更详细的纹理。此外,通过使用激活前的特征来增强感知损失,这提供了更强的监督,从而恢复更准确的亮度和逼真的纹理。