【论文笔记】图像修复MPRNet:Multi-Stage Progressive Image Restoration 含代码解析1:https://developer.aliyun.com/article/1507772
3.Stage1 Encoder
Stage1和Stage1的Encoder有一些区别,所以分开介绍。Stage1 Encoder有一个输入和三个不同尺度的输出,为的是提取三个尺度的特征并为下面的尺度融合流程做准备;其中有多个CAB结构,可以更好的提取通道特征;下采样通过粗暴的Downsample实现,结构如下:
图4
代码实现:
# 位置MPRNet.py class Encoder(nn.Module): def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff): super(Encoder, self).__init__() self.encoder_level1 = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2)] self.encoder_level2 = [CAB(n_feat+scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(2)] self.encoder_level3 = [CAB(n_feat+(scale_unetfeats*2), kernel_size, reduction, bias=bias, act=act) for _ in range(2)] self.encoder_level1 = nn.Sequential(*self.encoder_level1) self.encoder_level2 = nn.Sequential(*self.encoder_level2) self.encoder_level3 = nn.Sequential(*self.encoder_level3) self.down12 = DownSample(n_feat, scale_unetfeats) self.down23 = DownSample(n_feat+scale_unetfeats, scale_unetfeats) # Cross Stage Feature Fusion (CSFF) if csff: self.csff_enc1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias) self.csff_enc2 = nn.Conv2d(n_feat+scale_unetfeats, n_feat+scale_unetfeats, kernel_size=1, bias=bias) self.csff_enc3 = nn.Conv2d(n_feat+(scale_unetfeats*2), n_feat+(scale_unetfeats*2), kernel_size=1, bias=bias) self.csff_dec1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias) self.csff_dec2 = nn.Conv2d(n_feat+scale_unetfeats, n_feat+scale_unetfeats, kernel_size=1, bias=bias) self.csff_dec3 = nn.Conv2d(n_feat+(scale_unetfeats*2), n_feat+(scale_unetfeats*2), kernel_size=1, bias=bias) def forward(self, x, encoder_outs=None, decoder_outs=None): enc1 = self.encoder_level1(x) if (encoder_outs is not None) and (decoder_outs is not None): enc1 = enc1 + self.csff_enc1(encoder_outs[0]) + self.csff_dec1(decoder_outs[0]) x = self.down12(enc1) enc2 = self.encoder_level2(x) if (encoder_outs is not None) and (decoder_outs is not None): enc2 = enc2 + self.csff_enc2(encoder_outs[1]) + self.csff_dec2(decoder_outs[1]) x = self.down23(enc2) enc3 = self.encoder_level3(x) if (encoder_outs is not None) and (decoder_outs is not None): enc3 = enc3 + self.csff_enc3(encoder_outs[2]) + self.csff_dec3(decoder_outs[2]) return [enc1, enc2, enc3]
4.Stage2 Encoder
Stage2 Encoder输入为三个,分别为上一层的输出和Stage1中的Decoder前后的特征。主流程(也就是左面竖着的那一列)和Stage1 Encoder是一样的。增加的两个输入,每个输入又分为三个尺度,每个尺度经过一个卷积层,然后相同尺度的特征图做特征融合,输出,结构如下:
图5
5.Decoder
两个阶段的Decoder结构是一样的,所以放在一起说,有三个不用尺度的输入;通过CAB提取特征;小尺度特征通过上采样变大,通过卷积使通道变小;小尺度的特征图shape最终变成跟大尺度一样,通过残差边实现特征融合,结构如下:
图6
代码实现:
# 位置:MPRNet.py class Decoder(nn.Module): def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats): super(Decoder, self).__init__() self.decoder_level1 = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2)] self.decoder_level2 = [CAB(n_feat+scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(2)] self.decoder_level3 = [CAB(n_feat+(scale_unetfeats*2), kernel_size, reduction, bias=bias, act=act) for _ in range(2)] self.decoder_level1 = nn.Sequential(*self.decoder_level1) self.decoder_level2 = nn.Sequential(*self.decoder_level2) self.decoder_level3 = nn.Sequential(*self.decoder_level3) self.skip_attn1 = CAB(n_feat, kernel_size, reduction, bias=bias, act=act) self.skip_attn2 = CAB(n_feat+scale_unetfeats, kernel_size, reduction, bias=bias, act=act) self.up21 = SkipUpSample(n_feat, scale_unetfeats) self.up32 = SkipUpSample(n_feat+scale_unetfeats, scale_unetfeats) def forward(self, outs): enc1, enc2, enc3 = outs dec3 = self.decoder_level3(enc3) x = self.up32(dec3, self.skip_attn2(enc2)) dec2 = self.decoder_level2(x) x = self.up21(dec2, self.skip_attn1(enc1)) dec1 = self.decoder_level1(x) return [dec1,dec2,dec3]
6.SAM(Supervised Attention Module)
SAM出现在两个阶段间,有两个输入,将上一层特征和原图作为输入,提升了特征提取的性能,,SAM作为有监督的注意模块,使用注意力图强力筛选了跨阶段间的有用特征。有两个输出,一个是经过了注意力机制的特征图,为下面的流程提供特征;一个是3通道的图片特征,为了训练阶段输出,结构如下:
图7
代码位置:
# 位置MPRNet.py ## Supervised Attention Module class SAM(nn.Module): def __init__(self, n_feat, kernel_size, bias): super(SAM, self).__init__() self.conv1 = conv(n_feat, n_feat, kernel_size, bias=bias) self.conv2 = conv(n_feat, 3, kernel_size, bias=bias) self.conv3 = conv(3, n_feat, kernel_size, bias=bias) def forward(self, x, x_img): x1 = self.conv1(x) img = self.conv2(x) + x_img x2 = torch.sigmoid(self.conv3(img)) x1 = x1*x2 x1 = x1+x return x1, img
7.ORSNet(Original Resolution Subnetwork)
为了保留输入图像的细节,模型在最后一阶段引入了原始分辨率的子网(ORSNet:Original Resolution Subnetwork)。ORSNet不使用任何降采样操作,并生成空间丰富的高分辨率特征。它由多个原始分辨率块(BRB)组成,是模型的最后阶段,结构如下:
图8
可以看到,输入为三个,分别为上一层的输出和Stage2中的Decoder前后的特征。后两个输入,每个输入又分为三个尺度,三个尺度的通道数都先变成96,然后在变成128;小尺度的size都变成和大尺度一样,最后做特征融合融合前会经过ORB(Original Resolution Block)模块。
ORB由一连串的CAB组成,还有一个大的残差边,结构如下:
图9
代码实现:
# 位置MPRNet.py ## Original Resolution Block (ORB) class ORB(nn.Module): def __init__(self, n_feat, kernel_size, reduction, act, bias, num_cab): super(ORB, self).__init__() modules_body = [] modules_body = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(num_cab)] modules_body.append(conv(n_feat, n_feat, kernel_size)) self.body = nn.Sequential(*modules_body) def forward(self, x): res = self.body(x) res += x return res ########################################################################## class ORSNet(nn.Module): def __init__(self, n_feat, scale_orsnetfeats, kernel_size, reduction, act, bias, scale_unetfeats, num_cab): super(ORSNet, self).__init__() self.orb1 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab) self.orb2 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab) self.orb3 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab) self.up_enc1 = UpSample(n_feat, scale_unetfeats) self.up_dec1 = UpSample(n_feat, scale_unetfeats) self.up_enc2 = nn.Sequential(UpSample(n_feat+scale_unetfeats, scale_unetfeats), UpSample(n_feat, scale_unetfeats)) self.up_dec2 = nn.Sequential(UpSample(n_feat+scale_unetfeats, scale_unetfeats), UpSample(n_feat, scale_unetfeats)) self.conv_enc1 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) self.conv_enc2 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) self.conv_enc3 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) self.conv_dec1 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) self.conv_dec2 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) self.conv_dec3 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) def forward(self, x, encoder_outs, decoder_outs): x = self.orb1(x) x = x + self.conv_enc1(encoder_outs[0]) + self.conv_dec1(decoder_outs[0]) x = self.orb2(x) x = x + self.conv_enc2(self.up_enc1(encoder_outs[1])) + self.conv_dec2(self.up_dec1(decoder_outs[1])) x = self.orb3(x) x = x + self.conv_enc3(self.up_enc2(encoder_outs[2])) + self.conv_dec3(self.up_dec2(decoder_outs[2])) return x
四、损失函数
MPRNet主要使用了两个损失函数CharbonnierLoss和EdgeLoss,公式如下:
其中累加是因为训练的时候三个阶段都有输出,都需要个GT计算损失(如图2的三个output);该模型不是直接预测恢复的图像 ,而是预测残差图像 ,添加退化的输入图像 得到:
Deblurring和Deraining两个任务CharbonnierLoss和EdgeLoss做了加权求和,比例1:0.05;只使用了CharbonnierLoss,我感觉是因为这里使用的噪声是某种分布(入高斯分布、泊松分布)的噪声,不会引起剧烈的边缘差异,所以Denoising没有使用EdgeLoss。
下面简单介绍一下两种损失。
1.CharbonnierLoss
公式如下:
CharbonnierLoss在零点附近由于常数 的存在,梯度不会变成零,避免梯度消失。函数曲线近似L1损失,相比L2损失而言,对异常值不敏感,避免过分放大误差。
代码实现:
# 位置losses.py class CharbonnierLoss(nn.Module): """Charbonnier Loss (L1)""" def __init__(self, eps=1e-3): super(CharbonnierLoss, self).__init__() self.eps = eps def forward(self, x, y): diff = x - y # loss = torch.sum(torch.sqrt(diff * diff + self.eps)) loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps))) return loss
2.EdgeLoss
L1或者L2损失注重的是全局,没有很好地考虑一些显著特征的影响, 而显著的结构和纹理信息与人的主观感知效果高度相关,是不能忽视的。
边缘损失主要考虑纹理部分的差异,可以很好地考虑高频的纹理结构信息, 提高生成图像的细节表现,公示如下:
其中 表示Laplacian边缘检测中的核函数, 表示对 做边缘检测,公式中其他部分和CharbonnierLoss类似。
代码实现:
# 位置losses.py class EdgeLoss(nn.Module): def __init__(self): super(EdgeLoss, self).__init__() k = torch.Tensor([[.05, .25, .4, .25, .05]]) self.kernel = torch.matmul(k.t(),k).unsqueeze(0).repeat(3,1,1,1) if torch.cuda.is_available(): self.kernel = self.kernel.cuda() self.loss = CharbonnierLoss() def conv_gauss(self, img): n_channels, _, kw, kh = self.kernel.shape img = F.pad(img, (kw//2, kh//2, kw//2, kh//2), mode='replicate') return F.conv2d(img, self.kernel, groups=n_channels) def laplacian_kernel(self, current): filtered = self.conv_gauss(current) # filter down = filtered[:,:,::2,::2] # downsample new_filter = torch.zeros_like(filtered) new_filter[:,:,::2,::2] = down*4 # upsample filtered = self.conv_gauss(new_filter) # filter diff = current - filtered return diff def forward(self, x, y): loss = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(y)) return loss