一、介绍
论文地址:https://arxiv.org/pdf/2102.02808.pdf
代码地址:http://github.com/swz30/MPRNet
恢复图像任务,需要在空间细节和高级上下文特征之间取得复杂的平衡。于是作者设计了一个多阶段的模型,模型首先使用编解码器架构来学习上下文的特征,然后将它们与保留局部信息的高分辨率分支结合起来。
打个比方,我要修复一张蛇的图片,编解码器负责提取高级上下文特征,告诉模型要在蛇身上“画”鳞片,而不是羽毛或其他东西;然后高分辨率分支负责细化鳞片的图案。
MPRNet细节很多,但最主要的创新还是“多阶段”,模型共有三个阶段,前两个阶段是编解码器子网络,用来学习较大感受野的上下文特征,最后一个阶段是高分辨率分支,用于在最终的输出图像中构建所需的纹理。作者给出了Deblurring、Denoising、Deraining三个任务的项目,三个项目的backbone是一样的,只是参数规模有所不同(Deblurring>Denoising>Deraining),下面我们以最大的Deblurring为例进行介绍。
二、使用方法
MPRNet项目分为Deblurring、Denoising和Deraining 三个子项目。作者没有用稀奇古怪的库,也没用高级的编程技巧,非常适合拿来研究学习,使用方法也很简单,几句话技能说完。
1.推理
(1)下载预训练模型:预训练模型分别存在三个子项目的pretrained_models文件夹,下载地址在每个pretrained_models文件夹的 README.md中,需要科学上网,我放在了网盘里:
链接:https://pan.baidu.com/s/1sxfidMvlU_pIeO5zD1tKZg 提取码:faye
(2)准备测试图片:将退化图片放在目录samples/input/中
(3)执行demo.py
# 执行Deblurring python demo.py --task Deblurring # 执行Denoising python demo.py --task Denoising # 执行Deraining python demo.py --task Deraining
(4)结果放在目录samples/output/中。
2.训练
(1)根据Dataset文件夹内的README.md文件中的地址下载数据集。
(2)查看training.yml是否需要修改,主要是最后的数据集地址。
(3)执行训练
python train.py
三、MPRNet结构
我将按照官方代码实现来介绍模型结构,一些重要模块的划分可能跟论文有区别,但是整体结构是一样的。
1.整体结构
MPRNet官方给出的结构图如下:
图1
这个图总体概括了MPRNet的结构,但是很多细节没有表现出来,通过阅读代码我给出更加详细的模型结构介绍。下面的图中输入统一512x512,我们以Deblurring为例,并且batch_size=1。
整体结构图如下:
图2
图中的三个Input都是原图,整个模型三个阶段,整体流程如下:
1.1 输入图片采用multi-patch方式分成四份,分成左上、右上、左下、右下;
1.2 每个patch经过一个3x3的卷积扩充维度,为的是后面能提取更丰富的特征信息;
1.3 经过CAB(Channel Attention Block),利用注意力机制提取每个维度上的特征;
1.4 Encoder,编码三种尺度的图像特征,提取多尺度上下文特征,同时也是提取更深层的语义特征;
1.5 合并深特征,将四个batch的同尺度特征合并成左右两个尺度,送入Decoder;
1.6 Decoder,提取合并后的每个尺度的特征;
1.7 输入图片采用multi-patch方式分成两份,分成左、右;
1.8 将左右两个batch分别与Stage1 Decoder输出的大尺度特征图送入SAM(Supervised Attention Module),SAM在训练的时候可以利用GT为当前阶段的恢复过程提供有用的控制信号;
1.9 SAM的输出分成两部分,一部分是第二次输入的原图特征,它将继续下面的流程;一部分用于训练时的Stage1输出,可以利用GT更快更好的让模型收敛。
2.0 经过Stage2的卷积扩充通道和CAB操作,将Stage1中的Decoder前后的特征送入Stage2的Encoder。
2.2 经过和Stage1相似的Decoder,也产生两个部分的输出,一部分继续Stage3,一部分输出与GT算损失;
3.1 Stage3的原图输入不在切分,目的是利用完整的上下文信息恢复图片细节。
3.2 将原图经过卷积做升维处理;
3.3 将Stage2中的Decoder前后的特征送入Stage3的ORSNet(Original Resolution Subnetwork),ORSNet不使用任何降采样操作,并生成空间丰富的高分辨率特征。
3.4 最后经过一个卷积将维度降为3,输出。
代码实现:
#位置:MPRNet.py class MPRNet(nn.Module): def __init__(self, in_c=3, out_c=3, n_feat=96, scale_unetfeats=48, scale_orsnetfeats=32, num_cab=8, kernel_size=3, reduction=4, bias=False): super(MPRNet, self).__init__() act=nn.PReLU() self.shallow_feat1 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat,kernel_size, reduction, bias=bias, act=act)) self.shallow_feat2 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat,kernel_size, reduction, bias=bias, act=act)) self.shallow_feat3 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat,kernel_size, reduction, bias=bias, act=act)) # Cross Stage Feature Fusion (CSFF) self.stage1_encoder = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=False) self.stage1_decoder = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats) self.stage2_encoder = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=True) self.stage2_decoder = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats) self.stage3_orsnet = ORSNet(n_feat, scale_orsnetfeats, kernel_size, reduction, act, bias, scale_unetfeats, num_cab) self.sam12 = SAM(n_feat, kernel_size=1, bias=bias) self.sam23 = SAM(n_feat, kernel_size=1, bias=bias) self.concat12 = conv(n_feat*2, n_feat, kernel_size, bias=bias) self.concat23 = conv(n_feat*2, n_feat+scale_orsnetfeats, kernel_size, bias=bias) self.tail = conv(n_feat+scale_orsnetfeats, out_c, kernel_size, bias=bias) def forward(self, x3_img): # Original-resolution Image for Stage 3 H = x3_img.size(2) W = x3_img.size(3) # Multi-Patch Hierarchy: Split Image into four non-overlapping patches # Two Patches for Stage 2 x2top_img = x3_img[:,:,0:int(H/2),:] x2bot_img = x3_img[:,:,int(H/2):H,:] # Four Patches for Stage 1 x1ltop_img = x2top_img[:,:,:,0:int(W/2)] x1rtop_img = x2top_img[:,:,:,int(W/2):W] x1lbot_img = x2bot_img[:,:,:,0:int(W/2)] x1rbot_img = x2bot_img[:,:,:,int(W/2):W] ##------------------------------------------- ##-------------- Stage 1--------------------- ##------------------------------------------- ## Compute Shallow Features x1ltop = self.shallow_feat1(x1ltop_img) x1rtop = self.shallow_feat1(x1rtop_img) x1lbot = self.shallow_feat1(x1lbot_img) x1rbot = self.shallow_feat1(x1rbot_img) ## Process features of all 4 patches with Encoder of Stage 1 feat1_ltop = self.stage1_encoder(x1ltop) feat1_rtop = self.stage1_encoder(x1rtop) feat1_lbot = self.stage1_encoder(x1lbot) feat1_rbot = self.stage1_encoder(x1rbot) ## Concat deep features feat1_top = [torch.cat((k,v), 3) for k,v in zip(feat1_ltop,feat1_rtop)] feat1_bot = [torch.cat((k,v), 3) for k,v in zip(feat1_lbot,feat1_rbot)] ## Pass features through Decoder of Stage 1 res1_top = self.stage1_decoder(feat1_top) res1_bot = self.stage1_decoder(feat1_bot) ## Apply Supervised Attention Module (SAM) x2top_samfeats, stage1_img_top = self.sam12(res1_top[0], x2top_img) x2bot_samfeats, stage1_img_bot = self.sam12(res1_bot[0], x2bot_img) ## Output image at Stage 1 stage1_img = torch.cat([stage1_img_top, stage1_img_bot],2) ##------------------------------------------- ##-------------- Stage 2--------------------- ##------------------------------------------- ## Compute Shallow Features x2top = self.shallow_feat2(x2top_img) x2bot = self.shallow_feat2(x2bot_img) ## Concatenate SAM features of Stage 1 with shallow features of Stage 2 x2top_cat = self.concat12(torch.cat([x2top, x2top_samfeats], 1)) x2bot_cat = self.concat12(torch.cat([x2bot, x2bot_samfeats], 1)) ## Process features of both patches with Encoder of Stage 2 feat2_top = self.stage2_encoder(x2top_cat, feat1_top, res1_top) feat2_bot = self.stage2_encoder(x2bot_cat, feat1_bot, res1_bot) ## Concat deep features feat2 = [torch.cat((k,v), 2) for k,v in zip(feat2_top,feat2_bot)] ## Pass features through Decoder of Stage 2 res2 = self.stage2_decoder(feat2) ## Apply SAM x3_samfeats, stage2_img = self.sam23(res2[0], x3_img) ##------------------------------------------- ##-------------- Stage 3--------------------- ##------------------------------------------- ## Compute Shallow Features x3 = self.shallow_feat3(x3_img) ## Concatenate SAM features of Stage 2 with shallow features of Stage 3 x3_cat = self.concat23(torch.cat([x3, x3_samfeats], 1)) x3_cat = self.stage3_orsnet(x3_cat, feat2, res2) stage3_img = self.tail(x3_cat) return [stage3_img+x3_img, stage2_img, stage1_img]
图中还有一些模块细节没有表现出来,下面我将详细介绍。
2.CAB(Channel Attention Block)
顾名思义,CAB就是利用注意力机制提取每个通道的特征,输出输入特征图形状不变,结构图如下:
图3
可以看到,经过了两个卷积和GAP之后得到了一个概率图(就是那个残差边),在经过两个卷积和Sigmoid之后与概率图相乘,就实现了一个通道注意力机制。
代码实现:
# 位置MPRNet.py ## Channel Attention Layer class CALayer(nn.Module): def __init__(self, channel, reduction=16, bias=False): super(CALayer, self).__init__() # global average pooling: feature --> point self.avg_pool = nn.AdaptiveAvgPool2d(1) # feature channel downscale and upscale --> channel weight self.conv_du = nn.Sequential( nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias), nn.ReLU(inplace=True), nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias), nn.Sigmoid() ) def forward(self, x): y = self.avg_pool(x) y = self.conv_du(y) return x * y ########################################################################## ## Channel Attention Block (CAB) class CAB(nn.Module): def __init__(self, n_feat, kernel_size, reduction, bias, act): super(CAB, self).__init__() modules_body = [] modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) modules_body.append(act) modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) self.CA = CALayer(n_feat, reduction, bias=bias) self.body = nn.Sequential(*modules_body) def forward(self, x): res = self.body(x) res = self.CA(res) res += x return res
【论文笔记】图像修复MPRNet:Multi-Stage Progressive Image Restoration 含代码解析2:https://developer.aliyun.com/article/1507775